Merge branch 'main' into feature/support-gitcode

This commit is contained in:
xuanerwa 2025-06-20 18:38:43 +08:00 committed by GitHub
commit 36edd32f53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
64 changed files with 1378 additions and 740 deletions

116
README.md
View File

@ -112,7 +112,7 @@ AccessToken: PUBLIC_ZHENXUN_TEST
| [插件库](https://github.com/zhenxun-org/zhenxun_bot_plugins) | 插件 | [zhenxun-org](https://github.com/zhenxun-org) | 原 plugins 文件夹插件 | | [插件库](https://github.com/zhenxun-org/zhenxun_bot_plugins) | 插件 | [zhenxun-org](https://github.com/zhenxun-org) | 原 plugins 文件夹插件 |
| [插件索引库](https://github.com/zhenxun-org/zhenxun_bot_plugins_index) | 插件 | [zhenxun-org](https://github.com/zhenxun-org) | 扩展插件索引库 | | [插件索引库](https://github.com/zhenxun-org/zhenxun_bot_plugins_index) | 插件 | [zhenxun-org](https://github.com/zhenxun-org) | 扩展插件索引库 |
| [一键安装](https://github.com/soloxiaoye2022/zhenxun_bot-deploy) | 安装 | [soloxiaoye2022](https://github.com/soloxiaoye2022) | 第三方 | | [一键安装](https://github.com/soloxiaoye2022/zhenxun_bot-deploy) | 安装 | [soloxiaoye2022](https://github.com/soloxiaoye2022) | 第三方 |
| [WebUi](https://github.com/HibiKier/zhenxun_bot_webui) | 管理 | [hibikier](https://github.com/HibiKier) | 基于真寻 WebApi 的 webui 实现 [预览](#-webui界面展示) | | [WebUi](https://github.com/zhenxun-org/zhenxun_bot) | 管理 | [hibikier](https://github.com/HibiKier) | 基于真寻 WebApi 的 webui 实现 [预览](#-webui界面展示) |
| [安卓 app(WebUi)](https://github.com/YuS1aN/zhenxun_bot_android_ui) | 安装 | [YuS1aN](https://github.com/YuS1aN) | 第三方 | | [安卓 app(WebUi)](https://github.com/YuS1aN/zhenxun_bot_android_ui) | 安装 | [YuS1aN](https://github.com/YuS1aN) | 第三方 |
</div> </div>
@ -126,6 +126,28 @@ AccessToken: PUBLIC_ZHENXUN_TEST
- 提供了 cd阻塞每日次数等限制仅仅通过简单的属性就可以生成一个限制例如`PluginCdBlock` 等 - 提供了 cd阻塞每日次数等限制仅仅通过简单的属性就可以生成一个限制例如`PluginCdBlock` 等
- **更多详细请通过 [传送门](https://zhenxun-org.github.io/zhenxun_bot/) 查看文档!** - **更多详细请通过 [传送门](https://zhenxun-org.github.io/zhenxun_bot/) 查看文档!**
## 🐣 小白整合
如果你系统是 **Windows** 且不想下载 Python
可以使用整合包Python3.10+zhenxun+webui
文档地址:[整合包文档](https://hibikier.github.io/zhenxun_bot/beginner/)
<details>
<summary>下载地址</summary>
- **百度云:**
https://pan.baidu.com/s/1ph4yzx1vdNbkxm9VBKDdgQ?pwd=971j
- **天翼云:**
https://cloud.189.cn/web/share?code=jq67r2i2E7Fb
访问码8wxm
- **Google Drive**
https://drive.google.com/file/d/1cc3Dqjk0x5hWGLNeMkrFwWl8BvsK6KfD/view?usp=drive_link
</details>
## 🛠️ 简单部署 ## 🛠️ 简单部署
```bash ```bash
@ -150,7 +172,7 @@ poetry run python bot.py
1.在 .env.dev 文件中填写你的机器人配置项 1.在 .env.dev 文件中填写你的机器人配置项
2.在 configs/config.yaml 文件中修改你需要修改的插件配置项 2.在 data/config.yaml 文件中修改你需要修改的插件配置项
<details> <details>
<summary>数据库地址DB_URL配置说明</summary> <summary>数据库地址DB_URL配置说明</summary>
@ -272,12 +294,12 @@ DB_URL 是基于 Tortoise ORM 的数据库连接字符串,用于指定项目
## ❔ 需要帮助? ## ❔ 需要帮助?
> [!TIP] > [!TIP]
> 发起 [issue](https://github.com/HibiKier/zhenxun_bot/issues/new/choose) 前,我们希望你能够阅读过或者了解 [提问的智慧](https://github.com/ryanhanwu/How-To-Ask-Questions-The-Smart-Way/blob/main/README-zh_CN.md) > 发起 [issue](https://github.com/zhenxun-org/zhenxun_bot/issues/new/choose) 前,我们希望你能够阅读过或者了解 [提问的智慧](https://github.com/ryanhanwu/How-To-Ask-Questions-The-Smart-Way/blob/main/README-zh_CN.md)
> >
> - 善用[搜索引擎](https://www.google.com/) > - 善用[搜索引擎](https://www.google.com/)
> - 查阅 issue 中是否有类似问题,如果没有请按照模板发起 issue > - 查阅 issue 中是否有类似问题,如果没有请按照模板发起 issue
欢迎前往 [issue](https://github.com/HibiKier/zhenxun_bot/issues/new/choose) 中提出你遇到的问题,或者加入我们的 [用户群](https://qm.qq.com/q/mRNtLSl6uc) 或 [技术群](https://qm.qq.com/q/YYYt5rkMYc)与我们联系 欢迎前往 [issue](https://github.com/zhenxun-org/zhenxun_bot/issues/new/choose) 中提出你遇到的问题,或者加入我们的 [用户群](https://qm.qq.com/q/mRNtLSl6uc) 或 [技术群](https://qm.qq.com/q/YYYt5rkMYc)与我们联系
## 🛠️ 进度追踪 ## 🛠️ 进度追踪
@ -287,6 +309,8 @@ Project [zhenxun_bot](https://github.com/users/HibiKier/projects/2)
首席设计师:[酥酥/coldly-ss](https://github.com/coldly-ss) 首席设计师:[酥酥/coldly-ss](https://github.com/coldly-ss)
LOGO 设计:[FrostN0v0](https://github.com/FrostN0v0)
## 🙏 感谢 ## 🙏 感谢
[botuniverse / onebot](https://github.com/botuniverse/onebot) :超棒的机器人协议 [botuniverse / onebot](https://github.com/botuniverse/onebot) :超棒的机器人协议
@ -326,34 +350,68 @@ Project [zhenxun_bot](https://github.com/users/HibiKier/projects/2)
<img src="https://contrib.rocks/image?repo=HibiKier/zhenxun_bot&max=1000" alt="contributors"/> <img src="https://contrib.rocks/image?repo=HibiKier/zhenxun_bot&max=1000" alt="contributors"/>
</a> </a>
## 📸 WebUI 界面展示 ## 📸 WebUI 界面展示(仅展示默认主题下的 pc 端)
<div style="display: flex; flex-wrap: wrap; justify-content: space-between;"> <div style="display: flex; flex-wrap: wrap; justify-content: space-between;">
<div style="width: 48%; margin-bottom: 10px;">
<img src="./docs_image/webui00.png" alt="webui00" style="width: 100%; height: auto;">
</div>
<div style="width: 48%; margin-bottom: 10px;">
<img src="./docs_image/webui01.png" alt="webui01" style="width: 100%; height: auto;">
</div>
<div style="width: 48%; margin-bottom: 10px;"> #### 登录界面
<img src="./docs_image/webui02.png" alt="webui02" style="width: 100%; height: auto;">
</div>
<div style="width: 48%; margin-bottom: 10px;">
<img src="./docs_image/webui03.png" alt="webui03" style="width: 100%; height: auto;">
</div>
<div style="width: 48%; margin-bottom: 10px;"> ![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-login.jpg)
<img src="./docs_image/webui04.png" alt="webui04" style="width: 100%; height: auto;">
</div> #### API 设置
<div style="width: 48%; margin-bottom: 10px;">
<img src="./docs_image/webui05.png" alt="webui05" style="width: 100%; height: auto;"> ![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-api.jpg)
</div>
#### 仪表盘
![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-dashboard.jpg)
#### 仪表盘(展开)
![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-dashboard1.jpg)
#### 控制台
![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-command.jpg)
#### 插件列表
![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-plugin.jpg)
#### 插件列表(配置项)
![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-plugin1.jpg)
#### 插件商店
![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-store.jpg)
#### 好友/群组管理
![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-manage.jpg)
#### 请求管理
![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-manage1.jpg)
#### 数据库管理
![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-database.jpg)
### 文件管理
![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-system.jpg)
### 文件管理(文本查看)
![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-system1.jpg)
### 文件管理(图片查看)
![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-system2.jpg)
### 关于
![x](https://github.com/zhenxun-org/zhenxun_bot/blob/main/docs_image/pc-about.jpg)
<div style="width: 48%; margin-bottom: 10px;">
<img src="./docs_image/webui06.png" alt="webui06" style="width: 100%; height: auto;">
</div>
<div style="width: 48%; margin-bottom: 10px;">
<img src="./docs_image/webui07.png" alt="webui07" style="width: 100%; height: auto;">
</div>
</div> </div>

4
bot.py
View File

@ -14,9 +14,9 @@ driver.register_adapter(OneBotV11Adapter)
# driver.register_adapter(DoDoAdapter) # driver.register_adapter(DoDoAdapter)
# driver.register_adapter(DiscordAdapter) # driver.register_adapter(DiscordAdapter)
from zhenxun.services.db_context import disconnect, init from zhenxun.services.db_context import disconnect
driver.on_startup(init) # driver.on_startup(init)
driver.on_shutdown(disconnect) driver.on_shutdown(disconnect)
# nonebot.load_builtin_plugins("echo") # nonebot.load_builtin_plugins("echo")

BIN
docs_image/pc-about.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 388 KiB

BIN
docs_image/pc-api.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 315 KiB

BIN
docs_image/pc-command.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 630 KiB

BIN
docs_image/pc-dashboard.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 708 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 598 KiB

BIN
docs_image/pc-database.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 405 KiB

BIN
docs_image/pc-login.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 250 KiB

BIN
docs_image/pc-manage.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 504 KiB

BIN
docs_image/pc-manage1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 423 KiB

BIN
docs_image/pc-plugin.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 551 KiB

BIN
docs_image/pc-plugin1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 453 KiB

BIN
docs_image/pc-store.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 400 KiB

BIN
docs_image/pc-system.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 336 KiB

BIN
docs_image/pc-system1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 152 KiB

BIN
docs_image/pc-system2.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 315 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 352 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 279 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 182 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 228 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 200 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 201 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 193 KiB

View File

@ -16,6 +16,7 @@ from zhenxun.models.sign_user import SignUser
from zhenxun.models.user_console import UserConsole from zhenxun.models.user_console import UserConsole
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.decorator.shop import shop_register from zhenxun.utils.decorator.shop import shop_register
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.manager.resource_manager import ResourceManager from zhenxun.utils.manager.resource_manager import ResourceManager
from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.platform import PlatformUtils
@ -70,7 +71,7 @@ from public.bag_users t1
""" """
@driver.on_startup @PriorityLifecycle.on_startup(priority=5)
async def _(): async def _():
await ResourceManager.init_resources() await ResourceManager.init_resources()
"""签到与用户的数据迁移""" """签到与用户的数据迁移"""

View File

@ -14,6 +14,7 @@ from zhenxun.services.log import logger
from zhenxun.utils._build_image import BuildImage from zhenxun.utils._build_image import BuildImage
from zhenxun.utils._image_template import ImageTemplate from zhenxun.utils._image_template import ImageTemplate
from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.platform import PlatformUtils
BASE_PATH = DATA_PATH / "welcome_message" BASE_PATH = DATA_PATH / "welcome_message"
@ -91,7 +92,7 @@ def migrate(path: Path):
json.dump(new_data, f, ensure_ascii=False, indent=4) json.dump(new_data, f, ensure_ascii=False, indent=4)
@driver.on_startup @PriorityLifecycle.on_startup(priority=5)
def _(): def _():
"""数据迁移 """数据迁移

View File

@ -37,8 +37,8 @@ __plugin_meta__ = PluginMetadata(
configs=[ configs=[
RegisterConfig( RegisterConfig(
key="type", key="type",
value="normal", value="zhenxun",
help="帮助图片样式 ['normal', 'HTML', 'zhenxun']", help="帮助图片样式 [normal, HTML, zhenxun]",
default_value="zhenxun", default_value="zhenxun",
) )
], ],

View File

@ -49,4 +49,14 @@ Config.add_plugin_config(
type=bool, type=bool,
) )
Config.add_plugin_config(
"hook",
"RECORD_BOT_SENT_MESSAGES",
True,
help="记录bot消息发送",
default_value=True,
type=bool,
)
nonebot.load_plugins(str(Path(__file__).parent.resolve())) nonebot.load_plugins(str(Path(__file__).parent.resolve()))

View File

@ -1,23 +1,85 @@
from typing import Any from typing import Any
from nonebot.adapters import Bot from nonebot.adapters import Bot, Message
from zhenxun.configs.config import Config
from zhenxun.models.bot_message_store import BotMessageStore
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.enum import BotSentType
from zhenxun.utils.manager.message_manager import MessageManager from zhenxun.utils.manager.message_manager import MessageManager
from zhenxun.utils.platform import PlatformUtils
def replace_message(message: Message) -> str:
"""将消息中的at、image、record、face替换为字符串
参数:
message: Message
返回:
str: 文本消息
"""
result = ""
for msg in message:
if isinstance(msg, str):
result += msg
elif msg.type == "at":
result += f"@{msg.data['qq']}"
elif msg.type == "image":
result += "[image]"
elif msg.type == "record":
result += "[record]"
elif msg.type == "face":
result += f"[face:{msg.data['id']}]"
elif msg.type == "reply":
result += ""
else:
result += str(msg)
return result
@Bot.on_called_api @Bot.on_called_api
async def handle_api_result( async def handle_api_result(
bot: Bot, exception: Exception | None, api: str, data: dict[str, Any], result: Any bot: Bot, exception: Exception | None, api: str, data: dict[str, Any], result: Any
): ):
if not exception and api == "send_msg": if exception or api != "send_msg":
return
user_id = data.get("user_id")
group_id = data.get("group_id")
message_id = result.get("message_id")
message: Message = data.get("message", "")
message_type = data.get("message_type")
try: try:
if (uid := data.get("user_id")) and (msg_id := result.get("message_id")): # 记录消息id
MessageManager.add(str(uid), str(msg_id)) if user_id and message_id:
MessageManager.add(str(user_id), str(message_id))
logger.debug( logger.debug(
f"收集消息iduser_id: {uid}, msg_id: {msg_id}", "msg_hook" f"收集消息iduser_id: {user_id}, msg_id: {message_id}", "msg_hook"
) )
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e
) )
if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"):
return
try:
await BotMessageStore.create(
bot_id=bot.self_id,
user_id=user_id,
group_id=group_id,
sent_type=BotSentType.GROUP
if message_type == "group"
else BotSentType.PRIVATE,
text=replace_message(message),
plain_text=message.extract_plain_text()
if isinstance(message, Message)
else replace_message(message),
platform=PlatformUtils.get_platform(bot),
)
logger.debug(f"消息发送记录message: {message}")
except Exception as e:
logger.warning(
f"消息发送记录发生错误...data: {data}, result: {result}",
"msg_hook",
e=e,
)

View File

@ -11,6 +11,7 @@ from zhenxun.configs.config import Config
from zhenxun.configs.path_config import DATA_PATH from zhenxun.configs.path_config import DATA_PATH
from zhenxun.configs.utils import RegisterConfig from zhenxun.configs.utils import RegisterConfig
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
_yaml = YAML(pure=True) _yaml = YAML(pure=True)
_yaml.allow_unicode = True _yaml.allow_unicode = True
@ -102,7 +103,7 @@ def _generate_simple_config(exists_module: list[str]):
temp_file.unlink() temp_file.unlink()
@driver.on_startup @PriorityLifecycle.on_startup(priority=0)
def _(): def _():
""" """
初始化插件数据配置 初始化插件数据配置
@ -125,3 +126,4 @@ def _():
with plugins2config_file.open("w", encoding="utf8") as wf: with plugins2config_file.open("w", encoding="utf8") as wf:
_yaml.dump(_data, wf) _yaml.dump(_data, wf)
_generate_simple_config(exists_module) _generate_simple_config(exists_module)
Config.reload()

View File

@ -20,6 +20,7 @@ from zhenxun.utils.enum import (
PluginLimitType, PluginLimitType,
PluginType, PluginType,
) )
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .manager import manager from .manager import manager
@ -95,7 +96,7 @@ async def _handle_setting(
) )
@driver.on_startup @PriorityLifecycle.on_startup(priority=5)
async def _(): async def _():
""" """
初始化插件数据配置 初始化插件数据配置

View File

@ -10,6 +10,7 @@ from zhenxun.models.group_console import GroupConsole
from zhenxun.models.task_info import TaskInfo from zhenxun.models.task_info import TaskInfo
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
driver: Driver = nonebot.get_driver() driver: Driver = nonebot.get_driver()
@ -132,7 +133,7 @@ async def create_schedule(task: Task):
logger.error(f"动态创建定时任务 {task.name}({task.module}) 失败", e=e) logger.error(f"动态创建定时任务 {task.name}({task.module}) 失败", e=e)
@driver.on_startup @PriorityLifecycle.on_startup(priority=5)
async def _(): async def _():
""" """
初始化插件数据配置 初始化插件数据配置

View File

@ -23,6 +23,12 @@ from .config import (
LOG_COMMAND, LOG_COMMAND,
) )
BAT_FILE = Path() / "win启动.bat"
WIN_COMMAND = ["./Python310/python.exe", "-m", "pip", "install", "-r"]
DEFAULT_COMMAND = ["poetry", "run", "pip", "install", "-r"]
def row_style(column: str, text: str) -> RowStyle: def row_style(column: str, text: str) -> RowStyle:
"""被动技能文本风格 """被动技能文本风格
@ -50,6 +56,33 @@ def install_requirement(plugin_path: Path):
VirtualEnvPackageManager.install_requirement(existing_requirements) VirtualEnvPackageManager.install_requirement(existing_requirements)
if not existing_requirements:
logger.debug(
f"No requirement.txt found for plugin: {plugin_path.name}", "插件管理"
)
return
try:
command = WIN_COMMAND if BAT_FILE.exists() else DEFAULT_COMMAND
command.append(str(existing_requirements))
result = subprocess.run(
command,
check=True,
capture_output=True,
text=True,
)
logger.debug(
"Successfully installed dependencies for"
f" plugin: {plugin_path.name}. Output:\n{result.stdout}",
"插件管理",
)
except subprocess.CalledProcessError:
logger.error(
f"Failed to install dependencies for plugin: {plugin_path.name}. "
" Error:\n{e.stderr}"
)
class StoreManager: class StoreManager:
@classmethod @classmethod
async def get_github_plugins(cls) -> list[StorePluginInfo]: async def get_github_plugins(cls) -> list[StorePluginInfo]:

View File

@ -1,12 +1,8 @@
import nonebot
from nonebot.drivers import Driver
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
driver: Driver = nonebot.get_driver()
@driver.on_startup @PriorityLifecycle.on_startup(priority=5)
async def _(): async def _():
"""开启/禁用插件格式修改""" """开启/禁用插件格式修改"""
_, is_create = await GroupConsole.get_or_create(group_id=133133133) _, is_create = await GroupConsole.get_or_create(group_id=133133133)

View File

@ -5,7 +5,9 @@ from nonebot_plugin_alconna import (
AlconnaQuery, AlconnaQuery,
Args, Args,
Arparma, Arparma,
At,
Match, Match,
MultiVar,
Option, Option,
Query, Query,
Subcommand, Subcommand,
@ -47,6 +49,7 @@ __plugin_meta__ = PluginMetadata(
plugin_type=PluginType.NORMAL, plugin_type=PluginType.NORMAL,
menu_type="商店", menu_type="商店",
commands=[ commands=[
Command(command="商店"),
Command(command="我的金币"), Command(command="我的金币"),
Command(command="我的道具"), Command(command="我的道具"),
Command(command="购买道具"), Command(command="购买道具"),
@ -75,13 +78,21 @@ _matcher = on_alconna(
Subcommand("my-cost", help_text="我的金币"), Subcommand("my-cost", help_text="我的金币"),
Subcommand("my-props", help_text="我的道具"), Subcommand("my-props", help_text="我的道具"),
Subcommand("buy", Args["name?", str]["num?", int], help_text="购买道具"), Subcommand("buy", Args["name?", str]["num?", int], help_text="购买道具"),
Subcommand("use", Args["name?", str]["num?", int], help_text="使用道具"),
Subcommand("gold-list", Args["num?", int], help_text="金币排行"), Subcommand("gold-list", Args["num?", int], help_text="金币排行"),
), ),
priority=5, priority=5,
block=True, block=True,
) )
_use_matcher = on_alconna(
Alconna(
"使用道具",
Args["name?", str]["num?", int]["at_users?", MultiVar(At)],
),
priority=5,
block=True,
)
_matcher.shortcut( _matcher.shortcut(
"我的金币", "我的金币",
command="商店", command="商店",
@ -103,13 +114,6 @@ _matcher.shortcut(
prefix=True, prefix=True,
) )
_matcher.shortcut(
"使用道具(?P<name>.*?)",
command="商店",
arguments=["use", "{name}"],
prefix=True,
)
_matcher.shortcut( _matcher.shortcut(
"金币排行", "金币排行",
command="商店", command="商店",
@ -173,7 +177,7 @@ async def _(
await MessageUtils.build_message(result).send(reply_to=True) await MessageUtils.build_message(result).send(reply_to=True)
@_matcher.assign("use") @_use_matcher.handle()
async def _( async def _(
bot: Bot, bot: Bot,
event: Event, event: Event,
@ -182,6 +186,7 @@ async def _(
arparma: Arparma, arparma: Arparma,
name: Match[str], name: Match[str],
num: Query[int] = AlconnaQuery("num", 1), num: Query[int] = AlconnaQuery("num", 1),
at_users: Query[list[At]] = AlconnaQuery("at_users", []),
): ):
if not name.available: if not name.available:
await MessageUtils.build_message( await MessageUtils.build_message(
@ -189,7 +194,7 @@ async def _(
).finish(reply_to=True) ).finish(reply_to=True)
try: try:
result = await ShopManage.use( result = await ShopManage.use(
bot, event, session, message, name.result, num.result, "" bot, event, session, message, name.result, num.result, "", at_users.result
) )
logger.info( logger.info(
f"使用道具 {name.result}, 数量: {num.result}", f"使用道具 {name.result}, 数量: {num.result}",

View File

@ -8,7 +8,7 @@ from typing import Any, Literal
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.compat import model_dump from nonebot.compat import model_dump
from nonebot_plugin_alconna import UniMessage, UniMsg from nonebot_plugin_alconna import At, UniMessage, UniMsg
from nonebot_plugin_uninfo import Uninfo from nonebot_plugin_uninfo import Uninfo
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
from tortoise.expressions import Q from tortoise.expressions import Q
@ -48,6 +48,10 @@ class Goods(BaseModel):
"""model""" """model"""
session: Uninfo | None = None session: Uninfo | None = None
"""Uninfo""" """Uninfo"""
at_user: str | None = None
"""At对象"""
at_users: list[str] = []
"""At对象列表"""
class ShopParam(BaseModel): class ShopParam(BaseModel):
@ -73,6 +77,10 @@ class ShopParam(BaseModel):
"""Uninfo""" """Uninfo"""
message: UniMsg message: UniMsg
"""UniMessage""" """UniMessage"""
at_user: str | None = None
"""At对象"""
at_users: list[str] = []
"""At对象列表"""
extra_data: dict[str, Any] = Field(default_factory=dict) extra_data: dict[str, Any] = Field(default_factory=dict)
"""额外数据""" """额外数据"""
@ -156,6 +164,7 @@ class ShopManage:
goods: Goods, goods: Goods,
num: int, num: int,
text: str, text: str,
at_users: list[str] = [],
) -> tuple[ShopParam, dict[str, Any]]: ) -> tuple[ShopParam, dict[str, Any]]:
"""构造参数 """构造参数
@ -165,6 +174,7 @@ class ShopManage:
goods_name: 商品名称 goods_name: 商品名称
num: 数量 num: 数量
text: 其他信息 text: 其他信息
at_users: at用户
""" """
group_id = None group_id = None
if session.group: if session.group:
@ -172,6 +182,7 @@ class ShopManage:
session.group.parent.id if session.group.parent else session.group.id session.group.parent.id if session.group.parent else session.group.id
) )
_kwargs = goods.params _kwargs = goods.params
at_user = at_users[0] if at_users else None
model = goods.model( model = goods.model(
**{ **{
"goods_name": goods.name, "goods_name": goods.name,
@ -183,6 +194,8 @@ class ShopManage:
"text": text, "text": text,
"session": session, "session": session,
"message": message, "message": message,
"at_user": at_user,
"at_users": at_users,
} }
) )
return model, { return model, {
@ -194,6 +207,8 @@ class ShopManage:
"num": num, "num": num,
"text": text, "text": text,
"goods_name": goods.name, "goods_name": goods.name,
"at_user": at_user,
"at_users": at_users,
} }
@classmethod @classmethod
@ -223,6 +238,7 @@ class ShopManage:
**param.extra_data, **param.extra_data,
"session": session, "session": session,
"message": message, "message": message,
"shop_param": ShopParam,
} }
for key in list(param_json.keys()): for key in list(param_json.keys()):
if key not in args: if key not in args:
@ -308,6 +324,7 @@ class ShopManage:
goods_name: str, goods_name: str,
num: int, num: int,
text: str, text: str,
at_users: list[At] = [],
) -> str | UniMessage | None: ) -> str | UniMessage | None:
"""使用道具 """使用道具
@ -319,6 +336,7 @@ class ShopManage:
goods_name: 商品名称 goods_name: 商品名称
num: 使用数量 num: 使用数量
text: 其他信息 text: 其他信息
at_users: at用户
返回: 返回:
str | MessageFactory | None: 使用完成后返回信息 str | MessageFactory | None: 使用完成后返回信息
@ -339,8 +357,9 @@ class ShopManage:
goods = cls.uuid2goods.get(goods_info.uuid) goods = cls.uuid2goods.get(goods_info.uuid)
if not goods or not goods.func: if not goods or not goods.func:
return f"{goods_info.goods_name} 未注册使用函数, 无法使用..." return f"{goods_info.goods_name} 未注册使用函数, 无法使用..."
at_user_ids = [at.target for at in at_users]
param, kwargs = cls.__build_params( param, kwargs = cls.__build_params(
bot, event, session, message, goods, num, text bot, event, session, message, goods, num, text, at_user_ids
) )
if num > param.max_num_limit: if num > param.max_num_limit:
return f"{goods_info.goods_name} 单次使用最大数量为{param.max_num_limit}..." return f"{goods_info.goods_name} 单次使用最大数量为{param.max_num_limit}..."
@ -480,10 +499,13 @@ class ShopManage:
if not user.props: if not user.props:
return None return None
user.props = {uuid: count for uuid, count in user.props.items() if count > 0}
goods_list = await GoodsInfo.filter(uuid__in=user.props.keys()).all() goods_list = await GoodsInfo.filter(uuid__in=user.props.keys()).all()
goods_by_uuid = {item.uuid: item for item in goods_list} goods_by_uuid = {item.uuid: item for item in goods_list}
user.props = {
uuid: count
for uuid, count in user.props.items()
if count > 0 and goods_by_uuid.get(uuid)
}
table_rows = [] table_rows = []
for i, prop_uuid in enumerate(user.props): for i, prop_uuid in enumerate(user.props):

View File

@ -10,7 +10,6 @@ from nonebot_plugin_alconna import (
store_true, store_true,
) )
from nonebot_plugin_apscheduler import scheduler from nonebot_plugin_apscheduler import scheduler
from nonebot_plugin_uninfo import Uninfo
from zhenxun.configs.utils import ( from zhenxun.configs.utils import (
Command, Command,
@ -23,7 +22,7 @@ from zhenxun.utils.depends import UserName
from zhenxun.utils.message import MessageUtils from zhenxun.utils.message import MessageUtils
from ._data_source import SignManage from ._data_source import SignManage
from .goods_register import driver # noqa: F401 from .goods_register import Uninfo
from .utils import clear_sign_data_pic from .utils import clear_sign_data_pic
__plugin_meta__ = PluginMetadata( __plugin_meta__ = PluginMetadata(

View File

@ -1,7 +1,6 @@
from decimal import Decimal from decimal import Decimal
import nonebot import nonebot
from nonebot.drivers import Driver
from nonebot_plugin_uninfo import Uninfo from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.sign_user import SignUser from zhenxun.models.sign_user import SignUser
@ -9,14 +8,7 @@ from zhenxun.models.user_console import UserConsole
from zhenxun.utils.decorator.shop import shop_register from zhenxun.utils.decorator.shop import shop_register
from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.platform import PlatformUtils
driver: Driver = nonebot.get_driver() driver = nonebot.get_driver()
# @driver.on_startup
# async def _():
# """
# 导入内置的三个商品
# """
@shop_register( @shop_register(

View File

@ -16,6 +16,7 @@ from zhenxun.models.sign_log import SignLog
from zhenxun.models.sign_user import SignUser from zhenxun.models.sign_user import SignUser
from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.image_utils import BuildImage from zhenxun.utils.image_utils import BuildImage
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.platform import PlatformUtils
from .config import ( from .config import (
@ -54,7 +55,7 @@ LG_MESSAGE = [
] ]
@driver.on_startup @PriorityLifecycle.on_startup(priority=5)
async def init_image(): async def init_image():
SIGN_RESOURCE_PATH.mkdir(parents=True, exist_ok=True) SIGN_RESOURCE_PATH.mkdir(parents=True, exist_ok=True)
SIGN_TODAY_CARD_PATH.mkdir(exist_ok=True, parents=True) SIGN_TODAY_CARD_PATH.mkdir(exist_ok=True, parents=True)

View File

@ -53,10 +53,7 @@ async def _(
) )
@scheduler.scheduled_job( @scheduler.scheduled_job("interval", minutes=1, max_instances=5)
"interval",
minutes=1,
)
async def _(): async def _():
try: try:
call_list = TEMP_LIST.copy() call_list = TEMP_LIST.copy()

View File

@ -10,7 +10,9 @@ from zhenxun.configs.config import Config as gConfig
from zhenxun.configs.utils import PluginExtraData, RegisterConfig from zhenxun.configs.utils import PluginExtraData, RegisterConfig
from zhenxun.services.log import logger, logger_ from zhenxun.services.log import logger, logger_
from zhenxun.utils.enum import PluginType from zhenxun.utils.enum import PluginType
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .api.configure import router as configure_router
from .api.logs import router as ws_log_routes from .api.logs import router as ws_log_routes
from .api.logs.log_manager import LOG_STORAGE from .api.logs.log_manager import LOG_STORAGE
from .api.menu import router as menu_router from .api.menu import router as menu_router
@ -81,6 +83,7 @@ BaseApiRouter.include_router(database_router)
BaseApiRouter.include_router(plugin_router) BaseApiRouter.include_router(plugin_router)
BaseApiRouter.include_router(system_router) BaseApiRouter.include_router(system_router)
BaseApiRouter.include_router(menu_router) BaseApiRouter.include_router(menu_router)
BaseApiRouter.include_router(configure_router)
WsApiRouter = APIRouter(prefix="/zhenxun/socket") WsApiRouter = APIRouter(prefix="/zhenxun/socket")
@ -89,7 +92,7 @@ WsApiRouter.include_router(status_routes)
WsApiRouter.include_router(chat_routes) WsApiRouter.include_router(chat_routes)
@driver.on_startup @PriorityLifecycle.on_startup(priority=0)
async def _(): async def _():
try: try:
# 存储任务引用的列表,防止任务被垃圾回收 # 存储任务引用的列表,防止任务被垃圾回收

View File

@ -0,0 +1,133 @@
import asyncio
import os
from pathlib import Path
import re
import subprocess
import sys
import time
from fastapi import APIRouter
from fastapi.responses import JSONResponse
import nonebot
from zhenxun.configs.config import BotConfig, Config
from ...base_model import Result
from .data_source import test_db_connection
from .model import Setting
router = APIRouter(prefix="/configure")
driver = nonebot.get_driver()
port = driver.config.port
BAT_FILE = Path() / "win启动.bat"
FILE_NAME = ".configure_restart"
@router.post(
"/set_configure",
response_model=Result,
response_class=JSONResponse,
description="设置基础配置",
)
async def _(setting: Setting) -> Result:
global port
password = Config.get_config("web-ui", "password")
if password or BotConfig.db_url:
return Result.fail("配置已存在请先删除DB_URL内容和前端密码再进行设置。")
env_file = Path() / ".env.dev"
if not env_file.exists():
return Result.fail("配置文件.env.dev不存在。")
env_text = env_file.read_text(encoding="utf-8")
if setting.db_url:
if setting.db_url.startswith("sqlite"):
base_dir = Path().resolve()
# 清理和验证数据库路径
db_path_str = setting.db_url.split(":")[-1].strip()
# 移除任何可能的路径遍历尝试
db_path_str = re.sub(r"[\\/]\.\.[\\/]", "", db_path_str)
# 规范化路径
db_path = Path(db_path_str).resolve()
parent_path = db_path.parent
# 验证路径是否在项目根目录内
try:
if not parent_path.absolute().is_relative_to(base_dir):
return Result.fail("数据库路径不在项目根目录内。")
except ValueError:
return Result.fail("无效的数据库路径。")
# 创建目录
try:
parent_path.mkdir(parents=True, exist_ok=True)
except Exception as e:
return Result.fail(f"创建数据库目录失败: {e!s}")
env_text = env_text.replace('DB_URL = ""', f'DB_URL = "{setting.db_url}"')
if setting.superusers:
superusers = ", ".join([f'"{s}"' for s in setting.superusers])
env_text = re.sub(r"SUPERUSERS=\[.*?\]", f"SUPERUSERS=[{superusers}]", env_text)
if setting.host:
env_text = env_text.replace("HOST = 127.0.0.1", f"HOST = {setting.host}")
if setting.port:
env_text = env_text.replace("PORT = 8080", f"PORT = {setting.port}")
port = setting.port
if setting.username:
Config.set_config("web-ui", "username", setting.username)
Config.set_config("web-ui", "password", setting.password, True)
env_file.write_text(env_text, encoding="utf-8")
if BAT_FILE.exists():
for file in os.listdir(Path()):
if file.startswith(FILE_NAME):
Path(file).unlink()
flag_file = Path() / f"{FILE_NAME}_{int(time.time())}"
flag_file.touch()
return Result.ok(BAT_FILE.exists(), info="设置成功,请重启真寻以完成配置!")
@router.get(
"/test_db",
response_model=Result,
response_class=JSONResponse,
description="设置基础配置",
)
async def _(db_url: str) -> Result:
result = await test_db_connection(db_url)
if isinstance(result, str):
return Result.fail(result)
return Result.ok(info="数据库连接成功!")
async def run_restart_command(bat_path: Path, port: int):
"""在后台执行重启命令"""
await asyncio.sleep(1) # 确保 FastAPI 已返回响应
subprocess.Popen([bat_path, str(port)], shell=True) # noqa: ASYNC220
sys.exit(0) # 退出当前进程
@router.post(
"/restart",
response_model=Result,
response_class=JSONResponse,
description="重启",
)
async def _() -> Result:
if not BAT_FILE.exists():
return Result.fail("自动重启仅支持意见整合包,请尝试手动重启")
flag_file = next(
(Path() / file for file in os.listdir(Path()) if file.startswith(FILE_NAME)),
None,
)
if not flag_file or not flag_file.exists():
return Result.fail("重启标志文件不存在...")
set_time = flag_file.name.split("_")[-1]
if time.time() - float(set_time) > 10 * 60:
return Result.fail("重启标志文件已过期,请重新设置配置。")
flag_file.unlink()
try:
return Result.ok(info="执行重启命令成功")
finally:
asyncio.create_task(run_restart_command(BAT_FILE, port)) # noqa: RUF006

View File

@ -0,0 +1,18 @@
from tortoise import Tortoise
async def test_db_connection(db_url: str) -> bool | str:
try:
# 初始化 Tortoise ORM
await Tortoise.init(
db_url=db_url,
modules={"models": ["__main__"]}, # 这里不需要实际模型
)
# 测试连接
await Tortoise.get_connection("default").execute_query("SELECT 1")
return True
except Exception as e:
return str(e)
finally:
# 关闭连接
await Tortoise.close_connections()

View File

@ -0,0 +1,16 @@
from pydantic import BaseModel
class Setting(BaseModel):
superusers: list[str]
"""超级用户列表"""
db_url: str
"""数据库地址"""
host: str
"""主机地址"""
port: int
"""端口"""
username: str
"""前端用户名"""
password: str
"""前端密码"""

View File

@ -5,18 +5,7 @@ from zhenxun.services.log import logger
from .model import MenuData, MenuItem from .model import MenuData, MenuItem
default_menus = [
class MenuManage:
def __init__(self) -> None:
self.file = DATA_PATH / "web_ui" / "menu.json"
self.menu = []
if self.file.exists():
try:
self.menu = json.load(self.file.open(encoding="utf8"))
except Exception as e:
logger.warning("菜单文件损坏,已重新生成...", "WebUi", e=e)
if not self.menu:
self.menu = [
MenuItem( MenuItem(
name="仪表盘", name="仪表盘",
module="dashboard", module="dashboard",
@ -30,30 +19,50 @@ class MenuManage:
router="/command", router="/command",
icon="command", icon="command",
), ),
MenuItem( MenuItem(name="插件列表", module="plugin", router="/plugin", icon="plugin"),
name="插件列表", module="plugin", router="/plugin", icon="plugin" MenuItem(name="插件商店", module="store", router="/store", icon="store"),
), MenuItem(name="好友/群组", module="manage", router="/manage", icon="user"),
MenuItem(
name="插件商店", module="store", router="/store", icon="store"
),
MenuItem(
name="好友/群组", module="manage", router="/manage", icon="user"
),
MenuItem( MenuItem(
name="数据库管理", name="数据库管理",
module="database", module="database",
router="/database", router="/database",
icon="database", icon="database",
), ),
MenuItem( MenuItem(name="系统信息", module="system", router="/system", icon="system"),
name="文件管理", module="system", router="/system", icon="system" MenuItem(name="关于我们", module="about", router="/about", icon="about"),
),
MenuItem(
name="关于我们", module="about", router="/about", icon="about"
),
] ]
class MenuManager:
def __init__(self) -> None:
self.file = DATA_PATH / "web_ui" / "menu.json"
self.menu = []
if self.file.exists():
try:
temp_menu = []
self.menu = json.load(self.file.open(encoding="utf8"))
self_menu_name = [menu["name"] for menu in self.menu]
for module in [m.module for m in default_menus]:
if module in self_menu_name:
temp_menu.append(
MenuItem(
**next(m for m in self.menu if m["module"] == module)
)
)
else:
temp_menu.append(self.__get_menu_model(module))
self.menu = temp_menu
except Exception as e:
logger.warning("菜单文件损坏,已重新生成...", "WebUi", e=e)
if not self.menu:
self.menu = default_menus
self.save() self.save()
def __get_menu_model(self, module: str):
return default_menus[
next(i for i, m in enumerate(default_menus) if m.module == module)
]
def get_menus(self): def get_menus(self):
return MenuData(menus=self.menu) return MenuData(menus=self.menu)
@ -64,4 +73,4 @@ class MenuManage:
json.dump(temp, f, ensure_ascii=False, indent=4) json.dump(temp, f, ensure_ascii=False, indent=4)
menu_manage = MenuManage() menu_manage = MenuManager()

View File

@ -13,6 +13,7 @@ from zhenxun.models.bot_connect_log import BotConnectLog
from zhenxun.models.chat_history import ChatHistory from zhenxun.models.chat_history import ChatHistory
from zhenxun.models.statistics import Statistics from zhenxun.models.statistics import Statistics
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.platform import PlatformUtils
from ....base_model import BaseResultModel, QueryModel from ....base_model import BaseResultModel, QueryModel
@ -31,7 +32,7 @@ driver: Driver = nonebot.get_driver()
CONNECT_TIME = 0 CONNECT_TIME = 0
@driver.on_startup @PriorityLifecycle.on_startup(priority=5)
async def _(): async def _():
global CONNECT_TIME global CONNECT_TIME
CONNECT_TIME = int(time.time()) CONNECT_TIME = int(time.time())

View File

@ -8,6 +8,7 @@ from zhenxun.configs.config import BotConfig
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.task_info import TaskInfo from zhenxun.models.task_info import TaskInfo
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from ....base_model import BaseResultModel, QueryModel, Result from ....base_model import BaseResultModel, QueryModel, Result
from ....utils import authentication from ....utils import authentication
@ -21,7 +22,7 @@ router = APIRouter(prefix="/database")
driver: Driver = nonebot.get_driver() driver: Driver = nonebot.get_driver()
@driver.on_startup @PriorityLifecycle.on_startup(priority=5)
async def _(): async def _():
for plugin in nonebot.get_loaded_plugins(): for plugin in nonebot.get_loaded_plugins():
module = plugin.name module = plugin.name

View File

@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse
from zhenxun.utils._build_image import BuildImage from zhenxun.utils._build_image import BuildImage
from ....base_model import Result, SystemFolderSize from ....base_model import Result, SystemFolderSize
from ....utils import authentication, get_system_disk from ....utils import authentication, get_system_disk, validate_path
from .model import AddFile, DeleteFile, DirFile, RenameFile, SaveFile from .model import AddFile, DeleteFile, DirFile, RenameFile, SaveFile
router = APIRouter(prefix="/system") router = APIRouter(prefix="/system")
@ -25,7 +25,12 @@ IMAGE_TYPE = ["jpg", "jpeg", "png", "gif", "bmp", "webp", "svg"]
description="获取文件列表", description="获取文件列表",
) )
async def _(path: str | None = None) -> Result[list[DirFile]]: async def _(path: str | None = None) -> Result[list[DirFile]]:
base_path = Path(path) if path else Path() try:
base_path, error = validate_path(path)
if error:
return Result.fail(error)
if not base_path:
return Result.fail("无效的路径")
data_list = [] data_list = []
for file in os.listdir(base_path): for file in os.listdir(base_path):
file_path = base_path / file file_path = base_path / file
@ -41,6 +46,8 @@ async def _(path: str | None = None) -> Result[list[DirFile]]:
) )
) )
return Result.ok(data_list) return Result.ok(data_list)
except Exception as e:
return Result.fail(f"获取文件列表失败: {e!s}")
@router.get( @router.get(
@ -62,8 +69,12 @@ async def _(full_path: str | None = None) -> Result[list[SystemFolderSize]]:
description="删除文件", description="删除文件",
) )
async def _(param: DeleteFile) -> Result: async def _(param: DeleteFile) -> Result:
path = Path(param.full_path) path, error = validate_path(param.full_path)
if not path or not path.exists(): if error:
return Result.fail(error)
if not path:
return Result.fail("无效的路径")
if not path.exists():
return Result.warning_("文件不存在...") return Result.warning_("文件不存在...")
try: try:
path.unlink() path.unlink()
@ -80,8 +91,12 @@ async def _(param: DeleteFile) -> Result:
description="删除文件夹", description="删除文件夹",
) )
async def _(param: DeleteFile) -> Result: async def _(param: DeleteFile) -> Result:
path = Path(param.full_path) path, error = validate_path(param.full_path)
if not path or not path.exists() or path.is_file(): if error:
return Result.fail(error)
if not path:
return Result.fail("无效的路径")
if not path.exists() or path.is_file():
return Result.warning_("文件夹不存在...") return Result.warning_("文件夹不存在...")
try: try:
shutil.rmtree(path.absolute()) shutil.rmtree(path.absolute())
@ -98,10 +113,14 @@ async def _(param: DeleteFile) -> Result:
description="重命名文件", description="重命名文件",
) )
async def _(param: RenameFile) -> Result: async def _(param: RenameFile) -> Result:
path = ( parent_path, error = validate_path(param.parent)
(Path(param.parent) / param.old_name) if param.parent else Path(param.old_name) if error:
) return Result.fail(error)
if not path or not path.exists(): if not parent_path:
return Result.fail("无效的路径")
path = (parent_path / param.old_name) if param.parent else Path(param.old_name)
if not path.exists():
return Result.warning_("文件不存在...") return Result.warning_("文件不存在...")
try: try:
path.rename(path.parent / param.name) path.rename(path.parent / param.name)
@ -118,10 +137,14 @@ async def _(param: RenameFile) -> Result:
description="重命名文件夹", description="重命名文件夹",
) )
async def _(param: RenameFile) -> Result: async def _(param: RenameFile) -> Result:
path = ( parent_path, error = validate_path(param.parent)
(Path(param.parent) / param.old_name) if param.parent else Path(param.old_name) if error:
) return Result.fail(error)
if not path or not path.exists() or path.is_file(): if not parent_path:
return Result.fail("无效的路径")
path = (parent_path / param.old_name) if param.parent else Path(param.old_name)
if not path.exists() or path.is_file():
return Result.warning_("文件夹不存在...") return Result.warning_("文件夹不存在...")
try: try:
new_path = path.parent / param.name new_path = path.parent / param.name
@ -139,7 +162,13 @@ async def _(param: RenameFile) -> Result:
description="新建文件", description="新建文件",
) )
async def _(param: AddFile) -> Result: async def _(param: AddFile) -> Result:
path = (Path(param.parent) / param.name) if param.parent else Path(param.name) parent_path, error = validate_path(param.parent)
if error:
return Result.fail(error)
if not parent_path:
return Result.fail("无效的路径")
path = (parent_path / param.name) if param.parent else Path(param.name)
if path.exists(): if path.exists():
return Result.warning_("文件已存在...") return Result.warning_("文件已存在...")
try: try:
@ -157,7 +186,13 @@ async def _(param: AddFile) -> Result:
description="新建文件夹", description="新建文件夹",
) )
async def _(param: AddFile) -> Result: async def _(param: AddFile) -> Result:
path = (Path(param.parent) / param.name) if param.parent else Path(param.name) parent_path, error = validate_path(param.parent)
if error:
return Result.fail(error)
if not parent_path:
return Result.fail("无效的路径")
path = (parent_path / param.name) if param.parent else Path(param.name)
if path.exists(): if path.exists():
return Result.warning_("文件夹已存在...") return Result.warning_("文件夹已存在...")
try: try:
@ -175,7 +210,11 @@ async def _(param: AddFile) -> Result:
description="读取文件", description="读取文件",
) )
async def _(full_path: str) -> Result: async def _(full_path: str) -> Result:
path = Path(full_path) path, error = validate_path(full_path)
if error:
return Result.fail(error)
if not path:
return Result.fail("无效的路径")
if not path.exists(): if not path.exists():
return Result.warning_("文件不存在...") return Result.warning_("文件不存在...")
try: try:
@ -193,9 +232,13 @@ async def _(full_path: str) -> Result:
description="读取文件", description="读取文件",
) )
async def _(param: SaveFile) -> Result[str]: async def _(param: SaveFile) -> Result[str]:
path = Path(param.full_path) path, error = validate_path(param.full_path)
if error:
return Result.fail(error)
if not path:
return Result.fail("无效的路径")
try: try:
async with aiofiles.open(path, "w", encoding="utf-8") as f: async with aiofiles.open(str(path), "w", encoding="utf-8") as f:
await f.write(param.content) await f.write(param.content)
return Result.ok("更新成功!") return Result.ok("更新成功!")
except Exception as e: except Exception as e:
@ -210,7 +253,11 @@ async def _(param: SaveFile) -> Result[str]:
description="读取图片base64", description="读取图片base64",
) )
async def _(full_path: str) -> Result[str]: async def _(full_path: str) -> Result[str]:
path = Path(full_path) path, error = validate_path(full_path)
if error:
return Result.fail(error)
if not path:
return Result.fail("无效的路径")
if not path.exists(): if not path.exists():
return Result.warning_("文件不存在...") return Result.warning_("文件不存在...")
try: try:

