导入最近PR

This commit is contained in:
ATTomatoo 2025-07-05 02:48:59 +08:00
parent daad4c718b
commit 4ec6f39af1
66 changed files with 2527 additions and 2305 deletions

View File

@ -16,6 +16,7 @@ from zhenxun.models.sign_user import SignUser
from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger
from zhenxun.utils.decorator.shop import shop_register
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.manager.resource_manager import ResourceManager
from zhenxun.utils.platform import PlatformUtils
@ -70,7 +71,7 @@ from public.bag_users t1
"""
@driver.on_startup
@PriorityLifecycle.on_startup(priority=5)
async def _():
await ResourceManager.init_resources()
"""签到与用户的数据迁移"""

View File

@ -26,6 +26,21 @@ __plugin_meta__ = PluginMetadata(
_matcher = on_alconna(Alconna("关于"), priority=5, block=True, rule=to_me())
QQ_INFO = """
绪山真寻Bot
版本{version}
简介基于Nonebot2开发支持多平台是一个非常可爱的Bot呀希望与大家要好好相处
""".strip()
INFO = """
绪山真寻Bot
版本{version}
简介基于Nonebot2开发支持多平台是一个非常可爱的Bot呀希望与大家要好好相处
项目地址https://github.com/zhenxun-org/zhenxun_bot
文档地址https://zhenxun-org.github.io/zhenxun_bot/
""".strip()
@_matcher.handle()
async def _(session: Uninfo, arparma: Arparma):
ver_file = Path() / "__version__"
@ -35,25 +50,11 @@ async def _(session: Uninfo, arparma: Arparma):
if text := await f.read():
version = text.split(":")[-1].strip()
if PlatformUtils.is_qbot(session):
info: list[str | Path] = [
f"""
绪山真寻Bot
版本{version}
简介基于Nonebot2开发支持多平台是一个非常可爱的Bot呀希望与大家要好好相处
""".strip()
]
result: list[str | Path] = [QQ_INFO.format(version=version)]
path = DATA_PATH / "about.png"
if path.exists():
info.append(path)
result.append(path)
await MessageUtils.build_message(result).send() # type: ignore
else:
info = [
f"""
绪山真寻Bot
版本{version}
简介基于Nonebot2开发支持多平台是一个非常可爱的Bot呀希望与大家要好好相处
项目地址https://github.com/HibiKier/zhenxun_bot
文档地址https://hibikier.github.io/zhenxun_bot/
""".strip()
]
await MessageUtils.build_message(info).send() # type: ignore
logger.info("查看关于", arparma.header_result, session=session)
await MessageUtils.build_message(INFO.format(version=version)).send()
logger.info("查看关于", arparma.header_result, session=session)

View File

@ -8,6 +8,7 @@ from zhenxun.models.level_user import LevelUser
from zhenxun.services.log import logger
from zhenxun.utils.image_utils import BuildImage, ImageTemplate
async def call_ban(user_id: str):
"""调用ban
@ -18,7 +19,6 @@ async def call_ban(user_id: str):
logger.info("辱骂次数过多,已将用户加入黑名单...", "ban", session=user_id)
class BanManage:
@classmethod
async def build_ban_image(

View File

@ -4,7 +4,8 @@ from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.task_info import TaskInfo
from zhenxun.utils.enum import BlockType, PluginType
from zhenxun.services.cache import Cache
from zhenxun.utils.enum import BlockType, CacheType, PluginType
from zhenxun.utils.exception import GroupInfoNotFound
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
@ -245,10 +246,12 @@ class PluginManage:
参数:
group_id: 群组id
"""
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
status=False
group, _ = await GroupConsole.get_or_create(
group_id=group_id, channel_id__isnull=True
)
group.status = False
await group.save(update_fields=["status"])
@classmethod
async def wake(cls, group_id: str):
"""醒来
@ -256,9 +259,11 @@ class PluginManage:
参数:
group_id: 群组id
"""
await GroupConsole.filter(group_id=group_id, channel_id__isnull=True).update(
status=True
group, _ = await GroupConsole.get_or_create(
group_id=group_id, channel_id__isnull=True
)
group.status = True
await group.save(update_fields=["status"])
@classmethod
async def block(cls, module: str):
@ -267,7 +272,9 @@ class PluginManage:
参数:
module: 模块名
"""
await PluginInfo.filter(module=module).update(status=False)
if plugin := await PluginInfo.get_plugin(module=module):
plugin.status = False
await plugin.save(update_fields=["status"])
@classmethod
async def unblock(cls, module: str):
@ -276,7 +283,9 @@ class PluginManage:
参数:
module: 模块名
"""
await PluginInfo.filter(module=module).update(status=True)
if plugin := await PluginInfo.get_plugin(module=module):
plugin.status = True
await plugin.save(update_fields=["status"])
@classmethod
async def block_group_plugin(cls, plugin_name: str, group_id: str) -> str:

View File

@ -14,6 +14,7 @@ from zhenxun.services.log import logger
from zhenxun.utils._build_image import BuildImage
from zhenxun.utils._image_template import ImageTemplate
from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.platform import PlatformUtils
BASE_PATH = DATA_PATH / "welcome_message"
@ -91,7 +92,7 @@ def migrate(path: Path):
json.dump(new_data, f, ensure_ascii=False, indent=4)
@driver.on_startup
@PriorityLifecycle.on_startup(priority=5)
def _():
"""数据迁移

View File

@ -40,7 +40,13 @@ __plugin_meta__ = PluginMetadata(
value="zhenxun",
help="帮助图片样式 [normal, HTML, zhenxun]",
default_value="zhenxun",
)
),
RegisterConfig(
key="detail_type",
value="zhenxun",
help="帮助详情图片样式 ['normal', 'zhenxun']",
default_value="zhenxun",
),
],
).to_dict(),
)

View File

