mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
⭐导入最近PR
This commit is contained in:
parent
daad4c718b
commit
4ec6f39af1
@ -16,6 +16,7 @@ from zhenxun.models.sign_user import SignUser
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.decorator.shop import shop_register
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.manager.resource_manager import ResourceManager
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
@ -70,7 +71,7 @@ from public.bag_users t1
|
||||
"""
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
await ResourceManager.init_resources()
|
||||
"""签到与用户的数据迁移"""
|
||||
|
||||
@ -26,6 +26,21 @@ __plugin_meta__ = PluginMetadata(
|
||||
_matcher = on_alconna(Alconna("关于"), priority=5, block=True, rule=to_me())
|
||||
|
||||
|
||||
QQ_INFO = """
|
||||
『绪山真寻Bot』
|
||||
版本:{version}
|
||||
简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处
|
||||
""".strip()
|
||||
|
||||
INFO = """
|
||||
『绪山真寻Bot』
|
||||
版本:{version}
|
||||
简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处
|
||||
项目地址:https://github.com/zhenxun-org/zhenxun_bot
|
||||
文档地址:https://zhenxun-org.github.io/zhenxun_bot/
|
||||
""".strip()
|
||||
|
||||
|
||||
@_matcher.handle()
|
||||
async def _(session: Uninfo, arparma: Arparma):
|
||||
ver_file = Path() / "__version__"
|
||||
@ -35,25 +50,11 @@ async def _(session: Uninfo, arparma: Arparma):
|
||||
if text := await f.read():
|
||||
version = text.split(":")[-1].strip()
|
||||
if PlatformUtils.is_qbot(session):
|
||||
info: list[str | Path] = [
|
||||
f"""
|
||||
『绪山真寻Bot』
|
||||
版本:{version}
|
||||
简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处
|
||||
""".strip()
|
||||
]
|
||||
result: list[str | Path] = [QQ_INFO.format(version=version)]
|
||||
path = DATA_PATH / "about.png"
|
||||
if path.exists():
|
||||
info.append(path)
|
||||
result.append(path)
|
||||
await MessageUtils.build_message(result).send() # type: ignore
|
||||
else:
|
||||
info = [
|
||||
f"""
|
||||
『绪山真寻Bot』
|
||||
版本:{version}
|
||||
简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处
|
||||
项目地址:https://github.com/HibiKier/zhenxun_bot
|
||||
文档地址:https://hibikier.github.io/zhenxun_bot/
|
||||
""".strip()
|
||||
]
|
||||
await MessageUtils.build_message(info).send() # type: ignore
|
||||
logger.info("查看关于", arparma.header_result, session=session)
|
||||
await MessageUtils.build_message(INFO.format(version=version)).send()
|
||||
logger.info("查看关于", arparma.header_result, session=session)
|
||||
|
||||
@ -8,6 +8,7 @@ from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate
|
||||
|
||||
|
||||
async def call_ban(user_id: str):
|
||||
"""调用ban
|
||||
|
||||
@ -18,7 +19,6 @@ async def call_ban(user_id: str):
|
||||
logger.info("辱骂次数过多,已将用户加入黑名单...", "ban", session=user_id)
|
||||
|
||||
|
||||
|
||||
class BanManage:
|
||||
@classmethod
|
||||
async def build_ban_image(
|
||||
|
||||
@ -4,7 +4,8 @@ from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.task_info import TaskInfo
|
||||
from zhenxun.utils.enum import BlockType, PluginType
|
||||
from zhenxun.services.cache import Cache
|
||||
from zhenxun.utils.enum import BlockType, CacheType, PluginType
|
||||
from zhenxun.utils.exception import GroupInfoNotFound
|
||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
|
||||
|
||||
@ -245,10 +246,12 @@ class PluginManage:
|
||||
参数:
|
||||
group_id: 群组id
|
||||
"""
|
||||
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
|
||||
status=False
|
||||
group, _ = await GroupConsole.get_or_create(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
|
||||
group.status = False
|
||||
await group.save(update_fields=["status"])
|
||||
|
||||
@classmethod
|
||||
async def wake(cls, group_id: str):
|
||||
"""醒来
|
||||
@ -256,9 +259,11 @@ class PluginManage:
|
||||
参数:
|
||||
group_id: 群组id
|
||||
"""
|
||||
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
|
||||
status=True
|
||||
group, _ = await GroupConsole.get_or_create(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
group.status = True
|
||||
await group.save(update_fields=["status"])
|
||||
|
||||
@classmethod
|
||||
async def block(cls, module: str):
|
||||
@ -267,7 +272,9 @@ class PluginManage:
|
||||
参数:
|
||||
module: 模块名
|
||||
"""
|
||||
await PluginInfo.filter(module=module).update(status=False)
|
||||
if plugin := await PluginInfo.get_plugin(module=module):
|
||||
plugin.status = False
|
||||
await plugin.save(update_fields=["status"])
|
||||
|
||||
@classmethod
|
||||
async def unblock(cls, module: str):
|
||||
@ -276,7 +283,9 @@ class PluginManage:
|
||||
参数:
|
||||
module: 模块名
|
||||
"""
|
||||
await PluginInfo.filter(module=module).update(status=True)
|
||||
if plugin := await PluginInfo.get_plugin(module=module):
|
||||
plugin.status = True
|
||||
await plugin.save(update_fields=["status"])
|
||||
|
||||
@classmethod
|
||||
async def block_group_plugin(cls, plugin_name: str, group_id: str) -> str:
|
||||
|
||||
@ -14,6 +14,7 @@ from zhenxun.services.log import logger
|
||||
from zhenxun.utils._build_image import BuildImage
|
||||
from zhenxun.utils._image_template import ImageTemplate
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
BASE_PATH = DATA_PATH / "welcome_message"
|
||||
@ -91,7 +92,7 @@ def migrate(path: Path):
|
||||
json.dump(new_data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
def _():
|
||||
"""数据迁移
|
||||
|
||||
|
||||
@ -40,7 +40,13 @@ __plugin_meta__ = PluginMetadata(
|
||||
value="zhenxun",
|
||||
help="帮助图片样式 [normal, HTML, zhenxun]",
|
||||
default_value="zhenxun",
|
||||
)
|
||||
),
|
||||
RegisterConfig(
|
||||
key="detail_type",
|
||||
value="zhenxun",
|
||||
help="帮助详情图片样式 ['normal', 'zhenxun']",
|
||||
default_value="zhenxun",
|
||||
),
|
||||
],
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
@ -1,13 +1,19 @@
|
||||
from pathlib import Path
|
||||
|
||||
import nonebot
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot_plugin_htmlrender import template_to_pic
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.path_config import IMAGE_PATH
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.path_config import IMAGE_PATH, TEMPLATE_PATH
|
||||
from zhenxun.configs.utils import PluginExtraData
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.statistics import Statistics
|
||||
from zhenxun.utils._image_template import ImageTemplate
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate
|
||||
from zhenxun.utils.image_utils import BuildImage
|
||||
|
||||
from ._config import (
|
||||
GROUP_HELP_PATH,
|
||||
@ -80,9 +86,96 @@ async def get_user_allow_help(user_id: str) -> list[PluginType]:
|
||||
return type_list
|
||||
|
||||
|
||||
async def get_plugin_help(
|
||||
user_id: str, name: str, is_superuser: bool
|
||||
) -> str | BuildImage:
|
||||
async def get_normal_help(
|
||||
metadata: PluginMetadata, extra: PluginExtraData, is_superuser: bool
|
||||
) -> str | bytes:
|
||||
"""构建默认帮助详情
|
||||
|
||||
参数:
|
||||
metadata: PluginMetadata
|
||||
extra: PluginExtraData
|
||||
is_superuser: 是否超级用户帮助
|
||||
|
||||
返回:
|
||||
str | bytes: 返回信息
|
||||
"""
|
||||
items = None
|
||||
if is_superuser:
|
||||
if usage := extra.superuser_help:
|
||||
items = {
|
||||
"简介": metadata.description,
|
||||
"用法": usage,
|
||||
}
|
||||
else:
|
||||
items = {
|
||||
"简介": metadata.description,
|
||||
"用法": metadata.usage,
|
||||
}
|
||||
if items:
|
||||
return (await ImageTemplate.hl_page(metadata.name, items)).pic2bytes()
|
||||
return "该功能没有帮助信息"
|
||||
|
||||
|
||||
def min_leading_spaces(str_list: list[str]) -> int:
|
||||
min_spaces = 9999
|
||||
|
||||
for s in str_list:
|
||||
leading_spaces = len(s) - len(s.lstrip(" "))
|
||||
|
||||
if leading_spaces < min_spaces:
|
||||
min_spaces = leading_spaces
|
||||
|
||||
return min_spaces if min_spaces != 9999 else 0
|
||||
|
||||
|
||||
def split_text(text: str):
|
||||
split_text = text.split("\n")
|
||||
min_spaces = min_leading_spaces(split_text)
|
||||
if min_spaces > 0:
|
||||
split_text = [s[min_spaces:] for s in split_text]
|
||||
return [s.replace(" ", " ") for s in split_text]
|
||||
|
||||
|
||||
async def get_zhenxun_help(
|
||||
module: str, metadata: PluginMetadata, extra: PluginExtraData, is_superuser: bool
|
||||
) -> str | bytes:
|
||||
"""构建ZhenXun帮助详情
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
metadata: PluginMetadata
|
||||
extra: PluginExtraData
|
||||
is_superuser: 是否超级用户帮助
|
||||
|
||||
返回:
|
||||
str | bytes: 返回信息
|
||||
"""
|
||||
call_count = await Statistics.filter(plugin_name=module).count()
|
||||
usage = metadata.usage
|
||||
if is_superuser:
|
||||
if not extra.superuser_help:
|
||||
return "该功能没有超级用户帮助信息"
|
||||
usage = extra.superuser_help
|
||||
return await template_to_pic(
|
||||
template_path=str((TEMPLATE_PATH / "help_detail").absolute()),
|
||||
template_name="main.html",
|
||||
templates={
|
||||
"title": metadata.name,
|
||||
"author": extra.author,
|
||||
"version": extra.version,
|
||||
"call_count": call_count,
|
||||
"descriptions": split_text(metadata.description),
|
||||
"usages": split_text(usage),
|
||||
},
|
||||
pages={
|
||||
"viewport": {"width": 824, "height": 590},
|
||||
"base_url": f"file://{TEMPLATE_PATH}",
|
||||
},
|
||||
wait=2,
|
||||
)
|
||||
|
||||
|
||||
async def get_plugin_help(user_id: str, name: str, is_superuser: bool) -> str | bytes:
|
||||
"""获取功能的帮助信息
|
||||
|
||||
参数:
|
||||
@ -100,20 +193,12 @@ async def get_plugin_help(
|
||||
if plugin:
|
||||
_plugin = nonebot.get_plugin_by_module_name(plugin.module_path)
|
||||
if _plugin and _plugin.metadata:
|
||||
items = None
|
||||
if is_superuser:
|
||||
extra = _plugin.metadata.extra
|
||||
if usage := extra.get("superuser_help"):
|
||||
items = {
|
||||
"简介": _plugin.metadata.description,
|
||||
"用法": usage,
|
||||
}
|
||||
extra_data = PluginExtraData(**_plugin.metadata.extra)
|
||||
if Config.get_config("help", "detail_type") == "zhenxun":
|
||||
return await get_zhenxun_help(
|
||||
plugin.module, _plugin.metadata, extra_data, is_superuser
|
||||
)
|
||||
else:
|
||||
items = {
|
||||
"简介": _plugin.metadata.description,
|
||||
"用法": _plugin.metadata.usage,
|
||||
}
|
||||
if items:
|
||||
return await ImageTemplate.hl_page(plugin.name, items)
|
||||
return await get_normal_help(_plugin.metadata, extra_data, is_superuser)
|
||||
return "糟糕! 该功能没有帮助喔..."
|
||||
return "没有查找到这个功能噢..."
|
||||
|
||||
@ -21,7 +21,7 @@ from zhenxun.utils.message import MessageUtils
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="笨蛋检测",
|
||||
description="功能名称当命令检测",
|
||||
usage="""被动""".strip(),
|
||||
usage="""当一些笨蛋直接输入功能名称时,提示笨蛋使用帮助指令查看功能帮助""".strip(),
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="0.1",
|
||||
|
||||
@ -53,7 +53,7 @@ Config.add_plugin_config(
|
||||
"hook",
|
||||
"RECORD_BOT_SENT_MESSAGES",
|
||||
True,
|
||||
help="记录bot消息校内",
|
||||
help="记录bot消息发送",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
)
|
||||
|
||||
@ -14,9 +14,7 @@ from .auth_checker import LimitManager, auth
|
||||
|
||||
# # 权限检测
|
||||
@run_preprocessor
|
||||
async def _(
|
||||
matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg
|
||||
):
|
||||
async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg):
|
||||
start_time = time.time()
|
||||
await auth(
|
||||
matcher,
|
||||
|
||||
@ -20,6 +20,7 @@ from zhenxun.utils.enum import (
|
||||
PluginLimitType,
|
||||
PluginType,
|
||||
)
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
from .manager import manager
|
||||
|
||||
@ -95,7 +96,7 @@ async def _handle_setting(
|
||||
)
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
"""
|
||||
初始化插件数据配置
|
||||
|
||||
@ -10,6 +10,7 @@ from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.task_info import TaskInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
driver: Driver = nonebot.get_driver()
|
||||
|
||||
@ -132,7 +133,7 @@ async def create_schedule(task: Task):
|
||||
logger.error(f"动态创建定时任务 {task.name}({task.module}) 失败", e=e)
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
"""
|
||||
初始化插件数据配置
|
||||
|
||||
@ -95,13 +95,6 @@ _matcher = on_alconna(
|
||||
block=True,
|
||||
)
|
||||
|
||||
_matcher.shortcut(
|
||||
r"1111",
|
||||
command="mahiro-bank",
|
||||
arguments=["test"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
_matcher.shortcut(
|
||||
r"存款\s*(?P<amount>\d+)?",
|
||||
command="mahiro-bank",
|
||||
|
||||
@ -241,7 +241,7 @@ class BankManager:
|
||||
@classmethod
|
||||
async def get_bank_info(cls) -> bytes:
|
||||
now = datetime.now()
|
||||
now_start = datetime.now() - timedelta(
|
||||
now_start = now - timedelta(
|
||||
hours=now.hour, minutes=now.minute, seconds=now.second
|
||||
)
|
||||
(
|
||||
@ -255,7 +255,9 @@ class BankManager:
|
||||
MahiroBank.annotate(
|
||||
amount_sum=Sum("amount"), user_count=Count("id")
|
||||
).values("amount_sum", "user_count"),
|
||||
MahiroBankLog.filter(create_time__gt=now_start).count(),
|
||||
MahiroBankLog.filter(
|
||||
create_time__gt=now_start, handle_type=BankHandleType.DEPOSIT
|
||||
).count(),
|
||||
MahiroBankLog.filter(handle_type=BankHandleType.INTEREST)
|
||||
.annotate(amount_sum=Sum("amount"))
|
||||
.values("amount_sum"),
|
||||
|
||||
@ -1,17 +1,12 @@
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
from nonebot import on_regex
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.params import Depends, RegexGroup
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot.rule import to_me
|
||||
from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
Args,
|
||||
Arparma,
|
||||
CommandMeta,
|
||||
Option,
|
||||
on_alconna,
|
||||
store_true,
|
||||
)
|
||||
from nonebot_plugin_alconna import Alconna, Option, on_alconna, store_true
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import BotConfig, Config
|
||||
@ -59,22 +54,15 @@ __plugin_meta__ = PluginMetadata(
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
_nickname_matcher = on_alconna(
|
||||
Alconna(
|
||||
"re:(?:以后)?(?:叫我|请叫我|称呼我)",
|
||||
Args["name?", str],
|
||||
meta=CommandMeta(compact=True),
|
||||
),
|
||||
_nickname_matcher = on_regex(
|
||||
"(?:以后)?(?:叫我|请叫我|称呼我)(.*)",
|
||||
rule=to_me(),
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
_global_nickname_matcher = on_alconna(
|
||||
Alconna("设置全局昵称", Args["name?", str], meta=CommandMeta(compact=True)),
|
||||
rule=to_me(),
|
||||
priority=5,
|
||||
block=True,
|
||||
_global_nickname_matcher = on_regex(
|
||||
"设置全局昵称(.*)", rule=to_me(), priority=5, block=True
|
||||
)
|
||||
|
||||
_matcher = on_alconna(
|
||||
@ -129,32 +117,34 @@ CANCEL = [
|
||||
]
|
||||
|
||||
|
||||
async def CheckNickname(
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
params: Arparma,
|
||||
):
|
||||
def CheckNickname():
|
||||
"""
|
||||
检查名称是否合法
|
||||
"""
|
||||
black_word = Config.get_config("nickname", "BLACK_WORD")
|
||||
name = params.query("name")
|
||||
logger.debug(f"昵称检查: {name}", "昵称设置", session=session)
|
||||
if not name:
|
||||
await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish(
|
||||
at_sender=True
|
||||
)
|
||||
if session.user.id in bot.config.superusers:
|
||||
logger.debug(
|
||||
f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session
|
||||
)
|
||||
else:
|
||||
|
||||
async def dependency(
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
reg_group: tuple[Any, ...] = RegexGroup(),
|
||||
):
|
||||
black_word = Config.get_config("nickname", "BLACK_WORD")
|
||||
(name,) = reg_group
|
||||
logger.debug(f"昵称检查: {name}", "昵称设置", session=session)
|
||||
if not name:
|
||||
await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish(
|
||||
at_sender=True
|
||||
)
|
||||
if session.user.id in bot.config.superusers:
|
||||
logger.debug(
|
||||
f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session
|
||||
)
|
||||
return
|
||||
if len(name) > 20:
|
||||
await MessageUtils.build_message("昵称可不能超过20个字!").finish(
|
||||
at_sender=True
|
||||
)
|
||||
if name in bot.config.nickname:
|
||||
await MessageUtils.build_message("笨蛋!休想占用我的名字! ").finish(
|
||||
await MessageUtils.build_message("笨蛋!休想占用我的名字! #").finish(
|
||||
at_sender=True
|
||||
)
|
||||
if black_word:
|
||||
@ -172,17 +162,17 @@ async def CheckNickname(
|
||||
await MessageUtils.build_message(
|
||||
f"字符 [{word}] 为禁止字符!"
|
||||
).finish(at_sender=True)
|
||||
return name
|
||||
|
||||
return Depends(dependency)
|
||||
|
||||
|
||||
@_nickname_matcher.handle()
|
||||
@_nickname_matcher.handle(parameterless=[CheckNickname()])
|
||||
async def _(
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
name_: Arparma,
|
||||
uname: str = UserName(),
|
||||
reg_group: tuple[Any, ...] = RegexGroup(),
|
||||
):
|
||||
name = await CheckNickname(bot, session, name_)
|
||||
(name,) = reg_group
|
||||
if len(name) < 5 and random.random() < 0.3:
|
||||
name = "~".join(name)
|
||||
group_id = None
|
||||
@ -210,14 +200,13 @@ async def _(
|
||||
)
|
||||
|
||||
|
||||
@_global_nickname_matcher.handle()
|
||||
@_global_nickname_matcher.handle(parameterless=[CheckNickname()])
|
||||
async def _(
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
name_: Arparma,
|
||||
nickname: str = UserName(),
|
||||
reg_group: tuple[Any, ...] = RegexGroup(),
|
||||
):
|
||||
name = await CheckNickname(bot, session, name_)
|
||||
(name,) = reg_group
|
||||
await FriendUser.set_user_nickname(
|
||||
session.user.id,
|
||||
name,
|
||||
@ -238,14 +227,15 @@ async def _(session: Uninfo, uname: str = UserName()):
|
||||
group_id = session.group.parent.id if session.group.parent else session.group.id
|
||||
if group_id:
|
||||
nickname = await GroupInfoUser.get_user_nickname(session.user.id, group_id)
|
||||
card = uname
|
||||
else:
|
||||
nickname = await FriendUser.get_user_nickname(session.user.id)
|
||||
card = uname
|
||||
if nickname:
|
||||
await MessageUtils.build_message(random.choice(REMIND).format(nickname)).finish(
|
||||
reply_to=True
|
||||
)
|
||||
else:
|
||||
card = uname
|
||||
await MessageUtils.build_message(
|
||||
random.choice(
|
||||
[
|
||||
|
||||
@ -147,7 +147,7 @@ class GroupManager:
|
||||
e=e,
|
||||
)
|
||||
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e
|
||||
# await GroupConsole.filter(group_id=group_id).delete()
|
||||
#await GroupConsole.filter(group_id=group_id).delete()
|
||||
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...")
|
||||
else:
|
||||
await cls.__handle_add_group(bot, group_id, group)
|
||||
|
||||
@ -31,4 +31,4 @@ async def _(session: Uninfo):
|
||||
await FriendUser.create(
|
||||
user_id=session.user.id, platform=PlatformUtils.get_platform(session)
|
||||
)
|
||||
logger.info("添加当前好友用户信息", "", session=session)
|
||||
logger.info("添加当前好友用户信息", "", session=session)
|
||||
@ -10,13 +10,4 @@ DEFAULT_GITHUB_URL = "https://github.com/zhenxun-org/zhenxun_bot_plugins/tree/ma
|
||||
EXTRA_GITHUB_URL = "https://github.com/zhenxun-org/zhenxun_bot_plugins_index/tree/index"
|
||||
"""插件库索引github仓库地址"""
|
||||
|
||||
GITEE_RAW_URL = "https://gitee.com/two_Dimension/zhenxun_bot_plugins/raw/main"
|
||||
"""GITEE仓库文件内容"""
|
||||
|
||||
GITEE_CONTENTS_URL = (
|
||||
"https://gitee.com/api/v5/repos/two_Dimension/zhenxun_bot_plugins/contents"
|
||||
)
|
||||
"""GITEE仓库文件列表获取"""
|
||||
|
||||
|
||||
LOG_COMMAND = "插件商店"
|
||||
|
||||
@ -1,51 +0,0 @@
|
||||
from nonebot.plugin import PluginMetadata
|
||||
|
||||
from zhenxun.configs.utils import PluginExtraData
|
||||
from zhenxun.utils.enum import PluginType
|
||||
|
||||
from . import command # noqa: F401
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="定时任务管理",
|
||||
description="查看和管理由 SchedulerManager 控制的定时任务。",
|
||||
usage="""
|
||||
📋 定时任务管理 - 支持群聊和私聊操作
|
||||
|
||||
🔍 查看任务:
|
||||
定时任务 查看 [-all] [-g <群号>] [-p <插件>] [--page <页码>]
|
||||
• 群聊中: 查看本群任务
|
||||
• 私聊中: 必须使用 -g <群号> 或 -all 选项 (SUPERUSER)
|
||||
|
||||
📊 任务状态:
|
||||
定时任务 状态 <任务ID> 或 任务状态 <任务ID>
|
||||
• 查看单个任务的详细信息和状态
|
||||
|
||||
⚙️ 任务管理 (SUPERUSER):
|
||||
定时任务 设置 <插件> [时间选项] [-g <群号> | -g all] [--kwargs <参数>]
|
||||
定时任务 删除 <任务ID> | -p <插件> [-g <群号>] | -all
|
||||
定时任务 暂停 <任务ID> | -p <插件> [-g <群号>] | -all
|
||||
定时任务 恢复 <任务ID> | -p <插件> [-g <群号>] | -all
|
||||
定时任务 执行 <任务ID>
|
||||
定时任务 更新 <任务ID> [时间选项] [--kwargs <参数>]
|
||||
|
||||
📝 时间选项 (三选一):
|
||||
--cron "<分> <时> <日> <月> <周>" # 例: --cron "0 8 * * *"
|
||||
--interval <时间间隔> # 例: --interval 30m, 2h, 10s
|
||||
--date "<YYYY-MM-DD HH:MM:SS>" # 例: --date "2024-01-01 08:00:00"
|
||||
--daily "<HH:MM>" # 例: --daily "08:30"
|
||||
|
||||
📚 其他功能:
|
||||
定时任务 插件列表 # 查看所有可设置定时任务的插件 (SUPERUSER)
|
||||
|
||||
🏷️ 别名支持:
|
||||
查看: ls, list | 设置: add, 开启 | 删除: del, rm, remove, 关闭, 取消
|
||||
暂停: pause | 恢复: resume | 执行: trigger, run | 状态: status, info
|
||||
更新: update, modify, 修改 | 插件列表: plugins
|
||||
""".strip(),
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="0.1.2",
|
||||
plugin_type=PluginType.SUPERUSER,
|
||||
is_show=False,
|
||||
).to_dict(),
|
||||
)
|
||||
@ -1,836 +0,0 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
from nonebot.params import Depends
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
AlconnaMatch,
|
||||
Args,
|
||||
Arparma,
|
||||
Match,
|
||||
Option,
|
||||
Query,
|
||||
Subcommand,
|
||||
on_alconna,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from zhenxun.utils._image_template import ImageTemplate
|
||||
from zhenxun.utils.manager.schedule_manager import scheduler_manager
|
||||
|
||||
|
||||
def _get_type_name(annotation) -> str:
|
||||
"""获取类型注解的名称"""
|
||||
if hasattr(annotation, "__name__"):
|
||||
return annotation.__name__
|
||||
elif hasattr(annotation, "_name"):
|
||||
return annotation._name
|
||||
else:
|
||||
return str(annotation)
|
||||
|
||||
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.rules import admin_check
|
||||
|
||||
|
||||
def _format_trigger(schedule_status: dict) -> str:
|
||||
"""将触发器配置格式化为人类可读的字符串"""
|
||||
trigger_type = schedule_status["trigger_type"]
|
||||
config = schedule_status["trigger_config"]
|
||||
|
||||
if trigger_type == "cron":
|
||||
minute = config.get("minute", "*")
|
||||
hour = config.get("hour", "*")
|
||||
day = config.get("day", "*")
|
||||
month = config.get("month", "*")
|
||||
day_of_week = config.get("day_of_week", "*")
|
||||
|
||||
if day == "*" and month == "*" and day_of_week == "*":
|
||||
formatted_hour = hour if hour == "*" else f"{int(hour):02d}"
|
||||
formatted_minute = minute if minute == "*" else f"{int(minute):02d}"
|
||||
return f"每天 {formatted_hour}:{formatted_minute}"
|
||||
else:
|
||||
return f"Cron: {minute} {hour} {day} {month} {day_of_week}"
|
||||
elif trigger_type == "interval":
|
||||
seconds = config.get("seconds", 0)
|
||||
minutes = config.get("minutes", 0)
|
||||
hours = config.get("hours", 0)
|
||||
days = config.get("days", 0)
|
||||
if days:
|
||||
trigger_str = f"每 {days} 天"
|
||||
elif hours:
|
||||
trigger_str = f"每 {hours} 小时"
|
||||
elif minutes:
|
||||
trigger_str = f"每 {minutes} 分钟"
|
||||
else:
|
||||
trigger_str = f"每 {seconds} 秒"
|
||||
elif trigger_type == "date":
|
||||
run_date = config.get("run_date", "未知时间")
|
||||
trigger_str = f"在 {run_date}"
|
||||
else:
|
||||
trigger_str = f"{trigger_type}: {config}"
|
||||
|
||||
return trigger_str
|
||||
|
||||
|
||||
def _format_params(schedule_status: dict) -> str:
|
||||
"""将任务参数格式化为人类可读的字符串"""
|
||||
if kwargs := schedule_status.get("job_kwargs"):
|
||||
kwargs_str = " | ".join(f"{k}: {v}" for k, v in kwargs.items())
|
||||
return kwargs_str
|
||||
return "-"
|
||||
|
||||
|
||||
def _parse_interval(interval_str: str) -> dict:
|
||||
"""增强版解析器,支持 d(天)"""
|
||||
match = re.match(r"(\d+)([smhd])", interval_str.lower())
|
||||
if not match:
|
||||
raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。")
|
||||
|
||||
value, unit = int(match.group(1)), match.group(2)
|
||||
if unit == "s":
|
||||
return {"seconds": value}
|
||||
if unit == "m":
|
||||
return {"minutes": value}
|
||||
if unit == "h":
|
||||
return {"hours": value}
|
||||
if unit == "d":
|
||||
return {"days": value}
|
||||
return {}
|
||||
|
||||
|
||||
def _parse_daily_time(time_str: str) -> dict:
|
||||
"""解析 HH:MM 或 HH:MM:SS 格式的时间为 cron 配置"""
|
||||
if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str):
|
||||
hour, minute, second = match.groups()
|
||||
hour, minute = int(hour), int(minute)
|
||||
|
||||
if not (0 <= hour <= 23 and 0 <= minute <= 59):
|
||||
raise ValueError("小时或分钟数值超出范围。")
|
||||
|
||||
cron_config = {
|
||||
"minute": str(minute),
|
||||
"hour": str(hour),
|
||||
"day": "*",
|
||||
"month": "*",
|
||||
"day_of_week": "*",
|
||||
}
|
||||
if second is not None:
|
||||
if not (0 <= int(second) <= 59):
|
||||
raise ValueError("秒数值超出范围。")
|
||||
cron_config["second"] = str(second)
|
||||
|
||||
return cron_config
|
||||
else:
|
||||
raise ValueError("时间格式错误,请使用 'HH:MM' 或 'HH:MM:SS' 格式。")
|
||||
|
||||
|
||||
async def GetBotId(
|
||||
bot: Bot,
|
||||
bot_id_match: Match[str] = AlconnaMatch("bot_id"),
|
||||
) -> str:
|
||||
"""获取要操作的Bot ID"""
|
||||
if bot_id_match.available:
|
||||
return bot_id_match.result
|
||||
return bot.self_id
|
||||
|
||||
|
||||
class ScheduleTarget:
|
||||
"""定时任务操作目标的基类"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TargetByID(ScheduleTarget):
|
||||
"""按任务ID操作"""
|
||||
|
||||
def __init__(self, id: int):
|
||||
self.id = id
|
||||
|
||||
|
||||
class TargetByPlugin(ScheduleTarget):
|
||||
"""按插件名操作"""
|
||||
|
||||
def __init__(
|
||||
self, plugin: str, group_id: str | None = None, all_groups: bool = False
|
||||
):
|
||||
self.plugin = plugin
|
||||
self.group_id = group_id
|
||||
self.all_groups = all_groups
|
||||
|
||||
|
||||
class TargetAll(ScheduleTarget):
|
||||
"""操作所有任务"""
|
||||
|
||||
def __init__(self, for_group: str | None = None):
|
||||
self.for_group = for_group
|
||||
|
||||
|
||||
TargetScope = TargetByID | TargetByPlugin | TargetAll | None
|
||||
|
||||
|
||||
def create_target_parser(subcommand_name: str):
|
||||
"""
|
||||
创建一个依赖注入函数,用于解析删除、暂停、恢复等命令的操作目标。
|
||||
"""
|
||||
|
||||
async def dependency(
|
||||
event: Event,
|
||||
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
|
||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||
group_id: Match[str] = AlconnaMatch("group_id"),
|
||||
all_enabled: Query[bool] = Query(f"{subcommand_name}.all"),
|
||||
) -> TargetScope:
|
||||
if schedule_id.available:
|
||||
return TargetByID(schedule_id.result)
|
||||
|
||||
if plugin_name.available:
|
||||
p_name = plugin_name.result
|
||||
if all_enabled.available:
|
||||
return TargetByPlugin(plugin=p_name, all_groups=True)
|
||||
elif group_id.available:
|
||||
gid = group_id.result
|
||||
if gid.lower() == "all":
|
||||
return TargetByPlugin(plugin=p_name, all_groups=True)
|
||||
return TargetByPlugin(plugin=p_name, group_id=gid)
|
||||
else:
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
if current_group_id:
|
||||
return TargetByPlugin(plugin=p_name, group_id=str(current_group_id))
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"私聊中操作插件任务必须使用 -g <群号> 或 -all 选项。"
|
||||
)
|
||||
|
||||
if all_enabled.available:
|
||||
return TargetAll(for_group=group_id.result if group_id.available else None)
|
||||
|
||||
return None
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
schedule_cmd = on_alconna(
|
||||
Alconna(
|
||||
"定时任务",
|
||||
Subcommand(
|
||||
"查看",
|
||||
Option("-g", Args["target_group_id", str]),
|
||||
Option("-all", help_text="查看所有群聊 (SUPERUSER)"),
|
||||
Option("-p", Args["plugin_name", str], help_text="按插件名筛选"),
|
||||
Option("--page", Args["page", int, 1], help_text="指定页码"),
|
||||
alias=["ls", "list"],
|
||||
help_text="查看定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"设置",
|
||||
Args["plugin_name", str],
|
||||
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
||||
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
||||
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
||||
Option(
|
||||
"--daily",
|
||||
Args["daily_expr", str],
|
||||
help_text="设置每天执行的时间 (如 08:20)",
|
||||
),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"),
|
||||
Option("-all", help_text="对所有群生效 (等同于 -g all)"),
|
||||
Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["add", "开启"],
|
||||
help_text="设置/开启一个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"删除",
|
||||
Args["schedule_id?", int],
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID"),
|
||||
Option("-all", help_text="对所有群生效"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["del", "rm", "remove", "关闭", "取消"],
|
||||
help_text="删除一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"暂停",
|
||||
Args["schedule_id?", int],
|
||||
Option("-all", help_text="对当前群所有任务生效"),
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["pause"],
|
||||
help_text="暂停一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"恢复",
|
||||
Args["schedule_id?", int],
|
||||
Option("-all", help_text="对当前群所有任务生效"),
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["resume"],
|
||||
help_text="恢复一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"执行",
|
||||
Args["schedule_id", int],
|
||||
alias=["trigger", "run"],
|
||||
help_text="立即执行一次任务",
|
||||
),
|
||||
Subcommand(
|
||||
"更新",
|
||||
Args["schedule_id", int],
|
||||
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
||||
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
||||
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
||||
Option(
|
||||
"--daily",
|
||||
Args["daily_expr", str],
|
||||
help_text="更新每天执行的时间 (如 08:20)",
|
||||
),
|
||||
Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"),
|
||||
alias=["update", "modify", "修改"],
|
||||
help_text="更新任务配置",
|
||||
),
|
||||
Subcommand(
|
||||
"状态",
|
||||
Args["schedule_id", int],
|
||||
alias=["status", "info"],
|
||||
help_text="查看单个任务的详细状态",
|
||||
),
|
||||
Subcommand(
|
||||
"插件列表",
|
||||
alias=["plugins"],
|
||||
help_text="列出所有可用的插件",
|
||||
),
|
||||
),
|
||||
priority=5,
|
||||
block=True,
|
||||
rule=admin_check(1),
|
||||
)
|
||||
|
||||
schedule_cmd.shortcut(
|
||||
"任务状态",
|
||||
command="定时任务",
|
||||
arguments=["状态", "{%0}"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.handle()
|
||||
async def _handle_time_options_mutex(arp: Arparma):
|
||||
time_options = ["cron", "interval", "date", "daily"]
|
||||
provided_options = [opt for opt in time_options if arp.query(opt) is not None]
|
||||
if len(provided_options) > 1:
|
||||
await schedule_cmd.finish(
|
||||
f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。"
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.assign("查看")
|
||||
async def _(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
target_group_id: Match[str] = AlconnaMatch("target_group_id"),
|
||||
all_groups: Query[bool] = Query("查看.all"),
|
||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||
page: Match[int] = AlconnaMatch("page"),
|
||||
):
|
||||
is_superuser = await SUPERUSER(bot, event)
|
||||
schedules = []
|
||||
title = ""
|
||||
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
if not (all_groups.available or target_group_id.available) and not current_group_id:
|
||||
await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。")
|
||||
|
||||
if all_groups.available:
|
||||
if not is_superuser:
|
||||
await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。")
|
||||
schedules = await scheduler_manager.get_all_schedules()
|
||||
title = "所有群组的定时任务"
|
||||
elif target_group_id.available:
|
||||
if not is_superuser:
|
||||
await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。")
|
||||
gid = target_group_id.result
|
||||
schedules = [
|
||||
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
|
||||
]
|
||||
title = f"群 {gid} 的定时任务"
|
||||
else:
|
||||
gid = str(current_group_id)
|
||||
schedules = [
|
||||
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
|
||||
]
|
||||
title = "本群的定时任务"
|
||||
|
||||
if plugin_name.available:
|
||||
schedules = [s for s in schedules if s.plugin_name == plugin_name.result]
|
||||
title += f" [插件: {plugin_name.result}]"
|
||||
|
||||
if not schedules:
|
||||
await schedule_cmd.finish("没有找到任何相关的定时任务。")
|
||||
|
||||
page_size = 15
|
||||
current_page = page.result
|
||||
total_items = len(schedules)
|
||||
total_pages = (total_items + page_size - 1) // page_size
|
||||
start_index = (current_page - 1) * page_size
|
||||
end_index = start_index + page_size
|
||||
paginated_schedules = schedules[start_index:end_index]
|
||||
|
||||
if not paginated_schedules:
|
||||
await schedule_cmd.finish("这一页没有内容了哦~")
|
||||
|
||||
status_tasks = [
|
||||
scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules
|
||||
]
|
||||
all_statuses = await asyncio.gather(*status_tasks)
|
||||
data_list = [
|
||||
[
|
||||
s["id"],
|
||||
s["plugin_name"],
|
||||
s.get("bot_id") or "N/A",
|
||||
s["group_id"] or "全局",
|
||||
s["next_run_time"],
|
||||
_format_trigger(s),
|
||||
_format_params(s),
|
||||
"✔️ 已启用" if s["is_enabled"] else "⏸️ 已暂停",
|
||||
]
|
||||
for s in all_statuses
|
||||
if s
|
||||
]
|
||||
|
||||
if not data_list:
|
||||
await schedule_cmd.finish("没有找到任何相关的定时任务。")
|
||||
|
||||
img = await ImageTemplate.table_page(
|
||||
head_text=title,
|
||||
tip_text=f"第 {current_page}/{total_pages} 页,共 {total_items} 条任务",
|
||||
column_name=[
|
||||
"ID",
|
||||
"插件",
|
||||
"Bot ID",
|
||||
"群组/目标",
|
||||
"下次运行",
|
||||
"触发规则",
|
||||
"参数",
|
||||
"状态",
|
||||
],
|
||||
data_list=data_list,
|
||||
column_space=20,
|
||||
)
|
||||
await MessageUtils.build_message(img).send(reply_to=True)
|
||||
|
||||
|
||||
@schedule_cmd.assign("设置")
|
||||
async def _(
|
||||
event: Event,
|
||||
plugin_name: str,
|
||||
cron_expr: str | None = None,
|
||||
interval_expr: str | None = None,
|
||||
date_expr: str | None = None,
|
||||
daily_expr: str | None = None,
|
||||
group_id: str | None = None,
|
||||
kwargs_str: str | None = None,
|
||||
all_enabled: Query[bool] = Query("设置.all"),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if plugin_name not in scheduler_manager._registered_tasks:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 没有注册可用的定时任务。\n"
|
||||
f"可用插件: {list(scheduler_manager._registered_tasks.keys())}"
|
||||
)
|
||||
|
||||
trigger_type = ""
|
||||
trigger_config = {}
|
||||
|
||||
try:
|
||||
if cron_expr:
|
||||
trigger_type = "cron"
|
||||
parts = cron_expr.split()
|
||||
if len(parts) != 5:
|
||||
raise ValueError("Cron 表达式必须有5个部分 (分 时 日 月 周)")
|
||||
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
|
||||
trigger_config = dict(zip(cron_keys, parts))
|
||||
elif interval_expr:
|
||||
trigger_type = "interval"
|
||||
trigger_config = _parse_interval(interval_expr)
|
||||
elif date_expr:
|
||||
trigger_type = "date"
|
||||
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
|
||||
elif daily_expr:
|
||||
trigger_type = "cron"
|
||||
trigger_config = _parse_daily_time(daily_expr)
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
|
||||
)
|
||||
except ValueError as e:
|
||||
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
||||
|
||||
job_kwargs = {}
|
||||
if kwargs_str:
|
||||
task_meta = scheduler_manager._registered_tasks[plugin_name]
|
||||
params_model = task_meta.get("model")
|
||||
if not params_model:
|
||||
await schedule_cmd.finish(f"插件 '{plugin_name}' 不支持设置额外参数。")
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
await schedule_cmd.finish(f"插件 '{plugin_name}' 的参数模型配置错误。")
|
||||
|
||||
raw_kwargs = {}
|
||||
try:
|
||||
for item in kwargs_str.split(","):
|
||||
key, value = item.strip().split("=", 1)
|
||||
raw_kwargs[key.strip()] = value
|
||||
except Exception as e:
|
||||
await schedule_cmd.finish(
|
||||
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
model_validate = getattr(params_model, "model_validate", None)
|
||||
if not model_validate:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 的参数模型不支持验证。"
|
||||
)
|
||||
return
|
||||
|
||||
validated_model = model_validate(raw_kwargs)
|
||||
|
||||
model_dump = getattr(validated_model, "model_dump", None)
|
||||
if not model_dump:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 的参数模型不支持导出。"
|
||||
)
|
||||
return
|
||||
|
||||
job_kwargs = model_dump()
|
||||
except ValidationError as e:
|
||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||
error_str = "\n".join(errors)
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
|
||||
)
|
||||
return
|
||||
|
||||
target_group_id: str | None
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
|
||||
if group_id and group_id.lower() == "all":
|
||||
target_group_id = "__ALL_GROUPS__"
|
||||
elif all_enabled.available:
|
||||
target_group_id = "__ALL_GROUPS__"
|
||||
elif group_id:
|
||||
target_group_id = group_id
|
||||
elif current_group_id:
|
||||
target_group_id = str(current_group_id)
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。"
|
||||
)
|
||||
return
|
||||
|
||||
success, msg = await scheduler_manager.add_schedule(
|
||||
plugin_name,
|
||||
target_group_id,
|
||||
trigger_type,
|
||||
trigger_config,
|
||||
job_kwargs,
|
||||
bot_id=bot_id_to_operate,
|
||||
)
|
||||
|
||||
if target_group_id == "__ALL_GROUPS__":
|
||||
target_desc = f"所有群组 (Bot: {bot_id_to_operate})"
|
||||
elif target_group_id is None:
|
||||
target_desc = "全局"
|
||||
else:
|
||||
target_desc = f"群组 {target_group_id}"
|
||||
|
||||
if success:
|
||||
await schedule_cmd.finish(f"已成功为 [{target_desc}] {msg}")
|
||||
else:
|
||||
await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败: {msg}")
|
||||
|
||||
|
||||
@schedule_cmd.assign("删除")
|
||||
async def _(
|
||||
target: TargetScope = Depends(create_target_parser("删除")),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if isinstance(target, TargetByID):
|
||||
_, message = await scheduler_manager.remove_schedule_by_id(target.id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetByPlugin):
|
||||
p_name = target.plugin
|
||||
if p_name not in scheduler_manager.get_registered_plugins():
|
||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
||||
|
||||
if target.all_groups:
|
||||
removed_count = await scheduler_manager.remove_schedule_for_all(
|
||||
p_name, bot_id=bot_id_to_operate
|
||||
)
|
||||
message = (
|
||||
f"已取消了 {removed_count} 个群组的插件 '{p_name}' 定时任务。"
|
||||
if removed_count > 0
|
||||
else f"没有找到插件 '{p_name}' 的定时任务。"
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.remove_schedule(
|
||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetAll):
|
||||
if target.for_group:
|
||||
_, message = await scheduler_manager.remove_schedules_by_group(
|
||||
target.for_group
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.remove_all_schedules()
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"删除任务失败:请提供任务ID,或通过 -p <插件> 或 -all 指定要删除的任务。"
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.assign("暂停")
|
||||
async def _(
|
||||
target: TargetScope = Depends(create_target_parser("暂停")),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if isinstance(target, TargetByID):
|
||||
_, message = await scheduler_manager.pause_schedule(target.id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetByPlugin):
|
||||
p_name = target.plugin
|
||||
if p_name not in scheduler_manager.get_registered_plugins():
|
||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
||||
|
||||
if target.all_groups:
|
||||
_, message = await scheduler_manager.pause_schedules_by_plugin(p_name)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.pause_schedule_by_plugin_group(
|
||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetAll):
|
||||
if target.for_group:
|
||||
_, message = await scheduler_manager.pause_schedules_by_group(
|
||||
target.for_group
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.pause_all_schedules()
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
else:
|
||||
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
|
||||
|
||||
|
||||
@schedule_cmd.assign("恢复")
|
||||
async def _(
|
||||
target: TargetScope = Depends(create_target_parser("恢复")),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if isinstance(target, TargetByID):
|
||||
_, message = await scheduler_manager.resume_schedule(target.id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetByPlugin):
|
||||
p_name = target.plugin
|
||||
if p_name not in scheduler_manager.get_registered_plugins():
|
||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
||||
|
||||
if target.all_groups:
|
||||
_, message = await scheduler_manager.resume_schedules_by_plugin(p_name)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.resume_schedule_by_plugin_group(
|
||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetAll):
|
||||
if target.for_group:
|
||||
_, message = await scheduler_manager.resume_schedules_by_group(
|
||||
target.for_group
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.resume_all_schedules()
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
else:
|
||||
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
|
||||
|
||||
|
||||
@schedule_cmd.assign("执行")
|
||||
async def _(schedule_id: int):
|
||||
_, message = await scheduler_manager.trigger_now(schedule_id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("更新")
|
||||
async def _(
|
||||
schedule_id: int,
|
||||
cron_expr: str | None = None,
|
||||
interval_expr: str | None = None,
|
||||
date_expr: str | None = None,
|
||||
daily_expr: str | None = None,
|
||||
kwargs_str: str | None = None,
|
||||
):
|
||||
if not any([cron_expr, interval_expr, date_expr, daily_expr, kwargs_str]):
|
||||
await schedule_cmd.finish(
|
||||
"请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)"
|
||||
)
|
||||
|
||||
trigger_config = None
|
||||
trigger_type = None
|
||||
try:
|
||||
if cron_expr:
|
||||
trigger_type = "cron"
|
||||
parts = cron_expr.split()
|
||||
if len(parts) != 5:
|
||||
raise ValueError("Cron 表达式必须有5个部分")
|
||||
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
|
||||
trigger_config = dict(zip(cron_keys, parts))
|
||||
elif interval_expr:
|
||||
trigger_type = "interval"
|
||||
trigger_config = _parse_interval(interval_expr)
|
||||
elif date_expr:
|
||||
trigger_type = "date"
|
||||
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
|
||||
elif daily_expr:
|
||||
trigger_type = "cron"
|
||||
trigger_config = _parse_daily_time(daily_expr)
|
||||
except ValueError as e:
|
||||
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
||||
|
||||
job_kwargs = None
|
||||
if kwargs_str:
|
||||
schedule = await scheduler_manager.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
await schedule_cmd.finish(f"未找到 ID 为 {schedule_id} 的任务。")
|
||||
|
||||
task_meta = scheduler_manager._registered_tasks.get(schedule.plugin_name)
|
||||
if not task_meta or not (params_model := task_meta.get("model")):
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 未定义参数模型,无法更新参数。"
|
||||
)
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 的参数模型配置错误。"
|
||||
)
|
||||
|
||||
raw_kwargs = {}
|
||||
try:
|
||||
for item in kwargs_str.split(","):
|
||||
key, value = item.strip().split("=", 1)
|
||||
raw_kwargs[key.strip()] = value
|
||||
except Exception as e:
|
||||
await schedule_cmd.finish(
|
||||
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
model_validate = getattr(params_model, "model_validate", None)
|
||||
if not model_validate:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 的参数模型不支持验证。"
|
||||
)
|
||||
return
|
||||
|
||||
validated_model = model_validate(raw_kwargs)
|
||||
|
||||
model_dump = getattr(validated_model, "model_dump", None)
|
||||
if not model_dump:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 的参数模型不支持导出。"
|
||||
)
|
||||
return
|
||||
|
||||
job_kwargs = model_dump(exclude_unset=True)
|
||||
except ValidationError as e:
|
||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||
error_str = "\n".join(errors)
|
||||
await schedule_cmd.finish(f"更新的参数验证失败:\n{error_str}")
|
||||
return
|
||||
|
||||
_, message = await scheduler_manager.update_schedule(
|
||||
schedule_id, trigger_type, trigger_config, job_kwargs
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("插件列表")
|
||||
async def _():
|
||||
registered_plugins = scheduler_manager.get_registered_plugins()
|
||||
if not registered_plugins:
|
||||
await schedule_cmd.finish("当前没有已注册的定时任务插件。")
|
||||
|
||||
message_parts = ["📋 已注册的定时任务插件:"]
|
||||
for i, plugin_name in enumerate(registered_plugins, 1):
|
||||
task_meta = scheduler_manager._registered_tasks[plugin_name]
|
||||
params_model = task_meta.get("model")
|
||||
|
||||
if not params_model:
|
||||
message_parts.append(f"{i}. {plugin_name} - 无参数")
|
||||
continue
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
message_parts.append(f"{i}. {plugin_name} - ⚠️ 参数模型配置错误")
|
||||
continue
|
||||
|
||||
model_fields = getattr(params_model, "model_fields", None)
|
||||
if model_fields:
|
||||
param_info = ", ".join(
|
||||
f"{field_name}({_get_type_name(field_info.annotation)})"
|
||||
for field_name, field_info in model_fields.items()
|
||||
)
|
||||
message_parts.append(f"{i}. {plugin_name} - 参数: {param_info}")
|
||||
else:
|
||||
message_parts.append(f"{i}. {plugin_name} - 无参数")
|
||||
|
||||
await schedule_cmd.finish("\n".join(message_parts))
|
||||
|
||||
|
||||
@schedule_cmd.assign("状态")
|
||||
async def _(schedule_id: int):
|
||||
status = await scheduler_manager.get_schedule_status(schedule_id)
|
||||
if not status:
|
||||
await schedule_cmd.finish(f"未找到ID为 {schedule_id} 的定时任务。")
|
||||
|
||||
info_lines = [
|
||||
f"📋 定时任务详细信息 (ID: {schedule_id})",
|
||||
"--------------------",
|
||||
f"▫️ 插件: {status['plugin_name']}",
|
||||
f"▫️ Bot ID: {status.get('bot_id') or '默认'}",
|
||||
f"▫️ 目标: {status['group_id'] or '全局'}",
|
||||
f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}",
|
||||
f"▫️ 下次运行: {status['next_run_time']}",
|
||||
f"▫️ 触发规则: {_format_trigger(status)}",
|
||||
f"▫️ 任务参数: {_format_params(status)}",
|
||||
]
|
||||
await schedule_cmd.finish("\n".join(info_lines))
|
||||
@ -1,67 +1,8 @@
|
||||
from asyncio.exceptions import TimeoutError
|
||||
|
||||
import aiofiles
|
||||
import nonebot
|
||||
from nonebot.drivers import Driver
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
import ujson as json
|
||||
|
||||
from zhenxun.configs.path_config import TEXT_PATH
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
|
||||
driver: Driver = nonebot.get_driver()
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def update_city():
|
||||
"""
|
||||
部分插件需要中国省份城市
|
||||
这里直接更新,避免插件内代码重复
|
||||
"""
|
||||
china_city = TEXT_PATH / "china_city.json"
|
||||
if not china_city.exists():
|
||||
data = {}
|
||||
try:
|
||||
logger.debug("开始更新城市列表...")
|
||||
res = await AsyncHttpx.get(
|
||||
"http://www.weather.com.cn/data/city3jdata/china.html", timeout=5
|
||||
)
|
||||
res.encoding = "utf8"
|
||||
provinces_data = json.loads(res.text)
|
||||
for province in provinces_data.keys():
|
||||
data[provinces_data[province]] = []
|
||||
res = await AsyncHttpx.get(
|
||||
f"http://www.weather.com.cn/data/city3jdata/provshi/{province}.html",
|
||||
timeout=5,
|
||||
)
|
||||
res.encoding = "utf8"
|
||||
city_data = json.loads(res.text)
|
||||
for city in city_data.keys():
|
||||
data[provinces_data[province]].append(city_data[city])
|
||||
async with aiofiles.open(china_city, "w", encoding="utf8") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
logger.info("自动更新城市列表完成.....")
|
||||
except TimeoutError as e:
|
||||
logger.warning("自动更新城市列表超时...", e=e)
|
||||
except ValueError as e:
|
||||
logger.warning("自动城市列表失败.....", e=e)
|
||||
except Exception as e:
|
||||
logger.error("自动城市列表未知错误", e=e)
|
||||
|
||||
|
||||
# 自动更新城市列表
|
||||
@scheduler.scheduled_job(
|
||||
"cron",
|
||||
hour=6,
|
||||
minute=1,
|
||||
)
|
||||
async def _():
|
||||
await update_city()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
"""开启/禁用插件格式修改"""
|
||||
_, is_create = await GroupConsole.get_or_create(group_id=133133133)
|
||||
|
||||
@ -5,7 +5,9 @@ from nonebot_plugin_alconna import (
|
||||
AlconnaQuery,
|
||||
Args,
|
||||
Arparma,
|
||||
At,
|
||||
Match,
|
||||
MultiVar,
|
||||
Option,
|
||||
Query,
|
||||
Subcommand,
|
||||
@ -33,6 +35,7 @@ __plugin_meta__ = PluginMetadata(
|
||||
usage="""
|
||||
商品操作
|
||||
指令:
|
||||
商店
|
||||
我的金币
|
||||
我的道具
|
||||
使用道具 [名称/Id]
|
||||
@ -46,6 +49,7 @@ __plugin_meta__ = PluginMetadata(
|
||||
plugin_type=PluginType.NORMAL,
|
||||
menu_type="商店",
|
||||
commands=[
|
||||
Command(command="商店"),
|
||||
Command(command="我的金币"),
|
||||
Command(command="我的道具"),
|
||||
Command(command="购买道具"),
|
||||
@ -74,13 +78,21 @@ _matcher = on_alconna(
|
||||
Subcommand("my-cost", help_text="我的金币"),
|
||||
Subcommand("my-props", help_text="我的道具"),
|
||||
Subcommand("buy", Args["name?", str]["num?", int], help_text="购买道具"),
|
||||
Subcommand("use", Args["name?", str]["num?", int], help_text="使用道具"),
|
||||
Subcommand("gold-list", Args["num?", int], help_text="金币排行"),
|
||||
),
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
_use_matcher = on_alconna(
|
||||
Alconna(
|
||||
"使用道具",
|
||||
Args["name?", str]["num?", int]["at_users?", MultiVar(At)],
|
||||
),
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
_matcher.shortcut(
|
||||
"我的金币",
|
||||
command="商店",
|
||||
@ -102,13 +114,6 @@ _matcher.shortcut(
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
_matcher.shortcut(
|
||||
"使用道具(?P<name>.*?)",
|
||||
command="商店",
|
||||
arguments=["use", "{name}"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
_matcher.shortcut(
|
||||
"金币排行",
|
||||
command="商店",
|
||||
@ -172,7 +177,7 @@ async def _(
|
||||
await MessageUtils.build_message(result).send(reply_to=True)
|
||||
|
||||
|
||||
@_matcher.assign("use")
|
||||
@_use_matcher.handle()
|
||||
async def _(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
@ -181,6 +186,7 @@ async def _(
|
||||
arparma: Arparma,
|
||||
name: Match[str],
|
||||
num: Query[int] = AlconnaQuery("num", 1),
|
||||
at_users: Query[list[At]] = AlconnaQuery("at_users", []),
|
||||
):
|
||||
if not name.available:
|
||||
await MessageUtils.build_message(
|
||||
@ -188,7 +194,7 @@ async def _(
|
||||
).finish(reply_to=True)
|
||||
try:
|
||||
result = await ShopManage.use(
|
||||
bot, event, session, message, name.result, num.result, ""
|
||||
bot, event, session, message, name.result, num.result, "", at_users.result
|
||||
)
|
||||
logger.info(
|
||||
f"使用道具 {name.result}, 数量: {num.result}",
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Any, Literal
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.compat import model_dump
|
||||
from nonebot_plugin_alconna import UniMessage, UniMsg
|
||||
from nonebot_plugin_alconna import At, UniMessage, UniMsg
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from tortoise.expressions import Q
|
||||
@ -48,6 +48,10 @@ class Goods(BaseModel):
|
||||
"""model"""
|
||||
session: Uninfo | None = None
|
||||
"""Uninfo"""
|
||||
at_user: str | None = None
|
||||
"""At对象"""
|
||||
at_users: list[str] = []
|
||||
"""At对象列表"""
|
||||
|
||||
|
||||
class ShopParam(BaseModel):
|
||||
@ -73,6 +77,10 @@ class ShopParam(BaseModel):
|
||||
"""Uninfo"""
|
||||
message: UniMsg
|
||||
"""UniMessage"""
|
||||
at_user: str | None = None
|
||||
"""At对象"""
|
||||
at_users: list[str] = []
|
||||
"""At对象列表"""
|
||||
extra_data: dict[str, Any] = Field(default_factory=dict)
|
||||
"""额外数据"""
|
||||
|
||||
@ -156,6 +164,7 @@ class ShopManage:
|
||||
goods: Goods,
|
||||
num: int,
|
||||
text: str,
|
||||
at_users: list[str] = [],
|
||||
) -> tuple[ShopParam, dict[str, Any]]:
|
||||
"""构造参数
|
||||
|
||||
@ -165,6 +174,7 @@ class ShopManage:
|
||||
goods_name: 商品名称
|
||||
num: 数量
|
||||
text: 其他信息
|
||||
at_users: at用户
|
||||
"""
|
||||
group_id = None
|
||||
if session.group:
|
||||
@ -172,6 +182,7 @@ class ShopManage:
|
||||
session.group.parent.id if session.group.parent else session.group.id
|
||||
)
|
||||
_kwargs = goods.params
|
||||
at_user = at_users[0] if at_users else None
|
||||
model = goods.model(
|
||||
**{
|
||||
"goods_name": goods.name,
|
||||
@ -183,6 +194,8 @@ class ShopManage:
|
||||
"text": text,
|
||||
"session": session,
|
||||
"message": message,
|
||||
"at_user": at_user,
|
||||
"at_users": at_users,
|
||||
}
|
||||
)
|
||||
return model, {
|
||||
@ -194,6 +207,8 @@ class ShopManage:
|
||||
"num": num,
|
||||
"text": text,
|
||||
"goods_name": goods.name,
|
||||
"at_user": at_user,
|
||||
"at_users": at_users,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -223,6 +238,7 @@ class ShopManage:
|
||||
**param.extra_data,
|
||||
"session": session,
|
||||
"message": message,
|
||||
"shop_param": ShopParam,
|
||||
}
|
||||
for key in list(param_json.keys()):
|
||||
if key not in args:
|
||||
@ -308,6 +324,7 @@ class ShopManage:
|
||||
goods_name: str,
|
||||
num: int,
|
||||
text: str,
|
||||
at_users: list[At] = [],
|
||||
) -> str | UniMessage | None:
|
||||
"""使用道具
|
||||
|
||||
@ -319,6 +336,7 @@ class ShopManage:
|
||||
goods_name: 商品名称
|
||||
num: 使用数量
|
||||
text: 其他信息
|
||||
at_users: at用户
|
||||
|
||||
返回:
|
||||
str | MessageFactory | None: 使用完成后返回信息
|
||||
@ -339,8 +357,9 @@ class ShopManage:
|
||||
goods = cls.uuid2goods.get(goods_info.uuid)
|
||||
if not goods or not goods.func:
|
||||
return f"{goods_info.goods_name} 未注册使用函数, 无法使用..."
|
||||
at_user_ids = [at.target for at in at_users]
|
||||
param, kwargs = cls.__build_params(
|
||||
bot, event, session, message, goods, num, text
|
||||
bot, event, session, message, goods, num, text, at_user_ids
|
||||
)
|
||||
if num > param.max_num_limit:
|
||||
return f"{goods_info.goods_name} 单次使用最大数量为{param.max_num_limit}..."
|
||||
@ -480,10 +499,13 @@ class ShopManage:
|
||||
if not user.props:
|
||||
return None
|
||||
|
||||
user.props = {uuid: count for uuid, count in user.props.items() if count > 0}
|
||||
|
||||
goods_list = await GoodsInfo.filter(uuid__in=user.props.keys()).all()
|
||||
goods_by_uuid = {item.uuid: item for item in goods_list}
|
||||
user.props = {
|
||||
uuid: count
|
||||
for uuid, count in user.props.items()
|
||||
if count > 0 and goods_by_uuid.get(uuid)
|
||||
}
|
||||
|
||||
table_rows = []
|
||||
for i, prop_uuid in enumerate(user.props):
|
||||
|
||||
@ -10,7 +10,6 @@ from nonebot_plugin_alconna import (
|
||||
store_true,
|
||||
)
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.utils import (
|
||||
Command,
|
||||
@ -23,7 +22,7 @@ from zhenxun.utils.depends import UserName
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
from ._data_source import SignManage
|
||||
from .goods_register import driver # noqa: F401
|
||||
from .goods_register import Uninfo
|
||||
from .utils import clear_sign_data_pic
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from decimal import Decimal
|
||||
|
||||
import nonebot
|
||||
from nonebot.drivers import Driver
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.sign_user import SignUser
|
||||
@ -9,14 +8,7 @@ from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.utils.decorator.shop import shop_register
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
driver: Driver = nonebot.get_driver()
|
||||
|
||||
|
||||
# @driver.on_startup
|
||||
# async def _():
|
||||
# """
|
||||
# 导入内置的三个商品
|
||||
# """
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@shop_register(
|
||||
|
||||
@ -16,6 +16,7 @@ from zhenxun.models.sign_log import SignLog
|
||||
from zhenxun.models.sign_user import SignUser
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
from zhenxun.utils.image_utils import BuildImage
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
from .config import (
|
||||
@ -54,7 +55,7 @@ LG_MESSAGE = [
|
||||
]
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def init_image():
|
||||
SIGN_RESOURCE_PATH.mkdir(parents=True, exist_ok=True)
|
||||
SIGN_TODAY_CARD_PATH.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@ -1,32 +1,77 @@
|
||||
from typing import Annotated
|
||||
|
||||
from nonebot import on_command
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.params import Command
|
||||
from arclet.alconna import AllParam
|
||||
from nepattern import UnionPattern
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot.rule import to_me
|
||||
from nonebot_plugin_alconna import Text as alcText
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
import nonebot_plugin_alconna as alc
|
||||
from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
Args,
|
||||
on_alconna,
|
||||
)
|
||||
from nonebot_plugin_alconna.uniseg.segment import (
|
||||
At,
|
||||
AtAll,
|
||||
Audio,
|
||||
Button,
|
||||
Emoji,
|
||||
File,
|
||||
Hyper,
|
||||
Image,
|
||||
Keyboard,
|
||||
Reference,
|
||||
Reply,
|
||||
Text,
|
||||
Video,
|
||||
Voice,
|
||||
)
|
||||
from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
from ._data_source import BroadcastManage
|
||||
from .broadcast_manager import BroadcastManager
|
||||
from .message_processor import (
|
||||
_extract_broadcast_content,
|
||||
get_broadcast_target_groups,
|
||||
send_broadcast_and_notify,
|
||||
)
|
||||
|
||||
BROADCAST_SEND_DELAY_RANGE = (1, 3)
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="广播",
|
||||
description="昭告天下!",
|
||||
usage="""
|
||||
广播 [消息] [图片]
|
||||
示例:广播 你们好!
|
||||
广播 [消息内容]
|
||||
- 直接发送消息到除当前群组外的所有群组
|
||||
- 支持文本、图片、@、表情、视频等多种消息类型
|
||||
- 示例:广播 你们好!
|
||||
- 示例:广播 [图片] 新活动开始啦!
|
||||
|
||||
广播 + 引用消息
|
||||
- 将引用的消息作为广播内容发送
|
||||
- 支持引用普通消息或合并转发消息
|
||||
- 示例:(引用一条消息) 广播
|
||||
|
||||
广播撤回
|
||||
- 撤回最近一次由您触发的广播消息
|
||||
- 仅能撤回短时间内的消息
|
||||
- 示例:广播撤回
|
||||
|
||||
特性:
|
||||
- 在群组中使用广播时,不会将消息发送到当前群组
|
||||
- 在私聊中使用广播时,会发送到所有群组
|
||||
|
||||
别名:
|
||||
- bc (广播的简写)
|
||||
- recall (广播撤回的别名)
|
||||
""".strip(),
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="0.1",
|
||||
version="1.2",
|
||||
plugin_type=PluginType.SUPERUSER,
|
||||
configs=[
|
||||
RegisterConfig(
|
||||
@ -42,26 +87,106 @@ __plugin_meta__ = PluginMetadata(
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
_matcher = on_command(
|
||||
"广播", priority=1, permission=SUPERUSER, block=True, rule=to_me()
|
||||
AnySeg = (
|
||||
UnionPattern(
|
||||
[
|
||||
Text,
|
||||
Image,
|
||||
At,
|
||||
AtAll,
|
||||
Audio,
|
||||
Video,
|
||||
File,
|
||||
Emoji,
|
||||
Reply,
|
||||
Reference,
|
||||
Hyper,
|
||||
Button,
|
||||
Keyboard,
|
||||
Voice,
|
||||
]
|
||||
)
|
||||
@ "AnySeg"
|
||||
)
|
||||
|
||||
_matcher = on_alconna(
|
||||
Alconna(
|
||||
"广播",
|
||||
Args["content?", AllParam],
|
||||
),
|
||||
aliases={"bc"},
|
||||
priority=1,
|
||||
permission=SUPERUSER,
|
||||
block=True,
|
||||
rule=to_me(),
|
||||
use_origin=False,
|
||||
)
|
||||
|
||||
_recall_matcher = on_alconna(
|
||||
Alconna("广播撤回"),
|
||||
aliases={"recall"},
|
||||
priority=1,
|
||||
permission=SUPERUSER,
|
||||
block=True,
|
||||
rule=to_me(),
|
||||
)
|
||||
|
||||
|
||||
@_matcher.handle()
|
||||
async def _(
|
||||
async def handle_broadcast(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
session: EventSession,
|
||||
message: UniMsg,
|
||||
command: Annotated[tuple[str, ...], Command()],
|
||||
arp: alc.Arparma,
|
||||
):
|
||||
for msg in message:
|
||||
if isinstance(msg, alcText) and msg.text.strip().startswith(command[0]):
|
||||
msg.text = msg.text.replace(command[0], "", 1).strip()
|
||||
break
|
||||
await MessageUtils.build_message("正在发送..请等一下哦!").send()
|
||||
count, error_count = await BroadcastManage.send(bot, message, session)
|
||||
result = f"成功广播 {count} 个群组"
|
||||
if error_count:
|
||||
result += f"\n广播失败 {error_count} 个群组"
|
||||
await MessageUtils.build_message(f"发送广播完成!\n{result}").send(reply_to=True)
|
||||
logger.info(f"发送广播信息: {message}", "广播", session=session)
|
||||
broadcast_content_msg = await _extract_broadcast_content(bot, event, arp, session)
|
||||
if not broadcast_content_msg:
|
||||
return
|
||||
|
||||
target_groups, enabled_groups = await get_broadcast_target_groups(bot, session)
|
||||
if not target_groups or not enabled_groups:
|
||||
return
|
||||
|
||||
try:
|
||||
await send_broadcast_and_notify(
|
||||
bot, event, broadcast_content_msg, enabled_groups, target_groups, session
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = "发送广播失败"
|
||||
BroadcastManager.log_error(error_msg, e, session)
|
||||
await MessageUtils.build_message(f"{error_msg}。").send(reply_to=True)
|
||||
|
||||
|
||||
@_recall_matcher.handle()
|
||||
async def handle_broadcast_recall(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
session: EventSession,
|
||||
):
|
||||
"""处理广播撤回命令"""
|
||||
await MessageUtils.build_message("正在尝试撤回最近一次广播...").send()
|
||||
|
||||
try:
|
||||
success_count, error_count = await BroadcastManager.recall_last_broadcast(
|
||||
bot, session
|
||||
)
|
||||
|
||||
user_id = str(event.get_user_id())
|
||||
if success_count == 0 and error_count == 0:
|
||||
await bot.send_private_msg(
|
||||
user_id=user_id,
|
||||
message="没有找到最近的广播消息记录,可能已经撤回或超过可撤回时间。",
|
||||
)
|
||||
else:
|
||||
result = f"广播撤回完成!\n成功撤回 {success_count} 条消息"
|
||||
if error_count:
|
||||
result += f"\n撤回失败 {error_count} 条消息 (可能已过期或无权限)"
|
||||
await bot.send_private_msg(user_id=user_id, message=result)
|
||||
BroadcastManager.log_info(
|
||||
f"广播撤回完成: 成功 {success_count}, 失败 {error_count}", session
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = "撤回广播消息失败"
|
||||
BroadcastManager.log_error(error_msg, e, session)
|
||||
user_id = str(event.get_user_id())
|
||||
await bot.send_private_msg(user_id=user_id, message=f"{error_msg}。")
|
||||
|
||||
@ -1,72 +0,0 @@
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
from nonebot.adapters import Bot
|
||||
import nonebot_plugin_alconna as alc
|
||||
from nonebot_plugin_alconna import Image, UniMsg
|
||||
from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
|
||||
class BroadcastManage:
|
||||
@classmethod
|
||||
async def send(
|
||||
cls, bot: Bot, message: UniMsg, session: EventSession
|
||||
) -> tuple[int, int]:
|
||||
"""发送广播消息
|
||||
|
||||
参数:
|
||||
bot: Bot
|
||||
message: 消息内容
|
||||
session: Session
|
||||
|
||||
返回:
|
||||
tuple[int, int]: 发送成功的群组数量, 发送失败的群组数量
|
||||
"""
|
||||
message_list = []
|
||||
for msg in message:
|
||||
if isinstance(msg, alc.Image) and msg.url:
|
||||
message_list.append(Image(url=msg.url))
|
||||
elif isinstance(msg, alc.Text):
|
||||
message_list.append(msg.text)
|
||||
group_list, _ = await PlatformUtils.get_group_list(bot)
|
||||
if group_list:
|
||||
error_count = 0
|
||||
for group in group_list:
|
||||
try:
|
||||
if not await CommonUtils.task_is_block(
|
||||
bot,
|
||||
"broadcast", # group.channel_id
|
||||
group.group_id,
|
||||
):
|
||||
target = PlatformUtils.get_target(
|
||||
group_id=group.group_id, channel_id=group.channel_id
|
||||
)
|
||||
if target:
|
||||
await MessageUtils.build_message(message_list).send(
|
||||
target, bot
|
||||
)
|
||||
logger.debug(
|
||||
"发送成功",
|
||||
"广播",
|
||||
session=session,
|
||||
target=f"{group.group_id}:{group.channel_id}",
|
||||
)
|
||||
await asyncio.sleep(random.randint(1, 3))
|
||||
else:
|
||||
logger.warning("target为空", "广播", session=session)
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(
|
||||
"发送失败",
|
||||
"广播",
|
||||
session=session,
|
||||
target=f"{group.group_id}:{group.channel_id}",
|
||||
e=e,
|
||||
)
|
||||
return len(group_list) - error_count, error_count
|
||||
return 0, 0
|
||||
490
zhenxun/builtin_plugins/superuser/broadcast/broadcast_manager.py
Normal file
490
zhenxun/builtin_plugins/superuser/broadcast/broadcast_manager.py
Normal file
@ -0,0 +1,490 @@
|
||||
import asyncio
|
||||
import random
|
||||
import traceback
|
||||
from typing import ClassVar
|
||||
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.adapters.onebot.v11 import Bot as V11Bot
|
||||
from nonebot.exception import ActionFailed
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
from nonebot_plugin_alconna.uniseg import Receipt, Reference
|
||||
from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
from .models import BroadcastDetailResult, BroadcastResult
|
||||
from .utils import custom_nodes_to_v11_nodes, uni_message_to_v11_list_of_dicts
|
||||
|
||||
|
||||
class BroadcastManager:
|
||||
"""广播管理器"""
|
||||
|
||||
_last_broadcast_msg_ids: ClassVar[dict[str, int]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _get_session_info(session: EventSession | None) -> str:
|
||||
"""获取会话信息字符串"""
|
||||
if not session:
|
||||
return ""
|
||||
|
||||
try:
|
||||
platform = getattr(session, "platform", "unknown")
|
||||
session_id = str(session)
|
||||
return f"[{platform}:{session_id}]"
|
||||
except Exception:
|
||||
return "[session-info-error]"
|
||||
|
||||
@staticmethod
|
||||
def log_error(
|
||||
message: str, error: Exception, session: EventSession | None = None, **kwargs
|
||||
):
|
||||
"""记录错误日志"""
|
||||
session_info = BroadcastManager._get_session_info(session)
|
||||
error_type = type(error).__name__
|
||||
stack_trace = traceback.format_exc()
|
||||
error_details = f"\n类型: {error_type}\n信息: {error!s}\n堆栈: {stack_trace}"
|
||||
|
||||
logger.error(
|
||||
f"{session_info} {message}{error_details}", "广播", e=error, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_warning(message: str, session: EventSession | None = None, **kwargs):
|
||||
"""记录警告级别日志"""
|
||||
session_info = BroadcastManager._get_session_info(session)
|
||||
logger.warning(f"{session_info} {message}", "广播", **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def log_info(message: str, session: EventSession | None = None, **kwargs):
|
||||
"""记录信息级别日志"""
|
||||
session_info = BroadcastManager._get_session_info(session)
|
||||
logger.info(f"{session_info} {message}", "广播", **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_last_broadcast_msg_ids(cls) -> dict[str, int]:
|
||||
"""获取最近广播消息ID"""
|
||||
return cls._last_broadcast_msg_ids.copy()
|
||||
|
||||
@classmethod
|
||||
def clear_last_broadcast_msg_ids(cls) -> None:
|
||||
"""清空消息ID记录"""
|
||||
cls._last_broadcast_msg_ids.clear()
|
||||
|
||||
@classmethod
|
||||
async def get_all_groups(cls, bot: Bot) -> tuple[list[GroupConsole], str]:
|
||||
"""获取群组列表"""
|
||||
return await PlatformUtils.get_group_list(bot)
|
||||
|
||||
@classmethod
|
||||
async def send(
|
||||
cls, bot: Bot, message: UniMessage, session: EventSession
|
||||
) -> BroadcastResult:
|
||||
"""发送广播到所有群组"""
|
||||
logger.debug(
|
||||
f"开始广播(send - 广播到所有群组),Bot ID: {bot.self_id}",
|
||||
"广播",
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.debug("清空上一次的广播消息ID记录", "广播", session=session)
|
||||
cls.clear_last_broadcast_msg_ids()
|
||||
|
||||
all_groups, _ = await cls.get_all_groups(bot)
|
||||
return await cls.send_to_specific_groups(bot, message, all_groups, session)
|
||||
|
||||
@classmethod
|
||||
async def send_to_specific_groups(
|
||||
cls,
|
||||
bot: Bot,
|
||||
message: UniMessage,
|
||||
target_groups: list[GroupConsole],
|
||||
session_info: EventSession | str | None = None,
|
||||
) -> BroadcastResult:
|
||||
"""发送广播到指定群组"""
|
||||
log_session = session_info or bot.self_id
|
||||
logger.debug(
|
||||
f"开始广播,目标 {len(target_groups)} 个群组,Bot ID: {bot.self_id}",
|
||||
"广播",
|
||||
session=log_session,
|
||||
)
|
||||
|
||||
if not target_groups:
|
||||
logger.debug("目标群组列表为空,广播结束", "广播", session=log_session)
|
||||
return 0, 0
|
||||
|
||||
platform = PlatformUtils.get_platform(bot)
|
||||
is_forward_broadcast = any(
|
||||
isinstance(seg, Reference) and getattr(seg, "nodes", None)
|
||||
for seg in message
|
||||
)
|
||||
|
||||
if platform == "qq" and isinstance(bot, V11Bot) and is_forward_broadcast:
|
||||
if (
|
||||
len(message) == 1
|
||||
and isinstance(message[0], Reference)
|
||||
and getattr(message[0], "nodes", None)
|
||||
):
|
||||
nodes_list = getattr(message[0], "nodes", [])
|
||||
v11_nodes = custom_nodes_to_v11_nodes(nodes_list)
|
||||
node_count = len(v11_nodes)
|
||||
logger.debug(
|
||||
f"从 UniMessage<Reference> 构造转发节点数: {node_count}",
|
||||
"广播",
|
||||
session=log_session,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"广播消息包含合并转发段和其他段,将尝试打平成一个节点发送",
|
||||
"广播",
|
||||
session=log_session,
|
||||
)
|
||||
v11_content_list = uni_message_to_v11_list_of_dicts(message)
|
||||
v11_nodes = (
|
||||
[
|
||||
{
|
||||
"type": "node",
|
||||
"data": {
|
||||
"user_id": bot.self_id,
|
||||
"nickname": "广播",
|
||||
"content": v11_content_list,
|
||||
},
|
||||
}
|
||||
]
|
||||
if v11_content_list
|
||||
else []
|
||||
)
|
||||
|
||||
if not v11_nodes:
|
||||
logger.warning(
|
||||
"构造出的 V11 合并转发节点为空,无法发送",
|
||||
"广播",
|
||||
session=log_session,
|
||||
)
|
||||
return 0, len(target_groups)
|
||||
success_count, error_count, skip_count = await cls._broadcast_forward(
|
||||
bot, log_session, target_groups, v11_nodes
|
||||
)
|
||||
else:
|
||||
if is_forward_broadcast:
|
||||
logger.warning(
|
||||
f"合并转发消息在适配器 ({platform}) 不支持,将作为普通消息发送",
|
||||
"广播",
|
||||
session=log_session,
|
||||
)
|
||||
success_count, error_count, skip_count = await cls._broadcast_normal(
|
||||
bot, log_session, target_groups, message
|
||||
)
|
||||
|
||||
total = len(target_groups)
|
||||
stats = f"成功: {success_count}, 失败: {error_count}"
|
||||
stats += f", 跳过: {skip_count}, 总计: {total}"
|
||||
logger.debug(
|
||||
f"广播统计 - {stats}",
|
||||
"广播",
|
||||
session=log_session,
|
||||
)
|
||||
|
||||
msg_ids = cls.get_last_broadcast_msg_ids()
|
||||
if msg_ids:
|
||||
id_list_str = ", ".join([f"{k}:{v}" for k, v in msg_ids.items()])
|
||||
logger.debug(
|
||||
f"广播结束,记录了 {len(msg_ids)} 条消息ID: {id_list_str}",
|
||||
"广播",
|
||||
session=log_session,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"广播结束,但没有记录任何消息ID",
|
||||
"广播",
|
||||
session=log_session,
|
||||
)
|
||||
|
||||
return success_count, error_count
|
||||
|
||||
@classmethod
|
||||
async def _extract_message_id_from_result(
|
||||
cls,
|
||||
result: dict | Receipt,
|
||||
group_key: str,
|
||||
session_info: EventSession | str,
|
||||
msg_type: str = "普通",
|
||||
) -> None:
|
||||
"""提取消息ID并记录"""
|
||||
if isinstance(result, dict) and "message_id" in result:
|
||||
msg_id = result["message_id"]
|
||||
try:
|
||||
msg_id_int = int(msg_id)
|
||||
cls._last_broadcast_msg_ids[group_key] = msg_id_int
|
||||
logger.debug(
|
||||
f"记录群 {group_key} 的{msg_type}消息ID: {msg_id_int}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"{msg_type}结果中的 message_id 不是有效整数: {msg_id}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
elif isinstance(result, Receipt) and result.msg_ids:
|
||||
try:
|
||||
first_id_info = result.msg_ids[0]
|
||||
msg_id = None
|
||||
if isinstance(first_id_info, dict) and "message_id" in first_id_info:
|
||||
msg_id = first_id_info["message_id"]
|
||||
logger.debug(
|
||||
f"从 Receipt.msg_ids[0] 提取到 ID: {msg_id}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
elif isinstance(first_id_info, int | str):
|
||||
msg_id = first_id_info
|
||||
logger.debug(
|
||||
f"从 Receipt.msg_ids[0] 提取到原始ID: {msg_id}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
if msg_id is not None:
|
||||
try:
|
||||
msg_id_int = int(msg_id)
|
||||
cls._last_broadcast_msg_ids[group_key] = msg_id_int
|
||||
logger.debug(
|
||||
f"记录群 {group_key} 的消息ID: {msg_id_int}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"提取的ID ({msg_id}) 不是有效整数",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
else:
|
||||
info_str = str(first_id_info)
|
||||
logger.warning(
|
||||
f"无法从 Receipt.msg_ids[0] 提取ID: {info_str}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
except IndexError:
|
||||
logger.warning("Receipt.msg_ids 为空", "广播", session=session_info)
|
||||
except Exception as e_extract:
|
||||
logger.error(
|
||||
f"从 Receipt 提取 msg_id 时出错: {e_extract}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
e=e_extract,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"发送成功但无法从结果获取消息 ID. 结果: {result}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _check_group_availability(cls, bot: Bot, group: GroupConsole) -> bool:
|
||||
"""检查群组是否可用"""
|
||||
if not group.group_id:
|
||||
return False
|
||||
|
||||
if await CommonUtils.task_is_block(bot, "broadcast", group.group_id):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def _broadcast_forward(
|
||||
cls,
|
||||
bot: V11Bot,
|
||||
session_info: EventSession | str,
|
||||
group_list: list[GroupConsole],
|
||||
v11_nodes: list[dict],
|
||||
) -> BroadcastDetailResult:
|
||||
"""发送合并转发"""
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
skip_count = 0
|
||||
|
||||
for _, group in enumerate(group_list):
|
||||
group_key = group.group_id or group.channel_id
|
||||
|
||||
if not await cls._check_group_availability(bot, group):
|
||||
skip_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await bot.send_group_forward_msg(
|
||||
group_id=int(group.group_id), messages=v11_nodes
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"合并转发消息发送结果: {result}, 类型: {type(result)}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
await cls._extract_message_id_from_result(
|
||||
result, group_key, session_info, "合并转发"
|
||||
)
|
||||
|
||||
success_count += 1
|
||||
await asyncio.sleep(random.randint(1, 3))
|
||||
except ActionFailed as af_e:
|
||||
error_count += 1
|
||||
logger.error(
|
||||
f"发送失败(合并转发) to {group_key}: {af_e}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
e=af_e,
|
||||
)
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(
|
||||
f"发送失败(合并转发) to {group_key}: {e}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
e=e,
|
||||
)
|
||||
|
||||
return success_count, error_count, skip_count
|
||||
|
||||
@classmethod
|
||||
async def _broadcast_normal(
|
||||
cls,
|
||||
bot: Bot,
|
||||
session_info: EventSession | str,
|
||||
group_list: list[GroupConsole],
|
||||
message: UniMessage,
|
||||
) -> BroadcastDetailResult:
|
||||
"""发送普通消息"""
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
skip_count = 0
|
||||
|
||||
for _, group in enumerate(group_list):
|
||||
group_key = (
|
||||
f"{group.group_id}:{group.channel_id}"
|
||||
if group.channel_id
|
||||
else str(group.group_id)
|
||||
)
|
||||
|
||||
if not await cls._check_group_availability(bot, group):
|
||||
skip_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
target = PlatformUtils.get_target(
|
||||
group_id=group.group_id, channel_id=group.channel_id
|
||||
)
|
||||
|
||||
if target:
|
||||
receipt: Receipt = await message.send(target, bot=bot)
|
||||
|
||||
logger.debug(
|
||||
f"广播消息发送结果: {receipt}, 类型: {type(receipt)}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
await cls._extract_message_id_from_result(
|
||||
receipt, group_key, session_info
|
||||
)
|
||||
|
||||
success_count += 1
|
||||
await asyncio.sleep(random.randint(1, 3))
|
||||
else:
|
||||
logger.warning(
|
||||
"target为空", "广播", session=session_info, target=group_key
|
||||
)
|
||||
skip_count += 1
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(
|
||||
f"发送失败(普通) to {group_key}: {e}",
|
||||
"广播",
|
||||
session=session_info,
|
||||
e=e,
|
||||
)
|
||||
|
||||
return success_count, error_count, skip_count
|
||||
|
||||
@classmethod
|
||||
async def recall_last_broadcast(
|
||||
cls, bot: Bot, session_info: EventSession | str
|
||||
) -> BroadcastResult:
|
||||
"""撤回最近广播"""
|
||||
msg_ids_to_recall = cls.get_last_broadcast_msg_ids()
|
||||
|
||||
if not msg_ids_to_recall:
|
||||
logger.warning(
|
||||
"没有找到最近的广播消息ID记录", "广播撤回", session=session_info
|
||||
)
|
||||
return 0, 0
|
||||
|
||||
id_list_str = ", ".join([f"{k}:{v}" for k, v in msg_ids_to_recall.items()])
|
||||
logger.debug(
|
||||
f"找到 {len(msg_ids_to_recall)} 条广播消息ID记录: {id_list_str}",
|
||||
"广播撤回",
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
|
||||
logger.info(
|
||||
f"准备撤回 {len(msg_ids_to_recall)} 条广播消息",
|
||||
"广播撤回",
|
||||
session=session_info,
|
||||
)
|
||||
|
||||
for group_key, msg_id in msg_ids_to_recall.items():
|
||||
try:
|
||||
logger.debug(
|
||||
f"尝试撤回消息 (ID: {msg_id}) in {group_key}",
|
||||
"广播撤回",
|
||||
session=session_info,
|
||||
)
|
||||
await bot.call_api("delete_msg", message_id=msg_id)
|
||||
success_count += 1
|
||||
except ActionFailed as af_e:
|
||||
retcode = getattr(af_e, "retcode", None)
|
||||
wording = getattr(af_e, "wording", "")
|
||||
if retcode == 100 and "MESSAGE_NOT_FOUND" in wording.upper():
|
||||
logger.warning(
|
||||
f"消息 (ID: {msg_id}) 可能已被撤回或不存在于 {group_key}",
|
||||
"广播撤回",
|
||||
session=session_info,
|
||||
)
|
||||
elif retcode == 300 and "delete message" in wording.lower():
|
||||
logger.warning(
|
||||
f"消息 (ID: {msg_id}) 可能已被撤回或不存在于 {group_key}",
|
||||
"广播撤回",
|
||||
session=session_info,
|
||||
)
|
||||
else:
|
||||
error_count += 1
|
||||
logger.error(
|
||||
f"撤回消息失败 (ID: {msg_id}) in {group_key}: {af_e}",
|
||||
"广播撤回",
|
||||
session=session_info,
|
||||
e=af_e,
|
||||
)
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(
|
||||
f"撤回消息时发生未知错误 (ID: {msg_id}) in {group_key}: {e}",
|
||||
"广播撤回",
|
||||
session=session_info,
|
||||
e=e,
|
||||
)
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
logger.debug("撤回操作完成,清空消息ID记录", "广播撤回", session=session_info)
|
||||
cls.clear_last_broadcast_msg_ids()
|
||||
|
||||
return success_count, error_count
|
||||
584
zhenxun/builtin_plugins/superuser/broadcast/message_processor.py
Normal file
584
zhenxun/builtin_plugins/superuser/broadcast/message_processor.py
Normal file
@ -0,0 +1,584 @@
|
||||
import base64
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.adapters.onebot.v11 import Message as V11Message
|
||||
from nonebot.adapters.onebot.v11 import MessageSegment as V11MessageSegment
|
||||
from nonebot.exception import ActionFailed
|
||||
import nonebot_plugin_alconna as alc
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
from nonebot_plugin_alconna.uniseg.segment import (
|
||||
At,
|
||||
AtAll,
|
||||
CustomNode,
|
||||
Image,
|
||||
Reference,
|
||||
Reply,
|
||||
Text,
|
||||
Video,
|
||||
)
|
||||
from nonebot_plugin_alconna.uniseg.tools import reply_fetch
|
||||
from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
from .broadcast_manager import BroadcastManager
|
||||
|
||||
MAX_FORWARD_DEPTH = 3
|
||||
|
||||
|
||||
async def _process_forward_content(
|
||||
forward_content: Any, forward_id: str | None, bot: Bot, depth: int
|
||||
) -> list[CustomNode]:
|
||||
"""处理转发消息内容"""
|
||||
nodes_for_alc = []
|
||||
content_parsed = False
|
||||
|
||||
if forward_content:
|
||||
nodes_from_content = None
|
||||
if isinstance(forward_content, list):
|
||||
nodes_from_content = forward_content
|
||||
elif isinstance(forward_content, str):
|
||||
try:
|
||||
parsed_content = json.loads(forward_content)
|
||||
if isinstance(parsed_content, list):
|
||||
nodes_from_content = parsed_content
|
||||
except Exception as json_e:
|
||||
logger.debug(
|
||||
f"[Depth {depth}] JSON解析失败: {json_e}",
|
||||
"广播",
|
||||
)
|
||||
|
||||
if nodes_from_content is not None:
|
||||
logger.debug(
|
||||
f"[D{depth}] 节点数: {len(nodes_from_content)}",
|
||||
"广播",
|
||||
)
|
||||
content_parsed = True
|
||||
for node_data in nodes_from_content:
|
||||
node = await _create_custom_node_from_data(node_data, bot, depth + 1)
|
||||
if node:
|
||||
nodes_for_alc.append(node)
|
||||
|
||||
if not content_parsed and forward_id:
|
||||
logger.debug(
|
||||
f"[D{depth}] 尝试API调用ID: {forward_id}",
|
||||
"广播",
|
||||
)
|
||||
try:
|
||||
forward_data = await bot.call_api("get_forward_msg", id=forward_id)
|
||||
nodes_list = None
|
||||
|
||||
if isinstance(forward_data, dict) and "messages" in forward_data:
|
||||
nodes_list = forward_data["messages"]
|
||||
elif (
|
||||
isinstance(forward_data, dict)
|
||||
and "data" in forward_data
|
||||
and isinstance(forward_data["data"], dict)
|
||||
and "message" in forward_data["data"]
|
||||
):
|
||||
nodes_list = forward_data["data"]["message"]
|
||||
elif isinstance(forward_data, list):
|
||||
nodes_list = forward_data
|
||||
|
||||
if nodes_list:
|
||||
node_count = len(nodes_list)
|
||||
logger.debug(
|
||||
f"[D{depth + 1}] 节点:{node_count}",
|
||||
"广播",
|
||||
)
|
||||
for node_data in nodes_list:
|
||||
node = await _create_custom_node_from_data(
|
||||
node_data, bot, depth + 1
|
||||
)
|
||||
if node:
|
||||
nodes_for_alc.append(node)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[D{depth + 1}] ID:{forward_id}无节点",
|
||||
"广播",
|
||||
)
|
||||
nodes_for_alc.append(
|
||||
CustomNode(
|
||||
uid="0",
|
||||
name="错误",
|
||||
content="[嵌套转发消息获取失败]",
|
||||
)
|
||||
)
|
||||
except ActionFailed as af_e:
|
||||
logger.error(
|
||||
f"[D{depth + 1}] API失败: {af_e}",
|
||||
"广播",
|
||||
e=af_e,
|
||||
)
|
||||
nodes_for_alc.append(
|
||||
CustomNode(
|
||||
uid="0",
|
||||
name="错误",
|
||||
content="[嵌套转发消息获取失败]",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[D{depth + 1}] 处理出错: {e}",
|
||||
"广播",
|
||||
e=e,
|
||||
)
|
||||
nodes_for_alc.append(
|
||||
CustomNode(
|
||||
uid="0",
|
||||
name="错误",
|
||||
content="[处理嵌套转发时出错]",
|
||||
)
|
||||
)
|
||||
elif not content_parsed and not forward_id:
|
||||
logger.warning(
|
||||
f"[D{depth}] 转发段无内容也无ID",
|
||||
"广播",
|
||||
)
|
||||
nodes_for_alc.append(
|
||||
CustomNode(
|
||||
uid="0",
|
||||
name="错误",
|
||||
content="[嵌套转发消息无法解析]",
|
||||
)
|
||||
)
|
||||
elif content_parsed and not nodes_for_alc:
|
||||
logger.warning(
|
||||
f"[D{depth}] 解析成功但无有效节点",
|
||||
"广播",
|
||||
)
|
||||
nodes_for_alc.append(
|
||||
CustomNode(
|
||||
uid="0",
|
||||
name="信息",
|
||||
content="[嵌套转发内容为空]",
|
||||
)
|
||||
)
|
||||
|
||||
return nodes_for_alc
|
||||
|
||||
|
||||
async def _create_custom_node_from_data(
|
||||
node_data: dict, bot: Bot, depth: int
|
||||
) -> CustomNode | None:
|
||||
"""从节点数据创建CustomNode"""
|
||||
node_content_raw = node_data.get("message") or node_data.get("content")
|
||||
if not node_content_raw:
|
||||
logger.warning(f"[D{depth}] 节点缺少消息内容", "广播")
|
||||
return None
|
||||
|
||||
sender = node_data.get("sender", {})
|
||||
uid = str(sender.get("user_id", "10000"))
|
||||
name = sender.get("nickname", f"用户{uid[:4]}")
|
||||
|
||||
extracted_uni_msg = await _extract_content_from_message(
|
||||
node_content_raw, bot, depth
|
||||
)
|
||||
if not extracted_uni_msg:
|
||||
return None
|
||||
|
||||
return CustomNode(uid=uid, name=name, content=extracted_uni_msg)
|
||||
|
||||
|
||||
async def _extract_broadcast_content(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
arp: alc.Arparma,
|
||||
session: EventSession,
|
||||
) -> UniMessage | None:
|
||||
"""从命令参数或引用消息中提取广播内容"""
|
||||
broadcast_content_msg: UniMessage | None = None
|
||||
|
||||
command_content_list = arp.all_matched_args.get("content", [])
|
||||
|
||||
processed_command_list = []
|
||||
has_command_content = False
|
||||
|
||||
if command_content_list:
|
||||
for item in command_content_list:
|
||||
if isinstance(item, alc.Segment):
|
||||
processed_command_list.append(item)
|
||||
if not (isinstance(item, Text) and not item.text.strip()):
|
||||
has_command_content = True
|
||||
elif isinstance(item, str):
|
||||
if item.strip():
|
||||
processed_command_list.append(Text(item.strip()))
|
||||
has_command_content = True
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unexpected type in command content: {type(item)}", "广播"
|
||||
)
|
||||
|
||||
if has_command_content:
|
||||
logger.debug("检测到命令参数内容,优先使用参数内容", "广播", session=session)
|
||||
broadcast_content_msg = UniMessage(processed_command_list)
|
||||
|
||||
if not broadcast_content_msg.filter(
|
||||
lambda x: not (isinstance(x, Text) and not x.text.strip())
|
||||
):
|
||||
logger.warning(
|
||||
"命令参数内容解析后为空或只包含空白", "广播", session=session
|
||||
)
|
||||
broadcast_content_msg = None
|
||||
|
||||
if not broadcast_content_msg:
|
||||
reply_segment_obj: Reply | None = await reply_fetch(event, bot)
|
||||
if (
|
||||
reply_segment_obj
|
||||
and hasattr(reply_segment_obj, "msg")
|
||||
and reply_segment_obj.msg
|
||||
):
|
||||
logger.debug(
|
||||
"未检测到有效命令参数,检测到引用消息", "广播", session=session
|
||||
)
|
||||
raw_quoted_content = reply_segment_obj.msg
|
||||
is_forward = False
|
||||
forward_id = None
|
||||
|
||||
if isinstance(raw_quoted_content, V11Message):
|
||||
for seg in raw_quoted_content:
|
||||
if isinstance(seg, V11MessageSegment):
|
||||
if seg.type == "forward":
|
||||
forward_id = seg.data.get("id")
|
||||
is_forward = bool(forward_id)
|
||||
break
|
||||
elif seg.type == "json":
|
||||
try:
|
||||
json_data_str = seg.data.get("data", "{}")
|
||||
if isinstance(json_data_str, str):
|
||||
import json
|
||||
|
||||
json_data = json.loads(json_data_str)
|
||||
if (
|
||||
json_data.get("app") == "com.tencent.multimsg"
|
||||
or json_data.get("view") == "Forward"
|
||||
) and json_data.get("meta", {}).get(
|
||||
"detail", {}
|
||||
).get("resid"):
|
||||
forward_id = json_data["meta"]["detail"][
|
||||
"resid"
|
||||
]
|
||||
is_forward = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if is_forward and forward_id:
|
||||
logger.info(
|
||||
f"尝试获取并构造合并转发内容 (ID: {forward_id})",
|
||||
"广播",
|
||||
session=session,
|
||||
)
|
||||
nodes_to_forward: list[CustomNode] = []
|
||||
try:
|
||||
forward_data = await bot.call_api("get_forward_msg", id=forward_id)
|
||||
nodes_list = None
|
||||
if isinstance(forward_data, dict) and "messages" in forward_data:
|
||||
nodes_list = forward_data["messages"]
|
||||
elif (
|
||||
isinstance(forward_data, dict)
|
||||
and "data" in forward_data
|
||||
and isinstance(forward_data["data"], dict)
|
||||
and "message" in forward_data["data"]
|
||||
):
|
||||
nodes_list = forward_data["data"]["message"]
|
||||
elif isinstance(forward_data, list):
|
||||
nodes_list = forward_data
|
||||
|
||||
if nodes_list is not None:
|
||||
for node_data in nodes_list:
|
||||
node_sender = node_data.get("sender", {})
|
||||
node_user_id = str(node_sender.get("user_id", "10000"))
|
||||
node_nickname = node_sender.get(
|
||||
"nickname", f"用户{node_user_id[:4]}"
|
||||
)
|
||||
node_content_raw = node_data.get(
|
||||
"message"
|
||||
) or node_data.get("content")
|
||||
if node_content_raw:
|
||||
extracted_node_uni_msg = (
|
||||
await _extract_content_from_message(
|
||||
node_content_raw, bot
|
||||
)
|
||||
)
|
||||
if extracted_node_uni_msg:
|
||||
nodes_to_forward.append(
|
||||
CustomNode(
|
||||
uid=node_user_id,
|
||||
name=node_nickname,
|
||||
content=extracted_node_uni_msg,
|
||||
)
|
||||
)
|
||||
if nodes_to_forward:
|
||||
broadcast_content_msg = UniMessage(
|
||||
Reference(nodes=nodes_to_forward)
|
||||
)
|
||||
except ActionFailed:
|
||||
await MessageUtils.build_message(
|
||||
"获取合并转发消息失败,可能不支持此 API。"
|
||||
).send(reply_to=True)
|
||||
return None
|
||||
except Exception as api_e:
|
||||
logger.error(f"处理合并转发时出错: {api_e}", "广播", e=api_e)
|
||||
await MessageUtils.build_message(
|
||||
"处理合并转发消息时发生内部错误。"
|
||||
).send(reply_to=True)
|
||||
return None
|
||||
else:
|
||||
broadcast_content_msg = await _extract_content_from_message(
|
||||
raw_quoted_content, bot
|
||||
)
|
||||
else:
|
||||
logger.debug("未检测到命令参数和引用消息", "广播", session=session)
|
||||
await MessageUtils.build_message("请提供广播内容或引用要广播的消息").send(
|
||||
reply_to=True
|
||||
)
|
||||
return None
|
||||
|
||||
if not broadcast_content_msg:
|
||||
logger.error(
|
||||
"未能从命令参数或引用消息中获取有效的广播内容", "广播", session=session
|
||||
)
|
||||
await MessageUtils.build_message("错误:未能获取有效的广播内容。").send(
|
||||
reply_to=True
|
||||
)
|
||||
return None
|
||||
|
||||
return broadcast_content_msg
|
||||
|
||||
|
||||
async def _process_v11_segment(
|
||||
seg_obj: V11MessageSegment | dict, depth: int, index: int, bot: Bot
|
||||
) -> list[alc.Segment]:
|
||||
"""处理V11消息段"""
|
||||
result = []
|
||||
seg_type = None
|
||||
data_dict = None
|
||||
|
||||
if isinstance(seg_obj, V11MessageSegment):
|
||||
seg_type = seg_obj.type
|
||||
data_dict = seg_obj.data
|
||||
elif isinstance(seg_obj, dict):
|
||||
seg_type = seg_obj.get("type")
|
||||
data_dict = seg_obj.get("data")
|
||||
else:
|
||||
return result
|
||||
|
||||
if not (seg_type and data_dict is not None):
|
||||
logger.warning(f"[D{depth}] 跳过无效数据: {type(seg_obj)}", "广播")
|
||||
return result
|
||||
|
||||
if seg_type == "text":
|
||||
text_content = data_dict.get("text", "")
|
||||
if isinstance(text_content, str) and text_content.strip():
|
||||
result.append(Text(text_content))
|
||||
elif seg_type == "image":
|
||||
img_seg = None
|
||||
if data_dict.get("url"):
|
||||
img_seg = Image(url=data_dict["url"])
|
||||
elif data_dict.get("file"):
|
||||
file_val = data_dict["file"]
|
||||
if isinstance(file_val, str) and file_val.startswith("base64://"):
|
||||
b64_data = file_val[9:]
|
||||
raw_bytes = base64.b64decode(b64_data)
|
||||
img_seg = Image(raw=raw_bytes)
|
||||
else:
|
||||
img_seg = Image(path=file_val)
|
||||
if img_seg:
|
||||
result.append(img_seg)
|
||||
else:
|
||||
logger.warning(f"[Depth {depth}] V11 图片 {index} 缺少URL/文件", "广播")
|
||||
elif seg_type == "at":
|
||||
target_qq = data_dict.get("qq", "")
|
||||
if target_qq.lower() == "all":
|
||||
result.append(AtAll())
|
||||
elif target_qq:
|
||||
result.append(At(flag="user", target=target_qq))
|
||||
elif seg_type == "video":
|
||||
video_seg = None
|
||||
if data_dict.get("url"):
|
||||
video_seg = Video(url=data_dict["url"])
|
||||
elif data_dict.get("file"):
|
||||
file_val = data_dict["file"]
|
||||
if isinstance(file_val, str) and file_val.startswith("base64://"):
|
||||
b64_data = file_val[9:]
|
||||
raw_bytes = base64.b64decode(b64_data)
|
||||
video_seg = Video(raw=raw_bytes)
|
||||
else:
|
||||
video_seg = Video(path=file_val)
|
||||
if video_seg:
|
||||
result.append(video_seg)
|
||||
logger.debug(f"[Depth {depth}] 处理视频消息成功", "广播")
|
||||
else:
|
||||
logger.warning(f"[Depth {depth}] V11 视频 {index} 缺少URL/文件", "广播")
|
||||
elif seg_type == "forward":
|
||||
nested_forward_id = data_dict.get("id") or data_dict.get("resid")
|
||||
nested_forward_content = data_dict.get("content")
|
||||
|
||||
logger.debug(f"[D{depth}] 嵌套转发ID: {nested_forward_id}", "广播")
|
||||
|
||||
nested_nodes = await _process_forward_content(
|
||||
nested_forward_content, nested_forward_id, bot, depth
|
||||
)
|
||||
|
||||
if nested_nodes:
|
||||
result.append(Reference(nodes=nested_nodes))
|
||||
else:
|
||||
logger.warning(f"[D{depth}] 跳过类型: {seg_type}", "广播")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _extract_content_from_message(
|
||||
message_content: Any, bot: Bot, depth: int = 0
|
||||
) -> UniMessage:
|
||||
"""提取消息内容到UniMessage"""
|
||||
temp_msg = UniMessage()
|
||||
input_type_str = str(type(message_content))
|
||||
|
||||
if depth >= MAX_FORWARD_DEPTH:
|
||||
logger.warning(
|
||||
f"[Depth {depth}] 达到最大递归深度 {MAX_FORWARD_DEPTH},停止解析嵌套转发。",
|
||||
"广播",
|
||||
)
|
||||
temp_msg.append(Text("[嵌套转发层数过多,内容已省略]"))
|
||||
return temp_msg
|
||||
|
||||
segments_to_process = []
|
||||
|
||||
if isinstance(message_content, UniMessage):
|
||||
segments_to_process = list(message_content)
|
||||
elif isinstance(message_content, V11Message):
|
||||
segments_to_process = list(message_content)
|
||||
elif isinstance(message_content, list):
|
||||
segments_to_process = message_content
|
||||
elif (
|
||||
isinstance(message_content, dict)
|
||||
and "type" in message_content
|
||||
and "data" in message_content
|
||||
):
|
||||
segments_to_process = [message_content]
|
||||
elif isinstance(message_content, str):
|
||||
if message_content.strip():
|
||||
temp_msg.append(Text(message_content))
|
||||
return temp_msg
|
||||
else:
|
||||
logger.warning(f"[Depth {depth}] 无法处理的输入类型: {input_type_str}", "广播")
|
||||
return temp_msg
|
||||
|
||||
if segments_to_process:
|
||||
for index, seg_obj in enumerate(segments_to_process):
|
||||
try:
|
||||
if isinstance(seg_obj, Text):
|
||||
text_content = getattr(seg_obj, "text", None)
|
||||
if isinstance(text_content, str) and text_content.strip():
|
||||
temp_msg.append(seg_obj)
|
||||
elif isinstance(seg_obj, Image):
|
||||
if (
|
||||
getattr(seg_obj, "url", None)
|
||||
or getattr(seg_obj, "path", None)
|
||||
or getattr(seg_obj, "raw", None)
|
||||
):
|
||||
temp_msg.append(seg_obj)
|
||||
elif isinstance(seg_obj, At):
|
||||
temp_msg.append(seg_obj)
|
||||
elif isinstance(seg_obj, AtAll):
|
||||
temp_msg.append(seg_obj)
|
||||
elif isinstance(seg_obj, Video):
|
||||
if (
|
||||
getattr(seg_obj, "url", None)
|
||||
or getattr(seg_obj, "path", None)
|
||||
or getattr(seg_obj, "raw", None)
|
||||
):
|
||||
temp_msg.append(seg_obj)
|
||||
logger.debug(f"[D{depth}] 处理Video对象成功", "广播")
|
||||
else:
|
||||
processed_segments = await _process_v11_segment(
|
||||
seg_obj, depth, index, bot
|
||||
)
|
||||
temp_msg.extend(processed_segments)
|
||||
except Exception as e_conv_seg:
|
||||
logger.warning(
|
||||
f"[D{depth}] 处理段 {index} 出错: {e_conv_seg}",
|
||||
"广播",
|
||||
e=e_conv_seg,
|
||||
)
|
||||
|
||||
if not temp_msg and message_content:
|
||||
logger.warning(f"未能从类型 {input_type_str} 中提取内容", "广播")
|
||||
|
||||
return temp_msg
|
||||
|
||||
|
||||
async def get_broadcast_target_groups(
|
||||
bot: Bot, session: EventSession
|
||||
) -> tuple[list, list]:
|
||||
"""获取广播目标群组和启用了广播功能的群组"""
|
||||
target_groups = []
|
||||
all_groups, _ = await BroadcastManager.get_all_groups(bot)
|
||||
|
||||
current_group_id = None
|
||||
if hasattr(session, "id2") and session.id2:
|
||||
current_group_id = session.id2
|
||||
|
||||
if current_group_id:
|
||||
target_groups = [
|
||||
group for group in all_groups if group.group_id != current_group_id
|
||||
]
|
||||
logger.info(
|
||||
f"向除当前群组({current_group_id})外的所有群组广播", "广播", session=session
|
||||
)
|
||||
else:
|
||||
target_groups = all_groups
|
||||
logger.info("向所有群组广播", "广播", session=session)
|
||||
|
||||
if not target_groups:
|
||||
await MessageUtils.build_message("没有找到符合条件的广播目标群组。").send(
|
||||
reply_to=True
|
||||
)
|
||||
return [], []
|
||||
|
||||
enabled_groups = []
|
||||
for group in target_groups:
|
||||
if not await CommonUtils.task_is_block(bot, "broadcast", group.group_id):
|
||||
enabled_groups.append(group)
|
||||
|
||||
if not enabled_groups:
|
||||
await MessageUtils.build_message(
|
||||
"没有启用了广播功能的目标群组可供立即发送。"
|
||||
).send(reply_to=True)
|
||||
return target_groups, []
|
||||
|
||||
return target_groups, enabled_groups
|
||||
|
||||
|
||||
async def send_broadcast_and_notify(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
message: UniMessage,
|
||||
enabled_groups: list,
|
||||
target_groups: list,
|
||||
session: EventSession,
|
||||
) -> None:
|
||||
"""发送广播并通知结果"""
|
||||
BroadcastManager.clear_last_broadcast_msg_ids()
|
||||
count, error_count = await BroadcastManager.send_to_specific_groups(
|
||||
bot, message, enabled_groups, session
|
||||
)
|
||||
|
||||
result = f"成功广播 {count} 个群组"
|
||||
if error_count:
|
||||
result += f"\n发送失败 {error_count} 个群组"
|
||||
result += f"\n有效: {len(enabled_groups)} / 总计: {len(target_groups)}"
|
||||
|
||||
user_id = str(event.get_user_id())
|
||||
await bot.send_private_msg(user_id=user_id, message=f"发送广播完成!\n{result}")
|
||||
|
||||
BroadcastManager.log_info(
|
||||
f"广播完成,有效/总计: {len(enabled_groups)}/{len(target_groups)}",
|
||||
session,
|
||||
)
|
||||
64
zhenxun/builtin_plugins/superuser/broadcast/models.py
Normal file
64
zhenxun/builtin_plugins/superuser/broadcast/models.py
Normal file
@ -0,0 +1,64 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
|
||||
GroupKey = str
|
||||
MessageID = int
|
||||
BroadcastResult = tuple[int, int]
|
||||
BroadcastDetailResult = tuple[int, int, int]
|
||||
|
||||
|
||||
class BroadcastTarget:
|
||||
"""广播目标"""
|
||||
|
||||
def __init__(self, group_id: str, channel_id: str | None = None):
|
||||
self.group_id = group_id
|
||||
self.channel_id = channel_id
|
||||
|
||||
def to_dict(self) -> dict[str, str | None]:
|
||||
"""转换为字典格式"""
|
||||
return {"group_id": self.group_id, "channel_id": self.channel_id}
|
||||
|
||||
@classmethod
|
||||
def from_group_console(cls, group: GroupConsole) -> "BroadcastTarget":
|
||||
"""从 GroupConsole 对象创建"""
|
||||
return cls(group_id=group.group_id, channel_id=group.channel_id)
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""获取群组的唯一标识"""
|
||||
if self.channel_id:
|
||||
return f"{self.group_id}:{self.channel_id}"
|
||||
return str(self.group_id)
|
||||
|
||||
|
||||
class BroadcastTask:
|
||||
"""广播任务"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot_id: str,
|
||||
message: UniMessage,
|
||||
targets: list[BroadcastTarget],
|
||||
scheduled_time: datetime | None = None,
|
||||
task_id: str | None = None,
|
||||
):
|
||||
self.bot_id = bot_id
|
||||
self.message = message
|
||||
self.targets = targets
|
||||
self.scheduled_time = scheduled_time
|
||||
self.task_id = task_id
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典格式,用于序列化"""
|
||||
return {
|
||||
"bot_id": self.bot_id,
|
||||
"targets": [t.to_dict() for t in self.targets],
|
||||
"scheduled_time": self.scheduled_time.isoformat()
|
||||
if self.scheduled_time
|
||||
else None,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
175
zhenxun/builtin_plugins/superuser/broadcast/utils.py
Normal file
175
zhenxun/builtin_plugins/superuser/broadcast/utils.py
Normal file
@ -0,0 +1,175 @@
|
||||
import base64
|
||||
|
||||
import nonebot_plugin_alconna as alc
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
from nonebot_plugin_alconna.uniseg import Reference
|
||||
from nonebot_plugin_alconna.uniseg.segment import CustomNode, Video
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
|
||||
def uni_segment_to_v11_segment_dict(
|
||||
seg: alc.Segment, depth: int = 0
|
||||
) -> dict | list[dict] | None:
|
||||
"""UniSeg段转V11字典"""
|
||||
if isinstance(seg, alc.Text):
|
||||
return {"type": "text", "data": {"text": seg.text}}
|
||||
elif isinstance(seg, alc.Image):
|
||||
if getattr(seg, "url", None):
|
||||
return {
|
||||
"type": "image",
|
||||
"data": {"file": seg.url},
|
||||
}
|
||||
elif getattr(seg, "raw", None):
|
||||
raw_data = seg.raw
|
||||
if isinstance(raw_data, str):
|
||||
if len(raw_data) >= 9 and raw_data[:9] == "base64://":
|
||||
return {"type": "image", "data": {"file": raw_data}}
|
||||
elif isinstance(raw_data, bytes):
|
||||
b64_str = base64.b64encode(raw_data).decode()
|
||||
return {"type": "image", "data": {"file": f"base64://{b64_str}"}}
|
||||
else:
|
||||
logger.warning(f"无法处理 Image.raw 的类型: {type(raw_data)}", "广播")
|
||||
elif getattr(seg, "path", None):
|
||||
logger.warning(
|
||||
f"在合并转发中使用了本地图片路径,可能无法显示: {seg.path}", "广播"
|
||||
)
|
||||
return {"type": "image", "data": {"file": f"file:///{seg.path}"}}
|
||||
else:
|
||||
logger.warning(f"alc.Image 缺少有效数据,无法转换为 V11 段: {seg}", "广播")
|
||||
elif isinstance(seg, alc.At):
|
||||
return {"type": "at", "data": {"qq": seg.target}}
|
||||
elif isinstance(seg, alc.AtAll):
|
||||
return {"type": "at", "data": {"qq": "all"}}
|
||||
elif isinstance(seg, Video):
|
||||
if getattr(seg, "url", None):
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {"file": seg.url},
|
||||
}
|
||||
elif getattr(seg, "raw", None):
|
||||
raw_data = seg.raw
|
||||
if isinstance(raw_data, str):
|
||||
if len(raw_data) >= 9 and raw_data[:9] == "base64://":
|
||||
return {"type": "video", "data": {"file": raw_data}}
|
||||
elif isinstance(raw_data, bytes):
|
||||
b64_str = base64.b64encode(raw_data).decode()
|
||||
return {"type": "video", "data": {"file": f"base64://{b64_str}"}}
|
||||
else:
|
||||
logger.warning(f"无法处理 Video.raw 的类型: {type(raw_data)}", "广播")
|
||||
elif getattr(seg, "path", None):
|
||||
logger.warning(
|
||||
f"在合并转发中使用了本地视频路径,可能无法显示: {seg.path}", "广播"
|
||||
)
|
||||
return {"type": "video", "data": {"file": f"file:///{seg.path}"}}
|
||||
else:
|
||||
logger.warning(f"Video 缺少有效数据,无法转换为 V11 段: {seg}", "广播")
|
||||
elif isinstance(seg, Reference) and getattr(seg, "nodes", None):
|
||||
if depth >= 3:
|
||||
logger.warning(
|
||||
f"嵌套转发深度超过限制 (depth={depth}),不再继续解析", "广播"
|
||||
)
|
||||
return {"type": "text", "data": {"text": "[嵌套转发层数过多,内容已省略]"}}
|
||||
|
||||
nested_v11_content_list = []
|
||||
nodes_list = getattr(seg, "nodes", [])
|
||||
for node in nodes_list:
|
||||
if isinstance(node, CustomNode):
|
||||
node_v11_content = []
|
||||
if isinstance(node.content, UniMessage):
|
||||
for nested_seg in node.content:
|
||||
converted_dict = uni_segment_to_v11_segment_dict(
|
||||
nested_seg, depth + 1
|
||||
)
|
||||
if isinstance(converted_dict, list):
|
||||
node_v11_content.extend(converted_dict)
|
||||
elif converted_dict:
|
||||
node_v11_content.append(converted_dict)
|
||||
elif isinstance(node.content, str):
|
||||
node_v11_content.append(
|
||||
{"type": "text", "data": {"text": node.content}}
|
||||
)
|
||||
if node_v11_content:
|
||||
separator = {
|
||||
"type": "text",
|
||||
"data": {
|
||||
"text": f"\n--- 来自 {node.name} ({node.uid}) 的消息 ---\n"
|
||||
},
|
||||
}
|
||||
nested_v11_content_list.insert(0, separator)
|
||||
nested_v11_content_list.extend(node_v11_content)
|
||||
nested_v11_content_list.append(
|
||||
{"type": "text", "data": {"text": "\n---\n"}}
|
||||
)
|
||||
|
||||
return nested_v11_content_list
|
||||
|
||||
else:
|
||||
logger.warning(f"广播时跳过不支持的 UniSeg 段类型: {type(seg)}", "广播")
|
||||
return None
|
||||
|
||||
|
||||
def uni_message_to_v11_list_of_dicts(uni_msg: UniMessage | str | list) -> list[dict]:
|
||||
"""UniMessage转V11字典列表"""
|
||||
try:
|
||||
if isinstance(uni_msg, str):
|
||||
return [{"type": "text", "data": {"text": uni_msg}}]
|
||||
|
||||
if isinstance(uni_msg, list):
|
||||
if not uni_msg:
|
||||
return []
|
||||
|
||||
if all(isinstance(item, str) for item in uni_msg):
|
||||
return [{"type": "text", "data": {"text": item}} for item in uni_msg]
|
||||
|
||||
result = []
|
||||
for item in uni_msg:
|
||||
if hasattr(item, "__iter__") and not isinstance(item, str | bytes):
|
||||
result.extend(uni_message_to_v11_list_of_dicts(item))
|
||||
elif hasattr(item, "text") and not isinstance(item, str | bytes):
|
||||
text_value = getattr(item, "text", "")
|
||||
result.append({"type": "text", "data": {"text": str(text_value)}})
|
||||
elif hasattr(item, "url") and not isinstance(item, str | bytes):
|
||||
url_value = getattr(item, "url", "")
|
||||
if isinstance(item, Video):
|
||||
result.append(
|
||||
{"type": "video", "data": {"file": str(url_value)}}
|
||||
)
|
||||
else:
|
||||
result.append(
|
||||
{"type": "image", "data": {"file": str(url_value)}}
|
||||
)
|
||||
else:
|
||||
try:
|
||||
result.append({"type": "text", "data": {"text": str(item)}})
|
||||
except Exception as e:
|
||||
logger.warning(f"无法转换列表元素: {item}, 错误: {e}", "广播")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"消息转换过程中出错: {e}", "广播")
|
||||
|
||||
return [{"type": "text", "data": {"text": str(uni_msg)}}]
|
||||
|
||||
|
||||
def custom_nodes_to_v11_nodes(custom_nodes: list[CustomNode]) -> list[dict]:
|
||||
"""CustomNode列表转V11节点"""
|
||||
v11_nodes = []
|
||||
for node in custom_nodes:
|
||||
v11_content_list = uni_message_to_v11_list_of_dicts(node.content)
|
||||
|
||||
if v11_content_list:
|
||||
v11_nodes.append(
|
||||
{
|
||||
"type": "node",
|
||||
"data": {
|
||||
"user_id": str(node.uid),
|
||||
"nickname": node.name,
|
||||
"content": v11_content_list,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"CustomNode (uid={node.uid}) 内容转换后为空,跳过此节点", "广播"
|
||||
)
|
||||
return v11_nodes
|
||||
@ -10,7 +10,9 @@ from zhenxun.configs.config import Config as gConfig
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
||||
from zhenxun.services.log import logger, logger_
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
from .api.configure import router as configure_router
|
||||
from .api.logs import router as ws_log_routes
|
||||
from .api.logs.log_manager import LOG_STORAGE
|
||||
from .api.menu import router as menu_router
|
||||
@ -81,6 +83,7 @@ BaseApiRouter.include_router(database_router)
|
||||
BaseApiRouter.include_router(plugin_router)
|
||||
BaseApiRouter.include_router(system_router)
|
||||
BaseApiRouter.include_router(menu_router)
|
||||
BaseApiRouter.include_router(configure_router)
|
||||
|
||||
WsApiRouter = APIRouter(prefix="/zhenxun/socket")
|
||||
|
||||
@ -89,7 +92,7 @@ WsApiRouter.include_router(status_routes)
|
||||
WsApiRouter.include_router(chat_routes)
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
@PriorityLifecycle.on_startup(priority=0)
|
||||
async def _():
|
||||
try:
|
||||
# 存储任务引用的列表,防止任务被垃圾回收
|
||||
|
||||
133
zhenxun/builtin_plugins/web_ui/api/configure/__init__.py
Normal file
133
zhenxun/builtin_plugins/web_ui/api/configure/__init__.py
Normal file
@ -0,0 +1,133 @@
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
import nonebot
|
||||
|
||||
from zhenxun.configs.config import BotConfig, Config
|
||||
|
||||
from ...base_model import Result
|
||||
from .data_source import test_db_connection
|
||||
from .model import Setting
|
||||
|
||||
router = APIRouter(prefix="/configure")
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
port = driver.config.port
|
||||
|
||||
BAT_FILE = Path() / "win启动.bat"
|
||||
|
||||
FILE_NAME = ".configure_restart"
|
||||
|
||||
|
||||
@router.post(
|
||||
"/set_configure",
|
||||
response_model=Result,
|
||||
response_class=JSONResponse,
|
||||
description="设置基础配置",
|
||||
)
|
||||
async def _(setting: Setting) -> Result:
|
||||
global port
|
||||
password = Config.get_config("web-ui", "password")
|
||||
if password or BotConfig.db_url:
|
||||
return Result.fail("配置已存在,请先删除DB_URL内容和前端密码再进行设置。")
|
||||
env_file = Path() / ".env.dev"
|
||||
if not env_file.exists():
|
||||
return Result.fail("配置文件.env.dev不存在。")
|
||||
env_text = env_file.read_text(encoding="utf-8")
|
||||
if setting.db_url:
|
||||
if setting.db_url.startswith("sqlite"):
|
||||
base_dir = Path().resolve()
|
||||
# 清理和验证数据库路径
|
||||
db_path_str = setting.db_url.split(":")[-1].strip()
|
||||
# 移除任何可能的路径遍历尝试
|
||||
db_path_str = re.sub(r"[\\/]\.\.[\\/]", "", db_path_str)
|
||||
# 规范化路径
|
||||
db_path = Path(db_path_str).resolve()
|
||||
parent_path = db_path.parent
|
||||
|
||||
# 验证路径是否在项目根目录内
|
||||
try:
|
||||
if not parent_path.absolute().is_relative_to(base_dir):
|
||||
return Result.fail("数据库路径不在项目根目录内。")
|
||||
except ValueError:
|
||||
return Result.fail("无效的数据库路径。")
|
||||
|
||||
# 创建目录
|
||||
try:
|
||||
parent_path.mkdir(parents=True, exist_ok=True)
|
||||
except Exception as e:
|
||||
return Result.fail(f"创建数据库目录失败: {e!s}")
|
||||
|
||||
env_text = env_text.replace('DB_URL = ""', f'DB_URL = "{setting.db_url}"')
|
||||
if setting.superusers:
|
||||
superusers = ", ".join([f'"{s}"' for s in setting.superusers])
|
||||
env_text = re.sub(r"SUPERUSERS=\[.*?\]", f"SUPERUSERS=[{superusers}]", env_text)
|
||||
if setting.host:
|
||||
env_text = env_text.replace("HOST = 127.0.0.1", f"HOST = {setting.host}")
|
||||
if setting.port:
|
||||
env_text = env_text.replace("PORT = 8080", f"PORT = {setting.port}")
|
||||
port = setting.port
|
||||
if setting.username:
|
||||
Config.set_config("web-ui", "username", setting.username)
|
||||
Config.set_config("web-ui", "password", setting.password, True)
|
||||
env_file.write_text(env_text, encoding="utf-8")
|
||||
if BAT_FILE.exists():
|
||||
for file in os.listdir(Path()):
|
||||
if file.startswith(FILE_NAME):
|
||||
Path(file).unlink()
|
||||
flag_file = Path() / f"{FILE_NAME}_{int(time.time())}"
|
||||
flag_file.touch()
|
||||
return Result.ok(BAT_FILE.exists(), info="设置成功,请重启真寻以完成配置!")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/test_db",
|
||||
response_model=Result,
|
||||
response_class=JSONResponse,
|
||||
description="设置基础配置",
|
||||
)
|
||||
async def _(db_url: str) -> Result:
|
||||
result = await test_db_connection(db_url)
|
||||
if isinstance(result, str):
|
||||
return Result.fail(result)
|
||||
return Result.ok(info="数据库连接成功!")
|
||||
|
||||
|
||||
async def run_restart_command(bat_path: Path, port: int):
|
||||
"""在后台执行重启命令"""
|
||||
await asyncio.sleep(1) # 确保 FastAPI 已返回响应
|
||||
subprocess.Popen([bat_path, str(port)], shell=True) # noqa: ASYNC220
|
||||
sys.exit(0) # 退出当前进程
|
||||
|
||||
|
||||
@router.post(
|
||||
"/restart",
|
||||
response_model=Result,
|
||||
response_class=JSONResponse,
|
||||
description="重启",
|
||||
)
|
||||
async def _() -> Result:
|
||||
if not BAT_FILE.exists():
|
||||
return Result.fail("自动重启仅支持意见整合包,请尝试手动重启")
|
||||
flag_file = next(
|
||||
(Path() / file for file in os.listdir(Path()) if file.startswith(FILE_NAME)),
|
||||
None,
|
||||
)
|
||||
if not flag_file or not flag_file.exists():
|
||||
return Result.fail("重启标志文件不存在...")
|
||||
set_time = flag_file.name.split("_")[-1]
|
||||
if time.time() - float(set_time) > 10 * 60:
|
||||
return Result.fail("重启标志文件已过期,请重新设置配置。")
|
||||
flag_file.unlink()
|
||||
try:
|
||||
return Result.ok(info="执行重启命令成功")
|
||||
finally:
|
||||
asyncio.create_task(run_restart_command(BAT_FILE, port)) # noqa: RUF006
|
||||
18
zhenxun/builtin_plugins/web_ui/api/configure/data_source.py
Normal file
18
zhenxun/builtin_plugins/web_ui/api/configure/data_source.py
Normal file
@ -0,0 +1,18 @@
|
||||
from tortoise import Tortoise
|
||||
|
||||
|
||||
async def test_db_connection(db_url: str) -> bool | str:
|
||||
try:
|
||||
# 初始化 Tortoise ORM
|
||||
await Tortoise.init(
|
||||
db_url=db_url,
|
||||
modules={"models": ["__main__"]}, # 这里不需要实际模型
|
||||
)
|
||||
# 测试连接
|
||||
await Tortoise.get_connection("default").execute_query("SELECT 1")
|
||||
return True
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
finally:
|
||||
# 关闭连接
|
||||
await Tortoise.close_connections()
|
||||
16
zhenxun/builtin_plugins/web_ui/api/configure/model.py
Normal file
16
zhenxun/builtin_plugins/web_ui/api/configure/model.py
Normal file
@ -0,0 +1,16 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Setting(BaseModel):
|
||||
superusers: list[str]
|
||||
"""超级用户列表"""
|
||||
db_url: str
|
||||
"""数据库地址"""
|
||||
host: str
|
||||
"""主机地址"""
|
||||
port: int
|
||||
"""端口"""
|
||||
username: str
|
||||
"""前端用户名"""
|
||||
password: str
|
||||
"""前端密码"""
|
||||
@ -5,54 +5,63 @@ from zhenxun.services.log import logger
|
||||
|
||||
from .model import MenuData, MenuItem
|
||||
|
||||
default_menus = [
|
||||
MenuItem(
|
||||
name="仪表盘",
|
||||
module="dashboard",
|
||||
router="/dashboard",
|
||||
icon="dashboard",
|
||||
default=True,
|
||||
),
|
||||
MenuItem(
|
||||
name="真寻控制台",
|
||||
module="command",
|
||||
router="/command",
|
||||
icon="command",
|
||||
),
|
||||
MenuItem(name="插件列表", module="plugin", router="/plugin", icon="plugin"),
|
||||
MenuItem(name="插件商店", module="store", router="/store", icon="store"),
|
||||
MenuItem(name="好友/群组", module="manage", router="/manage", icon="user"),
|
||||
MenuItem(
|
||||
name="数据库管理",
|
||||
module="database",
|
||||
router="/database",
|
||||
icon="database",
|
||||
),
|
||||
MenuItem(name="系统信息", module="system", router="/system", icon="system"),
|
||||
MenuItem(name="关于我们", module="about", router="/about", icon="about"),
|
||||
]
|
||||
|
||||
class MenuManage:
|
||||
|
||||
class MenuManager:
|
||||
def __init__(self) -> None:
|
||||
self.file = DATA_PATH / "web_ui" / "menu.json"
|
||||
self.menu = []
|
||||
if self.file.exists():
|
||||
try:
|
||||
temp_menu = []
|
||||
self.menu = json.load(self.file.open(encoding="utf8"))
|
||||
self_menu_name = [menu["name"] for menu in self.menu]
|
||||
for module in [m.module for m in default_menus]:
|
||||
if module in self_menu_name:
|
||||
temp_menu.append(
|
||||
MenuItem(
|
||||
**next(m for m in self.menu if m["module"] == module)
|
||||
)
|
||||
)
|
||||
else:
|
||||
temp_menu.append(self.__get_menu_model(module))
|
||||
self.menu = temp_menu
|
||||
except Exception as e:
|
||||
logger.warning("菜单文件损坏,已重新生成...", "WebUi", e=e)
|
||||
if not self.menu:
|
||||
self.menu = [
|
||||
MenuItem(
|
||||
name="仪表盘",
|
||||
module="dashboard",
|
||||
router="/dashboard",
|
||||
icon="dashboard",
|
||||
default=True,
|
||||
),
|
||||
MenuItem(
|
||||
name="真寻控制台",
|
||||
module="command",
|
||||
router="/command",
|
||||
icon="command",
|
||||
),
|
||||
MenuItem(
|
||||
name="插件列表", module="plugin", router="/plugin", icon="plugin"
|
||||
),
|
||||
MenuItem(
|
||||
name="插件商店", module="store", router="/store", icon="store"
|
||||
),
|
||||
MenuItem(
|
||||
name="好友/群组", module="manage", router="/manage", icon="user"
|
||||
),
|
||||
MenuItem(
|
||||
name="数据库管理",
|
||||
module="database",
|
||||
router="/database",
|
||||
icon="database",
|
||||
),
|
||||
MenuItem(
|
||||
name="文件管理", module="system", router="/system", icon="system"
|
||||
),
|
||||
MenuItem(
|
||||
name="关于我们", module="about", router="/about", icon="about"
|
||||
),
|
||||
]
|
||||
self.save()
|
||||
self.menu = default_menus
|
||||
self.save()
|
||||
|
||||
def __get_menu_model(self, module: str):
|
||||
return default_menus[
|
||||
next(i for i, m in enumerate(default_menus) if m.module == module)
|
||||
]
|
||||
|
||||
def get_menus(self):
|
||||
return MenuData(menus=self.menu)
|
||||
@ -64,4 +73,4 @@ class MenuManage:
|
||||
json.dump(temp, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
menu_manage = MenuManage()
|
||||
menu_manage = MenuManager()
|
||||
|
||||
@ -13,6 +13,7 @@ from zhenxun.models.bot_connect_log import BotConnectLog
|
||||
from zhenxun.models.chat_history import ChatHistory
|
||||
from zhenxun.models.statistics import Statistics
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
from ....base_model import BaseResultModel, QueryModel
|
||||
@ -31,7 +32,7 @@ driver: Driver = nonebot.get_driver()
|
||||
CONNECT_TIME = 0
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
global CONNECT_TIME
|
||||
CONNECT_TIME = int(time.time())
|
||||
|
||||
@ -8,6 +8,7 @@ from zhenxun.configs.config import BotConfig
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.task_info import TaskInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
from ....base_model import BaseResultModel, QueryModel, Result
|
||||
from ....utils import authentication
|
||||
@ -21,7 +22,7 @@ router = APIRouter(prefix="/database")
|
||||
driver: Driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
for plugin in nonebot.get_loaded_plugins():
|
||||
module = plugin.name
|
||||
|
||||
@ -119,7 +119,7 @@ class ApiDataSource:
|
||||
(await PlatformUtils.get_friend_list(select_bot.bot))[0]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("获取bot好友/群组信息失败...", "WebUi", e=e)
|
||||
logger.warning("获取bot好友/群组数量失败...", "WebUi", e=e)
|
||||
select_bot.group_count = 0
|
||||
select_bot.friend_count = 0
|
||||
select_bot.status = await BotConsole.get_bot_status(select_bot.self_id)
|
||||
|
||||
@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse
|
||||
from zhenxun.utils._build_image import BuildImage
|
||||
|
||||
from ....base_model import Result, SystemFolderSize
|
||||
from ....utils import authentication, get_system_disk
|
||||
from ....utils import authentication, get_system_disk, validate_path
|
||||
from .model import AddFile, DeleteFile, DirFile, RenameFile, SaveFile
|
||||
|
||||
router = APIRouter(prefix="/system")
|
||||
@ -25,22 +25,30 @@ IMAGE_TYPE = ["jpg", "jpeg", "png", "gif", "bmp", "webp", "svg"]
|
||||
description="获取文件列表",
|
||||
)
|
||||
async def _(path: str | None = None) -> Result[list[DirFile]]:
|
||||
base_path = Path(path) if path else Path()
|
||||
data_list = []
|
||||
for file in os.listdir(base_path):
|
||||
file_path = base_path / file
|
||||
is_image = any(file.endswith(f".{t}") for t in IMAGE_TYPE)
|
||||
data_list.append(
|
||||
DirFile(
|
||||
is_file=not file_path.is_dir(),
|
||||
is_image=is_image,
|
||||
name=file,
|
||||
parent=path,
|
||||
size=None if file_path.is_dir() else file_path.stat().st_size,
|
||||
mtime=file_path.stat().st_mtime,
|
||||
try:
|
||||
base_path, error = validate_path(path)
|
||||
if error:
|
||||
return Result.fail(error)
|
||||
if not base_path:
|
||||
return Result.fail("无效的路径")
|
||||
data_list = []
|
||||
for file in os.listdir(base_path):
|
||||
file_path = base_path / file
|
||||
is_image = any(file.endswith(f".{t}") for t in IMAGE_TYPE)
|
||||
data_list.append(
|
||||
DirFile(
|
||||
is_file=not file_path.is_dir(),
|
||||
is_image=is_image,
|
||||
name=file,
|
||||
parent=path,
|
||||
size=None if file_path.is_dir() else file_path.stat().st_size,
|
||||
mtime=file_path.stat().st_mtime,
|
||||
)
|
||||
)
|
||||
)
|
||||
return Result.ok(data_list)
|
||||
sorted(data_list, key=lambda f: f.name)
|
||||
return Result.ok(data_list)
|
||||
except Exception as e:
|
||||
return Result.fail(f"获取文件列表失败: {e!s}")
|
||||
|
||||
|
||||
@router.get(
|
||||
@ -62,8 +70,12 @@ async def _(full_path: str | None = None) -> Result[list[SystemFolderSize]]:
|
||||
description="删除文件",
|
||||
)
|
||||
async def _(param: DeleteFile) -> Result:
|
||||
path = Path(param.full_path)
|
||||
if not path or not path.exists():
|
||||
path, error = validate_path(param.full_path)
|
||||
if error:
|
||||
return Result.fail(error)
|
||||
if not path:
|
||||
return Result.fail("无效的路径")
|
||||
if not path.exists():
|
||||
return Result.warning_("文件不存在...")
|
||||
try:
|
||||
path.unlink()
|
||||
@ -80,8 +92,12 @@ async def _(param: DeleteFile) -> Result:
|
||||
description="删除文件夹",
|
||||
)
|
||||
async def _(param: DeleteFile) -> Result:
|
||||
path = Path(param.full_path)
|
||||
if not path or not path.exists() or path.is_file():
|
||||
path, error = validate_path(param.full_path)
|
||||
if error:
|
||||
return Result.fail(error)
|
||||
if not path:
|
||||
return Result.fail("无效的路径")
|
||||
if not path.exists() or path.is_file():
|
||||
return Result.warning_("文件夹不存在...")
|
||||
try:
|
||||
shutil.rmtree(path.absolute())
|
||||
@ -98,10 +114,14 @@ async def _(param: DeleteFile) -> Result:
|
||||
description="重命名文件",
|
||||
)
|
||||
async def _(param: RenameFile) -> Result:
|
||||
path = (
|
||||
(Path(param.parent) / param.old_name) if param.parent else Path(param.old_name)
|
||||
)
|
||||
if not path or not path.exists():
|
||||
parent_path, error = validate_path(param.parent)
|
||||
if error:
|
||||
return Result.fail(error)
|
||||
if not parent_path:
|
||||
return Result.fail("无效的路径")
|
||||
|
||||
path = (parent_path / param.old_name) if param.parent else Path(param.old_name)
|
||||
if not path.exists():
|
||||
return Result.warning_("文件不存在...")
|
||||
try:
|
||||
path.rename(path.parent / param.name)
|
||||
@ -118,10 +138,14 @@ async def _(param: RenameFile) -> Result:
|
||||
description="重命名文件夹",
|
||||
)
|
||||
async def _(param: RenameFile) -> Result:
|
||||
path = (
|
||||
(Path(param.parent) / param.old_name) if param.parent else Path(param.old_name)
|
||||
)
|
||||
if not path or not path.exists() or path.is_file():
|
||||
parent_path, error = validate_path(param.parent)
|
||||
if error:
|
||||
return Result.fail(error)
|
||||
if not parent_path:
|
||||
return Result.fail("无效的路径")
|
||||
|
||||
path = (parent_path / param.old_name) if param.parent else Path(param.old_name)
|
||||
if not path.exists() or path.is_file():
|
||||
return Result.warning_("文件夹不存在...")
|
||||
try:
|
||||
new_path = path.parent / param.name
|
||||
@ -139,7 +163,13 @@ async def _(param: RenameFile) -> Result:
|
||||
description="新建文件",
|
||||
)
|
||||
async def _(param: AddFile) -> Result:
|
||||
path = (Path(param.parent) / param.name) if param.parent else Path(param.name)
|
||||
parent_path, error = validate_path(param.parent)
|
||||
if error:
|
||||
return Result.fail(error)
|
||||
if not parent_path:
|
||||
return Result.fail("无效的路径")
|
||||
|
||||
path = (parent_path / param.name) if param.parent else Path(param.name)
|
||||
if path.exists():
|
||||
return Result.warning_("文件已存在...")
|
||||
try:
|
||||
@ -157,7 +187,13 @@ async def _(param: AddFile) -> Result:
|
||||
description="新建文件夹",
|
||||
)
|
||||
async def _(param: AddFile) -> Result:
|
||||
path = (Path(param.parent) / param.name) if param.parent else Path(param.name)
|
||||
parent_path, error = validate_path(param.parent)
|
||||
if error:
|
||||
return Result.fail(error)
|
||||
if not parent_path:
|
||||
return Result.fail("无效的路径")
|
||||
|
||||
path = (parent_path / param.name) if param.parent else Path(param.name)
|
||||
if path.exists():
|
||||
return Result.warning_("文件夹已存在...")
|
||||
try:
|
||||
@ -175,7 +211,11 @@ async def _(param: AddFile) -> Result:
|
||||
description="读取文件",
|
||||
)
|
||||
async def _(full_path: str) -> Result:
|
||||
path = Path(full_path)
|
||||
path, error = validate_path(full_path)
|
||||
if error:
|
||||
return Result.fail(error)
|
||||
if not path:
|
||||
return Result.fail("无效的路径")
|
||||
if not path.exists():
|
||||
return Result.warning_("文件不存在...")
|
||||
try:
|
||||
@ -193,9 +233,13 @@ async def _(full_path: str) -> Result:
|
||||
description="读取文件",
|
||||
)
|
||||
async def _(param: SaveFile) -> Result[str]:
|
||||
path = Path(param.full_path)
|
||||
path, error = validate_path(param.full_path)
|
||||
if error:
|
||||
return Result.fail(error)
|
||||
if not path:
|
||||
return Result.fail("无效的路径")
|
||||
try:
|
||||
async with aiofiles.open(path, "w", encoding="utf-8") as f:
|
||||
async with aiofiles.open(str(path), "w", encoding="utf-8") as f:
|
||||
await f.write(param.content)
|
||||
return Result.ok("更新成功!")
|
||||
except Exception as e:
|
||||
@ -210,7 +254,11 @@ async def _(param: SaveFile) -> Result[str]:
|
||||
description="读取图片base64",
|
||||
)
|
||||
async def _(full_path: str) -> Result[str]:
|
||||
path = Path(full_path)
|
||||
path, error = validate_path(full_path)
|
||||
if error:
|
||||
return Result.fail(error)
|
||||
if not path:
|
||||
return Result.fail("无效的路径")
|
||||
if not path.exists():
|
||||
return Result.warning_("文件不存在...")
|
||||
try:
|
||||
|
||||
@ -1,6 +1,12 @@
|
||||
import sys
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import nonebot
|
||||
from strenum import StrEnum
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
from strenum import StrEnum
|
||||
|
||||
from zhenxun.configs.path_config import DATA_PATH, TEMP_PATH
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ async def update_webui_assets():
|
||||
download_url = await GithubUtils.parse_github_url(
|
||||
WEBUI_DIST_GITHUB_URL
|
||||
).get_archive_download_urls()
|
||||
logger.info("开始下载 webui_assets 资源...", COMMAND_NAME)
|
||||
if await AsyncHttpx.download_file(
|
||||
download_url, webui_assets_path, follow_redirects=True
|
||||
):
|
||||
|
||||
@ -2,6 +2,7 @@ import contextlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
@ -28,6 +29,45 @@ if token_file.exists():
|
||||
token_data = json.load(open(token_file, encoding="utf8"))
|
||||
|
||||
|
||||
def validate_path(path_str: str | None) -> tuple[Path | None, str | None]:
|
||||
"""验证路径是否安全
|
||||
|
||||
参数:
|
||||
path_str: 用户输入的路径
|
||||
|
||||
返回:
|
||||
tuple[Path | None, str | None]: (验证后的路径, 错误信息)
|
||||
"""
|
||||
try:
|
||||
if not path_str:
|
||||
return Path().resolve(), None
|
||||
|
||||
# 1. 移除任何可能的路径遍历尝试
|
||||
path_str = re.sub(r"[\\/]\.\.[\\/]", "", path_str)
|
||||
|
||||
# 2. 规范化路径并转换为绝对路径
|
||||
path = Path(path_str).resolve()
|
||||
|
||||
# 3. 获取项目根目录
|
||||
root_dir = Path().resolve()
|
||||
|
||||
# 4. 验证路径是否在项目根目录内
|
||||
try:
|
||||
if not path.is_relative_to(root_dir):
|
||||
return None, "访问路径超出允许范围"
|
||||
except ValueError:
|
||||
return None, "无效的路径格式"
|
||||
|
||||
# 5. 验证路径是否包含任何危险字符
|
||||
if any(c in str(path) for c in ["..", "~", "*", "?", ">", "<", "|", '"']):
|
||||
return None, "路径包含非法字符"
|
||||
|
||||
# 6. 验证路径长度是否合理
|
||||
return (None, "路径长度超出限制") if len(str(path)) > 4096 else (path, None)
|
||||
except Exception as e:
|
||||
return None, f"路径验证失败: {e!s}"
|
||||
|
||||
|
||||
GROUP_HELP_PATH = DATA_PATH / "group_help"
|
||||
SIMPLE_HELP_IMAGE = IMAGE_PATH / "SIMPLE_HELP.png"
|
||||
SIMPLE_DETAIL_HELP_IMAGE = IMAGE_PATH / "SIMPLE_DETAIL_HELP.png"
|
||||
|
||||
@ -29,7 +29,7 @@ class BotConsole(Model):
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "bot_console"
|
||||
table_description = "Bot数据表"
|
||||
|
||||
|
||||
cache_type = CacheType.BOT
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -3,7 +3,6 @@ from tortoise import fields
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
|
||||
class LevelUser(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
"""自增id"""
|
||||
@ -56,7 +55,7 @@ class LevelUser(Model):
|
||||
level: 权限等级
|
||||
group_flag: 是否被自动更新刷新权限 0:是, 1:否.
|
||||
"""
|
||||
if await cls.exists(user_id=user_id, group_id=group_id, user_level=level):
|
||||
if await cls.exists(user_id=user_id, group_id=group_id, level=level):
|
||||
# 权限相同时跳过
|
||||
return
|
||||
await cls.update_or_create(
|
||||
|
||||
@ -58,7 +58,7 @@ class PluginInfo(Model):
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "plugin_info"
|
||||
table_description = "插件基本信息"
|
||||
|
||||
|
||||
cache_type = CacheType.PLUGINS
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
|
||||
|
||||
class ScheduleInfo(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
"""自增id"""
|
||||
bot_id = fields.CharField(
|
||||
255, null=True, default=None, description="任务关联的Bot ID"
|
||||
)
|
||||
"""任务关联的Bot ID"""
|
||||
plugin_name = fields.CharField(255, description="插件模块名")
|
||||
"""插件模块名"""
|
||||
group_id = fields.CharField(
|
||||
255,
|
||||
null=True,
|
||||
description="群组ID, '__ALL_GROUPS__' 表示所有群, 为空表示全局任务",
|
||||
)
|
||||
"""群组ID, 为空表示全局任务"""
|
||||
trigger_type = fields.CharField(
|
||||
max_length=20, default="cron", description="触发器类型 (cron, interval, date)"
|
||||
)
|
||||
"""触发器类型 (cron, interval, date)"""
|
||||
trigger_config = fields.JSONField(description="触发器具体配置")
|
||||
"""触发器具体配置"""
|
||||
job_kwargs = fields.JSONField(
|
||||
default=dict, description="传递给任务函数的额外关键字参数"
|
||||
)
|
||||
"""传递给任务函数的额外关键字参数"""
|
||||
is_enabled = fields.BooleanField(default=True, description="是否启用")
|
||||
"""是否启用"""
|
||||
create_time = fields.DatetimeField(auto_now_add=True)
|
||||
"""创建时间"""
|
||||
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "schedule_info"
|
||||
table_description = "通用定时任务表"
|
||||
0
zhenxun/plugins/__init__.py
Normal file
0
zhenxun/plugins/__init__.py
Normal file
@ -12,7 +12,7 @@ from aiocache.serializers import JsonSerializer
|
||||
import nonebot
|
||||
from nonebot.compat import model_dump
|
||||
from nonebot.utils import is_coroutine_callable
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
@ -121,7 +121,7 @@ class CacheData(BaseModel):
|
||||
lazy_load: bool = True # 默认延迟加载
|
||||
result_model: type | None = None
|
||||
_keys: set[str] = set() # 存储所有缓存键
|
||||
_cache: BaseCache | AioCache
|
||||
_cache: BaseCache | AioCache = PrivateAttr()
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@ -2,7 +2,6 @@ from asyncio import Semaphore
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, ClassVar
|
||||
from typing_extensions import Self
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import nonebot
|
||||
from nonebot.utils import is_coroutine_callable
|
||||
@ -32,17 +31,20 @@ def _():
|
||||
global CACHE_FLAG
|
||||
CACHE_FLAG = True
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
class Model(TortoiseModel):
|
||||
"""
|
||||
自动添加模块
|
||||
|
||||
Args:
|
||||
Model_: Model
|
||||
"""
|
||||
|
||||
sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {}
|
||||
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
if cls.__module__ not in MODELS:
|
||||
MODELS.append(cls.__module__)
|
||||
MODELS.append(cls.__module__)
|
||||
|
||||
if func := getattr(cls, "_run_script", None):
|
||||
SCRIPT_METHOD.append((cls.__module__, func))
|
||||
@ -169,7 +171,7 @@ class Model(TortoiseModel):
|
||||
await CacheRoot.reload(cache_type)
|
||||
|
||||
|
||||
class DbUrlIsNode(HookPriorityException):
|
||||
class DbUrlMissing(Exception):
|
||||
"""
|
||||
数据库链接地址为空
|
||||
"""
|
||||
@ -188,6 +190,7 @@ class DbConnectError(Exception):
|
||||
@PriorityLifecycle.on_startup(priority=1)
|
||||
async def init():
|
||||
if not BotConfig.db_url:
|
||||
# raise DbUrlMissing("数据库配置为空,请在.env.dev中配置DB_URL...")
|
||||
error = f"""
|
||||
**********************************************************************
|
||||
🌟 **************************** 配置为空 ************************* 🌟
|
||||
@ -196,74 +199,13 @@ async def init():
|
||||
***********************************************************************
|
||||
***********************************************************************
|
||||
"""
|
||||
raise DbUrlIsNode("\n" + error.strip())
|
||||
raise DbUrlMissing("\n" + error.strip())
|
||||
try:
|
||||
db_url = BotConfig.db_url
|
||||
url_scheme = db_url.split(":", 1)[0]
|
||||
config = None
|
||||
|
||||
if url_scheme in ("postgres", "postgresql"):
|
||||
# 解析 db_url
|
||||
url = urlparse(db_url)
|
||||
credentials = {
|
||||
"host": url.hostname,
|
||||
"port": url.port or 5432,
|
||||
"user": url.username,
|
||||
"password": url.password,
|
||||
"database": url.path.lstrip("/"),
|
||||
"minsize": 1,
|
||||
"maxsize": 50,
|
||||
}
|
||||
config = {
|
||||
"connections": {
|
||||
"default": {
|
||||
"engine": "tortoise.backends.asyncpg",
|
||||
"credentials": credentials,
|
||||
}
|
||||
},
|
||||
"apps": {
|
||||
"models": {
|
||||
"models": MODELS,
|
||||
"default_connection": "default",
|
||||
}
|
||||
},
|
||||
}
|
||||
elif url_scheme in ("mysql", "mysql+aiomysql"):
|
||||
url = urlparse(db_url)
|
||||
credentials = {
|
||||
"host": url.hostname,
|
||||
"port": url.port or 3306,
|
||||
"user": url.username,
|
||||
"password": url.password,
|
||||
"database": url.path.lstrip("/"),
|
||||
"minsize": 1,
|
||||
"maxsize": 50,
|
||||
}
|
||||
config = {
|
||||
"connections": {
|
||||
"default": {
|
||||
"engine": "tortoise.backends.mysql",
|
||||
"credentials": credentials,
|
||||
}
|
||||
},
|
||||
"apps": {
|
||||
"models": {
|
||||
"models": MODELS,
|
||||
"default_connection": "default",
|
||||
}
|
||||
},
|
||||
}
|
||||
else:
|
||||
# sqlite 或其它,直接用 db_url
|
||||
await Tortoise.init(
|
||||
db_url=db_url,
|
||||
modules={"models": MODELS},
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
if config:
|
||||
await Tortoise.init(config=config)
|
||||
|
||||
await Tortoise.init(
|
||||
db_url=BotConfig.db_url,
|
||||
modules={"models": MODELS},
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
if SCRIPT_METHOD:
|
||||
db = Tortoise.get_connection("default")
|
||||
logger.debug(
|
||||
@ -282,6 +224,7 @@ async def init():
|
||||
logger.debug(f"执行SQL: {sql}")
|
||||
try:
|
||||
await db.execute_query_dict(sql)
|
||||
# await TestSQL.raw(sql)
|
||||
except Exception as e:
|
||||
logger.debug(f"执行SQL: {sql} 错误...", e=e)
|
||||
if sql_list:
|
||||
@ -293,4 +236,4 @@ async def init():
|
||||
|
||||
|
||||
async def disconnect():
|
||||
await connections.close_all()
|
||||
await connections.close_all()
|
||||
|
||||
@ -6,6 +6,7 @@ from nonebot.utils import is_coroutine_callable
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
@ -100,6 +101,6 @@ class PluginInitManager:
|
||||
logger.error(f"执行: {module_path}:remove 失败", e=e)
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
@PriorityLifecycle.on_startup(priority=5)
|
||||
async def _():
|
||||
await PluginInitManager.install_all()
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
import random
|
||||
import sys
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from strenum import StrEnum
|
||||
|
||||
from ._build_image import BuildImage
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
from strenum import StrEnum
|
||||
|
||||
|
||||
class MatType(StrEnum):
|
||||
LINE = "LINE"
|
||||
|
||||
@ -3,9 +3,12 @@ from io import BytesIO
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
from nonebot_plugin_htmlrender import md_to_pic, template_to_pic
|
||||
from PIL.ImageFont import FreeTypeFont
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.configs.path_config import TEMPLATE_PATH
|
||||
|
||||
from ._build_image import BuildImage
|
||||
|
||||
|
||||
@ -283,3 +286,191 @@ class ImageTemplate:
|
||||
width = max(width, w)
|
||||
height += h
|
||||
return width, height
|
||||
|
||||
|
||||
class MarkdownTable:
|
||||
def __init__(self, headers: list[str], rows: list[list[str]]):
|
||||
self.headers = headers
|
||||
self.rows = rows
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
"""将表格转换为Markdown格式"""
|
||||
header_row = "| " + " | ".join(self.headers) + " |"
|
||||
separator_row = "| " + " | ".join(["---"] * len(self.headers)) + " |"
|
||||
data_rows = "\n".join(
|
||||
"| " + " | ".join(map(str, row)) + " |" for row in self.rows
|
||||
)
|
||||
return f"{header_row}\n{separator_row}\n{data_rows}"
|
||||
|
||||
|
||||
class Markdown:
|
||||
def __init__(self, data: list[str] | None = None):
|
||||
if data is None:
|
||||
data = []
|
||||
self._data = data
|
||||
|
||||
def text(self, text: str) -> "Markdown":
|
||||
"""添加Markdown文本"""
|
||||
self._data.append(text)
|
||||
return self
|
||||
|
||||
def head(self, text: str, level: int = 1) -> "Markdown":
|
||||
"""添加Markdown标题"""
|
||||
if level < 1 or level > 6:
|
||||
raise ValueError("标题级别必须在1到6之间")
|
||||
self._data.append(f"{'#' * level} {text}")
|
||||
return self
|
||||
|
||||
def image(self, content: str | Path, add_empty_line: bool = True) -> "Markdown":
|
||||
"""添加Markdown图片
|
||||
|
||||
参数:
|
||||
content: 图片内容,可以是url地址,图片路径或base64字符串.
|
||||
add_empty_line: 默认添加换行.
|
||||
|
||||
返回:
|
||||
Markdown: Markdown
|
||||
"""
|
||||
if isinstance(content, Path):
|
||||
content = str(content.absolute())
|
||||
if content.startswith("base64"):
|
||||
content = f"data:image/png;base64,{content.split('base64://', 1)[-1]}"
|
||||
self._data.append(f"")
|
||||
if add_empty_line:
|
||||
self._add_empty_line()
|
||||
return self
|
||||
|
||||
def quote(self, text: str | list[str]) -> "Markdown":
|
||||
"""添加Markdown引用文本
|
||||
|
||||
参数:
|
||||
text: 引用文本内容,可以是字符串或字符串列表.
|
||||
如果是列表,则每个元素都会被单独引用。
|
||||
|
||||
返回:
|
||||
Markdown: Markdown
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
self._data.append(f"> {text}")
|
||||
elif isinstance(text, list):
|
||||
for t in text:
|
||||
self._data.append(f"> {t}")
|
||||
self._add_empty_line()
|
||||
return self
|
||||
|
||||
def code(self, code: str, language: str = "python") -> "Markdown":
|
||||
"""添加Markdown代码块"""
|
||||
self._data.append(f"```{language}\n{code}\n```")
|
||||
return self
|
||||
|
||||
def table(self, headers: list[str], rows: list[list[str]]) -> "Markdown":
|
||||
"""添加Markdown表格"""
|
||||
table = MarkdownTable(headers, rows)
|
||||
self._data.append(table.to_markdown())
|
||||
return self
|
||||
|
||||
def list(self, items: list[str | list[str]]) -> "Markdown":
|
||||
"""添加Markdown列表"""
|
||||
self._add_empty_line()
|
||||
_text = "\n".join(
|
||||
f"- {item}"
|
||||
if isinstance(item, str)
|
||||
else "\n".join(f"- {sub_item}" for sub_item in item)
|
||||
for item in items
|
||||
)
|
||||
self._data.append(_text)
|
||||
return self
|
||||
|
||||
def _add_empty_line(self):
|
||||
"""添加空行"""
|
||||
self._data.append("")
|
||||
|
||||
async def build(self, width: int = 800, css_path: Path | None = None) -> bytes:
|
||||
"""构建Markdown文本"""
|
||||
if css_path is not None:
|
||||
return await md_to_pic(
|
||||
md="\n".join(self._data), width=width, css_path=str(css_path.absolute())
|
||||
)
|
||||
return await md_to_pic(md="\n".join(self._data), width=width)
|
||||
|
||||
|
||||
class Notebook:
|
||||
def __init__(self, data: list[dict] | None = None):
|
||||
self._data = data if data is not None else []
|
||||
|
||||
def text(self, text: str) -> "Notebook":
|
||||
"""添加Notebook文本"""
|
||||
self._data.append({"type": "paragraph", "text": text})
|
||||
return self
|
||||
|
||||
def head(self, text: str, level: int = 1) -> "Notebook":
|
||||
"""添加Notebook标题"""
|
||||
if not 1 <= level <= 4:
|
||||
raise ValueError("标题级别必须在1-4之间")
|
||||
self._data.append({"type": "heading", "text": text, "level": level})
|
||||
return self
|
||||
|
||||
def image(
|
||||
self,
|
||||
content: str | Path,
|
||||
caption: str | None = None,
|
||||
) -> "Notebook":
|
||||
"""添加Notebook图片
|
||||
|
||||
参数:
|
||||
content: 图片内容,可以是url地址,图片路径或base64字符串.
|
||||
caption: 图片说明.
|
||||
|
||||
返回:
|
||||
Notebook: Notebook
|
||||
"""
|
||||
if isinstance(content, Path):
|
||||
content = str(content.absolute())
|
||||
if content.startswith("base64"):
|
||||
content = f"data:image/png;base64,{content.split('base64://', 1)[-1]}"
|
||||
self._data.append({"type": "image", "src": content, "caption": caption})
|
||||
return self
|
||||
|
||||
def quote(self, text: str | list[str]) -> "Notebook":
|
||||
"""添加Notebook引用文本
|
||||
|
||||
参数:
|
||||
text: 引用文本内容,可以是字符串或字符串列表.
|
||||
如果是列表,则每个元素都会被单独引用。
|
||||
|
||||
返回:
|
||||
Notebook: Notebook
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
self._data.append({"type": "blockquote", "text": text})
|
||||
elif isinstance(text, list):
|
||||
for t in text:
|
||||
self._data.append({"type": "blockquote", "text": text})
|
||||
return self
|
||||
|
||||
def code(self, code: str, language: str = "python") -> "Notebook":
|
||||
"""添加Notebook代码块"""
|
||||
self._data.append({"type": "code", "code": code, "language": language})
|
||||
return self
|
||||
|
||||
def list(self, items: list[str], ordered: bool = False) -> "Notebook":
|
||||
"""添加Notebook列表"""
|
||||
self._data.append({"type": "list", "data": items, "ordered": ordered})
|
||||
return self
|
||||
|
||||
def add_divider(self) -> None:
|
||||
"""添加分隔线"""
|
||||
self._data.append({"type": "divider"})
|
||||
|
||||
async def build(self) -> bytes:
|
||||
"""构建Notebook"""
|
||||
return await template_to_pic(
|
||||
template_path=str((TEMPLATE_PATH / "notebook").absolute()),
|
||||
template_name="main.html",
|
||||
templates={"elements": self._data},
|
||||
pages={
|
||||
"viewport": {"width": 700, "height": 1000},
|
||||
"base_url": f"file://{TEMPLATE_PATH}",
|
||||
},
|
||||
wait=2,
|
||||
)
|
||||
|
||||
@ -10,8 +10,6 @@ class HookPriorityException(BaseException):
|
||||
return self.info
|
||||
|
||||
|
||||
|
||||
|
||||
class NotFoundError(Exception):
|
||||
"""
|
||||
未发现
|
||||
|
||||
@ -21,6 +21,9 @@ CACHED_API_TTL = 300
|
||||
RAW_CONTENT_FORMAT = "https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
|
||||
"""raw content格式"""
|
||||
|
||||
GITEE_RAW_CONTENT_FORMAT = "https://gitee.com/{owner}/{repo}/raw/main/{path}"
|
||||
"""gitee raw content格式"""
|
||||
|
||||
ARCHIVE_URL_FORMAT = "https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
|
||||
"""archive url格式"""
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from zhenxun.utils.http_utils import AsyncHttpx
|
||||
|
||||
from .const import (
|
||||
ARCHIVE_URL_FORMAT,
|
||||
GITEE_RAW_CONTENT_FORMAT,
|
||||
RAW_CONTENT_FORMAT,
|
||||
RELEASE_ASSETS_FORMAT,
|
||||
RELEASE_SOURCE_FORMAT,
|
||||
@ -21,15 +22,11 @@ async def __get_fastest_formats(formats: dict[str, str]) -> list[str]:
|
||||
async def get_fastest_raw_formats() -> list[str]:
|
||||
"""获取最快的raw下载地址格式"""
|
||||
formats: dict[str, str] = {
|
||||
"https://gitee.com/": GITEE_RAW_CONTENT_FORMAT,
|
||||
"https://raw.githubusercontent.com/": RAW_CONTENT_FORMAT,
|
||||
"https://ghproxy.cc/": f"https://ghproxy.cc/{RAW_CONTENT_FORMAT}",
|
||||
"https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{RAW_CONTENT_FORMAT}",
|
||||
"https://gh-proxy.com/": f"https://gh-proxy.com/{RAW_CONTENT_FORMAT}",
|
||||
"https://cdn.jsdelivr.net/": "https://cdn.jsdelivr.net/gh/{owner}/{repo}@{branch}/{path}",
|
||||
#"https://raw.gitcode.com/": "https://raw.gitcode.com/qq_41605780/{repo}/raw/{branch}/{path}", # ✅ 新增 GitCode raw 格式
|
||||
#"https://raw.gitcode.com/": "https://raw.gitcode.com/ATTomato/{repo}/raw/{branch}/{path}"
|
||||
#"https://raw.gitcode.com/": "https://raw.gitcode.com/{owner}/{repo}/raw/{branch}/{path}"
|
||||
|
||||
}
|
||||
return await __get_fastest_formats(formats)
|
||||
|
||||
@ -40,7 +37,6 @@ async def get_fastest_archive_formats() -> list[str]:
|
||||
formats: dict[str, str] = {
|
||||
"https://github.com/": ARCHIVE_URL_FORMAT,
|
||||
"https://ghproxy.cc/": f"https://ghproxy.cc/{ARCHIVE_URL_FORMAT}",
|
||||
"https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{ARCHIVE_URL_FORMAT}",
|
||||
"https://gh-proxy.com/": f"https://gh-proxy.com/{ARCHIVE_URL_FORMAT}",
|
||||
}
|
||||
return await __get_fastest_formats(formats)
|
||||
@ -52,7 +48,6 @@ async def get_fastest_release_formats() -> list[str]:
|
||||
formats: dict[str, str] = {
|
||||
"https://objects.githubusercontent.com/": RELEASE_ASSETS_FORMAT,
|
||||
"https://ghproxy.cc/": f"https://ghproxy.cc/{RELEASE_ASSETS_FORMAT}",
|
||||
"https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{RELEASE_ASSETS_FORMAT}",
|
||||
"https://gh-proxy.com/": f"https://gh-proxy.com/{RELEASE_ASSETS_FORMAT}",
|
||||
}
|
||||
return await __get_fastest_formats(formats)
|
||||
@ -65,4 +60,4 @@ async def get_fastest_release_source_formats() -> list[str]:
|
||||
"https://codeload.github.com/": RELEASE_SOURCE_FORMAT,
|
||||
"https://p.102333.xyz/": f"https://p.102333.xyz/{RELEASE_SOURCE_FORMAT}",
|
||||
}
|
||||
return await __get_fastest_formats(formats)
|
||||
return await __get_fastest_formats(formats)
|
||||
|
||||
@ -1,13 +1,18 @@
|
||||
import contextlib
|
||||
import sys
|
||||
from typing import Protocol
|
||||
|
||||
from aiocache import cached
|
||||
from nonebot.compat import model_dump
|
||||
from pydantic import BaseModel, Field
|
||||
from strenum import StrEnum
|
||||
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
from strenum import StrEnum
|
||||
|
||||
from .const import (
|
||||
CACHED_API_TTL,
|
||||
GIT_API_COMMIT_FORMAT,
|
||||
|
||||
1
zhenxun/utils/html_template/__init__.py
Normal file
1
zhenxun/utils/html_template/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
36
zhenxun/utils/html_template/component.py
Normal file
36
zhenxun/utils/html_template/component.py
Normal file
@ -0,0 +1,36 @@
|
||||
from abc import ABC
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Style(BaseModel):
|
||||
"""常用样式"""
|
||||
|
||||
padding: str = "0px"
|
||||
margin: str = "0px"
|
||||
border: str = "0px"
|
||||
border_radius: str = "0px"
|
||||
text_align: Literal["left", "right", "center"] = "left"
|
||||
color: str = "#000"
|
||||
font_size: str = "16px"
|
||||
|
||||
|
||||
class Component(ABC):
|
||||
def __init__(self, background_color: str = "#fff", is_container: bool = False):
|
||||
self.extra_style = []
|
||||
self.style = Style()
|
||||
self.background_color = background_color
|
||||
self.is_container = is_container
|
||||
self.children = []
|
||||
|
||||
def add_child(self, child: "Component | str"):
|
||||
self.children.append(child)
|
||||
|
||||
def set_style(self, style: Style):
|
||||
self.style = style
|
||||
|
||||
def add_style(self, style: str):
|
||||
self.extra_style.append(style)
|
||||
|
||||
def to_html(self) -> str: ...
|
||||
15
zhenxun/utils/html_template/components/title.py
Normal file
15
zhenxun/utils/html_template/components/title.py
Normal file
@ -0,0 +1,15 @@
|
||||
from ..component import Component, Style
|
||||
from ..container import Row
|
||||
|
||||
|
||||
class Title(Component):
|
||||
def __init__(self, text: str, color: str = "#000"):
|
||||
self.text = text
|
||||
self.color = color
|
||||
|
||||
def build(self):
|
||||
row = Row()
|
||||
style = Style(font_size="36px", color=self.color)
|
||||
row.set_style(style)
|
||||
|
||||
# def
|
||||
31
zhenxun/utils/html_template/container.py
Normal file
31
zhenxun/utils/html_template/container.py
Normal file
@ -0,0 +1,31 @@
|
||||
from .component import Component
|
||||
|
||||
|
||||
class Row(Component):
|
||||
def __init__(self, background_color: str = "#fff"):
|
||||
super().__init__(background_color, True)
|
||||
|
||||
|
||||
class Col(Component):
|
||||
def __init__(self, background_color: str = "#fff"):
|
||||
super().__init__(background_color, True)
|
||||
|
||||
|
||||
class Container(Component):
|
||||
def __init__(self, background_color: str = "#fff"):
|
||||
super().__init__(background_color, True)
|
||||
self.children = []
|
||||
|
||||
|
||||
class GlobalOverview:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.class_name: dict[str, list[str]] = {}
|
||||
self.content = None
|
||||
|
||||
def set_content(self, content: Container):
|
||||
self.content = content
|
||||
|
||||
def add_class(self, class_name: str, contents: list[str]):
|
||||
"""全局样式"""
|
||||
self.class_name[class_name] = contents
|
||||
@ -22,6 +22,4 @@ class MessageManager:
|
||||
|
||||
@classmethod
|
||||
def get(cls, uid: str) -> list[str]:
|
||||
if uid in cls.data:
|
||||
return cls.data[uid]
|
||||
return []
|
||||
return cls.data[uid] if uid in cls.data else []
|
||||
|
||||
@ -1,810 +0,0 @@
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
import copy
|
||||
import inspect
|
||||
import random
|
||||
from typing import ClassVar
|
||||
|
||||
import nonebot
|
||||
from nonebot import get_bots
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.schedule_info import ScheduleInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit"
|
||||
|
||||
|
||||
class SchedulerManager:
|
||||
"""
|
||||
一个通用的、持久化的定时任务管理器,供所有插件使用。
|
||||
"""
|
||||
|
||||
_registered_tasks: ClassVar[
|
||||
dict[str, dict[str, Callable | type[BaseModel] | None]]
|
||||
] = {}
|
||||
_JOB_PREFIX = "zhenxun_schedule_"
|
||||
_running_tasks: ClassVar[set] = set()
|
||||
|
||||
def register(
|
||||
self, plugin_name: str, params_model: type[BaseModel] | None = None
|
||||
) -> Callable:
|
||||
"""
|
||||
注册一个可调度的任务函数。
|
||||
被装饰的函数签名应为 `async def func(group_id: str | None, **kwargs)`
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件的唯一名称 (通常是模块名)。
|
||||
params_model (type[BaseModel], optional): 一个 Pydantic BaseModel 类,
|
||||
用于定义和验证任务函数接受的额外参数。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
|
||||
if plugin_name in self._registered_tasks:
|
||||
logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。")
|
||||
self._registered_tasks[plugin_name] = {
|
||||
"func": func,
|
||||
"model": params_model,
|
||||
}
|
||||
model_name = params_model.__name__ if params_model else "无"
|
||||
logger.debug(
|
||||
f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}"
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_registered_plugins(self) -> list[str]:
|
||||
"""获取所有已注册定时任务的插件列表。"""
|
||||
return list(self._registered_tasks.keys())
|
||||
|
||||
def _get_job_id(self, schedule_id: int) -> str:
|
||||
"""根据数据库ID生成唯一的 APScheduler Job ID。"""
|
||||
return f"{self._JOB_PREFIX}{schedule_id}"
|
||||
|
||||
async def _execute_job(self, schedule_id: int):
|
||||
"""
|
||||
APScheduler 调度的入口函数。
|
||||
根据 schedule_id 处理特定任务、所有群组任务或全局任务。
|
||||
"""
|
||||
schedule = await ScheduleInfo.get_or_none(id=schedule_id)
|
||||
if not schedule or not schedule.is_enabled:
|
||||
logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。")
|
||||
return
|
||||
|
||||
plugin_name = schedule.plugin_name
|
||||
|
||||
task_meta = self._registered_tasks.get(plugin_name)
|
||||
if not task_meta:
|
||||
logger.error(
|
||||
f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。"
|
||||
)
|
||||
schedule.is_enabled = False
|
||||
await schedule.save(update_fields=["is_enabled"])
|
||||
self._remove_aps_job(schedule.id)
|
||||
return
|
||||
|
||||
try:
|
||||
if schedule.bot_id:
|
||||
bot = nonebot.get_bot(schedule.bot_id)
|
||||
else:
|
||||
bot = nonebot.get_bot()
|
||||
logger.debug(
|
||||
f"任务 {schedule_id} 未关联特定Bot,使用默认Bot {bot.self_id}"
|
||||
)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} "
|
||||
f"不在线,本次执行跳过。"
|
||||
)
|
||||
return
|
||||
except ValueError:
|
||||
logger.warning(f"当前没有Bot在线,定时任务 {schedule_id} 跳过。")
|
||||
return
|
||||
|
||||
if schedule.group_id == "__ALL_GROUPS__":
|
||||
await self._execute_for_all_groups(schedule, task_meta, bot)
|
||||
else:
|
||||
await self._execute_for_single_target(schedule, task_meta, bot)
|
||||
|
||||
async def _execute_for_all_groups(
|
||||
self, schedule: ScheduleInfo, task_meta: dict, bot
|
||||
):
|
||||
"""为所有群组执行任务,并处理优先级覆盖。"""
|
||||
plugin_name = schedule.plugin_name
|
||||
|
||||
concurrency_limit = Config.get_config(
|
||||
"SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5
|
||||
)
|
||||
if not isinstance(concurrency_limit, int) or concurrency_limit <= 0:
|
||||
logger.warning(
|
||||
f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。"
|
||||
)
|
||||
concurrency_limit = 5
|
||||
|
||||
logger.info(
|
||||
f"开始执行针对 [所有群组] 的任务 "
|
||||
f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id}),"
|
||||
f"并发限制: {concurrency_limit}"
|
||||
)
|
||||
|
||||
all_gids = set()
|
||||
try:
|
||||
group_list, _ = await PlatformUtils.get_group_list(bot)
|
||||
all_gids.update(
|
||||
g.group_id for g in group_list if g.group_id and not g.channel_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"为 'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e)
|
||||
return
|
||||
|
||||
specific_tasks_gids = set(
|
||||
await ScheduleInfo.filter(
|
||||
plugin_name=plugin_name, group_id__in=list(all_gids)
|
||||
).values_list("group_id", flat=True)
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def worker(gid: str):
|
||||
"""使用 Semaphore 包装单个群组的任务执行"""
|
||||
async with semaphore:
|
||||
temp_schedule = copy.deepcopy(schedule)
|
||||
temp_schedule.group_id = gid
|
||||
await self._execute_for_single_target(temp_schedule, task_meta, bot)
|
||||
await asyncio.sleep(random.uniform(0.1, 0.5))
|
||||
|
||||
tasks_to_run = []
|
||||
for gid in all_gids:
|
||||
if gid in specific_tasks_gids:
|
||||
logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。")
|
||||
continue
|
||||
tasks_to_run.append(worker(gid))
|
||||
|
||||
if tasks_to_run:
|
||||
await asyncio.gather(*tasks_to_run)
|
||||
|
||||
async def _execute_for_single_target(
|
||||
self, schedule: ScheduleInfo, task_meta: dict, bot
|
||||
):
|
||||
"""为单个目标(具体群组或全局)执行任务。"""
|
||||
plugin_name = schedule.plugin_name
|
||||
group_id = schedule.group_id
|
||||
|
||||
try:
|
||||
is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id)
|
||||
if is_blocked:
|
||||
target_desc = f"群 {group_id}" if group_id else "全局"
|
||||
logger.info(
|
||||
f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]"
|
||||
"因功能被禁用而跳过执行。"
|
||||
)
|
||||
return
|
||||
|
||||
task_func = task_meta["func"]
|
||||
job_kwargs = schedule.job_kwargs
|
||||
if not isinstance(job_kwargs, dict):
|
||||
logger.error(
|
||||
f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}"
|
||||
)
|
||||
return
|
||||
|
||||
sig = inspect.signature(task_func)
|
||||
if "bot" in sig.parameters:
|
||||
job_kwargs["bot"] = bot
|
||||
|
||||
logger.info(
|
||||
f"插件 '{plugin_name}' 开始为目标 [{group_id or '全局'}] "
|
||||
f"执行定时任务 (ID: {schedule.id})。"
|
||||
)
|
||||
task = asyncio.create_task(task_func(group_id, **job_kwargs))
|
||||
self._running_tasks.add(task)
|
||||
task.add_done_callback(self._running_tasks.discard)
|
||||
await task
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, "
|
||||
f"目标: {group_id or '全局'}) 时发生异常",
|
||||
e=e,
|
||||
)
|
||||
|
||||
def _validate_and_prepare_kwargs(
|
||||
self, plugin_name: str, job_kwargs: dict | None
|
||||
) -> tuple[bool, str | dict]:
|
||||
"""验证并准备任务参数,应用默认值"""
|
||||
task_meta = self._registered_tasks.get(plugin_name)
|
||||
if not task_meta:
|
||||
return False, f"插件 '{plugin_name}' 未注册。"
|
||||
|
||||
params_model = task_meta.get("model")
|
||||
job_kwargs = job_kwargs if job_kwargs is not None else {}
|
||||
|
||||
if not params_model:
|
||||
if job_kwargs:
|
||||
logger.warning(
|
||||
f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}"
|
||||
)
|
||||
return True, job_kwargs
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类")
|
||||
return False, f"插件 '{plugin_name}' 的参数模型配置错误"
|
||||
|
||||
try:
|
||||
model_validate = getattr(params_model, "model_validate", None)
|
||||
if not model_validate:
|
||||
return False, f"插件 '{plugin_name}' 的参数模型不支持验证"
|
||||
|
||||
validated_model = model_validate(job_kwargs)
|
||||
|
||||
model_dump = getattr(validated_model, "model_dump", None)
|
||||
if not model_dump:
|
||||
return False, f"插件 '{plugin_name}' 的参数模型不支持导出"
|
||||
|
||||
return True, model_dump()
|
||||
except ValidationError as e:
|
||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||
error_str = "\n".join(errors)
|
||||
msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
|
||||
return False, msg
|
||||
|
||||
def _add_aps_job(self, schedule: ScheduleInfo):
|
||||
"""根据 ScheduleInfo 对象添加或更新一个 APScheduler 任务。"""
|
||||
job_id = self._get_job_id(schedule.id)
|
||||
try:
|
||||
scheduler.remove_job(job_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not isinstance(schedule.trigger_config, dict):
|
||||
logger.error(
|
||||
f"任务 {schedule.id} 的 trigger_config 不是字典类型: "
|
||||
f"{type(schedule.trigger_config)}"
|
||||
)
|
||||
return
|
||||
|
||||
scheduler.add_job(
|
||||
self._execute_job,
|
||||
trigger=schedule.trigger_type,
|
||||
id=job_id,
|
||||
misfire_grace_time=300,
|
||||
args=[schedule.id],
|
||||
**schedule.trigger_config,
|
||||
)
|
||||
logger.debug(
|
||||
f"已在 APScheduler 中添加/更新任务: {job_id} "
|
||||
f"with trigger: {schedule.trigger_config}"
|
||||
)
|
||||
|
||||
def _remove_aps_job(self, schedule_id: int):
|
||||
"""移除一个 APScheduler 任务。"""
|
||||
job_id = self._get_job_id(schedule_id)
|
||||
try:
|
||||
scheduler.remove_job(job_id)
|
||||
logger.debug(f"已从 APScheduler 中移除任务: {job_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def add_schedule(
|
||||
self,
|
||||
plugin_name: str,
|
||||
group_id: str | None,
|
||||
trigger_type: str,
|
||||
trigger_config: dict,
|
||||
job_kwargs: dict | None = None,
|
||||
bot_id: str | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
添加或更新一个定时任务。
|
||||
"""
|
||||
if plugin_name not in self._registered_tasks:
|
||||
return False, f"插件 '{plugin_name}' 没有注册可用的定时任务。"
|
||||
|
||||
is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs)
|
||||
if not is_valid:
|
||||
return False, str(result)
|
||||
|
||||
validated_job_kwargs = result
|
||||
|
||||
effective_bot_id = bot_id if group_id == "__ALL_GROUPS__" else None
|
||||
|
||||
search_kwargs = {
|
||||
"plugin_name": plugin_name,
|
||||
"group_id": group_id,
|
||||
}
|
||||
if effective_bot_id:
|
||||
search_kwargs["bot_id"] = effective_bot_id
|
||||
else:
|
||||
search_kwargs["bot_id__isnull"] = True
|
||||
|
||||
defaults = {
|
||||
"trigger_type": trigger_type,
|
||||
"trigger_config": trigger_config,
|
||||
"job_kwargs": validated_job_kwargs,
|
||||
"is_enabled": True,
|
||||
}
|
||||
|
||||
schedule = await ScheduleInfo.filter(**search_kwargs).first()
|
||||
created = False
|
||||
|
||||
if schedule:
|
||||
for key, value in defaults.items():
|
||||
setattr(schedule, key, value)
|
||||
await schedule.save()
|
||||
else:
|
||||
creation_kwargs = {
|
||||
"plugin_name": plugin_name,
|
||||
"group_id": group_id,
|
||||
"bot_id": effective_bot_id,
|
||||
**defaults,
|
||||
}
|
||||
schedule = await ScheduleInfo.create(**creation_kwargs)
|
||||
created = True
|
||||
self._add_aps_job(schedule)
|
||||
action = "设置" if created else "更新"
|
||||
return True, f"已成功{action}插件 '{plugin_name}' 的定时任务。"
|
||||
|
||||
async def add_schedule_for_all(
|
||||
self,
|
||||
plugin_name: str,
|
||||
trigger_type: str,
|
||||
trigger_config: dict,
|
||||
job_kwargs: dict | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""为所有机器人所在的群组添加定时任务。"""
|
||||
if plugin_name not in self._registered_tasks:
|
||||
raise ValueError(f"插件 '{plugin_name}' 没有注册可用的定时任务。")
|
||||
|
||||
groups = set()
|
||||
for bot in get_bots().values():
|
||||
try:
|
||||
group_list, _ = await PlatformUtils.get_group_list(bot)
|
||||
groups.update(
|
||||
g.group_id for g in group_list if g.group_id and not g.channel_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Bot {bot.self_id} 的群列表失败", e=e)
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
for gid in groups:
|
||||
try:
|
||||
success, _ = await self.add_schedule(
|
||||
plugin_name, gid, trigger_type, trigger_config, job_kwargs
|
||||
)
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"为群 {gid} 添加定时任务失败: {e}", e=e)
|
||||
fail_count += 1
|
||||
await asyncio.sleep(0.05)
|
||||
return success_count, fail_count
|
||||
|
||||
async def update_schedule(
|
||||
self,
|
||||
schedule_id: int,
|
||||
trigger_type: str | None = None,
|
||||
trigger_config: dict | None = None,
|
||||
job_kwargs: dict | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""部分更新一个已存在的定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return False, f"未找到 ID 为 {schedule_id} 的任务。"
|
||||
|
||||
updated_fields = []
|
||||
if trigger_config is not None:
|
||||
schedule.trigger_config = trigger_config
|
||||
updated_fields.append("trigger_config")
|
||||
|
||||
if trigger_type is not None and schedule.trigger_type != trigger_type:
|
||||
schedule.trigger_type = trigger_type
|
||||
updated_fields.append("trigger_type")
|
||||
|
||||
if job_kwargs is not None:
|
||||
if not isinstance(schedule.job_kwargs, dict):
|
||||
return False, f"任务 {schedule_id} 的 job_kwargs 数据格式错误。"
|
||||
|
||||
merged_kwargs = schedule.job_kwargs.copy()
|
||||
merged_kwargs.update(job_kwargs)
|
||||
|
||||
is_valid, result = self._validate_and_prepare_kwargs(
|
||||
schedule.plugin_name, merged_kwargs
|
||||
)
|
||||
if not is_valid:
|
||||
return False, str(result)
|
||||
|
||||
schedule.job_kwargs = result # type: ignore
|
||||
updated_fields.append("job_kwargs")
|
||||
|
||||
if not updated_fields:
|
||||
return True, "没有任何需要更新的配置。"
|
||||
|
||||
await schedule.save(update_fields=updated_fields)
|
||||
self._add_aps_job(schedule)
|
||||
return True, f"成功更新了任务 ID: {schedule_id} 的配置。"
|
||||
|
||||
async def remove_schedule(
|
||||
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""移除指定插件和群组的定时任务。"""
|
||||
query = {"plugin_name": plugin_name, "group_id": group_id}
|
||||
if bot_id:
|
||||
query["bot_id"] = bot_id
|
||||
|
||||
schedules = await ScheduleInfo.filter(**query)
|
||||
if not schedules:
|
||||
msg = (
|
||||
f"未找到与 Bot {bot_id} 相关的群 {group_id} "
|
||||
f"的插件 '{plugin_name}' 定时任务。"
|
||||
)
|
||||
return (False, msg)
|
||||
|
||||
for schedule in schedules:
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
|
||||
target_desc = f"群 {group_id}" if group_id else "全局"
|
||||
msg = (
|
||||
f"已取消 Bot {bot_id} 在 [{target_desc}] "
|
||||
f"的插件 '{plugin_name}' 所有定时任务。"
|
||||
)
|
||||
return (True, msg)
|
||||
|
||||
async def remove_schedule_for_all(
|
||||
self, plugin_name: str, bot_id: str | None = None
|
||||
) -> int:
|
||||
"""移除指定插件在所有群组的定时任务。"""
|
||||
query = {"plugin_name": plugin_name}
|
||||
if bot_id:
|
||||
query["bot_id"] = bot_id
|
||||
|
||||
schedules_to_delete = await ScheduleInfo.filter(**query).all()
|
||||
if not schedules_to_delete:
|
||||
return 0
|
||||
|
||||
for schedule in schedules_to_delete:
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return len(schedules_to_delete)
|
||||
|
||||
async def remove_schedules_by_group(self, group_id: str) -> tuple[bool, str]:
|
||||
"""移除指定群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(group_id=group_id)
|
||||
if not schedules:
|
||||
return False, f"群 {group_id} 没有任何定时任务。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return True, f"已成功移除群 {group_id} 的 {count} 个定时任务。"
|
||||
|
||||
async def pause_schedules_by_group(self, group_id: str) -> tuple[int, str]:
|
||||
"""暂停指定群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=True)
|
||||
if not schedules:
|
||||
return 0, f"群 {group_id} 没有正在运行的定时任务可暂停。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.pause_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功暂停群 {group_id} 的 {count} 个定时任务。"
|
||||
|
||||
async def resume_schedules_by_group(self, group_id: str) -> tuple[int, str]:
|
||||
"""恢复指定群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=False)
|
||||
if not schedules:
|
||||
return 0, f"群 {group_id} 没有已暂停的定时任务可恢复。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.resume_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功恢复群 {group_id} 的 {count} 个定时任务。"
|
||||
|
||||
async def pause_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
|
||||
"""暂停指定插件在所有群组的定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=True)
|
||||
if not schedules:
|
||||
return 0, f"插件 '{plugin_name}' 没有正在运行的定时任务可暂停。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.pause_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return (
|
||||
count,
|
||||
f"已成功暂停插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
|
||||
)
|
||||
|
||||
async def resume_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
|
||||
"""恢复指定插件在所有群组的定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=False)
|
||||
if not schedules:
|
||||
return 0, f"插件 '{plugin_name}' 没有已暂停的定时任务可恢复。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.resume_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return (
|
||||
count,
|
||||
f"已成功恢复插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
|
||||
)
|
||||
|
||||
async def pause_schedule_by_plugin_group(
|
||||
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""暂停指定插件在指定群组的定时任务。"""
|
||||
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": True}
|
||||
if bot_id:
|
||||
query["bot_id"] = bot_id
|
||||
|
||||
schedules = await ScheduleInfo.filter(**query)
|
||||
if not schedules:
|
||||
return (
|
||||
False,
|
||||
f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已暂停。",
|
||||
)
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.pause_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已成功暂停群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。",
|
||||
)
|
||||
|
||||
async def resume_schedule_by_plugin_group(
|
||||
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""恢复指定插件在指定群组的定时任务。"""
|
||||
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": False}
|
||||
if bot_id:
|
||||
query["bot_id"] = bot_id
|
||||
|
||||
schedules = await ScheduleInfo.filter(**query)
|
||||
if not schedules:
|
||||
return (
|
||||
False,
|
||||
f"群 {group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已启用。",
|
||||
)
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.resume_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已成功恢复群 {group_id} 的插件 '{plugin_name}' 共 {count} 个定时任务。",
|
||||
)
|
||||
|
||||
async def remove_all_schedules(self) -> tuple[int, str]:
|
||||
"""移除所有群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.all()
|
||||
if not schedules:
|
||||
return 0, "当前没有任何定时任务。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功移除所有群组的 {count} 个定时任务。"
|
||||
|
||||
async def pause_all_schedules(self) -> tuple[int, str]:
|
||||
"""暂停所有群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(is_enabled=True)
|
||||
if not schedules:
|
||||
return 0, "当前没有正在运行的定时任务可暂停。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.pause_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功暂停所有群组的 {count} 个定时任务。"
|
||||
|
||||
async def resume_all_schedules(self) -> tuple[int, str]:
|
||||
"""恢复所有群组的所有定时任务。"""
|
||||
schedules = await ScheduleInfo.filter(is_enabled=False)
|
||||
if not schedules:
|
||||
return 0, "当前没有已暂停的定时任务可恢复。"
|
||||
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
success, _ = await self.resume_schedule(schedule.id)
|
||||
if success:
|
||||
count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return count, f"已成功恢复所有群组的 {count} 个定时任务。"
|
||||
|
||||
async def remove_schedule_by_id(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""通过ID移除指定的定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
|
||||
|
||||
self._remove_aps_job(schedule.id)
|
||||
await schedule.delete()
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已删除插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
||||
f"的定时任务 (ID: {schedule.id})。",
|
||||
)
|
||||
|
||||
async def get_schedule_by_id(self, schedule_id: int) -> ScheduleInfo | None:
|
||||
"""通过ID获取定时任务信息。"""
|
||||
return await ScheduleInfo.get_or_none(id=schedule_id)
|
||||
|
||||
async def get_schedules(
|
||||
self, plugin_name: str, group_id: str | None
|
||||
) -> list[ScheduleInfo]:
|
||||
"""获取特定群组特定插件的所有定时任务。"""
|
||||
return await ScheduleInfo.filter(plugin_name=plugin_name, group_id=group_id)
|
||||
|
||||
async def get_schedule(
|
||||
self, plugin_name: str, group_id: str | None
|
||||
) -> ScheduleInfo | None:
|
||||
"""获取特定群组的定时任务信息。"""
|
||||
return await ScheduleInfo.get_or_none(
|
||||
plugin_name=plugin_name, group_id=group_id
|
||||
)
|
||||
|
||||
async def get_all_schedules(
|
||||
self, plugin_name: str | None = None
|
||||
) -> list[ScheduleInfo]:
|
||||
"""获取所有定时任务信息,可按插件名过滤。"""
|
||||
if plugin_name:
|
||||
return await ScheduleInfo.filter(plugin_name=plugin_name).all()
|
||||
return await ScheduleInfo.all()
|
||||
|
||||
async def get_schedule_status(self, schedule_id: int) -> dict | None:
|
||||
"""获取任务的详细状态。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return None
|
||||
|
||||
job_id = self._get_job_id(schedule.id)
|
||||
job = scheduler.get_job(job_id)
|
||||
|
||||
status = {
|
||||
"id": schedule.id,
|
||||
"bot_id": schedule.bot_id,
|
||||
"plugin_name": schedule.plugin_name,
|
||||
"group_id": schedule.group_id,
|
||||
"is_enabled": schedule.is_enabled,
|
||||
"trigger_type": schedule.trigger_type,
|
||||
"trigger_config": schedule.trigger_config,
|
||||
"job_kwargs": schedule.job_kwargs,
|
||||
"next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if job and job.next_run_time
|
||||
else "N/A",
|
||||
"is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A",
|
||||
}
|
||||
return status
|
||||
|
||||
async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""暂停一个定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule or not schedule.is_enabled:
|
||||
return False, "任务不存在或已暂停。"
|
||||
|
||||
schedule.is_enabled = False
|
||||
await schedule.save(update_fields=["is_enabled"])
|
||||
|
||||
job_id = self._get_job_id(schedule.id)
|
||||
try:
|
||||
scheduler.pause_job(job_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已暂停插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
||||
f"的定时任务 (ID: {schedule.id})。",
|
||||
)
|
||||
|
||||
async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""恢复一个定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule or schedule.is_enabled:
|
||||
return False, "任务不存在或已启用。"
|
||||
|
||||
schedule.is_enabled = True
|
||||
await schedule.save(update_fields=["is_enabled"])
|
||||
|
||||
job_id = self._get_job_id(schedule.id)
|
||||
try:
|
||||
scheduler.resume_job(job_id)
|
||||
except Exception:
|
||||
self._add_aps_job(schedule)
|
||||
|
||||
return (
|
||||
True,
|
||||
f"已恢复插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
||||
f"的定时任务 (ID: {schedule.id})。",
|
||||
)
|
||||
|
||||
async def trigger_now(self, schedule_id: int) -> tuple[bool, str]:
|
||||
"""手动触发一个定时任务。"""
|
||||
schedule = await self.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
|
||||
|
||||
if schedule.plugin_name not in self._registered_tasks:
|
||||
return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。"
|
||||
|
||||
try:
|
||||
await self._execute_job(schedule.id)
|
||||
return (
|
||||
True,
|
||||
f"已手动触发插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
|
||||
f"的定时任务 (ID: {schedule.id})。",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"手动触发任务失败: {e}")
|
||||
return False, f"手动触发任务失败: {e}"
|
||||
|
||||
|
||||
scheduler_manager = SchedulerManager()
|
||||
|
||||
|
||||
@PriorityLifecycle.on_startup(priority=90)
|
||||
async def _load_schedules_from_db():
|
||||
"""在服务启动时从数据库加载并调度所有任务。"""
|
||||
Config.add_plugin_config(
|
||||
"SchedulerManager",
|
||||
SCHEDULE_CONCURRENCY_KEY,
|
||||
5,
|
||||
help="“所有群组”类型定时任务的并发执行数量限制",
|
||||
type=int,
|
||||
)
|
||||
|
||||
logger.info("正在从数据库加载并调度所有定时任务...")
|
||||
schedules = await ScheduleInfo.filter(is_enabled=True).all()
|
||||
count = 0
|
||||
for schedule in schedules:
|
||||
if schedule.plugin_name in scheduler_manager._registered_tasks:
|
||||
scheduler_manager._add_aps_job(schedule)
|
||||
count += 1
|
||||
else:
|
||||
logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。")
|
||||
logger.info(f"定时任务加载完成,共成功加载 {count} 个任务。")
|
||||
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
import random
|
||||
from typing import Literal
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
import nonebot
|
||||
@ -227,7 +227,7 @@ class PlatformUtils:
|
||||
url = None
|
||||
if platform == "qq":
|
||||
if user_id.isdigit():
|
||||
url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=640"
|
||||
url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=160"
|
||||
else:
|
||||
url = f"https://q.qlogo.cn/qqapp/{appid}/{user_id}/640"
|
||||
return await AsyncHttpx.get_content(url) if url else None
|
||||
@ -486,15 +486,134 @@ class PlatformUtils:
|
||||
return target
|
||||
|
||||
|
||||
class BroadcastEngine:
|
||||
def __init__(
|
||||
self,
|
||||
message: str | UniMessage,
|
||||
bot: Bot | list[Bot] | None = None,
|
||||
bot_id: str | set[str] | None = None,
|
||||
ignore_group: list[str] | None = None,
|
||||
check_func: Callable[[Bot, str], Awaitable] | None = None,
|
||||
log_cmd: str | None = None,
|
||||
platform: str | None = None,
|
||||
):
|
||||
"""广播引擎
|
||||
|
||||
参数:
|
||||
message: 广播消息内容
|
||||
bot: 指定bot对象.
|
||||
bot_id: 指定bot id.
|
||||
ignore_group: 忽略群聊列表.
|
||||
check_func: 发送前对群聊检测方法,判断是否发送.
|
||||
log_cmd: 日志标记.
|
||||
platform: 指定平台.
|
||||
|
||||
异常:
|
||||
ValueError: 没有可用的Bot对象
|
||||
"""
|
||||
if ignore_group is None:
|
||||
ignore_group = []
|
||||
self.message = MessageUtils.build_message(message)
|
||||
self.ignore_group = ignore_group
|
||||
self.check_func = check_func
|
||||
self.log_cmd = log_cmd
|
||||
self.platform = platform
|
||||
self.bot_list = []
|
||||
self.count = 0
|
||||
if bot:
|
||||
self.bot_list = [bot] if isinstance(bot, Bot) else bot
|
||||
if isinstance(bot_id, str):
|
||||
bot_id = set(bot_id)
|
||||
if bot_id:
|
||||
for i in bot_id:
|
||||
try:
|
||||
self.bot_list.append(nonebot.get_bot(i))
|
||||
except KeyError:
|
||||
logger.warning(f"Bot:{i} 对象未连接或不存在")
|
||||
if not self.bot_list:
|
||||
raise ValueError("当前没有可用的Bot对象...", log_cmd)
|
||||
|
||||
async def call_check(self, bot: Bot, group_id: str) -> bool:
|
||||
"""运行发送检测函数
|
||||
|
||||
参数:
|
||||
bot: Bot
|
||||
group_id: 群组id
|
||||
|
||||
返回:
|
||||
bool: 是否发送
|
||||
"""
|
||||
if not self.check_func:
|
||||
return True
|
||||
if is_coroutine_callable(self.check_func):
|
||||
is_run = await self.check_func(bot, group_id)
|
||||
else:
|
||||
is_run = self.check_func(bot, group_id)
|
||||
return cast(bool, is_run)
|
||||
|
||||
async def __send_message(self, bot: Bot, group: GroupConsole):
|
||||
"""群组发送消息
|
||||
|
||||
参数:
|
||||
bot: Bot
|
||||
group: GroupConsole
|
||||
"""
|
||||
key = f"{group.group_id}:{group.channel_id}"
|
||||
if not await self.call_check(bot, group.group_id):
|
||||
logger.debug(
|
||||
"广播方法检测运行方法为 False, 已跳过该群组...",
|
||||
self.log_cmd,
|
||||
group_id=group.group_id,
|
||||
)
|
||||
return
|
||||
if target := PlatformUtils.get_target(
|
||||
group_id=group.group_id,
|
||||
channel_id=group.channel_id,
|
||||
):
|
||||
self.ignore_group.append(key)
|
||||
await MessageUtils.build_message(self.message).send(target, bot)
|
||||
logger.debug("广播消息发送成功...", self.log_cmd, target=key)
|
||||
else:
|
||||
logger.warning("广播消息获取Target失败...", self.log_cmd, target=key)
|
||||
|
||||
async def broadcast(self) -> int:
|
||||
"""广播消息
|
||||
|
||||
返回:
|
||||
int: 成功发送次数
|
||||
"""
|
||||
for bot in self.bot_list:
|
||||
if self.platform and self.platform != PlatformUtils.get_platform(bot):
|
||||
continue
|
||||
group_list, _ = await PlatformUtils.get_group_list(bot)
|
||||
if not group_list:
|
||||
continue
|
||||
for group in group_list:
|
||||
if (
|
||||
group.group_id in self.ignore_group
|
||||
or group.channel_id in self.ignore_group
|
||||
):
|
||||
continue
|
||||
try:
|
||||
await self.__send_message(bot, group)
|
||||
await asyncio.sleep(random.randint(1, 3))
|
||||
self.count += 1
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"广播消息发送失败", self.log_cmd, target=group.group_id, e=e
|
||||
)
|
||||
return self.count
|
||||
|
||||
|
||||
async def broadcast_group(
|
||||
message: str | UniMessage,
|
||||
bot: Bot | list[Bot] | None = None,
|
||||
bot_id: str | set[str] | None = None,
|
||||
ignore_group: set[int] | None = None,
|
||||
ignore_group: list[str] = [],
|
||||
check_func: Callable[[Bot, str], Awaitable] | None = None,
|
||||
log_cmd: str | None = None,
|
||||
platform: Literal["qq", "dodo", "kaiheila"] | None = None,
|
||||
):
|
||||
platform: str | None = None,
|
||||
) -> int:
|
||||
"""获取所有Bot或指定Bot对象广播群聊
|
||||
|
||||
参数:
|
||||
@ -505,81 +624,18 @@ async def broadcast_group(
|
||||
check_func: 发送前对群聊检测方法,判断是否发送.
|
||||
log_cmd: 日志标记.
|
||||
platform: 指定平台
|
||||
|
||||
返回:
|
||||
int: 成功发送次数
|
||||
"""
|
||||
if platform and platform not in ["qq", "dodo", "kaiheila"]:
|
||||
raise ValueError("指定平台不支持")
|
||||
if not message:
|
||||
raise ValueError("群聊广播消息不能为空")
|
||||
bot_dict = nonebot.get_bots()
|
||||
bot_list: list[Bot] = []
|
||||
if bot:
|
||||
if isinstance(bot, list):
|
||||
bot_list = bot
|
||||
else:
|
||||
bot_list.append(bot)
|
||||
elif bot_id:
|
||||
_bot_id_list = bot_id
|
||||
if isinstance(bot_id, str):
|
||||
_bot_id_list = [bot_id]
|
||||
for id_ in _bot_id_list:
|
||||
if bot_id in bot_dict:
|
||||
bot_list.append(bot_dict[bot_id])
|
||||
else:
|
||||
logger.warning(f"Bot:{id_} 对象未连接或不存在")
|
||||
else:
|
||||
bot_list = list(bot_dict.values())
|
||||
_used_group = []
|
||||
for _bot in bot_list:
|
||||
try:
|
||||
if platform and platform != PlatformUtils.get_platform(_bot):
|
||||
continue
|
||||
group_list, _ = await PlatformUtils.get_group_list(_bot)
|
||||
if group_list:
|
||||
for group in group_list:
|
||||
key = f"{group.group_id}:{group.channel_id}"
|
||||
try:
|
||||
if (
|
||||
ignore_group
|
||||
and (
|
||||
group.group_id in ignore_group
|
||||
or group.channel_id in ignore_group
|
||||
)
|
||||
) or key in _used_group:
|
||||
logger.debug(
|
||||
"广播方法群组重复, 已跳过...",
|
||||
log_cmd,
|
||||
group_id=group.group_id,
|
||||
)
|
||||
continue
|
||||
is_run = False
|
||||
if check_func:
|
||||
if is_coroutine_callable(check_func):
|
||||
is_run = await check_func(_bot, group.group_id)
|
||||
else:
|
||||
is_run = check_func(_bot, group.group_id)
|
||||
if not is_run:
|
||||
logger.debug(
|
||||
"广播方法检测运行方法为 False, 已跳过...",
|
||||
log_cmd,
|
||||
group_id=group.group_id,
|
||||
)
|
||||
continue
|
||||
target = PlatformUtils.get_target(
|
||||
user_id=None,
|
||||
group_id=group.group_id,
|
||||
channel_id=group.channel_id,
|
||||
)
|
||||
if target:
|
||||
_used_group.append(key)
|
||||
message_list = message
|
||||
await MessageUtils.build_message(message_list).send(
|
||||
target, _bot
|
||||
)
|
||||
logger.debug("发送成功", log_cmd, target=key)
|
||||
await asyncio.sleep(random.randint(1, 3))
|
||||
else:
|
||||
logger.warning("target为空", log_cmd, target=key)
|
||||
except Exception as e:
|
||||
logger.error("发送失败", log_cmd, target=key, e=e)
|
||||
except Exception as e:
|
||||
logger.error(f"Bot: {_bot.self_id} 获取群聊列表失败", command=log_cmd, e=e)
|
||||
if not message.strip():
|
||||
raise ValueError("群聊广播消息不能为空...")
|
||||
return await BroadcastEngine(
|
||||
message=message,
|
||||
bot=bot,
|
||||
bot_id=bot_id,
|
||||
ignore_group=ignore_group,
|
||||
check_func=check_func,
|
||||
log_cmd=log_cmd,
|
||||
platform=platform,
|
||||
).broadcast()
|
||||
|
||||
@ -244,10 +244,8 @@ def is_valid_date(date_text: str, separator: str = "-") -> bool:
|
||||
|
||||
def get_entity_ids(session: Uninfo) -> EntityIDs:
|
||||
"""获取用户id,群组id,频道id
|
||||
|
||||
参数:
|
||||
session: Uninfo
|
||||
|
||||
返回:
|
||||
EntityIDs: 用户id,群组id,频道id
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user