View File

@ -1,5 +1,11 @@
import sys
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import nonebot import nonebot
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from strenum import StrEnum from strenum import StrEnum
from zhenxun.configs.path_config import DATA_PATH, TEMP_PATH from zhenxun.configs.path_config import DATA_PATH, TEMP_PATH

View File

@ -18,6 +18,7 @@ async def update_webui_assets():
download_url = await GithubUtils.parse_github_url( download_url = await GithubUtils.parse_github_url(
WEBUI_DIST_GITHUB_URL WEBUI_DIST_GITHUB_URL
).get_archive_download_urls() ).get_archive_download_urls()
logger.info("开始下载 webui_assets 资源...", COMMAND_NAME)
if await AsyncHttpx.download_file( if await AsyncHttpx.download_file(
download_url, webui_assets_path, follow_redirects=True download_url, webui_assets_path, follow_redirects=True
): ):

View File

@ -2,6 +2,7 @@ import contextlib
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
import os import os
from pathlib import Path from pathlib import Path
import re
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
@ -28,6 +29,45 @@ if token_file.exists():
token_data = json.load(open(token_file, encoding="utf8")) token_data = json.load(open(token_file, encoding="utf8"))
def validate_path(path_str: str | None) -> tuple[Path | None, str | None]:
"""验证路径是否安全
参数:
path_str: 用户输入的路径
返回:
tuple[Path | None, str | None]: (验证后的路径, 错误信息)
"""
try:
if not path_str:
return Path().resolve(), None
# 1. 移除任何可能的路径遍历尝试
path_str = re.sub(r"[\\/]\.\.[\\/]", "", path_str)
# 2. 规范化路径并转换为绝对路径
path = Path(path_str).resolve()
# 3. 获取项目根目录
root_dir = Path().resolve()
# 4. 验证路径是否在项目根目录内
try:
if not path.is_relative_to(root_dir):
return None, "访问路径超出允许范围"
except ValueError:
return None, "无效的路径格式"
# 5. 验证路径是否包含任何危险字符
if any(c in str(path) for c in ["..", "~", "*", "?", ">", "<", "|", '"']):
return None, "路径包含非法字符"
# 6. 验证路径长度是否合理
return (None, "路径长度超出限制") if len(str(path)) > 4096 else (path, None)
except Exception as e:
return None, f"路径验证失败: {e!s}"
GROUP_HELP_PATH = DATA_PATH / "group_help" GROUP_HELP_PATH = DATA_PATH / "group_help"
SIMPLE_HELP_IMAGE = IMAGE_PATH / "SIMPLE_HELP.png" SIMPLE_HELP_IMAGE = IMAGE_PATH / "SIMPLE_HELP.png"
SIMPLE_DETAIL_HELP_IMAGE = IMAGE_PATH / "SIMPLE_DETAIL_HELP.png" SIMPLE_DETAIL_HELP_IMAGE = IMAGE_PATH / "SIMPLE_DETAIL_HELP.png"