@ -1,13 +1,19 @@
from pathlib import Path
import nonebot
from nonebot.plugin import PluginMetadata
from nonebot_plugin_htmlrender import template_to_pic
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.path_config import IMAGE_PATH
from zhenxun.configs.config import Config
from zhenxun.configs.path_config import IMAGE_PATH, TEMPLATE_PATH
from zhenxun.configs.utils import PluginExtraData
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.statistics import Statistics
from zhenxun.utils._image_template import ImageTemplate
from zhenxun.utils.enum import PluginType
from zhenxun.utils.image_utils import BuildImage, ImageTemplate
from zhenxun.utils.image_utils import BuildImage
from ._config import (
GROUP_HELP_PATH,
@ -80,9 +86,96 @@ async def get_user_allow_help(user_id: str) -> list[PluginType]:
return type_list
async def get_plugin_help(
user_id: str, name: str, is_superuser: bool
) -> str | BuildImage:
async def get_normal_help(
metadata: PluginMetadata, extra: PluginExtraData, is_superuser: bool
) -> str | bytes:
"""构建默认帮助详情
参数:
metadata: PluginMetadata
extra: PluginExtraData
is_superuser: 是否超级用户帮助
返回:
str | bytes: 返回信息
"""
items = None
if is_superuser:
if usage := extra.superuser_help:
items = {
"简介": metadata.description,
"用法": usage,
}
else:
items = {
"简介": metadata.description,
"用法": metadata.usage,
}
if items:
return (await ImageTemplate.hl_page(metadata.name, items)).pic2bytes()
return "该功能没有帮助信息"
def min_leading_spaces(str_list: list[str]) -> int:
min_spaces = 9999
for s in str_list:
leading_spaces = len(s) - len(s.lstrip(" "))
if leading_spaces < min_spaces:
min_spaces = leading_spaces
return min_spaces if min_spaces != 9999 else 0
def split_text(text: str):
split_text = text.split("\n")
min_spaces = min_leading_spaces(split_text)
if min_spaces > 0:
split_text = [s[min_spaces:] for s in split_text]
return [s.replace(" ", "&nbsp;") for s in split_text]
async def get_zhenxun_help(
module: str, metadata: PluginMetadata, extra: PluginExtraData, is_superuser: bool
) -> str | bytes:
"""构建ZhenXun帮助详情
参数:
module: 模块名
metadata: PluginMetadata
extra: PluginExtraData
is_superuser: 是否超级用户帮助
返回:
str | bytes: 返回信息
"""
call_count = await Statistics.filter(plugin_name=module).count()
usage = metadata.usage
if is_superuser:
if not extra.superuser_help:
return "该功能没有超级用户帮助信息"
usage = extra.superuser_help
return await template_to_pic(
template_path=str((TEMPLATE_PATH / "help_detail").absolute()),
template_name="main.html",
templates={
"title": metadata.name,
"author": extra.author,
"version": extra.version,
"call_count": call_count,
"descriptions": split_text(metadata.description),
"usages": split_text(usage),
},
pages={
"viewport": {"width": 824, "height": 590},
"base_url": f"file://{TEMPLATE_PATH}",
},
wait=2,
)
async def get_plugin_help(user_id: str, name: str, is_superuser: bool) -> str | bytes:
"""获取功能的帮助信息
参数:
@ -100,20 +193,12 @@ async def get_plugin_help(
if plugin:
_plugin = nonebot.get_plugin_by_module_name(plugin.module_path)
if _plugin and _plugin.metadata:
items = None
if is_superuser:
extra = _plugin.metadata.extra
if usage := extra.get("superuser_help"):
items = {
"简介": _plugin.metadata.description,
"用法": usage,
}
extra_data = PluginExtraData(**_plugin.metadata.extra)
if Config.get_config("help", "detail_type") == "zhenxun":
return await get_zhenxun_help(
plugin.module, _plugin.metadata, extra_data, is_superuser
)
else:
items = {
"简介": _plugin.metadata.description,
"用法": _plugin.metadata.usage,
}
if items:
return await ImageTemplate.hl_page(plugin.name, items)
return await get_normal_help(_plugin.metadata, extra_data, is_superuser)
return "糟糕! 该功能没有帮助喔..."
return "没有查找到这个功能噢..."

View File

@ -21,7 +21,7 @@ from zhenxun.utils.message import MessageUtils
__plugin_meta__ = PluginMetadata(
name="笨蛋检测",
description="功能名称当命令检测",
usage="""被动""".strip(),
usage="""当一些笨蛋直接输入功能名称时,提示笨蛋使用帮助指令查看功能帮助""".strip(),
extra=PluginExtraData(
author="HibiKier",
version="0.1",

View File

@ -53,7 +53,7 @@ Config.add_plugin_config(
"hook",
"RECORD_BOT_SENT_MESSAGES",
True,
help="记录bot消息校内",
help="记录bot消息发送",
default_value=True,
type=bool,
)

View File

@ -14,9 +14,7 @@ from .auth_checker import LimitManager, auth
# # 权限检测
@run_preprocessor
async def _(
matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg
):
async def _(matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg):
start_time = time.time()
await auth(
matcher,

View File

@ -20,6 +20,7 @@ from zhenxun.utils.enum import (
PluginLimitType,
PluginType,
)
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .manager import manager
@ -95,7 +96,7 @@ async def _handle_setting(
)
@driver.on_startup
@PriorityLifecycle.on_startup(priority=5)
async def _():
"""
初始化插件数据配置

View File

@ -10,6 +10,7 @@ from zhenxun.models.group_console import GroupConsole
from zhenxun.models.task_info import TaskInfo
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
driver: Driver = nonebot.get_driver()
@ -132,7 +133,7 @@ async def create_schedule(task: Task):
logger.error(f"动态创建定时任务 {task.name}({task.module}) 失败", e=e)
@driver.on_startup
@PriorityLifecycle.on_startup(priority=5)
async def _():
"""
初始化插件数据配置

View File

@ -95,13 +95,6 @@ _matcher = on_alconna(
block=True,
)
_matcher.shortcut(
r"1111",
command="mahiro-bank",
arguments=["test"],
prefix=True,
)
_matcher.shortcut(
r"存款\s*(?P<amount>\d+)?",
command="mahiro-bank",

View File

@ -241,7 +241,7 @@ class BankManager:
@classmethod
async def get_bank_info(cls) -> bytes:
now = datetime.now()
now_start = datetime.now() - timedelta(
now_start = now - timedelta(
hours=now.hour, minutes=now.minute, seconds=now.second
)
(
@ -255,7 +255,9 @@ class BankManager:
MahiroBank.annotate(
amount_sum=Sum("amount"), user_count=Count("id")
).values("amount_sum", "user_count"),
MahiroBankLog.filter(create_time__gt=now_start).count(),
MahiroBankLog.filter(
create_time__gt=now_start, handle_type=BankHandleType.DEPOSIT
).count(),
MahiroBankLog.filter(handle_type=BankHandleType.INTEREST)
.annotate(amount_sum=Sum("amount"))
.values("amount_sum"),

View File

@ -1,17 +1,12 @@
import random
from typing import Any
from nonebot import on_regex
from nonebot.adapters import Bot
from nonebot.params import Depends, RegexGroup
from nonebot.plugin import PluginMetadata
from nonebot.rule import to_me
from nonebot_plugin_alconna import (
Alconna,
Args,
Arparma,
CommandMeta,
Option,
on_alconna,
store_true,
)
from nonebot_plugin_alconna import Alconna, Option, on_alconna, store_true
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.config import BotConfig, Config
@ -59,22 +54,15 @@ __plugin_meta__ = PluginMetadata(
).to_dict(),
)
_nickname_matcher = on_alconna(
Alconna(
"re:(?:以后)?(?:叫我|请叫我|称呼我)",
Args["name?", str],
meta=CommandMeta(compact=True),
),
_nickname_matcher = on_regex(
"(?:以后)?(?:叫我|请叫我|称呼我)(.*)",
rule=to_me(),
priority=5,
block=True,
)
_global_nickname_matcher = on_alconna(
Alconna("设置全局昵称", Args["name?", str], meta=CommandMeta(compact=True)),
rule=to_me(),
priority=5,
block=True,
_global_nickname_matcher = on_regex(
"设置全局昵称(.*)", rule=to_me(), priority=5, block=True
)
_matcher = on_alconna(
@ -129,32 +117,34 @@ CANCEL = [
]
async def CheckNickname(
bot: Bot,
session: Uninfo,
params: Arparma,
):
def CheckNickname():
"""
检查名称是否合法
"""
black_word = Config.get_config("nickname", "BLACK_WORD")
name = params.query("name")
logger.debug(f"昵称检查: {name}", "昵称设置", session=session)
if not name:
await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish(
at_sender=True
)
if session.user.id in bot.config.superusers:
logger.debug(
f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session
)
else:
async def dependency(
bot: Bot,
session: Uninfo,
reg_group: tuple[Any, ...] = RegexGroup(),
):
black_word = Config.get_config("nickname", "BLACK_WORD")
(name,) = reg_group
logger.debug(f"昵称检查: {name}", "昵称设置", session=session)
if not name:
await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish(
at_sender=True
)
if session.user.id in bot.config.superusers:
logger.debug(
f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session
)
return
if len(name) > 20:
await MessageUtils.build_message("昵称可不能超过20个字!").finish(
at_sender=True
)
if name in bot.config.nickname:
await MessageUtils.build_message("笨蛋!休想占用我的名字! ").finish(
await MessageUtils.build_message("笨蛋!休想占用我的名字! #").finish(
at_sender=True
)
if black_word:
@ -172,17 +162,17 @@ async def CheckNickname(
await MessageUtils.build_message(
f"字符 [{word}] 为禁止字符!"
).finish(at_sender=True)
return name
return Depends(dependency)
@_nickname_matcher.handle()
@_nickname_matcher.handle(parameterless=[CheckNickname()])
async def _(
bot: Bot,
session: Uninfo,
name_: Arparma,
uname: str = UserName(),
reg_group: tuple[Any, ...] = RegexGroup(),
):
name = await CheckNickname(bot, session, name_)
(name,) = reg_group
if len(name) < 5 and random.random() < 0.3:
name = "~".join(name)
group_id = None
@ -210,14 +200,13 @@ async def _(
)
@_global_nickname_matcher.handle()
@_global_nickname_matcher.handle(parameterless=[CheckNickname()])
async def _(
bot: Bot,
session: Uninfo,
name_: Arparma,
nickname: str = UserName(),
reg_group: tuple[Any, ...] = RegexGroup(),
):
name = await CheckNickname(bot, session, name_)
(name,) = reg_group
await FriendUser.set_user_nickname(
session.user.id,
name,
@ -238,14 +227,15 @@ async def _(session: Uninfo, uname: str = UserName()):
group_id = session.group.parent.id if session.group.parent else session.group.id
if group_id:
nickname = await GroupInfoUser.get_user_nickname(session.user.id, group_id)
card = uname
else:
nickname = await FriendUser.get_user_nickname(session.user.id)
card = uname
if nickname:
await MessageUtils.build_message(random.choice(REMIND).format(nickname)).finish(
reply_to=True
)
else:
card = uname
await MessageUtils.build_message(
random.choice(
[

View File

@ -147,7 +147,7 @@ class GroupManager:
e=e,
)
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e
# await GroupConsole.filter(group_id=group_id).delete()
#await GroupConsole.filter(group_id=group_id).delete()
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...")
else:
await cls.__handle_add_group(bot, group_id, group)

View File

@ -31,4 +31,4 @@ async def _(session: Uninfo):
await FriendUser.create(
user_id=session.user.id, platform=PlatformUtils.get_platform(session)
)
logger.info("添加当前好友用户信息", "", session=session)
logger.info("添加当前好友用户信息", "", session=session)

View File

@ -10,13 +10,4 @@ DEFAULT_GITHUB_URL = "https://github.com/zhenxun-org/zhenxun_bot_plugins/tree/ma
EXTRA_GITHUB_URL = "https://github.com/zhenxun-org/zhenxun_bot_plugins_index/tree/index"
"""插件库索引github仓库地址"""
GITEE_RAW_URL = "https://gitee.com/two_Dimension/zhenxun_bot_plugins/raw/main"
"""GITEE仓库文件内容"""
GITEE_CONTENTS_URL = (
"https://gitee.com/api/v5/repos/two_Dimension/zhenxun_bot_plugins/contents"
)
"""GITEE仓库文件列表获取"""
LOG_COMMAND = "插件商店"

View File

@ -1,51 +0,0 @@
from nonebot.plugin import PluginMetadata
from zhenxun.configs.utils import PluginExtraData
from zhenxun.utils.enum import PluginType
from . import command # noqa: F401
__plugin_meta__ = PluginMetadata(
name="定时任务管理",
description="查看和管理由 SchedulerManager 控制的定时任务。",
usage="""
📋 定时任务管理 - 支持群聊和私聊操作
🔍 查看任务:
定时任务 查看 [-all] [-g <群号>] [-p <插件>] [--page <页码>]
群聊中: 查看本群任务
私聊中: 必须使用 -g <群号> -all 选项 (SUPERUSER)
📊 任务状态:
定时任务 状态 <任务ID> 任务状态 <任务ID>
查看单个任务的详细信息和状态
任务管理 (SUPERUSER):
定时任务 设置 <插件> [时间选项] [-g <群号> | -g all] [--kwargs <参数>]
定时任务 删除 <任务ID> | -p <插件> [-g <群号>] | -all
定时任务 暂停 <任务ID> | -p <插件> [-g <群号>] | -all
定时任务 恢复 <任务ID> | -p <插件> [-g <群号>] | -all
定时任务 执行 <任务ID>
定时任务 更新 <任务ID> [时间选项] [--kwargs <参数>]
📝 时间选项 (三选一):
--cron "<分> <时> <日> <月> <周>" # 例: --cron "0 8 * * *"
--interval <时间间隔> # 例: --interval 30m, 2h, 10s
--date "<YYYY-MM-DD HH:MM:SS>" # 例: --date "2024-01-01 08:00:00"
--daily "<HH:MM>" # 例: --daily "08:30"
📚 其他功能:
定时任务 插件列表 # 查看所有可设置定时任务的插件 (SUPERUSER)
🏷 别名支持:
查看: ls, list | 设置: add, 开启 | 删除: del, rm, remove, 关闭, 取消
暂停: pause | 恢复: resume | 执行: trigger, run | 状态: status, info
更新: update, modify, 修改 | 插件列表: plugins
""".strip(),
extra=PluginExtraData(
author="HibiKier",
version="0.1.2",
plugin_type=PluginType.SUPERUSER,
is_show=False,
).to_dict(),
)

View File

@ -1,836 +0,0 @@
import asyncio
from datetime import datetime
import re
from nonebot.adapters import Event
from nonebot.adapters.onebot.v11 import Bot
from nonebot.params import Depends
from nonebot.permission import SUPERUSER
from nonebot_plugin_alconna import (
Alconna,
AlconnaMatch,
Args,
Arparma,
Match,
Option,
Query,
Subcommand,
on_alconna,
)
from pydantic import BaseModel, ValidationError
from zhenxun.utils._image_template import ImageTemplate
from zhenxun.utils.manager.schedule_manager import scheduler_manager
def _get_type_name(annotation) -> str:
"""获取类型注解的名称"""
if hasattr(annotation, "__name__"):
return annotation.__name__
elif hasattr(annotation, "_name"):
return annotation._name
else:
return str(annotation)
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.rules import admin_check
def _format_trigger(schedule_status: dict) -> str:
"""将触发器配置格式化为人类可读的字符串"""
trigger_type = schedule_status["trigger_type"]
config = schedule_status["trigger_config"]
if trigger_type == "cron":
minute = config.get("minute", "*")
hour = config.get("hour", "*")
day = config.get("day", "*")
month = config.get("month", "*")
day_of_week = config.get("day_of_week", "*")
if day == "*" and month == "*" and day_of_week == "*":
formatted_hour = hour if hour == "*" else f"{int(hour):02d}"
formatted_minute = minute if minute == "*" else f"{int(minute):02d}"
return f"每天 {formatted_hour}:{formatted_minute}"
else:
return f"Cron: {minute} {hour} {day} {month} {day_of_week}"
elif trigger_type == "interval":
seconds = config.get("seconds", 0)
minutes = config.get("minutes", 0)
hours = config.get("hours", 0)
days = config.get("days", 0)
if days:
trigger_str = f"{days}"
elif hours:
trigger_str = f"{hours} 小时"
elif minutes:
trigger_str = f"{minutes} 分钟"
else:
trigger_str = f"{seconds}"
elif trigger_type == "date":
run_date = config.get("run_date", "未知时间")
trigger_str = f"{run_date}"
else:
trigger_str = f"{trigger_type}: {config}"
return trigger_str
def _format_params(schedule_status: dict) -> str:
"""将任务参数格式化为人类可读的字符串"""
if kwargs := schedule_status.get("job_kwargs"):
kwargs_str = " | ".join(f"{k}: {v}" for k, v in kwargs.items())
return kwargs_str
return "-"
def _parse_interval(interval_str: str) -> dict:
"""增强版解析器,支持 d(天)"""
match = re.match(r"(\d+)([smhd])", interval_str.lower())
if not match:
raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。")
value, unit = int(match.group(1)), match.group(2)
if unit == "s":
return {"seconds": value}
if unit == "m":
return {"minutes": value}
if unit == "h":
return {"hours": value}
if unit == "d":
return {"days": value}
return {}
def _parse_daily_time(time_str: str) -> dict:
"""解析 HH:MM 或 HH:MM:SS 格式的时间为 cron 配置"""
if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str):
hour, minute, second = match.groups()
hour, minute = int(hour), int(minute)
if not (0 <= hour <= 23 and 0 <= minute <= 59):
raise ValueError("小时或分钟数值超出范围。")
cron_config = {
"minute": str(minute),
"hour": str(hour),
"day": "*",
"month": "*",
"day_of_week": "*",
}
if second is not None:
if not (0 <= int(second) <= 59):
raise ValueError("秒数值超出范围。")
cron_config["second"] = str(second)
return cron_config
else:
raise ValueError("时间格式错误,请使用 'HH:MM''HH:MM:SS' 格式。")
async def GetBotId(
bot: Bot,
bot_id_match: Match[str] = AlconnaMatch("bot_id"),
) -> str:
"""获取要操作的Bot ID"""
if bot_id_match.available:
return bot_id_match.result
return bot.self_id
class ScheduleTarget:
"""定时任务操作目标的基类"""
pass
class TargetByID(ScheduleTarget):
"""按任务ID操作"""
def __init__(self, id: int):
self.id = id
class TargetByPlugin(ScheduleTarget):
"""按插件名操作"""
def __init__(
self, plugin: str, group_id: str | None = None, all_groups: bool = False
):
self.plugin = plugin
self.group_id = group_id
self.all_groups = all_groups
class TargetAll(ScheduleTarget):
"""操作所有任务"""
def __init__(self, for_group: str | None = None):
self.for_group = for_group
TargetScope = TargetByID | TargetByPlugin | TargetAll | None
def create_target_parser(subcommand_name: str):
"""
创建一个依赖注入函数用于解析删除暂停恢复等命令的操作目标
"""
async def dependency(
event: Event,
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
group_id: Match[str] = AlconnaMatch("group_id"),
all_enabled: Query[bool] = Query(f"{subcommand_name}.all"),
) -> TargetScope:
if schedule_id.available:
return TargetByID(schedule_id.result)
if plugin_name.available:
p_name = plugin_name.result
if all_enabled.available:
return TargetByPlugin(plugin=p_name, all_groups=True)
elif group_id.available:
gid = group_id.result
if gid.lower() == "all":
return TargetByPlugin(plugin=p_name, all_groups=True)
return TargetByPlugin(plugin=p_name, group_id=gid)
else:
current_group_id = getattr(event, "group_id", None)
if current_group_id:
return TargetByPlugin(plugin=p_name, group_id=str(current_group_id))
else:
await schedule_cmd.finish(
"私聊中操作插件任务必须使用 -g <群号> 或 -all 选项。"
)
if all_enabled.available:
return TargetAll(for_group=group_id.result if group_id.available else None)
return None
return dependency
schedule_cmd = on_alconna(
Alconna(
"定时任务",
Subcommand(
"查看",
Option("-g", Args["target_group_id", str]),
Option("-all", help_text="查看所有群聊 (SUPERUSER)"),
Option("-p", Args["plugin_name", str], help_text="按插件名筛选"),
Option("--page", Args["page", int, 1], help_text="指定页码"),
alias=["ls", "list"],
help_text="查看定时任务",
),
Subcommand(
"设置",
Args["plugin_name", str],
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
Option(
"--daily",
Args["daily_expr", str],
help_text="设置每天执行的时间 (如 08:20)",
),
Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"),
Option("-all", help_text="对所有群生效 (等同于 -g all)"),
Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["add", "开启"],
help_text="设置/开启一个定时任务",
),
Subcommand(
"删除",
Args["schedule_id?", int],
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID"),
Option("-all", help_text="对所有群生效"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["del", "rm", "remove", "关闭", "取消"],
help_text="删除一个或多个定时任务",
),
Subcommand(
"暂停",
Args["schedule_id?", int],
Option("-all", help_text="对当前群所有任务生效"),
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["pause"],
help_text="暂停一个或多个定时任务",
),
Subcommand(
"恢复",
Args["schedule_id?", int],
Option("-all", help_text="对当前群所有任务生效"),
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
Option(
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
),
alias=["resume"],
help_text="恢复一个或多个定时任务",
),
Subcommand(
"执行",
Args["schedule_id", int],
alias=["trigger", "run"],
help_text="立即执行一次任务",
),
Subcommand(
"更新",
Args["schedule_id", int],
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
Option(
"--daily",
Args["daily_expr", str],
help_text="更新每天执行的时间 (如 08:20)",
),
Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"),
alias=["update", "modify", "修改"],
help_text="更新任务配置",
),
Subcommand(
"状态",
Args["schedule_id", int],
alias=["status", "info"],
help_text="查看单个任务的详细状态",
),
Subcommand(
"插件列表",
alias=["plugins"],
help_text="列出所有可用的插件",
),
),
priority=5,
block=True,
rule=admin_check(1),
)
schedule_cmd.shortcut(
"任务状态",
command="定时任务",
arguments=["状态", "{%0}"],
prefix=True,
)
@schedule_cmd.handle()
async def _handle_time_options_mutex(arp: Arparma):
time_options = ["cron", "interval", "date", "daily"]
provided_options = [opt for opt in time_options if arp.query(opt) is not None]
if len(provided_options) > 1:
await schedule_cmd.finish(
f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。"
)
@schedule_cmd.assign("查看")
async def _(
bot: Bot,
event: Event,
target_group_id: Match[str] = AlconnaMatch("target_group_id"),
all_groups: Query[bool] = Query("查看.all"),
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
page: Match[int] = AlconnaMatch("page"),
):
is_superuser = await SUPERUSER(bot, event)
schedules = []
title = ""
current_group_id = getattr(event, "group_id", None)
if not (all_groups.available or target_group_id.available) and not current_group_id:
await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。")
if all_groups.available:
if not is_superuser:
await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。")
schedules = await scheduler_manager.get_all_schedules()
title = "所有群组的定时任务"
elif target_group_id.available:
if not is_superuser:
await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。")
gid = target_group_id.result
schedules = [
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
]
title = f"{gid} 的定时任务"
else:
gid = str(current_group_id)
schedules = [
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
]
title = "本群的定时任务"
if plugin_name.available:
schedules = [s for s in schedules if s.plugin_name == plugin_name.result]
title += f" [插件: {plugin_name.result}]"
if not schedules:
await schedule_cmd.finish("没有找到任何相关的定时任务。")
page_size = 15
current_page = page.result
total_items = len(schedules)
total_pages = (total_items + page_size - 1) // page_size
start_index = (current_page - 1) * page_size
end_index = start_index + page_size
paginated_schedules = schedules[start_index:end_index]
if not paginated_schedules:
await schedule_cmd.finish("这一页没有内容了哦~")
status_tasks = [
scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules
]
all_statuses = await asyncio.gather(*status_tasks)
data_list = [
[
s["id"],
s["plugin_name"],
s.get("bot_id") or "N/A",
s["group_id"] or "全局",
s["next_run_time"],
_format_trigger(s),
_format_params(s),
"✔️ 已启用" if s["is_enabled"] else "⏸️ 已暂停",
]
for s in all_statuses
if s
]
if not data_list:
await schedule_cmd.finish("没有找到任何相关的定时任务。")
img = await ImageTemplate.table_page(
head_text=title,
tip_text=f"{current_page}/{total_pages} 页,共 {total_items} 条任务",
column_name=[
"ID",
"插件",
"Bot ID",
"群组/目标",
"下次运行",
"触发规则",
"参数",
"状态",
],
data_list=data_list,
column_space=20,
)
await MessageUtils.build_message(img).send(reply_to=True)
@schedule_cmd.assign("设置")
async def _(
event: Event,
plugin_name: str,
cron_expr: str | None = None,
interval_expr: str | None = None,
date_expr: str | None = None,
daily_expr: str | None = None,
group_id: str | None = None,
kwargs_str: str | None = None,
all_enabled: Query[bool] = Query("设置.all"),
bot_id_to_operate: str = Depends(GetBotId),
):
if plugin_name not in scheduler_manager._registered_tasks:
await schedule_cmd.finish(
f"插件 '{plugin_name}' 没有注册可用的定时任务。\n"
f"可用插件: {list(scheduler_manager._registered_tasks.keys())}"
)
trigger_type = ""
trigger_config = {}
try:
if cron_expr:
trigger_type = "cron"
parts = cron_expr.split()
if len(parts) != 5:
raise ValueError("Cron 表达式必须有5个部分 (分 时 日 月 周)")
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
trigger_config = dict(zip(cron_keys, parts))
elif interval_expr:
trigger_type = "interval"
trigger_config = _parse_interval(interval_expr)
elif date_expr:
trigger_type = "date"
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
elif daily_expr:
trigger_type = "cron"
trigger_config = _parse_daily_time(daily_expr)
else:
await schedule_cmd.finish(
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
)
except ValueError as e:
await schedule_cmd.finish(f"时间参数解析错误: {e}")
job_kwargs = {}
if kwargs_str:
task_meta = scheduler_manager._registered_tasks[plugin_name]
params_model = task_meta.get("model")
if not params_model:
await schedule_cmd.finish(f"插件 '{plugin_name}' 不支持设置额外参数。")
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
await schedule_cmd.finish(f"插件 '{plugin_name}' 的参数模型配置错误。")
raw_kwargs = {}
try:
for item in kwargs_str.split(","):
key, value = item.strip().split("=", 1)
raw_kwargs[key.strip()] = value
except Exception as e:
await schedule_cmd.finish(
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
)
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
await schedule_cmd.finish(
f"插件 '{plugin_name}' 的参数模型不支持验证。"
)
return
validated_model = model_validate(raw_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
await schedule_cmd.finish(
f"插件 '{plugin_name}' 的参数模型不支持导出。"
)
return
job_kwargs = model_dump()
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
await schedule_cmd.finish(
f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
)
return
target_group_id: str | None
current_group_id = getattr(event, "group_id", None)
if group_id and group_id.lower() == "all":
target_group_id = "__ALL_GROUPS__"
elif all_enabled.available:
target_group_id = "__ALL_GROUPS__"
elif group_id:
target_group_id = group_id
elif current_group_id:
target_group_id = str(current_group_id)
else:
await schedule_cmd.finish(
"私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。"
)
return
success, msg = await scheduler_manager.add_schedule(
plugin_name,
target_group_id,
trigger_type,
trigger_config,
job_kwargs,
bot_id=bot_id_to_operate,
)
if target_group_id == "__ALL_GROUPS__":
target_desc = f"所有群组 (Bot: {bot_id_to_operate})"
elif target_group_id is None:
target_desc = "全局"
else:
target_desc = f"群组 {target_group_id}"
if success:
await schedule_cmd.finish(f"已成功为 [{target_desc}] {msg}")
else:
await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败: {msg}")
@schedule_cmd.assign("删除")
async def _(
target: TargetScope = Depends(create_target_parser("删除")),
bot_id_to_operate: str = Depends(GetBotId),
):
if isinstance(target, TargetByID):
_, message = await scheduler_manager.remove_schedule_by_id(target.id)
await schedule_cmd.finish(message)
elif isinstance(target, TargetByPlugin):
p_name = target.plugin
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(f"未找到插件 '{p_name}'")
if target.all_groups:
removed_count = await scheduler_manager.remove_schedule_for_all(
p_name, bot_id=bot_id_to_operate
)
message = (
f"已取消了 {removed_count} 个群组的插件 '{p_name}' 定时任务。"
if removed_count > 0
else f"没有找到插件 '{p_name}' 的定时任务。"
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.remove_schedule(
p_name, target.group_id, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(message)
elif isinstance(target, TargetAll):
if target.for_group:
_, message = await scheduler_manager.remove_schedules_by_group(
target.for_group
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.remove_all_schedules()
await schedule_cmd.finish(message)
else:
await schedule_cmd.finish(
"删除任务失败请提供任务ID或通过 -p <插件> 或 -all 指定要删除的任务。"
)
@schedule_cmd.assign("暂停")
async def _(
target: TargetScope = Depends(create_target_parser("暂停")),
bot_id_to_operate: str = Depends(GetBotId),
):
if isinstance(target, TargetByID):
_, message = await scheduler_manager.pause_schedule(target.id)
await schedule_cmd.finish(message)
elif isinstance(target, TargetByPlugin):
p_name = target.plugin
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(f"未找到插件 '{p_name}'")
if target.all_groups:
_, message = await scheduler_manager.pause_schedules_by_plugin(p_name)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.pause_schedule_by_plugin_group(
p_name, target.group_id, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(message)
elif isinstance(target, TargetAll):
if target.for_group:
_, message = await scheduler_manager.pause_schedules_by_group(
target.for_group
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.pause_all_schedules()
await schedule_cmd.finish(message)
else:
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
@schedule_cmd.assign("恢复")
async def _(
target: TargetScope = Depends(create_target_parser("恢复")),
bot_id_to_operate: str = Depends(GetBotId),
):
if isinstance(target, TargetByID):
_, message = await scheduler_manager.resume_schedule(target.id)
await schedule_cmd.finish(message)
elif isinstance(target, TargetByPlugin):
p_name = target.plugin
if p_name not in scheduler_manager.get_registered_plugins():
await schedule_cmd.finish(f"未找到插件 '{p_name}'")
if target.all_groups:
_, message = await scheduler_manager.resume_schedules_by_plugin(p_name)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.resume_schedule_by_plugin_group(
p_name, target.group_id, bot_id=bot_id_to_operate
)
await schedule_cmd.finish(message)
elif isinstance(target, TargetAll):
if target.for_group:
_, message = await scheduler_manager.resume_schedules_by_group(
target.for_group
)
await schedule_cmd.finish(message)
else:
_, message = await scheduler_manager.resume_all_schedules()
await schedule_cmd.finish(message)
else:
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
@schedule_cmd.assign("执行")
async def _(schedule_id: int):
_, message = await scheduler_manager.trigger_now(schedule_id)
await schedule_cmd.finish(message)
@schedule_cmd.assign("更新")
async def _(
schedule_id: int,
cron_expr: str | None = None,
interval_expr: str | None = None,
date_expr: str | None = None,
daily_expr: str | None = None,
kwargs_str: str | None = None,
):
if not any([cron_expr, interval_expr, date_expr, daily_expr, kwargs_str]):
await schedule_cmd.finish(
"请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)"
)
trigger_config = None
trigger_type = None
try:
if cron_expr:
trigger_type = "cron"
parts = cron_expr.split()
if len(parts) != 5:
raise ValueError("Cron 表达式必须有5个部分")
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
trigger_config = dict(zip(cron_keys, parts))
elif interval_expr:
trigger_type = "interval"
trigger_config = _parse_interval(interval_expr)
elif date_expr:
trigger_type = "date"
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
elif daily_expr:
trigger_type = "cron"
trigger_config = _parse_daily_time(daily_expr)
except ValueError as e:
await schedule_cmd.finish(f"时间参数解析错误: {e}")
job_kwargs = None
if kwargs_str:
schedule = await scheduler_manager.get_schedule_by_id(schedule_id)
if not schedule:
await schedule_cmd.finish(f"未找到 ID 为 {schedule_id} 的任务。")
task_meta = scheduler_manager._registered_tasks.get(schedule.plugin_name)
if not task_meta or not (params_model := task_meta.get("model")):
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 未定义参数模型,无法更新参数。"
)
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 的参数模型配置错误。"
)
raw_kwargs = {}
try:
for item in kwargs_str.split(","):
key, value = item.strip().split("=", 1)
raw_kwargs[key.strip()] = value
except Exception as e:
await schedule_cmd.finish(
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
)
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 的参数模型不支持验证。"
)
return
validated_model = model_validate(raw_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
await schedule_cmd.finish(
f"插件 '{schedule.plugin_name}' 的参数模型不支持导出。"
)
return
job_kwargs = model_dump(exclude_unset=True)
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
await schedule_cmd.finish(f"更新的参数验证失败:\n{error_str}")
return
_, message = await scheduler_manager.update_schedule(
schedule_id, trigger_type, trigger_config, job_kwargs
)
await schedule_cmd.finish(message)
@schedule_cmd.assign("插件列表")
async def _():
registered_plugins = scheduler_manager.get_registered_plugins()
if not registered_plugins:
await schedule_cmd.finish("当前没有已注册的定时任务插件。")
message_parts = ["📋 已注册的定时任务插件:"]
for i, plugin_name in enumerate(registered_plugins, 1):
task_meta = scheduler_manager._registered_tasks[plugin_name]
params_model = task_meta.get("model")
if not params_model:
message_parts.append(f"{i}. {plugin_name} - 无参数")
continue
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
message_parts.append(f"{i}. {plugin_name} - ⚠️ 参数模型配置错误")
continue
model_fields = getattr(params_model, "model_fields", None)
if model_fields:
param_info = ", ".join(
f"{field_name}({_get_type_name(field_info.annotation)})"
for field_name, field_info in model_fields.items()
)
message_parts.append(f"{i}. {plugin_name} - 参数: {param_info}")
else:
message_parts.append(f"{i}. {plugin_name} - 无参数")
await schedule_cmd.finish("\n".join(message_parts))
@schedule_cmd.assign("状态")
async def _(schedule_id: int):
status = await scheduler_manager.get_schedule_status(schedule_id)
if not status:
await schedule_cmd.finish(f"未找到ID为 {schedule_id} 的定时任务。")
info_lines = [
f"📋 定时任务详细信息 (ID: {schedule_id})",
"--------------------",
f"▫️ 插件: {status['plugin_name']}",
f"▫️ Bot ID: {status.get('bot_id') or '默认'}",
f"▫️ 目标: {status['group_id'] or '全局'}",
f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}",
f"▫️ 下次运行: {status['next_run_time']}",
f"▫️ 触发规则: {_format_trigger(status)}",
f"▫️ 任务参数: {_format_params(status)}",
]
await schedule_cmd.finish("\n".join(info_lines))

View File

@ -1,67 +1,8 @@
from asyncio.exceptions import TimeoutError
import aiofiles
import nonebot
from nonebot.drivers import Driver
from nonebot_plugin_apscheduler import scheduler
import ujson as json
from zhenxun.configs.path_config import TEXT_PATH
from zhenxun.models.group_console import GroupConsole
from zhenxun.services.log import logger
from zhenxun.utils.http_utils import AsyncHttpx
driver: Driver = nonebot.get_driver()
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
@driver.on_startup
async def update_city():
"""
部分插件需要中国省份城市
这里直接更新避免插件内代码重复
"""
china_city = TEXT_PATH / "china_city.json"
if not china_city.exists():
data = {}
try:
logger.debug("开始更新城市列表...")
res = await AsyncHttpx.get(
"http://www.weather.com.cn/data/city3jdata/china.html", timeout=5
)
res.encoding = "utf8"
provinces_data = json.loads(res.text)
for province in provinces_data.keys():
data[provinces_data[province]] = []
res = await AsyncHttpx.get(
f"http://www.weather.com.cn/data/city3jdata/provshi/{province}.html",
timeout=5,
)
res.encoding = "utf8"
city_data = json.loads(res.text)
for city in city_data.keys():
data[provinces_data[province]].append(city_data[city])
async with aiofiles.open(china_city, "w", encoding="utf8") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
logger.info("自动更新城市列表完成.....")
except TimeoutError as e:
logger.warning("自动更新城市列表超时...", e=e)
except ValueError as e:
logger.warning("自动城市列表失败.....", e=e)
except Exception as e:
logger.error("自动城市列表未知错误", e=e)
# 自动更新城市列表
@scheduler.scheduled_job(
"cron",
hour=6,
minute=1,
)
async def _():
await update_city()
@driver.on_startup
@PriorityLifecycle.on_startup(priority=5)
async def _():
"""开启/禁用插件格式修改"""
_, is_create = await GroupConsole.get_or_create(group_id=133133133)

View File

@ -5,7 +5,9 @@ from nonebot_plugin_alconna import (
AlconnaQuery,
Args,
Arparma,
At,
Match,
MultiVar,
Option,
Query,
Subcommand,
@ -33,6 +35,7 @@ __plugin_meta__ = PluginMetadata(
usage="""
商品操作
指令
商店
我的金币
我的道具
使用道具 [名称/Id]
@ -46,6 +49,7 @@ __plugin_meta__ = PluginMetadata(
plugin_type=PluginType.NORMAL,
menu_type="商店",
commands=[
Command(command="商店"),
Command(command="我的金币"),
Command(command="我的道具"),
Command(command="购买道具"),
@ -74,13 +78,21 @@ _matcher = on_alconna(
Subcommand("my-cost", help_text="我的金币"),
Subcommand("my-props", help_text="我的道具"),
Subcommand("buy", Args["name?", str]["num?", int], help_text="购买道具"),
Subcommand("use", Args["name?", str]["num?", int], help_text="使用道具"),
Subcommand("gold-list", Args["num?", int], help_text="金币排行"),
),
priority=5,
block=True,
)
_use_matcher = on_alconna(
Alconna(
"使用道具",
Args["name?", str]["num?", int]["at_users?", MultiVar(At)],
),
priority=5,
block=True,
)
_matcher.shortcut(
"我的金币",
command="商店",
@ -102,13 +114,6 @@ _matcher.shortcut(
prefix=True,
)
_matcher.shortcut(
"使用道具(?P<name>.*?)",
command="商店",
arguments=["use", "{name}"],
prefix=True,
)
_matcher.shortcut(
"金币排行",
command="商店",
@ -172,7 +177,7 @@ async def _(
await MessageUtils.build_message(result).send(reply_to=True)
@_matcher.assign("use")
@_use_matcher.handle()
async def _(
bot: Bot,
event: Event,
@ -181,6 +186,7 @@ async def _(
arparma: Arparma,
name: Match[str],
num: Query[int] = AlconnaQuery("num", 1),
at_users: Query[list[At]] = AlconnaQuery("at_users", []),
):
if not name.available:
await MessageUtils.build_message(
@ -188,7 +194,7 @@ async def _(
).finish(reply_to=True)
try:
result = await ShopManage.use(
bot, event, session, message, name.result, num.result, ""
bot, event, session, message, name.result, num.result, "", at_users.result
)
logger.info(
f"使用道具 {name.result}, 数量: {num.result}",

View File

@ -8,7 +8,7 @@ from typing import Any, Literal
from nonebot.adapters import Bot, Event
from nonebot.compat import model_dump
from nonebot_plugin_alconna import UniMessage, UniMsg
from nonebot_plugin_alconna import At, UniMessage, UniMsg
from nonebot_plugin_uninfo import Uninfo
from pydantic import BaseModel, Field, create_model
from tortoise.expressions import Q
@ -48,6 +48,10 @@ class Goods(BaseModel):
"""model"""
session: Uninfo | None = None
"""Uninfo"""
at_user: str | None = None
"""At对象"""
at_users: list[str] = []
"""At对象列表"""
class ShopParam(BaseModel):
@ -73,6 +77,10 @@ class ShopParam(BaseModel):
"""Uninfo"""
message: UniMsg
"""UniMessage"""
at_user: str | None = None
"""At对象"""
at_users: list[str] = []
"""At对象列表"""
extra_data: dict[str, Any] = Field(default_factory=dict)
"""额外数据"""
@ -156,6 +164,7 @@ class ShopManage:
goods: Goods,
num: int,
text: str,
at_users: list[str] = [],
) -> tuple[ShopParam, dict[str, Any]]:
"""构造参数
@ -165,6 +174,7 @@ class ShopManage:
goods_name: 商品名称
num: 数量
text: 其他信息
at_users: at用户
"""
group_id = None
if session.group:
@ -172,6 +182,7 @@ class ShopManage:
session.group.parent.id if session.group.parent else session.group.id
)
_kwargs = goods.params
at_user = at_users[0] if at_users else None
model = goods.model(
**{
"goods_name": goods.name,
@ -183,6 +194,8 @@ class ShopManage:
"text": text,
"session": session,
"message": message,
"at_user": at_user,
"at_users": at_users,
}
)
return model, {
@ -194,6 +207,8 @@ class ShopManage:
"num": num,
"text": text,
"goods_name": goods.name,
"at_user": at_user,
"at_users": at_users,
}
@classmethod
@ -223,6 +238,7 @@ class ShopManage:
**param.extra_data,
"session": session,
"message": message,
"shop_param": ShopParam,
}
for key in list(param_json.keys()):
if key not in args:
@ -308,6 +324,7 @@ class ShopManage:
goods_name: str,
num: int,
text: str,
at_users: list[At] = [],
) -> str | UniMessage | None:
"""使用道具
@ -319,6 +336,7 @@ class ShopManage:
goods_name: 商品名称
num: 使用数量
text: 其他信息
at_users: at用户
返回:
str | MessageFactory | None: 使用完成后返回信息
@ -339,8 +357,9 @@ class ShopManage:
goods = cls.uuid2goods.get(goods_info.uuid)
if not goods or not goods.func:
return f"{goods_info.goods_name} 未注册使用函数, 无法使用..."
at_user_ids = [at.target for at in at_users]
param, kwargs = cls.__build_params(
bot, event, session, message, goods, num, text
bot, event, session, message, goods, num, text, at_user_ids
)
if num > param.max_num_limit:
return f"{goods_info.goods_name} 单次使用最大数量为{param.max_num_limit}..."
@ -480,10 +499,13 @@ class ShopManage:
if not user.props:
return None
user.props = {uuid: count for uuid, count in user.props.items() if count > 0}
goods_list = await GoodsInfo.filter(uuid__in=user.props.keys()).all()
goods_by_uuid = {item.uuid: item for item in goods_list}
user.props = {
uuid: count
for uuid, count in user.props.items()
if count > 0 and goods_by_uuid.get(uuid)
}
table_rows = []
for i, prop_uuid in enumerate(user.props):

View File

@ -10,7 +10,6 @@ from nonebot_plugin_alconna import (
store_true,
)
from nonebot_plugin_apscheduler import scheduler
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.utils import (
Command,
@ -23,7 +22,7 @@ from zhenxun.utils.depends import UserName
from zhenxun.utils.message import MessageUtils
from ._data_source import SignManage
from .goods_register import driver # noqa: F401
from .goods_register import Uninfo
from .utils import clear_sign_data_pic
__plugin_meta__ = PluginMetadata(

View File

@ -1,7 +1,6 @@
from decimal import Decimal
import nonebot
from nonebot.drivers import Driver
from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.sign_user import SignUser
@ -9,14 +8,7 @@ from zhenxun.models.user_console import UserConsole
from zhenxun.utils.decorator.shop import shop_register
from zhenxun.utils.platform import PlatformUtils
driver: Driver = nonebot.get_driver()
# @driver.on_startup
# async def _():
# """
# 导入内置的三个商品
# """
driver = nonebot.get_driver()
@shop_register(

View File

@ -16,6 +16,7 @@ from zhenxun.models.sign_log import SignLog
from zhenxun.models.sign_user import SignUser
from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.image_utils import BuildImage
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.platform import PlatformUtils
from .config import (
@ -54,7 +55,7 @@ LG_MESSAGE = [
]
@driver.on_startup
@PriorityLifecycle.on_startup(priority=5)
async def init_image():
SIGN_RESOURCE_PATH.mkdir(parents=True, exist_ok=True)
SIGN_TODAY_CARD_PATH.mkdir(exist_ok=True, parents=True)

View File

@ -1,32 +1,77 @@
from typing import Annotated
from nonebot import on_command
from nonebot.adapters import Bot
from nonebot.params import Command
from arclet.alconna import AllParam
from nepattern import UnionPattern
from nonebot.adapters import Bot, Event
from nonebot.permission import SUPERUSER
from nonebot.plugin import PluginMetadata
from nonebot.rule import to_me
from nonebot_plugin_alconna import Text as alcText
from nonebot_plugin_alconna import UniMsg
import nonebot_plugin_alconna as alc
from nonebot_plugin_alconna import (
Alconna,
Args,
on_alconna,
)
from nonebot_plugin_alconna.uniseg.segment import (
At,
AtAll,
Audio,
Button,
Emoji,
File,
Hyper,
Image,
Keyboard,
Reference,
Reply,
Text,
Video,
Voice,
)
from nonebot_plugin_session import EventSession
from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task
from zhenxun.services.log import logger
from zhenxun.utils.enum import PluginType
from zhenxun.utils.message import MessageUtils
from ._data_source import BroadcastManage
from .broadcast_manager import BroadcastManager
from .message_processor import (
_extract_broadcast_content,
get_broadcast_target_groups,
send_broadcast_and_notify,
)
BROADCAST_SEND_DELAY_RANGE = (1, 3)
__plugin_meta__ = PluginMetadata(
name="广播",
description="昭告天下!",
usage="""
广播 [消息] [图片]
示例广播 你们好
广播 [消息内容]
- 直接发送消息到除当前群组外的所有群组
- 支持文本图片@表情视频等多种消息类型
- 示例广播 你们好
- 示例广播 [图片] 新活动开始啦
广播 + 引用消息
- 将引用的消息作为广播内容发送
- 支持引用普通消息或合并转发消息
- 示例(引用一条消息) 广播
广播撤回
- 撤回最近一次由您触发的广播消息
- 仅能撤回短时间内的消息
- 示例广播撤回
特性
- 在群组中使用广播时不会将消息发送到当前群组
- 在私聊中使用广播时会发送到所有群组
别名
- bc (广播的简写)
- recall (广播撤回的别名)
""".strip(),
extra=PluginExtraData(
author="HibiKier",
version="0.1",
version="1.2",
plugin_type=PluginType.SUPERUSER,
configs=[
RegisterConfig(
@ -42,26 +87,106 @@ __plugin_meta__ = PluginMetadata(
).to_dict(),
)
_matcher = on_command(
"广播", priority=1, permission=SUPERUSER, block=True, rule=to_me()
AnySeg = (
UnionPattern(
[
Text,
Image,
At,
AtAll,
Audio,
Video,
File,
Emoji,
Reply,
Reference,
Hyper,
Button,
Keyboard,
Voice,
]
)
@ "AnySeg"
)
_matcher = on_alconna(
Alconna(
"广播",
Args["content?", AllParam],
),
aliases={"bc"},
priority=1,
permission=SUPERUSER,
block=True,
rule=to_me(),
use_origin=False,
)
_recall_matcher = on_alconna(
Alconna("广播撤回"),
aliases={"recall"},
priority=1,
permission=SUPERUSER,
block=True,
rule=to_me(),
)
@_matcher.handle()
async def _(
async def handle_broadcast(
bot: Bot,
event: Event,
session: EventSession,
message: UniMsg,
command: Annotated[tuple[str, ...], Command()],
arp: alc.Arparma,
):
for msg in message:
if isinstance(msg, alcText) and msg.text.strip().startswith(command[0]):
msg.text = msg.text.replace(command[0], "", 1).strip()
break
await MessageUtils.build_message("正在发送..请等一下哦!").send()
count, error_count = await BroadcastManage.send(bot, message, session)
result = f"成功广播 {count} 个群组"
if error_count:
result += f"\n广播失败 {error_count} 个群组"
await MessageUtils.build_message(f"发送广播完成!\n{result}").send(reply_to=True)
logger.info(f"发送广播信息: {message}", "广播", session=session)
broadcast_content_msg = await _extract_broadcast_content(bot, event, arp, session)
if not broadcast_content_msg:
return
target_groups, enabled_groups = await get_broadcast_target_groups(bot, session)
if not target_groups or not enabled_groups:
return
try:
await send_broadcast_and_notify(
bot, event, broadcast_content_msg, enabled_groups, target_groups, session
)
except Exception as e:
error_msg = "发送广播失败"
BroadcastManager.log_error(error_msg, e, session)
await MessageUtils.build_message(f"{error_msg}").send(reply_to=True)
@_recall_matcher.handle()
async def handle_broadcast_recall(
bot: Bot,
event: Event,
session: EventSession,
):
"""处理广播撤回命令"""
await MessageUtils.build_message("正在尝试撤回最近一次广播...").send()
try:
success_count, error_count = await BroadcastManager.recall_last_broadcast(
bot, session
)
user_id = str(event.get_user_id())
if success_count == 0 and error_count == 0:
await bot.send_private_msg(
user_id=user_id,
message="没有找到最近的广播消息记录,可能已经撤回或超过可撤回时间。",
)
else:
result = f"广播撤回完成!\n成功撤回 {success_count} 条消息"
if error_count:
result += f"\n撤回失败 {error_count} 条消息 (可能已过期或无权限)"
await bot.send_private_msg(user_id=user_id, message=result)
BroadcastManager.log_info(
f"广播撤回完成: 成功 {success_count}, 失败 {error_count}", session
)
except Exception as e:
error_msg = "撤回广播消息失败"
BroadcastManager.log_error(error_msg, e, session)
user_id = str(event.get_user_id())
await bot.send_private_msg(user_id=user_id, message=f"{error_msg}")

View File

@ -1,72 +0,0 @@
import asyncio
import random
from nonebot.adapters import Bot
import nonebot_plugin_alconna as alc
from nonebot_plugin_alconna import Image, UniMsg
from nonebot_plugin_session import EventSession
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.message import MessageUtils
from zhenxun.utils.platform import PlatformUtils
class BroadcastManage:
@classmethod
async def send(
cls, bot: Bot, message: UniMsg, session: EventSession
) -> tuple[int, int]:
"""发送广播消息
参数:
bot: Bot
message: 消息内容
session: Session
返回:
tuple[int, int]: 发送成功的群组数量, 发送失败的群组数量
"""
message_list = []
for msg in message:
if isinstance(msg, alc.Image) and msg.url:
message_list.append(Image(url=msg.url))
elif isinstance(msg, alc.Text):
message_list.append(msg.text)
group_list, _ = await PlatformUtils.get_group_list(bot)
if group_list:
error_count = 0
for group in group_list:
try:
if not await CommonUtils.task_is_block(
bot,
"broadcast", # group.channel_id
group.group_id,
):
target = PlatformUtils.get_target(
group_id=group.group_id, channel_id=group.channel_id
)
if target:
await MessageUtils.build_message(message_list).send(
target, bot
)
logger.debug(
"发送成功",
"广播",
session=session,
target=f"{group.group_id}:{group.channel_id}",
)
await asyncio.sleep(random.randint(1, 3))
else:
logger.warning("target为空", "广播", session=session)
except Exception as e:
error_count += 1
logger.error(
"发送失败",
"广播",
session=session,
target=f"{group.group_id}:{group.channel_id}",
e=e,
)
return len(group_list) - error_count, error_count
return 0, 0

View File

@ -0,0 +1,490 @@
import asyncio
import random
import traceback
from typing import ClassVar
from nonebot.adapters import Bot
from nonebot.adapters.onebot.v11 import Bot as V11Bot
from nonebot.exception import ActionFailed
from nonebot_plugin_alconna import UniMessage
from nonebot_plugin_alconna.uniseg import Receipt, Reference
from nonebot_plugin_session import EventSession
from zhenxun.models.group_console import GroupConsole
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.platform import PlatformUtils
from .models import BroadcastDetailResult, BroadcastResult
from .utils import custom_nodes_to_v11_nodes, uni_message_to_v11_list_of_dicts
class BroadcastManager:
"""广播管理器"""
_last_broadcast_msg_ids: ClassVar[dict[str, int]] = {}
@staticmethod
def _get_session_info(session: EventSession | None) -> str:
"""获取会话信息字符串"""
if not session:
return ""
try:
platform = getattr(session, "platform", "unknown")
session_id = str(session)
return f"[{platform}:{session_id}]"
except Exception:
return "[session-info-error]"
@staticmethod
def log_error(
message: str, error: Exception, session: EventSession | None = None, **kwargs
):
"""记录错误日志"""
session_info = BroadcastManager._get_session_info(session)
error_type = type(error).__name__
stack_trace = traceback.format_exc()
error_details = f"\n类型: {error_type}\n信息: {error!s}\n堆栈: {stack_trace}"
logger.error(
f"{session_info} {message}{error_details}", "广播", e=error, **kwargs
)
@staticmethod
def log_warning(message: str, session: EventSession | None = None, **kwargs):
"""记录警告级别日志"""
session_info = BroadcastManager._get_session_info(session)
logger.warning(f"{session_info} {message}", "广播", **kwargs)
@staticmethod
def log_info(message: str, session: EventSession | None = None, **kwargs):
"""记录信息级别日志"""
session_info = BroadcastManager._get_session_info(session)
logger.info(f"{session_info} {message}", "广播", **kwargs)
@classmethod
def get_last_broadcast_msg_ids(cls) -> dict[str, int]:
"""获取最近广播消息ID"""
return cls._last_broadcast_msg_ids.copy()
@classmethod
def clear_last_broadcast_msg_ids(cls) -> None:
"""清空消息ID记录"""
cls._last_broadcast_msg_ids.clear()
@classmethod
async def get_all_groups(cls, bot: Bot) -> tuple[list[GroupConsole], str]:
"""获取群组列表"""
return await PlatformUtils.get_group_list(bot)
@classmethod
async def send(
cls, bot: Bot, message: UniMessage, session: EventSession
) -> BroadcastResult:
"""发送广播到所有群组"""
logger.debug(
f"开始广播(send - 广播到所有群组)Bot ID: {bot.self_id}",
"广播",
session=session,
)
logger.debug("清空上一次的广播消息ID记录", "广播", session=session)
cls.clear_last_broadcast_msg_ids()
all_groups, _ = await cls.get_all_groups(bot)
return await cls.send_to_specific_groups(bot, message, all_groups, session)
@classmethod
async def send_to_specific_groups(
cls,
bot: Bot,
message: UniMessage,
target_groups: list[GroupConsole],
session_info: EventSession | str | None = None,
) -> BroadcastResult:
"""发送广播到指定群组"""
log_session = session_info or bot.self_id
logger.debug(
f"开始广播,目标 {len(target_groups)} 个群组Bot ID: {bot.self_id}",
"广播",
session=log_session,
)
if not target_groups:
logger.debug("目标群组列表为空,广播结束", "广播", session=log_session)
return 0, 0
platform = PlatformUtils.get_platform(bot)
is_forward_broadcast = any(
isinstance(seg, Reference) and getattr(seg, "nodes", None)
for seg in message
)
if platform == "qq" and isinstance(bot, V11Bot) and is_forward_broadcast:
if (
len(message) == 1
and isinstance(message[0], Reference)
and getattr(message[0], "nodes", None)
):
nodes_list = getattr(message[0], "nodes", [])
v11_nodes = custom_nodes_to_v11_nodes(nodes_list)
node_count = len(v11_nodes)
logger.debug(
f"从 UniMessage<Reference> 构造转发节点数: {node_count}",
"广播",
session=log_session,
)
else:
logger.warning(
"广播消息包含合并转发段和其他段,将尝试打平成一个节点发送",
"广播",
session=log_session,
)
v11_content_list = uni_message_to_v11_list_of_dicts(message)
v11_nodes = (
[
{
"type": "node",
"data": {
"user_id": bot.self_id,
"nickname": "广播",
"content": v11_content_list,
},
}
]
if v11_content_list
else []
)
if not v11_nodes:
logger.warning(
"构造出的 V11 合并转发节点为空,无法发送",
"广播",
session=log_session,
)
return 0, len(target_groups)
success_count, error_count, skip_count = await cls._broadcast_forward(
bot, log_session, target_groups, v11_nodes
)
else:
if is_forward_broadcast:
logger.warning(
f"合并转发消息在适配器 ({platform}) 不支持,将作为普通消息发送",
"广播",
session=log_session,
)
success_count, error_count, skip_count = await cls._broadcast_normal(
bot, log_session, target_groups, message
)
total = len(target_groups)
stats = f"成功: {success_count}, 失败: {error_count}"
stats += f", 跳过: {skip_count}, 总计: {total}"
logger.debug(
f"广播统计 - {stats}",
"广播",
session=log_session,
)
msg_ids = cls.get_last_broadcast_msg_ids()
if msg_ids:
id_list_str = ", ".join([f"{k}:{v}" for k, v in msg_ids.items()])
logger.debug(
f"广播结束,记录了 {len(msg_ids)} 条消息ID: {id_list_str}",
"广播",
session=log_session,
)
else:
logger.warning(
"广播结束但没有记录任何消息ID",
"广播",
session=log_session,
)
return success_count, error_count
@classmethod
async def _extract_message_id_from_result(
cls,
result: dict | Receipt,
group_key: str,
session_info: EventSession | str,
msg_type: str = "普通",
) -> None:
"""提取消息ID并记录"""
if isinstance(result, dict) and "message_id" in result:
msg_id = result["message_id"]
try:
msg_id_int = int(msg_id)
cls._last_broadcast_msg_ids[group_key] = msg_id_int
logger.debug(
f"记录群 {group_key}{msg_type}消息ID: {msg_id_int}",
"广播",
session=session_info,
)
except (ValueError, TypeError):
logger.warning(
f"{msg_type}结果中的 message_id 不是有效整数: {msg_id}",
"广播",
session=session_info,
)
elif isinstance(result, Receipt) and result.msg_ids:
try:
first_id_info = result.msg_ids[0]
msg_id = None
if isinstance(first_id_info, dict) and "message_id" in first_id_info:
msg_id = first_id_info["message_id"]
logger.debug(
f"从 Receipt.msg_ids[0] 提取到 ID: {msg_id}",
"广播",
session=session_info,
)
elif isinstance(first_id_info, int | str):
msg_id = first_id_info
logger.debug(
f"从 Receipt.msg_ids[0] 提取到原始ID: {msg_id}",
"广播",
session=session_info,
)
if msg_id is not None:
try:
msg_id_int = int(msg_id)
cls._last_broadcast_msg_ids[group_key] = msg_id_int
logger.debug(
f"记录群 {group_key} 的消息ID: {msg_id_int}",
"广播",
session=session_info,
)
except (ValueError, TypeError):
logger.warning(
f"提取的ID ({msg_id}) 不是有效整数",
"广播",
session=session_info,
)
else:
info_str = str(first_id_info)
logger.warning(
f"无法从 Receipt.msg_ids[0] 提取ID: {info_str}",
"广播",
session=session_info,
)
except IndexError:
logger.warning("Receipt.msg_ids 为空", "广播", session=session_info)
except Exception as e_extract:
logger.error(
f"从 Receipt 提取 msg_id 时出错: {e_extract}",
"广播",
session=session_info,
e=e_extract,
)
else:
logger.warning(
f"发送成功但无法从结果获取消息 ID. 结果: {result}",
"广播",
session=session_info,
)
@classmethod
async def _check_group_availability(cls, bot: Bot, group: GroupConsole) -> bool:
"""检查群组是否可用"""
if not group.group_id:
return False
if await CommonUtils.task_is_block(bot, "broadcast", group.group_id):
return False
return True
@classmethod
async def _broadcast_forward(
cls,
bot: V11Bot,
session_info: EventSession | str,
group_list: list[GroupConsole],
v11_nodes: list[dict],
) -> BroadcastDetailResult:
"""发送合并转发"""
success_count = 0
error_count = 0
skip_count = 0
for _, group in enumerate(group_list):
group_key = group.group_id or group.channel_id
if not await cls._check_group_availability(bot, group):
skip_count += 1
continue
try:
result = await bot.send_group_forward_msg(
group_id=int(group.group_id), messages=v11_nodes
)
logger.debug(
f"合并转发消息发送结果: {result}, 类型: {type(result)}",
"广播",
session=session_info,
)
await cls._extract_message_id_from_result(
result, group_key, session_info, "合并转发"
)
success_count += 1
await asyncio.sleep(random.randint(1, 3))
except ActionFailed as af_e:
error_count += 1
logger.error(
f"发送失败(合并转发) to {group_key}: {af_e}",
"广播",
session=session_info,
e=af_e,
)
except Exception as e:
error_count += 1
logger.error(
f"发送失败(合并转发) to {group_key}: {e}",
"广播",
session=session_info,
e=e,
)
return success_count, error_count, skip_count
@classmethod
async def _broadcast_normal(
cls,
bot: Bot,
session_info: EventSession | str,
group_list: list[GroupConsole],
message: UniMessage,
) -> BroadcastDetailResult:
"""发送普通消息"""
success_count = 0
error_count = 0
skip_count = 0
for _, group in enumerate(group_list):
group_key = (
f"{group.group_id}:{group.channel_id}"
if group.channel_id
else str(group.group_id)
)
if not await cls._check_group_availability(bot, group):
skip_count += 1
continue
try:
target = PlatformUtils.get_target(
group_id=group.group_id, channel_id=group.channel_id
)
if target:
receipt: Receipt = await message.send(target, bot=bot)
logger.debug(
f"广播消息发送结果: {receipt}, 类型: {type(receipt)}",
"广播",
session=session_info,
)
await cls._extract_message_id_from_result(
receipt, group_key, session_info
)
success_count += 1
await asyncio.sleep(random.randint(1, 3))
else:
logger.warning(
"target为空", "广播", session=session_info, target=group_key
)
skip_count += 1
except Exception as e:
error_count += 1
logger.error(
f"发送失败(普通) to {group_key}: {e}",
"广播",
session=session_info,
e=e,
)
return success_count, error_count, skip_count
@classmethod
async def recall_last_broadcast(
cls, bot: Bot, session_info: EventSession | str
) -> BroadcastResult:
"""撤回最近广播"""
msg_ids_to_recall = cls.get_last_broadcast_msg_ids()
if not msg_ids_to_recall:
logger.warning(
"没有找到最近的广播消息ID记录", "广播撤回", session=session_info
)
return 0, 0
id_list_str = ", ".join([f"{k}:{v}" for k, v in msg_ids_to_recall.items()])
logger.debug(
f"找到 {len(msg_ids_to_recall)} 条广播消息ID记录: {id_list_str}",
"广播撤回",
session=session_info,
)
success_count = 0
error_count = 0
logger.info(
f"准备撤回 {len(msg_ids_to_recall)} 条广播消息",
"广播撤回",
session=session_info,
)
for group_key, msg_id in msg_ids_to_recall.items():
try:
logger.debug(
f"尝试撤回消息 (ID: {msg_id}) in {group_key}",
"广播撤回",
session=session_info,
)
await bot.call_api("delete_msg", message_id=msg_id)
success_count += 1
except ActionFailed as af_e:
retcode = getattr(af_e, "retcode", None)
wording = getattr(af_e, "wording", "")
if retcode == 100 and "MESSAGE_NOT_FOUND" in wording.upper():
logger.warning(
f"消息 (ID: {msg_id}) 可能已被撤回或不存在于 {group_key}",
"广播撤回",
session=session_info,
)
elif retcode == 300 and "delete message" in wording.lower():
logger.warning(
f"消息 (ID: {msg_id}) 可能已被撤回或不存在于 {group_key}",
"广播撤回",
session=session_info,
)
else:
error_count += 1
logger.error(
f"撤回消息失败 (ID: {msg_id}) in {group_key}: {af_e}",
"广播撤回",
session=session_info,
e=af_e,
)
except Exception as e:
error_count += 1
logger.error(
f"撤回消息时发生未知错误 (ID: {msg_id}) in {group_key}: {e}",
"广播撤回",
session=session_info,
e=e,
)
await asyncio.sleep(0.2)
logger.debug("撤回操作完成清空消息ID记录", "广播撤回", session=session_info)
cls.clear_last_broadcast_msg_ids()
return success_count, error_count

View File

@ -0,0 +1,584 @@
import base64
import json
from typing import Any
from nonebot.adapters import Bot, Event
from nonebot.adapters.onebot.v11 import Message as V11Message
from nonebot.adapters.onebot.v11 import MessageSegment as V11MessageSegment
from nonebot.exception import ActionFailed
import nonebot_plugin_alconna as alc
from nonebot_plugin_alconna import UniMessage
from nonebot_plugin_alconna.uniseg.segment import (
At,
AtAll,
CustomNode,
Image,
Reference,
Reply,
Text,
Video,
)
from nonebot_plugin_alconna.uniseg.tools import reply_fetch
from nonebot_plugin_session import EventSession
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.message import MessageUtils
from .broadcast_manager import BroadcastManager
MAX_FORWARD_DEPTH = 3
async def _process_forward_content(
forward_content: Any, forward_id: str | None, bot: Bot, depth: int
) -> list[CustomNode]:
"""处理转发消息内容"""
nodes_for_alc = []
content_parsed = False
if forward_content:
nodes_from_content = None
if isinstance(forward_content, list):
nodes_from_content = forward_content
elif isinstance(forward_content, str):
try:
parsed_content = json.loads(forward_content)
if isinstance(parsed_content, list):
nodes_from_content = parsed_content
except Exception as json_e:
logger.debug(
f"[Depth {depth}] JSON解析失败: {json_e}",
"广播",
)
if nodes_from_content is not None:
logger.debug(
f"[D{depth}] 节点数: {len(nodes_from_content)}",
"广播",
)
content_parsed = True
for node_data in nodes_from_content:
node = await _create_custom_node_from_data(node_data, bot, depth + 1)
if node:
nodes_for_alc.append(node)
if not content_parsed and forward_id:
logger.debug(
f"[D{depth}] 尝试API调用ID: {forward_id}",
"广播",
)
try:
forward_data = await bot.call_api("get_forward_msg", id=forward_id)
nodes_list = None
if isinstance(forward_data, dict) and "messages" in forward_data:
nodes_list = forward_data["messages"]
elif (
isinstance(forward_data, dict)
and "data" in forward_data
and isinstance(forward_data["data"], dict)
and "message" in forward_data["data"]
):
nodes_list = forward_data["data"]["message"]
elif isinstance(forward_data, list):
nodes_list = forward_data
if nodes_list:
node_count = len(nodes_list)
logger.debug(
f"[D{depth + 1}] 节点:{node_count}",
"广播",
)
for node_data in nodes_list:
node = await _create_custom_node_from_data(
node_data, bot, depth + 1
)
if node:
nodes_for_alc.append(node)
else:
logger.warning(
f"[D{depth + 1}] ID:{forward_id}无节点",
"广播",
)
nodes_for_alc.append(
CustomNode(
uid="0",
name="错误",
content="[嵌套转发消息获取失败]",
)
)
except ActionFailed as af_e:
logger.error(
f"[D{depth + 1}] API失败: {af_e}",
"广播",
e=af_e,
)
nodes_for_alc.append(
CustomNode(
uid="0",
name="错误",
content="[嵌套转发消息获取失败]",
)
)
except Exception as e:
logger.error(
f"[D{depth + 1}] 处理出错: {e}",
"广播",
e=e,
)
nodes_for_alc.append(
CustomNode(
uid="0",
name="错误",
content="[处理嵌套转发时出错]",
)
)
elif not content_parsed and not forward_id:
logger.warning(
f"[D{depth}] 转发段无内容也无ID",
"广播",
)
nodes_for_alc.append(
CustomNode(
uid="0",
name="错误",
content="[嵌套转发消息无法解析]",
)
)
elif content_parsed and not nodes_for_alc:
logger.warning(
f"[D{depth}] 解析成功但无有效节点",
"广播",
)
nodes_for_alc.append(
CustomNode(
uid="0",
name="信息",
content="[嵌套转发内容为空]",
)
)
return nodes_for_alc
async def _create_custom_node_from_data(
node_data: dict, bot: Bot, depth: int
) -> CustomNode | None:
"""从节点数据创建CustomNode"""
node_content_raw = node_data.get("message") or node_data.get("content")
if not node_content_raw:
logger.warning(f"[D{depth}] 节点缺少消息内容", "广播")
return None
sender = node_data.get("sender", {})
uid = str(sender.get("user_id", "10000"))
name = sender.get("nickname", f"用户{uid[:4]}")
extracted_uni_msg = await _extract_content_from_message(
node_content_raw, bot, depth
)
if not extracted_uni_msg:
return None
return CustomNode(uid=uid, name=name, content=extracted_uni_msg)
async def _extract_broadcast_content(
bot: Bot,
event: Event,
arp: alc.Arparma,
session: EventSession,
) -> UniMessage | None:
"""从命令参数或引用消息中提取广播内容"""
broadcast_content_msg: UniMessage | None = None
command_content_list = arp.all_matched_args.get("content", [])
processed_command_list = []
has_command_content = False
if command_content_list:
for item in command_content_list:
if isinstance(item, alc.Segment):
processed_command_list.append(item)
if not (isinstance(item, Text) and not item.text.strip()):
has_command_content = True
elif isinstance(item, str):
if item.strip():
processed_command_list.append(Text(item.strip()))
has_command_content = True
else:
logger.warning(
f"Unexpected type in command content: {type(item)}", "广播"
)
if has_command_content:
logger.debug("检测到命令参数内容,优先使用参数内容", "广播", session=session)
broadcast_content_msg = UniMessage(processed_command_list)
if not broadcast_content_msg.filter(
lambda x: not (isinstance(x, Text) and not x.text.strip())
):
logger.warning(
"命令参数内容解析后为空或只包含空白", "广播", session=session
)
broadcast_content_msg = None
if not broadcast_content_msg:
reply_segment_obj: Reply | None = await reply_fetch(event, bot)
if (
reply_segment_obj
and hasattr(reply_segment_obj, "msg")
and reply_segment_obj.msg
):
logger.debug(
"未检测到有效命令参数,检测到引用消息", "广播", session=session
)
raw_quoted_content = reply_segment_obj.msg
is_forward = False
forward_id = None
if isinstance(raw_quoted_content, V11Message):
for seg in raw_quoted_content:
if isinstance(seg, V11MessageSegment):
if seg.type == "forward":
forward_id = seg.data.get("id")
is_forward = bool(forward_id)
break
elif seg.type == "json":
try:
json_data_str = seg.data.get("data", "{}")
if isinstance(json_data_str, str):
import json
json_data = json.loads(json_data_str)
if (
json_data.get("app") == "com.tencent.multimsg"
or json_data.get("view") == "Forward"
) and json_data.get("meta", {}).get(
"detail", {}
).get("resid"):
forward_id = json_data["meta"]["detail"][
"resid"
]
is_forward = True
break
except Exception:
pass
if is_forward and forward_id:
logger.info(
f"尝试获取并构造合并转发内容 (ID: {forward_id})",
"广播",
session=session,
)
nodes_to_forward: list[CustomNode] = []
try:
forward_data = await bot.call_api("get_forward_msg", id=forward_id)
nodes_list = None
if isinstance(forward_data, dict) and "messages" in forward_data:
nodes_list = forward_data["messages"]
elif (
isinstance(forward_data, dict)
and "data" in forward_data
and isinstance(forward_data["data"], dict)
and "message" in forward_data["data"]
):
nodes_list = forward_data["data"]["message"]
elif isinstance(forward_data, list):
nodes_list = forward_data
if nodes_list is not None:
for node_data in nodes_list:
node_sender = node_data.get("sender", {})
node_user_id = str(node_sender.get("user_id", "10000"))
node_nickname = node_sender.get(
"nickname", f"用户{node_user_id[:4]}"
)
node_content_raw = node_data.get(
"message"
) or node_data.get("content")
if node_content_raw:
extracted_node_uni_msg = (
await _extract_content_from_message(
node_content_raw, bot
)
)
if extracted_node_uni_msg:
nodes_to_forward.append(
CustomNode(
uid=node_user_id,
name=node_nickname,
content=extracted_node_uni_msg,
)
)
if nodes_to_forward:
broadcast_content_msg = UniMessage(
Reference(nodes=nodes_to_forward)
)
except ActionFailed:
await MessageUtils.build_message(
"获取合并转发消息失败,可能不支持此 API。"
).send(reply_to=True)
return None
except Exception as api_e:
logger.error(f"处理合并转发时出错: {api_e}", "广播", e=api_e)
await MessageUtils.build_message(
"处理合并转发消息时发生内部错误。"
).send(reply_to=True)
return None
else:
broadcast_content_msg = await _extract_content_from_message(
raw_quoted_content, bot
)
else:
logger.debug("未检测到命令参数和引用消息", "广播", session=session)
await MessageUtils.build_message("请提供广播内容或引用要广播的消息").send(
reply_to=True
)
return None
if not broadcast_content_msg:
logger.error(
"未能从命令参数或引用消息中获取有效的广播内容", "广播", session=session
)
await MessageUtils.build_message("错误:未能获取有效的广播内容。").send(
reply_to=True
)
return None
return broadcast_content_msg
async def _process_v11_segment(
seg_obj: V11MessageSegment | dict, depth: int, index: int, bot: Bot
) -> list[alc.Segment]:
"""处理V11消息段"""
result = []
seg_type = None
data_dict = None
if isinstance(seg_obj, V11MessageSegment):
seg_type = seg_obj.type
data_dict = seg_obj.data
elif isinstance(seg_obj, dict):
seg_type = seg_obj.get("type")
data_dict = seg_obj.get("data")
else:
return result
if not (seg_type and data_dict is not None):
logger.warning(f"[D{depth}] 跳过无效数据: {type(seg_obj)}", "广播")
return result
if seg_type == "text":
text_content = data_dict.get("text", "")
if isinstance(text_content, str) and text_content.strip():
result.append(Text(text_content))
elif seg_type == "image":
img_seg = None
if data_dict.get("url"):
img_seg = Image(url=data_dict["url"])
elif data_dict.get("file"):
file_val = data_dict["file"]
if isinstance(file_val, str) and file_val.startswith("base64://"):
b64_data = file_val[9:]
raw_bytes = base64.b64decode(b64_data)
img_seg = Image(raw=raw_bytes)
else:
img_seg = Image(path=file_val)
if img_seg:
result.append(img_seg)
else:
logger.warning(f"[Depth {depth}] V11 图片 {index} 缺少URL/文件", "广播")
elif seg_type == "at":
target_qq = data_dict.get("qq", "")
if target_qq.lower() == "all":
result.append(AtAll())
elif target_qq:
result.append(At(flag="user", target=target_qq))
elif seg_type == "video":
video_seg = None
if data_dict.get("url"):
video_seg = Video(url=data_dict["url"])
elif data_dict.get("file"):
file_val = data_dict["file"]
if isinstance(file_val, str) and file_val.startswith("base64://"):
b64_data = file_val[9:]
raw_bytes = base64.b64decode(b64_data)
video_seg = Video(raw=raw_bytes)
else:
video_seg = Video(path=file_val)
if video_seg:
result.append(video_seg)
logger.debug(f"[Depth {depth}] 处理视频消息成功", "广播")
else:
logger.warning(f"[Depth {depth}] V11 视频 {index} 缺少URL/文件", "广播")
elif seg_type == "forward":
nested_forward_id = data_dict.get("id") or data_dict.get("resid")
nested_forward_content = data_dict.get("content")
logger.debug(f"[D{depth}] 嵌套转发ID: {nested_forward_id}", "广播")
nested_nodes = await _process_forward_content(
nested_forward_content, nested_forward_id, bot, depth
)
if nested_nodes:
result.append(Reference(nodes=nested_nodes))
else:
logger.warning(f"[D{depth}] 跳过类型: {seg_type}", "广播")
return result
async def _extract_content_from_message(
message_content: Any, bot: Bot, depth: int = 0
) -> UniMessage:
"""提取消息内容到UniMessage"""
temp_msg = UniMessage()
input_type_str = str(type(message_content))
if depth >= MAX_FORWARD_DEPTH:
logger.warning(
f"[Depth {depth}] 达到最大递归深度 {MAX_FORWARD_DEPTH},停止解析嵌套转发。",
"广播",
)
temp_msg.append(Text("[嵌套转发层数过多,内容已省略]"))
return temp_msg
segments_to_process = []
if isinstance(message_content, UniMessage):
segments_to_process = list(message_content)
elif isinstance(message_content, V11Message):
segments_to_process = list(message_content)
elif isinstance(message_content, list):
segments_to_process = message_content
elif (
isinstance(message_content, dict)
and "type" in message_content
and "data" in message_content
):
segments_to_process = [message_content]
elif isinstance(message_content, str):
if message_content.strip():
temp_msg.append(Text(message_content))
return temp_msg
else:
logger.warning(f"[Depth {depth}] 无法处理的输入类型: {input_type_str}", "广播")
return temp_msg
if segments_to_process:
for index, seg_obj in enumerate(segments_to_process):
try:
if isinstance(seg_obj, Text):
text_content = getattr(seg_obj, "text", None)
if isinstance(text_content, str) and text_content.strip():
temp_msg.append(seg_obj)
elif isinstance(seg_obj, Image):
if (
getattr(seg_obj, "url", None)
or getattr(seg_obj, "path", None)
or getattr(seg_obj, "raw", None)
):
temp_msg.append(seg_obj)
elif isinstance(seg_obj, At):
temp_msg.append(seg_obj)
elif isinstance(seg_obj, AtAll):
temp_msg.append(seg_obj)
elif isinstance(seg_obj, Video):
if (
getattr(seg_obj, "url", None)
or getattr(seg_obj, "path", None)
or getattr(seg_obj, "raw", None)
):
temp_msg.append(seg_obj)
logger.debug(f"[D{depth}] 处理Video对象成功", "广播")
else:
processed_segments = await _process_v11_segment(
seg_obj, depth, index, bot
)
temp_msg.extend(processed_segments)
except Exception as e_conv_seg:
logger.warning(
f"[D{depth}] 处理段 {index} 出错: {e_conv_seg}",
"广播",
e=e_conv_seg,
)
if not temp_msg and message_content:
logger.warning(f"未能从类型 {input_type_str} 中提取内容", "广播")
return temp_msg
async def get_broadcast_target_groups(
bot: Bot, session: EventSession
) -> tuple[list, list]:
"""获取广播目标群组和启用了广播功能的群组"""
target_groups = []
all_groups, _ = await BroadcastManager.get_all_groups(bot)
current_group_id = None
if hasattr(session, "id2") and session.id2:
current_group_id = session.id2
if current_group_id:
target_groups = [
group for group in all_groups if group.group_id != current_group_id
]
logger.info(
f"向除当前群组({current_group_id})外的所有群组广播", "广播", session=session
)
else:
target_groups = all_groups
logger.info("向所有群组广播", "广播", session=session)
if not target_groups:
await MessageUtils.build_message("没有找到符合条件的广播目标群组。").send(
reply_to=True
)
return [], []
enabled_groups = []
for group in target_groups:
if not await CommonUtils.task_is_block(bot, "broadcast", group.group_id):
enabled_groups.append(group)
if not enabled_groups:
await MessageUtils.build_message(
"没有启用了广播功能的目标群组可供立即发送。"
).send(reply_to=True)
return target_groups, []
return target_groups, enabled_groups
async def send_broadcast_and_notify(
bot: Bot,
event: Event,
message: UniMessage,
enabled_groups: list,
target_groups: list,
session: EventSession,
) -> None:
"""发送广播并通知结果"""
BroadcastManager.clear_last_broadcast_msg_ids()
count, error_count = await BroadcastManager.send_to_specific_groups(
bot, message, enabled_groups, session
)
result = f"成功广播 {count} 个群组"
if error_count:
result += f"\n发送失败 {error_count} 个群组"
result += f"\n有效: {len(enabled_groups)} / 总计: {len(target_groups)}"
user_id = str(event.get_user_id())
await bot.send_private_msg(user_id=user_id, message=f"发送广播完成!\n{result}")
BroadcastManager.log_info(
f"广播完成,有效/总计: {len(enabled_groups)}/{len(target_groups)}",
session,
)

View File

@ -0,0 +1,64 @@
from datetime import datetime
from typing import Any
from nonebot_plugin_alconna import UniMessage
from zhenxun.models.group_console import GroupConsole
GroupKey = str
MessageID = int
BroadcastResult = tuple[int, int]
BroadcastDetailResult = tuple[int, int, int]
class BroadcastTarget:
"""广播目标"""
def __init__(self, group_id: str, channel_id: str | None = None):
self.group_id = group_id
self.channel_id = channel_id
def to_dict(self) -> dict[str, str | None]:
"""转换为字典格式"""
return {"group_id": self.group_id, "channel_id": self.channel_id}
@classmethod
def from_group_console(cls, group: GroupConsole) -> "BroadcastTarget":
"""从 GroupConsole 对象创建"""
return cls(group_id=group.group_id, channel_id=group.channel_id)
@property
def key(self) -> str:
"""获取群组的唯一标识"""
if self.channel_id:
return f"{self.group_id}:{self.channel_id}"
return str(self.group_id)
class BroadcastTask:
"""广播任务"""
def __init__(
self,
bot_id: str,
message: UniMessage,
targets: list[BroadcastTarget],
scheduled_time: datetime | None = None,
task_id: str | None = None,
):
self.bot_id = bot_id
self.message = message
self.targets = targets
self.scheduled_time = scheduled_time
self.task_id = task_id
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式,用于序列化"""
return {
"bot_id": self.bot_id,
"targets": [t.to_dict() for t in self.targets],
"scheduled_time": self.scheduled_time.isoformat()
if self.scheduled_time
else None,
"task_id": self.task_id,
}

View File

@ -0,0 +1,175 @@
import base64
import nonebot_plugin_alconna as alc
from nonebot_plugin_alconna import UniMessage
from nonebot_plugin_alconna.uniseg import Reference
from nonebot_plugin_alconna.uniseg.segment import CustomNode, Video
from zhenxun.services.log import logger
def uni_segment_to_v11_segment_dict(
seg: alc.Segment, depth: int = 0
) -> dict | list[dict] | None:
"""UniSeg段转V11字典"""
if isinstance(seg, alc.Text):
return {"type": "text", "data": {"text": seg.text}}
elif isinstance(seg, alc.Image):
if getattr(seg, "url", None):
return {
"type": "image",
"data": {"file": seg.url},
}
elif getattr(seg, "raw", None):
raw_data = seg.raw
if isinstance(raw_data, str):
if len(raw_data) >= 9 and raw_data[:9] == "base64://":
return {"type": "image", "data": {"file": raw_data}}
elif isinstance(raw_data, bytes):
b64_str = base64.b64encode(raw_data).decode()
return {"type": "image", "data": {"file": f"base64://{b64_str}"}}
else:
logger.warning(f"无法处理 Image.raw 的类型: {type(raw_data)}", "广播")
elif getattr(seg, "path", None):
logger.warning(
f"在合并转发中使用了本地图片路径,可能无法显示: {seg.path}", "广播"
)
return {"type": "image", "data": {"file": f"file:///{seg.path}"}}
else:
logger.warning(f"alc.Image 缺少有效数据,无法转换为 V11 段: {seg}", "广播")
elif isinstance(seg, alc.At):
return {"type": "at", "data": {"qq": seg.target}}
elif isinstance(seg, alc.AtAll):
return {"type": "at", "data": {"qq": "all"}}
elif isinstance(seg, Video):
if getattr(seg, "url", None):
return {
"type": "video",
"data": {"file": seg.url},
}
elif getattr(seg, "raw", None):
raw_data = seg.raw
if isinstance(raw_data, str):
if len(raw_data) >= 9 and raw_data[:9] == "base64://":
return {"type": "video", "data": {"file": raw_data}}
elif isinstance(raw_data, bytes):
b64_str = base64.b64encode(raw_data).decode()
return {"type": "video", "data": {"file": f"base64://{b64_str}"}}
else:
logger.warning(f"无法处理 Video.raw 的类型: {type(raw_data)}", "广播")
elif getattr(seg, "path", None):
logger.warning(
f"在合并转发中使用了本地视频路径,可能无法显示: {seg.path}", "广播"
)
return {"type": "video", "data": {"file": f"file:///{seg.path}"}}
else:
logger.warning(f"Video 缺少有效数据,无法转换为 V11 段: {seg}", "广播")
elif isinstance(seg, Reference) and getattr(seg, "nodes", None):
if depth >= 3:
logger.warning(
f"嵌套转发深度超过限制 (depth={depth}),不再继续解析", "广播"
)
return {"type": "text", "data": {"text": "[嵌套转发层数过多,内容已省略]"}}
nested_v11_content_list = []
nodes_list = getattr(seg, "nodes", [])
for node in nodes_list:
if isinstance(node, CustomNode):
node_v11_content = []
if isinstance(node.content, UniMessage):
for nested_seg in node.content:
converted_dict = uni_segment_to_v11_segment_dict(
nested_seg, depth + 1
)
if isinstance(converted_dict, list):
node_v11_content.extend(converted_dict)
elif converted_dict:
node_v11_content.append(converted_dict)
elif isinstance(node.content, str):
node_v11_content.append(
{"type": "text", "data": {"text": node.content}}
)
if node_v11_content:
separator = {
"type": "text",
"data": {
"text": f"\n--- 来自 {node.name} ({node.uid}) 的消息 ---\n"
},
}
nested_v11_content_list.insert(0, separator)
nested_v11_content_list.extend(node_v11_content)
nested_v11_content_list.append(
{"type": "text", "data": {"text": "\n---\n"}}
)
return nested_v11_content_list
else:
logger.warning(f"广播时跳过不支持的 UniSeg 段类型: {type(seg)}", "广播")
return None
def uni_message_to_v11_list_of_dicts(uni_msg: UniMessage | str | list) -> list[dict]:
"""UniMessage转V11字典列表"""
try:
if isinstance(uni_msg, str):
return [{"type": "text", "data": {"text": uni_msg}}]
if isinstance(uni_msg, list):
if not uni_msg:
return []
if all(isinstance(item, str) for item in uni_msg):
return [{"type": "text", "data": {"text": item}} for item in uni_msg]
result = []
for item in uni_msg:
if hasattr(item, "__iter__") and not isinstance(item, str | bytes):
result.extend(uni_message_to_v11_list_of_dicts(item))
elif hasattr(item, "text") and not isinstance(item, str | bytes):
text_value = getattr(item, "text", "")
result.append({"type": "text", "data": {"text": str(text_value)}})
elif hasattr(item, "url") and not isinstance(item, str | bytes):
url_value = getattr(item, "url", "")
if isinstance(item, Video):
result.append(
{"type": "video", "data": {"file": str(url_value)}}
)
else:
result.append(
{"type": "image", "data": {"file": str(url_value)}}
)
else:
try:
result.append({"type": "text", "data": {"text": str(item)}})
except Exception as e:
logger.warning(f"无法转换列表元素: {item}, 错误: {e}", "广播")
return result
except Exception as e:
logger.warning(f"消息转换过程中出错: {e}", "广播")
return [{"type": "text", "data": {"text": str(uni_msg)}}]
def custom_nodes_to_v11_nodes(custom_nodes: list[CustomNode]) -> list[dict]:
"""CustomNode列表转V11节点"""
v11_nodes = []
for node in custom_nodes:
v11_content_list = uni_message_to_v11_list_of_dicts(node.content)
if v11_content_list:
v11_nodes.append(
{
"type": "node",
"data": {
"user_id": str(node.uid),
"nickname": node.name,
"content": v11_content_list,
},
}
)
else:
logger.warning(
f"CustomNode (uid={node.uid}) 内容转换后为空,跳过此节点", "广播"
)
return v11_nodes

View File

@ -10,7 +10,9 @@ from zhenxun.configs.config import Config as gConfig
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
from zhenxun.services.log import logger, logger_
from zhenxun.utils.enum import PluginType
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .api.configure import router as configure_router
from .api.logs import router as ws_log_routes
from .api.logs.log_manager import LOG_STORAGE
from .api.menu import router as menu_router
@ -81,6 +83,7 @@ BaseApiRouter.include_router(database_router)
BaseApiRouter.include_router(plugin_router)
BaseApiRouter.include_router(system_router)
BaseApiRouter.include_router(menu_router)
BaseApiRouter.include_router(configure_router)
WsApiRouter = APIRouter(prefix="/zhenxun/socket")
@ -89,7 +92,7 @@ WsApiRouter.include_router(status_routes)
WsApiRouter.include_router(chat_routes)
@driver.on_startup
@PriorityLifecycle.on_startup(priority=0)
async def _():
try:
# 存储任务引用的列表,防止任务被垃圾回收

View File

@ -0,0 +1,133 @@
import asyncio
import os
from pathlib import Path
import re
import subprocess
import sys
import time
from fastapi import APIRouter
from fastapi.responses import JSONResponse
import nonebot
from zhenxun.configs.config import BotConfig, Config
from ...base_model import Result
from .data_source import test_db_connection
from .model import Setting
router = APIRouter(prefix="/configure")
driver = nonebot.get_driver()
port = driver.config.port
BAT_FILE = Path() / "win启动.bat"
FILE_NAME = ".configure_restart"
@router.post(
"/set_configure",
response_model=Result,
response_class=JSONResponse,
description="设置基础配置",
)
async def _(setting: Setting) -> Result:
global port
password = Config.get_config("web-ui", "password")
if password or BotConfig.db_url:
return Result.fail("配置已存在请先删除DB_URL内容和前端密码再进行设置。")
env_file = Path() / ".env.dev"
if not env_file.exists():
return Result.fail("配置文件.env.dev不存在。")
env_text = env_file.read_text(encoding="utf-8")
if setting.db_url:
if setting.db_url.startswith("sqlite"):
base_dir = Path().resolve()
# 清理和验证数据库路径
db_path_str = setting.db_url.split(":")[-1].strip()
# 移除任何可能的路径遍历尝试
db_path_str = re.sub(r"[\\/]\.\.[\\/]", "", db_path_str)
# 规范化路径
db_path = Path(db_path_str).resolve()
parent_path = db_path.parent
# 验证路径是否在项目根目录内
try:
if not parent_path.absolute().is_relative_to(base_dir):
return Result.fail("数据库路径不在项目根目录内。")
except ValueError:
return Result.fail("无效的数据库路径。")
# 创建目录
try:
parent_path.mkdir(parents=True, exist_ok=True)
except Exception as e:
return Result.fail(f"创建数据库目录失败: {e!s}")
env_text = env_text.replace('DB_URL = ""', f'DB_URL = "{setting.db_url}"')
if setting.superusers:
superusers = ", ".join([f'"{s}"' for s in setting.superusers])
env_text = re.sub(r"SUPERUSERS=\[.*?\]", f"SUPERUSERS=[{superusers}]", env_text)
if setting.host:
env_text = env_text.replace("HOST = 127.0.0.1", f"HOST = {setting.host}")
if setting.port:
env_text = env_text.replace("PORT = 8080", f"PORT = {setting.port}")
port = setting.port
if setting.username:
Config.set_config("web-ui", "username", setting.username)
Config.set_config("web-ui", "password", setting.password, True)
env_file.write_text(env_text, encoding="utf-8")
if BAT_FILE.exists():
for file in os.listdir(Path()):
if file.startswith(FILE_NAME):
Path(file).unlink()
flag_file = Path() / f"{FILE_NAME}_{int(time.time())}"
flag_file.touch()
return Result.ok(BAT_FILE.exists(), info="设置成功,请重启真寻以完成配置!")
@router.get(
"/test_db",
response_model=Result,
response_class=JSONResponse,
description="设置基础配置",
)
async def _(db_url: str) -> Result:
result = await test_db_connection(db_url)
if isinstance(result, str):
return Result.fail(result)
return Result.ok(info="数据库连接成功!")
async def run_restart_command(bat_path: Path, port: int):
"""在后台执行重启命令"""
await asyncio.sleep(1) # 确保 FastAPI 已返回响应
subprocess.Popen([bat_path, str(port)], shell=True) # noqa: ASYNC220
sys.exit(0) # 退出当前进程
@router.post(
"/restart",
response_model=Result,
response_class=JSONResponse,
description="重启",
)
async def _() -> Result:
if not BAT_FILE.exists():
return Result.fail("自动重启仅支持意见整合包,请尝试手动重启")
flag_file = next(
(Path() / file for file in os.listdir(Path()) if file.startswith(FILE_NAME)),
None,
)
if not flag_file or not flag_file.exists():
return Result.fail("重启标志文件不存在...")
set_time = flag_file.name.split("_")[-1]
if time.time() - float(set_time) > 10 * 60:
return Result.fail("重启标志文件已过期,请重新设置配置。")
flag_file.unlink()
try:
return Result.ok(info="执行重启命令成功")
finally:
asyncio.create_task(run_restart_command(BAT_FILE, port)) # noqa: RUF006

View File

@ -0,0 +1,18 @@
from tortoise import Tortoise
async def test_db_connection(db_url: str) -> bool | str:
try:
# 初始化 Tortoise ORM
await Tortoise.init(
db_url=db_url,
modules={"models": ["__main__"]}, # 这里不需要实际模型
)
# 测试连接
await Tortoise.get_connection("default").execute_query("SELECT 1")
return True
except Exception as e:
return str(e)
finally:
# 关闭连接
await Tortoise.close_connections()

View File

@ -0,0 +1,16 @@
from pydantic import BaseModel
class Setting(BaseModel):
superusers: list[str]
"""超级用户列表"""
db_url: str
"""数据库地址"""
host: str
"""主机地址"""
port: int
"""端口"""
username: str
"""前端用户名"""
password: str
"""前端密码"""

View File

@ -5,54 +5,63 @@ from zhenxun.services.log import logger
from .model import MenuData, MenuItem
default_menus = [
MenuItem(
name="仪表盘",
module="dashboard",
router="/dashboard",
icon="dashboard",
default=True,
),
MenuItem(
name="真寻控制台",
module="command",
router="/command",
icon="command",
),
MenuItem(name="插件列表", module="plugin", router="/plugin", icon="plugin"),
MenuItem(name="插件商店", module="store", router="/store", icon="store"),
MenuItem(name="好友/群组", module="manage", router="/manage", icon="user"),
MenuItem(
name="数据库管理",
module="database",
router="/database",
icon="database",
),
MenuItem(name="系统信息", module="system", router="/system", icon="system"),
MenuItem(name="关于我们", module="about", router="/about", icon="about"),
]
class MenuManage:
class MenuManager:
def __init__(self) -> None:
self.file = DATA_PATH / "web_ui" / "menu.json"
self.menu = []
if self.file.exists():
try:
temp_menu = []
self.menu = json.load(self.file.open(encoding="utf8"))
self_menu_name = [menu["name"] for menu in self.menu]
for module in [m.module for m in default_menus]:
if module in self_menu_name:
temp_menu.append(
MenuItem(
**next(m for m in self.menu if m["module"] == module)
)
)
else:
temp_menu.append(self.__get_menu_model(module))
self.menu = temp_menu
except Exception as e:
logger.warning("菜单文件损坏,已重新生成...", "WebUi", e=e)
if not self.menu:
self.menu = [
MenuItem(
name="仪表盘",
module="dashboard",
router="/dashboard",
icon="dashboard",
default=True,
),
MenuItem(
name="真寻控制台",
module="command",
router="/command",
icon="command",
),
MenuItem(
name="插件列表", module="plugin", router="/plugin", icon="plugin"
),
MenuItem(
name="插件商店", module="store", router="/store", icon="store"
),
MenuItem(
name="好友/群组", module="manage", router="/manage", icon="user"
),
MenuItem(
name="数据库管理",
module="database",
router="/database",
icon="database",
),
MenuItem(
name="文件管理", module="system", router="/system", icon="system"
),
MenuItem(
name="关于我们", module="about", router="/about", icon="about"
),
]
self.save()
self.menu = default_menus
self.save()
def __get_menu_model(self, module: str):
return default_menus[
next(i for i, m in enumerate(default_menus) if m.module == module)
]
def get_menus(self):
return MenuData(menus=self.menu)
@ -64,4 +73,4 @@ class MenuManage:
json.dump(temp, f, ensure_ascii=False, indent=4)
menu_manage = MenuManage()
menu_manage = MenuManager()

View File

@ -13,6 +13,7 @@ from zhenxun.models.bot_connect_log import BotConnectLog
from zhenxun.models.chat_history import ChatHistory
from zhenxun.models.statistics import Statistics
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.platform import PlatformUtils
from ....base_model import BaseResultModel, QueryModel
@ -31,7 +32,7 @@ driver: Driver = nonebot.get_driver()
CONNECT_TIME = 0
@driver.on_startup
@PriorityLifecycle.on_startup(priority=5)
async def _():
global CONNECT_TIME
CONNECT_TIME = int(time.time())

View File

@ -8,6 +8,7 @@ from zhenxun.configs.config import BotConfig
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.task_info import TaskInfo
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from ....base_model import BaseResultModel, QueryModel, Result
from ....utils import authentication
@ -21,7 +22,7 @@ router = APIRouter(prefix="/database")
driver: Driver = nonebot.get_driver()
@driver.on_startup
@PriorityLifecycle.on_startup(priority=5)
async def _():
for plugin in nonebot.get_loaded_plugins():
module = plugin.name

View File

@ -119,7 +119,7 @@ class ApiDataSource:
(await PlatformUtils.get_friend_list(select_bot.bot))[0]
)
except Exception as e:
logger.warning("获取bot好友/群组信息失败...", "WebUi", e=e)
logger.warning("获取bot好友/群组数量失败...", "WebUi", e=e)
select_bot.group_count = 0
select_bot.friend_count = 0
select_bot.status = await BotConsole.get_bot_status(select_bot.self_id)

View File

@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse
from zhenxun.utils._build_image import BuildImage
from ....base_model import Result, SystemFolderSize
from ....utils import authentication, get_system_disk
from ....utils import authentication, get_system_disk, validate_path
from .model import AddFile, DeleteFile, DirFile, RenameFile, SaveFile
router = APIRouter(prefix="/system")
@ -25,22 +25,30 @@ IMAGE_TYPE = ["jpg", "jpeg", "png", "gif", "bmp", "webp", "svg"]
description="获取文件列表",
)
async def _(path: str | None = None) -> Result[list[DirFile]]:
base_path = Path(path) if path else Path()
data_list = []
for file in os.listdir(base_path):
file_path = base_path / file
is_image = any(file.endswith(f".{t}") for t in IMAGE_TYPE)
data_list.append(
DirFile(
is_file=not file_path.is_dir(),
is_image=is_image,
name=file,
parent=path,
size=None if file_path.is_dir() else file_path.stat().st_size,
mtime=file_path.stat().st_mtime,
try:
base_path, error = validate_path(path)
if error:
return Result.fail(error)
if not base_path:
return Result.fail("无效的路径")
data_list = []
for file in os.listdir(base_path):
file_path = base_path / file
is_image = any(file.endswith(f".{t}") for t in IMAGE_TYPE)
data_list.append(
DirFile(
is_file=not file_path.is_dir(),
is_image=is_image,
name=file,
parent=path,
size=None if file_path.is_dir() else file_path.stat().st_size,
mtime=file_path.stat().st_mtime,
)
)
)
return Result.ok(data_list)
sorted(data_list, key=lambda f: f.name)
return Result.ok(data_list)
except Exception as e:
return Result.fail(f"获取文件列表失败: {e!s}")
@router.get(
@ -62,8 +70,12 @@ async def _(full_path: str | None = None) -> Result[list[SystemFolderSize]]:
description="删除文件",
)
async def _(param: DeleteFile) -> Result:
path = Path(param.full_path)
if not path or not path.exists():
path, error = validate_path(param.full_path)
if error:
return Result.fail(error)
if not path:
return Result.fail("无效的路径")
if not path.exists():
return Result.warning_("文件不存在...")
try:
path.unlink()
@ -80,8 +92,12 @@ async def _(param: DeleteFile) -> Result:
description="删除文件夹",
)
async def _(param: DeleteFile) -> Result:
path = Path(param.full_path)
if not path or not path.exists() or path.is_file():
path, error = validate_path(param.full_path)
if error:
return Result.fail(error)
if not path:
return Result.fail("无效的路径")
if not path.exists() or path.is_file():
return Result.warning_("文件夹不存在...")
try:
shutil.rmtree(path.absolute())
@ -98,10 +114,14 @@ async def _(param: DeleteFile) -> Result:
description="重命名文件",
)
async def _(param: RenameFile) -> Result:
path = (
(Path(param.parent) / param.old_name) if param.parent else Path(param.old_name)
)
if not path or not path.exists():
parent_path, error = validate_path(param.parent)
if error:
return Result.fail(error)
if not parent_path:
return Result.fail("无效的路径")
path = (parent_path / param.old_name) if param.parent else Path(param.old_name)
if not path.exists():
return Result.warning_("文件不存在...")
try:
path.rename(path.parent / param.name)
@ -118,10 +138,14 @@ async def _(param: RenameFile) -> Result:
description="重命名文件夹",
)
async def _(param: RenameFile) -> Result:
path = (
(Path(param.parent) / param.old_name) if param.parent else Path(param.old_name)
)
if not path or not path.exists() or path.is_file():
parent_path, error = validate_path(param.parent)
if error:
return Result.fail(error)
if not parent_path:
return Result.fail("无效的路径")
path = (parent_path / param.old_name) if param.parent else Path(param.old_name)
if not path.exists() or path.is_file():
return Result.warning_("文件夹不存在...")
try:
new_path = path.parent / param.name
@ -139,7 +163,13 @@ async def _(param: RenameFile) -> Result:
description="新建文件",
)
async def _(param: AddFile) -> Result:
path = (Path(param.parent) / param.name) if param.parent else Path(param.name)
parent_path, error = validate_path(param.parent)
if error:
return Result.fail(error)
if not parent_path:
return Result.fail("无效的路径")
path = (parent_path / param.name) if param.parent else Path(param.name)
if path.exists():
return Result.warning_("文件已存在...")
try:
@ -157,7 +187,13 @@ async def _(param: AddFile) -> Result:
description="新建文件夹",
)
async def _(param: AddFile) -> Result:
path = (Path(param.parent) / param.name) if param.parent else Path(param.name)
parent_path, error = validate_path(param.parent)
if error:
return Result.fail(error)
if not parent_path:
return Result.fail("无效的路径")
path = (parent_path / param.name) if param.parent else Path(param.name)
if path.exists():
return Result.warning_("文件夹已存在...")
try:
@ -175,7 +211,11 @@ async def _(param: AddFile) -> Result:
description="读取文件",
)
async def _(full_path: str) -> Result:
path = Path(full_path)
path, error = validate_path(full_path)
if error:
return Result.fail(error)
if not path:
return Result.fail("无效的路径")
if not path.exists():
return Result.warning_("文件不存在...")
try:
@ -193,9 +233,13 @@ async def _(full_path: str) -> Result:
description="读取文件",
)
async def _(param: SaveFile) -> Result[str]:
path = Path(param.full_path)
path, error = validate_path(param.full_path)
if error:
return Result.fail(error)
if not path:
return Result.fail("无效的路径")
try:
async with aiofiles.open(path, "w", encoding="utf-8") as f:
async with aiofiles.open(str(path), "w", encoding="utf-8") as f:
await f.write(param.content)
return Result.ok("更新成功!")
except Exception as e:
@ -210,7 +254,11 @@ async def _(param: SaveFile) -> Result[str]:
description="读取图片base64",
)
async def _(full_path: str) -> Result[str]:
path = Path(full_path)
path, error = validate_path(full_path)
if error:
return Result.fail(error)
if not path:
return Result.fail("无效的路径")
if not path.exists():
return Result.warning_("文件不存在...")
try:

View File

@ -1,6 +1,12 @@
import sys
from fastapi.middleware.cors import CORSMiddleware
import nonebot
from strenum import StrEnum
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from strenum import StrEnum
from zhenxun.configs.path_config import DATA_PATH, TEMP_PATH

View File

@ -18,6 +18,7 @@ async def update_webui_assets():
download_url = await GithubUtils.parse_github_url(
WEBUI_DIST_GITHUB_URL
).get_archive_download_urls()
logger.info("开始下载 webui_assets 资源...", COMMAND_NAME)
if await AsyncHttpx.download_file(
download_url, webui_assets_path, follow_redirects=True
):

View File

@ -2,6 +2,7 @@ import contextlib
from datetime import datetime, timedelta, timezone
import os
from pathlib import Path
import re
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
@ -28,6 +29,45 @@ if token_file.exists():
token_data = json.load(open(token_file, encoding="utf8"))
def validate_path(path_str: str | None) -> tuple[Path | None, str | None]:
"""验证路径是否安全
参数:
path_str: 用户输入的路径
返回:
tuple[Path | None, str | None]: (验证后的路径, 错误信息)
"""
try:
if not path_str:
return Path().resolve(), None
# 1. 移除任何可能的路径遍历尝试
path_str = re.sub(r"[\\/]\.\.[\\/]", "", path_str)
# 2. 规范化路径并转换为绝对路径
path = Path(path_str).resolve()
# 3. 获取项目根目录
root_dir = Path().resolve()
# 4. 验证路径是否在项目根目录内
try:
if not path.is_relative_to(root_dir):
return None, "访问路径超出允许范围"
except ValueError:
return None, "无效的路径格式"
# 5. 验证路径是否包含任何危险字符
if any(c in str(path) for c in ["..", "~", "*", "?", ">", "<", "|", '"']):
return None, "路径包含非法字符"
# 6. 验证路径长度是否合理
return (None, "路径长度超出限制") if len(str(path)) > 4096 else (path, None)
except Exception as e:
return None, f"路径验证失败: {e!s}"
GROUP_HELP_PATH = DATA_PATH / "group_help"
SIMPLE_HELP_IMAGE = IMAGE_PATH / "SIMPLE_HELP.png"
SIMPLE_DETAIL_HELP_IMAGE = IMAGE_PATH / "SIMPLE_DETAIL_HELP.png"

View File

@ -29,7 +29,7 @@ class BotConsole(Model):
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
table = "bot_console"
table_description = "Bot数据表"
cache_type = CacheType.BOT
@staticmethod

View File

@ -3,7 +3,6 @@ from tortoise import fields
from zhenxun.services.db_context import Model
from zhenxun.utils.enum import CacheType
class LevelUser(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True)
"""自增id"""
@ -56,7 +55,7 @@ class LevelUser(Model):
level: 权限等级
group_flag: 是否被自动更新刷新权限 0:, 1:.
"""
if await cls.exists(user_id=user_id, group_id=group_id, user_level=level):
if await cls.exists(user_id=user_id, group_id=group_id, level=level):
# 权限相同时跳过
return
await cls.update_or_create(

View File

@ -58,7 +58,7 @@ class PluginInfo(Model):
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
table = "plugin_info"
table_description = "插件基本信息"
cache_type = CacheType.PLUGINS
@classmethod

View File

@ -1,38 +0,0 @@
from tortoise import fields
from zhenxun.services.db_context import Model
class ScheduleInfo(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True)
"""自增id"""
bot_id = fields.CharField(
255, null=True, default=None, description="任务关联的Bot ID"
)
"""任务关联的Bot ID"""
plugin_name = fields.CharField(255, description="插件模块名")
"""插件模块名"""
group_id = fields.CharField(
255,
null=True,
description="群组ID, '__ALL_GROUPS__' 表示所有群, 为空表示全局任务",
)
"""群组ID, 为空表示全局任务"""
trigger_type = fields.CharField(
max_length=20, default="cron", description="触发器类型 (cron, interval, date)"
)
"""触发器类型 (cron, interval, date)"""
trigger_config = fields.JSONField(description="触发器具体配置")
"""触发器具体配置"""
job_kwargs = fields.JSONField(
default=dict, description="传递给任务函数的额外关键字参数"
)
"""传递给任务函数的额外关键字参数"""
is_enabled = fields.BooleanField(default=True, description="是否启用")
"""是否启用"""
create_time = fields.DatetimeField(auto_now_add=True)
"""创建时间"""
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
table = "schedule_info"
table_description = "通用定时任务表"

View File

View File

@ -12,7 +12,7 @@ from aiocache.serializers import JsonSerializer
import nonebot
from nonebot.compat import model_dump
from nonebot.utils import is_coroutine_callable
from pydantic import BaseModel
from pydantic import BaseModel, PrivateAttr
from zhenxun.services.log import logger
@ -121,7 +121,7 @@ class CacheData(BaseModel):
lazy_load: bool = True # 默认延迟加载
result_model: type | None = None
_keys: set[str] = set() # 存储所有缓存键
_cache: BaseCache | AioCache
_cache: BaseCache | AioCache = PrivateAttr()
class Config:
arbitrary_types_allowed = True

View File

@ -2,7 +2,6 @@ from asyncio import Semaphore
from collections.abc import Iterable
from typing import Any, ClassVar
from typing_extensions import Self
from urllib.parse import urlparse
import nonebot
from nonebot.utils import is_coroutine_callable
@ -32,17 +31,20 @@ def _():
global CACHE_FLAG
CACHE_FLAG = True
driver = nonebot.get_driver()
class Model(TortoiseModel):
"""
自动添加模块
Args:
Model_: Model
"""
sem_data: ClassVar[dict[str, dict[str, Semaphore]]] = {}
def __init_subclass__(cls, **kwargs):
if cls.__module__ not in MODELS:
MODELS.append(cls.__module__)
MODELS.append(cls.__module__)
if func := getattr(cls, "_run_script", None):
SCRIPT_METHOD.append((cls.__module__, func))
@ -169,7 +171,7 @@ class Model(TortoiseModel):
await CacheRoot.reload(cache_type)
class DbUrlIsNode(HookPriorityException):
class DbUrlMissing(Exception):
"""
数据库链接地址为空
"""
@ -188,6 +190,7 @@ class DbConnectError(Exception):
@PriorityLifecycle.on_startup(priority=1)
async def init():
if not BotConfig.db_url:
# raise DbUrlMissing("数据库配置为空,请在.env.dev中配置DB_URL...")
error = f"""
**********************************************************************
🌟 **************************** 配置为空 ************************* 🌟
@ -196,74 +199,13 @@ async def init():
***********************************************************************
***********************************************************************
"""
raise DbUrlIsNode("\n" + error.strip())
raise DbUrlMissing("\n" + error.strip())
try:
db_url = BotConfig.db_url
url_scheme = db_url.split(":", 1)[0]
config = None
if url_scheme in ("postgres", "postgresql"):
# 解析 db_url
url = urlparse(db_url)
credentials = {
"host": url.hostname,
"port": url.port or 5432,
"user": url.username,
"password": url.password,
"database": url.path.lstrip("/"),
"minsize": 1,
"maxsize": 50,
}
config = {
"connections": {
"default": {
"engine": "tortoise.backends.asyncpg",
"credentials": credentials,
}
},
"apps": {
"models": {
"models": MODELS,
"default_connection": "default",
}
},
}
elif url_scheme in ("mysql", "mysql+aiomysql"):
url = urlparse(db_url)
credentials = {
"host": url.hostname,
"port": url.port or 3306,
"user": url.username,
"password": url.password,
"database": url.path.lstrip("/"),
"minsize": 1,
"maxsize": 50,
}
config = {
"connections": {
"default": {
"engine": "tortoise.backends.mysql",
"credentials": credentials,
}
},
"apps": {
"models": {
"models": MODELS,
"default_connection": "default",
}
},
}
else:
# sqlite 或其它,直接用 db_url
await Tortoise.init(
db_url=db_url,
modules={"models": MODELS},
timezone="Asia/Shanghai",
)
if config:
await Tortoise.init(config=config)
await Tortoise.init(
db_url=BotConfig.db_url,
modules={"models": MODELS},
timezone="Asia/Shanghai",
)
if SCRIPT_METHOD:
db = Tortoise.get_connection("default")
logger.debug(
@ -282,6 +224,7 @@ async def init():
logger.debug(f"执行SQL: {sql}")
try:
await db.execute_query_dict(sql)
# await TestSQL.raw(sql)
except Exception as e:
logger.debug(f"执行SQL: {sql} 错误...", e=e)
if sql_list:
@ -293,4 +236,4 @@ async def init():
async def disconnect():
await connections.close_all()
await connections.close_all()

View File

@ -6,6 +6,7 @@ from nonebot.utils import is_coroutine_callable
from pydantic import BaseModel
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
driver = nonebot.get_driver()
@ -100,6 +101,6 @@ class PluginInitManager:
logger.error(f"执行: {module_path}:remove 失败", e=e)
@driver.on_startup
@PriorityLifecycle.on_startup(priority=5)
async def _():
await PluginInitManager.install_all()

View File

@ -1,12 +1,17 @@
from io import BytesIO
from pathlib import Path
import random
import sys
from pydantic import BaseModel, Field
from strenum import StrEnum
from ._build_image import BuildImage
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from strenum import StrEnum
class MatType(StrEnum):
LINE = "LINE"

View File

@ -3,9 +3,12 @@ from io import BytesIO
from pathlib import Path
import random
from nonebot_plugin_htmlrender import md_to_pic, template_to_pic
from PIL.ImageFont import FreeTypeFont
from pydantic import BaseModel
from zhenxun.configs.path_config import TEMPLATE_PATH
from ._build_image import BuildImage
@ -283,3 +286,191 @@ class ImageTemplate:
width = max(width, w)
height += h
return width, height
class MarkdownTable:
def __init__(self, headers: list[str], rows: list[list[str]]):
self.headers = headers
self.rows = rows
def to_markdown(self) -> str:
"""将表格转换为Markdown格式"""
header_row = "| " + " | ".join(self.headers) + " |"
separator_row = "| " + " | ".join(["---"] * len(self.headers)) + " |"
data_rows = "\n".join(
"| " + " | ".join(map(str, row)) + " |" for row in self.rows
)
return f"{header_row}\n{separator_row}\n{data_rows}"
class Markdown:
def __init__(self, data: list[str] | None = None):
if data is None:
data = []
self._data = data
def text(self, text: str) -> "Markdown":
"""添加Markdown文本"""
self._data.append(text)
return self
def head(self, text: str, level: int = 1) -> "Markdown":
"""添加Markdown标题"""
if level < 1 or level > 6:
raise ValueError("标题级别必须在1到6之间")
self._data.append(f"{'#' * level} {text}")
return self
def image(self, content: str | Path, add_empty_line: bool = True) -> "Markdown":
"""添加Markdown图片
参数:
content: 图片内容可以是url地址图片路径或base64字符串.
add_empty_line: 默认添加换行.
返回:
Markdown: Markdown
"""
if isinstance(content, Path):
content = str(content.absolute())
if content.startswith("base64"):
content = f"data:image/png;base64,{content.split('base64://', 1)[-1]}"
self._data.append(f"![image]({content})")
if add_empty_line:
self._add_empty_line()
return self
def quote(self, text: str | list[str]) -> "Markdown":
"""添加Markdown引用文本
参数:
text: 引用文本内容可以是字符串或字符串列表.
如果是列表则每个元素都会被单独引用
返回:
Markdown: Markdown
"""
if isinstance(text, str):
self._data.append(f"> {text}")
elif isinstance(text, list):
for t in text:
self._data.append(f"> {t}")
self._add_empty_line()
return self
def code(self, code: str, language: str = "python") -> "Markdown":
"""添加Markdown代码块"""
self._data.append(f"```{language}\n{code}\n```")
return self
def table(self, headers: list[str], rows: list[list[str]]) -> "Markdown":
"""添加Markdown表格"""
table = MarkdownTable(headers, rows)
self._data.append(table.to_markdown())
return self
def list(self, items: list[str | list[str]]) -> "Markdown":
"""添加Markdown列表"""
self._add_empty_line()
_text = "\n".join(
f"- {item}"
if isinstance(item, str)
else "\n".join(f"- {sub_item}" for sub_item in item)
for item in items
)
self._data.append(_text)
return self
def _add_empty_line(self):
"""添加空行"""
self._data.append("")
async def build(self, width: int = 800, css_path: Path | None = None) -> bytes:
"""构建Markdown文本"""
if css_path is not None:
return await md_to_pic(
md="\n".join(self._data), width=width, css_path=str(css_path.absolute())
)
return await md_to_pic(md="\n".join(self._data), width=width)
class Notebook:
def __init__(self, data: list[dict] | None = None):
self._data = data if data is not None else []
def text(self, text: str) -> "Notebook":
"""添加Notebook文本"""
self._data.append({"type": "paragraph", "text": text})
return self
def head(self, text: str, level: int = 1) -> "Notebook":
"""添加Notebook标题"""
if not 1 <= level <= 4:
raise ValueError("标题级别必须在1-4之间")
self._data.append({"type": "heading", "text": text, "level": level})
return self
def image(
self,
content: str | Path,
caption: str | None = None,
) -> "Notebook":
"""添加Notebook图片
参数:
content: 图片内容可以是url地址图片路径或base64字符串.
caption: 图片说明.
返回:
Notebook: Notebook
"""
if isinstance(content, Path):
content = str(content.absolute())
if content.startswith("base64"):
content = f"data:image/png;base64,{content.split('base64://', 1)[-1]}"
self._data.append({"type": "image", "src": content, "caption": caption})
return self
def quote(self, text: str | list[str]) -> "Notebook":
"""添加Notebook引用文本
参数:
text: 引用文本内容可以是字符串或字符串列表.
如果是列表则每个元素都会被单独引用
返回:
Notebook: Notebook
"""
if isinstance(text, str):
self._data.append({"type": "blockquote", "text": text})
elif isinstance(text, list):
for t in text:
self._data.append({"type": "blockquote", "text": text})
return self
def code(self, code: str, language: str = "python") -> "Notebook":
"""添加Notebook代码块"""
self._data.append({"type": "code", "code": code, "language": language})
return self
def list(self, items: list[str], ordered: bool = False) -> "Notebook":
"""添加Notebook列表"""
self._data.append({"type": "list", "data": items, "ordered": ordered})
return self
def add_divider(self) -> None:
"""添加分隔线"""
self._data.append({"type": "divider"})
async def build(self) -> bytes:
"""构建Notebook"""
return await template_to_pic(
template_path=str((TEMPLATE_PATH / "notebook").absolute()),
template_name="main.html",
templates={"elements": self._data},
pages={
"viewport": {"width": 700, "height": 1000},
"base_url": f"file://{TEMPLATE_PATH}",
},
wait=2,
)

View File

@ -10,8 +10,6 @@ class HookPriorityException(BaseException):
return self.info
class NotFoundError(Exception):
"""
未发现

View File

@ -21,6 +21,9 @@ CACHED_API_TTL = 300
RAW_CONTENT_FORMAT = "https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
"""raw content格式"""
GITEE_RAW_CONTENT_FORMAT = "https://gitee.com/{owner}/{repo}/raw/main/{path}"
"""gitee raw content格式"""
ARCHIVE_URL_FORMAT = "https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
"""archive url格式"""

View File

@ -4,6 +4,7 @@ from zhenxun.utils.http_utils import AsyncHttpx
from .const import (
ARCHIVE_URL_FORMAT,
GITEE_RAW_CONTENT_FORMAT,
RAW_CONTENT_FORMAT,
RELEASE_ASSETS_FORMAT,
RELEASE_SOURCE_FORMAT,
@ -21,15 +22,11 @@ async def __get_fastest_formats(formats: dict[str, str]) -> list[str]:
async def get_fastest_raw_formats() -> list[str]:
"""获取最快的raw下载地址格式"""
formats: dict[str, str] = {
"https://gitee.com/": GITEE_RAW_CONTENT_FORMAT,
"https://raw.githubusercontent.com/": RAW_CONTENT_FORMAT,
"https://ghproxy.cc/": f"https://ghproxy.cc/{RAW_CONTENT_FORMAT}",
"https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{RAW_CONTENT_FORMAT}",
"https://gh-proxy.com/": f"https://gh-proxy.com/{RAW_CONTENT_FORMAT}",
"https://cdn.jsdelivr.net/": "https://cdn.jsdelivr.net/gh/{owner}/{repo}@{branch}/{path}",
#"https://raw.gitcode.com/": "https://raw.gitcode.com/qq_41605780/{repo}/raw/{branch}/{path}", # ✅ 新增 GitCode raw 格式
#"https://raw.gitcode.com/": "https://raw.gitcode.com/ATTomato/{repo}/raw/{branch}/{path}"
#"https://raw.gitcode.com/": "https://raw.gitcode.com/{owner}/{repo}/raw/{branch}/{path}"
}
return await __get_fastest_formats(formats)
@ -40,7 +37,6 @@ async def get_fastest_archive_formats() -> list[str]:
formats: dict[str, str] = {
"https://github.com/": ARCHIVE_URL_FORMAT,
"https://ghproxy.cc/": f"https://ghproxy.cc/{ARCHIVE_URL_FORMAT}",
"https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{ARCHIVE_URL_FORMAT}",
"https://gh-proxy.com/": f"https://gh-proxy.com/{ARCHIVE_URL_FORMAT}",
}
return await __get_fastest_formats(formats)
@ -52,7 +48,6 @@ async def get_fastest_release_formats() -> list[str]:
formats: dict[str, str] = {
"https://objects.githubusercontent.com/": RELEASE_ASSETS_FORMAT,
"https://ghproxy.cc/": f"https://ghproxy.cc/{RELEASE_ASSETS_FORMAT}",
"https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{RELEASE_ASSETS_FORMAT}",
"https://gh-proxy.com/": f"https://gh-proxy.com/{RELEASE_ASSETS_FORMAT}",
}
return await __get_fastest_formats(formats)
@ -65,4 +60,4 @@ async def get_fastest_release_source_formats() -> list[str]:
"https://codeload.github.com/": RELEASE_SOURCE_FORMAT,
"https://p.102333.xyz/": f"https://p.102333.xyz/{RELEASE_SOURCE_FORMAT}",
}
return await __get_fastest_formats(formats)
return await __get_fastest_formats(formats)

View File

@ -1,13 +1,18 @@
import contextlib
import sys
from typing import Protocol
from aiocache import cached
from nonebot.compat import model_dump
from pydantic import BaseModel, Field
from strenum import StrEnum
from zhenxun.utils.http_utils import AsyncHttpx
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from strenum import StrEnum
from .const import (
CACHED_API_TTL,
GIT_API_COMMIT_FORMAT,

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,36 @@
from abc import ABC
from typing import Literal
from pydantic import BaseModel
class Style(BaseModel):
"""常用样式"""
padding: str = "0px"
margin: str = "0px"
border: str = "0px"
border_radius: str = "0px"
text_align: Literal["left", "right", "center"] = "left"
color: str = "#000"
font_size: str = "16px"
class Component(ABC):
def __init__(self, background_color: str = "#fff", is_container: bool = False):
self.extra_style = []
self.style = Style()
self.background_color = background_color
self.is_container = is_container
self.children = []
def add_child(self, child: "Component | str"):
self.children.append(child)
def set_style(self, style: Style):
self.style = style
def add_style(self, style: str):
self.extra_style.append(style)
def to_html(self) -> str: ...

View File

@ -0,0 +1,15 @@
from ..component import Component, Style
from ..container import Row
class Title(Component):
def __init__(self, text: str, color: str = "#000"):
self.text = text
self.color = color
def build(self):
row = Row()
style = Style(font_size="36px", color=self.color)
row.set_style(style)
# def

View File

@ -0,0 +1,31 @@
from .component import Component
class Row(Component):
def __init__(self, background_color: str = "#fff"):
super().__init__(background_color, True)
class Col(Component):
def __init__(self, background_color: str = "#fff"):
super().__init__(background_color, True)
class Container(Component):
def __init__(self, background_color: str = "#fff"):
super().__init__(background_color, True)
self.children = []
class GlobalOverview:
def __init__(self, name: str):
self.name = name
self.class_name: dict[str, list[str]] = {}
self.content = None
def set_content(self, content: Container):
self.content = content
def add_class(self, class_name: str, contents: list[str]):
"""全局样式"""
self.class_name[class_name] = contents

View File

@ -22,6 +22,4 @@ class MessageManager:
@classmethod
def get(cls, uid: str) -> list[str]:
if uid in cls.data:
return cls.data[uid]
return []
return cls.data[uid] if uid in cls.data else []

View File

@ -1,810 +0,0 @@
import asyncio
from collections.abc import Callable, Coroutine
import copy
import inspect
import random
from typing import ClassVar
import nonebot
from nonebot import get_bots
from nonebot_plugin_apscheduler import scheduler
from pydantic import BaseModel, ValidationError
from zhenxun.configs.config import Config
from zhenxun.models.schedule_info import ScheduleInfo
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.platform import PlatformUtils
SCHEDULE_CONCURRENCY_KEY = "all_groups_concurrency_limit"
class SchedulerManager:
"""
一个通用的持久化的定时任务管理器供所有插件使用
"""
_registered_tasks: ClassVar[
dict[str, dict[str, Callable | type[BaseModel] | None]]
] = {}
_JOB_PREFIX = "zhenxun_schedule_"
_running_tasks: ClassVar[set] = set()
def register(
self, plugin_name: str, params_model: type[BaseModel] | None = None
) -> Callable:
"""
注册一个可调度的任务函数
被装饰的函数签名应为 `async def func(group_id: str | None, **kwargs)`
Args:
plugin_name (str): 插件的唯一名称 (通常是模块名)
params_model (type[BaseModel], optional): 一个 Pydantic BaseModel
用于定义和验证任务函数接受的额外参数
"""
def decorator(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
if plugin_name in self._registered_tasks:
logger.warning(f"插件 '{plugin_name}' 的定时任务已被重复注册。")
self._registered_tasks[plugin_name] = {
"func": func,
"model": params_model,
}
model_name = params_model.__name__ if params_model else ""
logger.debug(
f"插件 '{plugin_name}' 的定时任务已注册,参数模型: {model_name}"
)
return func
return decorator
def get_registered_plugins(self) -> list[str]:
"""获取所有已注册定时任务的插件列表。"""
return list(self._registered_tasks.keys())
def _get_job_id(self, schedule_id: int) -> str:
"""根据数据库ID生成唯一的 APScheduler Job ID。"""
return f"{self._JOB_PREFIX}{schedule_id}"
async def _execute_job(self, schedule_id: int):
"""
APScheduler 调度的入口函数
根据 schedule_id 处理特定任务所有群组任务或全局任务
"""
schedule = await ScheduleInfo.get_or_none(id=schedule_id)
if not schedule or not schedule.is_enabled:
logger.warning(f"定时任务 {schedule_id} 不存在或已禁用,跳过执行。")
return
plugin_name = schedule.plugin_name
task_meta = self._registered_tasks.get(plugin_name)
if not task_meta:
logger.error(
f"无法执行定时任务:插件 '{plugin_name}' 未注册或已卸载。将禁用该任务。"
)
schedule.is_enabled = False
await schedule.save(update_fields=["is_enabled"])
self._remove_aps_job(schedule.id)
return
try:
if schedule.bot_id:
bot = nonebot.get_bot(schedule.bot_id)
else:
bot = nonebot.get_bot()
logger.debug(
f"任务 {schedule_id} 未关联特定Bot使用默认Bot {bot.self_id}"
)
except KeyError:
logger.warning(
f"定时任务 {schedule_id} 需要的 Bot {schedule.bot_id} "
f"不在线,本次执行跳过。"
)
return
except ValueError:
logger.warning(f"当前没有Bot在线定时任务 {schedule_id} 跳过。")
return
if schedule.group_id == "__ALL_GROUPS__":
await self._execute_for_all_groups(schedule, task_meta, bot)
else:
await self._execute_for_single_target(schedule, task_meta, bot)
async def _execute_for_all_groups(
self, schedule: ScheduleInfo, task_meta: dict, bot
):
"""为所有群组执行任务,并处理优先级覆盖。"""
plugin_name = schedule.plugin_name
concurrency_limit = Config.get_config(
"SchedulerManager", SCHEDULE_CONCURRENCY_KEY, 5
)
if not isinstance(concurrency_limit, int) or concurrency_limit <= 0:
logger.warning(
f"无效的定时任务并发限制配置 '{concurrency_limit}',将使用默认值 5。"
)
concurrency_limit = 5
logger.info(
f"开始执行针对 [所有群组] 的任务 "
f"(ID: {schedule.id}, 插件: {plugin_name}, Bot: {bot.self_id})"
f"并发限制: {concurrency_limit}"
)
all_gids = set()
try:
group_list, _ = await PlatformUtils.get_group_list(bot)
all_gids.update(
g.group_id for g in group_list if g.group_id and not g.channel_id
)
except Exception as e:
logger.error(f"'all' 任务获取 Bot {bot.self_id} 的群列表失败", e=e)
return
specific_tasks_gids = set(
await ScheduleInfo.filter(
plugin_name=plugin_name, group_id__in=list(all_gids)
).values_list("group_id", flat=True)
)
semaphore = asyncio.Semaphore(concurrency_limit)
async def worker(gid: str):
"""使用 Semaphore 包装单个群组的任务执行"""
async with semaphore:
temp_schedule = copy.deepcopy(schedule)
temp_schedule.group_id = gid
await self._execute_for_single_target(temp_schedule, task_meta, bot)
await asyncio.sleep(random.uniform(0.1, 0.5))
tasks_to_run = []
for gid in all_gids:
if gid in specific_tasks_gids:
logger.debug(f"群组 {gid} 已有特定任务,跳过 'all' 任务的执行。")
continue
tasks_to_run.append(worker(gid))
if tasks_to_run:
await asyncio.gather(*tasks_to_run)
async def _execute_for_single_target(
self, schedule: ScheduleInfo, task_meta: dict, bot
):
"""为单个目标(具体群组或全局)执行任务。"""
plugin_name = schedule.plugin_name
group_id = schedule.group_id
try:
is_blocked = await CommonUtils.task_is_block(bot, plugin_name, group_id)
if is_blocked:
target_desc = f"{group_id}" if group_id else "全局"
logger.info(
f"插件 '{plugin_name}' 的定时任务在目标 [{target_desc}]"
"因功能被禁用而跳过执行。"
)
return
task_func = task_meta["func"]
job_kwargs = schedule.job_kwargs
if not isinstance(job_kwargs, dict):
logger.error(
f"任务 {schedule.id} 的 job_kwargs 不是字典类型: {type(job_kwargs)}"
)
return
sig = inspect.signature(task_func)
if "bot" in sig.parameters:
job_kwargs["bot"] = bot
logger.info(
f"插件 '{plugin_name}' 开始为目标 [{group_id or '全局'}] "
f"执行定时任务 (ID: {schedule.id})。"
)
task = asyncio.create_task(task_func(group_id, **job_kwargs))
self._running_tasks.add(task)
task.add_done_callback(self._running_tasks.discard)
await task
except Exception as e:
logger.error(
f"执行定时任务 (ID: {schedule.id}, 插件: {plugin_name}, "
f"目标: {group_id or '全局'}) 时发生异常",
e=e,
)
def _validate_and_prepare_kwargs(
self, plugin_name: str, job_kwargs: dict | None
) -> tuple[bool, str | dict]:
"""验证并准备任务参数,应用默认值"""
task_meta = self._registered_tasks.get(plugin_name)
if not task_meta:
return False, f"插件 '{plugin_name}' 未注册。"
params_model = task_meta.get("model")
job_kwargs = job_kwargs if job_kwargs is not None else {}
if not params_model:
if job_kwargs:
logger.warning(
f"插件 '{plugin_name}' 未定义参数模型,但收到了参数: {job_kwargs}"
)
return True, job_kwargs
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
logger.error(f"插件 '{plugin_name}' 的参数模型不是有效的 BaseModel 类")
return False, f"插件 '{plugin_name}' 的参数模型配置错误"
try:
model_validate = getattr(params_model, "model_validate", None)
if not model_validate:
return False, f"插件 '{plugin_name}' 的参数模型不支持验证"
validated_model = model_validate(job_kwargs)
model_dump = getattr(validated_model, "model_dump", None)
if not model_dump:
return False, f"插件 '{plugin_name}' 的参数模型不支持导出"
return True, model_dump()
except ValidationError as e:
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
error_str = "\n".join(errors)
msg = f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
return False, msg
def _add_aps_job(self, schedule: ScheduleInfo):
"""根据 ScheduleInfo 对象添加或更新一个 APScheduler 任务。"""
job_id = self._get_job_id(schedule.id)
try:
scheduler.remove_job(job_id)
except Exception:
pass
if not isinstance(schedule.trigger_config, dict):
logger.error(
f"任务 {schedule.id} 的 trigger_config 不是字典类型: "
f"{type(schedule.trigger_config)}"
)
return
scheduler.add_job(
self._execute_job,
trigger=schedule.trigger_type,
id=job_id,
misfire_grace_time=300,
args=[schedule.id],
**schedule.trigger_config,
)
logger.debug(
f"已在 APScheduler 中添加/更新任务: {job_id} "
f"with trigger: {schedule.trigger_config}"
)
def _remove_aps_job(self, schedule_id: int):
"""移除一个 APScheduler 任务。"""
job_id = self._get_job_id(schedule_id)
try:
scheduler.remove_job(job_id)
logger.debug(f"已从 APScheduler 中移除任务: {job_id}")
except Exception:
pass
async def add_schedule(
self,
plugin_name: str,
group_id: str | None,
trigger_type: str,
trigger_config: dict,
job_kwargs: dict | None = None,
bot_id: str | None = None,
) -> tuple[bool, str]:
"""
添加或更新一个定时任务
"""
if plugin_name not in self._registered_tasks:
return False, f"插件 '{plugin_name}' 没有注册可用的定时任务。"
is_valid, result = self._validate_and_prepare_kwargs(plugin_name, job_kwargs)
if not is_valid:
return False, str(result)
validated_job_kwargs = result
effective_bot_id = bot_id if group_id == "__ALL_GROUPS__" else None
search_kwargs = {
"plugin_name": plugin_name,
"group_id": group_id,
}
if effective_bot_id:
search_kwargs["bot_id"] = effective_bot_id
else:
search_kwargs["bot_id__isnull"] = True
defaults = {
"trigger_type": trigger_type,
"trigger_config": trigger_config,
"job_kwargs": validated_job_kwargs,
"is_enabled": True,
}
schedule = await ScheduleInfo.filter(**search_kwargs).first()
created = False
if schedule:
for key, value in defaults.items():
setattr(schedule, key, value)
await schedule.save()
else:
creation_kwargs = {
"plugin_name": plugin_name,
"group_id": group_id,
"bot_id": effective_bot_id,
**defaults,
}
schedule = await ScheduleInfo.create(**creation_kwargs)
created = True
self._add_aps_job(schedule)
action = "设置" if created else "更新"
return True, f"已成功{action}插件 '{plugin_name}' 的定时任务。"
async def add_schedule_for_all(
self,
plugin_name: str,
trigger_type: str,
trigger_config: dict,
job_kwargs: dict | None = None,
) -> tuple[int, int]:
"""为所有机器人所在的群组添加定时任务。"""
if plugin_name not in self._registered_tasks:
raise ValueError(f"插件 '{plugin_name}' 没有注册可用的定时任务。")
groups = set()
for bot in get_bots().values():
try:
group_list, _ = await PlatformUtils.get_group_list(bot)
groups.update(
g.group_id for g in group_list if g.group_id and not g.channel_id
)
except Exception as e:
logger.error(f"获取 Bot {bot.self_id} 的群列表失败", e=e)
success_count = 0
fail_count = 0
for gid in groups:
try:
success, _ = await self.add_schedule(
plugin_name, gid, trigger_type, trigger_config, job_kwargs
)
if success:
success_count += 1
else:
fail_count += 1
except Exception as e:
logger.error(f"为群 {gid} 添加定时任务失败: {e}", e=e)
fail_count += 1
await asyncio.sleep(0.05)
return success_count, fail_count
async def update_schedule(
self,
schedule_id: int,
trigger_type: str | None = None,
trigger_config: dict | None = None,
job_kwargs: dict | None = None,
) -> tuple[bool, str]:
"""部分更新一个已存在的定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的任务。"
updated_fields = []
if trigger_config is not None:
schedule.trigger_config = trigger_config
updated_fields.append("trigger_config")
if trigger_type is not None and schedule.trigger_type != trigger_type:
schedule.trigger_type = trigger_type
updated_fields.append("trigger_type")
if job_kwargs is not None:
if not isinstance(schedule.job_kwargs, dict):
return False, f"任务 {schedule_id} 的 job_kwargs 数据格式错误。"
merged_kwargs = schedule.job_kwargs.copy()
merged_kwargs.update(job_kwargs)
is_valid, result = self._validate_and_prepare_kwargs(
schedule.plugin_name, merged_kwargs
)
if not is_valid:
return False, str(result)
schedule.job_kwargs = result # type: ignore
updated_fields.append("job_kwargs")
if not updated_fields:
return True, "没有任何需要更新的配置。"
await schedule.save(update_fields=updated_fields)
self._add_aps_job(schedule)
return True, f"成功更新了任务 ID: {schedule_id} 的配置。"
async def remove_schedule(
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
) -> tuple[bool, str]:
"""移除指定插件和群组的定时任务。"""
query = {"plugin_name": plugin_name, "group_id": group_id}
if bot_id:
query["bot_id"] = bot_id
schedules = await ScheduleInfo.filter(**query)
if not schedules:
msg = (
f"未找到与 Bot {bot_id} 相关的群 {group_id} "
f"的插件 '{plugin_name}' 定时任务。"
)
return (False, msg)
for schedule in schedules:
self._remove_aps_job(schedule.id)
await schedule.delete()
target_desc = f"{group_id}" if group_id else "全局"
msg = (
f"已取消 Bot {bot_id} 在 [{target_desc}] "
f"的插件 '{plugin_name}' 所有定时任务。"
)
return (True, msg)
async def remove_schedule_for_all(
self, plugin_name: str, bot_id: str | None = None
) -> int:
"""移除指定插件在所有群组的定时任务。"""
query = {"plugin_name": plugin_name}
if bot_id:
query["bot_id"] = bot_id
schedules_to_delete = await ScheduleInfo.filter(**query).all()
if not schedules_to_delete:
return 0
for schedule in schedules_to_delete:
self._remove_aps_job(schedule.id)
await schedule.delete()
await asyncio.sleep(0.01)
return len(schedules_to_delete)
async def remove_schedules_by_group(self, group_id: str) -> tuple[bool, str]:
"""移除指定群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(group_id=group_id)
if not schedules:
return False, f"{group_id} 没有任何定时任务。"
count = 0
for schedule in schedules:
self._remove_aps_job(schedule.id)
await schedule.delete()
count += 1
await asyncio.sleep(0.01)
return True, f"已成功移除群 {group_id}{count} 个定时任务。"
async def pause_schedules_by_group(self, group_id: str) -> tuple[int, str]:
"""暂停指定群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=True)
if not schedules:
return 0, f"{group_id} 没有正在运行的定时任务可暂停。"
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功暂停群 {group_id}{count} 个定时任务。"
async def resume_schedules_by_group(self, group_id: str) -> tuple[int, str]:
"""恢复指定群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(group_id=group_id, is_enabled=False)
if not schedules:
return 0, f"{group_id} 没有已暂停的定时任务可恢复。"
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功恢复群 {group_id}{count} 个定时任务。"
async def pause_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
"""暂停指定插件在所有群组的定时任务。"""
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=True)
if not schedules:
return 0, f"插件 '{plugin_name}' 没有正在运行的定时任务可暂停。"
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return (
count,
f"已成功暂停插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
)
async def resume_schedules_by_plugin(self, plugin_name: str) -> tuple[int, str]:
"""恢复指定插件在所有群组的定时任务。"""
schedules = await ScheduleInfo.filter(plugin_name=plugin_name, is_enabled=False)
if not schedules:
return 0, f"插件 '{plugin_name}' 没有已暂停的定时任务可恢复。"
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return (
count,
f"已成功恢复插件 '{plugin_name}' 在所有群组的 {count} 个定时任务。",
)
async def pause_schedule_by_plugin_group(
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
) -> tuple[bool, str]:
"""暂停指定插件在指定群组的定时任务。"""
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": True}
if bot_id:
query["bot_id"] = bot_id
schedules = await ScheduleInfo.filter(**query)
if not schedules:
return (
False,
f"{group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已暂停。",
)
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
return (
True,
f"已成功暂停群 {group_id} 的插件 '{plugin_name}'{count} 个定时任务。",
)
async def resume_schedule_by_plugin_group(
self, plugin_name: str, group_id: str | None, bot_id: str | None = None
) -> tuple[bool, str]:
"""恢复指定插件在指定群组的定时任务。"""
query = {"plugin_name": plugin_name, "group_id": group_id, "is_enabled": False}
if bot_id:
query["bot_id"] = bot_id
schedules = await ScheduleInfo.filter(**query)
if not schedules:
return (
False,
f"{group_id} 未设置插件 '{plugin_name}' 的定时任务或任务已启用。",
)
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
return (
True,
f"已成功恢复群 {group_id} 的插件 '{plugin_name}'{count} 个定时任务。",
)
async def remove_all_schedules(self) -> tuple[int, str]:
"""移除所有群组的所有定时任务。"""
schedules = await ScheduleInfo.all()
if not schedules:
return 0, "当前没有任何定时任务。"
count = 0
for schedule in schedules:
self._remove_aps_job(schedule.id)
await schedule.delete()
count += 1
await asyncio.sleep(0.01)
return count, f"已成功移除所有群组的 {count} 个定时任务。"
async def pause_all_schedules(self) -> tuple[int, str]:
"""暂停所有群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(is_enabled=True)
if not schedules:
return 0, "当前没有正在运行的定时任务可暂停。"
count = 0
for schedule in schedules:
success, _ = await self.pause_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功暂停所有群组的 {count} 个定时任务。"
async def resume_all_schedules(self) -> tuple[int, str]:
"""恢复所有群组的所有定时任务。"""
schedules = await ScheduleInfo.filter(is_enabled=False)
if not schedules:
return 0, "当前没有已暂停的定时任务可恢复。"
count = 0
for schedule in schedules:
success, _ = await self.resume_schedule(schedule.id)
if success:
count += 1
await asyncio.sleep(0.01)
return count, f"已成功恢复所有群组的 {count} 个定时任务。"
async def remove_schedule_by_id(self, schedule_id: int) -> tuple[bool, str]:
"""通过ID移除指定的定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
self._remove_aps_job(schedule.id)
await schedule.delete()
return (
True,
f"已删除插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
async def get_schedule_by_id(self, schedule_id: int) -> ScheduleInfo | None:
"""通过ID获取定时任务信息。"""
return await ScheduleInfo.get_or_none(id=schedule_id)
async def get_schedules(
self, plugin_name: str, group_id: str | None
) -> list[ScheduleInfo]:
"""获取特定群组特定插件的所有定时任务。"""
return await ScheduleInfo.filter(plugin_name=plugin_name, group_id=group_id)
async def get_schedule(
self, plugin_name: str, group_id: str | None
) -> ScheduleInfo | None:
"""获取特定群组的定时任务信息。"""
return await ScheduleInfo.get_or_none(
plugin_name=plugin_name, group_id=group_id
)
async def get_all_schedules(
self, plugin_name: str | None = None
) -> list[ScheduleInfo]:
"""获取所有定时任务信息,可按插件名过滤。"""
if plugin_name:
return await ScheduleInfo.filter(plugin_name=plugin_name).all()
return await ScheduleInfo.all()
async def get_schedule_status(self, schedule_id: int) -> dict | None:
"""获取任务的详细状态。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return None
job_id = self._get_job_id(schedule.id)
job = scheduler.get_job(job_id)
status = {
"id": schedule.id,
"bot_id": schedule.bot_id,
"plugin_name": schedule.plugin_name,
"group_id": schedule.group_id,
"is_enabled": schedule.is_enabled,
"trigger_type": schedule.trigger_type,
"trigger_config": schedule.trigger_config,
"job_kwargs": schedule.job_kwargs,
"next_run_time": job.next_run_time.strftime("%Y-%m-%d %H:%M:%S")
if job and job.next_run_time
else "N/A",
"is_paused_in_scheduler": not bool(job.next_run_time) if job else "N/A",
}
return status
async def pause_schedule(self, schedule_id: int) -> tuple[bool, str]:
"""暂停一个定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule or not schedule.is_enabled:
return False, "任务不存在或已暂停。"
schedule.is_enabled = False
await schedule.save(update_fields=["is_enabled"])
job_id = self._get_job_id(schedule.id)
try:
scheduler.pause_job(job_id)
except Exception:
pass
return (
True,
f"已暂停插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
async def resume_schedule(self, schedule_id: int) -> tuple[bool, str]:
"""恢复一个定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule or schedule.is_enabled:
return False, "任务不存在或已启用。"
schedule.is_enabled = True
await schedule.save(update_fields=["is_enabled"])
job_id = self._get_job_id(schedule.id)
try:
scheduler.resume_job(job_id)
except Exception:
self._add_aps_job(schedule)
return (
True,
f"已恢复插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
async def trigger_now(self, schedule_id: int) -> tuple[bool, str]:
"""手动触发一个定时任务。"""
schedule = await self.get_schedule_by_id(schedule_id)
if not schedule:
return False, f"未找到 ID 为 {schedule_id} 的定时任务。"
if schedule.plugin_name not in self._registered_tasks:
return False, f"插件 '{schedule.plugin_name}' 没有注册可用的定时任务。"
try:
await self._execute_job(schedule.id)
return (
True,
f"已手动触发插件 '{schedule.plugin_name}' 在群 {schedule.group_id} "
f"的定时任务 (ID: {schedule.id})。",
)
except Exception as e:
logger.error(f"手动触发任务失败: {e}")
return False, f"手动触发任务失败: {e}"
scheduler_manager = SchedulerManager()
@PriorityLifecycle.on_startup(priority=90)
async def _load_schedules_from_db():
"""在服务启动时从数据库加载并调度所有任务。"""
Config.add_plugin_config(
"SchedulerManager",
SCHEDULE_CONCURRENCY_KEY,
5,
help="“所有群组”类型定时任务的并发执行数量限制",
type=int,
)
logger.info("正在从数据库加载并调度所有定时任务...")
schedules = await ScheduleInfo.filter(is_enabled=True).all()
count = 0
for schedule in schedules:
if schedule.plugin_name in scheduler_manager._registered_tasks:
scheduler_manager._add_aps_job(schedule)
count += 1
else:
logger.warning(f"跳过加载定时任务:插件 '{schedule.plugin_name}' 未注册。")
logger.info(f"定时任务加载完成,共成功加载 {count} 个任务。")

View File

@ -1,7 +1,7 @@
import asyncio
from collections.abc import Awaitable, Callable
import random
from typing import Literal
from typing import cast
import httpx
import nonebot
@ -227,7 +227,7 @@ class PlatformUtils:
url = None
if platform == "qq":
if user_id.isdigit():
url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=640"
url = f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=160"
else:
url = f"https://q.qlogo.cn/qqapp/{appid}/{user_id}/640"
return await AsyncHttpx.get_content(url) if url else None
@ -486,15 +486,134 @@ class PlatformUtils:
return target
class BroadcastEngine:
def __init__(
self,
message: str | UniMessage,
bot: Bot | list[Bot] | None = None,
bot_id: str | set[str] | None = None,
ignore_group: list[str] | None = None,
check_func: Callable[[Bot, str], Awaitable] | None = None,
log_cmd: str | None = None,
platform: str | None = None,
):
"""广播引擎
参数:
message: 广播消息内容
bot: 指定bot对象.
bot_id: 指定bot id.
ignore_group: 忽略群聊列表.
check_func: 发送前对群聊检测方法判断是否发送.
log_cmd: 日志标记.
platform: 指定平台.
异常:
ValueError: 没有可用的Bot对象
"""
if ignore_group is None:
ignore_group = []
self.message = MessageUtils.build_message(message)
self.ignore_group = ignore_group
self.check_func = check_func
self.log_cmd = log_cmd
self.platform = platform
self.bot_list = []
self.count = 0
if bot:
self.bot_list = [bot] if isinstance(bot, Bot) else bot
if isinstance(bot_id, str):
bot_id = set(bot_id)
if bot_id:
for i in bot_id:
try:
self.bot_list.append(nonebot.get_bot(i))
except KeyError:
logger.warning(f"Bot:{i} 对象未连接或不存在")
if not self.bot_list:
raise ValueError("当前没有可用的Bot对象...", log_cmd)
async def call_check(self, bot: Bot, group_id: str) -> bool:
"""运行发送检测函数
参数:
bot: Bot
group_id: 群组id
返回:
bool: 是否发送
"""
if not self.check_func:
return True
if is_coroutine_callable(self.check_func):
is_run = await self.check_func(bot, group_id)
else:
is_run = self.check_func(bot, group_id)
return cast(bool, is_run)
async def __send_message(self, bot: Bot, group: GroupConsole):
"""群组发送消息
参数:
bot: Bot
group: GroupConsole
"""
key = f"{group.group_id}:{group.channel_id}"
if not await self.call_check(bot, group.group_id):
logger.debug(
"广播方法检测运行方法为 False, 已跳过该群组...",
self.log_cmd,
group_id=group.group_id,
)
return
if target := PlatformUtils.get_target(
group_id=group.group_id,
channel_id=group.channel_id,
):
self.ignore_group.append(key)
await MessageUtils.build_message(self.message).send(target, bot)
logger.debug("广播消息发送成功...", self.log_cmd, target=key)
else:
logger.warning("广播消息获取Target失败...", self.log_cmd, target=key)
async def broadcast(self) -> int:
"""广播消息
返回:
int: 成功发送次数
"""
for bot in self.bot_list:
if self.platform and self.platform != PlatformUtils.get_platform(bot):
continue
group_list, _ = await PlatformUtils.get_group_list(bot)
if not group_list:
continue
for group in group_list:
if (
group.group_id in self.ignore_group
or group.channel_id in self.ignore_group
):
continue
try:
await self.__send_message(bot, group)
await asyncio.sleep(random.randint(1, 3))
self.count += 1
except Exception as e:
logger.warning(
"广播消息发送失败", self.log_cmd, target=group.group_id, e=e
)
return self.count
async def broadcast_group(
message: str | UniMessage,
bot: Bot | list[Bot] | None = None,
bot_id: str | set[str] | None = None,
ignore_group: set[int] | None = None,
ignore_group: list[str] = [],
check_func: Callable[[Bot, str], Awaitable] | None = None,
log_cmd: str | None = None,
platform: Literal["qq", "dodo", "kaiheila"] | None = None,
):
platform: str | None = None,
) -> int:
"""获取所有Bot或指定Bot对象广播群聊
参数:
@ -505,81 +624,18 @@ async def broadcast_group(
check_func: 发送前对群聊检测方法判断是否发送.
log_cmd: 日志标记.
platform: 指定平台
返回:
int: 成功发送次数
"""
if platform and platform not in ["qq", "dodo", "kaiheila"]:
raise ValueError("指定平台不支持")
if not message:
raise ValueError("群聊广播消息不能为空")
bot_dict = nonebot.get_bots()
bot_list: list[Bot] = []
if bot:
if isinstance(bot, list):
bot_list = bot
else:
bot_list.append(bot)
elif bot_id:
_bot_id_list = bot_id
if isinstance(bot_id, str):
_bot_id_list = [bot_id]
for id_ in _bot_id_list:
if bot_id in bot_dict:
bot_list.append(bot_dict[bot_id])
else:
logger.warning(f"Bot:{id_} 对象未连接或不存在")
else:
bot_list = list(bot_dict.values())
_used_group = []
for _bot in bot_list:
try:
if platform and platform != PlatformUtils.get_platform(_bot):
continue
group_list, _ = await PlatformUtils.get_group_list(_bot)
if group_list:
for group in group_list:
key = f"{group.group_id}:{group.channel_id}"
try:
if (
ignore_group
and (
group.group_id in ignore_group
or group.channel_id in ignore_group
)
) or key in _used_group:
logger.debug(
"广播方法群组重复, 已跳过...",
log_cmd,
group_id=group.group_id,
)
continue
is_run = False
if check_func:
if is_coroutine_callable(check_func):
is_run = await check_func(_bot, group.group_id)
else:
is_run = check_func(_bot, group.group_id)
if not is_run:
logger.debug(
"广播方法检测运行方法为 False, 已跳过...",
log_cmd,
group_id=group.group_id,
)
continue
target = PlatformUtils.get_target(
user_id=None,
group_id=group.group_id,
channel_id=group.channel_id,
)
if target:
_used_group.append(key)
message_list = message
await MessageUtils.build_message(message_list).send(
target, _bot
)
logger.debug("发送成功", log_cmd, target=key)
await asyncio.sleep(random.randint(1, 3))
else:
logger.warning("target为空", log_cmd, target=key)
except Exception as e:
logger.error("发送失败", log_cmd, target=key, e=e)
except Exception as e:
logger.error(f"Bot: {_bot.self_id} 获取群聊列表失败", command=log_cmd, e=e)
if not message.strip():
raise ValueError("群聊广播消息不能为空...")
return await BroadcastEngine(
message=message,
bot=bot,
bot_id=bot_id,
ignore_group=ignore_group,
check_func=check_func,
log_cmd=log_cmd,
platform=platform,
).broadcast()

View File

@ -244,10 +244,8 @@ def is_valid_date(date_text: str, separator: str = "-") -> bool:
def get_entity_ids(session: Uninfo) -> EntityIDs:
"""获取用户id群组id频道id
参数:
session: Uninfo
返回:
EntityIDs: 用户id群组id频道id
"""