mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
Compare commits
4 Commits
b5d6fe30aa
...
4fbb99f20e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fbb99f20e | ||
|
|
52d6d1337f | ||
|
|
e5b2a872d3 | ||
|
|
68460d18cc |
3
.github/workflows/codeql.yml
vendored
3
.github/workflows/codeql.yml
vendored
@ -45,12 +45,9 @@ jobs:
|
||||
include:
|
||||
- language: python
|
||||
build-mode: none
|
||||
- language: javascript-typescript
|
||||
build-mode: none
|
||||
# CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift'
|
||||
# Use `c-cpp` to analyze code written in C, C++ or both
|
||||
# Use 'java-kotlin' to analyze code written in Java, Kotlin or both
|
||||
# Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
|
||||
# To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
|
||||
# see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
|
||||
# If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
|
||||
|
||||
@ -7,7 +7,7 @@ ci:
|
||||
autoupdate_commit_msg: ":arrow_up: auto update by pre-commit hooks"
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.3
|
||||
rev: v0.14.7
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.group_plugin_setting import GroupPluginSetting
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
@ -23,6 +24,11 @@ def register_cache_types():
|
||||
CacheRegistry.register(CacheType.GROUPS, GroupConsole)
|
||||
CacheRegistry.register(CacheType.BOT, BotConsole)
|
||||
CacheRegistry.register(CacheType.USERS, UserConsole)
|
||||
CacheRegistry.register(
|
||||
CacheType.GROUP_PLUGIN_SETTINGS,
|
||||
GroupPluginSetting,
|
||||
key_format="{group_id}_{plugin_name}_{key}",
|
||||
)
|
||||
CacheRegistry.register(
|
||||
CacheType.LEVEL, LevelUser, key_format="{user_id}_{group_id}"
|
||||
)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot_plugin_alconna import (
|
||||
@ -58,7 +60,12 @@ __plugin_meta__ = PluginMetadata(
|
||||
llm_cmd = on_alconna(
|
||||
Alconna(
|
||||
"llm",
|
||||
Subcommand("list", alias=["ls"], help_text="查看模型列表"),
|
||||
Subcommand(
|
||||
"list",
|
||||
Option("--text", action=store_true, help_text="以纯文本格式输出模型列表"),
|
||||
alias=["ls"],
|
||||
help_text="查看模型列表",
|
||||
),
|
||||
Subcommand("info", Args["model_name", str], help_text="查看模型详情"),
|
||||
Subcommand("default", Args["model_name?", str], help_text="查看或设置默认模型"),
|
||||
Subcommand(
|
||||
@ -80,13 +87,36 @@ llm_cmd = on_alconna(
|
||||
|
||||
|
||||
@llm_cmd.assign("list")
|
||||
async def handle_list(arp: Arparma, show_all: Query[bool] = Query("all")):
|
||||
async def handle_list(
|
||||
arp: Arparma,
|
||||
show_all: Query[bool] = Query("all"),
|
||||
text_mode: Query[bool] = Query("list.text.value", False),
|
||||
):
|
||||
"""处理 'llm list' 命令"""
|
||||
logger.info("获取LLM模型列表", command="LLM Manage", session=arp.header_result)
|
||||
models = await DataSource.get_model_list(show_all=show_all.result)
|
||||
|
||||
image = await Presenters.format_model_list_as_image(models, show_all.result)
|
||||
await llm_cmd.finish(MessageUtils.build_message(image))
|
||||
if text_mode.result:
|
||||
if not models:
|
||||
await llm_cmd.finish("当前没有配置任何LLM模型。")
|
||||
|
||||
grouped_models = defaultdict(list)
|
||||
for model in models:
|
||||
grouped_models[model["provider_name"]].append(model)
|
||||
|
||||
response_parts = ["可用的LLM模型列表:"]
|
||||
for provider, model_list in grouped_models.items():
|
||||
response_parts.append(f"\n{provider}:")
|
||||
for model in model_list:
|
||||
response_parts.append(
|
||||
f" {model['provider_name']}/{model['model_name']}"
|
||||
)
|
||||
|
||||
response_text = "\n".join(response_parts)
|
||||
await llm_cmd.finish(response_text)
|
||||
else:
|
||||
image = await Presenters.format_model_list_as_image(models, show_all.result)
|
||||
await llm_cmd.finish(MessageUtils.build_message(image))
|
||||
|
||||
|
||||
@llm_cmd.assign("info")
|
||||
@ -114,7 +144,7 @@ async def handle_default(arp: Arparma, model_name: Match[str]):
|
||||
command="LLM Manage",
|
||||
session=arp.header_result,
|
||||
)
|
||||
success, message = await DataSource.set_default_model(model_name.result)
|
||||
_success, message = await DataSource.set_default_model(model_name.result)
|
||||
await llm_cmd.finish(message)
|
||||
else:
|
||||
logger.info("查看默认模型", command="LLM Manage", session=arp.header_result)
|
||||
@ -132,7 +162,7 @@ async def handle_test(arp: Arparma, model_name: Match[str]):
|
||||
)
|
||||
await llm_cmd.send(f"正在测试模型 '{model_name.result}',请稍候...")
|
||||
|
||||
success, message = await DataSource.test_model_connectivity(model_name.result)
|
||||
_success, message = await DataSource.test_model_connectivity(model_name.result)
|
||||
await llm_cmd.finish(message)
|
||||
|
||||
|
||||
@ -167,5 +197,5 @@ async def handle_reset_key(
|
||||
)
|
||||
logger.info(log_msg, command="LLM Manage", session=arp.header_result)
|
||||
|
||||
success, message = await DataSource.reset_key(provider_name.result, key_to_reset)
|
||||
_success, message = await DataSource.reset_key(provider_name.result, key_to_reset)
|
||||
await llm_cmd.finish(message)
|
||||
|
||||
@ -17,6 +17,8 @@ from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task
|
||||
from zhenxun.models.event_log import EventLog
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.services.cache import CacheRoot
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.services.tags import tag_manager
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.enum import EventLogType, PluginType
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
@ -135,6 +137,11 @@ async def _(
|
||||
await EventLog.create(
|
||||
user_id=user_id, group_id=group_id, event_type=EventLogType.KICK_BOT
|
||||
)
|
||||
await tag_manager.remove_group_from_all_tags(group_id)
|
||||
logger.info(
|
||||
f"机器人被移出群聊,已自动从所有静态标签中移除群组 {group_id}",
|
||||
"群组标签管理",
|
||||
)
|
||||
elif event.sub_type in ["leave", "kick"]:
|
||||
if event.sub_type == "leave":
|
||||
"""主动退群"""
|
||||
|
||||
@ -2,6 +2,7 @@ import nonebot
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.services.tags import tag_manager
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
|
||||
@ -37,3 +38,20 @@ async def _():
|
||||
f"Bot: {bot.self_id} 自动更新好友信息错误", "自动更新好友", e=e
|
||||
)
|
||||
logger.info("自动更新好友信息成功...")
|
||||
|
||||
|
||||
# 自动清理静态标签中的无效群组
|
||||
@scheduler.scheduled_job(
|
||||
"cron",
|
||||
hour=23,
|
||||
minute=30,
|
||||
)
|
||||
async def _prune_stale_tags():
|
||||
deleted_count = await tag_manager.prune_stale_group_links()
|
||||
if deleted_count > 0:
|
||||
logger.info(
|
||||
f"定时任务:成功清理了 {deleted_count} 个无效的群组标签关联。",
|
||||
"群组标签管理",
|
||||
)
|
||||
else:
|
||||
logger.debug("定时任务:未发现无效的群组标签关联。", "群组标签管理")
|
||||
|
||||
@ -28,7 +28,8 @@ from nonebot_plugin_alconna.uniseg.segment import (
|
||||
)
|
||||
from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.configs.utils import PluginExtraData, Task
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
@ -45,34 +46,52 @@ __plugin_meta__ = PluginMetadata(
|
||||
name="广播",
|
||||
description="昭告天下!",
|
||||
usage="""
|
||||
广播 [消息内容]
|
||||
- 直接发送消息到除当前群组外的所有群组
|
||||
- 支持文本、图片、@、表情、视频等多种消息类型
|
||||
- 示例:广播 你们好!
|
||||
- 示例:广播 [图片] 新活动开始啦!
|
||||
向所有群组或指定标签的群组发送广播消息。
|
||||
|
||||
广播 + 引用消息
|
||||
- 将引用的消息作为广播内容发送
|
||||
- 支持引用普通消息或合并转发消息
|
||||
- 示例:(引用一条消息) 广播
|
||||
**基础用法**
|
||||
- `广播 [消息内容]`:向所有群组发送广播。
|
||||
- `广播` (并引用一条消息):将引用的消息作为内容进行广播。
|
||||
|
||||
广播撤回
|
||||
- 撤回最近一次由您触发的广播消息
|
||||
- 仅能撤回短时间内的消息
|
||||
- 示例:广播撤回
|
||||
**高级定向广播**
|
||||
- `广播 -t <标签名> [消息内容]`:向指定标签下的所有群组广播。
|
||||
- `广播到 <标签名> [消息内容]`:与 `-t` 等效的快捷方式。
|
||||
|
||||
特性:
|
||||
- 在群组中使用广播时,不会将消息发送到当前群组
|
||||
- 在私聊中使用广播时,会发送到所有群组
|
||||
**标签可以是静态的,也可以是动态的,例如:**
|
||||
- `广播到 核心群 通知:...`
|
||||
- `广播到 成员数>500的群 通知:...`
|
||||
|
||||
别名:
|
||||
- bc (广播的简写)
|
||||
- recall (广播撤回的别名)
|
||||
**其他命令**
|
||||
- `广播撤回` (别名: `recall`):撤回最近一次发送的广播。
|
||||
|
||||
特性:
|
||||
- 在群组中使用广播时,不会将消息发送到当前群组
|
||||
- 在私聊中使用广播时,会发送到所有群组
|
||||
|
||||
别名:
|
||||
- bc (广播的简写)
|
||||
- recall (广播撤回的别名)
|
||||
""".strip(),
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="1.2",
|
||||
version="1.3",
|
||||
plugin_type=PluginType.SUPERUSER,
|
||||
configs=[
|
||||
RegisterConfig(
|
||||
module="_task",
|
||||
key="DEFAULT_BROADCAST",
|
||||
value=True,
|
||||
help="被动 广播 进群默认开关状态",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
),
|
||||
RegisterConfig(
|
||||
module="_task",
|
||||
key="BROADCAST_CONCURRENCY_LIMIT",
|
||||
value=10,
|
||||
help="广播时的最大并发任务数,以避免API速率限制",
|
||||
default_value=10,
|
||||
),
|
||||
],
|
||||
tasks=[Task(module="broadcast", name="广播")],
|
||||
).to_dict(),
|
||||
)
|
||||
@ -103,6 +122,9 @@ _matcher = on_alconna(
|
||||
Alconna(
|
||||
"广播",
|
||||
Args["content?", AllParam],
|
||||
alc.Option(
|
||||
"-t|--tag", Args["tag_name_bc", str], help_text="向指定标签的群组广播"
|
||||
),
|
||||
),
|
||||
aliases={"bc"},
|
||||
priority=1,
|
||||
@ -112,6 +134,8 @@ _matcher = on_alconna(
|
||||
use_origin=False,
|
||||
)
|
||||
|
||||
_matcher.shortcut("广播到 {tag}", command="广播 -t {tag} {%*}")
|
||||
|
||||
_recall_matcher = on_alconna(
|
||||
Alconna("广播撤回"),
|
||||
aliases={"recall"},
|
||||
@ -128,23 +152,59 @@ async def handle_broadcast(
|
||||
event: Event,
|
||||
session: EventSession,
|
||||
arp: alc.Arparma,
|
||||
tag_name_match: alc.Match[str] = alc.AlconnaMatch("tag_name_bc"),
|
||||
):
|
||||
broadcast_content_msg = await _extract_broadcast_content(bot, event, arp, session)
|
||||
if not broadcast_content_msg:
|
||||
return
|
||||
|
||||
target_groups, enabled_groups = await get_broadcast_target_groups(bot, session)
|
||||
if not target_groups or not enabled_groups:
|
||||
tag_name_to_broadcast = None
|
||||
force_send = False
|
||||
|
||||
if tag_name_match.available:
|
||||
tag_name_to_broadcast = tag_name_match.result
|
||||
force_send = True
|
||||
|
||||
mode_desc = "强制发送到标签" if force_send else "普通发送"
|
||||
logger.debug(
|
||||
f"广播模式: {mode_desc}, 标签名: {tag_name_to_broadcast}",
|
||||
"广播",
|
||||
)
|
||||
|
||||
target_groups_console, groups_to_actually_send = await get_broadcast_target_groups(
|
||||
bot, session, tag_name_to_broadcast, force_send
|
||||
)
|
||||
|
||||
if not target_groups_console:
|
||||
if tag_name_to_broadcast:
|
||||
await MessageUtils.build_message(
|
||||
f"标签 '{tag_name_to_broadcast}' 中没有群组或标签不存在。"
|
||||
).send(reply_to=True)
|
||||
return
|
||||
|
||||
if not groups_to_actually_send:
|
||||
if not force_send and target_groups_console:
|
||||
await MessageUtils.build_message(
|
||||
"没有启用了广播功能的目标群组可供立即发送。"
|
||||
).send(reply_to=True)
|
||||
return
|
||||
|
||||
try:
|
||||
await send_broadcast_and_notify(
|
||||
bot, event, broadcast_content_msg, enabled_groups, target_groups, session
|
||||
bot,
|
||||
event,
|
||||
broadcast_content_msg,
|
||||
groups_to_actually_send,
|
||||
target_groups_console,
|
||||
session,
|
||||
force_send,
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = "发送广播失败"
|
||||
BroadcastManager.log_error(error_msg, e, session)
|
||||
await MessageUtils.build_message(f"{error_msg}。").send(reply_to=True)
|
||||
await bot.send_private_msg(
|
||||
user_id=str(event.get_user_id()), message=f"{error_msg}。"
|
||||
)
|
||||
|
||||
|
||||
@_recall_matcher.handle()
|
||||
@ -178,5 +238,6 @@ async def handle_broadcast_recall(
|
||||
except Exception as e:
|
||||
error_msg = "撤回广播消息失败"
|
||||
BroadcastManager.log_error(error_msg, e, session)
|
||||
user_id = str(event.get_user_id())
|
||||
await bot.send_private_msg(user_id=user_id, message=f"{error_msg}。")
|
||||
await bot.send_private_msg(
|
||||
user_id=str(event.get_user_id()), message=f"{error_msg}。"
|
||||
)
|
||||
|
||||
@ -5,11 +5,12 @@ from typing import ClassVar
|
||||
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.adapters.onebot.v11 import Bot as V11Bot
|
||||
from nonebot.exception import ActionFailed
|
||||
from nonebot.exception import ActionFailed, AdapterException
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
from nonebot_plugin_alconna.uniseg import Receipt, Reference
|
||||
from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
@ -18,6 +19,8 @@ from zhenxun.utils.platform import PlatformUtils
|
||||
from .models import BroadcastDetailResult, BroadcastResult
|
||||
from .utils import custom_nodes_to_v11_nodes, uni_message_to_v11_list_of_dicts
|
||||
|
||||
BROADCAST_SEND_DELAY_RANGE = (1, 3)
|
||||
|
||||
|
||||
class BroadcastManager:
|
||||
"""广播管理器"""
|
||||
@ -92,8 +95,16 @@ class BroadcastManager:
|
||||
logger.debug("清空上一次的广播消息ID记录", "广播", session=session)
|
||||
cls.clear_last_broadcast_msg_ids()
|
||||
|
||||
concurrency_limit = Config.get_config(
|
||||
"_task",
|
||||
"BROADCAST_CONCURRENCY_LIMIT",
|
||||
10,
|
||||
)
|
||||
|
||||
all_groups, _ = await cls.get_all_groups(bot)
|
||||
return await cls.send_to_specific_groups(bot, message, all_groups, session)
|
||||
return await cls.send_to_specific_groups(
|
||||
bot, message, all_groups, session, concurrency_limit=concurrency_limit
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def send_to_specific_groups(
|
||||
@ -102,14 +113,17 @@ class BroadcastManager:
|
||||
message: UniMessage,
|
||||
target_groups: list[GroupConsole],
|
||||
session_info: EventSession | str | None = None,
|
||||
force_send: bool = False,
|
||||
concurrency_limit: int = 10,
|
||||
) -> BroadcastResult:
|
||||
"""发送广播到指定群组"""
|
||||
log_session = session_info or bot.self_id
|
||||
logger.debug(
|
||||
f"开始广播,目标 {len(target_groups)} 个群组,Bot ID: {bot.self_id}",
|
||||
"广播",
|
||||
session=log_session,
|
||||
target_count = len(target_groups)
|
||||
log_message = (
|
||||
f"开始广播,目标 {target_count} 个群组 (并发数: {concurrency_limit}),"
|
||||
f"Bot ID: {bot.self_id}, ForceSend: {force_send}"
|
||||
)
|
||||
logger.info(log_message, "广播", session=log_session)
|
||||
|
||||
if not target_groups:
|
||||
logger.debug("目标群组列表为空,广播结束", "广播", session=log_session)
|
||||
@ -165,7 +179,12 @@ class BroadcastManager:
|
||||
)
|
||||
return 0, len(target_groups)
|
||||
success_count, error_count, skip_count = await cls._broadcast_forward(
|
||||
bot, log_session, target_groups, v11_nodes
|
||||
bot,
|
||||
log_session,
|
||||
target_groups,
|
||||
v11_nodes,
|
||||
force_send,
|
||||
concurrency_limit,
|
||||
)
|
||||
else:
|
||||
if is_forward_broadcast:
|
||||
@ -175,7 +194,12 @@ class BroadcastManager:
|
||||
session=log_session,
|
||||
)
|
||||
success_count, error_count, skip_count = await cls._broadcast_normal(
|
||||
bot, log_session, target_groups, message
|
||||
bot,
|
||||
log_session,
|
||||
target_groups,
|
||||
message,
|
||||
force_send,
|
||||
concurrency_limit,
|
||||
)
|
||||
|
||||
total = len(target_groups)
|
||||
@ -287,11 +311,16 @@ class BroadcastManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _check_group_availability(cls, bot: Bot, group: GroupConsole) -> bool:
|
||||
async def _check_group_availability(
|
||||
cls, bot: Bot, group: GroupConsole, force_send: bool = False
|
||||
) -> bool:
|
||||
"""检查群组是否可用"""
|
||||
if not group.group_id:
|
||||
return False
|
||||
|
||||
if force_send:
|
||||
return True
|
||||
|
||||
if await CommonUtils.task_is_block(bot, "broadcast", group.group_id):
|
||||
return False
|
||||
|
||||
@ -304,54 +333,69 @@ class BroadcastManager:
|
||||
session_info: EventSession | str,
|
||||
group_list: list[GroupConsole],
|
||||
v11_nodes: list[dict],
|
||||
force_send: bool = False,
|
||||
concurrency_limit: int = 10,
|
||||
) -> BroadcastDetailResult:
|
||||
"""发送合并转发"""
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
skip_count = 0
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
msg_id_lock = asyncio.Lock()
|
||||
|
||||
for _, group in enumerate(group_list):
|
||||
async def send_to_group(group: GroupConsole) -> GroupConsole:
|
||||
group_key = group.group_id or group.channel_id
|
||||
async with semaphore:
|
||||
try:
|
||||
result = await bot.send_group_forward_msg(
|
||||
group_id=int(group.group_id), messages=v11_nodes
|
||||
)
|
||||
async with msg_id_lock:
|
||||
await cls._extract_message_id_from_result(
|
||||
result, group_key, session_info, "合并转发"
|
||||
)
|
||||
await asyncio.sleep(random.uniform(*BROADCAST_SEND_DELAY_RANGE))
|
||||
return group
|
||||
except (ActionFailed, AdapterException) as ae:
|
||||
logger.error(
|
||||
f"发送失败(合并转发) to {group_key}: {ae}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
e=ae,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"发送失败(合并转发) to {group_key}: {e}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
e=e,
|
||||
)
|
||||
raise
|
||||
|
||||
if not await cls._check_group_availability(bot, group):
|
||||
skip_count += 1
|
||||
continue
|
||||
tasks: list[asyncio.Task] = []
|
||||
skipped_groups: list[GroupConsole] = []
|
||||
for group in group_list:
|
||||
if await cls._check_group_availability(bot, group, force_send):
|
||||
tasks.append(asyncio.create_task(send_to_group(group)))
|
||||
else:
|
||||
skipped_groups.append(group)
|
||||
|
||||
try:
|
||||
result = await bot.send_group_forward_msg(
|
||||
group_id=int(group.group_id), messages=v11_nodes
|
||||
)
|
||||
if skipped_groups:
|
||||
logger.info(
|
||||
f"跳过 {len(skipped_groups)} 个不符合条件的群组",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"合并转发消息发送结果: {result}, 类型: {type(result)}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
if not tasks:
|
||||
return 0, 0, len(skipped_groups)
|
||||
|
||||
await cls._extract_message_id_from_result(
|
||||
result, group_key, session_info, "合并转发"
|
||||
)
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
success_count += 1
|
||||
await asyncio.sleep(random.randint(1, 3))
|
||||
except ActionFailed as af_e:
|
||||
error_count += 1
|
||||
logger.error(
|
||||
f"发送失败(合并转发) to {group_key}: {af_e}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
e=af_e,
|
||||
)
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(
|
||||
f"发送失败(合并转发) to {group_key}: {e}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
e=e,
|
||||
)
|
||||
success_count = sum(
|
||||
1 for result in results if not isinstance(result, Exception)
|
||||
)
|
||||
error_count = len(results) - success_count
|
||||
|
||||
return success_count, error_count, skip_count
|
||||
return success_count, error_count, len(skipped_groups)
|
||||
|
||||
@classmethod
|
||||
async def _broadcast_normal(
|
||||
@ -360,58 +404,83 @@ class BroadcastManager:
|
||||
session_info: EventSession | str,
|
||||
group_list: list[GroupConsole],
|
||||
message: UniMessage,
|
||||
force_send: bool = False,
|
||||
concurrency_limit: int = 10,
|
||||
) -> BroadcastDetailResult:
|
||||
"""发送普通消息"""
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
skip_count = 0
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
msg_id_lock = asyncio.Lock()
|
||||
|
||||
for _, group in enumerate(group_list):
|
||||
async def send_to_group(group: GroupConsole) -> GroupConsole:
|
||||
group_key = (
|
||||
f"{group.group_id}:{group.channel_id}"
|
||||
if group.channel_id
|
||||
else str(group.group_id)
|
||||
)
|
||||
|
||||
if not await cls._check_group_availability(bot, group):
|
||||
skip_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
target = PlatformUtils.get_target(
|
||||
group_id=group.group_id, channel_id=group.channel_id
|
||||
)
|
||||
|
||||
if target:
|
||||
receipt: Receipt = await message.send(target, bot=bot)
|
||||
|
||||
logger.debug(
|
||||
f"广播消息发送结果: {receipt}, 类型: {type(receipt)}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
await cls._extract_message_id_from_result(
|
||||
receipt, group_key, session_info
|
||||
)
|
||||
|
||||
success_count += 1
|
||||
await asyncio.sleep(random.randint(1, 3))
|
||||
else:
|
||||
logger.warning(
|
||||
"target为空", "广播", session=session_info, target=group_key
|
||||
)
|
||||
skip_count += 1
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(
|
||||
f"发送失败(普通) to {group_key}: {e}",
|
||||
target = PlatformUtils.get_target(
|
||||
group_id=group.group_id, channel_id=group.channel_id
|
||||
)
|
||||
if not target:
|
||||
logger.warning(
|
||||
"target为空",
|
||||
"广播",
|
||||
session=session_info,
|
||||
e=e,
|
||||
target=group_key,
|
||||
)
|
||||
raise ValueError(f"无法为群组 {group_key} 创建发送目标")
|
||||
|
||||
return success_count, error_count, skip_count
|
||||
async with semaphore:
|
||||
try:
|
||||
receipt: Receipt = await message.send(target, bot=bot)
|
||||
async with msg_id_lock:
|
||||
await cls._extract_message_id_from_result(
|
||||
receipt, group_key, session_info
|
||||
)
|
||||
await asyncio.sleep(random.uniform(*BROADCAST_SEND_DELAY_RANGE))
|
||||
return group
|
||||
except (ActionFailed, AdapterException) as ae:
|
||||
logger.error(
|
||||
f"发送失败(普通) to {group_key}: {ae}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
e=ae,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"发送失败(普通) to {group_key}: {e}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
e=e,
|
||||
)
|
||||
raise
|
||||
|
||||
tasks: list[asyncio.Task] = []
|
||||
skipped_groups: list[GroupConsole] = []
|
||||
for group in group_list:
|
||||
if await cls._check_group_availability(bot, group, force_send):
|
||||
tasks.append(asyncio.create_task(send_to_group(group)))
|
||||
else:
|
||||
skipped_groups.append(group)
|
||||
|
||||
if skipped_groups:
|
||||
logger.info(
|
||||
f"跳过 {len(skipped_groups)} 个不符合条件的群组",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
if not tasks:
|
||||
return 0, 0, len(skipped_groups)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
success_count = sum(
|
||||
1 for result in results if not isinstance(result, Exception)
|
||||
)
|
||||
error_count = len(results) - success_count
|
||||
|
||||
return success_count, error_count, len(skipped_groups)
|
||||
|
||||
@classmethod
|
||||
async def recall_last_broadcast(
|
||||
|
||||
@ -21,8 +21,11 @@ from nonebot_plugin_alconna.uniseg.segment import (
|
||||
from nonebot_plugin_alconna.uniseg.tools import reply_fetch
|
||||
from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.services.tags import tag_manager as TagManager
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
from .broadcast_manager import BroadcastManager
|
||||
@ -399,22 +402,29 @@ async def _process_v11_segment(
|
||||
elif target_qq:
|
||||
result.append(At(flag="user", target=target_qq))
|
||||
elif seg_type == "video":
|
||||
video_seg = None
|
||||
if data_dict.get("url"):
|
||||
video_seg = Video(url=data_dict["url"])
|
||||
elif data_dict.get("file"):
|
||||
file_val = data_dict["file"]
|
||||
if url := data_dict.get("url"):
|
||||
try:
|
||||
logger.debug(f"[D{depth}] 正在下载视频用于广播: {url}", "广播")
|
||||
video_bytes = await AsyncHttpx.get_content(url)
|
||||
video_seg = Video(raw=video_bytes)
|
||||
logger.debug(
|
||||
f"[D{depth}] 视频下载成功, 大小: {len(video_bytes)} bytes",
|
||||
"广播",
|
||||
)
|
||||
result.append(video_seg)
|
||||
except Exception as e:
|
||||
logger.error(f"[D{depth}] 广播时下载视频失败: {url}", "广播", e=e)
|
||||
result.append(Text(f"[视频下载失败: {url}]"))
|
||||
elif file_val := data_dict.get("file"):
|
||||
if isinstance(file_val, str) and file_val.startswith("base64://"):
|
||||
b64_data = file_val[9:]
|
||||
raw_bytes = base64.b64decode(b64_data)
|
||||
video_seg = Video(raw=raw_bytes)
|
||||
result.append(video_seg)
|
||||
else:
|
||||
video_seg = Video(path=file_val)
|
||||
if video_seg:
|
||||
result.append(video_seg)
|
||||
logger.debug(f"[Depth {depth}] 处理视频消息成功", "广播")
|
||||
else:
|
||||
logger.warning(f"[Depth {depth}] V11 视频 {index} 缺少URL/文件", "广播")
|
||||
result.append(video_seg)
|
||||
return result
|
||||
elif seg_type == "forward":
|
||||
nested_forward_id = data_dict.get("id") or data_dict.get("resid")
|
||||
nested_forward_content = data_dict.get("content")
|
||||
@ -515,70 +525,129 @@ async def _extract_content_from_message(
|
||||
|
||||
|
||||
async def get_broadcast_target_groups(
|
||||
bot: Bot, session: EventSession
|
||||
bot: Bot,
|
||||
session: EventSession,
|
||||
tag_name: str | None = None,
|
||||
force_send: bool = False,
|
||||
) -> tuple[list, list]:
|
||||
"""获取广播目标群组和启用了广播功能的群组"""
|
||||
target_groups = []
|
||||
all_groups, _ = await BroadcastManager.get_all_groups(bot)
|
||||
target_groups_console: list[GroupConsole] = []
|
||||
|
||||
current_group_id = None
|
||||
if hasattr(session, "id2") and session.id2:
|
||||
current_group_id = session.id2
|
||||
current_group_raw = getattr(session, "id2", None) or getattr(
|
||||
session, "group_id", None
|
||||
)
|
||||
current_group_id = str(current_group_raw) if current_group_raw else None
|
||||
|
||||
if current_group_id:
|
||||
target_groups = [
|
||||
group for group in all_groups if group.group_id != current_group_id
|
||||
]
|
||||
logger.info(
|
||||
f"向除当前群组({current_group_id})外的所有群组广播", "广播", session=session
|
||||
)
|
||||
logger.debug(f"当前群组ID: {current_group_id}", "广播")
|
||||
|
||||
if tag_name:
|
||||
tagged_group_ids = await TagManager.resolve_tag_to_group_ids(tag_name, bot=bot)
|
||||
if not tagged_group_ids:
|
||||
return [], []
|
||||
|
||||
valid_groups = await GroupConsole.filter(group_id__in=tagged_group_ids)
|
||||
|
||||
if current_group_id:
|
||||
target_groups_console = [
|
||||
group
|
||||
for group in valid_groups
|
||||
if str(group.group_id) != current_group_id
|
||||
]
|
||||
excluded_msg = (
|
||||
f",已排除当前群组({current_group_id})"
|
||||
if any(
|
||||
str(group.group_id) == current_group_id for group in valid_groups
|
||||
)
|
||||
else ""
|
||||
)
|
||||
broadcast_msg = (
|
||||
f"向标签 '{tag_name}' 中的 {len(target_groups_console)} 个群组广播 "
|
||||
f"(ForceSend: {force_send}){excluded_msg}"
|
||||
)
|
||||
logger.info(broadcast_msg, "广播", session=session)
|
||||
else:
|
||||
target_groups_console = valid_groups
|
||||
broadcast_msg = (
|
||||
f"向标签 '{tag_name}' 中的 {len(target_groups_console)} 个群组广播 "
|
||||
f"(ForceSend: {force_send})"
|
||||
)
|
||||
logger.info(broadcast_msg, "广播", session=session)
|
||||
else:
|
||||
target_groups = all_groups
|
||||
logger.info("向所有群组广播", "广播", session=session)
|
||||
all_groups, _ = await BroadcastManager.get_all_groups(bot)
|
||||
|
||||
if not target_groups:
|
||||
await MessageUtils.build_message("没有找到符合条件的广播目标群组。").send(
|
||||
reply_to=True
|
||||
)
|
||||
if current_group_id:
|
||||
target_groups_console = [
|
||||
group for group in all_groups if str(group.group_id) != current_group_id
|
||||
]
|
||||
logger.info(
|
||||
(
|
||||
f"向除当前群组({current_group_id})外的所有群组广播 "
|
||||
f"(ForceSend: {force_send})"
|
||||
),
|
||||
"广播",
|
||||
session=session,
|
||||
)
|
||||
else:
|
||||
target_groups_console = all_groups
|
||||
logger.info(
|
||||
f"向所有群组广播 (ForceSend: {force_send})", "广播", session=session
|
||||
)
|
||||
|
||||
if not target_groups_console:
|
||||
if not tag_name:
|
||||
await MessageUtils.build_message("没有找到符合条件的广播目标群组。").send(
|
||||
reply_to=True
|
||||
)
|
||||
return [], []
|
||||
|
||||
enabled_groups = []
|
||||
for group in target_groups:
|
||||
if not await CommonUtils.task_is_block(bot, "broadcast", group.group_id):
|
||||
enabled_groups.append(group)
|
||||
groups_to_actually_send = []
|
||||
if force_send:
|
||||
groups_to_actually_send = target_groups_console
|
||||
logger.debug(
|
||||
f"强制发送模式,将向 {len(groups_to_actually_send)} 个目标群组尝试发送。",
|
||||
"广播",
|
||||
)
|
||||
else:
|
||||
for group in target_groups_console:
|
||||
if not await CommonUtils.task_is_block(bot, "broadcast", group.group_id):
|
||||
groups_to_actually_send.append(group)
|
||||
logger.debug(
|
||||
f"普通发送模式,筛选后将向 {len(groups_to_actually_send)} "
|
||||
f"个目标群组尝试发送",
|
||||
"广播",
|
||||
)
|
||||
|
||||
if not enabled_groups:
|
||||
await MessageUtils.build_message(
|
||||
"没有启用了广播功能的目标群组可供立即发送。"
|
||||
).send(reply_to=True)
|
||||
return target_groups, []
|
||||
|
||||
return target_groups, enabled_groups
|
||||
return target_groups_console, groups_to_actually_send
|
||||
|
||||
|
||||
async def send_broadcast_and_notify(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
message: UniMessage,
|
||||
enabled_groups: list,
|
||||
target_groups: list,
|
||||
groups_to_send: list,
|
||||
all_target_groups_for_stats: list,
|
||||
session: EventSession,
|
||||
force_send: bool = False,
|
||||
) -> None:
|
||||
"""发送广播并通知结果"""
|
||||
BroadcastManager.clear_last_broadcast_msg_ids()
|
||||
count, error_count = await BroadcastManager.send_to_specific_groups(
|
||||
bot, message, enabled_groups, session
|
||||
bot, message, groups_to_send, session, force_send
|
||||
)
|
||||
|
||||
result = f"成功广播 {count} 个群组"
|
||||
if error_count:
|
||||
result += f"\n发送失败 {error_count} 个群组"
|
||||
result += f"\n有效: {len(enabled_groups)} / 总计: {len(target_groups)}"
|
||||
|
||||
effective_sent_count = len(groups_to_send)
|
||||
total_considered_count = len(all_target_groups_for_stats)
|
||||
|
||||
result += f"\n有效: {effective_sent_count} / 总计目标: {total_considered_count}"
|
||||
|
||||
user_id = str(event.get_user_id())
|
||||
await bot.send_private_msg(user_id=user_id, message=f"发送广播完成!\n{result}")
|
||||
|
||||
BroadcastManager.log_info(
|
||||
f"广播完成,有效/总计: {len(enabled_groups)}/{len(target_groups)}",
|
||||
f"广播完成,有效/总计目标: {effective_sent_count}/{total_considered_count}",
|
||||
session,
|
||||
)
|
||||
|
||||
@ -59,7 +59,7 @@ def uni_segment_to_v11_segment_dict(
|
||||
logger.warning(f"无法处理 Video.raw 的类型: {type(raw_data)}", "广播")
|
||||
elif getattr(seg, "path", None):
|
||||
logger.warning(
|
||||
f"在合并转发中使用了本地视频路径,可能无法显示: {seg.path}", "广播"
|
||||
f"在合并转发中使用了本地视频路径,可能无法发送: {seg.path}", "广播"
|
||||
)
|
||||
return {"type": "video", "data": {"file": f"file:///{seg.path}"}}
|
||||
else:
|
||||
|
||||
581
zhenxun/builtin_plugins/superuser/plugin_config_manager.py
Normal file
581
zhenxun/builtin_plugins/superuser/plugin_config_manager.py
Normal file
@ -0,0 +1,581 @@
|
||||
from typing import Any
|
||||
|
||||
from arclet.alconna.typing import KeyWordVar
|
||||
import nonebot
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.compat import model_fields
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
Args,
|
||||
Arparma,
|
||||
Match,
|
||||
MultiVar,
|
||||
Option,
|
||||
Subcommand,
|
||||
on_alconna,
|
||||
store_true,
|
||||
)
|
||||
from nonebot_plugin_session import EventSession
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
||||
from zhenxun.services import group_settings_service, renderer_service
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.services.tags import tag_manager
|
||||
from zhenxun.ui import builders as ui
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
from zhenxun.utils.pydantic_compat import parse_as
|
||||
from zhenxun.utils.rules import admin_check
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="插件配置管理",
|
||||
description="一个统一的命令,用于管理所有插件的分群配置",
|
||||
usage="""
|
||||
### ⚙️ 插件配置管理 (pconf)
|
||||
---
|
||||
一个统一的命令,用于管理所有插件的分群或全局配置。
|
||||
|
||||
#### **📖 命令格式**
|
||||
`pconf <子命令> [参数] [选项]`
|
||||
|
||||
#### **🎯 目标选项 (互斥)**
|
||||
- `-g, --group <群号...>`: 指定一个或多个群组ID **(SUPERUSER)**
|
||||
- `-t, --tag <标签名>`: 指定一个群组标签 **(SUPERUSER)**
|
||||
- `--all`: 对当前Bot所在的所有群组执行操作 **(SUPERUSER)**
|
||||
- `--global`: 操作全局配置 (config.yaml) **(SUPERUSER)**
|
||||
- **(无)**: 在群聊中操作时,默认目标为当前群。
|
||||
|
||||
#### **📋 子命令列表**
|
||||
* **`list` (或 `ls`)**: 查看列表
|
||||
* `pconf list`: 查看所有支持分群配置的插件。
|
||||
* `pconf list -p <插件名>`: 查看指定插件的所有分群可配置项。
|
||||
* `pconf list -p <插件名> --all`: 查看所有群组对该插件的配置。
|
||||
* `pconf list -p <插件名> --global`: 查看指定插件的全局可配置项。
|
||||
|
||||
* **`get <配置项>`**: 获取配置值
|
||||
* `pconf get <配置项> -p <插件名>`: 获取当前群的配置值。
|
||||
* `pconf get <配置项> -p <插件名> -g <群号>`: 获取指定群的配置值。
|
||||
|
||||
* **`set <key=value...>`**: 设置一个或多个配置值
|
||||
* `pconf set key1=value1 key2=value2 -p <插件名>`
|
||||
|
||||
* **`reset [配置项]`**: 重置配置为默认值
|
||||
* `pconf reset -p <插件名>`: 重置当前群该插件的所有配置。
|
||||
* `pconf reset <配置项> -p <插件名>`: 重置当前群该插件的指定配置项。
|
||||
""",
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="1.0",
|
||||
plugin_type=PluginType.SUPERUSER,
|
||||
configs=[
|
||||
RegisterConfig(
|
||||
module="plugin_config_manager",
|
||||
key="PCONF_ADMIN_LEVEL",
|
||||
value=5,
|
||||
help="管理分群配置的基础权限等级",
|
||||
default_value=5,
|
||||
type=int,
|
||||
),
|
||||
RegisterConfig(
|
||||
module="plugin_config_manager",
|
||||
key="SHOW_DEFAULT_CONFIG_IN_ALL",
|
||||
value=False,
|
||||
help="在使用 --all 查询时,是否显示配置为默认值的群组",
|
||||
default_value=False,
|
||||
type=bool,
|
||||
),
|
||||
],
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
|
||||
pconf_cmd = on_alconna(
|
||||
Alconna(
|
||||
"pconf",
|
||||
Subcommand(
|
||||
"list",
|
||||
alias=["ls"],
|
||||
help_text="查看插件或配置项列表",
|
||||
),
|
||||
Subcommand(
|
||||
"get",
|
||||
Args["key", str],
|
||||
help_text="获取配置值",
|
||||
),
|
||||
Subcommand(
|
||||
"set",
|
||||
Args["settings", MultiVar(KeyWordVar(Any))],
|
||||
help_text="设置配置值",
|
||||
),
|
||||
Subcommand(
|
||||
"reset",
|
||||
Args["key?", str],
|
||||
help_text="重置配置",
|
||||
),
|
||||
Option("-p|--plugin", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g|--group", Args["group_ids", MultiVar(str)], help_text="指定群组ID"),
|
||||
Option("-t|--tag", Args["tag_name", str], help_text="指定群组标签"),
|
||||
Option("--all", action=store_true, help_text="操作所有群组"),
|
||||
Option("--global", action=store_true, help_text="操作全局配置"),
|
||||
),
|
||||
rule=admin_check("plugin_config_manager", "PCONF_ADMIN_LEVEL"),
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
|
||||
async def get_plugin_config_model(plugin_name: str) -> type[BaseModel] | None:
|
||||
"""通过插件名查找其注册的分群配置模型"""
|
||||
for p in nonebot.get_loaded_plugins():
|
||||
if p.name == plugin_name and p.metadata and p.metadata.extra:
|
||||
extra = PluginExtraData(**p.metadata.extra)
|
||||
if extra.group_config_model:
|
||||
return extra.group_config_model
|
||||
return None
|
||||
|
||||
|
||||
def truncate_text(text: str, max_len: int) -> str:
|
||||
"""截断文本,过长时添加省略号"""
|
||||
if len(text) > max_len:
|
||||
return text[: max_len - 3] + "..."
|
||||
return text
|
||||
|
||||
|
||||
async def GetTargets(
|
||||
bot: Bot, event: Event, session: EventSession, arp: Arparma
|
||||
) -> list[str]:
|
||||
"""
|
||||
依赖注入,根据 -g, -t, --all 或当前会话解析目标群组ID列表,并进行权限检查。
|
||||
"""
|
||||
is_superuser = await SUPERUSER(bot, event)
|
||||
|
||||
if group_ids_match := arp.query[list[str]]("group.group_ids"):
|
||||
if not is_superuser:
|
||||
logger.warning(f"非超级用户 {session.id1} 尝试使用 -g 参数。")
|
||||
raise SkippedException("权限不足")
|
||||
return group_ids_match
|
||||
|
||||
if tag_name_match := arp.query[str]("tag.tag_name"):
|
||||
if not is_superuser:
|
||||
logger.warning(f"非超级用户 {session.id1} 尝试使用 -t 参数。")
|
||||
raise SkippedException("权限不足")
|
||||
|
||||
resolved_groups = await tag_manager.resolve_tag_to_group_ids(
|
||||
tag_name_match, bot=bot
|
||||
)
|
||||
if not resolved_groups:
|
||||
await pconf_cmd.finish(f"标签 '{tag_name_match}' 没有匹配到任何群组。")
|
||||
return resolved_groups
|
||||
|
||||
if arp.find("all"):
|
||||
if not is_superuser:
|
||||
logger.warning(f"非超级用户 {session.id1} 尝试使用 --all 参数。")
|
||||
raise SkippedException("权限不足")
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
all_groups, _ = await PlatformUtils.get_group_list(bot)
|
||||
return [g.group_id for g in all_groups]
|
||||
|
||||
if gid := session.id3 or session.id2:
|
||||
return [gid]
|
||||
|
||||
if not is_superuser:
|
||||
logger.warning(f"管理员 {session.id1} 尝试在私聊中操作分群配置。")
|
||||
raise SkippedException("权限不足")
|
||||
|
||||
await pconf_cmd.finish(
|
||||
"超级用户在私聊中操作时,必须使用 -g <群号>、-t <标签名> 或 --all 指定目标群组"
|
||||
)
|
||||
|
||||
|
||||
@pconf_cmd.assign("list")
|
||||
async def handle_list(arp: Arparma, bot: Bot, event: Event):
|
||||
"""处理 list 子命令"""
|
||||
plugin_name_str = None
|
||||
is_superuser = await SUPERUSER(bot, event)
|
||||
if arp.find("plugin"):
|
||||
plugin_name_str = arp.query[str]("plugin.plugin_name")
|
||||
|
||||
if plugin_name_str:
|
||||
is_global = arp.find("global")
|
||||
is_all_groups = arp.find("all")
|
||||
|
||||
if is_all_groups and not is_global:
|
||||
if not is_superuser:
|
||||
await MessageUtils.build_message(
|
||||
"只有超级用户才能查看所有群的配置。"
|
||||
).finish()
|
||||
|
||||
model = await get_plugin_config_model(plugin_name_str)
|
||||
model_fields_list = model_fields(model) if model else []
|
||||
if not model_fields_list:
|
||||
await MessageUtils.build_message(
|
||||
f"插件 '{plugin_name_str}' 不支持分群配置。"
|
||||
).finish()
|
||||
|
||||
all_groups, _ = await PlatformUtils.get_group_list(bot)
|
||||
if not all_groups:
|
||||
await MessageUtils.build_message("机器人未加入任何群组。").finish()
|
||||
|
||||
model_fields_dict = {field.name: field for field in model_fields_list}
|
||||
config_keys = list(model_fields_dict.keys())
|
||||
headers = ["群号", "群名称", *config_keys]
|
||||
rows = []
|
||||
|
||||
for group in all_groups:
|
||||
settings_dict = await group_settings_service.get_all_for_plugin(
|
||||
group.group_id, plugin_name_str
|
||||
)
|
||||
row_data = [group.group_id, truncate_text(group.group_name, 10)]
|
||||
for key in config_keys:
|
||||
value = settings_dict.get(key)
|
||||
default_value = model_fields_dict[key].field_info.default
|
||||
|
||||
if value == default_value:
|
||||
value_str = "默认"
|
||||
else:
|
||||
value_str = str(value) if value is not None else "N/A"
|
||||
|
||||
row_data.append(truncate_text(value_str, 20))
|
||||
|
||||
show_default = Config.get_config(
|
||||
"plugin_config_manager", "SHOW_DEFAULT_CONFIG_IN_ALL", False
|
||||
)
|
||||
if not show_default:
|
||||
is_all_default = all(val == "默认" for val in row_data[2:])
|
||||
if is_all_default:
|
||||
continue
|
||||
|
||||
rows.append(row_data)
|
||||
|
||||
builder = ui.TableBuilder(
|
||||
title=f"插件 '{plugin_name_str}' 全群配置",
|
||||
tip=f"共查询 {len(rows)} 个群组",
|
||||
)
|
||||
builder.set_headers(headers).add_rows(rows)
|
||||
|
||||
viewport_width = 300 + len(config_keys) * 280
|
||||
img = await renderer_service.render(
|
||||
builder.build(), viewport={"width": viewport_width, "height": 10}
|
||||
)
|
||||
await MessageUtils.build_message(img).finish()
|
||||
|
||||
if is_global:
|
||||
if not is_superuser:
|
||||
await MessageUtils.build_message(
|
||||
"只有超级用户才能查看全局配置。"
|
||||
).finish()
|
||||
config_group = Config.get(plugin_name_str)
|
||||
if not config_group or not config_group.configs:
|
||||
await MessageUtils.build_message(
|
||||
f"插件 '{plugin_name_str}' 没有可配置的全局项。"
|
||||
).finish()
|
||||
|
||||
builder = ui.TableBuilder(
|
||||
title=f"插件 '{plugin_name_str}' 全局可配置项",
|
||||
tip=(
|
||||
f"位于 config.yaml, 使用 pconf set <key>=<value> "
|
||||
f"-p {plugin_name_str} --global 进行设置"
|
||||
),
|
||||
)
|
||||
builder.set_headers(["配置项", "当前值", "类型", "描述"])
|
||||
|
||||
for key, config_model in config_group.configs.items():
|
||||
type_name = getattr(
|
||||
config_model.type, "__name__", str(config_model.type)
|
||||
)
|
||||
builder.add_row(
|
||||
[
|
||||
key,
|
||||
truncate_text(str(config_model.value), 20),
|
||||
type_name,
|
||||
truncate_text(config_model.help or "无", 20),
|
||||
]
|
||||
)
|
||||
|
||||
img = await renderer_service.render(builder.build())
|
||||
await MessageUtils.build_message(img).finish()
|
||||
else:
|
||||
model = await get_plugin_config_model(plugin_name_str)
|
||||
model_fields_list = model_fields(model) if model else []
|
||||
if not model_fields_list:
|
||||
await MessageUtils.build_message(
|
||||
f"插件 '{plugin_name_str}' 不支持分群配置。"
|
||||
).finish()
|
||||
|
||||
builder = ui.TableBuilder(
|
||||
title=f"插件 '{plugin_name_str}' 可配置项",
|
||||
tip=f"使用 pconf set <key>=<value> -p {plugin_name_str} 进行设置",
|
||||
)
|
||||
builder.set_headers(["配置项", "类型", "描述", "默认值"])
|
||||
|
||||
for field in model_fields_list:
|
||||
type_name = getattr(field.annotation, "__name__", str(field.annotation))
|
||||
description = field.field_info.description or "无"
|
||||
default_value = (
|
||||
str(field.get_default())
|
||||
if field.field_info.default is not None
|
||||
else "无"
|
||||
)
|
||||
builder.add_row([field.name, type_name, description, default_value])
|
||||
|
||||
img = await renderer_service.render(builder.build())
|
||||
await MessageUtils.build_message(img).finish()
|
||||
|
||||
else:
|
||||
configurable_plugins = []
|
||||
for p in nonebot.get_loaded_plugins():
|
||||
if p.metadata and p.metadata.extra:
|
||||
extra = PluginExtraData(**p.metadata.extra)
|
||||
if extra.group_config_model:
|
||||
configurable_plugins.append(p.name)
|
||||
|
||||
if not configurable_plugins:
|
||||
await MessageUtils.build_message("当前没有插件支持分群配置。").finish()
|
||||
|
||||
await MessageUtils.build_message(
|
||||
"支持分群配置的插件列表:\n"
|
||||
+ "\n".join(f"- {name}" for name in configurable_plugins)
|
||||
).finish()
|
||||
|
||||
|
||||
@pconf_cmd.assign("get")
|
||||
async def handle_get(
|
||||
arp: Arparma,
|
||||
key: Match[str],
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
session: EventSession,
|
||||
):
|
||||
if not arp.find("plugin"):
|
||||
await pconf_cmd.finish("必须使用 -p <插件名> 指定要操作的插件。")
|
||||
plugin_name_str = arp.query[str]("plugin.plugin_name")
|
||||
if not plugin_name_str:
|
||||
await pconf_cmd.finish("插件名不能为空。")
|
||||
is_superuser = await SUPERUSER(bot, event)
|
||||
|
||||
if arp.find("global"):
|
||||
if not is_superuser:
|
||||
await MessageUtils.build_message("只有超级用户才能获取全局配置。").finish()
|
||||
value = Config.get_config(plugin_name_str, key.result)
|
||||
await MessageUtils.build_message(
|
||||
f"全局配置项 '{key.result}' 的值为: {value}"
|
||||
).finish()
|
||||
else:
|
||||
target_group_ids = await GetTargets(bot, event, session, arp)
|
||||
target_group_id = target_group_ids[0]
|
||||
value = await group_settings_service.get(
|
||||
target_group_id, plugin_name_str, key.result
|
||||
)
|
||||
await MessageUtils.build_message(
|
||||
f"群组 {target_group_id} 的配置项 '{key.result}' 的值为: {value}"
|
||||
).finish()
|
||||
|
||||
|
||||
@pconf_cmd.assign("set")
|
||||
async def handle_set(
|
||||
arp: Arparma,
|
||||
settings: Match[dict],
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
session: EventSession,
|
||||
):
|
||||
if not arp.find("plugin"):
|
||||
await pconf_cmd.finish("必须使用 -p <插件名> 指定要操作的插件。")
|
||||
plugin_name_str = arp.query[str]("plugin.plugin_name")
|
||||
if not plugin_name_str:
|
||||
await pconf_cmd.finish("插件名不能为空。")
|
||||
is_superuser = await SUPERUSER(bot, event)
|
||||
|
||||
is_global = arp.find("global")
|
||||
|
||||
if is_global:
|
||||
if not is_superuser:
|
||||
await MessageUtils.build_message("只有超级用户才能设置全局配置。").finish()
|
||||
config_group = Config.get(plugin_name_str)
|
||||
if not config_group or not config_group.configs:
|
||||
await MessageUtils.build_message(
|
||||
f"插件 '{plugin_name_str}' 没有可配置的全局项。"
|
||||
).finish()
|
||||
|
||||
changes_made = False
|
||||
success_messages = []
|
||||
for key, value_str in settings.result.items():
|
||||
config_model = config_group.configs.get(key.upper())
|
||||
if not config_model:
|
||||
await MessageUtils.build_message(
|
||||
f"❌ 全局配置项 '{key}' 不存在。"
|
||||
).send()
|
||||
continue
|
||||
|
||||
target_type = config_model.type
|
||||
if target_type is None:
|
||||
if config_model.default_value is not None:
|
||||
target_type = type(config_model.default_value)
|
||||
elif config_model.value is not None:
|
||||
target_type = type(config_model.value)
|
||||
|
||||
converted_value: Any = value_str
|
||||
if target_type and value_str is not None:
|
||||
try:
|
||||
converted_value = parse_as(target_type, value_str)
|
||||
except (ValidationError, TypeError, ValueError) as e:
|
||||
type_name = getattr(target_type, "__name__", str(target_type))
|
||||
await MessageUtils.build_message(
|
||||
f"❌ 配置项 '{key}' 的值 '{value_str}' "
|
||||
f"无法转换为期望的类型 '{type_name}': {e}"
|
||||
).send()
|
||||
continue
|
||||
|
||||
Config.set_config(plugin_name_str, key.upper(), converted_value)
|
||||
success_messages.append(f" - 配置项 '{key}' 已设置为: `{converted_value}`")
|
||||
changes_made = True
|
||||
|
||||
if changes_made:
|
||||
Config.save(save_simple_data=True)
|
||||
response_msg = (
|
||||
f"✅ 插件 '{plugin_name_str}' 的全局配置已更新:\n"
|
||||
+ "\n".join(success_messages)
|
||||
)
|
||||
await MessageUtils.build_message(response_msg).finish()
|
||||
else:
|
||||
model = await get_plugin_config_model(plugin_name_str)
|
||||
if not model:
|
||||
await MessageUtils.build_message(
|
||||
f"插件 '{plugin_name_str}' 不支持分群配置。"
|
||||
).finish()
|
||||
|
||||
target_group_ids = await GetTargets(bot, event, session, arp)
|
||||
model_fields_map = {field.name: field for field in model_fields(model)}
|
||||
|
||||
success_groups = []
|
||||
failed_groups = []
|
||||
update_details = []
|
||||
|
||||
for group_id in target_group_ids:
|
||||
for key, value_str in settings.result.items():
|
||||
field = model_fields_map.get(key)
|
||||
if not field:
|
||||
await MessageUtils.build_message(
|
||||
f"配置项 '{key}' 在插件 '{plugin_name_str}' 中不存在。"
|
||||
).finish()
|
||||
|
||||
try:
|
||||
validated_value = (
|
||||
parse_as(field.annotation, value_str)
|
||||
if field.annotation is not None
|
||||
else value_str
|
||||
)
|
||||
await group_settings_service.set_key_value(
|
||||
group_id, plugin_name_str, key, validated_value
|
||||
)
|
||||
if group_id not in success_groups:
|
||||
success_groups.append(group_id)
|
||||
|
||||
if (key, validated_value) not in update_details:
|
||||
update_details.append((key, validated_value))
|
||||
except (ValidationError, TypeError, ValueError) as e:
|
||||
failed_groups.append(
|
||||
(group_id, f"配置项 '{key}' 值 '{value_str}' 类型错误: {e}")
|
||||
)
|
||||
except Exception as e:
|
||||
failed_groups.append((group_id, f"内部错误: {e}"))
|
||||
|
||||
if len(target_group_ids) == 1:
|
||||
group_id = target_group_ids[0]
|
||||
if group_id in success_groups and group_id not in [
|
||||
g[0] for g in failed_groups
|
||||
]:
|
||||
settings_summary = [
|
||||
f" - '{k}' 已设置为: `{v}`" for k, v in update_details
|
||||
]
|
||||
msg = (
|
||||
f"✅ 群组 {group_id} 插件 '{plugin_name_str}' 配置更新成功:\n"
|
||||
+ "\n".join(settings_summary)
|
||||
)
|
||||
else:
|
||||
errors = [f[1] for f in failed_groups if f[0] == group_id]
|
||||
msg = (
|
||||
f"❌ 群组 {group_id} 插件 '{plugin_name_str}' 配置更新失败:\n"
|
||||
+ "\n".join(errors)
|
||||
)
|
||||
else:
|
||||
settings_count = len(settings.result)
|
||||
msg = (
|
||||
f"✅ 批量为 {len(success_groups)} 个群组设置了 "
|
||||
f"{settings_count} 个配置项。"
|
||||
)
|
||||
if failed_groups:
|
||||
failed_count = len({g[0] for g in failed_groups})
|
||||
msg += f"\n❌ 其中 {failed_count} 个群组部分或全部设置失败。"
|
||||
|
||||
await MessageUtils.build_message(msg).finish()
|
||||
|
||||
|
||||
@pconf_cmd.assign("reset")
|
||||
async def handle_reset(
|
||||
arp: Arparma,
|
||||
key: Match[str],
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
session: EventSession,
|
||||
):
|
||||
if not arp.find("plugin"):
|
||||
await pconf_cmd.finish("必须使用 -p <插件名> 指定要操作的插件。")
|
||||
plugin_name_str = arp.query[str]("plugin.plugin_name")
|
||||
if not plugin_name_str:
|
||||
await pconf_cmd.finish("插件名不能为空。")
|
||||
is_superuser = await SUPERUSER(bot, event)
|
||||
|
||||
if arp.find("global"):
|
||||
if not is_superuser:
|
||||
await MessageUtils.build_message("只有超级用户才能重置全局配置。").finish()
|
||||
await MessageUtils.build_message("全局配置重置功能暂未实现。").finish()
|
||||
else:
|
||||
target_group_ids = await GetTargets(bot, event, session, arp)
|
||||
key_str = key.result if key.available else None
|
||||
|
||||
success_groups = []
|
||||
failed_groups = []
|
||||
|
||||
for group_id in target_group_ids:
|
||||
try:
|
||||
if key_str:
|
||||
await group_settings_service.reset_key(
|
||||
group_id, plugin_name_str, key_str
|
||||
)
|
||||
else:
|
||||
await group_settings_service.reset_all_for_plugin(
|
||||
group_id, plugin_name_str
|
||||
)
|
||||
success_groups.append(group_id)
|
||||
except Exception as e:
|
||||
failed_groups.append((group_id, str(e)))
|
||||
|
||||
action = f"配置项 '{key_str}'" if key_str else "所有配置"
|
||||
|
||||
if len(target_group_ids) == 1:
|
||||
if success_groups:
|
||||
msg = (
|
||||
f"✅ 群组 {target_group_ids[0]} 中插件 '{plugin_name_str}' "
|
||||
f"的 {action} 已成功重置。"
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
f"❌ 群组 {target_group_ids[0]} 中插件 '{plugin_name_str}' "
|
||||
f"的 {action} 重置失败: {failed_groups[0][1]}"
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
f"✅ 批量操作完成: 成功为 {len(success_groups)} 个群组重置了 {action}。"
|
||||
)
|
||||
if failed_groups:
|
||||
failed_count = len({g[0] for g in failed_groups})
|
||||
msg += f"\n❌ 其中 {failed_count} 个群组操作失败。"
|
||||
await MessageUtils.build_message(msg).finish()
|
||||
@ -176,12 +176,30 @@ tag_cmd = on_alconna(
|
||||
help_text="删除标签",
|
||||
),
|
||||
Subcommand("clear", help_text="清空所有标签"),
|
||||
Subcommand("prune", alias=["check", "清理"], help_text="清理无效的群组关联"),
|
||||
Subcommand(
|
||||
"clone",
|
||||
Args["source_name", str]["new_name", str],
|
||||
Option("--add", Args["add_groups", MultiVar(str)]),
|
||||
Option("--remove", Args["remove_groups", MultiVar(str)]),
|
||||
Option("--as-dynamic", action=store_true),
|
||||
Option("--desc", Args["description", str]),
|
||||
Option("--mode", Args["mode", ["black", "white"]]),
|
||||
help_text="克隆标签",
|
||||
),
|
||||
),
|
||||
permission=SUPERUSER,
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
tag_cmd.shortcut(
|
||||
"清理标签",
|
||||
command="tag",
|
||||
arguments=["prune"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
|
||||
@tag_cmd.assign("list")
|
||||
async def handle_list():
|
||||
@ -269,17 +287,24 @@ async def handle_create(
|
||||
).finish()
|
||||
|
||||
try:
|
||||
gids_to_create = None
|
||||
unique_gids_count = 0
|
||||
if group_ids.available:
|
||||
unique_gids = list(dict.fromkeys(group_ids.result))
|
||||
gids_to_create = unique_gids
|
||||
unique_gids_count = len(unique_gids)
|
||||
|
||||
tag = await tag_manager.create_tag(
|
||||
name=name.result,
|
||||
is_blacklist=blacklist.result,
|
||||
description=description.result if description.available else None,
|
||||
group_ids=group_ids.result if group_ids.available else None,
|
||||
group_ids=gids_to_create,
|
||||
tag_type=ttype,
|
||||
dynamic_rule=rule.result if rule.available else None,
|
||||
)
|
||||
msg = f"标签 '{tag.name}' 创建成功!"
|
||||
if group_ids.available:
|
||||
msg += f"\n已同时关联 {len(group_ids.result)} 个群组。"
|
||||
msg += f"\n已同时关联 {unique_gids_count} 个群组。"
|
||||
await MessageUtils.build_message(msg).finish()
|
||||
except IntegrityError:
|
||||
await MessageUtils.build_message(
|
||||
@ -411,3 +436,48 @@ async def handle_clear():
|
||||
await MessageUtils.build_message(f"操作完成,已清空 {count} 个标签。").finish()
|
||||
else:
|
||||
await MessageUtils.build_message("操作已取消。").finish()
|
||||
|
||||
|
||||
@tag_cmd.assign("clone")
|
||||
async def handle_clone(
|
||||
bot: Bot,
|
||||
source_name: Match[str],
|
||||
new_name: Match[str],
|
||||
add_groups: Query[list[str] | None] = AlconnaQuery("clone.add.add_groups", None),
|
||||
remove_groups: Query[list[str] | None] = AlconnaQuery(
|
||||
"clone.remove.remove_groups", None
|
||||
),
|
||||
as_dynamic: Query[bool] = AlconnaQuery("clone.as-dynamic.value", False),
|
||||
description: Query[str | None] = AlconnaQuery("clone.desc.description", None),
|
||||
mode: Query[str | None] = AlconnaQuery("clone.mode.mode", None),
|
||||
):
|
||||
try:
|
||||
new_tag = await tag_manager.clone_tag(
|
||||
source_name=source_name.result,
|
||||
new_name=new_name.result,
|
||||
bot=bot,
|
||||
add_groups=add_groups.result,
|
||||
remove_groups=remove_groups.result,
|
||||
as_dynamic=as_dynamic.result,
|
||||
description=description.result,
|
||||
mode=mode.result,
|
||||
)
|
||||
|
||||
tag_type_str = "动态" if new_tag.tag_type == "DYNAMIC" else "静态"
|
||||
group_count = 0
|
||||
if new_tag.tag_type == "STATIC":
|
||||
group_count = await new_tag.groups.all().count()
|
||||
|
||||
msg = f"✅ 成功克隆标签!\n- 新标签: {new_tag.name}\n- 类型: {tag_type_str}"
|
||||
if new_tag.tag_type == "STATIC":
|
||||
msg += f" (含 {group_count} 个群组)"
|
||||
await MessageUtils.build_message(msg).finish()
|
||||
except (ValueError, IntegrityError) as e:
|
||||
await MessageUtils.build_message(f"克隆失败: {e}").finish()
|
||||
|
||||
|
||||
@tag_cmd.assign("prune")
|
||||
async def handle_prune():
|
||||
deleted_count = await tag_manager.prune_stale_group_links()
|
||||
msg = f"清理完成!共移除了 {deleted_count} 个无效的群组关联。"
|
||||
await MessageUtils.build_message(msg).finish()
|
||||
|
||||
@ -270,3 +270,9 @@ class PluginExtraData(BaseModel):
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
|
||||
group_config_model: type[BaseModel] | None = None
|
||||
"""插件的分群配置模型"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
29
zhenxun/models/group_plugin_setting.py
Normal file
29
zhenxun/models/group_plugin_setting.py
Normal file
@ -0,0 +1,29 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
|
||||
class GroupPluginSetting(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
"""自增ID"""
|
||||
group_id = fields.CharField(max_length=255, indexed=True, description="群组ID")
|
||||
"""群组ID"""
|
||||
plugin_name = fields.CharField(
|
||||
max_length=255, indexed=True, description="插件模块名"
|
||||
)
|
||||
"""插件模块名"""
|
||||
settings = fields.JSONField(description="插件的完整配置 (JSON)")
|
||||
"""插件的完整配置 (JSON)"""
|
||||
updated_at = fields.DatetimeField(auto_now=True, description="最后更新时间")
|
||||
"""最后更新时间"""
|
||||
|
||||
cache_type = CacheType.GROUP_PLUGIN_SETTINGS
|
||||
"""缓存类型"""
|
||||
cache_key_field = ("group_id", "plugin_name")
|
||||
"""缓存键字段"""
|
||||
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "group_plugin_settings"
|
||||
table_description = "插件分群通用配置表"
|
||||
unique_together = ("group_id", "plugin_name")
|
||||
@ -20,6 +20,7 @@ require("nonebot_plugin_waiter")
|
||||
|
||||
from .avatar_service import avatar_service
|
||||
from .db_context import Model, disconnect, with_db_timeout
|
||||
from .group_settings_service import group_settings_service
|
||||
from .llm import (
|
||||
AI,
|
||||
AIConfig,
|
||||
@ -77,6 +78,7 @@ __all__ = [
|
||||
"generate_structured",
|
||||
"get_cache_stats",
|
||||
"get_model_instance",
|
||||
"group_settings_service",
|
||||
"list_available_models",
|
||||
"list_embedding_models",
|
||||
"logger",
|
||||
|
||||
223
zhenxun/services/group_settings_service.py
Normal file
223
zhenxun/services/group_settings_service.py
Normal file
@ -0,0 +1,223 @@
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
import ujson as json
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.group_plugin_setting import GroupPluginSetting
|
||||
from zhenxun.services.cache import Cache
|
||||
from zhenxun.services.data_access import DataAccess
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.pydantic_compat import model_dump, model_validate, parse_as
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class GroupSettingsService:
|
||||
"""
|
||||
一个用于管理插件分群配置的服务。
|
||||
集成了聚合缓存、批量操作和版本迁移功能。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.dao = DataAccess(GroupPluginSetting)
|
||||
self._cache = Cache[dict]("group_plugin_settings")
|
||||
|
||||
async def set(
|
||||
self, group_id: str, plugin_name: str, settings_model: BaseModel
|
||||
) -> None:
|
||||
"""
|
||||
为一个插件在指定群组中设置完整的配置模型。
|
||||
|
||||
参数:
|
||||
group_id: 目标群组ID。
|
||||
plugin_name: 插件的模块名。
|
||||
settings_model: 包含完整配置的Pydantic模型实例。
|
||||
"""
|
||||
settings_dict = model_dump(settings_model)
|
||||
json_value = json.dumps(settings_dict, ensure_ascii=False)
|
||||
|
||||
await self.dao.update_or_create(
|
||||
defaults={"settings": json_value}, # type: ignore
|
||||
group_id=group_id,
|
||||
plugin_name=plugin_name,
|
||||
)
|
||||
|
||||
await self.dao.clear_cache(group_id=group_id, plugin_name=plugin_name)
|
||||
|
||||
async def set_key_value(
|
||||
self, group_id: str, plugin_name: str, key: str, value: Any
|
||||
) -> None:
|
||||
"""为一个插件在指定群组中设置单个配置项的值。"""
|
||||
setting_entry, _ = await GroupPluginSetting.get_or_create(
|
||||
defaults={"settings": {}},
|
||||
group_id=group_id,
|
||||
plugin_name=plugin_name,
|
||||
)
|
||||
|
||||
if not isinstance(setting_entry.settings, dict):
|
||||
setting_entry.settings = {}
|
||||
|
||||
setting_entry.settings[key] = value
|
||||
await setting_entry.save(update_fields=["settings"])
|
||||
await self.dao.clear_cache(group_id=group_id, plugin_name=plugin_name)
|
||||
|
||||
async def reset_key(self, group_id: str, plugin_name: str, key: str) -> bool:
|
||||
"""重置单个配置项"""
|
||||
setting = await self.dao.get_or_none(group_id=group_id, plugin_name=plugin_name)
|
||||
if setting and isinstance(setting.settings, dict) and key in setting.settings:
|
||||
del setting.settings[key]
|
||||
if not setting.settings:
|
||||
await setting.delete()
|
||||
else:
|
||||
await setting.save(update_fields=["settings"])
|
||||
await self.dao.clear_cache(group_id=group_id, plugin_name=plugin_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get(
|
||||
self, group_id: str, plugin_name: str, key: str, default: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
获取一个分群配置项的值,如果群组未单独设置,则回退到全局默认值。
|
||||
|
||||
参数:
|
||||
group_id: 目标群组ID。
|
||||
plugin_name: 插件的模块名。
|
||||
key: 配置项的键。
|
||||
default: 如果找不到配置项,返回的默认值。
|
||||
|
||||
返回:
|
||||
配置项的值。
|
||||
"""
|
||||
full_settings = await self.get_all_for_plugin(group_id, plugin_name)
|
||||
return full_settings.get(key, default)
|
||||
|
||||
async def reset_all_for_plugin(self, group_id: str, plugin_name: str) -> bool:
|
||||
"""
|
||||
重置一个插件在指定群组的配置,使其回退到全局默认值。
|
||||
这通过删除数据库中的对应记录来实现。
|
||||
|
||||
参数:
|
||||
group_id: 目标群组ID。
|
||||
plugin_name: 插件的模块名。
|
||||
|
||||
返回:
|
||||
bool: 如果成功删除了一个条目,则返回 True,否则返回 False。
|
||||
"""
|
||||
deleted_count = await self.dao.delete(
|
||||
group_id=group_id, plugin_name=plugin_name
|
||||
)
|
||||
|
||||
if deleted_count > 0:
|
||||
await self.dao.clear_cache(group_id=group_id, plugin_name=plugin_name)
|
||||
logger.debug(f"已重置插件 '{plugin_name}' 在群组 '{group_id}' 的配置。")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@overload
|
||||
async def get_all_for_plugin(
|
||||
self, group_id: str, plugin_name: str, *, parse_model: type[T]
|
||||
) -> T: ...
|
||||
|
||||
@overload
|
||||
async def get_all_for_plugin(
|
||||
self, group_id: str, plugin_name: str, *, parse_model: None = None
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
async def get_all_for_plugin(
|
||||
self, group_id: str, plugin_name: str, *, parse_model: type[T] | None = None
|
||||
) -> T | dict[str, Any]:
|
||||
"""
|
||||
获取一个插件在指定群组中的完整配置,应用了“继承与覆盖”逻辑。
|
||||
它首先获取全局默认配置,然后用数据库中存储的群组特定配置覆盖它。
|
||||
|
||||
参数:
|
||||
group_id: 目标群组ID。
|
||||
plugin_name: 插件的模块名。
|
||||
parse_model: (可选) Pydantic模型,用于解析和验证配置。
|
||||
"""
|
||||
cache_key = f"{group_id}:{plugin_name}"
|
||||
cached_settings = await self._cache.get(cache_key)
|
||||
if cached_settings is not None:
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
if parse_model:
|
||||
try:
|
||||
return parse_as(parse_model, cached_settings)
|
||||
except (ValidationError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"缓存数据 '{cache_key}' 与模型 '{parse_model.__name__}' "
|
||||
f"不匹配: {e}。将从数据库重新加载。"
|
||||
)
|
||||
else:
|
||||
return cached_settings
|
||||
|
||||
logger.debug(f"缓存未命中: {cache_key},从数据库加载。")
|
||||
|
||||
global_config_group = Config.get(plugin_name)
|
||||
final_settings_dict = {
|
||||
key: global_config_group.get(key, build_model=False)
|
||||
for key in global_config_group.configs.keys()
|
||||
}
|
||||
|
||||
group_setting_entry = await self.dao.get_or_none(
|
||||
group_id=group_id, plugin_name=plugin_name
|
||||
)
|
||||
if group_setting_entry:
|
||||
try:
|
||||
group_specific_settings = group_setting_entry.settings
|
||||
if isinstance(group_specific_settings, dict):
|
||||
final_settings_dict.update(group_specific_settings)
|
||||
else:
|
||||
logger.warning(
|
||||
f"群组 {group_id} 插件 '{plugin_name}' 的配置格式不正确"
|
||||
f"(不是字典),已忽略。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"加载群组 {group_id} 插件 '{plugin_name}' 的特定配置时出错: {e}"
|
||||
)
|
||||
|
||||
await self._cache.set(cache_key, final_settings_dict)
|
||||
|
||||
if parse_model:
|
||||
try:
|
||||
return parse_as(parse_model, final_settings_dict)
|
||||
except (ValidationError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"插件 '{plugin_name}' 的配置无法解析为 '{parse_model.__name__}'。"
|
||||
f"值: {final_settings_dict}, 错误: {e}。将返回一个默认模型实例。"
|
||||
)
|
||||
return parse_as(parse_model, {})
|
||||
|
||||
return final_settings_dict
|
||||
|
||||
async def set_bulk(
|
||||
self, group_ids: list[str], plugin_name: str, key: str, value: Any
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
为多个群组批量设置同一个配置项。
|
||||
|
||||
参数:
|
||||
group_ids: 目标群组ID列表。
|
||||
plugin_name: 插件模块名。
|
||||
key: 配置项的键。
|
||||
value: 要设置的值。
|
||||
|
||||
返回:
|
||||
一个元组 (updated_count, created_count)。
|
||||
"""
|
||||
if not group_ids:
|
||||
return 0, 0
|
||||
|
||||
for group_id in group_ids:
|
||||
current_settings = await self.get_all_for_plugin(group_id, plugin_name)
|
||||
current_settings[key] = value
|
||||
await self.set(
|
||||
group_id, plugin_name, model_validate(BaseModel, current_settings)
|
||||
)
|
||||
return len(group_ids), 0
|
||||
|
||||
|
||||
group_settings_service = GroupSettingsService()
|
||||
@ -354,6 +354,24 @@ class GeminiAdapter(BaseAdapter):
|
||||
|
||||
return safety_settings if safety_settings else None
|
||||
|
||||
def validate_response(self, response_json: dict[str, Any]) -> None:
|
||||
"""验证 Gemini API 响应,增加对 promptFeedback 的检查"""
|
||||
super().validate_response(response_json)
|
||||
|
||||
if prompt_feedback := response_json.get("promptFeedback"):
|
||||
if block_reason := prompt_feedback.get("blockReason"):
|
||||
logger.warning(
|
||||
f"Gemini 内容因 promptFeedback 被安全过滤: {block_reason}"
|
||||
)
|
||||
raise LLMException(
|
||||
f"内容被安全过滤: {block_reason}",
|
||||
code=LLMErrorCode.CONTENT_FILTERED,
|
||||
details={
|
||||
"block_reason": block_reason,
|
||||
"safety_ratings": prompt_feedback.get("safetyRatings"),
|
||||
},
|
||||
)
|
||||
|
||||
def parse_response(
|
||||
self,
|
||||
model: "LLMModel",
|
||||
|
||||
@ -192,10 +192,20 @@ def get_default_providers() -> list[dict[str, Any]]:
|
||||
"api_base": "https://generativelanguage.googleapis.com",
|
||||
"api_type": "gemini",
|
||||
"models": [
|
||||
{"model_name": "gemini-2.0-flash"},
|
||||
{"model_name": "gemini-2.5-flash"},
|
||||
{"model_name": "gemini-2.5-pro"},
|
||||
{"model_name": "gemini-2.5-flash-lite-preview-06-17"},
|
||||
{"model_name": "gemini-2.5-flash-lite"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "OpenRouter",
|
||||
"api_key": "YOUR_OPENROUTER_API_KEY",
|
||||
"api_base": "https://openrouter.ai/api",
|
||||
"api_type": "openrouter",
|
||||
"models": [
|
||||
{"model_name": "google/gemini-2.5-pro"},
|
||||
{"model_name": "google/gemini-2.5-flash"},
|
||||
{"model_name": "x-ai/grok-4"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
@ -9,6 +9,8 @@ import fnmatch
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
|
||||
class ModelModality(str, Enum):
|
||||
TEXT = "text"
|
||||
@ -50,6 +52,46 @@ GEMINI_IMAGE_GEN_CAPABILITIES = ModelCapabilities(
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
GPT_MULTIMODAL_IO_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.AUDIO, ModelModality.IMAGE},
|
||||
output_modalities={ModelModality.TEXT, ModelModality.AUDIO},
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
GPT_IMAGE_GENERATION_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||
output_modalities={ModelModality.IMAGE},
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
GPT_VIDEO_GENERATION_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
||||
output_modalities={ModelModality.VIDEO},
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
DEFAULT_PERMISSIVE_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={
|
||||
ModelModality.TEXT,
|
||||
ModelModality.IMAGE,
|
||||
ModelModality.AUDIO,
|
||||
ModelModality.VIDEO,
|
||||
},
|
||||
output_modalities={
|
||||
ModelModality.TEXT,
|
||||
ModelModality.IMAGE,
|
||||
ModelModality.AUDIO,
|
||||
ModelModality.VIDEO,
|
||||
},
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
|
||||
DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
||||
@ -91,11 +133,8 @@ MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
|
||||
is_embedding_model=True,
|
||||
),
|
||||
"*gemini-*-image-preview*": GEMINI_IMAGE_GEN_CAPABILITIES,
|
||||
"gemini-2.5-pro*": GEMINI_CAPABILITIES,
|
||||
"gemini-1.5-pro*": GEMINI_CAPABILITIES,
|
||||
"gemini-2.5-flash*": GEMINI_CAPABILITIES,
|
||||
"gemini-2.0-flash*": GEMINI_CAPABILITIES,
|
||||
"gemini-1.5-flash*": GEMINI_CAPABILITIES,
|
||||
"gemini-*-pro*": GEMINI_CAPABILITIES,
|
||||
"gemini-*-flash*": GEMINI_CAPABILITIES,
|
||||
"GLM-4V-Flash": ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
@ -112,6 +151,13 @@ MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
|
||||
"doubao-1-5-thinking-vision-pro": DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES,
|
||||
"deepseek-chat": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||
"deepseek-reasoner": STANDARD_TEXT_TOOL_CAPABILITIES,
|
||||
"gpt-5*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES,
|
||||
"gpt-4.1*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES,
|
||||
"gpt-4o*": GPT_MULTIMODAL_IO_CAPABILITIES,
|
||||
"o3*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES,
|
||||
"o4-mini*": GPT_ADVANCED_TEXT_IMAGE_CAPABILITIES,
|
||||
"gpt image*": GPT_IMAGE_GENERATION_CAPABILITIES,
|
||||
"sora*": GPT_VIDEO_GENERATION_CAPABILITIES,
|
||||
}
|
||||
|
||||
|
||||
@ -126,11 +172,25 @@ def get_model_capabilities(model_name: str) -> ModelCapabilities:
|
||||
canonical_name = c_name
|
||||
break
|
||||
|
||||
if canonical_name in MODEL_CAPABILITIES_REGISTRY:
|
||||
return MODEL_CAPABILITIES_REGISTRY[canonical_name]
|
||||
parts = canonical_name.split("/")
|
||||
names_to_check = ["/".join(parts[i:]) for i in range(len(parts))]
|
||||
|
||||
for pattern, capabilities in MODEL_CAPABILITIES_REGISTRY.items():
|
||||
if "*" in pattern and fnmatch.fnmatch(model_name, pattern):
|
||||
return capabilities
|
||||
logger.trace(f"为 '{model_name}' 生成的检查列表: {names_to_check}")
|
||||
|
||||
return ModelCapabilities()
|
||||
for name in names_to_check:
|
||||
if name in MODEL_CAPABILITIES_REGISTRY:
|
||||
logger.debug(f"模型 '{model_name}' 通过精确匹配 '{name}' 找到能力定义。")
|
||||
return MODEL_CAPABILITIES_REGISTRY[name]
|
||||
|
||||
for pattern, capabilities in MODEL_CAPABILITIES_REGISTRY.items():
|
||||
if "*" in pattern and fnmatch.fnmatch(name, pattern):
|
||||
logger.debug(
|
||||
f"模型 '{model_name}' 通过通配符匹配 '{name}'(pattern: '{pattern}')"
|
||||
f"找到能力定义。"
|
||||
)
|
||||
return capabilities
|
||||
|
||||
logger.warning(
|
||||
f"模型 '{model_name}' 的能力定义未在注册表中找到,将使用默认的'全功能'回退配置"
|
||||
)
|
||||
return DEFAULT_PERMISSIVE_CAPABILITIES
|
||||
|
||||
@ -40,7 +40,7 @@ class Renderable(ABC):
|
||||
@abstractmethod
|
||||
def get_children(self) -> Iterable["Renderable"]:
|
||||
"""
|
||||
[新增] 返回一个包含所有直接子组件的可迭代对象。
|
||||
返回一个包含所有直接子组件的可迭代对象。
|
||||
|
||||
这使得渲染服务能够递归地遍历整个组件树,以执行依赖收集(CSS、JS)等任务。
|
||||
非容器组件应返回一个空列表。
|
||||
|
||||
@ -75,6 +75,7 @@ class RendererService:
|
||||
self._custom_globals: dict[str, Callable] = {}
|
||||
|
||||
self.filter("dump_json")(self._pydantic_tojson_filter)
|
||||
self.global_function("inline_asset")(self._inline_asset_global)
|
||||
|
||||
def _create_jinja_env(self) -> Environment:
|
||||
"""
|
||||
@ -176,9 +177,24 @@ class RendererService:
|
||||
|
||||
return decorator
|
||||
|
||||
async def _inline_asset_global(self, namespaced_path: str) -> str:
|
||||
"""
|
||||
一个Jinja2全局函数,用于读取并内联一个已注册命名空间下的资源文件内容。
|
||||
主要用于内联SVG,以解决浏览器的跨域安全问题。
|
||||
"""
|
||||
if not self._jinja_env or not self._jinja_env.loader:
|
||||
return f"<!-- Error: Jinja env not ready for {namespaced_path} -->"
|
||||
try:
|
||||
source, _, _ = self._jinja_env.loader.get_source(
|
||||
self._jinja_env, namespaced_path
|
||||
)
|
||||
return source
|
||||
except TemplateNotFound:
|
||||
return f"<!-- Asset not found: {namespaced_path} -->"
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
[新增] 延迟初始化方法,在 on_startup 钩子中调用。
|
||||
延迟初始化方法,在 on_startup 钩子中调用。
|
||||
|
||||
负责初始化截图引擎和主题管理器,确保在首次渲染前所有依赖都已准备就绪。
|
||||
使用锁来防止并发初始化。
|
||||
@ -223,27 +239,36 @@ class RendererService:
|
||||
)
|
||||
|
||||
style_paths_to_load = []
|
||||
if manifest and "styles" in manifest:
|
||||
styles = (
|
||||
[manifest["styles"]]
|
||||
if isinstance(manifest["styles"], str)
|
||||
else manifest["styles"]
|
||||
)
|
||||
for style_path in styles:
|
||||
full_style_path = str(Path(component_path_base) / style_path).replace(
|
||||
"\\", "/"
|
||||
if manifest and manifest.get("styles"):
|
||||
styles = manifest["styles"]
|
||||
styles = [styles] if isinstance(styles, str) else styles
|
||||
|
||||
resolution_base_path = Path(component_path_base)
|
||||
if variant:
|
||||
skin_manifest_path = str(Path(component_path_base) / "skins" / variant)
|
||||
skin_manifest = await context.theme_manager._load_single_manifest(
|
||||
skin_manifest_path
|
||||
)
|
||||
style_paths_to_load.append(full_style_path)
|
||||
if skin_manifest and "styles" in skin_manifest:
|
||||
resolution_base_path = Path(skin_manifest_path)
|
||||
|
||||
style_paths_to_load.extend(
|
||||
str(resolution_base_path / style).replace("\\", "/") for style in styles
|
||||
)
|
||||
else:
|
||||
resolved_template_name = (
|
||||
base_template_path = (
|
||||
await context.theme_manager._resolve_component_template(
|
||||
component, context
|
||||
)
|
||||
)
|
||||
conventional_style_path = str(
|
||||
Path(resolved_template_name).with_name("style.css")
|
||||
base_style_path = str(
|
||||
Path(base_template_path).with_name("style.css")
|
||||
).replace("\\", "/")
|
||||
style_paths_to_load.append(conventional_style_path)
|
||||
style_paths_to_load.append(base_style_path)
|
||||
|
||||
if variant:
|
||||
skin_style_path = f"{component_path_base}/skins/{variant}/style.css"
|
||||
style_paths_to_load.append(skin_style_path)
|
||||
|
||||
for css_template_path in style_paths_to_load:
|
||||
try:
|
||||
|
||||
@ -172,24 +172,45 @@ class ResourceResolver:
|
||||
|
||||
if asset_path.startswith("@"):
|
||||
try:
|
||||
full_asset_path = self.theme_manager.jinja_env.join_path(
|
||||
asset_path, current_template_name
|
||||
)
|
||||
_source, file_abs_path, _uptodate = (
|
||||
self.theme_manager.jinja_env.loader.get_source(
|
||||
self.theme_manager.jinja_env, full_asset_path
|
||||
if "/" not in asset_path:
|
||||
raise TemplateNotFound(f"无效的命名空间路径: {asset_path}")
|
||||
|
||||
namespace, rel_path = asset_path.split("/", 1)
|
||||
|
||||
loader = self.theme_manager.jinja_env.loader
|
||||
if (
|
||||
isinstance(loader, ChoiceLoader)
|
||||
and loader.loaders
|
||||
and isinstance(loader.loaders[0], PrefixLoader)
|
||||
):
|
||||
prefix_loader = loader.loaders[0]
|
||||
if namespace in prefix_loader.mapping:
|
||||
loader_for_namespace = prefix_loader.mapping[namespace]
|
||||
if isinstance(loader_for_namespace, FileSystemLoader):
|
||||
base_path = Path(loader_for_namespace.searchpath[0])
|
||||
file_abs_path = (base_path / rel_path).resolve()
|
||||
|
||||
if file_abs_path.is_file():
|
||||
logger.debug(
|
||||
f"Resolved namespaced asset"
|
||||
f" '{asset_path}' -> '{file_abs_path}'"
|
||||
)
|
||||
return file_abs_path.as_uri()
|
||||
else:
|
||||
raise TemplateNotFound(asset_path)
|
||||
else:
|
||||
raise TemplateNotFound(
|
||||
f"Unsupported loader type for namespace '{namespace}'."
|
||||
)
|
||||
else:
|
||||
raise TemplateNotFound(f"Namespace '{namespace}' not found.")
|
||||
else:
|
||||
raise TemplateNotFound(
|
||||
f"无法解析命名空间资源 '{asset_path}',加载器结构不符合预期。"
|
||||
)
|
||||
)
|
||||
if file_abs_path:
|
||||
logger.debug(
|
||||
f"Jinja Loader resolved asset '{asset_path}'->'{file_abs_path}'"
|
||||
)
|
||||
return Path(file_abs_path).absolute().as_uri()
|
||||
|
||||
except TemplateNotFound:
|
||||
logger.warning(
|
||||
f"资源文件在命名空间中未找到: '{asset_path}'"
|
||||
f"(在模板 '{current_template_name}' 中引用)"
|
||||
)
|
||||
logger.warning(f"资源文件在命名空间中未找到: '{asset_path}'")
|
||||
return ""
|
||||
|
||||
search_paths: list[tuple[str, Path]] = []
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import Any, ClassVar
|
||||
|
||||
from aiocache import Cache, cached
|
||||
from arclet.alconna import Alconna, Args
|
||||
import nonebot
|
||||
from nonebot.adapters import Bot
|
||||
from tortoise.exceptions import IntegrityError
|
||||
from tortoise.expressions import Q
|
||||
@ -156,8 +157,9 @@ class TagManager:
|
||||
dynamic_rule=dynamic_rule,
|
||||
)
|
||||
if group_ids:
|
||||
unique_group_ids = list(dict.fromkeys(group_ids))
|
||||
await GroupTagLink.bulk_create(
|
||||
[GroupTagLink(tag=tag, group_id=gid) for gid in group_ids]
|
||||
[GroupTagLink(tag=tag, group_id=gid) for gid in unique_group_ids]
|
||||
)
|
||||
return tag
|
||||
|
||||
@ -175,6 +177,49 @@ class TagManager:
|
||||
deleted_count = await GroupTag.filter(name=name).delete()
|
||||
return deleted_count > 0
|
||||
|
||||
@invalidate_on_change
|
||||
async def remove_group_from_all_tags(self, group_id: str) -> int:
|
||||
"""
|
||||
从所有静态标签中移除一个指定的群组ID。
|
||||
主要用于机器人退群时的实时清理。
|
||||
|
||||
参数:
|
||||
group_id: 要移除的群组ID。
|
||||
|
||||
返回:
|
||||
被删除的关联数量。
|
||||
"""
|
||||
deleted_count = await GroupTagLink.filter(group_id=group_id).delete()
|
||||
if deleted_count > 0:
|
||||
logger.info(f"已从 {deleted_count} 个标签中移除群组 {group_id} 的关联。")
|
||||
return deleted_count
|
||||
|
||||
@invalidate_on_change
|
||||
async def prune_stale_group_links(self) -> int:
|
||||
"""
|
||||
清理所有静态标签中无效的群组关联。
|
||||
无效指的是机器人已不再任何一个已连接的Bot的群组列表中。
|
||||
|
||||
返回:
|
||||
被清理的无效关联的总数。
|
||||
"""
|
||||
all_bot_group_ids = set()
|
||||
for bot in nonebot.get_bots().values():
|
||||
groups, _ = await PlatformUtils.get_group_list(bot)
|
||||
all_bot_group_ids.update(g.group_id for g in groups if g.group_id)
|
||||
|
||||
all_static_links = await GroupTagLink.filter(tag__tag_type="STATIC").all()
|
||||
|
||||
stale_link_ids = [
|
||||
link.id
|
||||
for link in all_static_links
|
||||
if link.group_id not in all_bot_group_ids
|
||||
]
|
||||
|
||||
if stale_link_ids:
|
||||
return await GroupTagLink.filter(id__in=stale_link_ids).delete()
|
||||
return 0
|
||||
|
||||
@invalidate_on_change
|
||||
async def add_groups_to_tag(self, name: str, group_ids: list[str]) -> int: # type: ignore
|
||||
"""
|
||||
@ -186,11 +231,12 @@ class TagManager:
|
||||
if tag.tag_type == "DYNAMIC":
|
||||
raise ValueError("不能向动态标签手动添加群组。")
|
||||
|
||||
unique_group_ids = list(dict.fromkeys(group_ids))
|
||||
await GroupTagLink.bulk_create(
|
||||
[GroupTagLink(tag=tag, group_id=gid) for gid in group_ids],
|
||||
[GroupTagLink(tag=tag, group_id=gid) for gid in unique_group_ids],
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
return len(group_ids)
|
||||
return len(unique_group_ids)
|
||||
|
||||
@invalidate_on_change
|
||||
async def remove_groups_from_tag(self, name: str, group_ids: list[str]) -> int:
|
||||
@ -205,6 +251,72 @@ class TagManager:
|
||||
).delete()
|
||||
return deleted_count
|
||||
|
||||
@invalidate_on_change
|
||||
async def clone_tag(
|
||||
self,
|
||||
source_name: str,
|
||||
new_name: str,
|
||||
bot: Bot,
|
||||
add_groups: list[str] | None = None,
|
||||
remove_groups: list[str] | None = None,
|
||||
as_dynamic: bool = False,
|
||||
description: str | None = None,
|
||||
mode: str | None = None,
|
||||
) -> GroupTag:
|
||||
"""
|
||||
克隆一个标签,支持动态转静态、修改群组等。
|
||||
"""
|
||||
source_tag = await GroupTag.get_or_none(name=source_name)
|
||||
if not source_tag:
|
||||
raise ValueError(f"源标签 '{source_name}' 不存在。")
|
||||
|
||||
if await GroupTag.exists(name=new_name):
|
||||
raise IntegrityError(f"目标标签 '{new_name}' 已存在。")
|
||||
|
||||
tag_type = "STATIC"
|
||||
group_ids_to_set: list[str] | None = None
|
||||
dynamic_rule: str | dict | None = None
|
||||
|
||||
if source_tag.tag_type == "STATIC":
|
||||
if as_dynamic:
|
||||
raise ValueError("不能将静态标签克隆为动态标签。")
|
||||
group_ids_to_set = await GroupTagLink.filter(tag=source_tag).values_list( # type: ignore
|
||||
"group_id", flat=True
|
||||
)
|
||||
else:
|
||||
if as_dynamic:
|
||||
tag_type = "DYNAMIC"
|
||||
dynamic_rule = source_tag.dynamic_rule
|
||||
if add_groups or remove_groups:
|
||||
raise ValueError(
|
||||
"克隆为动态标签时,不支持 --add 或 --remove 操作。"
|
||||
)
|
||||
else:
|
||||
group_ids_to_set = await self.resolve_tag_to_group_ids(
|
||||
source_name, bot=bot
|
||||
)
|
||||
|
||||
if group_ids_to_set is not None:
|
||||
final_group_set = set(group_ids_to_set)
|
||||
if add_groups:
|
||||
final_group_set.update(add_groups)
|
||||
if remove_groups:
|
||||
final_group_set.difference_update(remove_groups)
|
||||
group_ids_to_set = list(final_group_set)
|
||||
|
||||
is_blacklist = (
|
||||
(mode == "black") if mode is not None else source_tag.is_blacklist
|
||||
)
|
||||
|
||||
return await self.create_tag(
|
||||
name=new_name,
|
||||
is_blacklist=is_blacklist,
|
||||
description=description,
|
||||
group_ids=group_ids_to_set,
|
||||
tag_type=tag_type,
|
||||
dynamic_rule=dynamic_rule,
|
||||
)
|
||||
|
||||
async def list_tags_with_counts(self) -> list[dict]:
|
||||
"""列出所有标签及其关联的群组数量。"""
|
||||
tags = await GroupTag.all().prefetch_related("groups")
|
||||
@ -514,11 +626,13 @@ class TagManager:
|
||||
raise ValueError("不能为动态标签设置静态群组列表。")
|
||||
async with in_transaction():
|
||||
await GroupTagLink.filter(tag=tag).delete()
|
||||
await GroupTagLink.bulk_create(
|
||||
[GroupTagLink(tag=tag, group_id=gid) for gid in group_ids],
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
return len(group_ids)
|
||||
unique_group_ids = list(dict.fromkeys(group_ids))
|
||||
if unique_group_ids:
|
||||
await GroupTagLink.bulk_create(
|
||||
[GroupTagLink(tag=tag, group_id=gid) for gid in unique_group_ids],
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
return len(unique_group_ids)
|
||||
|
||||
@invalidate_on_change
|
||||
async def clear_all_tags(self) -> int:
|
||||
|
||||
@ -64,13 +64,13 @@ class RenderableComponent(BaseModel, Renderable):
|
||||
|
||||
@compat_computed_field
|
||||
def inline_style_str(self) -> str:
|
||||
"""[新增] 一个辅助属性,将内联样式字典转换为CSS字符串"""
|
||||
"""一个辅助属性,将内联样式字典转换为CSS字符串"""
|
||||
if not self.inline_style:
|
||||
return ""
|
||||
return "; ".join(f"{k}: {v}" for k, v in self.inline_style.items())
|
||||
|
||||
def get_extra_css(self, context: Any) -> str | Awaitable[str]:
|
||||
return ""
|
||||
return self.component_css or ""
|
||||
|
||||
|
||||
class ContainerComponent(RenderableComponent, ABC):
|
||||
@ -86,7 +86,7 @@ class ContainerComponent(RenderableComponent, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_required_scripts(self) -> list[str]:
|
||||
"""[新增] 聚合所有子组件的脚本依赖。"""
|
||||
"""聚合所有子组件的脚本依赖。"""
|
||||
scripts = set(super().get_required_scripts())
|
||||
for child in self.get_children():
|
||||
if child:
|
||||
@ -94,7 +94,7 @@ class ContainerComponent(RenderableComponent, ABC):
|
||||
return list(scripts)
|
||||
|
||||
def get_required_styles(self) -> list[str]:
|
||||
"""[新增] 聚合所有子组件的样式依赖。"""
|
||||
"""聚合所有子组件的样式依赖。"""
|
||||
styles = set(super().get_required_styles())
|
||||
for child in self.get_children():
|
||||
if child:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.internal.params import Depends
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.params import Command
|
||||
@ -9,6 +10,7 @@ from nonebot_plugin_session import EventSession
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.services import group_settings_service
|
||||
from zhenxun.utils.limiters import ConcurrencyLimiter, FreqLimiter, RateLimiter
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.time_utils import TimeUtils
|
||||
@ -249,6 +251,34 @@ def GetConfig(
|
||||
return Depends(dependency)
|
||||
|
||||
|
||||
def GetGroupConfig(model: type[Any]):
|
||||
"""
|
||||
依赖注入函数,用于获取并解析插件的分群配置。
|
||||
"""
|
||||
|
||||
async def dependency(matcher: Matcher, session: EventSession):
|
||||
"""
|
||||
实际的依赖注入逻辑。
|
||||
"""
|
||||
plugin_name = matcher.plugin_name
|
||||
group_id = session.id3 or session.id2
|
||||
|
||||
if not plugin_name:
|
||||
raise SkippedException("无法确定插件名称以获取配置")
|
||||
|
||||
if not group_id:
|
||||
try:
|
||||
return model()
|
||||
except Exception:
|
||||
raise SkippedException("在私聊中无法获取分群配置")
|
||||
|
||||
return await group_settings_service.get_all_for_plugin(
|
||||
group_id, plugin_name, parse_model=model
|
||||
)
|
||||
|
||||
return Depends(dependency)
|
||||
|
||||
|
||||
def CheckConfig(
|
||||
module: str | None = None,
|
||||
config: str | list[str] = "",
|
||||
|
||||
@ -53,6 +53,8 @@ class CacheType(StrEnum):
|
||||
"""全局全部插件"""
|
||||
GROUPS = "GLOBAL_ALL_GROUPS"
|
||||
"""全局全部群组"""
|
||||
GROUP_PLUGIN_SETTINGS = "GROUP_PLUGIN_SETTINGS"
|
||||
"""插件分群配置"""
|
||||
USERS = "GLOBAL_ALL_USERS"
|
||||
"""全部用户"""
|
||||
BAN = "GLOBAL_ALL_BAN"
|
||||
|
||||
@ -6,7 +6,7 @@ import random
|
||||
import re
|
||||
|
||||
import imagehash
|
||||
from nonebot.utils import is_coroutine_callable
|
||||
from nonebot.utils import is_coroutine_callable, run_sync
|
||||
from PIL import Image
|
||||
|
||||
from zhenxun.configs.path_config import TEMP_PATH
|
||||
@ -378,7 +378,9 @@ async def get_download_image_hash(url: str, mark: str, use_proxy: bool = False)
|
||||
if await AsyncHttpx.download_file(
|
||||
url, TEMP_PATH / f"compare_download_{mark}_img.jpg", use_proxy=use_proxy
|
||||
):
|
||||
img_hash = get_img_hash(TEMP_PATH / f"compare_download_{mark}_img.jpg")
|
||||
img_hash = await run_sync(get_img_hash)(
|
||||
TEMP_PATH / f"compare_download_{mark}_img.jpg"
|
||||
)
|
||||
return str(img_hash)
|
||||
except Exception as e:
|
||||
logger.warning("下载读取图片Hash出错", e=e)
|
||||
|
||||
@ -24,6 +24,7 @@ __all__ = [
|
||||
"_is_pydantic_type",
|
||||
"compat_computed_field",
|
||||
"dump_json_safely",
|
||||
"model_construct",
|
||||
"model_copy",
|
||||
"model_dump",
|
||||
"model_json_schema",
|
||||
@ -45,6 +46,16 @@ def model_copy(
|
||||
return model.copy(update=update_dict, deep=deep)
|
||||
|
||||
|
||||
def model_construct(model_class: type[T], **kwargs: Any) -> T:
|
||||
"""
|
||||
Pydantic `model_construct` (v2) 与 `construct` (v1) 的兼容函数。
|
||||
"""
|
||||
if PYDANTIC_V2:
|
||||
return model_class.model_construct(**kwargs)
|
||||
else:
|
||||
return model_class.construct(**kwargs)
|
||||
|
||||
|
||||
def model_validate(model_class: type[T], obj: Any) -> T:
|
||||
"""
|
||||
Pydantic `model_validate` (v2) 与 `parse_obj` (v1) 的兼容函数。
|
||||
|
||||
Loading…
Reference in New Issue
Block a user