View File

@ -0,0 +1,29 @@
from tortoise import fields
from zhenxun.services.db_context import Model
from zhenxun.utils.enum import BotSentType
class BotMessageStore(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True)
"""自增id"""
bot_id = fields.CharField(255, null=True)
"""bot id"""
user_id = fields.CharField(255, null=True)
"""目标id"""
group_id = fields.CharField(255, null=True)
"""群组id"""
sent_type = fields.CharEnumField(BotSentType)
"""类型"""
text = fields.TextField(null=True)
"""文本内容"""
plain_text = fields.TextField(null=True)
"""纯文本"""
platform = fields.CharField(255, null=True)
"""平台"""
create_time = fields.DatetimeField(auto_now_add=True)
"""创建时间"""
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
table = "bot_message_store"
table_description = "Bot发送消息列表"

View File

@ -1,9 +1,12 @@
import nonebot
from nonebot.utils import is_coroutine_callable from nonebot.utils import is_coroutine_callable
from tortoise import Tortoise from tortoise import Tortoise
from tortoise.connection import connections from tortoise.connection import connections
from tortoise.models import Model as Model_ from tortoise.models import Model as Model_
from zhenxun.configs.config import BotConfig from zhenxun.configs.config import BotConfig
from zhenxun.utils.exception import HookPriorityException
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .log import logger from .log import logger
@ -11,6 +14,9 @@ SCRIPT_METHOD = []
MODELS: list[str] = [] MODELS: list[str] = []
driver = nonebot.get_driver()
class Model(Model_): class Model(Model_):
""" """
自动添加模块 自动添加模块
@ -26,7 +32,7 @@ class Model(Model_):
SCRIPT_METHOD.append((cls.__module__, func)) SCRIPT_METHOD.append((cls.__module__, func))
class DbUrlIsNode(Exception): class DbUrlIsNode(HookPriorityException):
""" """
数据库链接地址为空 数据库链接地址为空
""" """
@ -42,9 +48,19 @@ class DbConnectError(Exception):
pass pass
@PriorityLifecycle.on_startup(priority=1)
async def init(): async def init():
if not BotConfig.db_url: if not BotConfig.db_url:
raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...") # raise DbUrlIsNode("数据库配置为空,请在.env.dev中配置DB_URL...")
error = f"""
**********************************************************************
🌟 **************************** 配置为空 ************************* 🌟
🚀 请打开 WebUi 进行基础配置 🚀
🌐 配置地址http://{driver.config.host}:{driver.config.port}/#/configure 🌐
***********************************************************************
***********************************************************************
"""
raise DbUrlIsNode("\n" + error.strip())
try: try:
await Tortoise.init( await Tortoise.init(
db_url=BotConfig.db_url, db_url=BotConfig.db_url,

View File

@ -1,4 +1,4 @@
from datetime import datetime, timedelta from datetime import timedelta
from typing import Any, overload from typing import Any, overload
import nonebot import nonebot
@ -17,7 +17,7 @@ driver = nonebot.get_driver()
log_level = driver.config.log_level or "INFO" log_level = driver.config.log_level or "INFO"
logger_.add( logger_.add(
LOG_PATH / f"{datetime.now().date()}.log", LOG_PATH / "{time:YYYY-MM-DD}.log",
level=log_level, level=log_level,
rotation="00:00", rotation="00:00",
format=default_format, format=default_format,
@ -26,7 +26,7 @@ logger_.add(
) )
logger_.add( logger_.add(
LOG_PATH / f"error_{datetime.now().date()}.log", LOG_PATH / "error_{time:YYYY-MM-DD}.log",
level="ERROR", level="ERROR",
rotation="00:00", rotation="00:00",
format=default_format, format=default_format,
@ -36,13 +36,9 @@ logger_.add(
class logger: class logger:
TEMPLATE_A = "Adapter[{}] {}" """
TEMPLATE_B = "Adapter[{}] [<u><c>{}</c></u>]: {}" 一个经过优化的支持多种上下文和格式的日志记录器
TEMPLATE_C = "Adapter[{}] 用户[<u><e>{}</e></u>] 触发 [<u><c>{}</c></u>]: {}" """
TEMPLATE_D = "Adapter[{}] 群聊[<u><e>{}</e></u>] 用户[<u><e>{}</e></u>] 触发"
" [<u><c>{}</c></u>]: {}"
TEMPLATE_E = "Adapter[{}] 群聊[<u><e>{}</e></u>] 用户[<u><e>{}</e></u>] 触发"
" [<u><c>{}</c></u>] [Target](<u><e>{}</e></u>): {}"
TEMPLATE_ADAPTER = "Adapter[<m>{}</m>]" TEMPLATE_ADAPTER = "Adapter[<m>{}</m>]"
TEMPLATE_USER = "用户[<u><e>{}</e></u>]" TEMPLATE_USER = "用户[<u><e>{}</e></u>]"
@ -50,12 +46,82 @@ class logger:
TEMPLATE_COMMAND = "CMD[<u><c>{}</c></u>]" TEMPLATE_COMMAND = "CMD[<u><c>{}</c></u>]"
TEMPLATE_PLATFORM = "平台[<u><m>{}</m></u>]" TEMPLATE_PLATFORM = "平台[<u><m>{}</m></u>]"
TEMPLATE_TARGET = "[Target]([<u><e>{}</e></u>])" TEMPLATE_TARGET = "[Target]([<u><e>{}</e></u>])"
SUCCESS_TEMPLATE = "[<u><c>{}</c></u>]: {} | 参数[{}] 返回: [<y>{}</y>]" SUCCESS_TEMPLATE = "[<u><c>{}</c></u>]: {} | 参数[{}] 返回: [<y>{}</y>]"
WARNING_TEMPLATE = "[<u><y>{}</y></u>]: {}" @classmethod
def __parser_template(
cls,
info: str,
command: str | None = None,
user_id: int | str | None = None,
group_id: int | str | None = None,
adapter: str | None = None,
target: Any = None,
platform: str | None = None,
) -> str:
"""
优化后的模板解析器构建并连接日志信息片段
"""
parts = []
if adapter:
parts.append(cls.TEMPLATE_ADAPTER.format(adapter))
if platform:
parts.append(cls.TEMPLATE_PLATFORM.format(platform))
if group_id:
parts.append(cls.TEMPLATE_GROUP.format(group_id))
if user_id:
parts.append(cls.TEMPLATE_USER.format(user_id))
if command:
parts.append(cls.TEMPLATE_COMMAND.format(command))
if target:
parts.append(cls.TEMPLATE_TARGET.format(target))
ERROR_TEMPLATE = "[<u><r>{}</r></u>]: {}" parts.append(info)
return " ".join(parts)
@classmethod
def _log(
cls,
level: str,
info: str,
command: str | None = None,
session: int | str | Session | uninfoSession | None = None,
group_id: int | str | None = None,
adapter: str | None = None,
target: Any = None,
platform: str | None = None,
e: Exception | None = None,
):
"""
核心日志处理方法处理所有日志级别的通用逻辑
"""
user_id: str | None = str(session) if isinstance(session, int | str) else None
if isinstance(session, Session):
user_id = session.id1
adapter = session.bot_type
group_id = f"{session.id3}:{session.id2}" if session.id3 else session.id2
platform = platform or session.platform
elif isinstance(session, uninfoSession):
user_id = session.user.id
adapter = session.adapter
if session.group:
group_id = session.group.id
platform = session.basic.get("scope")
template = cls.__parser_template(
info, command, user_id, group_id, adapter, target, platform
)
if e:
template += f" || 错误 <r>{type(e).__name__}: {e}</r>"
try:
log_func = getattr(logger_.opt(colors=True), level)
log_func(template)
except Exception:
log_func_fallback = getattr(logger_, level)
log_func_fallback(template)
@overload @overload
@classmethod @classmethod
@ -70,7 +136,6 @@ class logger:
target: Any = None, target: Any = None,
platform: str | None = None, platform: str | None = None,
): ... ): ...
@overload @overload
@classmethod @classmethod
def info( def info(
@ -82,7 +147,6 @@ class logger:
target: Any = None, target: Any = None,
platform: str | None = None, platform: str | None = None,
): ... ): ...
@overload @overload
@classmethod @classmethod
def info( def info(
@ -107,28 +171,16 @@ class logger:
target: Any = None, target: Any = None,
platform: str | None = None, platform: str | None = None,
): ):
user_id: str | None = session # type: ignore cls._log(
if isinstance(session, Session): "info",
user_id = session.id1 info=info,
adapter = session.bot_type command=command,
if session.id3: session=session,
group_id = f"{session.id3}:{session.id2}" group_id=group_id,
elif session.id2: adapter=adapter,
group_id = f"{session.id2}" target=target,
platform = platform or session.platform platform=platform,
elif isinstance(session, uninfoSession):
user_id = session.user.id
adapter = session.adapter
if session.group:
group_id = session.group.id
platform = session.basic["scope"]
template = cls.__parser_template(
info, command, user_id, group_id, adapter, target, platform
) )
try:
logger_.opt(colors=True).info(template)
except Exception:
logger_.info(template)
@classmethod @classmethod
def success( def success(
@ -138,9 +190,11 @@ class logger:
param: dict[str, Any] | None = None, param: dict[str, Any] | None = None,
result: str = "", result: str = "",
): ):
param_str = "" param_str = (
if param: ",".join([f"<m>{k}</m>:<g>{v}</g>" for k, v in param.items()])
param_str = ",".join([f"<m>{k}</m>:<g>{v}</g>" for k, v in param.items()]) if param
else ""
)
logger_.opt(colors=True).success( logger_.opt(colors=True).success(
cls.SUCCESS_TEMPLATE.format(command, info, param_str, result) cls.SUCCESS_TEMPLATE.format(command, info, param_str, result)
) )
@ -159,7 +213,6 @@ class logger:
platform: str | None = None, platform: str | None = None,
e: Exception | None = None, e: Exception | None = None,
): ... ): ...
@overload @overload
@classmethod @classmethod
def warning( def warning(
@ -168,12 +221,10 @@ class logger:
command: str | None = None, command: str | None = None,
*, *,
session: Session | None = None, session: Session | None = None,
adapter: str | None = None,
target: Any = None, target: Any = None,
platform: str | None = None, platform: str | None = None,
e: Exception | None = None, e: Exception | None = None,
): ... ): ...
@overload @overload
@classmethod @classmethod
def warning( def warning(
@ -182,7 +233,6 @@ class logger:
command: str | None = None, command: str | None = None,
*, *,
session: uninfoSession | None = None, session: uninfoSession | None = None,
adapter: str | None = None,
target: Any = None, target: Any = None,
platform: str | None = None, platform: str | None = None,
e: Exception | None = None, e: Exception | None = None,
@ -201,30 +251,17 @@ class logger:
platform: str | None = None, platform: str | None = None,
e: Exception | None = None, e: Exception | None = None,
): ):
user_id: str | None = session # type: ignore cls._log(
if isinstance(session, Session): "warning",
user_id = session.id1 info=info,
adapter = session.bot_type command=command,
if session.id3: session=session,
group_id = f"{session.id3}:{session.id2}" group_id=group_id,
elif session.id2: adapter=adapter,
group_id = f"{session.id2}" target=target,
platform = platform or session.platform platform=platform,
elif isinstance(session, uninfoSession): e=e,
user_id = session.user.id
adapter = session.adapter
if session.group:
group_id = session.group.id
platform = session.basic["scope"]
template = cls.__parser_template(
info, command, user_id, group_id, adapter, target, platform
) )
if e:
template += f" || 错误<r>{type(e)}: {e}</r>"
try:
logger_.opt(colors=True).warning(template)
except Exception as e:
logger_.warning(template)
@overload @overload
@classmethod @classmethod
@ -240,7 +277,6 @@ class logger:
platform: str | None = None, platform: str | None = None,
e: Exception | None = None, e: Exception | None = None,
): ... ): ...
@overload @overload
@classmethod @classmethod
def error( def error(
@ -253,7 +289,6 @@ class logger:
platform: str | None = None, platform: str | None = None,
e: Exception | None = None, e: Exception | None = None,
): ... ): ...
@overload @overload
@classmethod @classmethod
def error( def error(
@ -280,30 +315,17 @@ class logger:
platform: str | None = None, platform: str | None = None,
e: Exception | None = None, e: Exception | None = None,
): ):
user_id: str | None = session # type: ignore cls._log(
if isinstance(session, Session): "error",
user_id = session.id1 info=info,
adapter = session.bot_type command=command,
if session.id3: session=session,
group_id = f"{session.id3}:{session.id2}" group_id=group_id,
elif session.id2: adapter=adapter,
group_id = f"{session.id2}" target=target,
platform = platform or session.platform platform=platform,
elif isinstance(session, uninfoSession): e=e,
user_id = session.user.id
adapter = session.adapter
if session.group:
group_id = session.group.id
platform = session.basic["scope"]
template = cls.__parser_template(
info, command, user_id, group_id, adapter, target, platform
) )
if e:
template += f" || 错误 <r>{type(e)}: {e}</r>"
try:
logger_.opt(colors=True).error(template)
except Exception as e:
logger_.error(template)
@overload @overload
@classmethod @classmethod
@ -319,7 +341,6 @@ class logger:
platform: str | None = None, platform: str | None = None,
e: Exception | None = None, e: Exception | None = None,
): ... ): ...
@overload @overload
@classmethod @classmethod
def debug( def debug(
@ -332,7 +353,6 @@ class logger:
platform: str | None = None, platform: str | None = None,
e: Exception | None = None, e: Exception | None = None,
): ... ): ...
@overload @overload
@classmethod @classmethod
def debug( def debug(
@ -359,62 +379,78 @@ class logger:
platform: str | None = None, platform: str | None = None,
e: Exception | None = None, e: Exception | None = None,
): ):
user_id: str | None = session # type: ignore cls._log(
if isinstance(session, Session): "debug",
user_id = session.id1 info=info,
adapter = session.bot_type command=command,
if session.id3: session=session,
group_id = f"{session.id3}:{session.id2}" group_id=group_id,
elif session.id2: adapter=adapter,
group_id = f"{session.id2}" target=target,
platform = platform or session.platform platform=platform,
elif isinstance(session, uninfoSession): e=e,
user_id = session.user.id
adapter = session.adapter
if session.group:
group_id = session.group.id
platform = session.basic["scope"]
template = cls.__parser_template(
info, command, user_id, group_id, adapter, target, platform
) )
if e:
template += f" || 错误 <r>{type(e)}: {e}</r>"
try:
logger_.opt(colors=True).debug(template)
except Exception as e:
logger_.debug(template)
@overload
@classmethod @classmethod
def __parser_template( def trace(
cls, cls,
info: str, info: str,
command: str | None = None, command: str | None = None,
user_id: int | str | None = None, *,
session: int | str | None = None,
group_id: int | str | None = None, group_id: int | str | None = None,
adapter: str | None = None, adapter: str | None = None,
target: Any = None, target: Any = None,
platform: str | None = None, platform: str | None = None,
) -> str: e: Exception | None = None,
arg_list = [] ): ...
template = "" @overload
if adapter is not None: @classmethod
template += cls.TEMPLATE_ADAPTER def trace(
arg_list.append(adapter) cls,
if platform is not None: info: str,
template += cls.TEMPLATE_PLATFORM command: str | None = None,
arg_list.append(platform) *,
if group_id is not None: session: Session | None = None,
template += cls.TEMPLATE_GROUP target: Any = None,
arg_list.append(group_id) platform: str | None = None,
if user_id is not None: e: Exception | None = None,
template += cls.TEMPLATE_USER ): ...
arg_list.append(user_id) @overload
if command is not None: @classmethod
template += cls.TEMPLATE_COMMAND def trace(
arg_list.append(command) cls,
if target is not None: info: str,
template += cls.TEMPLATE_TARGET command: str | None = None,
arg_list.append(target) *,
arg_list.append(info) session: uninfoSession | None = None,
template += "{}" target: Any = None,
return template.format(*arg_list) platform: str | None = None,
e: Exception | None = None,
): ...
@classmethod
def trace(
cls,
info: str,
command: str | None = None,
*,
session: int | str | Session | uninfoSession | None = None,
group_id: int | str | None = None,
adapter: str | None = None,
target: Any = None,
platform: str | None = None,
e: Exception | None = None,
):
cls._log(
"trace",
info=info,
command=command,
session=session,
group_id=group_id,
adapter=adapter,
target=target,
platform=platform,
e=e,
)

View File

@ -6,6 +6,7 @@ from nonebot.utils import is_coroutine_callable
from pydantic import BaseModel from pydantic import BaseModel
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
driver = nonebot.get_driver() driver = nonebot.get_driver()
@ -100,6 +101,6 @@ class PluginInitManager:
logger.error(f"执行: {module_path}:remove 失败", e=e) logger.error(f"执行: {module_path}:remove 失败", e=e)
@driver.on_startup @PriorityLifecycle.on_startup(priority=5)
async def _(): async def _():
await PluginInitManager.install_all() await PluginInitManager.install_all()

View File

@ -1,12 +1,17 @@
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
import random import random
import sys
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from strenum import StrEnum
from ._build_image import BuildImage from ._build_image import BuildImage
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from strenum import StrEnum
class MatType(StrEnum): class MatType(StrEnum):
LINE = "LINE" LINE = "LINE"

View File

@ -1,6 +1,23 @@
import sys
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from strenum import StrEnum from strenum import StrEnum
class PriorityLifecycleType(StrEnum):
STARTUP = "STARTUP"
"""启动"""
SHUTDOWN = "SHUTDOWN"
"""关闭"""
class BotSentType(StrEnum):
GROUP = "GROUP"
PRIVATE = "PRIVATE"
class BankHandleType(StrEnum): class BankHandleType(StrEnum):
DEPOSIT = "DEPOSIT" DEPOSIT = "DEPOSIT"
"""存款""" """存款"""

View File

@ -1,3 +1,15 @@
class HookPriorityException(BaseException):
"""
钩子优先级异常
"""
def __init__(self, info: str = "") -> None:
self.info = info
def __str__(self) -> str:
return self.info
class NotFoundError(Exception): class NotFoundError(Exception):
""" """
未发现 未发现

View File

@ -1,13 +1,18 @@
import contextlib import contextlib
import sys
from typing import Protocol from typing import Protocol
from aiocache import cached from aiocache import cached
from nonebot.compat import model_dump from nonebot.compat import model_dump
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from strenum import StrEnum
from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.utils.http_utils import AsyncHttpx
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from strenum import StrEnum
from .const import ( from .const import (
CACHED_API_TTL, CACHED_API_TTL,
GIT_API_COMMIT_FORMAT, GIT_API_COMMIT_FORMAT,

View File

@ -1,217 +1,223 @@
import asyncio import asyncio
from asyncio.exceptions import TimeoutError from collections.abc import AsyncGenerator, Sequence
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
import time import time
from typing import Any, ClassVar, Literal from typing import Any, ClassVar, Literal, cast
import aiofiles import aiofiles
from anyio import EndOfStream
import httpx import httpx
from httpx import ConnectTimeout, HTTPStatusError, Response from httpx import AsyncHTTPTransport, HTTPStatusError, Proxy, Response
from nonebot_plugin_alconna import UniMessage from nonebot_plugin_alconna import UniMessage
from nonebot_plugin_htmlrender import get_browser from nonebot_plugin_htmlrender import get_browser
from playwright.async_api import Page from playwright.async_api import Page
import rich from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
TextColumn,
TransferSpeedColumn,
)
from zhenxun.configs.config import BotConfig from zhenxun.configs.config import BotConfig
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.message import MessageUtils from zhenxun.utils.message import MessageUtils
from zhenxun.utils.user_agent import get_user_agent from zhenxun.utils.user_agent import get_user_agent
# from .browser import get_browser CLIENT_KEY = ["use_proxy", "proxies", "proxy", "verify", "headers"]
def get_async_client(
proxies: dict[str, str] | None = None,
proxy: str | None = None,
verify: bool = False,
**kwargs,
) -> httpx.AsyncClient:
transport = kwargs.pop("transport", None) or AsyncHTTPTransport(verify=verify)
if proxies:
http_proxy = proxies.get("http://")
https_proxy = proxies.get("https://")
return httpx.AsyncClient(
mounts={
"http://": AsyncHTTPTransport(
proxy=Proxy(http_proxy) if http_proxy else None
),
"https://": AsyncHTTPTransport(
proxy=Proxy(https_proxy) if https_proxy else None
),
},
transport=transport,
**kwargs,
)
elif proxy:
return httpx.AsyncClient(
mounts={
"http://": AsyncHTTPTransport(proxy=Proxy(proxy)),
"https://": AsyncHTTPTransport(proxy=Proxy(proxy)),
},
transport=transport,
**kwargs,
)
return httpx.AsyncClient(transport=transport, **kwargs)
class AsyncHttpx: class AsyncHttpx:
proxy: ClassVar[dict[str, str | None]] = { default_proxy: ClassVar[dict[str, str] | None] = (
{
"http://": BotConfig.system_proxy, "http://": BotConfig.system_proxy,
"https://": BotConfig.system_proxy, "https://": BotConfig.system_proxy,
} }
if BotConfig.system_proxy
else None
)
@classmethod
@asynccontextmanager
async def _create_client(
cls,
*,
use_proxy: bool = True,
proxies: dict[str, str] | None = None,
proxy: str | None = None,
headers: dict[str, str] | None = None,
verify: bool = False,
**kwargs,
) -> AsyncGenerator[httpx.AsyncClient, None]:
"""创建一个私有的、配置好的 httpx.AsyncClient 上下文管理器。
说明:
此方法用于内部统一创建客户端处理代理和请求头逻辑减少代码重复
参数:
use_proxy: 是否使用在类中定义的默认代理
proxies: 手动指定的代理会覆盖默认代理
proxy: 单个代理,用于兼容旧版本不再使用
headers: 需要合并到客户端的自定义请求头
verify: 是否验证 SSL 证书
**kwargs: 其他所有传递给 httpx.AsyncClient 的参数
返回:
AsyncGenerator[httpx.AsyncClient, None]: 生成器
"""
proxies_to_use = proxies or (cls.default_proxy if use_proxy else None)
final_headers = get_user_agent()
if headers:
final_headers.update(headers)
async with get_async_client(
proxies=proxies_to_use,
proxy=proxy,
verify=verify,
headers=final_headers,
**kwargs,
) as client:
yield client
@classmethod @classmethod
async def get( async def get(
cls, cls,
url: str | list[str], url: str | list[str],
*, *,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
verify: bool = True,
use_proxy: bool = True,
proxy: dict[str, str] | None = None,
timeout: int = 30, # noqa: ASYNC109
check_status_code: int | None = None, check_status_code: int | None = None,
**kwargs, **kwargs,
) -> Response: ) -> Response: # sourcery skip: use-assigned-variable
"""Get """发送 GET 请求,并返回第一个成功的响应。
说明:
本方法是 httpx.get 的高级包装增加了多链接尝试自动重试和统一的代理管理
如果提供 URL 列表它将依次尝试直到成功为止
参数: 参数:
url: url url: 单个请求 URL 或一个 URL 列表
params: params check_status_code: (可选) 若提供将检查响应状态码是否匹配否则抛出异常
headers: 请求头 **kwargs: 其他所有传递给 httpx.get 的参数
cookies: cookies ( `params`, `headers`, `timeout`)
verify: verify
use_proxy: 使用默认代理 返回:
proxy: 指定代理 Response: Response
timeout: 超时时间
check_status_code: 检查状态码
""" """
urls = [url] if isinstance(url, str) else url urls = [url] if isinstance(url, str) else url
return await cls._get_first_successful(
urls,
params=params,
headers=headers,
cookies=cookies,
verify=verify,
use_proxy=use_proxy,
proxy=proxy,
timeout=timeout,
check_status_code=check_status_code,
**kwargs,
)
@classmethod
async def _get_first_successful(
cls,
urls: list[str],
check_status_code: int | None = None,
**kwargs,
) -> Response:
last_exception = None last_exception = None
for url in urls: for current_url in urls:
try: try:
logger.info(f"开始获取 {url}..") logger.info(f"开始获取 {current_url}..")
response = await cls._get_single(url, **kwargs) client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
for key in CLIENT_KEY:
kwargs.pop(key, None)
async with cls._create_client(**client_kwargs) as client:
response = await client.get(current_url, **kwargs)
if check_status_code and response.status_code != check_status_code: if check_status_code and response.status_code != check_status_code:
status_code = response.status_code raise HTTPStatusError(
raise Exception(f"状态码错误:{status_code}!={check_status_code}") f"状态码错误: {response.status_code}!={check_status_code}",
request=response.request,
response=response,
)
return response return response
except Exception as e: except Exception as e:
last_exception = e last_exception = e
if url != urls[-1]: if current_url != urls[-1]:
logger.warning(f"获取 {url} 失败, 尝试下一个") logger.warning(f"获取 {current_url} 失败, 尝试下一个", e=e)
raise last_exception or Exception("All URLs failed")
raise last_exception or Exception("所有URL都获取失败")
@classmethod @classmethod
async def _get_single( async def head(cls, url: str, **kwargs) -> Response:
cls, """发送 HEAD 请求。
url: str,
*,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
verify: bool = True,
use_proxy: bool = True,
proxy: dict[str, str] | None = None,
timeout: int = 30, # noqa: ASYNC109
**kwargs,
) -> Response:
if not headers:
headers = get_user_agent()
_proxy = proxy or (cls.proxy if use_proxy else None)
async with httpx.AsyncClient(proxies=_proxy, verify=verify) as client: # type: ignore
return await client.get(
url,
params=params,
headers=headers,
cookies=cookies,
timeout=timeout,
**kwargs,
)
@classmethod
async def head(
cls,
url: str,
*,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
verify: bool = True,
use_proxy: bool = True,
proxy: dict[str, str] | None = None,
timeout: int = 30, # noqa: ASYNC109
**kwargs,
) -> Response:
"""Get
参数:
url: url
params: params
headers: 请求头
cookies: cookies
verify: verify
use_proxy: 使用默认代理
proxy: 指定代理
timeout: 超时时间
"""
if not headers:
headers = get_user_agent()
_proxy = proxy or (cls.proxy if use_proxy else None)
async with httpx.AsyncClient(proxies=_proxy, verify=verify) as client: # type: ignore
return await client.head(
url,
params=params,
headers=headers,
cookies=cookies,
timeout=timeout,
**kwargs,
)
@classmethod
async def post(
cls,
url: str,
*,
data: dict[str, Any] | None = None,
content: Any = None,
files: Any = None,
verify: bool = True,
use_proxy: bool = True,
proxy: dict[str, str] | None = None,
json: dict[str, Any] | None = None,
params: dict[str, str] | None = None,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
timeout: int = 30, # noqa: ASYNC109
**kwargs,
) -> Response:
"""
说明: 说明:
Post 本方法是对 httpx.head 的封装通常用于检查资源的元信息如大小类型
参数: 参数:
url: url url: 请求的 URL
data: data **kwargs: 其他所有传递给 httpx.head 的参数
content: content ( `headers`, `timeout`, `allow_redirects`)
files: files
use_proxy: 是否默认代理 返回:
proxy: 指定代理 Response: Response
json: json
params: params
headers: 请求头
cookies: cookies
timeout: 超时时间
""" """
if not headers: client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
headers = get_user_agent() for key in CLIENT_KEY:
_proxy = proxy or (cls.proxy if use_proxy else None) kwargs.pop(key, None)
async with httpx.AsyncClient(proxies=_proxy, verify=verify) as client: # type: ignore async with cls._create_client(**client_kwargs) as client:
return await client.post( return await client.head(url, **kwargs)
url,
content=content, @classmethod
data=data, async def post(cls, url: str, **kwargs) -> Response:
files=files, """发送 POST 请求。
json=json,
params=params, 说明:
headers=headers, 本方法是对 httpx.post 的封装提供了统一的代理和客户端管理
cookies=cookies,
timeout=timeout, 参数:
**kwargs, url: 请求的 URL
) **kwargs: 其他所有传递给 httpx.post 的参数
( `data`, `json`, `content` )
返回:
Response: Response
"""
client_kwargs = {k: v for k, v in kwargs.items() if k in CLIENT_KEY}
for key in CLIENT_KEY:
kwargs.pop(key, None)
async with cls._create_client(**client_kwargs) as client:
return await client.post(url, **kwargs)
@classmethod @classmethod
async def get_content(cls, url: str, **kwargs) -> bytes: async def get_content(cls, url: str, **kwargs) -> bytes:
"""获取指定 URL 的二进制内容。
说明:
这是一个便捷方法等同于调用 get() 后再访问 .content 属性
参数:
url: 请求的 URL
**kwargs: 所有传递给 get() 方法的参数
返回:
bytes: 响应内容的二进制字节流 (bytes)
"""
res = await cls.get(url, **kwargs) res = await cls.get(url, **kwargs)
return res.content return res.content
@ -221,195 +227,143 @@ class AsyncHttpx:
url: str | list[str], url: str | list[str],
path: str | Path, path: str | Path,
*, *,
params: dict[str, str] | None = None,
verify: bool = True,
use_proxy: bool = True,
proxy: dict[str, str] | None = None,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
timeout: int = 30, # noqa: ASYNC109
stream: bool = False, stream: bool = False,
follow_redirects: bool = True,
**kwargs, **kwargs,
) -> bool: ) -> bool:
"""下载文件 """下载文件到指定路径。
说明:
支持多链接尝试和流式下载带进度条
参数: 参数:
url: url url: 单个文件 URL 或一个备用 URL 列表
path: 存储路径 path: 文件保存的本地路径
params: params stream: (可选) 是否使用流式下载适用于大文件默认为 False
verify: verify **kwargs: 其他所有传递给 get() 方法或 httpx.stream() 的参数
use_proxy: 使用代理
proxy: 指定代理 返回:
headers: 请求头 bool: 是否下载成功
cookies: cookies
timeout: 超时时间
stream: 是否使用流式下载流式写入+进度条适用于下载大文件
""" """
if isinstance(path, str):
path = Path(path) path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
try:
for _ in range(3): urls = [url] if isinstance(url, str) else url
if not isinstance(url, list):
url = [url] for current_url in urls:
for u in url:
try: try:
if not stream: if not stream:
response = await cls.get( response = await cls.get(current_url, **kwargs)
u,
params=params,
headers=headers,
cookies=cookies,
use_proxy=use_proxy,
proxy=proxy,
timeout=timeout,
follow_redirects=follow_redirects,
**kwargs,
)
response.raise_for_status() response.raise_for_status()
content = response.content async with aiofiles.open(path, "wb") as f:
async with aiofiles.open(path, "wb") as wf: await f.write(response.content)
await wf.write(content)
logger.info(f"下载 {u} 成功.. Path{path.absolute()}")
else: else:
if not headers: async with cls._create_client(**kwargs) as client:
headers = get_user_agent() stream_kwargs = {
_proxy = proxy or (cls.proxy if use_proxy else None) k: v
async with httpx.AsyncClient( for k, v in kwargs.items()
proxies=_proxy, # type: ignore if k not in ["use_proxy", "proxy", "verify"]
verify=verify, }
) as client:
async with client.stream( async with client.stream(
"GET", "GET", current_url, **stream_kwargs
u,
params=params,
headers=headers,
cookies=cookies,
timeout=timeout,
follow_redirects=True,
**kwargs,
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
logger.info( total = int(response.headers.get("Content-Length", 0))
f"开始下载 {path.name}.. "
f"Url: {u}.. " with Progress(
f"Path: {path.absolute()}" TextColumn(path.name),
) "[progress.percentage]{task.percentage:>3.0f}%",
async with aiofiles.open(path, "wb") as wf: BarColumn(bar_width=None),
total = int( DownloadColumn(),
response.headers.get("Content-Length", 0) TransferSpeedColumn(),
)
with rich.progress.Progress( # type: ignore
rich.progress.TextColumn(path.name), # type: ignore
"[progress.percentage]{task.percentage:>3.0f}%", # type: ignore
rich.progress.BarColumn(bar_width=None), # type: ignore
rich.progress.DownloadColumn(), # type: ignore
rich.progress.TransferSpeedColumn(), # type: ignore
) as progress: ) as progress:
download_task = progress.add_task( task_id = progress.add_task("Download", total=total)
"Download", async with aiofiles.open(path, "wb") as f:
total=total or None,
)
async for chunk in response.aiter_bytes(): async for chunk in response.aiter_bytes():
await wf.write(chunk) await f.write(chunk)
await wf.flush() progress.update(task_id, advance=len(chunk))
progress.update(
download_task, logger.info(f"下载 {current_url} 成功 -> {path.absolute()}")
completed=response.num_bytes_downloaded,
)
logger.info(
f"下载 {u} 成功.. Path{path.absolute()}"
)
return True return True
except (TimeoutError, ConnectTimeout, HTTPStatusError):
logger.warning(f"下载 {u} 失败.. 尝试下一个地址..")
except EndOfStream as e:
logger.warning(
f"下载 {url} EndOfStream 异常 Path{path.absolute()}", e=e
)
if path.exists():
return True
logger.error(f"下载 {url} 下载超时.. Path{path.absolute()}")
except Exception as e: except Exception as e:
logger.error(f"下载 {url} 错误 Path{path.absolute()}", e=e) logger.warning(f"下载 {current_url} 失败,尝试下一个。错误: {e}")
logger.error(f"所有URL {urls} 下载均失败 -> {path.absolute()}")
return False return False
@classmethod @classmethod
async def gather_download_file( async def gather_download_file(
cls, cls,
url_list: list[str] | list[list[str]], url_list: Sequence[list[str] | str],
path_list: list[str | Path], path_list: Sequence[str | Path],
*, *,
limit_async_number: int | None = None, limit_async_number: int = 5,
params: dict[str, str] | None = None,
use_proxy: bool = True,
proxy: dict[str, str] | None = None,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
timeout: int = 30, # noqa: ASYNC109
**kwargs, **kwargs,
) -> list[bool]: ) -> list[bool]:
"""分组同时下载文件 """并发下载多个文件,支持为每个文件提供备用镜像链接。
说明:
使用 asyncio.Semaphore 来控制并发请求的数量
对于 url_list 中的每个元素如果它是一个列表则会依次尝试直到下载成功
参数: 参数:
url_list: url列表 url_list: 包含所有文件下载任务的列表每个元素可以是
path_list: 存储路径列表 - 一个字符串 (str): 代表该任务的唯一URL
limit_async_number: 限制同时请求数量 - 一个字符串列表 (list[str]): 代表该任务的多个备用/镜像URL
params: params path_list: url_list 对应的文件保存路径列表
use_proxy: 使用代理 limit_async_number: (可选) 最大并发下载数默认为 5
proxy: 指定代理 **kwargs: 其他所有传递给 download_file() 方法的参数
headers: 请求头
cookies: cookies 返回:
timeout: 超时时间 list[bool]: 对应每个下载任务是否成功
""" """
if n := len(url_list) != len(path_list): if len(url_list) != len(path_list):
raise UrlPathNumberNotEqual( raise ValueError("URL 列表和路径列表的长度必须相等")
f"Url数量与Path数量不对等Url{len(url_list)}Path{len(path_list)}"
semaphore = asyncio.Semaphore(limit_async_number)
async def _download_with_semaphore(
urls_for_one_path: str | list[str], path: str | Path
):
async with semaphore:
return await cls.download_file(urls_for_one_path, path, **kwargs)
tasks = [
_download_with_semaphore(url_group, path)
for url_group, path in zip(url_list, path_list)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
final_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
url_info = (
url_list[i]
if isinstance(url_list[i], str)
else ", ".join(url_list[i])
) )
if limit_async_number and n > limit_async_number: logger.error(f"并发下载任务 ({url_info}) 时发生错误", e=result)
m = float(n) / limit_async_number final_results.append(False)
x = 0
j = limit_async_number
_split_url_list = []
_split_path_list = []
for _ in range(int(m)):
_split_url_list.append(url_list[x:j])
_split_path_list.append(path_list[x:j])
x += limit_async_number
j += limit_async_number
if int(m) < m:
_split_url_list.append(url_list[j:])
_split_path_list.append(path_list[j:])
else: else:
_split_url_list = [url_list] # download_file 返回的是 bool可以直接附加
_split_path_list = [path_list] final_results.append(cast(bool, result))
tasks = []
result_ = [] return final_results
for x, y in zip(_split_url_list, _split_path_list):
tasks.extend(
asyncio.create_task(
cls.download_file(
url,
path,
params=params,
headers=headers,
cookies=cookies,
use_proxy=use_proxy,
timeout=timeout,
proxy=proxy,
**kwargs,
)
)
for url, path in zip(x, y)
)
_x = await asyncio.gather(*tasks)
result_ = result_ + list(_x)
tasks.clear()
return result_
@classmethod @classmethod
async def get_fastest_mirror(cls, url_list: list[str]) -> list[str]: async def get_fastest_mirror(cls, url_list: list[str]) -> list[str]:
"""测试并返回最快的镜像地址。
说明:
通过并发发送 HEAD 请求来测试每个 URL 的响应时间和可用性并按响应速度排序
参数:
url_list: 需要测试的镜像 URL 列表
返回:
list[str]: 按从快到慢的顺序包含了所有可用的 URL
"""
assert url_list assert url_list
async def head_mirror(client: type[AsyncHttpx], url: str) -> dict[str, Any]: async def head_mirror(client: type[AsyncHttpx], url: str) -> dict[str, Any]:
@ -478,7 +432,7 @@ class AsyncPlaywright:
wait_until: ( wait_until: (
Literal["domcontentloaded", "load", "networkidle"] | None Literal["domcontentloaded", "load", "networkidle"] | None
) = "networkidle", ) = "networkidle",
timeout: float | None = None, # noqa: ASYNC109 timeout: float | None = None,
type_: Literal["jpeg", "png"] | None = None, type_: Literal["jpeg", "png"] | None = None,
user_agent: str | None = None, user_agent: str | None = None,
cookies: list[dict[str, Any]] | dict[str, Any] | None = None, cookies: list[dict[str, Any]] | dict[str, Any] | None = None,
@ -522,9 +476,5 @@ class AsyncPlaywright:
return None return None
class UrlPathNumberNotEqual(Exception):
pass
class BrowserIsNone(Exception): class BrowserIsNone(Exception):
pass pass

View File

@ -22,6 +22,4 @@ class MessageManager:
@classmethod @classmethod
def get(cls, uid: str) -> list[str]: def get(cls, uid: str) -> list[str]:
if uid in cls.data: return cls.data[uid] if uid in cls.data else []
return cls.data[uid]
return []

View File

@ -0,0 +1,57 @@
from collections.abc import Callable
from typing import ClassVar
import nonebot
from nonebot.utils import is_coroutine_callable
from zhenxun.services.log import logger
from zhenxun.utils.enum import PriorityLifecycleType
from zhenxun.utils.exception import HookPriorityException
driver = nonebot.get_driver()
class PriorityLifecycle:
_data: ClassVar[dict[PriorityLifecycleType, dict[int, list[Callable]]]] = {}
@classmethod
def add(cls, hook_type: PriorityLifecycleType, func: Callable, priority: int):
if hook_type not in cls._data:
cls._data[hook_type] = {}
if priority not in cls._data[hook_type]:
cls._data[hook_type][priority] = []
cls._data[hook_type][priority].append(func)
@classmethod
def on_startup(cls, *, priority: int):
def wrapper(func):
cls.add(PriorityLifecycleType.STARTUP, func, priority)
return func
return wrapper
@classmethod
def on_shutdown(cls, *, priority: int):
def wrapper(func):
cls.add(PriorityLifecycleType.SHUTDOWN, func, priority)
return func
return wrapper
@driver.on_startup
async def _():
priority_data = PriorityLifecycle._data.get(PriorityLifecycleType.STARTUP)
if not priority_data:
return
priority_list = sorted(priority_data.keys())
priority = 0
try:
for priority in priority_list:
for func in priority_data[priority]:
if is_coroutine_callable(func):
await func()
else:
func()
except HookPriorityException as e:
logger.error(f"打断优先级 [{priority}] on_startup 方法. {type(e)}: {e}")

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
import random import random
from typing import Literal from typing import cast
import httpx import httpx
import nonebot import nonebot
@ -486,15 +486,134 @@ class PlatformUtils:
return target return target
class BroadcastEngine:
def __init__(
self,
message: str | UniMessage,
bot: Bot | list[Bot] | None = None,
bot_id: str | set[str] | None = None,
ignore_group: list[str] | None = None,
check_func: Callable[[Bot, str], Awaitable] | None = None,
log_cmd: str | None = None,
platform: str | None = None,
):
"""广播引擎
参数:
message: 广播消息内容
bot: 指定bot对象.
bot_id: 指定bot id.
ignore_group: 忽略群聊列表.
check_func: 发送前对群聊检测方法判断是否发送.
log_cmd: 日志标记.
platform: 指定平台.
异常:
ValueError: 没有可用的Bot对象
"""
if ignore_group is None:
ignore_group = []
self.message = MessageUtils.build_message(message)
self.ignore_group = ignore_group
self.check_func = check_func
self.log_cmd = log_cmd
self.platform = platform
self.bot_list = []
self.count = 0
if bot:
self.bot_list = [bot] if isinstance(bot, Bot) else bot
if isinstance(bot_id, str):
bot_id = set(bot_id)
if bot_id:
for i in bot_id:
try:
self.bot_list.append(nonebot.get_bot(i))
except KeyError:
logger.warning(f"Bot:{i} 对象未连接或不存在")
if not self.bot_list:
raise ValueError("当前没有可用的Bot对象...", log_cmd)
async def call_check(self, bot: Bot, group_id: str) -> bool:
"""运行发送检测函数
参数:
bot: Bot
group_id: 群组id
返回:
bool: 是否发送
"""
if not self.check_func:
return True
if is_coroutine_callable(self.check_func):
is_run = await self.check_func(bot, group_id)
else:
is_run = self.check_func(bot, group_id)
return cast(bool, is_run)
async def __send_message(self, bot: Bot, group: GroupConsole):
"""群组发送消息
参数:
bot: Bot
group: GroupConsole
"""
key = f"{group.group_id}:{group.channel_id}"
if not await self.call_check(bot, group.group_id):
logger.debug(
"广播方法检测运行方法为 False, 已跳过该群组...",
self.log_cmd,
group_id=group.group_id,
)
return
if target := PlatformUtils.get_target(
group_id=group.group_id,
channel_id=group.channel_id,
):
self.ignore_group.append(key)
await MessageUtils.build_message(self.message).send(target, bot)
logger.debug("广播消息发送成功...", self.log_cmd, target=key)
else:
logger.warning("广播消息获取Target失败...", self.log_cmd, target=key)
async def broadcast(self) -> int:
"""广播消息
返回:
int: 成功发送次数
"""
for bot in self.bot_list:
if self.platform and self.platform != PlatformUtils.get_platform(bot):
continue
group_list, _ = await PlatformUtils.get_group_list(bot)
if not group_list:
continue
for group in group_list:
if (
group.group_id in self.ignore_group
or group.channel_id in self.ignore_group
):
continue
try:
await self.__send_message(bot, group)
await asyncio.sleep(random.randint(1, 3))
self.count += 1
except Exception as e:
logger.warning(
"广播消息发送失败", self.log_cmd, target=group.group_id, e=e
)
return self.count
async def broadcast_group( async def broadcast_group(
message: str | UniMessage, message: str | UniMessage,
bot: Bot | list[Bot] | None = None, bot: Bot | list[Bot] | None = None,
bot_id: str | set[str] | None = None, bot_id: str | set[str] | None = None,
ignore_group: set[int] | None = None, ignore_group: list[str] = [],
check_func: Callable[[Bot, str], Awaitable] | None = None, check_func: Callable[[Bot, str], Awaitable] | None = None,
log_cmd: str | None = None, log_cmd: str | None = None,
platform: Literal["qq", "dodo", "kaiheila"] | None = None, platform: str | None = None,
): ) -> int:
"""获取所有Bot或指定Bot对象广播群聊 """获取所有Bot或指定Bot对象广播群聊
参数: 参数:
@ -505,81 +624,18 @@ async def broadcast_group(
check_func: 发送前对群聊检测方法判断是否发送. check_func: 发送前对群聊检测方法判断是否发送.
log_cmd: 日志标记. log_cmd: 日志标记.
platform: 指定平台 platform: 指定平台
返回:
int: 成功发送次数
""" """
if platform and platform not in ["qq", "dodo", "kaiheila"]: if not message.strip():
raise ValueError("指定平台不支持") raise ValueError("群聊广播消息不能为空...")
if not message: return await BroadcastEngine(
raise ValueError("群聊广播消息不能为空") message=message,
bot_dict = nonebot.get_bots() bot=bot,
bot_list: list[Bot] = [] bot_id=bot_id,
if bot: ignore_group=ignore_group,
if isinstance(bot, list): check_func=check_func,
bot_list = bot log_cmd=log_cmd,
else: platform=platform,
bot_list.append(bot) ).broadcast()
elif bot_id:
_bot_id_list = bot_id
if isinstance(bot_id, str):
_bot_id_list = [bot_id]
for id_ in _bot_id_list:
if bot_id in bot_dict:
bot_list.append(bot_dict[bot_id])
else:
logger.warning(f"Bot:{id_} 对象未连接或不存在")
else:
bot_list = list(bot_dict.values())
_used_group = []
for _bot in bot_list:
try:
if platform and platform != PlatformUtils.get_platform(_bot):
continue
group_list, _ = await PlatformUtils.get_group_list(_bot)
if group_list:
for group in group_list:
key = f"{group.group_id}:{group.channel_id}"
try:
if (
ignore_group
and (
group.group_id in ignore_group
or group.channel_id in ignore_group
)
) or key in _used_group:
logger.debug(
"广播方法群组重复, 已跳过...",
log_cmd,
group_id=group.group_id,
)
continue
is_run = False
if check_func:
if is_coroutine_callable(check_func):
is_run = await check_func(_bot, group.group_id)
else:
is_run = check_func(_bot, group.group_id)
if not is_run:
logger.debug(
"广播方法检测运行方法为 False, 已跳过...",
log_cmd,
group_id=group.group_id,
)
continue
target = PlatformUtils.get_target(
user_id=None,
group_id=group.group_id,
channel_id=group.channel_id,
)
if target:
_used_group.append(key)
message_list = message
await MessageUtils.build_message(message_list).send(
target, _bot
)
logger.debug("发送成功", log_cmd, target=key)
await asyncio.sleep(random.randint(1, 3))
else:
logger.warning("target为空", log_cmd, target=key)
except Exception as e:
logger.error("发送失败", log_cmd, target=key, e=e)
except Exception as e:
logger.error(f"Bot: {_bot.self_id} 获取群聊列表失败", command=log_cmd, e=e)