mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
使用缓存cache
This commit is contained in:
parent
f346ee6cf3
commit
fd228b0bc7
10
.env.dev
10
.env.dev
@ -27,6 +27,16 @@ QBOT_ID_DATA = '{
|
||||
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
|
||||
DB_URL = ""
|
||||
|
||||
# REDIS配置,使用REDIS替换Cache内存缓存
|
||||
# REDIS地址
|
||||
# REDIS_HOST = "127.0.0.1"
|
||||
# REDIS端口
|
||||
# REDIS_PORT = 6379
|
||||
# REDIS密码
|
||||
# REDIS_PASSWORD = ""
|
||||
# REDIS过期时间
|
||||
# REDIS_EXPIRE = 600
|
||||
|
||||
# 系统代理
|
||||
# SYSTEM_PROXY = "http://127.0.0.1:7890"
|
||||
|
||||
|
||||
@ -150,7 +150,7 @@ poetry run python bot.py
|
||||
|
||||
1.在 .env.dev 文件中填写你的机器人配置项
|
||||
|
||||
2.在 data/config.yaml 文件中修改你需要修改的插件配置项
|
||||
2.在 configs/config.yaml 文件中修改你需要修改的插件配置项
|
||||
|
||||
<details>
|
||||
<summary>数据库地址(DB_URL)配置说明</summary>
|
||||
|
||||
951
poetry.lock
generated
951
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -16,7 +16,7 @@ python = "^3.10"
|
||||
playwright = "^1.41.1"
|
||||
nonebot-adapter-onebot = "^2.3.1"
|
||||
nonebot-plugin-apscheduler = "^0.5"
|
||||
tortoise-orm = { extras = ["asyncpg"], version = "^0.20.0" }
|
||||
tortoise-orm = "^0.20.0"
|
||||
cattrs = "^23.2.3"
|
||||
ruamel-yaml = "^0.18.5"
|
||||
strenum = "^0.4.15"
|
||||
@ -39,7 +39,7 @@ dateparser = "^1.2.0"
|
||||
bilireq = "0.2.3post0"
|
||||
python-jose = { extras = ["cryptography"], version = "^3.3.0" }
|
||||
python-multipart = "^0.0.9"
|
||||
aiocache = "^0.12.2"
|
||||
aiocache = {extras = ["redis"], version = "^0.12.3"}
|
||||
py-cpuinfo = "^9.0.0"
|
||||
nonebot-plugin-alconna = "^0.54.0"
|
||||
tenacity = "^9.0.0"
|
||||
@ -47,6 +47,9 @@ nonebot-plugin-uninfo = ">0.4.1"
|
||||
nonebot-plugin-waiter = "^0.8.1"
|
||||
multidict = ">=6.0.0,!=6.3.2"
|
||||
|
||||
redis = { version = ">=5", optional = true }
|
||||
asyncpg = { version = ">=0.20.0", optional = true }
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
nonebug = "^0.4"
|
||||
pytest-cov = "^5.0.0"
|
||||
@ -57,6 +60,9 @@ respx = "^0.21.1"
|
||||
ruff = "^0.8.0"
|
||||
pre-commit = "^4.0.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
redis = ["redis"]
|
||||
postgresql = ["asyncpg"]
|
||||
|
||||
[tool.nonebot]
|
||||
plugins = [
|
||||
|
||||
@ -26,21 +26,6 @@ __plugin_meta__ = PluginMetadata(
|
||||
_matcher = on_alconna(Alconna("关于"), priority=5, block=True, rule=to_me())
|
||||
|
||||
|
||||
QQ_INFO = """
|
||||
『绪山真寻Bot』
|
||||
版本:{version}
|
||||
简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处
|
||||
""".strip()
|
||||
|
||||
INFO = """
|
||||
『绪山真寻Bot』
|
||||
版本:{version}
|
||||
简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处
|
||||
项目地址:https://github.com/zhenxun-org/zhenxun_bot
|
||||
文档地址:https://zhenxun-org.github.io/zhenxun_bot/
|
||||
""".strip()
|
||||
|
||||
|
||||
@_matcher.handle()
|
||||
async def _(session: Uninfo, arparma: Arparma):
|
||||
ver_file = Path() / "__version__"
|
||||
@ -50,11 +35,25 @@ async def _(session: Uninfo, arparma: Arparma):
|
||||
if text := await f.read():
|
||||
version = text.split(":")[-1].strip()
|
||||
if PlatformUtils.is_qbot(session):
|
||||
result: list[str | Path] = [QQ_INFO.format(version=version)]
|
||||
info: list[str | Path] = [
|
||||
f"""
|
||||
『绪山真寻Bot』
|
||||
版本:{version}
|
||||
简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处
|
||||
""".strip()
|
||||
]
|
||||
path = DATA_PATH / "about.png"
|
||||
if path.exists():
|
||||
result.append(path)
|
||||
await MessageUtils.build_message(result).send() # type: ignore
|
||||
info.append(path)
|
||||
else:
|
||||
await MessageUtils.build_message(INFO.format(version=version)).send()
|
||||
info = [
|
||||
f"""
|
||||
『绪山真寻Bot』
|
||||
版本:{version}
|
||||
简介:基于Nonebot2开发,支持多平台,是一个非常可爱的Bot呀,希望与大家要好好相处
|
||||
项目地址:https://github.com/HibiKier/zhenxun_bot
|
||||
文档地址:https://hibikier.github.io/zhenxun_bot/
|
||||
""".strip()
|
||||
]
|
||||
await MessageUtils.build_message(info).send() # type: ignore
|
||||
logger.info("查看关于", arparma.header_result, session=session)
|
||||
|
||||
@ -14,13 +14,19 @@ from nonebot_plugin_alconna import (
|
||||
from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.configs.config import BotConfig, Config
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
||||
from zhenxun.configs.utils import (
|
||||
AICallableParam,
|
||||
AICallableProperties,
|
||||
AICallableTag,
|
||||
PluginExtraData,
|
||||
RegisterConfig,
|
||||
)
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.rules import admin_check
|
||||
|
||||
from ._data_source import BanManage
|
||||
from ._data_source import BanManage, call_ban
|
||||
|
||||
base_config = Config.get("ban")
|
||||
|
||||
@ -78,6 +84,22 @@ __plugin_meta__ = PluginMetadata(
|
||||
type=int,
|
||||
)
|
||||
],
|
||||
smart_tools=[
|
||||
AICallableTag(
|
||||
name="call_ban",
|
||||
description="某人多次(至少三次)辱骂你,调用此方法进行封禁",
|
||||
parameters=AICallableParam(
|
||||
type="object",
|
||||
properties={
|
||||
"user_id": AICallableProperties(
|
||||
type="string", description="用户的id"
|
||||
),
|
||||
},
|
||||
required=["user_id"],
|
||||
),
|
||||
func=call_ban,
|
||||
)
|
||||
],
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
|
||||
@ -5,8 +5,19 @@ from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate
|
||||
|
||||
async def call_ban(user_id: str):
|
||||
"""调用ban
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
"""
|
||||
await BanConsole.ban(user_id, None, 9, 60 * 12)
|
||||
logger.info("辱骂次数过多,已将用户加入黑名单...", "ban", session=user_id)
|
||||
|
||||
|
||||
|
||||
class BanManage:
|
||||
@classmethod
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
from nonebot import on_message
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
from nonebot_plugin_session import EventSession
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
||||
from zhenxun.models.chat_history import ChatHistory
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="消息存储",
|
||||
@ -37,18 +39,34 @@ def rule(message: UniMsg) -> bool:
|
||||
|
||||
chat_history = on_message(rule=rule, priority=1, block=False)
|
||||
|
||||
TEMP_LIST = []
|
||||
|
||||
|
||||
@chat_history.handle()
|
||||
async def handle_message(message: UniMsg, session: EventSession):
|
||||
"""处理消息存储"""
|
||||
try:
|
||||
await ChatHistory.create(
|
||||
user_id=session.id1,
|
||||
group_id=session.id2,
|
||||
async def _(message: UniMsg, session: Uninfo):
|
||||
entity = get_entity_ids(session)
|
||||
TEMP_LIST.append(
|
||||
ChatHistory(
|
||||
user_id=entity.user_id,
|
||||
group_id=entity.group_id,
|
||||
text=str(message),
|
||||
plain_text=message.extract_plain_text(),
|
||||
bot_id=session.bot_id,
|
||||
bot_id=session.self_id,
|
||||
platform=session.platform,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@scheduler.scheduled_job(
|
||||
"interval",
|
||||
minutes=1,
|
||||
)
|
||||
async def _():
|
||||
try:
|
||||
message_list = TEMP_LIST.copy()
|
||||
TEMP_LIST.clear()
|
||||
if message_list:
|
||||
await ChatHistory.bulk_create(message_list)
|
||||
logger.debug(f"批量添加聊天记录 {len(message_list)} 条", "定时任务")
|
||||
except Exception as e:
|
||||
logger.warning("存储聊天记录失败", "chat_history", e=e)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta
|
||||
from io import BytesIO
|
||||
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot_plugin_alconna import (
|
||||
@ -14,35 +15,38 @@ from nonebot_plugin_alconna import (
|
||||
from nonebot_plugin_session import EventSession
|
||||
import pytz
|
||||
|
||||
from zhenxun.configs.utils import Command, PluginExtraData
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.utils import Command, PluginExtraData, RegisterConfig
|
||||
from zhenxun.models.chat_history import ChatHistory
|
||||
from zhenxun.models.group_member_info import GroupInfoUser
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.image_utils import ImageTemplate
|
||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="消息统计",
|
||||
description="消息统计查询",
|
||||
usage="""
|
||||
格式:
|
||||
消息排行 ?[type [日,周,月,年]] ?[--des]
|
||||
消息排行 ?[type [日,周,月,季,年]] ?[--des]
|
||||
|
||||
快捷:
|
||||
[日,周,月,年]消息排行 ?[数量]
|
||||
[日,周,月,季,年]消息排行 ?[数量]
|
||||
|
||||
示例:
|
||||
消息排行 : 所有记录排行
|
||||
日消息排行 : 今日记录排行
|
||||
周消息排行 : 今日记录排行
|
||||
月消息排行 : 今日记录排行
|
||||
年消息排行 : 今日记录排行
|
||||
周消息排行 : 本周记录排行
|
||||
月消息排行 : 本月记录排行
|
||||
季消息排行 : 本季度记录排行
|
||||
年消息排行 : 本年记录排行
|
||||
消息排行 周 --des : 逆序周记录排行
|
||||
""".strip(),
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="0.1",
|
||||
version="0.2",
|
||||
plugin_type=PluginType.NORMAL,
|
||||
menu_type="数据统计",
|
||||
commands=[
|
||||
@ -50,8 +54,19 @@ __plugin_meta__ = PluginMetadata(
|
||||
Command(command="日消息统计"),
|
||||
Command(command="周消息排行"),
|
||||
Command(command="月消息排行"),
|
||||
Command(command="季消息排行"),
|
||||
Command(command="年消息排行"),
|
||||
],
|
||||
configs=[
|
||||
RegisterConfig(
|
||||
module="chat_history",
|
||||
key="SHOW_QUIT_MEMBER",
|
||||
value=True,
|
||||
help="是否在消息排行中显示已退群用户",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
)
|
||||
],
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
@ -60,7 +75,7 @@ _matcher = on_alconna(
|
||||
Alconna(
|
||||
"消息排行",
|
||||
Option("--des", action=store_true, help_text="逆序"),
|
||||
Args["type?", ["日", "周", "月", "年"]]["count?", int, 10],
|
||||
Args["type?", ["日", "周", "月", "季", "年"]]["count?", int, 10],
|
||||
),
|
||||
aliases={"消息统计"},
|
||||
priority=5,
|
||||
@ -68,7 +83,7 @@ _matcher = on_alconna(
|
||||
)
|
||||
|
||||
_matcher.shortcut(
|
||||
r"(?P<type>['日', '周', '月', '年'])?消息(排行|统计)\s?(?P<cnt>\d+)?",
|
||||
r"(?P<type>['日', '周', '月', '季', '年'])?消息(排行|统计)\s?(?P<cnt>\d+)?",
|
||||
command="消息排行",
|
||||
arguments=["{type}", "{cnt}"],
|
||||
prefix=True,
|
||||
@ -96,20 +111,57 @@ async def _(
|
||||
date_scope = (time_now - timedelta(days=7), time_now)
|
||||
elif date in ["月"]:
|
||||
date_scope = (time_now - timedelta(days=30), time_now)
|
||||
column_name = ["名次", "昵称", "发言次数"]
|
||||
elif date in ["季"]:
|
||||
date_scope = (time_now - timedelta(days=90), time_now)
|
||||
column_name = ["名次", "头像", "昵称", "发言次数"]
|
||||
show_quit_member = Config.get_config("chat_history", "SHOW_QUIT_MEMBER", True)
|
||||
|
||||
fetch_count = count.result
|
||||
if not show_quit_member:
|
||||
fetch_count = count.result * 2
|
||||
|
||||
if rank_data := await ChatHistory.get_group_msg_rank(
|
||||
group_id, count.result, "DES" if arparma.find("des") else "DESC", date_scope
|
||||
group_id, fetch_count, "DES" if arparma.find("des") else "DESC", date_scope
|
||||
):
|
||||
idx = 1
|
||||
data_list = []
|
||||
|
||||
for uid, num in rank_data:
|
||||
if user := await GroupInfoUser.filter(
|
||||
if len(data_list) >= count.result:
|
||||
break
|
||||
|
||||
user_in_group = await GroupInfoUser.filter(
|
||||
user_id=uid, group_id=group_id
|
||||
).first():
|
||||
user_name = user.user_name
|
||||
).first()
|
||||
|
||||
if not user_in_group and not show_quit_member:
|
||||
continue
|
||||
|
||||
if user_in_group:
|
||||
user_name = user_in_group.user_name
|
||||
else:
|
||||
user_name = uid
|
||||
data_list.append([idx, user_name, num])
|
||||
user_name = f"{uid}(已退群)"
|
||||
|
||||
avatar_size = 40
|
||||
try:
|
||||
avatar_bytes = await PlatformUtils.get_user_avatar(str(uid), "qq")
|
||||
if avatar_bytes:
|
||||
avatar_img = BuildImage(
|
||||
avatar_size, avatar_size, background=BytesIO(avatar_bytes)
|
||||
)
|
||||
await avatar_img.circle()
|
||||
avatar_tuple = (avatar_img, avatar_size, avatar_size)
|
||||
else:
|
||||
avatar_img = BuildImage(avatar_size, avatar_size, color="#CCCCCC")
|
||||
await avatar_img.circle()
|
||||
avatar_tuple = (avatar_img, avatar_size, avatar_size)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取用户头像失败: {e}", "chat_history")
|
||||
avatar_img = BuildImage(avatar_size, avatar_size, color="#CCCCCC")
|
||||
await avatar_img.circle()
|
||||
avatar_tuple = (avatar_img, avatar_size, avatar_size)
|
||||
|
||||
data_list.append([idx, avatar_tuple, user_name, num])
|
||||
idx += 1
|
||||
if not date_scope:
|
||||
if date_scope := await ChatHistory.get_group_first_msg_datetime(group_id):
|
||||
@ -132,13 +184,3 @@ async def _(
|
||||
)
|
||||
await MessageUtils.build_message(A).finish(reply_to=True)
|
||||
await MessageUtils.build_message("群组消息记录为空...").finish()
|
||||
|
||||
|
||||
# # @test.handle()
|
||||
# # async def _(event: MessageEvent):
|
||||
# # print(await ChatHistory.get_user_msg(event.user_id, "private"))
|
||||
# # print(await ChatHistory.get_user_msg_count(event.user_id, "private"))
|
||||
# # print(await ChatHistory.get_user_msg(event.user_id, "group"))
|
||||
# # print(await ChatHistory.get_user_msg_count(event.user_id, "group"))
|
||||
# # print(await ChatHistory.get_group_msg(event.group_id))
|
||||
# # print(await ChatHistory.get_group_msg_count(event.group_id))
|
||||
|
||||
@ -37,8 +37,8 @@ __plugin_meta__ = PluginMetadata(
|
||||
configs=[
|
||||
RegisterConfig(
|
||||
key="type",
|
||||
value="normal",
|
||||
help="帮助图片样式 ['normal', 'HTML', 'zhenxun']",
|
||||
value="zhenxun",
|
||||
help="帮助图片样式 [normal, HTML, zhenxun]",
|
||||
default_value="zhenxun",
|
||||
)
|
||||
],
|
||||
|
||||
@ -40,7 +40,9 @@ async def create_help_img(
|
||||
|
||||
match help_type:
|
||||
case "html":
|
||||
result = BuildImage.open(await build_html_image(group_id, is_detail))
|
||||
result = BuildImage.open(
|
||||
await build_html_image(session, group_id, is_detail)
|
||||
)
|
||||
case "zhenxun":
|
||||
result = BuildImage.open(
|
||||
await build_zhenxun_image(session, group_id, is_detail)
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.utils.enum import PluginType
|
||||
@ -27,13 +30,15 @@ async def sort_type() -> dict[str, list[PluginInfo]]:
|
||||
|
||||
|
||||
async def classify_plugin(
|
||||
group_id: str | None, is_detail: bool, handle: Callable
|
||||
session: Uninfo, group_id: str | None, is_detail: bool, handle: Callable
|
||||
) -> dict[str, list]:
|
||||
"""对插件进行分类并判断状态
|
||||
|
||||
参数:
|
||||
session: Uninfo对象
|
||||
group_id: 群组id
|
||||
is_detail: 是否详细帮助
|
||||
handle: 回调方法
|
||||
|
||||
返回:
|
||||
dict[str, list[Item]]: 分类插件数据
|
||||
@ -41,9 +46,10 @@ async def classify_plugin(
|
||||
sort_data = await sort_type()
|
||||
classify: dict[str, list] = {}
|
||||
group = await GroupConsole.get_or_none(group_id=group_id) if group_id else None
|
||||
bot = await BotConsole.get_or_none(bot_id=session.self_id)
|
||||
for menu, value in sort_data.items():
|
||||
for plugin in value:
|
||||
if not classify.get(menu):
|
||||
classify[menu] = []
|
||||
classify[menu].append(handle(plugin, group, is_detail))
|
||||
classify[menu].append(handle(bot, plugin, group, is_detail))
|
||||
return classify
|
||||
|
||||
@ -2,9 +2,11 @@ import os
|
||||
import random
|
||||
|
||||
from nonebot_plugin_htmlrender import template_to_pic
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.configs.path_config import TEMPLATE_PATH
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.utils.enum import BlockType
|
||||
@ -48,11 +50,12 @@ ICON2STR = {
|
||||
|
||||
|
||||
def __handle_item(
|
||||
plugin: PluginInfo, group: GroupConsole | None, is_detail: bool
|
||||
bot: BotConsole, plugin: PluginInfo, group: GroupConsole | None, is_detail: bool
|
||||
) -> Item:
|
||||
"""构造Item
|
||||
|
||||
参数:
|
||||
bot: BotConsole
|
||||
plugin: PluginInfo
|
||||
group: 群组
|
||||
is_detail: 是否详细
|
||||
@ -73,10 +76,13 @@ def __handle_item(
|
||||
]:
|
||||
sta = 2
|
||||
if group:
|
||||
if f"{plugin.module}:super," in group.block_plugin:
|
||||
if f"{plugin.module}," in group.superuser_block_plugin:
|
||||
sta = 2
|
||||
if f"{plugin.module}," in group.block_plugin:
|
||||
sta = 1
|
||||
if bot:
|
||||
if f"{plugin.module}," in bot.block_plugins:
|
||||
sta = 2
|
||||
return Item(plugin_name=plugin.name, sta=sta)
|
||||
|
||||
|
||||
@ -119,14 +125,17 @@ def build_plugin_data(classify: dict[str, list[Item]]) -> list[dict[str, str]]:
|
||||
return plugin_list
|
||||
|
||||
|
||||
async def build_html_image(group_id: str | None, is_detail: bool) -> bytes:
|
||||
async def build_html_image(
|
||||
session: Uninfo, group_id: str | None, is_detail: bool
|
||||
) -> bytes:
|
||||
"""构造HTML帮助图片
|
||||
|
||||
参数:
|
||||
session: Uninfo
|
||||
group_id: 群号
|
||||
is_detail: 是否详细帮助
|
||||
"""
|
||||
classify = await classify_plugin(group_id, is_detail, __handle_item)
|
||||
classify = await classify_plugin(session, group_id, is_detail, __handle_item)
|
||||
plugin_list = build_plugin_data(classify)
|
||||
return await template_to_pic(
|
||||
template_path=str((TEMPLATE_PATH / "menu").absolute()),
|
||||
|
||||
@ -6,6 +6,7 @@ from pydantic import BaseModel
|
||||
from zhenxun.configs.config import BotConfig
|
||||
from zhenxun.configs.path_config import TEMPLATE_PATH
|
||||
from zhenxun.configs.utils import PluginExtraData
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.utils.enum import BlockType
|
||||
@ -21,12 +22,19 @@ class Item(BaseModel):
|
||||
"""插件命令"""
|
||||
|
||||
|
||||
def __handle_item(plugin: PluginInfo, group: GroupConsole | None, is_detail: bool):
|
||||
def __handle_item(
|
||||
bot: BotConsole | None,
|
||||
plugin: PluginInfo,
|
||||
group: GroupConsole | None,
|
||||
is_detail: bool,
|
||||
):
|
||||
"""构造Item
|
||||
|
||||
参数:
|
||||
bot: BotConsole
|
||||
plugin: PluginInfo
|
||||
group: 群组
|
||||
is_detail: 是否为详细
|
||||
|
||||
返回:
|
||||
Item: Item
|
||||
@ -40,6 +48,8 @@ def __handle_item(plugin: PluginInfo, group: GroupConsole | None, is_detail: boo
|
||||
plugin.name = f"{plugin.name}(不可用)"
|
||||
elif group and f"{plugin.module}," in group.block_plugin:
|
||||
plugin.name = f"{plugin.name}(不可用)"
|
||||
elif bot and f"{plugin.module}," in bot.block_plugins:
|
||||
plugin.name = f"{plugin.name}(不可用)"
|
||||
commands = []
|
||||
nb_plugin = nonebot.get_plugin_by_module_name(plugin.module_path)
|
||||
if is_detail and nb_plugin and nb_plugin.metadata and nb_plugin.metadata.extra:
|
||||
@ -142,7 +152,7 @@ async def build_zhenxun_image(
|
||||
group_id: 群号
|
||||
is_detail: 是否详细帮助
|
||||
"""
|
||||
classify = await classify_plugin(group_id, is_detail, __handle_item)
|
||||
classify = await classify_plugin(session, group_id, is_detail, __handle_item)
|
||||
plugin_list = build_plugin_data(classify)
|
||||
platform = PlatformUtils.get_platform(session)
|
||||
bot_id = BotConfig.get_qbot_uid(session.self_id) or session.self_id
|
||||
|
||||
@ -49,4 +49,14 @@ Config.add_plugin_config(
|
||||
type=bool,
|
||||
)
|
||||
|
||||
Config.add_plugin_config(
|
||||
"hook",
|
||||
"RECORD_BOT_SENT_MESSAGES",
|
||||
True,
|
||||
help="记录bot消息校内",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
)
|
||||
|
||||
|
||||
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
||||
|
||||
@ -1,597 +0,0 @@
|
||||
from typing import ClassVar
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
||||
from nonebot.exception import IgnoredException
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot_plugin_alconna import At, UniMsg
|
||||
from nonebot_plugin_session import EventSession
|
||||
from pydantic import BaseModel
|
||||
from tortoise.exceptions import IntegrityError
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.plugin_limit import PluginLimit
|
||||
from zhenxun.models.sign_user import SignUser
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import (
|
||||
BlockType,
|
||||
GoldHandle,
|
||||
LimitWatchType,
|
||||
PluginLimitType,
|
||||
PluginType,
|
||||
)
|
||||
from zhenxun.utils.exception import InsufficientGold
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import CountLimiter, FreqLimiter, UserBlockLimiter
|
||||
|
||||
base_config = Config.get("hook")
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
limit: PluginLimit
|
||||
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LimitManage:
|
||||
add_module: ClassVar[list] = []
|
||||
|
||||
cd_limit: ClassVar[dict[str, Limit]] = {}
|
||||
block_limit: ClassVar[dict[str, Limit]] = {}
|
||||
count_limit: ClassVar[dict[str, Limit]] = {}
|
||||
|
||||
@classmethod
|
||||
def add_limit(cls, limit: PluginLimit):
|
||||
"""添加限制
|
||||
|
||||
参数:
|
||||
limit: PluginLimit
|
||||
"""
|
||||
if limit.module not in cls.add_module:
|
||||
cls.add_module.append(limit.module)
|
||||
if limit.limit_type == PluginLimitType.BLOCK:
|
||||
cls.block_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=UserBlockLimiter()
|
||||
)
|
||||
elif limit.limit_type == PluginLimitType.CD:
|
||||
cls.cd_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=FreqLimiter(limit.cd)
|
||||
)
|
||||
elif limit.limit_type == PluginLimitType.COUNT:
|
||||
cls.count_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=CountLimiter(limit.max_count)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unblock(
|
||||
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
|
||||
):
|
||||
"""解除插件block
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
"""
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
limit = limit_model.limit
|
||||
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
|
||||
key_type = user_id
|
||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||
key_type = channel_id or group_id
|
||||
logger.debug(
|
||||
f"解除对象: {key_type} 的block限制",
|
||||
"AuthChecker",
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
limiter.set_false(key_type)
|
||||
|
||||
@classmethod
|
||||
async def check(
|
||||
cls,
|
||||
module: str,
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
session: EventSession,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
session: Session
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
if limit_model := cls.cd_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||
if limit_model := cls.count_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id, session)
|
||||
|
||||
@classmethod
|
||||
async def __check(
|
||||
cls,
|
||||
limit_model: Limit | None,
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
session: EventSession,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
参数:
|
||||
limit_model: Limit
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
session: Session
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
if not limit_model:
|
||||
return
|
||||
limit = limit_model.limit
|
||||
limiter = limit_model.limiter
|
||||
is_limit = (
|
||||
LimitWatchType.ALL
|
||||
or (group_id and limit.watch_type == LimitWatchType.GROUP)
|
||||
or (not group_id and limit.watch_type == LimitWatchType.USER)
|
||||
)
|
||||
key_type = user_id
|
||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||
key_type = channel_id or group_id
|
||||
if is_limit and not limiter.check(key_type):
|
||||
if limit.result:
|
||||
await MessageUtils.build_message(limit.result).send()
|
||||
logger.debug(
|
||||
f"{limit.module}({limit.limit_type}) 正在限制中...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException(f"{limit.module} 正在限制中...")
|
||||
else:
|
||||
logger.debug(
|
||||
f"开始进行限制 {limit.module}({limit.limit_type})...",
|
||||
"AuthChecker",
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
if isinstance(limiter, FreqLimiter):
|
||||
limiter.start_cd(key_type)
|
||||
if isinstance(limiter, UserBlockLimiter):
|
||||
limiter.set_true(key_type)
|
||||
if isinstance(limiter, CountLimiter):
|
||||
limiter.increase(key_type)
|
||||
|
||||
|
||||
class IsSuperuserException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AuthChecker:
|
||||
"""
|
||||
权限检查
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
|
||||
if check_notice_info_cd is None or check_notice_info_cd < 0:
|
||||
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
|
||||
self._flmt = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_g = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_s = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_c = FreqLimiter(check_notice_info_cd)
|
||||
|
||||
def is_send_limit_message(self, plugin: PluginInfo, sid: str) -> bool:
|
||||
"""是否发送提示消息
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
返回:
|
||||
bool: 是否发送提示消息
|
||||
"""
|
||||
if not base_config.get("IS_SEND_TIP_MESSAGE"):
|
||||
return False
|
||||
if plugin.plugin_type == PluginType.DEPENDANT:
|
||||
return False
|
||||
if plugin.ignore_prompt:
|
||||
return False
|
||||
return self._flmt_s.check(sid)
|
||||
|
||||
async def auth(
|
||||
self,
|
||||
matcher: Matcher,
|
||||
event: Event,
|
||||
bot: Bot,
|
||||
session: EventSession,
|
||||
message: UniMsg,
|
||||
):
|
||||
"""权限检查
|
||||
|
||||
参数:
|
||||
matcher: matcher
|
||||
bot: bot
|
||||
session: EventSession
|
||||
message: UniMsg
|
||||
"""
|
||||
is_ignore = False
|
||||
cost_gold = 0
|
||||
user_id = session.id1
|
||||
group_id = session.id3
|
||||
channel_id = session.id2
|
||||
if not group_id:
|
||||
group_id = channel_id
|
||||
channel_id = None
|
||||
if matcher.type == "notice" and not isinstance(event, PokeNotifyEvent):
|
||||
"""过滤除poke外的notice"""
|
||||
return
|
||||
if user_id and matcher.plugin and (module_path := matcher.plugin.module_name):
|
||||
try:
|
||||
user = await UserConsole.get_user(user_id, session.platform)
|
||||
except IntegrityError as e:
|
||||
logger.debug(
|
||||
"重复创建用户,已跳过该次权限...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
return
|
||||
if plugin := await PluginInfo.get_or_none(module_path=module_path):
|
||||
if plugin.plugin_type == PluginType.HIDDEN:
|
||||
logger.debug(
|
||||
f"插件: {plugin.name}:{plugin.module} "
|
||||
"为HIDDEN,已跳过权限检查..."
|
||||
)
|
||||
return
|
||||
try:
|
||||
cost_gold = await self.auth_cost(user, plugin, session)
|
||||
if session.id1 in bot.config.superusers:
|
||||
if plugin.plugin_type == PluginType.SUPERUSER:
|
||||
raise IsSuperuserException()
|
||||
if not plugin.limit_superuser:
|
||||
cost_gold = 0
|
||||
raise IsSuperuserException()
|
||||
await self.auth_bot(plugin, bot.self_id)
|
||||
await self.auth_group(plugin, session, message)
|
||||
await self.auth_admin(plugin, session)
|
||||
await self.auth_plugin(plugin, session, event)
|
||||
await self.auth_limit(plugin, session)
|
||||
except IsSuperuserException:
|
||||
logger.debug(
|
||||
"超级用户或被ban跳过权限检测...", "AuthChecker", session=session
|
||||
)
|
||||
except IgnoredException:
|
||||
is_ignore = True
|
||||
LimitManage.unblock(
|
||||
matcher.plugin.name, user_id, group_id, channel_id
|
||||
)
|
||||
except AssertionError as e:
|
||||
is_ignore = True
|
||||
logger.debug("消息无法发送", session=session, e=e)
|
||||
if cost_gold and user_id:
|
||||
"""花费金币"""
|
||||
try:
|
||||
await UserConsole.reduce_gold(
|
||||
user_id,
|
||||
cost_gold,
|
||||
GoldHandle.PLUGIN,
|
||||
matcher.plugin.name if matcher.plugin else "",
|
||||
session.platform,
|
||||
)
|
||||
except InsufficientGold:
|
||||
if u := await UserConsole.get_user(user_id):
|
||||
u.gold = 0
|
||||
await u.save(update_fields=["gold"])
|
||||
logger.debug(
|
||||
f"调用功能花费金币: {cost_gold}", "AuthChecker", session=session
|
||||
)
|
||||
if is_ignore:
|
||||
raise IgnoredException("权限检测 ignore")
|
||||
|
||||
async def auth_bot(self, plugin: PluginInfo, bot_id: str):
|
||||
"""机器人权限
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
bot_id: bot_id
|
||||
"""
|
||||
if not await BotConsole.get_bot_status(bot_id):
|
||||
logger.debug("Bot休眠中阻断权限检测...", "AuthChecker")
|
||||
raise IgnoredException("BotConsole休眠权限检测 ignore")
|
||||
if await BotConsole.is_block_plugin(bot_id, plugin.module):
|
||||
logger.debug(
|
||||
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭...",
|
||||
"AuthChecker",
|
||||
)
|
||||
raise IgnoredException("BotConsole插件权限检测 ignore")
|
||||
|
||||
async def auth_limit(self, plugin: PluginInfo, session: EventSession):
|
||||
"""插件限制
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
"""
|
||||
user_id = session.id1
|
||||
group_id = session.id3
|
||||
channel_id = session.id2
|
||||
if not group_id:
|
||||
group_id = channel_id
|
||||
channel_id = None
|
||||
if plugin.module not in LimitManage.add_module:
|
||||
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
|
||||
status=True
|
||||
).all() # type: ignore
|
||||
for limit in limit_list:
|
||||
LimitManage.add_limit(limit)
|
||||
if user_id:
|
||||
await LimitManage.check(
|
||||
plugin.module, user_id, group_id, channel_id, session
|
||||
)
|
||||
|
||||
async def auth_plugin(
|
||||
self, plugin: PluginInfo, session: EventSession, event: Event
|
||||
):
|
||||
"""插件状态
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
"""
|
||||
group_id = session.id3
|
||||
channel_id = session.id2
|
||||
if not group_id:
|
||||
group_id = channel_id
|
||||
channel_id = None
|
||||
if user_id := session.id1:
|
||||
if plugin.impression > 0:
|
||||
sign_user = await SignUser.get_user(user_id)
|
||||
if float(sign_user.impression) < plugin.impression:
|
||||
if self.is_send_limit_message(plugin, user_id):
|
||||
self._flmt_s.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
f"好感度不足哦,当前功能需要好感度: {plugin.impression},"
|
||||
"请继续签到提升好感度吧!"
|
||||
).send(reply_to=True)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 用户好感度不足...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("好感度不足...")
|
||||
if group_id:
|
||||
sid = group_id or user_id
|
||||
if await GroupConsole.is_superuser_block_plugin(
|
||||
group_id, plugin.module
|
||||
):
|
||||
"""超级用户群组插件状态"""
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_s.start_cd(group_id or user_id)
|
||||
await MessageUtils.build_message(
|
||||
"超级管理员禁用了该群此功能..."
|
||||
).send(reply_to=True)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("超级管理员禁用了该群此功能...")
|
||||
if await GroupConsole.is_normal_block_plugin(group_id, plugin.module):
|
||||
"""群组插件状态"""
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_s.start_cd(group_id or user_id)
|
||||
await MessageUtils.build_message("该群未开启此功能...").send(
|
||||
reply_to=True
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 未开启此功能...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("该群未开启此功能...")
|
||||
if plugin.block_type == BlockType.GROUP:
|
||||
"""全局群组禁用"""
|
||||
try:
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_c.start_cd(group_id)
|
||||
await MessageUtils.build_message(
|
||||
"该功能在群组中已被禁用..."
|
||||
).send(reply_to=True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_plugin 发送消息失败",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("该插件在群组中已被禁用...")
|
||||
else:
|
||||
sid = user_id
|
||||
if plugin.block_type == BlockType.PRIVATE:
|
||||
"""全局私聊禁用"""
|
||||
try:
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_c.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
"该功能在私聊中已被禁用..."
|
||||
).send()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_admin 发送消息失败",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("该插件在私聊中已被禁用...")
|
||||
if not plugin.status and plugin.block_type == BlockType.ALL:
|
||||
"""全局状态"""
|
||||
if group_id and await GroupConsole.is_super_group(group_id):
|
||||
raise IsSuperuserException()
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 全局未开启此功能...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
if self.is_send_limit_message(plugin, sid):
|
||||
self._flmt_s.start_cd(group_id or user_id)
|
||||
await MessageUtils.build_message("全局未开启此功能...").send()
|
||||
raise IgnoredException("全局未开启此功能...")
|
||||
|
||||
async def auth_admin(self, plugin: PluginInfo, session: EventSession):
|
||||
"""管理员命令 个人权限
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
"""
|
||||
user_id = session.id1
|
||||
if user_id and plugin.admin_level:
|
||||
if group_id := session.id3 or session.id2:
|
||||
if not await LevelUser.check_level(
|
||||
user_id, group_id, plugin.admin_level
|
||||
):
|
||||
try:
|
||||
if self._flmt.check(user_id):
|
||||
self._flmt.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
[
|
||||
At(flag="user", target=user_id),
|
||||
f"你的权限不足喔,"
|
||||
f"该功能需要的权限等级: {plugin.admin_level}",
|
||||
]
|
||||
).send(reply_to=True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_admin 发送消息失败",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("管理员权限不足...")
|
||||
elif not await LevelUser.check_level(user_id, None, plugin.admin_level):
|
||||
try:
|
||||
await MessageUtils.build_message(
|
||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}"
|
||||
).send()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_admin 发送消息失败", "AuthChecker", session=session, e=e
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("权限不足")
|
||||
|
||||
async def auth_group(
|
||||
self, plugin: PluginInfo, session: EventSession, message: UniMsg
|
||||
):
|
||||
"""群黑名单检测 群总开关检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
message: UniMsg
|
||||
"""
|
||||
if not (group_id := session.id3 or session.id2):
|
||||
return
|
||||
text = message.extract_plain_text()
|
||||
group = await GroupConsole.get_group(group_id)
|
||||
if not group:
|
||||
"""群不存在"""
|
||||
logger.debug(
|
||||
"群组信息不存在...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("群不存在")
|
||||
if group.level < 0:
|
||||
"""群权限小于0"""
|
||||
logger.debug(
|
||||
"群黑名单, 群权限-1...",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException("群黑名单")
|
||||
if not group.status:
|
||||
"""群休眠"""
|
||||
if text.strip() != "醒来":
|
||||
logger.debug("群休眠状态...", "AuthChecker", session=session)
|
||||
raise IgnoredException("群休眠状态")
|
||||
if plugin.level > group.level:
|
||||
"""插件等级大于群等级"""
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 群等级限制.."
|
||||
f"该功能需要的群等级: {plugin.level}..",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException(f"{plugin.name}({plugin.module}) 群等级限制...")
|
||||
|
||||
async def auth_cost(
|
||||
self, user: UserConsole, plugin: PluginInfo, session: EventSession
|
||||
) -> int:
|
||||
"""检测是否满足金币条件
|
||||
|
||||
参数:
|
||||
user: UserConsole
|
||||
plugin: PluginInfo
|
||||
session: EventSession
|
||||
|
||||
返回:
|
||||
int: 需要消耗的金币
|
||||
"""
|
||||
if user.gold < plugin.cost_gold:
|
||||
"""插件消耗金币不足"""
|
||||
try:
|
||||
await MessageUtils.build_message(
|
||||
f"金币不足..该功能需要{plugin.cost_gold}金币.."
|
||||
).send()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"auth_cost 发送消息失败", "AuthChecker", session=session, e=e
|
||||
)
|
||||
logger.debug(
|
||||
f"{plugin.name}({plugin.module}) 金币限制.."
|
||||
f"该功能需要{plugin.cost_gold}金币..",
|
||||
"AuthChecker",
|
||||
session=session,
|
||||
)
|
||||
raise IgnoredException(f"{plugin.name}({plugin.module}) 金币限制...")
|
||||
return plugin.cost_gold
|
||||
|
||||
|
||||
checker = AuthChecker()
|
||||
52
zhenxun/builtin_plugins/hooks/auth/auth_admin.py
Normal file
52
zhenxun/builtin_plugins/hooks/auth/auth_admin.py
Normal file
@ -0,0 +1,52 @@
|
||||
from nonebot_plugin_alconna import At
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.cache import Cache
|
||||
from zhenxun.utils.enum import CacheType
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
from .exception import SkipPluginException
|
||||
from .utils import send_message
|
||||
|
||||
|
||||
async def auth_admin(plugin: PluginInfo, session: Uninfo):
|
||||
"""管理员命令 个人权限
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
"""
|
||||
if not plugin.admin_level:
|
||||
return
|
||||
entity = get_entity_ids(session)
|
||||
cache = Cache[list[LevelUser]](CacheType.LEVEL)
|
||||
user_list = await cache.get(session.user.id) or []
|
||||
if entity.group_id:
|
||||
user_list += await cache.get(session.user.id, entity.group_id) or []
|
||||
if user_list:
|
||||
user = max(user_list, key=lambda x: x.user_level)
|
||||
user_level = user.user_level
|
||||
else:
|
||||
user_level = 0
|
||||
if user_level < plugin.admin_level:
|
||||
await send_message(
|
||||
session,
|
||||
[
|
||||
At(flag="user", target=session.user.id),
|
||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||
],
|
||||
entity.user_id,
|
||||
)
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||
)
|
||||
elif user_list:
|
||||
user = max(user_list, key=lambda x: x.user_level)
|
||||
if user.user_level < plugin.admin_level:
|
||||
await send_message(
|
||||
session,
|
||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||
)
|
||||
raise SkipPluginException(f"{plugin.name}({plugin.module}) 管理员权限不足...")
|
||||
175
zhenxun/builtin_plugins/hooks/auth/auth_ban.py
Normal file
175
zhenxun/builtin_plugins/hooks/auth/auth_ban.py
Normal file
@ -0,0 +1,175 @@
|
||||
import asyncio
|
||||
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot_plugin_alconna import At
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from tortoise.exceptions import MultipleObjectsReturned
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.cache import Cache
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import CacheType, PluginType
|
||||
from zhenxun.utils.utils import EntityIDs, get_entity_ids
|
||||
|
||||
from .config import LOGGER_COMMAND
|
||||
from .exception import SkipPluginException
|
||||
from .utils import freq, send_message
|
||||
|
||||
Config.add_plugin_config(
|
||||
"hook",
|
||||
"BAN_RESULT",
|
||||
"才不会给你发消息.",
|
||||
help="对被ban用户发送的消息",
|
||||
)
|
||||
|
||||
|
||||
async def is_ban(user_id: str | None, group_id: str | None) -> int:
|
||||
if not user_id and not group_id:
|
||||
return 0
|
||||
cache = Cache[BanConsole](CacheType.BAN)
|
||||
group_user, user = await asyncio.gather(
|
||||
cache.get(user_id, group_id), cache.get(user_id)
|
||||
)
|
||||
results = []
|
||||
if group_user:
|
||||
results.append(group_user)
|
||||
if user:
|
||||
results.append(user)
|
||||
if not results:
|
||||
return 0
|
||||
for result in results:
|
||||
if result.duration > 0 or result.duration == -1:
|
||||
return await BanConsole.check_ban_time(user_id, group_id)
|
||||
return 0
|
||||
|
||||
|
||||
def check_plugin_type(matcher: Matcher) -> bool:
|
||||
"""判断插件类型是否是隐藏插件
|
||||
|
||||
参数:
|
||||
matcher: Matcher
|
||||
|
||||
返回:
|
||||
bool: 是否为隐藏插件
|
||||
"""
|
||||
if plugin := matcher.plugin:
|
||||
if metadata := plugin.metadata:
|
||||
extra = metadata.extra
|
||||
if extra.get("plugin_type") in [PluginType.HIDDEN]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def format_time(time: float) -> str:
|
||||
"""格式化时间
|
||||
|
||||
参数:
|
||||
time: ban时长
|
||||
|
||||
返回:
|
||||
str: 格式化时间文本
|
||||
"""
|
||||
if time == -1:
|
||||
return "∞"
|
||||
time = abs(int(time))
|
||||
if time < 60:
|
||||
time_str = f"{time!s} 秒"
|
||||
else:
|
||||
minute = int(time / 60)
|
||||
if minute > 60:
|
||||
hours = minute // 60
|
||||
minute %= 60
|
||||
time_str = f"{hours} 小时 {minute}分钟"
|
||||
else:
|
||||
time_str = f"{minute} 分钟"
|
||||
return time_str
|
||||
|
||||
|
||||
async def group_handle(cache: Cache[list[BanConsole]], group_id: str):
|
||||
"""群组ban检查
|
||||
|
||||
参数:
|
||||
cache: cache
|
||||
group_id: 群组id
|
||||
|
||||
异常:
|
||||
SkipPluginException: 群组处于黑名单
|
||||
"""
|
||||
try:
|
||||
if await is_ban(None, group_id):
|
||||
raise SkipPluginException("群组处于黑名单中...")
|
||||
except MultipleObjectsReturned:
|
||||
logger.warning(
|
||||
"群组黑名单数据重复,过滤该次hook并移除多余数据...", LOGGER_COMMAND
|
||||
)
|
||||
ids = await BanConsole.filter(user_id="", group_id=group_id).values_list(
|
||||
"id", flat=True
|
||||
)
|
||||
await BanConsole.filter(id__in=ids[:-1]).delete()
|
||||
await cache.reload()
|
||||
|
||||
|
||||
async def user_handle(
|
||||
module: str, cache: Cache[list[BanConsole]], entity: EntityIDs, session: Uninfo
|
||||
):
|
||||
"""用户ban检查
|
||||
|
||||
参数:
|
||||
module: 插件模块名
|
||||
cache: cache
|
||||
user_id: 用户id
|
||||
session: Uninfo
|
||||
|
||||
异常:
|
||||
SkipPluginException: 用户处于黑名单
|
||||
"""
|
||||
ban_result = Config.get_config("hook", "BAN_RESULT")
|
||||
try:
|
||||
time = await is_ban(entity.user_id, entity.group_id)
|
||||
if not time:
|
||||
return
|
||||
time_str = format_time(time)
|
||||
db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
|
||||
if (
|
||||
db_plugin
|
||||
# and not db_plugin.ignore_prompt
|
||||
and time != -1
|
||||
and ban_result
|
||||
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
|
||||
):
|
||||
await send_message(
|
||||
session,
|
||||
[
|
||||
At(flag="user", target=entity.user_id),
|
||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||
],
|
||||
entity.user_id,
|
||||
)
|
||||
raise SkipPluginException("用户处于黑名单中...")
|
||||
except MultipleObjectsReturned:
|
||||
logger.warning(
|
||||
"用户黑名单数据重复,过滤该次hook并移除多余数据...", LOGGER_COMMAND
|
||||
)
|
||||
ids = await BanConsole.filter(user_id=entity.user_id, group_id="").values_list(
|
||||
"id", flat=True
|
||||
)
|
||||
await BanConsole.filter(id__in=ids[:-1]).delete()
|
||||
await cache.reload()
|
||||
|
||||
|
||||
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo):
|
||||
if not check_plugin_type(matcher):
|
||||
return
|
||||
if not matcher.plugin_name:
|
||||
return
|
||||
entity = get_entity_ids(session)
|
||||
if entity.user_id in bot.config.superusers:
|
||||
return
|
||||
cache = Cache[list[BanConsole]](CacheType.BAN)
|
||||
if entity.group_id:
|
||||
await group_handle(cache, entity.group_id)
|
||||
if entity.user_id:
|
||||
await user_handle(matcher.plugin_name, cache, entity, session)
|
||||
28
zhenxun/builtin_plugins/hooks/auth/auth_bot.py
Normal file
28
zhenxun/builtin_plugins/hooks/auth/auth_bot.py
Normal file
@ -0,0 +1,28 @@
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.cache import Cache
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
from .exception import SkipPluginException
|
||||
|
||||
|
||||
async def auth_bot(plugin: PluginInfo, bot_id: str):
|
||||
"""bot层面的权限检查
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
bot_id: bot id
|
||||
|
||||
异常:
|
||||
SkipPluginException: 忽略插件
|
||||
SkipPluginException: 忽略插件
|
||||
"""
|
||||
if cache := Cache[BotConsole](CacheType.BOT):
|
||||
bot = await cache.get(bot_id)
|
||||
if not bot or not bot.status:
|
||||
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
|
||||
if CommonUtils.format(plugin.module) in bot.block_plugins:
|
||||
raise SkipPluginException(
|
||||
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
|
||||
)
|
||||
24
zhenxun/builtin_plugins/hooks/auth/auth_cost.py
Normal file
24
zhenxun/builtin_plugins/hooks/auth/auth_cost.py
Normal file
@ -0,0 +1,24 @@
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
|
||||
from .exception import SkipPluginException
|
||||
from .utils import send_message
|
||||
|
||||
|
||||
async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> int:
|
||||
"""检测是否满足金币条件
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
|
||||
返回:
|
||||
int: 需要消耗的金币
|
||||
"""
|
||||
if user.gold < plugin.cost_gold:
|
||||
"""插件消耗金币不足"""
|
||||
await send_message(session, f"金币不足..该功能需要{plugin.cost_gold}金币..")
|
||||
raise SkipPluginException(f"{plugin.name}({plugin.module}) 金币限制...")
|
||||
return plugin.cost_gold
|
||||
35
zhenxun/builtin_plugins/hooks/auth/auth_group.py
Normal file
35
zhenxun/builtin_plugins/hooks/auth/auth_group.py
Normal file
@ -0,0 +1,35 @@
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.cache import Cache
|
||||
from zhenxun.utils.enum import CacheType
|
||||
from zhenxun.utils.utils import EntityIDs
|
||||
|
||||
from .config import SwitchEnum
|
||||
from .exception import SkipPluginException
|
||||
|
||||
|
||||
async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
|
||||
"""群黑名单检测 群总开关检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
entity: EntityIDs
|
||||
message: UniMsg
|
||||
"""
|
||||
if not entity.group_id:
|
||||
return
|
||||
text = message.extract_plain_text()
|
||||
group = await Cache[GroupConsole](CacheType.GROUPS).get(entity.group_id)
|
||||
if not group:
|
||||
raise SkipPluginException("群组信息不存在...")
|
||||
if group.level < 0:
|
||||
raise SkipPluginException("群组黑名单, 目标群组群权限权限-1...")
|
||||
if text.strip() != SwitchEnum.ENABLE and not group.status:
|
||||
raise SkipPluginException("群组休眠状态...")
|
||||
if plugin.level > group.level:
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 群等级限制,"
|
||||
f"该功能需要的群等级: {plugin.level}..."
|
||||
)
|
||||
194
zhenxun/builtin_plugins/hooks/auth/auth_limit.py
Normal file
194
zhenxun/builtin_plugins/hooks/auth/auth_limit.py
Normal file
@ -0,0 +1,194 @@
|
||||
from typing import ClassVar
|
||||
|
||||
import nonebot
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.plugin_limit import PluginLimit
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import LimitWatchType, PluginLimitType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import (
|
||||
CountLimiter,
|
||||
FreqLimiter,
|
||||
UserBlockLimiter,
|
||||
get_entity_ids,
|
||||
)
|
||||
|
||||
from .config import LOGGER_COMMAND
|
||||
from .exception import SkipPluginException
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
"""初始化限制"""
|
||||
await LimitManager.init_limit()
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
limit: PluginLimit
|
||||
limiter: FreqLimiter | UserBlockLimiter | CountLimiter
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LimitManager:
|
||||
add_module: ClassVar[list] = []
|
||||
|
||||
cd_limit: ClassVar[dict[str, Limit]] = {}
|
||||
block_limit: ClassVar[dict[str, Limit]] = {}
|
||||
count_limit: ClassVar[dict[str, Limit]] = {}
|
||||
|
||||
@classmethod
|
||||
async def init_limit(cls):
|
||||
"""初始化限制"""
|
||||
limit_list = await PluginLimit.filter(status=True).all()
|
||||
for limit in limit_list:
|
||||
cls.add_limit(limit)
|
||||
|
||||
@classmethod
|
||||
def add_limit(cls, limit: PluginLimit):
|
||||
"""添加限制
|
||||
|
||||
参数:
|
||||
limit: PluginLimit
|
||||
"""
|
||||
if limit.module not in cls.add_module:
|
||||
cls.add_module.append(limit.module)
|
||||
if limit.limit_type == PluginLimitType.BLOCK:
|
||||
cls.block_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=UserBlockLimiter()
|
||||
)
|
||||
elif limit.limit_type == PluginLimitType.CD:
|
||||
cls.cd_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=FreqLimiter(limit.cd)
|
||||
)
|
||||
elif limit.limit_type == PluginLimitType.COUNT:
|
||||
cls.count_limit[limit.module] = Limit(
|
||||
limit=limit, limiter=CountLimiter(limit.max_count)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unblock(
|
||||
cls, module: str, user_id: str, group_id: str | None, channel_id: str | None
|
||||
):
|
||||
"""解除插件block
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
"""
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
limit = limit_model.limit
|
||||
limiter: UserBlockLimiter = limit_model.limiter # type: ignore
|
||||
key_type = user_id
|
||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||
key_type = channel_id or group_id
|
||||
logger.debug(
|
||||
f"解除对象: {key_type} 的block限制",
|
||||
LOGGER_COMMAND,
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
limiter.set_false(key_type)
|
||||
|
||||
@classmethod
|
||||
async def check(
|
||||
cls,
|
||||
module: str,
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
if limit_model := cls.cd_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
if limit_model := cls.block_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
if limit_model := cls.count_limit.get(module):
|
||||
await cls.__check(limit_model, user_id, group_id, channel_id)
|
||||
|
||||
@classmethod
|
||||
async def __check(
|
||||
cls,
|
||||
limit_model: Limit | None,
|
||||
user_id: str,
|
||||
group_id: str | None,
|
||||
channel_id: str | None,
|
||||
):
|
||||
"""检测限制
|
||||
|
||||
参数:
|
||||
limit_model: Limit
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
channel_id: 频道id
|
||||
|
||||
异常:
|
||||
IgnoredException: IgnoredException
|
||||
"""
|
||||
if not limit_model:
|
||||
return
|
||||
limit = limit_model.limit
|
||||
limiter = limit_model.limiter
|
||||
is_limit = (
|
||||
LimitWatchType.ALL
|
||||
or (group_id and limit.watch_type == LimitWatchType.GROUP)
|
||||
or (not group_id and limit.watch_type == LimitWatchType.USER)
|
||||
)
|
||||
key_type = user_id
|
||||
if group_id and limit.watch_type == LimitWatchType.GROUP:
|
||||
key_type = channel_id or group_id
|
||||
if is_limit and not limiter.check(key_type):
|
||||
if limit.result:
|
||||
await MessageUtils.build_message(limit.result).send()
|
||||
raise SkipPluginException(
|
||||
f"{limit.module}({limit.limit_type}) 正在限制中..."
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"开始进行限制 {limit.module}({limit.limit_type})...",
|
||||
LOGGER_COMMAND,
|
||||
session=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
if isinstance(limiter, FreqLimiter):
|
||||
limiter.start_cd(key_type)
|
||||
if isinstance(limiter, UserBlockLimiter):
|
||||
limiter.set_true(key_type)
|
||||
if isinstance(limiter, CountLimiter):
|
||||
limiter.increase(key_type)
|
||||
|
||||
|
||||
async def auth_limit(plugin: PluginInfo, session: Uninfo):
|
||||
"""插件限制
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
"""
|
||||
entity = get_entity_ids(session)
|
||||
if plugin.module not in LimitManager.add_module:
|
||||
limit_list = await PluginLimit.filter(module=plugin.module, status=True).all()
|
||||
for limit in limit_list:
|
||||
LimitManager.add_limit(limit)
|
||||
if entity.user_id:
|
||||
await LimitManager.check(
|
||||
plugin.module, entity.user_id, entity.group_id, entity.channel_id
|
||||
)
|
||||
147
zhenxun/builtin_plugins/hooks/auth/auth_plugin.py
Normal file
147
zhenxun/builtin_plugins/hooks/auth/auth_plugin.py
Normal file
@ -0,0 +1,147 @@
|
||||
from nonebot.adapters import Event
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.cache import Cache
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.enum import BlockType, CacheType
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
from .exception import IsSuperuserException, SkipPluginException
|
||||
from .utils import freq, is_poke, send_message
|
||||
|
||||
|
||||
class GroupCheck:
|
||||
def __init__(
|
||||
self, plugin: PluginInfo, group_id: str, session: Uninfo, is_poke: bool
|
||||
) -> None:
|
||||
self.group_id = group_id
|
||||
self.session = session
|
||||
self.is_poke = is_poke
|
||||
self.plugin = plugin
|
||||
|
||||
async def __get_data(self):
|
||||
cache = Cache[GroupConsole](CacheType.GROUPS)
|
||||
return await cache.get(self.group_id)
|
||||
|
||||
async def check(self):
|
||||
await self.check_superuser_block(self.plugin)
|
||||
|
||||
async def check_superuser_block(self, plugin: PluginInfo):
|
||||
"""超级用户禁用群组插件检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
异常:
|
||||
IgnoredException: 忽略插件
|
||||
"""
|
||||
group = await self.__get_data()
|
||||
if group and CommonUtils.format(plugin.module) in group.superuser_block_plugin:
|
||||
if freq.is_send_limit_message(plugin, group.group_id, self.is_poke):
|
||||
await send_message(
|
||||
self.session, "超级管理员禁用了该群此功能...", self.group_id
|
||||
)
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 超级管理员禁用了该群此功能..."
|
||||
)
|
||||
await self.check_normal_block(self.plugin)
|
||||
|
||||
async def check_normal_block(self, plugin: PluginInfo):
|
||||
"""群组插件状态
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
异常:
|
||||
IgnoredException: 忽略插件
|
||||
"""
|
||||
group = await self.__get_data()
|
||||
if group and CommonUtils.format(plugin.module) in group.block_plugin:
|
||||
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
|
||||
await send_message(self.session, "该群未开启此功能...", self.group_id)
|
||||
raise SkipPluginException(f"{plugin.name}({plugin.module}) 未开启此功能...")
|
||||
await self.check_global_block(self.plugin)
|
||||
|
||||
async def check_global_block(self, plugin: PluginInfo):
|
||||
"""全局禁用插件检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
异常:
|
||||
IgnoredException: 忽略插件
|
||||
"""
|
||||
if plugin.block_type == BlockType.GROUP:
|
||||
"""全局群组禁用"""
|
||||
if freq.is_send_limit_message(plugin, self.group_id, self.is_poke):
|
||||
await send_message(
|
||||
self.session, "该功能在群组中已被禁用...", self.group_id
|
||||
)
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 该插件在群组中已被禁用..."
|
||||
)
|
||||
|
||||
|
||||
class PluginCheck:
|
||||
def __init__(self, group_id: str | None, session: Uninfo, is_poke: bool):
|
||||
self.session = session
|
||||
self.is_poke = is_poke
|
||||
self.group_id = group_id
|
||||
|
||||
async def check_user(self, plugin: PluginInfo):
|
||||
"""全局私聊禁用检测
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
异常:
|
||||
IgnoredException: 忽略插件
|
||||
"""
|
||||
if plugin.block_type == BlockType.PRIVATE:
|
||||
if freq.is_send_limit_message(plugin, self.session.user.id, self.is_poke):
|
||||
await send_message(self.session, "该功能在私聊中已被禁用...")
|
||||
raise SkipPluginException(
|
||||
f"{plugin.name}({plugin.module}) 该插件在私聊中已被禁用..."
|
||||
)
|
||||
|
||||
async def check_global(self, plugin: PluginInfo):
|
||||
"""全局状态
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
|
||||
异常:
|
||||
IgnoredException: 忽略插件
|
||||
"""
|
||||
if plugin.status or plugin.block_type != BlockType.ALL:
|
||||
return
|
||||
"""全局状态"""
|
||||
cache = Cache[GroupConsole](CacheType.GROUPS)
|
||||
if self.group_id and (group := await cache.get(self.group_id)):
|
||||
if group.is_super:
|
||||
raise IsSuperuserException()
|
||||
sid = self.group_id or self.session.user.id
|
||||
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
||||
await send_message(self.session, "全局未开启此功能...", sid)
|
||||
raise SkipPluginException(f"{plugin.name}({plugin.module}) 全局未开启此功能...")
|
||||
|
||||
|
||||
async def auth_plugin(plugin: PluginInfo, session: Uninfo, event: Event):
|
||||
"""插件状态
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
session: Uninfo
|
||||
event: Event
|
||||
"""
|
||||
entity = get_entity_ids(session)
|
||||
is_poke_event = is_poke(event)
|
||||
user_check = PluginCheck(entity.group_id, session, is_poke_event)
|
||||
if entity.group_id:
|
||||
group_check = GroupCheck(plugin, entity.group_id, session, is_poke_event)
|
||||
await group_check.check()
|
||||
else:
|
||||
await user_check.check_user(plugin)
|
||||
await user_check.check_global(plugin)
|
||||
35
zhenxun/builtin_plugins/hooks/auth/bot_filter.py
Normal file
35
zhenxun/builtin_plugins/hooks/auth/bot_filter.py
Normal file
@ -0,0 +1,35 @@
|
||||
import nonebot
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
|
||||
from .exception import SkipPluginException
|
||||
|
||||
Config.add_plugin_config(
|
||||
"hook",
|
||||
"FILTER_BOT",
|
||||
True,
|
||||
help="过滤当前连接bot(防止bot互相调用)",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
)
|
||||
|
||||
|
||||
def bot_filter(session: Uninfo):
|
||||
"""过滤bot调用bot
|
||||
|
||||
参数:
|
||||
session: Uninfo
|
||||
|
||||
异常:
|
||||
SkipPluginException: bot互相调用
|
||||
"""
|
||||
if not Config.get_config("hook", "FILTER_BOT"):
|
||||
return
|
||||
bot_ids = list(nonebot.get_bots().keys())
|
||||
if session.user.id == session.self_id:
|
||||
return
|
||||
if session.user.id in bot_ids:
|
||||
raise SkipPluginException(
|
||||
f"bot:{session.self_id} 尝试调用 bot:{session.user.id}"
|
||||
)
|
||||
13
zhenxun/builtin_plugins/hooks/auth/config.py
Normal file
13
zhenxun/builtin_plugins/hooks/auth/config.py
Normal file
@ -0,0 +1,13 @@
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
from strenum import StrEnum
|
||||
|
||||
LOGGER_COMMAND = "AuthChecker"
|
||||
|
||||
|
||||
class SwitchEnum(StrEnum):
|
||||
ENABLE = "醒来"
|
||||
DISABLE = "休息吧"
|
||||
26
zhenxun/builtin_plugins/hooks/auth/exception.py
Normal file
26
zhenxun/builtin_plugins/hooks/auth/exception.py
Normal file
@ -0,0 +1,26 @@
|
||||
class IsSuperuserException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SkipPluginException(Exception):
|
||||
def __init__(self, info: str, *args: object) -> None:
|
||||
super().__init__(*args)
|
||||
self.info = info
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.info
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.info
|
||||
|
||||
|
||||
class PermissionExemption(Exception):
|
||||
def __init__(self, info: str, *args: object) -> None:
|
||||
super().__init__(*args)
|
||||
self.info = info
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.info
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.info
|
||||
91
zhenxun/builtin_plugins/hooks/auth/utils.py
Normal file
91
zhenxun/builtin_plugins/hooks/auth/utils.py
Normal file
@ -0,0 +1,91 @@
|
||||
import contextlib
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import FreqLimiter
|
||||
|
||||
from .config import LOGGER_COMMAND
|
||||
|
||||
base_config = Config.get("hook")
|
||||
|
||||
|
||||
def is_poke(event: Event) -> bool:
|
||||
"""判断是否为poke类型
|
||||
|
||||
参数:
|
||||
event: Event
|
||||
|
||||
返回:
|
||||
bool: 是否为poke类型
|
||||
"""
|
||||
with contextlib.suppress(ImportError):
|
||||
from nonebot.adapters.onebot.v11 import PokeNotifyEvent
|
||||
|
||||
return isinstance(event, PokeNotifyEvent)
|
||||
return False
|
||||
|
||||
|
||||
async def send_message(
|
||||
session: Uninfo, message: list | str, check_tag: str | None = None
|
||||
):
|
||||
"""发送消息
|
||||
|
||||
参数:
|
||||
session: Uninfo
|
||||
message: 消息
|
||||
check_tag: cd flag
|
||||
"""
|
||||
try:
|
||||
if not check_tag:
|
||||
await MessageUtils.build_message(message).send(reply_to=True)
|
||||
elif freq._flmt.check(check_tag):
|
||||
freq._flmt.start_cd(check_tag)
|
||||
await MessageUtils.build_message(message).send(reply_to=True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"发送消息失败",
|
||||
LOGGER_COMMAND,
|
||||
session=session,
|
||||
e=e,
|
||||
)
|
||||
|
||||
|
||||
class FreqUtils:
|
||||
def __init__(self):
|
||||
check_notice_info_cd = Config.get_config("hook", "CHECK_NOTICE_INFO_CD")
|
||||
if check_notice_info_cd is None or check_notice_info_cd < 0:
|
||||
raise ValueError("模块: [hook], 配置项: [CHECK_NOTICE_INFO_CD] 为空或小于0")
|
||||
self._flmt = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_g = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_s = FreqLimiter(check_notice_info_cd)
|
||||
self._flmt_c = FreqLimiter(check_notice_info_cd)
|
||||
|
||||
def is_send_limit_message(
|
||||
self, plugin: PluginInfo, sid: str, is_poke: bool
|
||||
) -> bool:
|
||||
"""是否发送提示消息
|
||||
|
||||
参数:
|
||||
plugin: PluginInfo
|
||||
sid: 检测键
|
||||
is_poke: 是否是戳一戳
|
||||
|
||||
返回:
|
||||
bool: 是否发送提示消息
|
||||
"""
|
||||
if is_poke:
|
||||
return False
|
||||
if not base_config.get("IS_SEND_TIP_MESSAGE"):
|
||||
return False
|
||||
if plugin.plugin_type == PluginType.DEPENDANT:
|
||||
return False
|
||||
return plugin.module != "ai" if self._flmt_s.check(sid) else False
|
||||
|
||||
|
||||
freq = FreqUtils()
|
||||
176
zhenxun/builtin_plugins/hooks/auth_checker.py
Normal file
176
zhenxun/builtin_plugins/hooks/auth_checker.py
Normal file
@ -0,0 +1,176 @@
|
||||
import asyncio
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.exception import IgnoredException
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from tortoise.exceptions import IntegrityError
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.cache import Cache
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import (
|
||||
CacheType,
|
||||
GoldHandle,
|
||||
PluginType,
|
||||
)
|
||||
from zhenxun.utils.exception import InsufficientGold
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
from zhenxun.utils.utils import get_entity_ids
|
||||
|
||||
from .auth.auth_admin import auth_admin
|
||||
from .auth.auth_ban import auth_ban
|
||||
from .auth.auth_bot import auth_bot
|
||||
from .auth.auth_cost import auth_cost
|
||||
from .auth.auth_group import auth_group
|
||||
from .auth.auth_limit import LimitManager, auth_limit
|
||||
from .auth.auth_plugin import auth_plugin
|
||||
from .auth.bot_filter import bot_filter
|
||||
from .auth.config import LOGGER_COMMAND
|
||||
from .auth.exception import (
|
||||
IsSuperuserException,
|
||||
PermissionExemption,
|
||||
SkipPluginException,
|
||||
)
|
||||
|
||||
|
||||
async def get_plugin_and_user(
|
||||
module: str, user_id: str
|
||||
) -> tuple[PluginInfo, UserConsole]:
|
||||
"""获取用户数据和插件信息
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
user_id: 用户id
|
||||
|
||||
异常:
|
||||
PermissionExemption: 插件数据不存在
|
||||
PermissionExemption: 插件类型为HIDDEN
|
||||
PermissionExemption: 重复创建用户
|
||||
PermissionExemption: 用户数据不存在
|
||||
|
||||
返回:
|
||||
tuple[PluginInfo, UserConsole]: 插件信息,用户信息
|
||||
"""
|
||||
user_cache = Cache[UserConsole](CacheType.USERS)
|
||||
plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
|
||||
if not plugin:
|
||||
raise PermissionExemption(f"插件:{module} 数据不存在,已跳过权限检查...")
|
||||
if plugin.plugin_type == PluginType.HIDDEN:
|
||||
raise PermissionExemption(
|
||||
f"插件: {plugin.name}:{plugin.module} 为HIDDEN,已跳过权限检查..."
|
||||
)
|
||||
user = None
|
||||
try:
|
||||
user = await user_cache.get(user_id)
|
||||
except IntegrityError as e:
|
||||
raise PermissionExemption("重复创建用户,已跳过该次权限检查...") from e
|
||||
if not user:
|
||||
raise PermissionExemption("用户数据不存在,已跳过权限检查...")
|
||||
return plugin, user
|
||||
|
||||
|
||||
async def get_plugin_cost(
|
||||
bot: Bot, user: UserConsole, plugin: PluginInfo, session: Uninfo
|
||||
) -> int:
|
||||
"""获取插件费用
|
||||
|
||||
参数:
|
||||
bot: Bot
|
||||
user: 用户数据
|
||||
plugin: 插件数据
|
||||
session: Uninfo
|
||||
|
||||
异常:
|
||||
IsSuperuserException: 超级用户
|
||||
IsSuperuserException: 超级用户
|
||||
|
||||
返回:
|
||||
int: 调用插件金币费用
|
||||
"""
|
||||
cost_gold = await auth_cost(user, plugin, session)
|
||||
if session.user.id in bot.config.superusers:
|
||||
if plugin.plugin_type == PluginType.SUPERUSER:
|
||||
raise IsSuperuserException()
|
||||
if not plugin.limit_superuser:
|
||||
raise IsSuperuserException()
|
||||
return cost_gold
|
||||
|
||||
|
||||
async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo):
|
||||
"""扣除用户金币
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
module: 插件模块名称
|
||||
cost_gold: 消耗金币
|
||||
session: Uninfo
|
||||
"""
|
||||
user_cache = Cache[UserConsole](CacheType.USERS)
|
||||
try:
|
||||
await UserConsole.reduce_gold(
|
||||
user_id,
|
||||
cost_gold,
|
||||
GoldHandle.PLUGIN,
|
||||
module,
|
||||
PlatformUtils.get_platform(session),
|
||||
)
|
||||
except InsufficientGold:
|
||||
if u := await UserConsole.get_user(user_id):
|
||||
u.gold = 0
|
||||
await u.save(update_fields=["gold"])
|
||||
# 更新缓存
|
||||
await user_cache.update(user_id)
|
||||
logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session)
|
||||
|
||||
|
||||
async def auth(
|
||||
matcher: Matcher,
|
||||
event: Event,
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
message: UniMsg,
|
||||
):
|
||||
"""权限检查
|
||||
|
||||
参数:
|
||||
matcher: matcher
|
||||
event: Event
|
||||
bot: bot
|
||||
session: Uninfo
|
||||
message: UniMsg
|
||||
"""
|
||||
cost_gold = 0
|
||||
ignore_flag = False
|
||||
entity = get_entity_ids(session)
|
||||
module = matcher.plugin_name or ""
|
||||
try:
|
||||
if not module:
|
||||
raise PermissionExemption("Matcher插件名称不存在...")
|
||||
plugin, user = await get_plugin_and_user(module, entity.user_id)
|
||||
cost_gold = await get_plugin_cost(bot, user, plugin, session)
|
||||
bot_filter(session)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
auth_ban(matcher, bot, session),
|
||||
auth_bot(plugin, bot.self_id),
|
||||
auth_group(plugin, entity, message),
|
||||
auth_admin(plugin, session),
|
||||
auth_plugin(plugin, session, event),
|
||||
auth_limit(plugin, session),
|
||||
]
|
||||
)
|
||||
except SkipPluginException as e:
|
||||
LimitManager.unblock(module, entity.user_id, entity.group_id, entity.channel_id)
|
||||
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||
ignore_flag = True
|
||||
except IsSuperuserException:
|
||||
logger.debug("超级用户跳过权限检测...", LOGGER_COMMAND, session=session)
|
||||
except PermissionExemption as e:
|
||||
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||
if not ignore_flag and cost_gold > 0:
|
||||
await reduce_gold(entity.user_id, module, cost_gold, session)
|
||||
if ignore_flag:
|
||||
raise IgnoredException("权限检测 ignore")
|
||||
@ -1,41 +1,45 @@
|
||||
from nonebot.adapters.onebot.v11 import Bot, Event
|
||||
import time
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.message import run_postprocessor, run_preprocessor
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
from nonebot_plugin_session import EventSession
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from ._auth_checker import LimitManage, checker
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .auth.config import LOGGER_COMMAND
|
||||
from .auth_checker import LimitManager, auth
|
||||
|
||||
|
||||
# # 权限检测
|
||||
@run_preprocessor
|
||||
async def _(
|
||||
matcher: Matcher, event: Event, bot: Bot, session: EventSession, message: UniMsg
|
||||
matcher: Matcher, event: Event, bot: Bot, session: Uninfo, message: UniMsg
|
||||
):
|
||||
await checker.auth(
|
||||
start_time = time.time()
|
||||
await auth(
|
||||
matcher,
|
||||
event,
|
||||
bot,
|
||||
session,
|
||||
message,
|
||||
)
|
||||
logger.debug(f"权限检测耗时:{time.time() - start_time}秒", LOGGER_COMMAND)
|
||||
|
||||
|
||||
# 解除命令block阻塞
|
||||
@run_postprocessor
|
||||
async def _(
|
||||
matcher: Matcher,
|
||||
exception: Exception | None,
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
session: EventSession,
|
||||
):
|
||||
user_id = session.id1
|
||||
group_id = session.id3
|
||||
channel_id = session.id2
|
||||
if not group_id:
|
||||
group_id = channel_id
|
||||
async def _(matcher: Matcher, session: Uninfo):
|
||||
user_id = session.user.id
|
||||
group_id = None
|
||||
channel_id = None
|
||||
if session.group:
|
||||
if session.group.parent:
|
||||
group_id = session.group.parent.id
|
||||
channel_id = session.group.id
|
||||
else:
|
||||
group_id = session.group.id
|
||||
if user_id and matcher.plugin:
|
||||
module = matcher.plugin.name
|
||||
LimitManage.unblock(module, user_id, group_id, channel_id)
|
||||
LimitManager.unblock(module, user_id, group_id, channel_id)
|
||||
|
||||
@ -1,84 +0,0 @@
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.exception import IgnoredException
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.message import run_preprocessor
|
||||
from nonebot.typing import T_State
|
||||
from nonebot_plugin_alconna import At
|
||||
from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import FreqLimiter
|
||||
|
||||
Config.add_plugin_config(
|
||||
"hook",
|
||||
"BAN_RESULT",
|
||||
"才不会给你发消息.",
|
||||
help="对被ban用户发送的消息",
|
||||
)
|
||||
|
||||
_flmt = FreqLimiter(300)
|
||||
|
||||
|
||||
# 检查是否被ban
|
||||
@run_preprocessor
|
||||
async def _(
|
||||
matcher: Matcher, bot: Bot, event: Event, state: T_State, session: EventSession
|
||||
):
|
||||
extra = {}
|
||||
if plugin := matcher.plugin:
|
||||
if metadata := plugin.metadata:
|
||||
extra = metadata.extra
|
||||
if extra.get("plugin_type") in [PluginType.HIDDEN]:
|
||||
return
|
||||
user_id = session.id1
|
||||
group_id = session.id3 or session.id2
|
||||
if group_id:
|
||||
if user_id in bot.config.superusers:
|
||||
return
|
||||
if await BanConsole.is_ban(None, group_id):
|
||||
logger.debug("群组处于黑名单中...", "ban_hook")
|
||||
raise IgnoredException("群组处于黑名单中...")
|
||||
if g := await GroupConsole.get_group(group_id):
|
||||
if g.level < 0:
|
||||
logger.debug("群黑名单, 群权限-1...", "ban_hook")
|
||||
raise IgnoredException("群黑名单, 群权限-1..")
|
||||
if user_id:
|
||||
ban_result = Config.get_config("hook", "BAN_RESULT")
|
||||
if user_id in bot.config.superusers:
|
||||
return
|
||||
if await BanConsole.is_ban(user_id, group_id):
|
||||
time = await BanConsole.check_ban_time(user_id, group_id)
|
||||
if time == -1:
|
||||
time_str = "∞"
|
||||
else:
|
||||
time = abs(int(time))
|
||||
if time < 60:
|
||||
time_str = f"{time!s} 秒"
|
||||
else:
|
||||
minute = int(time / 60)
|
||||
if minute > 60:
|
||||
hours = minute // 60
|
||||
minute %= 60
|
||||
time_str = f"{hours} 小时 {minute}分钟"
|
||||
else:
|
||||
time_str = f"{minute} 分钟"
|
||||
if (
|
||||
not extra.get("ignore_prompt")
|
||||
and time != -1
|
||||
and ban_result
|
||||
and _flmt.check(user_id)
|
||||
):
|
||||
_flmt.start_cd(user_id)
|
||||
await MessageUtils.build_message(
|
||||
[
|
||||
At(flag="user", target=user_id),
|
||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||
]
|
||||
).send()
|
||||
logger.debug("用户处于黑名单中...", "ban_hook")
|
||||
raise IgnoredException("用户处于黑名单中...")
|
||||
@ -1,23 +1,85 @@
|
||||
from typing import Any
|
||||
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.adapters import Bot, Message
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.bot_message_store import BotMessageStore
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import BotSentType
|
||||
from zhenxun.utils.manager.message_manager import MessageManager
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
|
||||
def replace_message(message: Message) -> str:
|
||||
"""将消息中的at、image、record、face替换为字符串
|
||||
|
||||
参数:
|
||||
message: Message
|
||||
|
||||
返回:
|
||||
str: 文本消息
|
||||
"""
|
||||
result = ""
|
||||
for msg in message:
|
||||
if isinstance(msg, str):
|
||||
result += msg
|
||||
elif msg.type == "at":
|
||||
result += f"@{msg.data['qq']}"
|
||||
elif msg.type == "image":
|
||||
result += "[image]"
|
||||
elif msg.type == "record":
|
||||
result += "[record]"
|
||||
elif msg.type == "face":
|
||||
result += f"[face:{msg.data['id']}]"
|
||||
elif msg.type == "reply":
|
||||
result += ""
|
||||
else:
|
||||
result += str(msg)
|
||||
return result
|
||||
|
||||
|
||||
@Bot.on_called_api
|
||||
async def handle_api_result(
|
||||
bot: Bot, exception: Exception | None, api: str, data: dict[str, Any], result: Any
|
||||
):
|
||||
if not exception and api == "send_msg":
|
||||
if exception or api != "send_msg":
|
||||
return
|
||||
user_id = data.get("user_id")
|
||||
group_id = data.get("group_id")
|
||||
message_id = result.get("message_id")
|
||||
message: Message = data.get("message", "")
|
||||
message_type = data.get("message_type")
|
||||
try:
|
||||
if (uid := data.get("user_id")) and (msg_id := result.get("message_id")):
|
||||
MessageManager.add(str(uid), str(msg_id))
|
||||
# 记录消息id
|
||||
if user_id and message_id:
|
||||
MessageManager.add(str(user_id), str(message_id))
|
||||
logger.debug(
|
||||
f"收集消息id,user_id: {uid}, msg_id: {msg_id}", "msg_hook"
|
||||
f"收集消息id,user_id: {user_id}, msg_id: {message_id}", "msg_hook"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e
|
||||
)
|
||||
if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"):
|
||||
return
|
||||
try:
|
||||
await BotMessageStore.create(
|
||||
bot_id=bot.self_id,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
sent_type=BotSentType.GROUP
|
||||
if message_type == "group"
|
||||
else BotSentType.PRIVATE,
|
||||
text=replace_message(message),
|
||||
plain_text=message.extract_plain_text()
|
||||
if isinstance(message, Message)
|
||||
else replace_message(message),
|
||||
platform=PlatformUtils.get_platform(bot),
|
||||
)
|
||||
logger.debug(f"消息发送记录,message: {message}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"消息发送记录发生错误...data: {data}, result: {result}",
|
||||
"msg_hook",
|
||||
e=e,
|
||||
)
|
||||
|
||||
@ -4,15 +4,25 @@ import nonebot
|
||||
from nonebot.adapters import Bot
|
||||
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.services.cache import DbCacheException
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
||||
|
||||
try:
|
||||
from .__init_cache import CacheRoot
|
||||
except DbCacheException as e:
|
||||
raise SystemError(f"ERROR:{e}")
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
await CacheRoot.init_non_lazy_caches()
|
||||
|
||||
|
||||
@driver.on_bot_connect
|
||||
async def _(bot: Bot):
|
||||
"""将bot已存在的群组添加群认证
|
||||
|
||||
208
zhenxun/builtin_plugins/init/__init_cache.py
Normal file
208
zhenxun/builtin_plugins/init/__init_cache.py
Normal file
@ -0,0 +1,208 @@
|
||||
from zhenxun.models.ban_console import BanConsole
|
||||
from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.cache import CacheData, CacheRoot
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.PLUGINS)
|
||||
async def _():
|
||||
"""初始化插件缓存"""
|
||||
data_list = await PluginInfo.get_plugins()
|
||||
return {p.module: p for p in data_list}
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
|
||||
async def _(cache_data: CacheData, module: str):
|
||||
"""获取插件缓存"""
|
||||
data = await cache_data.get_key(module)
|
||||
if not data:
|
||||
if plugin := await PluginInfo.get_plugin(module=module):
|
||||
await cache_data.set_key(module, plugin)
|
||||
logger.debug(f"插件 {module} 数据已设置到缓存")
|
||||
return plugin
|
||||
return data
|
||||
|
||||
|
||||
@CacheRoot.with_refresh(CacheType.PLUGINS)
|
||||
async def _(cache_data: CacheData, data: dict[str, PluginInfo] | None):
|
||||
"""刷新插件缓存"""
|
||||
if not data:
|
||||
return
|
||||
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all()
|
||||
for plugin in plugins:
|
||||
await cache_data.set_key(plugin.module, plugin)
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.GROUPS)
|
||||
async def _():
|
||||
"""初始化群组缓存"""
|
||||
data_list = await GroupConsole.all()
|
||||
return {p.group_id: p for p in data_list if not p.channel_id}
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
|
||||
async def _(cache_data: CacheData, group_id: str):
|
||||
"""获取群组缓存"""
|
||||
data = await cache_data.get_key(group_id)
|
||||
if not data:
|
||||
if group := await GroupConsole.get_group(group_id=group_id):
|
||||
await cache_data.set_key(group_id, group)
|
||||
return group
|
||||
return data
|
||||
|
||||
|
||||
@CacheRoot.with_refresh(CacheType.GROUPS)
|
||||
async def _(cache_data: CacheData, data: dict[str, GroupConsole] | None):
|
||||
"""刷新群组缓存"""
|
||||
if not data:
|
||||
return
|
||||
groups = await GroupConsole.filter(
|
||||
group_id__in=data.keys(), channel_id__isnull=True
|
||||
).all()
|
||||
for group in groups:
|
||||
await cache_data.set_key(group.group_id, group)
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.BOT)
|
||||
async def _():
|
||||
"""初始化机器人缓存"""
|
||||
data_list = await BotConsole.all()
|
||||
return {p.bot_id: p for p in data_list}
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
|
||||
async def _(cache_data: CacheData, bot_id: str):
|
||||
"""获取机器人缓存"""
|
||||
data = await cache_data.get_key(bot_id)
|
||||
if not data:
|
||||
if bot := await BotConsole.get_or_none(bot_id=bot_id):
|
||||
await cache_data.set_key(bot_id, bot)
|
||||
return bot
|
||||
return data
|
||||
|
||||
|
||||
@CacheRoot.with_refresh(CacheType.BOT)
|
||||
async def _(cache_data: CacheData, data: dict[str, BotConsole] | None):
|
||||
"""刷新机器人缓存"""
|
||||
if not data:
|
||||
return
|
||||
bots = await BotConsole.filter(bot_id__in=data.keys()).all()
|
||||
for bot in bots:
|
||||
await cache_data.set_key(bot.bot_id, bot)
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.USERS)
|
||||
async def _():
|
||||
"""初始化用户缓存"""
|
||||
data_list = await UserConsole.all()
|
||||
return {p.user_id: p for p in data_list}
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
|
||||
async def _(cache_data: CacheData, user_id: str):
|
||||
"""获取用户缓存"""
|
||||
data = await cache_data.get_key(user_id)
|
||||
if not data:
|
||||
if user := await UserConsole.get_user(user_id=user_id):
|
||||
await cache_data.set_key(user_id, user)
|
||||
return user
|
||||
return data
|
||||
|
||||
|
||||
@CacheRoot.with_refresh(CacheType.USERS)
|
||||
async def _(cache_data: CacheData, data: dict[str, UserConsole] | None):
|
||||
"""刷新用户缓存"""
|
||||
if not data:
|
||||
return
|
||||
users = await UserConsole.filter(user_id__in=data.keys()).all()
|
||||
for user in users:
|
||||
await cache_data.set_key(user.user_id, user)
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.LEVEL)
|
||||
async def _():
|
||||
"""初始化等级缓存"""
|
||||
data_list = await LevelUser().all()
|
||||
return {f"{d.user_id}:{d.group_id or ''}": d for d in data_list}
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
|
||||
async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
|
||||
"""获取等级缓存"""
|
||||
key = f"{user_id}:{group_id or ''}"
|
||||
data = await cache_data.get_key(key)
|
||||
if not data:
|
||||
if group_id:
|
||||
data = await LevelUser.filter(user_id=user_id, group_id=group_id).all()
|
||||
else:
|
||||
data = await LevelUser.filter(user_id=user_id, group_id__isnull=True).all()
|
||||
if data:
|
||||
await cache_data.set_key(key, data)
|
||||
return data
|
||||
return data or []
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.BAN, False)
|
||||
async def _():
|
||||
"""初始化封禁缓存"""
|
||||
data_list = await BanConsole.all()
|
||||
return {f"{d.group_id or ''}:{d.user_id or ''}": d for d in data_list}
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.BAN, result_model=BanConsole)
|
||||
async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
|
||||
"""获取封禁缓存"""
|
||||
if not user_id and not group_id:
|
||||
return []
|
||||
key = f"{group_id or ''}:{user_id or ''}"
|
||||
data = await cache_data.get_key(key)
|
||||
# if not data:
|
||||
# start = time.time()
|
||||
# if user_id and group_id:
|
||||
# data = await BanConsole.filter(user_id=user_id, group_id=group_id).all()
|
||||
# elif user_id:
|
||||
# data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all()
|
||||
# elif group_id:
|
||||
# data = await BanConsole.filter(
|
||||
# user_id__isnull=True, group_id=group_id
|
||||
# ).all()
|
||||
# logger.info(
|
||||
# f"获取封禁缓存耗时: {time.time() - start:.2f}秒, key: {key}, data: {data}"
|
||||
# )
|
||||
# if data:
|
||||
# await cache_data.set_key(key, data)
|
||||
# return data
|
||||
return data or []
|
||||
|
||||
|
||||
# @CacheRoot.new(CacheType.LIMIT)
|
||||
# async def _():
|
||||
# """初始化限制缓存"""
|
||||
# data_list = await PluginLimit.filter(status=True).all()
|
||||
# return {data.module: data for data in data_list}
|
||||
|
||||
|
||||
# @CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
|
||||
# async def _(cache_data: CacheData, module: str):
|
||||
# """获取限制缓存"""
|
||||
# data = await cache_data.get_key(module)
|
||||
# if not data:
|
||||
# if limits := await PluginLimit.filter(module=module, status=True):
|
||||
# await cache_data.set_key(module, limits)
|
||||
# return limits
|
||||
# return data or []
|
||||
|
||||
|
||||
# @CacheRoot.with_refresh(CacheType.LIMIT)
|
||||
# async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None):
|
||||
# """刷新限制缓存"""
|
||||
# if not data:
|
||||
# return
|
||||
# limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
|
||||
# for limit in limits:
|
||||
# await cache_data.set_key(limit.module, limit)
|
||||
@ -11,6 +11,7 @@ from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.path_config import DATA_PATH
|
||||
from zhenxun.configs.utils import RegisterConfig
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
|
||||
_yaml = YAML(pure=True)
|
||||
_yaml.allow_unicode = True
|
||||
@ -57,7 +58,7 @@ def _generate_simple_config(exists_module: list[str]):
|
||||
生成简易配置
|
||||
|
||||
异常:
|
||||
AttributeError: _description_
|
||||
AttributeError: AttributeError
|
||||
"""
|
||||
# 读取用户配置
|
||||
_data = {}
|
||||
@ -73,7 +74,9 @@ def _generate_simple_config(exists_module: list[str]):
|
||||
if _data.get(module) and k in _data[module].keys():
|
||||
Config.set_config(module, k, _data[module][k])
|
||||
if f"{module}:{k}".lower() in exists_module:
|
||||
_tmp_data[module][k] = Config.get_config(module, k)
|
||||
_tmp_data[module][k] = Config.get_config(
|
||||
module, k, build_model=False
|
||||
)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f"{e}\n可能为config.yaml配置文件填写不规范") from e
|
||||
if not _tmp_data[module]:
|
||||
@ -102,7 +105,7 @@ def _generate_simple_config(exists_module: list[str]):
|
||||
temp_file.unlink()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
@PriorityLifecycle.on_startup(priority=0)
|
||||
def _():
|
||||
"""
|
||||
初始化插件数据配置
|
||||
@ -125,3 +128,4 @@ def _():
|
||||
with plugins2config_file.open("w", encoding="utf8") as wf:
|
||||
_yaml.dump(_data, wf)
|
||||
_generate_simple_config(exists_module)
|
||||
Config.reload()
|
||||
|
||||
259
zhenxun/builtin_plugins/mahiro_bank/__init__.py
Normal file
259
zhenxun/builtin_plugins/mahiro_bank/__init__.py
Normal file
@ -0,0 +1,259 @@
|
||||
from datetime import datetime
|
||||
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot_plugin_alconna import Alconna, Args, Arparma, Match, Subcommand, on_alconna
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from nonebot_plugin_waiter import prompt_until
|
||||
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.depends import UserName
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import is_number
|
||||
|
||||
from .data_source import BankManager
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="小真寻银行",
|
||||
description="""
|
||||
小真寻银行,提供高品质的存款!当好感度等级达到指初识时,小真寻会偷偷的帮助你哦。
|
||||
存款额度与好感度有关,每日存款次数有限制。
|
||||
基础存款提供基础利息
|
||||
每日存款提供高额利息
|
||||
""".strip(),
|
||||
usage="""
|
||||
指令:
|
||||
存款 [金额]
|
||||
取款 [金额]
|
||||
银行信息
|
||||
我的银行信息
|
||||
""".strip(),
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="0.1",
|
||||
menu_type="群内小游戏",
|
||||
configs=[
|
||||
RegisterConfig(
|
||||
key="sign_max_deposit",
|
||||
value=100,
|
||||
help="好感度换算存款金额比例,当值是100时,最大存款金额=好感度*100,存款的最低金额是100(强制)",
|
||||
default_value=100,
|
||||
type=int,
|
||||
),
|
||||
RegisterConfig(
|
||||
key="max_daily_deposit_count",
|
||||
value=3,
|
||||
help="每日最大存款次数",
|
||||
default_value=3,
|
||||
type=int,
|
||||
),
|
||||
RegisterConfig(
|
||||
key="rate_range",
|
||||
value=[0.0005, 0.001],
|
||||
help="小时利率范围",
|
||||
default_value=[0.0005, 0.001],
|
||||
type=list[float],
|
||||
),
|
||||
RegisterConfig(
|
||||
key="impression_event",
|
||||
value=25,
|
||||
help="到达指定好感度时随机提高或降低利率",
|
||||
default_value=25,
|
||||
type=int,
|
||||
),
|
||||
RegisterConfig(
|
||||
key="impression_event_range",
|
||||
value=[0.00001, 0.0003],
|
||||
help="到达指定好感度时随机提高或降低利率",
|
||||
default_value=[0.00001, 0.0003],
|
||||
type=list[float],
|
||||
),
|
||||
RegisterConfig(
|
||||
key="impression_event_prop",
|
||||
value=0.3,
|
||||
help="到达指定好感度时随机提高或降低利率触发概率",
|
||||
default_value=0.3,
|
||||
type=float,
|
||||
),
|
||||
],
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
|
||||
_matcher = on_alconna(
|
||||
Alconna(
|
||||
"mahiro-bank",
|
||||
Subcommand("deposit", Args["amount?", int]),
|
||||
Subcommand("withdraw", Args["amount?", int]),
|
||||
Subcommand("user-info"),
|
||||
Subcommand("bank-info"),
|
||||
# Subcommand("loan", Args["amount?", int]),
|
||||
# Subcommand("repayment", Args["amount?", int]),
|
||||
),
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
_matcher.shortcut(
|
||||
r"1111",
|
||||
command="mahiro-bank",
|
||||
arguments=["test"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
_matcher.shortcut(
|
||||
r"存款\s*(?P<amount>\d+)?",
|
||||
command="mahiro-bank",
|
||||
arguments=["deposit", "{amount}"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
_matcher.shortcut(
|
||||
r"取款\s*(?P<withdraw>\d+)?",
|
||||
command="mahiro-bank",
|
||||
arguments=["withdraw", "{withdraw}"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
_matcher.shortcut(
|
||||
r"我的银行信息",
|
||||
command="mahiro-bank",
|
||||
arguments=["user-info"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
_matcher.shortcut(
|
||||
r"银行信息",
|
||||
command="mahiro-bank",
|
||||
arguments=["bank-info"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
|
||||
async def get_amount(handle_type: str) -> int:
|
||||
amount_num = await prompt_until(
|
||||
f"请输入{handle_type}金币数量",
|
||||
lambda msg: is_number(msg.extract_plain_text()),
|
||||
timeout=60,
|
||||
retry=3,
|
||||
retry_prompt="输入错误,请输入数字。剩余次数:{count}",
|
||||
)
|
||||
if not amount_num:
|
||||
await MessageUtils.build_message(
|
||||
"输入超时了哦,小真寻柜员以取消本次存款操作..."
|
||||
).finish()
|
||||
return int(amount_num.extract_plain_text())
|
||||
|
||||
|
||||
@_matcher.assign("deposit")
|
||||
async def _(session: Uninfo, arparma: Arparma, amount: Match[int]):
|
||||
amount_num = amount.result if amount.available else await get_amount("存款")
|
||||
if result := await BankManager.deposit_check(session.user.id, amount_num):
|
||||
await MessageUtils.build_message(result).finish(reply_to=True)
|
||||
_, rate, event_rate = await BankManager.deposit(session.user.id, amount_num)
|
||||
result = (
|
||||
f"存款成功!\n此次存款金额为: {amount.result}\n"
|
||||
f"当前小时利率为: {rate * 100:.2f}%"
|
||||
)
|
||||
effective_hour = int(24 - datetime.now().hour)
|
||||
if event_rate:
|
||||
result += f"(小真寻偷偷将小时利率给你增加了 {event_rate:.2f}% 哦)"
|
||||
result += (
|
||||
f"\n预计总收益为: {int(amount.result * rate * effective_hour) or 1} 金币。"
|
||||
)
|
||||
logger.info(
|
||||
f"小真寻银行存款:{amount_num},当前存款数:{amount.result},存款小时利率: {rate}",
|
||||
arparma.header_result,
|
||||
session=session,
|
||||
)
|
||||
await MessageUtils.build_message(result).finish(at_sender=True)
|
||||
|
||||
|
||||
@_matcher.assign("withdraw")
|
||||
async def _(session: Uninfo, arparma: Arparma, amount: Match[int]):
|
||||
amount_num = amount.result if amount.available else await get_amount("取款")
|
||||
if result := await BankManager.withdraw_check(session.user.id, amount_num):
|
||||
await MessageUtils.build_message(result).finish(reply_to=True)
|
||||
try:
|
||||
user = await BankManager.withdraw(session.user.id, amount_num)
|
||||
result = (
|
||||
f"取款成功!\n当前取款金额为: {amount_num}\n当前存款金额为: {user.amount}"
|
||||
)
|
||||
logger.info(
|
||||
f"小真寻银行取款:{amount_num}, 当前存款数:{user.amount},"
|
||||
f" 存款小时利率:{user.rate}",
|
||||
arparma.header_result,
|
||||
session=session,
|
||||
)
|
||||
await MessageUtils.build_message(result).finish(reply_to=True)
|
||||
except ValueError:
|
||||
await MessageUtils.build_message("你的银行内的存款数量不足哦...").finish(
|
||||
reply_to=True
|
||||
)
|
||||
|
||||
|
||||
@_matcher.assign("user-info")
|
||||
async def _(session: Uninfo, arparma: Arparma, uname: str = UserName()):
|
||||
result = await BankManager.get_user_info(session, uname)
|
||||
await MessageUtils.build_message(result).send()
|
||||
logger.info("查看银行个人信息", arparma.header_result, session=session)
|
||||
|
||||
|
||||
@_matcher.assign("bank-info")
|
||||
async def _(session: Uninfo, arparma: Arparma):
|
||||
result = await BankManager.get_bank_info()
|
||||
await MessageUtils.build_message(result).send()
|
||||
logger.info("查看银行信息", arparma.header_result, session=session)
|
||||
|
||||
|
||||
# @_matcher.assign("loan")
|
||||
# async def _(session: Uninfo, arparma: Arparma, amount: Match[int]):
|
||||
# amount_num = amount.result if amount.available else await get_amount("贷款")
|
||||
# if amount_num <= 0:
|
||||
# await MessageUtils.build_message("贷款数量必须大于 0 啊笨蛋!").finish()
|
||||
# try:
|
||||
# user, event_rate = await BankManager.loan(session.user.id, amount_num)
|
||||
# result = (
|
||||
# f"贷款成功!\n当前贷金额为: {user.loan_amount}"
|
||||
# f"\n当前利率为: {user.loan_rate * 100}%"
|
||||
# )
|
||||
# if event_rate:
|
||||
# result += f"(小真寻偷偷将利率给你降低了 {event_rate}% 哦)"
|
||||
# result += f"\n预计每小时利息为:{int(user.loan_amount * user.loan_rate)}金币。"
|
||||
# logger.info(
|
||||
# f"小真寻银行贷款: {amount_num}, 当前贷款数: {user.loan_amount}, "
|
||||
# f"贷款利率: {user.loan_rate}",
|
||||
# arparma.header_result,
|
||||
# session=session,
|
||||
# )
|
||||
# except ValueError:
|
||||
# await MessageUtils.build_message(
|
||||
# "贷款数量超过最大限制,请签到提升好感度获取更多额度吧..."
|
||||
# ).finish(reply_to=True)
|
||||
|
||||
|
||||
# @_matcher.assign("repayment")
|
||||
# async def _(session: Uninfo, arparma: Arparma, amount: Match[int]):
|
||||
# amount_num = amount.result if amount.available else await get_amount("还款")
|
||||
# if amount_num <= 0:
|
||||
# await MessageUtils.build_message("还款数量必须大于 0 啊笨蛋!").finish()
|
||||
# user = await BankManager.repayment(session.user.id, amount_num)
|
||||
# result = (f"还款成功!\n当前还款金额为: {amount_num}\n"
|
||||
# f"当前贷款金额为: {user.loan_amount}")
|
||||
# logger.info(
|
||||
# f"小真寻银行还款:{amount_num},当前贷款数:{user.amount}, 贷款利率:{user.rate}",
|
||||
# arparma.header_result,
|
||||
# session=session,
|
||||
# )
|
||||
# await MessageUtils.build_message(result).finish(at_sender=True)
|
||||
|
||||
|
||||
@scheduler.scheduled_job(
|
||||
"cron",
|
||||
hour=0,
|
||||
minute=0,
|
||||
)
|
||||
async def _():
|
||||
await BankManager.settlement()
|
||||
logger.info("小真寻银行结算", "定时任务")
|
||||
448
zhenxun/builtin_plugins/mahiro_bank/data_source.py
Normal file
448
zhenxun/builtin_plugins/mahiro_bank/data_source.py
Normal file
@ -0,0 +1,448 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
|
||||
from nonebot_plugin_htmlrender import template_to_pic
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from tortoise.expressions import RawSQL
|
||||
from tortoise.functions import Count, Sum
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.path_config import TEMPLATE_PATH
|
||||
from zhenxun.models.mahiro_bank import MahiroBank
|
||||
from zhenxun.models.mahiro_bank_log import MahiroBankLog
|
||||
from zhenxun.models.sign_user import SignUser
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.utils.enum import BankHandleType, GoldHandle
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
base_config = Config.get("mahiro_bank")
|
||||
|
||||
|
||||
class BankManager:
|
||||
@classmethod
|
||||
async def random_event(cls, impression: float):
|
||||
"""随机事件"""
|
||||
impression_event = base_config.get("impression_event")
|
||||
impression_event_prop = base_config.get("impression_event_prop")
|
||||
impression_event_range = base_config.get("impression_event_range")
|
||||
if impression >= impression_event and random.random() < impression_event_prop:
|
||||
"""触发好感度事件"""
|
||||
return random.uniform(impression_event_range[0], impression_event_range[1])
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def deposit_check(cls, user_id: str, amount: int) -> str | None:
|
||||
"""检查存款是否合法
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
amount: 存款金额
|
||||
|
||||
返回:
|
||||
str | None: 存款信息
|
||||
"""
|
||||
if amount <= 0:
|
||||
return "存款数量必须大于 0 啊笨蛋!"
|
||||
user, sign_user, bank_user = await asyncio.gather(
|
||||
*[
|
||||
UserConsole.get_user(user_id),
|
||||
SignUser.get_user(user_id),
|
||||
cls.get_user(user_id),
|
||||
]
|
||||
)
|
||||
sign_max_deposit: int = base_config.get("sign_max_deposit")
|
||||
max_deposit = max(int(float(sign_user.impression) * sign_max_deposit), 100)
|
||||
if user.gold < amount:
|
||||
return f"金币数量不足,当前你的金币为:{user.gold}."
|
||||
if bank_user.amount + amount > max_deposit:
|
||||
return (
|
||||
f"存款超过上限,存款上限为:{max_deposit},"
|
||||
f"当前你的还可以存款金额:{max_deposit - bank_user.amount}。"
|
||||
)
|
||||
max_daily_deposit_count: int = base_config.get("max_daily_deposit_count")
|
||||
today_deposit_count = len(await cls.get_user_deposit(user_id))
|
||||
if today_deposit_count >= max_daily_deposit_count:
|
||||
return f"存款次数超过上限,每日存款次数上限为:{max_daily_deposit_count}。"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def withdraw_check(cls, user_id: str, amount: int) -> str | None:
|
||||
"""检查取款是否合法
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
amount: 取款金额
|
||||
|
||||
返回:
|
||||
str | None: 取款信息
|
||||
"""
|
||||
if amount <= 0:
|
||||
return "取款数量必须大于 0 啊笨蛋!"
|
||||
user = await cls.get_user(user_id)
|
||||
data_list = await cls.get_user_deposit(user_id)
|
||||
lock_amount = sum(data.amount for data in data_list)
|
||||
if user.amount - lock_amount < amount:
|
||||
return (
|
||||
"取款金额不足,当前你的存款为:"
|
||||
f"{user.amount}({lock_amount}已被锁定)!"
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def get_user_deposit(
|
||||
cls, user_id: str, is_completed: bool = False
|
||||
) -> list[MahiroBankLog]:
|
||||
"""获取用户今日存款次数
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
|
||||
返回:
|
||||
list[MahiroBankLog]: 存款列表
|
||||
"""
|
||||
return await MahiroBankLog.filter(
|
||||
user_id=user_id,
|
||||
handle_type=BankHandleType.DEPOSIT,
|
||||
is_completed=is_completed,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_user(cls, user_id: str) -> MahiroBank:
|
||||
"""查询余额
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
|
||||
返回:
|
||||
MahiroBank
|
||||
"""
|
||||
user, _ = await MahiroBank.get_or_create(user_id=user_id)
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def get_user_data(
|
||||
cls,
|
||||
user_id: str,
|
||||
data_type: BankHandleType,
|
||||
is_completed: bool = False,
|
||||
count: int = 5,
|
||||
) -> list[MahiroBankLog]:
|
||||
return (
|
||||
await MahiroBankLog.filter(
|
||||
user_id=user_id, handle_type=data_type, is_completed=is_completed
|
||||
)
|
||||
.order_by("-id")
|
||||
.limit(count)
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def complete_projected_revenue(cls, user_id: str) -> int:
|
||||
"""预计收益
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
|
||||
返回:
|
||||
int: 预计收益金额
|
||||
"""
|
||||
deposit_list = await cls.get_user_deposit(user_id)
|
||||
if not deposit_list:
|
||||
return 0
|
||||
return int(
|
||||
sum(
|
||||
deposit.rate * deposit.amount * deposit.effective_hour
|
||||
for deposit in deposit_list
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_user_info(cls, session: Uninfo, uname: str) -> bytes:
|
||||
"""获取用户数据
|
||||
|
||||
参数:
|
||||
session: Uninfo
|
||||
uname: 用户id
|
||||
|
||||
返回:
|
||||
bytes: 图片数据
|
||||
"""
|
||||
user_id = session.user.id
|
||||
user = await cls.get_user(user_id=user_id)
|
||||
(
|
||||
rank,
|
||||
deposit_count,
|
||||
user_today_deposit,
|
||||
projected_revenue,
|
||||
sum_data,
|
||||
) = await asyncio.gather(
|
||||
*[
|
||||
MahiroBank.filter(amount__gt=user.amount).count(),
|
||||
MahiroBankLog.filter(user_id=user_id).count(),
|
||||
cls.get_user_deposit(user_id),
|
||||
cls.complete_projected_revenue(user_id),
|
||||
MahiroBankLog.filter(
|
||||
user_id=user_id, handle_type=BankHandleType.INTEREST
|
||||
)
|
||||
.annotate(sum=Sum("amount"))
|
||||
.values("sum"),
|
||||
]
|
||||
)
|
||||
now = datetime.now()
|
||||
end_time = (
|
||||
now
|
||||
+ timedelta(days=1)
|
||||
- timedelta(hours=now.hour, minutes=now.minute, seconds=now.second)
|
||||
)
|
||||
today_deposit_amount = sum(deposit.amount for deposit in user_today_deposit)
|
||||
deposit_list = [
|
||||
{
|
||||
"id": deposit.id,
|
||||
"date": now.date(),
|
||||
"start_time": str(deposit.create_time).split(".")[0],
|
||||
"end_time": end_time.replace(microsecond=0),
|
||||
"amount": deposit.amount,
|
||||
"rate": f"{deposit.rate * 100:.2f}",
|
||||
"projected_revenue": int(
|
||||
deposit.amount * deposit.rate * deposit.effective_hour
|
||||
)
|
||||
or 1,
|
||||
}
|
||||
for deposit in user_today_deposit
|
||||
]
|
||||
platform = PlatformUtils.get_platform(session)
|
||||
data = {
|
||||
"name": uname,
|
||||
"rank": rank + 1,
|
||||
"avatar_url": PlatformUtils.get_user_avatar_url(
|
||||
user_id, platform, session.self_id
|
||||
),
|
||||
"amount": user.amount,
|
||||
"deposit_count": deposit_count,
|
||||
"today_deposit_count": len(user_today_deposit),
|
||||
"cumulative_gain": sum_data[0]["sum"] or 0,
|
||||
"projected_revenue": projected_revenue,
|
||||
"today_deposit_amount": today_deposit_amount,
|
||||
"deposit_list": deposit_list,
|
||||
"create_time": now.replace(microsecond=0),
|
||||
}
|
||||
return await template_to_pic(
|
||||
template_path=str((TEMPLATE_PATH / "mahiro_bank").absolute()),
|
||||
template_name="user.html",
|
||||
templates={"data": data},
|
||||
pages={
|
||||
"viewport": {"width": 386, "height": 700},
|
||||
"base_url": f"file://{TEMPLATE_PATH}",
|
||||
},
|
||||
wait=2,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_bank_info(cls) -> bytes:
|
||||
now = datetime.now()
|
||||
now_start = datetime.now() - timedelta(
|
||||
hours=now.hour, minutes=now.minute, seconds=now.second
|
||||
)
|
||||
(
|
||||
bank_data,
|
||||
today_count,
|
||||
interest_amount,
|
||||
active_user_count,
|
||||
date_data,
|
||||
) = await asyncio.gather(
|
||||
*[
|
||||
MahiroBank.annotate(
|
||||
amount_sum=Sum("amount"), user_count=Count("id")
|
||||
).values("amount_sum", "user_count"),
|
||||
MahiroBankLog.filter(create_time__gt=now_start).count(),
|
||||
MahiroBankLog.filter(handle_type=BankHandleType.INTEREST)
|
||||
.annotate(amount_sum=Sum("amount"))
|
||||
.values("amount_sum"),
|
||||
MahiroBankLog.filter(
|
||||
create_time__gte=now_start - timedelta(days=7),
|
||||
handle_type=BankHandleType.DEPOSIT,
|
||||
)
|
||||
.annotate(count=Count("user_id", distinct=True))
|
||||
.values("count"),
|
||||
MahiroBankLog.filter(
|
||||
create_time__gte=now_start - timedelta(days=7),
|
||||
handle_type=BankHandleType.DEPOSIT,
|
||||
)
|
||||
.annotate(date=RawSQL("DATE(create_time)"), total_amount=Sum("amount"))
|
||||
.group_by("date")
|
||||
.values("date", "total_amount"),
|
||||
]
|
||||
)
|
||||
date2cnt = {str(date["date"]): date["total_amount"] for date in date_data}
|
||||
date = now.date()
|
||||
e_date, e_amount = [], []
|
||||
for _ in range(7):
|
||||
if str(date) in date2cnt:
|
||||
e_amount.append(date2cnt[str(date)])
|
||||
else:
|
||||
e_amount.append(0)
|
||||
e_date.append(str(date)[5:])
|
||||
date -= timedelta(days=1)
|
||||
e_date.reverse()
|
||||
e_amount.reverse()
|
||||
date = 1
|
||||
lasted_log = await MahiroBankLog.annotate().order_by("create_time").first()
|
||||
if lasted_log:
|
||||
date = now.date() - lasted_log.create_time.date()
|
||||
date = (date.days or 1) + 1
|
||||
data = {
|
||||
"amount_sum": bank_data[0]["amount_sum"],
|
||||
"user_count": bank_data[0]["user_count"],
|
||||
"today_count": today_count,
|
||||
"day_amount": int(bank_data[0]["amount_sum"] / date),
|
||||
"interest_amount": interest_amount[0]["amount_sum"] or 0,
|
||||
"active_user_count": active_user_count[0]["count"] or 0,
|
||||
"e_data": e_date,
|
||||
"e_amount": e_amount,
|
||||
"create_time": now.replace(microsecond=0),
|
||||
}
|
||||
return await template_to_pic(
|
||||
template_path=str((TEMPLATE_PATH / "mahiro_bank").absolute()),
|
||||
template_name="bank.html",
|
||||
templates={"data": data},
|
||||
pages={
|
||||
"viewport": {"width": 450, "height": 750},
|
||||
"base_url": f"file://{TEMPLATE_PATH}",
|
||||
},
|
||||
wait=2,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def deposit(
|
||||
cls, user_id: str, amount: int
|
||||
) -> tuple[MahiroBank, float, float | None]:
|
||||
"""存款
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
amount: 存款数量
|
||||
|
||||
返回:
|
||||
tuple[MahiroBank, float, float]: MahiroBank,利率,增加的利率
|
||||
"""
|
||||
rate_range = base_config.get("rate_range")
|
||||
rate = random.uniform(rate_range[0], rate_range[1])
|
||||
sign_user = await SignUser.get_user(user_id)
|
||||
random_add_rate = await cls.random_event(float(sign_user.impression))
|
||||
if random_add_rate:
|
||||
rate += random_add_rate
|
||||
await UserConsole.reduce_gold(user_id, amount, GoldHandle.PLUGIN, "bank")
|
||||
return await MahiroBank.deposit(user_id, amount, rate), rate, random_add_rate
|
||||
|
||||
@classmethod
|
||||
async def withdraw(cls, user_id: str, amount: int) -> MahiroBank:
|
||||
"""取款
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
amount: 取款数量
|
||||
|
||||
返回:
|
||||
MahiroBank
|
||||
"""
|
||||
await UserConsole.add_gold(user_id, amount, "bank")
|
||||
return await MahiroBank.withdraw(user_id, amount)
|
||||
|
||||
@classmethod
|
||||
async def loan(cls, user_id: str, amount: int) -> tuple[MahiroBank, float | None]:
|
||||
"""贷款
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
amount: 贷款数量
|
||||
|
||||
返回:
|
||||
tuple[MahiroBank, float]: MahiroBank,贷款利率
|
||||
"""
|
||||
rate_range = base_config.get("rate_range")
|
||||
rate = random.uniform(rate_range[0], rate_range[1])
|
||||
sign_user = await SignUser.get_user(user_id)
|
||||
user, _ = await MahiroBank.get_or_create(user_id=user_id)
|
||||
if user.loan_amount + amount > sign_user.impression * 150:
|
||||
raise ValueError("贷款数量超过最大限制,请签到提升好感度获取更多额度吧...")
|
||||
random_reduce_rate = await cls.random_event(float(sign_user.impression))
|
||||
if random_reduce_rate:
|
||||
rate -= random_reduce_rate
|
||||
await UserConsole.add_gold(user_id, amount, "bank")
|
||||
return await MahiroBank.loan(user_id, amount, rate), random_reduce_rate
|
||||
|
||||
@classmethod
|
||||
async def repayment(cls, user_id: str, amount: int) -> MahiroBank:
|
||||
"""还款
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
amount: 还款数量
|
||||
|
||||
返回:
|
||||
MahiroBank
|
||||
"""
|
||||
await UserConsole.reduce_gold(user_id, amount, GoldHandle.PLUGIN, "bank")
|
||||
return await MahiroBank.repayment(user_id, amount)
|
||||
|
||||
@classmethod
|
||||
async def settlement(cls):
|
||||
"""结算每日利率"""
|
||||
bank_user_list = await MahiroBank.filter(amount__gt=0).all()
|
||||
log_list = await MahiroBankLog.filter(
|
||||
is_completed=False, handle_type=BankHandleType.DEPOSIT
|
||||
).all()
|
||||
user_list = await UserConsole.filter(
|
||||
user_id__in=[user.user_id for user in bank_user_list]
|
||||
).all()
|
||||
user_data = {user.user_id: user for user in user_list}
|
||||
bank_data: dict[str, list[MahiroBankLog]] = {}
|
||||
for log in log_list:
|
||||
if log.user_id not in bank_data:
|
||||
bank_data[log.user_id] = []
|
||||
bank_data[log.user_id].append(log)
|
||||
log_create_list = []
|
||||
log_update_list = []
|
||||
# 计算每日默认金币
|
||||
for bank_user in bank_user_list:
|
||||
if user := user_data.get(bank_user.user_id):
|
||||
amount = bank_user.amount
|
||||
if logs := bank_data.get(bank_user.user_id):
|
||||
amount -= sum(log.amount for log in logs)
|
||||
if not amount:
|
||||
continue
|
||||
# 计算每日默认金币
|
||||
gold = int(amount * bank_user.rate)
|
||||
user.gold += gold
|
||||
log_create_list.append(
|
||||
MahiroBankLog(
|
||||
user_id=bank_user.user_id,
|
||||
amount=gold,
|
||||
rate=bank_user.rate,
|
||||
handle_type=BankHandleType.INTEREST,
|
||||
is_completed=True,
|
||||
)
|
||||
)
|
||||
# 计算每日存款金币
|
||||
for user_id, logs in bank_data.items():
|
||||
if user := user_data.get(user_id):
|
||||
for log in logs:
|
||||
gold = int(log.amount * log.rate * log.effective_hour) or 1
|
||||
user.gold += gold
|
||||
log.is_completed = True
|
||||
log_update_list.append(log)
|
||||
log_create_list.append(
|
||||
MahiroBankLog(
|
||||
user_id=user_id,
|
||||
amount=gold,
|
||||
rate=log.rate,
|
||||
handle_type=BankHandleType.INTEREST,
|
||||
is_completed=True,
|
||||
)
|
||||
)
|
||||
if log_create_list:
|
||||
await MahiroBankLog.bulk_create(log_create_list, 10)
|
||||
if log_update_list:
|
||||
await MahiroBankLog.bulk_update(log_update_list, ["is_completed"], 10)
|
||||
await UserConsole.bulk_update(user_list, ["gold"], 10)
|
||||
@ -1,12 +1,17 @@
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
from nonebot import on_regex
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.params import Depends, RegexGroup
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot.rule import to_me
|
||||
from nonebot_plugin_alconna import Alconna, Option, on_alconna, store_true
|
||||
from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
Args,
|
||||
Arparma,
|
||||
CommandMeta,
|
||||
Option,
|
||||
on_alconna,
|
||||
store_true,
|
||||
)
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import BotConfig, Config
|
||||
@ -54,15 +59,22 @@ __plugin_meta__ = PluginMetadata(
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
_nickname_matcher = on_regex(
|
||||
"(?:以后)?(?:叫我|请叫我|称呼我)(.*)",
|
||||
_nickname_matcher = on_alconna(
|
||||
Alconna(
|
||||
"re:(?:以后)?(?:叫我|请叫我|称呼我)",
|
||||
Args["name?", str],
|
||||
meta=CommandMeta(compact=True),
|
||||
),
|
||||
rule=to_me(),
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
_global_nickname_matcher = on_regex(
|
||||
"设置全局昵称(.*)", rule=to_me(), priority=5, block=True
|
||||
_global_nickname_matcher = on_alconna(
|
||||
Alconna("设置全局昵称", Args["name?", str], meta=CommandMeta(compact=True)),
|
||||
rule=to_me(),
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
_matcher = on_alconna(
|
||||
@ -117,18 +129,16 @@ CANCEL = [
|
||||
]
|
||||
|
||||
|
||||
def CheckNickname():
|
||||
async def CheckNickname(
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
params: Arparma,
|
||||
):
|
||||
"""
|
||||
检查名称是否合法
|
||||
"""
|
||||
|
||||
async def dependency(
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
reg_group: tuple[Any, ...] = RegexGroup(),
|
||||
):
|
||||
black_word = Config.get_config("nickname", "BLACK_WORD")
|
||||
(name,) = reg_group
|
||||
name = params.query("name")
|
||||
logger.debug(f"昵称检查: {name}", "昵称设置", session=session)
|
||||
if not name:
|
||||
await MessageUtils.build_message("叫你空白?叫你虚空?叫你无名??").finish(
|
||||
@ -138,13 +148,13 @@ def CheckNickname():
|
||||
logger.debug(
|
||||
f"超级用户设置昵称, 跳过合法检测: {name}", "昵称设置", session=session
|
||||
)
|
||||
return
|
||||
else:
|
||||
if len(name) > 20:
|
||||
await MessageUtils.build_message("昵称可不能超过20个字!").finish(
|
||||
at_sender=True
|
||||
)
|
||||
if name in bot.config.nickname:
|
||||
await MessageUtils.build_message("笨蛋!休想占用我的名字! #").finish(
|
||||
await MessageUtils.build_message("笨蛋!休想占用我的名字! ").finish(
|
||||
at_sender=True
|
||||
)
|
||||
if black_word:
|
||||
@ -162,17 +172,17 @@ def CheckNickname():
|
||||
await MessageUtils.build_message(
|
||||
f"字符 [{word}] 为禁止字符!"
|
||||
).finish(at_sender=True)
|
||||
|
||||
return Depends(dependency)
|
||||
return name
|
||||
|
||||
|
||||
@_nickname_matcher.handle(parameterless=[CheckNickname()])
|
||||
@_nickname_matcher.handle()
|
||||
async def _(
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
name_: Arparma,
|
||||
uname: str = UserName(),
|
||||
reg_group: tuple[Any, ...] = RegexGroup(),
|
||||
):
|
||||
(name,) = reg_group
|
||||
name = await CheckNickname(bot, session, name_)
|
||||
if len(name) < 5 and random.random() < 0.3:
|
||||
name = "~".join(name)
|
||||
group_id = None
|
||||
@ -200,13 +210,14 @@ async def _(
|
||||
)
|
||||
|
||||
|
||||
@_global_nickname_matcher.handle(parameterless=[CheckNickname()])
|
||||
@_global_nickname_matcher.handle()
|
||||
async def _(
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
name_: Arparma,
|
||||
nickname: str = UserName(),
|
||||
reg_group: tuple[Any, ...] = RegexGroup(),
|
||||
):
|
||||
(name,) = reg_group
|
||||
name = await CheckNickname(bot, session, name_)
|
||||
await FriendUser.set_user_nickname(
|
||||
session.user.id,
|
||||
name,
|
||||
@ -227,15 +238,14 @@ async def _(session: Uninfo, uname: str = UserName()):
|
||||
group_id = session.group.parent.id if session.group.parent else session.group.id
|
||||
if group_id:
|
||||
nickname = await GroupInfoUser.get_user_nickname(session.user.id, group_id)
|
||||
card = uname
|
||||
else:
|
||||
nickname = await FriendUser.get_user_nickname(session.user.id)
|
||||
card = uname
|
||||
if nickname:
|
||||
await MessageUtils.build_message(random.choice(REMIND).format(nickname)).finish(
|
||||
reply_to=True
|
||||
)
|
||||
else:
|
||||
card = uname
|
||||
await MessageUtils.build_message(
|
||||
random.choice(
|
||||
[
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from nonebot import on_notice, on_request
|
||||
from nonebot import on_notice
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.adapters.onebot.v11 import (
|
||||
GroupDecreaseNoticeEvent,
|
||||
@ -14,9 +14,10 @@ from nonebot_plugin_uninfo import Uninfo
|
||||
from zhenxun.builtin_plugins.platform.qq.exception import ForceAddGroupError
|
||||
from zhenxun.configs.config import BotConfig, Config
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task
|
||||
from zhenxun.models.event_log import EventLog
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.enum import EventLogType, PluginType
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
from zhenxun.utils.rules import notice_rule
|
||||
|
||||
@ -106,8 +107,6 @@ group_decrease_handle = on_notice(
|
||||
rule=notice_rule([GroupMemberDecreaseEvent, GroupDecreaseNoticeEvent]),
|
||||
)
|
||||
"""群员减少处理"""
|
||||
add_group = on_request(priority=1, block=False)
|
||||
"""加群同意请求"""
|
||||
|
||||
|
||||
@group_increase_handle.handle()
|
||||
@ -141,8 +140,21 @@ async def _(
|
||||
group_id = str(event.group_id)
|
||||
if event.sub_type == "kick_me":
|
||||
"""踢出Bot"""
|
||||
await GroupManager.kick_bot(bot, user_id, group_id)
|
||||
await GroupManager.kick_bot(bot, group_id, str(event.operator_id))
|
||||
await EventLog.create(
|
||||
user_id=user_id, group_id=group_id, event_type=EventLogType.KICK_BOT
|
||||
)
|
||||
elif event.sub_type in ["leave", "kick"]:
|
||||
if event.sub_type == "leave":
|
||||
"""主动退群"""
|
||||
await EventLog.create(
|
||||
user_id=user_id, group_id=group_id, event_type=EventLogType.LEAVE_MEMBER
|
||||
)
|
||||
else:
|
||||
"""被踢出群"""
|
||||
await EventLog.create(
|
||||
user_id=user_id, group_id=group_id, event_type=EventLogType.KICK_MEMBER
|
||||
)
|
||||
result = await GroupManager.run_user(
|
||||
bot, user_id, group_id, str(event.operator_id), event.sub_type
|
||||
)
|
||||
|
||||
@ -55,15 +55,17 @@ class GroupManager:
|
||||
if plugin_list := await PluginInfo.filter(default_status=False).all():
|
||||
for plugin in plugin_list:
|
||||
block_plugin += f"<{plugin.module},"
|
||||
group_info = await bot.get_group_info(group_id=group_id, no_cache=True)
|
||||
await GroupConsole.create(
|
||||
group_info = await bot.get_group_info(group_id=group_id)
|
||||
await GroupConsole.update_or_create(
|
||||
group_id=group_info["group_id"],
|
||||
group_name=group_info["group_name"],
|
||||
max_member_count=group_info["max_member_count"],
|
||||
member_count=group_info["member_count"],
|
||||
group_flag=1,
|
||||
block_plugin=block_plugin,
|
||||
platform="qq",
|
||||
defaults={
|
||||
"group_name": group_info["group_name"],
|
||||
"max_member_count": group_info["max_member_count"],
|
||||
"member_count": group_info["member_count"],
|
||||
"group_flag": 1,
|
||||
"block_plugin": block_plugin,
|
||||
"platform": "qq",
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -145,7 +147,7 @@ class GroupManager:
|
||||
e=e,
|
||||
)
|
||||
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e
|
||||
await GroupConsole.filter(group_id=group_id).delete()
|
||||
# await GroupConsole.filter(group_id=group_id).delete()
|
||||
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...")
|
||||
else:
|
||||
await cls.__handle_add_group(bot, group_id, group)
|
||||
|
||||
100
zhenxun/builtin_plugins/platform/qq/user_group_request.py
Normal file
100
zhenxun/builtin_plugins/platform/qq/user_group_request.py
Normal file
@ -0,0 +1,100 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import random
|
||||
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot.rule import to_me
|
||||
from nonebot_plugin_alconna import Alconna, Args, Arparma, Field, on_alconna
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.utils import PluginCdBlock, PluginExtraData
|
||||
from zhenxun.models.fg_request import FgRequest
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.depends import UserName
|
||||
from zhenxun.utils.enum import RequestHandleType, RequestType
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="群组申请",
|
||||
description="""
|
||||
一些小群直接邀请入群导致无法正常生成审核请求,需要用该方法手动生成审核请求。
|
||||
当管理员同意同意时会发送消息进行提示,之后再进行拉群不会退出。
|
||||
该消息会发送至管理员,多次发送不存在的群组id或相同群组id可能导致ban。
|
||||
""".strip(),
|
||||
usage="""
|
||||
指令:
|
||||
申请入群 [群号]
|
||||
示例: 申请入群 123123123
|
||||
""".strip(),
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="0.1",
|
||||
menu_type="其他",
|
||||
limits=[PluginCdBlock(cd=300, result="每5分钟只能申请一次哦~")],
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
|
||||
_matcher = on_alconna(
|
||||
Alconna(
|
||||
"申请入群",
|
||||
Args[
|
||||
"group_id",
|
||||
int,
|
||||
Field(
|
||||
missing_tips=lambda: "请在命令后跟随群组id!",
|
||||
unmatch_tips=lambda _: "群组id必须为数字!",
|
||||
),
|
||||
],
|
||||
),
|
||||
skip_for_unmatch=False,
|
||||
priority=5,
|
||||
block=True,
|
||||
rule=to_me(),
|
||||
)
|
||||
|
||||
|
||||
@_matcher.handle()
|
||||
async def _(
|
||||
bot: Bot, session: Uninfo, arparma: Arparma, group_id: int, uname: str = UserName()
|
||||
):
|
||||
# 旧请求全部设置为过期
|
||||
await FgRequest.filter(
|
||||
request_type=RequestType.GROUP,
|
||||
user_id=session.user.id,
|
||||
group_id=str(group_id),
|
||||
handle_type__isnull=True,
|
||||
).update(handle_type=RequestHandleType.EXPIRE)
|
||||
f = await FgRequest.create(
|
||||
request_type=RequestType.GROUP,
|
||||
platform=PlatformUtils.get_platform(session),
|
||||
bot_id=bot.self_id,
|
||||
flag="0",
|
||||
user_id=session.user.id,
|
||||
nickname=uname,
|
||||
group_id=str(group_id),
|
||||
)
|
||||
results = await PlatformUtils.send_superuser(
|
||||
bot,
|
||||
f"*****一份入群申请*****\n"
|
||||
f"ID:{f.id}\n"
|
||||
f"申请人:{uname}({session.user.id})\n群聊:"
|
||||
f"{group_id}\n邀请日期:{datetime.now().replace(microsecond=0)}\n"
|
||||
"注:该请求为手动申请入群",
|
||||
)
|
||||
if message_ids := [
|
||||
str(r[1].msg_ids[0]["message_id"]) for r in results if r[1] and r[1].msg_ids
|
||||
]:
|
||||
f.message_ids = ",".join(message_ids)
|
||||
await f.save(update_fields=["message_ids"])
|
||||
await asyncio.sleep(random.randint(1, 5))
|
||||
await bot.send_private_msg(
|
||||
user_id=int(session.user.id),
|
||||
message=f"已发送申请,请等待管理员审核,ID:{f.id}。",
|
||||
)
|
||||
logger.info(
|
||||
f"用户 {uname}({session.user.id}) 申请入群 {group_id},ID:{f.id}。",
|
||||
arparma.header_result,
|
||||
session=session,
|
||||
)
|
||||
@ -1,4 +1,4 @@
|
||||
from nonebot.message import run_preprocessor
|
||||
from nonebot import on_message
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.models.friend_user import FriendUser
|
||||
@ -8,24 +8,27 @@ from zhenxun.services.log import logger
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
|
||||
@run_preprocessor
|
||||
async def do_something(session: Uninfo):
|
||||
def rule(session: Uninfo) -> bool:
|
||||
return PlatformUtils.is_qbot(session)
|
||||
|
||||
|
||||
_matcher = on_message(priority=999, block=False, rule=rule)
|
||||
|
||||
|
||||
@_matcher.handle()
|
||||
async def _(session: Uninfo):
|
||||
platform = PlatformUtils.get_platform(session)
|
||||
if session.group:
|
||||
if not await GroupConsole.exists(group_id=session.group.id):
|
||||
await GroupConsole.create(group_id=session.group.id)
|
||||
logger.info("添加当前群组ID信息" "", session=session)
|
||||
|
||||
if not await GroupInfoUser.exists(
|
||||
user_id=session.user.id, group_id=session.group.id
|
||||
):
|
||||
await GroupInfoUser.create(
|
||||
user_id=session.user.id, group_id=session.group.id, platform=platform
|
||||
logger.info("添加当前群组ID信息", session=session)
|
||||
await GroupInfoUser.update_or_create(
|
||||
user_id=session.user.id,
|
||||
group_id=session.group.id,
|
||||
platform=PlatformUtils.get_platform(session),
|
||||
)
|
||||
logger.info("添加当前用户群组ID信息", "", session=session)
|
||||
elif not await FriendUser.exists(user_id=session.user.id, platform=platform):
|
||||
try:
|
||||
await FriendUser.create(user_id=session.user.id, platform=platform)
|
||||
await FriendUser.create(
|
||||
user_id=session.user.id, platform=PlatformUtils.get_platform(session)
|
||||
)
|
||||
logger.info("添加当前好友用户信息", "", session=session)
|
||||
except Exception as e:
|
||||
logger.error("添加当前好友用户信息失败", session=session, e=e)
|
||||
|
||||
@ -9,7 +9,7 @@ from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.utils import is_number
|
||||
|
||||
from .data_source import ShopManage
|
||||
from .data_source import StoreManager
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="插件商店",
|
||||
@ -82,7 +82,7 @@ _matcher.shortcut(
|
||||
@_matcher.assign("$main")
|
||||
async def _(session: EventSession):
|
||||
try:
|
||||
result = await ShopManage.get_plugins_info()
|
||||
result = await StoreManager.get_plugins_info()
|
||||
logger.info("查看插件列表", "插件商店", session=session)
|
||||
await MessageUtils.build_message(result).send()
|
||||
except Exception as e:
|
||||
@ -97,7 +97,7 @@ async def _(session: EventSession, plugin_id: str):
|
||||
await MessageUtils.build_message(f"正在添加插件 Id: {plugin_id}").send()
|
||||
else:
|
||||
await MessageUtils.build_message(f"正在添加插件 Module: {plugin_id}").send()
|
||||
result = await ShopManage.add_plugin(plugin_id)
|
||||
result = await StoreManager.add_plugin(plugin_id)
|
||||
except Exception as e:
|
||||
logger.error(f"添加插件 Id: {plugin_id}失败", "插件商店", session=session, e=e)
|
||||
await MessageUtils.build_message(
|
||||
@ -110,7 +110,7 @@ async def _(session: EventSession, plugin_id: str):
|
||||
@_matcher.assign("remove")
|
||||
async def _(session: EventSession, plugin_id: str):
|
||||
try:
|
||||
result = await ShopManage.remove_plugin(plugin_id)
|
||||
result = await StoreManager.remove_plugin(plugin_id)
|
||||
except Exception as e:
|
||||
logger.error(f"移除插件 Id: {plugin_id}失败", "插件商店", session=session, e=e)
|
||||
await MessageUtils.build_message(
|
||||
@ -123,7 +123,7 @@ async def _(session: EventSession, plugin_id: str):
|
||||
@_matcher.assign("search")
|
||||
async def _(session: EventSession, plugin_name_or_author: str):
|
||||
try:
|
||||
result = await ShopManage.search_plugin(plugin_name_or_author)
|
||||
result = await StoreManager.search_plugin(plugin_name_or_author)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"搜索插件 name: {plugin_name_or_author}失败",
|
||||
@ -145,7 +145,7 @@ async def _(session: EventSession, plugin_id: str):
|
||||
await MessageUtils.build_message(f"正在更新插件 Id: {plugin_id}").send()
|
||||
else:
|
||||
await MessageUtils.build_message(f"正在更新插件 Module: {plugin_id}").send()
|
||||
result = await ShopManage.update_plugin(plugin_id)
|
||||
result = await StoreManager.update_plugin(plugin_id)
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件 Id: {plugin_id}失败", "插件商店", session=session, e=e)
|
||||
await MessageUtils.build_message(
|
||||
@ -159,7 +159,7 @@ async def _(session: EventSession, plugin_id: str):
|
||||
async def _(session: EventSession):
|
||||
try:
|
||||
await MessageUtils.build_message("正在更新全部插件").send()
|
||||
result = await ShopManage.update_all_plugin()
|
||||
result = await StoreManager.update_all_plugin()
|
||||
except Exception as e:
|
||||
logger.error("更新全部插件失败", "插件商店", session=session, e=e)
|
||||
await MessageUtils.build_message(f"更新全部插件失败 e: {e}").finish()
|
||||
|
||||
@ -9,3 +9,14 @@ DEFAULT_GITHUB_URL = "https://github.com/zhenxun-org/zhenxun_bot_plugins/tree/ma
|
||||
|
||||
EXTRA_GITHUB_URL = "https://github.com/zhenxun-org/zhenxun_bot_plugins_index/tree/index"
|
||||
"""插件库索引github仓库地址"""
|
||||
|
||||
GITEE_RAW_URL = "https://gitee.com/two_Dimension/zhenxun_bot_plugins/raw/main"
|
||||
"""GITEE仓库文件内容"""
|
||||
|
||||
GITEE_CONTENTS_URL = (
|
||||
"https://gitee.com/api/v5/repos/two_Dimension/zhenxun_bot_plugins/contents"
|
||||
)
|
||||
"""GITEE仓库文件列表获取"""
|
||||
|
||||
|
||||
LOG_COMMAND = "插件商店"
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
from aiocache import cached
|
||||
import ujson as json
|
||||
@ -14,9 +13,15 @@ from zhenxun.utils.github_utils import GithubUtils
|
||||
from zhenxun.utils.github_utils.models import RepoAPI
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
|
||||
from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager
|
||||
from zhenxun.utils.utils import is_number
|
||||
|
||||
from .config import BASE_PATH, DEFAULT_GITHUB_URL, EXTRA_GITHUB_URL
|
||||
from .config import (
|
||||
BASE_PATH,
|
||||
DEFAULT_GITHUB_URL,
|
||||
EXTRA_GITHUB_URL,
|
||||
LOG_COMMAND,
|
||||
)
|
||||
|
||||
|
||||
def row_style(column: str, text: str) -> RowStyle:
|
||||
@ -39,72 +44,69 @@ def install_requirement(plugin_path: Path):
|
||||
requirement_files = ["requirement.txt", "requirements.txt"]
|
||||
requirement_paths = [plugin_path / file for file in requirement_files]
|
||||
|
||||
existing_requirements = next(
|
||||
if existing_requirements := next(
|
||||
(path for path in requirement_paths if path.exists()), None
|
||||
)
|
||||
|
||||
if not existing_requirements:
|
||||
logger.debug(
|
||||
f"No requirement.txt found for plugin: {plugin_path.name}", "插件管理"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["poetry", "run", "pip", "install", "-r", str(existing_requirements)],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
logger.debug(
|
||||
"Successfully installed dependencies for"
|
||||
f" plugin: {plugin_path.name}. Output:\n{result.stdout}",
|
||||
"插件管理",
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
logger.error(
|
||||
f"Failed to install dependencies for plugin: {plugin_path.name}. "
|
||||
" Error:\n{e.stderr}"
|
||||
)
|
||||
):
|
||||
VirtualEnvPackageManager.install_requirement(existing_requirements)
|
||||
|
||||
|
||||
class ShopManage:
|
||||
class StoreManager:
|
||||
@classmethod
|
||||
@cached(60)
|
||||
async def get_data(cls) -> dict[str, StorePluginInfo]:
|
||||
"""获取插件信息数据
|
||||
|
||||
异常:
|
||||
ValueError: 访问请求失败
|
||||
async def get_github_plugins(cls) -> list[StorePluginInfo]:
|
||||
"""获取github插件列表信息
|
||||
|
||||
返回:
|
||||
dict: 插件信息数据
|
||||
list[StorePluginInfo]: 插件列表数据
|
||||
"""
|
||||
default_github_repo = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL)
|
||||
extra_github_repo = GithubUtils.parse_github_url(EXTRA_GITHUB_URL)
|
||||
for repo_info in [default_github_repo, extra_github_repo]:
|
||||
repo_info = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL)
|
||||
if await repo_info.update_repo_commit():
|
||||
logger.info(f"获取最新提交: {repo_info.branch}", "插件管理")
|
||||
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
|
||||
else:
|
||||
logger.warning(f"获取最新提交失败: {repo_info}", "插件管理")
|
||||
default_github_url = await default_github_repo.get_raw_download_urls(
|
||||
"plugins.json"
|
||||
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
|
||||
default_github_url = await repo_info.get_raw_download_urls("plugins.json")
|
||||
response = await AsyncHttpx.get(default_github_url, check_status_code=200)
|
||||
if response.status_code == 200:
|
||||
logger.info("获取github插件列表成功", LOG_COMMAND)
|
||||
return [StorePluginInfo(**detail) for detail in json.loads(response.text)]
|
||||
else:
|
||||
logger.warning(
|
||||
f"获取github插件列表失败: {response.status_code}", LOG_COMMAND
|
||||
)
|
||||
extra_github_url = await extra_github_repo.get_raw_download_urls("plugins.json")
|
||||
res = await AsyncHttpx.get(default_github_url)
|
||||
res2 = await AsyncHttpx.get(extra_github_url)
|
||||
return []
|
||||
|
||||
# 检查请求结果
|
||||
if res.status_code != 200 or res2.status_code != 200:
|
||||
raise ValueError(f"下载错误, code: {res.status_code}, {res2.status_code}")
|
||||
@classmethod
|
||||
async def get_extra_plugins(cls) -> list[StorePluginInfo]:
|
||||
"""获取额外插件列表信息
|
||||
|
||||
# 解析并合并返回的 JSON 数据
|
||||
data1 = json.loads(res.text)
|
||||
data2 = json.loads(res2.text)
|
||||
return {
|
||||
name: StorePluginInfo(**detail)
|
||||
for name, detail in {**data1, **data2}.items()
|
||||
}
|
||||
返回:
|
||||
list[StorePluginInfo]: 插件列表数据
|
||||
"""
|
||||
repo_info = GithubUtils.parse_github_url(EXTRA_GITHUB_URL)
|
||||
if await repo_info.update_repo_commit():
|
||||
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
|
||||
else:
|
||||
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
|
||||
extra_github_url = await repo_info.get_raw_download_urls("plugins.json")
|
||||
response = await AsyncHttpx.get(extra_github_url, check_status_code=200)
|
||||
if response.status_code == 200:
|
||||
return [StorePluginInfo(**detail) for detail in json.loads(response.text)]
|
||||
else:
|
||||
logger.warning(
|
||||
f"获取github扩展插件列表失败: {response.status_code}", LOG_COMMAND
|
||||
)
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
@cached(60)
|
||||
async def get_data(cls) -> list[StorePluginInfo]:
|
||||
"""获取插件信息数据
|
||||
|
||||
返回:
|
||||
list[StorePluginInfo]: 插件信息数据
|
||||
"""
|
||||
plugins = await cls.get_github_plugins()
|
||||
extra_plugins = await cls.get_extra_plugins()
|
||||
return [*plugins, *extra_plugins]
|
||||
|
||||
@classmethod
|
||||
def version_check(cls, plugin_info: StorePluginInfo, suc_plugin: dict[str, str]):
|
||||
@ -112,7 +114,7 @@ class ShopManage:
|
||||
|
||||
参数:
|
||||
plugin_info: StorePluginInfo
|
||||
suc_plugin: dict[str, str]
|
||||
suc_plugin: 模块名: 版本号
|
||||
|
||||
返回:
|
||||
str: 版本号
|
||||
@ -132,7 +134,7 @@ class ShopManage:
|
||||
|
||||
参数:
|
||||
plugin_info: StorePluginInfo
|
||||
suc_plugin: dict[str, str]
|
||||
suc_plugin: 模块名: 版本号
|
||||
|
||||
返回:
|
||||
bool: 是否有更新
|
||||
@ -156,21 +158,21 @@ class ShopManage:
|
||||
返回:
|
||||
BuildImage | str: 返回消息
|
||||
"""
|
||||
data: dict[str, StorePluginInfo] = await cls.get_data()
|
||||
plugin_list: list[StorePluginInfo] = await cls.get_data()
|
||||
column_name = ["-", "ID", "名称", "简介", "作者", "版本", "类型"]
|
||||
plugin_list = await cls.get_loaded_plugins("module", "version")
|
||||
suc_plugin = {p[0]: (p[1] or "0.1") for p in plugin_list}
|
||||
db_plugin_list = await cls.get_loaded_plugins("module", "version")
|
||||
suc_plugin = {p[0]: (p[1] or "0.1") for p in db_plugin_list}
|
||||
data_list = [
|
||||
[
|
||||
"已安装" if plugin_info[1].module in suc_plugin else "",
|
||||
"已安装" if plugin_info.module in suc_plugin else "",
|
||||
id,
|
||||
plugin_info[0],
|
||||
plugin_info[1].description,
|
||||
plugin_info[1].author,
|
||||
cls.version_check(plugin_info[1], suc_plugin),
|
||||
plugin_info[1].plugin_type_name,
|
||||
plugin_info.name,
|
||||
plugin_info.description,
|
||||
plugin_info.author,
|
||||
cls.version_check(plugin_info, suc_plugin),
|
||||
plugin_info.plugin_type_name,
|
||||
]
|
||||
for id, plugin_info in enumerate(data.items())
|
||||
for id, plugin_info in enumerate(plugin_list)
|
||||
]
|
||||
return await ImageTemplate.table_page(
|
||||
"插件列表",
|
||||
@ -190,15 +192,15 @@ class ShopManage:
|
||||
返回:
|
||||
str: 返回消息
|
||||
"""
|
||||
data: dict[str, StorePluginInfo] = await cls.get_data()
|
||||
plugin_list: list[StorePluginInfo] = await cls.get_data()
|
||||
try:
|
||||
plugin_key = await cls._resolve_plugin_key(plugin_id)
|
||||
except ValueError as e:
|
||||
return str(e)
|
||||
plugin_list = await cls.get_loaded_plugins("module")
|
||||
plugin_info = data[plugin_key]
|
||||
if plugin_info.module in [p[0] for p in plugin_list]:
|
||||
return f"插件 {plugin_key} 已安装,无需重复安装"
|
||||
db_plugin_list = await cls.get_loaded_plugins("module")
|
||||
plugin_info = next(p for p in plugin_list if p.module == plugin_key)
|
||||
if plugin_info.module in [p[0] for p in db_plugin_list]:
|
||||
return f"插件 {plugin_info.name} 已安装,无需重复安装"
|
||||
is_external = True
|
||||
if plugin_info.github_url is None:
|
||||
plugin_info.github_url = DEFAULT_GITHUB_URL
|
||||
@ -207,34 +209,39 @@ class ShopManage:
|
||||
if len(version_split) > 1:
|
||||
github_url_split = plugin_info.github_url.split("/tree/")
|
||||
plugin_info.github_url = f"{github_url_split[0]}/tree/{version_split[1]}"
|
||||
logger.info(f"正在安装插件 {plugin_key}...")
|
||||
logger.info(f"正在安装插件 {plugin_info.name}...", LOG_COMMAND)
|
||||
await cls.install_plugin_with_repo(
|
||||
plugin_info.github_url,
|
||||
plugin_info.module_path,
|
||||
plugin_info.is_dir,
|
||||
is_external,
|
||||
)
|
||||
return f"插件 {plugin_key} 安装成功! 重启后生效"
|
||||
return f"插件 {plugin_info.name} 安装成功! 重启后生效"
|
||||
|
||||
@classmethod
|
||||
async def install_plugin_with_repo(
|
||||
cls, github_url: str, module_path: str, is_dir: bool, is_external: bool = False
|
||||
cls,
|
||||
github_url: str,
|
||||
module_path: str,
|
||||
is_dir: bool,
|
||||
is_external: bool = False,
|
||||
):
|
||||
files: list[str]
|
||||
repo_api: RepoAPI
|
||||
repo_info = GithubUtils.parse_github_url(github_url)
|
||||
if await repo_info.update_repo_commit():
|
||||
logger.info(f"获取最新提交: {repo_info.branch}", "插件管理")
|
||||
logger.info(f"获取最新提交: {repo_info.branch}", LOG_COMMAND)
|
||||
else:
|
||||
logger.warning(f"获取最新提交失败: {repo_info}", "插件管理")
|
||||
logger.debug(f"成功获取仓库信息: {repo_info}", "插件管理")
|
||||
logger.warning(f"获取最新提交失败: {repo_info}", LOG_COMMAND)
|
||||
logger.debug(f"成功获取仓库信息: {repo_info}", LOG_COMMAND)
|
||||
for repo_api in GithubUtils.iter_api_strategies():
|
||||
try:
|
||||
await repo_api.parse_repo_info(repo_info)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"获取插件文件失败: {e} | API类型: {repo_api.strategy}", "插件管理"
|
||||
f"获取插件文件失败 | API类型: {repo_api.strategy}",
|
||||
LOG_COMMAND,
|
||||
e=e,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
@ -250,7 +257,7 @@ class ShopManage:
|
||||
base_path = BASE_PATH / "plugins" if is_external else BASE_PATH
|
||||
base_path = base_path if module_path else base_path / repo_info.repo
|
||||
download_paths: list[Path | str] = [base_path / file for file in files]
|
||||
logger.debug(f"插件下载路径: {download_paths}", "插件管理")
|
||||
logger.debug(f"插件下载路径: {download_paths}", LOG_COMMAND)
|
||||
result = await AsyncHttpx.gather_download_file(download_urls, download_paths)
|
||||
for _id, success in enumerate(result):
|
||||
if not success:
|
||||
@ -265,12 +272,12 @@ class ShopManage:
|
||||
req_files.extend(
|
||||
repo_api.get_files(f"{replace_module_path}/requirement.txt", False)
|
||||
)
|
||||
logger.debug(f"获取插件依赖文件列表: {req_files}", "插件管理")
|
||||
logger.debug(f"获取插件依赖文件列表: {req_files}", LOG_COMMAND)
|
||||
req_download_urls = [
|
||||
await repo_info.get_raw_download_urls(file) for file in req_files
|
||||
]
|
||||
req_paths: list[Path | str] = [plugin_path / file for file in req_files]
|
||||
logger.debug(f"插件依赖文件下载路径: {req_paths}", "插件管理")
|
||||
logger.debug(f"插件依赖文件下载路径: {req_paths}", LOG_COMMAND)
|
||||
if req_files:
|
||||
result = await AsyncHttpx.gather_download_file(
|
||||
req_download_urls, req_paths
|
||||
@ -278,7 +285,7 @@ class ShopManage:
|
||||
for success in result:
|
||||
if not success:
|
||||
raise Exception("插件依赖文件下载失败")
|
||||
logger.debug(f"插件依赖文件列表: {req_paths}", "插件管理")
|
||||
logger.debug(f"插件依赖文件列表: {req_paths}", LOG_COMMAND)
|
||||
install_requirement(plugin_path)
|
||||
except ValueError as e:
|
||||
logger.warning("未获取到依赖文件路径...", e=e)
|
||||
@ -295,12 +302,12 @@ class ShopManage:
|
||||
返回:
|
||||
str: 返回消息
|
||||
"""
|
||||
data: dict[str, StorePluginInfo] = await cls.get_data()
|
||||
plugin_list: list[StorePluginInfo] = await cls.get_data()
|
||||
try:
|
||||
plugin_key = await cls._resolve_plugin_key(plugin_id)
|
||||
except ValueError as e:
|
||||
return str(e)
|
||||
plugin_info = data[plugin_key]
|
||||
plugin_info = next(p for p in plugin_list if p.module == plugin_key)
|
||||
path = BASE_PATH
|
||||
if plugin_info.github_url:
|
||||
path = BASE_PATH / "plugins"
|
||||
@ -309,14 +316,14 @@ class ShopManage:
|
||||
if not plugin_info.is_dir:
|
||||
path = Path(f"{path}.py")
|
||||
if not path.exists():
|
||||
return f"插件 {plugin_key} 不存在..."
|
||||
logger.debug(f"尝试移除插件 {plugin_key} 文件: {path}", "插件管理")
|
||||
return f"插件 {plugin_info.name} 不存在..."
|
||||
logger.debug(f"尝试移除插件 {plugin_info.name} 文件: {path}", LOG_COMMAND)
|
||||
if plugin_info.is_dir:
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
path.unlink()
|
||||
await PluginInitManager.remove(f"zhenxun.{plugin_info.module_path}")
|
||||
return f"插件 {plugin_key} 移除成功! 重启后生效"
|
||||
return f"插件 {plugin_info.name} 移除成功! 重启后生效"
|
||||
|
||||
@classmethod
|
||||
async def search_plugin(cls, plugin_name_or_author: str) -> BuildImage | str:
|
||||
@ -328,25 +335,25 @@ class ShopManage:
|
||||
返回:
|
||||
BuildImage | str: 返回消息
|
||||
"""
|
||||
data: dict[str, StorePluginInfo] = await cls.get_data()
|
||||
plugin_list = await cls.get_loaded_plugins("module", "version")
|
||||
suc_plugin = {p[0]: (p[1] or "Unknown") for p in plugin_list}
|
||||
plugin_list: list[StorePluginInfo] = await cls.get_data()
|
||||
db_plugin_list = await cls.get_loaded_plugins("module", "version")
|
||||
suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list}
|
||||
filtered_data = [
|
||||
(id, plugin_info)
|
||||
for id, plugin_info in enumerate(data.items())
|
||||
if plugin_name_or_author.lower() in plugin_info[0].lower()
|
||||
or plugin_name_or_author.lower() in plugin_info[1].author.lower()
|
||||
for id, plugin_info in enumerate(plugin_list)
|
||||
if plugin_name_or_author.lower() in plugin_info.name.lower()
|
||||
or plugin_name_or_author.lower() in plugin_info.author.lower()
|
||||
]
|
||||
|
||||
data_list = [
|
||||
[
|
||||
"已安装" if plugin_info[1].module in suc_plugin else "",
|
||||
"已安装" if plugin_info.module in suc_plugin else "",
|
||||
id,
|
||||
plugin_info[0],
|
||||
plugin_info[1].description,
|
||||
plugin_info[1].author,
|
||||
cls.version_check(plugin_info[1], suc_plugin),
|
||||
plugin_info[1].plugin_type_name,
|
||||
plugin_info.name,
|
||||
plugin_info.description,
|
||||
plugin_info.author,
|
||||
cls.version_check(plugin_info, suc_plugin),
|
||||
plugin_info.plugin_type_name,
|
||||
]
|
||||
for id, plugin_info in filtered_data
|
||||
]
|
||||
@ -354,7 +361,7 @@ class ShopManage:
|
||||
return "未找到相关插件..."
|
||||
column_name = ["-", "ID", "名称", "简介", "作者", "版本", "类型"]
|
||||
return await ImageTemplate.table_page(
|
||||
"插件列表",
|
||||
"商店插件列表",
|
||||
"通过添加/移除插件 ID 来管理插件",
|
||||
column_name,
|
||||
data_list,
|
||||
@ -371,20 +378,20 @@ class ShopManage:
|
||||
返回:
|
||||
str: 返回消息
|
||||
"""
|
||||
data: dict[str, StorePluginInfo] = await cls.get_data()
|
||||
plugin_list: list[StorePluginInfo] = await cls.get_data()
|
||||
try:
|
||||
plugin_key = await cls._resolve_plugin_key(plugin_id)
|
||||
except ValueError as e:
|
||||
return str(e)
|
||||
logger.info(f"尝试更新插件 {plugin_key}", "插件管理")
|
||||
plugin_info = data[plugin_key]
|
||||
plugin_list = await cls.get_loaded_plugins("module", "version")
|
||||
suc_plugin = {p[0]: (p[1] or "Unknown") for p in plugin_list}
|
||||
if plugin_info.module not in [p[0] for p in plugin_list]:
|
||||
return f"插件 {plugin_key} 未安装,无法更新"
|
||||
logger.debug(f"当前插件列表: {suc_plugin}", "插件管理")
|
||||
plugin_info = next(p for p in plugin_list if p.module == plugin_key)
|
||||
logger.info(f"尝试更新插件 {plugin_info.name}", LOG_COMMAND)
|
||||
db_plugin_list = await cls.get_loaded_plugins("module", "version")
|
||||
suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list}
|
||||
if plugin_info.module not in [p[0] for p in db_plugin_list]:
|
||||
return f"插件 {plugin_info.name} 未安装,无法更新"
|
||||
logger.debug(f"当前插件列表: {suc_plugin}", LOG_COMMAND)
|
||||
if cls.check_version_is_new(plugin_info, suc_plugin):
|
||||
return f"插件 {plugin_key} 已是最新版本"
|
||||
return f"插件 {plugin_info.name} 已是最新版本"
|
||||
is_external = True
|
||||
if plugin_info.github_url is None:
|
||||
plugin_info.github_url = DEFAULT_GITHUB_URL
|
||||
@ -395,7 +402,7 @@ class ShopManage:
|
||||
plugin_info.is_dir,
|
||||
is_external,
|
||||
)
|
||||
return f"插件 {plugin_key} 更新成功! 重启后生效"
|
||||
return f"插件 {plugin_info.name} 更新成功! 重启后生效"
|
||||
|
||||
@classmethod
|
||||
async def update_all_plugin(cls) -> str:
|
||||
@ -407,24 +414,33 @@ class ShopManage:
|
||||
返回:
|
||||
str: 返回消息
|
||||
"""
|
||||
data: dict[str, StorePluginInfo] = await cls.get_data()
|
||||
plugin_list = list(data.keys())
|
||||
plugin_list: list[StorePluginInfo] = await cls.get_data()
|
||||
plugin_name_list = [p.name for p in plugin_list]
|
||||
update_failed_list = []
|
||||
update_success_list = []
|
||||
result = "--已更新{}个插件 {}个失败 {}个成功--"
|
||||
logger.info(f"尝试更新全部插件 {plugin_list}", "插件管理")
|
||||
for plugin_key in plugin_list:
|
||||
logger.info(f"尝试更新全部插件 {plugin_name_list}", LOG_COMMAND)
|
||||
for plugin_info in plugin_list:
|
||||
try:
|
||||
plugin_info = data[plugin_key]
|
||||
plugin_list = await cls.get_loaded_plugins("module", "version")
|
||||
suc_plugin = {p[0]: (p[1] or "Unknown") for p in plugin_list}
|
||||
if plugin_info.module not in [p[0] for p in plugin_list]:
|
||||
logger.debug(f"插件 {plugin_key} 未安装,跳过", "插件管理")
|
||||
db_plugin_list = await cls.get_loaded_plugins("module", "version")
|
||||
suc_plugin = {p[0]: (p[1] or "Unknown") for p in db_plugin_list}
|
||||
if plugin_info.module not in [p[0] for p in db_plugin_list]:
|
||||
logger.debug(
|
||||
f"插件 {plugin_info.name}({plugin_info.module}) 未安装,跳过",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
continue
|
||||
if cls.check_version_is_new(plugin_info, suc_plugin):
|
||||
logger.debug(f"插件 {plugin_key} 已是最新版本,跳过", "插件管理")
|
||||
logger.debug(
|
||||
f"插件 {plugin_info.name}({plugin_info.module}) "
|
||||
"已是最新版本,跳过",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
continue
|
||||
logger.info(f"正在更新插件 {plugin_key}", "插件管理")
|
||||
logger.info(
|
||||
f"正在更新插件 {plugin_info.name}({plugin_info.module})",
|
||||
LOG_COMMAND,
|
||||
)
|
||||
is_external = True
|
||||
if plugin_info.github_url is None:
|
||||
plugin_info.github_url = DEFAULT_GITHUB_URL
|
||||
@ -435,10 +451,14 @@ class ShopManage:
|
||||
plugin_info.is_dir,
|
||||
is_external,
|
||||
)
|
||||
update_success_list.append(plugin_key)
|
||||
update_success_list.append(plugin_info.name)
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件 {plugin_key} 失败: {e}", "插件管理")
|
||||
update_failed_list.append(plugin_key)
|
||||
logger.error(
|
||||
f"更新插件 {plugin_info.name}({plugin_info.module}) 失败",
|
||||
LOG_COMMAND,
|
||||
e=e,
|
||||
)
|
||||
update_failed_list.append(plugin_info.name)
|
||||
if not update_success_list and not update_failed_list:
|
||||
return "全部插件已是最新版本"
|
||||
if update_success_list:
|
||||
@ -460,13 +480,28 @@ class ShopManage:
|
||||
|
||||
@classmethod
|
||||
async def _resolve_plugin_key(cls, plugin_id: str) -> str:
|
||||
data: dict[str, StorePluginInfo] = await cls.get_data()
|
||||
"""获取插件module
|
||||
|
||||
参数:
|
||||
plugin_id: module,id或插件名称
|
||||
|
||||
异常:
|
||||
ValueError: 插件不存在
|
||||
ValueError: 插件不存在
|
||||
|
||||
返回:
|
||||
str: 插件模块名
|
||||
"""
|
||||
plugin_list: list[StorePluginInfo] = await cls.get_data()
|
||||
if is_number(plugin_id):
|
||||
idx = int(plugin_id)
|
||||
if idx < 0 or idx >= len(data):
|
||||
if idx < 0 or idx >= len(plugin_list):
|
||||
raise ValueError("插件ID不存在...")
|
||||
return list(data.keys())[idx]
|
||||
return plugin_list[idx].module
|
||||
elif isinstance(plugin_id, str):
|
||||
if plugin_id not in [v.module for k, v in data.items()]:
|
||||
raise ValueError("插件Module不存在...")
|
||||
return {v.module: k for k, v in data.items()}[plugin_id]
|
||||
result = (
|
||||
None if plugin_id not in [v.module for v in plugin_list] else plugin_id
|
||||
) or next(v for v in plugin_list if v.name == plugin_id).module
|
||||
if not result:
|
||||
raise ValueError("插件 Module / 名称 不存在...")
|
||||
return result
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from nonebot.compat import model_dump
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -13,9 +15,30 @@ type2name: dict[str, str] = {
|
||||
}
|
||||
|
||||
|
||||
class GiteeContents(BaseModel):
|
||||
"""Gitee Api内容"""
|
||||
|
||||
type: Literal["file", "dir"]
|
||||
"""类型"""
|
||||
size: Any
|
||||
"""文件大小"""
|
||||
name: str
|
||||
"""文件名"""
|
||||
path: str
|
||||
"""文件路径"""
|
||||
url: str
|
||||
"""文件链接"""
|
||||
html_url: str
|
||||
"""文件html链接"""
|
||||
download_url: str
|
||||
"""文件raw链接"""
|
||||
|
||||
|
||||
class StorePluginInfo(BaseModel):
|
||||
"""插件信息"""
|
||||
|
||||
name: str
|
||||
"""插件名"""
|
||||
module: str
|
||||
"""模块名"""
|
||||
module_path: str
|
||||
|
||||
@ -17,11 +17,12 @@ from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.configs.config import BotConfig, Config
|
||||
from zhenxun.configs.utils import PluginExtraData, RegisterConfig
|
||||
from zhenxun.models.event_log import EventLog
|
||||
from zhenxun.models.fg_request import FgRequest
|
||||
from zhenxun.models.friend_user import FriendUser
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import PluginType, RequestHandleType, RequestType
|
||||
from zhenxun.utils.enum import EventLogType, PluginType, RequestHandleType, RequestType
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
base_config = Config.get("invite_manager")
|
||||
@ -112,21 +113,29 @@ async def _(bot: v12Bot | v11Bot, event: FriendRequestEvent, session: EventSessi
|
||||
nickname=nickname,
|
||||
comment=comment,
|
||||
)
|
||||
await PlatformUtils.send_superuser(
|
||||
results = await PlatformUtils.send_superuser(
|
||||
bot,
|
||||
f"*****一份好友申请*****\n"
|
||||
f"ID: {f.id}\n"
|
||||
f"昵称:{nickname}({event.user_id})\n"
|
||||
f"自动同意:{'√' if base_config.get('AUTO_ADD_FRIEND') else '×'}\n"
|
||||
f"日期:{str(datetime.now()).split('.')[0]}\n"
|
||||
f"日期:{datetime.now().replace(microsecond=0)}\n"
|
||||
f"备注:{event.comment}",
|
||||
)
|
||||
if message_ids := [
|
||||
str(r[1].msg_ids[0]["message_id"])
|
||||
for r in results
|
||||
if r[1] and r[1].msg_ids
|
||||
]:
|
||||
f.message_ids = ",".join(message_ids)
|
||||
await f.save(update_fields=["message_ids"])
|
||||
else:
|
||||
logger.debug("好友请求五分钟内重复, 已忽略", "好友请求", target=event.user_id)
|
||||
|
||||
|
||||
@group_req.handle()
|
||||
async def _(bot: v12Bot | v11Bot, event: GroupRequestEvent, session: EventSession):
|
||||
# sourcery skip: low-code-quality
|
||||
if event.sub_type != "invite":
|
||||
return
|
||||
if str(event.user_id) in bot.config.superusers or base_config.get("AUTO_ADD_GROUP"):
|
||||
@ -186,7 +195,7 @@ async def _(bot: v12Bot | v11Bot, event: GroupRequestEvent, session: EventSessio
|
||||
group_id=str(event.group_id),
|
||||
handle_type=RequestHandleType.APPROVE,
|
||||
)
|
||||
await PlatformUtils.send_superuser(
|
||||
results = await PlatformUtils.send_superuser(
|
||||
bot,
|
||||
f"*****一份入群申请*****\n"
|
||||
f"ID:{f.id}\n"
|
||||
@ -230,13 +239,27 @@ async def _(bot: v12Bot | v11Bot, event: GroupRequestEvent, session: EventSessio
|
||||
nickname=nickname,
|
||||
group_id=str(event.group_id),
|
||||
)
|
||||
await PlatformUtils.send_superuser(
|
||||
kick_count = await EventLog.filter(
|
||||
group_id=str(event.group_id), event_type=EventLogType.KICK_BOT
|
||||
).count()
|
||||
kick_message = (
|
||||
f"\n该群累计踢出{BotConfig.self_nickname} <{kick_count}>次"
|
||||
if kick_count
|
||||
else ""
|
||||
)
|
||||
results = await PlatformUtils.send_superuser(
|
||||
bot,
|
||||
f"*****一份入群申请*****\n"
|
||||
f"ID:{f.id}\n"
|
||||
f"申请人:{nickname}({event.user_id})\n群聊:"
|
||||
f"{event.group_id}\n邀请日期:{datetime.now().replace(microsecond=0)}",
|
||||
f"{event.group_id}\n邀请日期:{datetime.now().replace(microsecond=0)}"
|
||||
f"{kick_message}",
|
||||
)
|
||||
if message_ids := [
|
||||
str(r[1].msg_ids[0]["message_id"]) for r in results if r[1] and r[1].msg_ids
|
||||
]:
|
||||
f.message_ids = ",".join(message_ids)
|
||||
await f.save(update_fields=["message_ids"])
|
||||
else:
|
||||
logger.debug(
|
||||
"群聊请求五分钟内重复, 已忽略",
|
||||
|
||||
51
zhenxun/builtin_plugins/scheduler_admin/__init__.py
Normal file
51
zhenxun/builtin_plugins/scheduler_admin/__init__.py
Normal file
@ -0,0 +1,51 @@
|
||||
from nonebot.plugin import PluginMetadata
|
||||
|
||||
from zhenxun.configs.utils import PluginExtraData
|
||||
from zhenxun.utils.enum import PluginType
|
||||
|
||||
from . import command # noqa: F401
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="定时任务管理",
|
||||
description="查看和管理由 SchedulerManager 控制的定时任务。",
|
||||
usage="""
|
||||
📋 定时任务管理 - 支持群聊和私聊操作
|
||||
|
||||
🔍 查看任务:
|
||||
定时任务 查看 [-all] [-g <群号>] [-p <插件>] [--page <页码>]
|
||||
• 群聊中: 查看本群任务
|
||||
• 私聊中: 必须使用 -g <群号> 或 -all 选项 (SUPERUSER)
|
||||
|
||||
📊 任务状态:
|
||||
定时任务 状态 <任务ID> 或 任务状态 <任务ID>
|
||||
• 查看单个任务的详细信息和状态
|
||||
|
||||
⚙️ 任务管理 (SUPERUSER):
|
||||
定时任务 设置 <插件> [时间选项] [-g <群号> | -g all] [--kwargs <参数>]
|
||||
定时任务 删除 <任务ID> | -p <插件> [-g <群号>] | -all
|
||||
定时任务 暂停 <任务ID> | -p <插件> [-g <群号>] | -all
|
||||
定时任务 恢复 <任务ID> | -p <插件> [-g <群号>] | -all
|
||||
定时任务 执行 <任务ID>
|
||||
定时任务 更新 <任务ID> [时间选项] [--kwargs <参数>]
|
||||
|
||||
📝 时间选项 (三选一):
|
||||
--cron "<分> <时> <日> <月> <周>" # 例: --cron "0 8 * * *"
|
||||
--interval <时间间隔> # 例: --interval 30m, 2h, 10s
|
||||
--date "<YYYY-MM-DD HH:MM:SS>" # 例: --date "2024-01-01 08:00:00"
|
||||
--daily "<HH:MM>" # 例: --daily "08:30"
|
||||
|
||||
📚 其他功能:
|
||||
定时任务 插件列表 # 查看所有可设置定时任务的插件 (SUPERUSER)
|
||||
|
||||
🏷️ 别名支持:
|
||||
查看: ls, list | 设置: add, 开启 | 删除: del, rm, remove, 关闭, 取消
|
||||
暂停: pause | 恢复: resume | 执行: trigger, run | 状态: status, info
|
||||
更新: update, modify, 修改 | 插件列表: plugins
|
||||
""".strip(),
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="0.1.2",
|
||||
plugin_type=PluginType.SUPERUSER,
|
||||
is_show=False,
|
||||
).to_dict(),
|
||||
)
|
||||
836
zhenxun/builtin_plugins/scheduler_admin/command.py
Normal file
836
zhenxun/builtin_plugins/scheduler_admin/command.py
Normal file
@ -0,0 +1,836 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
from nonebot.params import Depends
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
AlconnaMatch,
|
||||
Args,
|
||||
Arparma,
|
||||
Match,
|
||||
Option,
|
||||
Query,
|
||||
Subcommand,
|
||||
on_alconna,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from zhenxun.utils._image_template import ImageTemplate
|
||||
from zhenxun.utils.manager.schedule_manager import scheduler_manager
|
||||
|
||||
|
||||
def _get_type_name(annotation) -> str:
|
||||
"""获取类型注解的名称"""
|
||||
if hasattr(annotation, "__name__"):
|
||||
return annotation.__name__
|
||||
elif hasattr(annotation, "_name"):
|
||||
return annotation._name
|
||||
else:
|
||||
return str(annotation)
|
||||
|
||||
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.rules import admin_check
|
||||
|
||||
|
||||
def _format_trigger(schedule_status: dict) -> str:
|
||||
"""将触发器配置格式化为人类可读的字符串"""
|
||||
trigger_type = schedule_status["trigger_type"]
|
||||
config = schedule_status["trigger_config"]
|
||||
|
||||
if trigger_type == "cron":
|
||||
minute = config.get("minute", "*")
|
||||
hour = config.get("hour", "*")
|
||||
day = config.get("day", "*")
|
||||
month = config.get("month", "*")
|
||||
day_of_week = config.get("day_of_week", "*")
|
||||
|
||||
if day == "*" and month == "*" and day_of_week == "*":
|
||||
formatted_hour = hour if hour == "*" else f"{int(hour):02d}"
|
||||
formatted_minute = minute if minute == "*" else f"{int(minute):02d}"
|
||||
return f"每天 {formatted_hour}:{formatted_minute}"
|
||||
else:
|
||||
return f"Cron: {minute} {hour} {day} {month} {day_of_week}"
|
||||
elif trigger_type == "interval":
|
||||
seconds = config.get("seconds", 0)
|
||||
minutes = config.get("minutes", 0)
|
||||
hours = config.get("hours", 0)
|
||||
days = config.get("days", 0)
|
||||
if days:
|
||||
trigger_str = f"每 {days} 天"
|
||||
elif hours:
|
||||
trigger_str = f"每 {hours} 小时"
|
||||
elif minutes:
|
||||
trigger_str = f"每 {minutes} 分钟"
|
||||
else:
|
||||
trigger_str = f"每 {seconds} 秒"
|
||||
elif trigger_type == "date":
|
||||
run_date = config.get("run_date", "未知时间")
|
||||
trigger_str = f"在 {run_date}"
|
||||
else:
|
||||
trigger_str = f"{trigger_type}: {config}"
|
||||
|
||||
return trigger_str
|
||||
|
||||
|
||||
def _format_params(schedule_status: dict) -> str:
|
||||
"""将任务参数格式化为人类可读的字符串"""
|
||||
if kwargs := schedule_status.get("job_kwargs"):
|
||||
kwargs_str = " | ".join(f"{k}: {v}" for k, v in kwargs.items())
|
||||
return kwargs_str
|
||||
return "-"
|
||||
|
||||
|
||||
def _parse_interval(interval_str: str) -> dict:
|
||||
"""增强版解析器,支持 d(天)"""
|
||||
match = re.match(r"(\d+)([smhd])", interval_str.lower())
|
||||
if not match:
|
||||
raise ValueError("时间间隔格式错误, 请使用如 '30m', '2h', '1d', '10s' 的格式。")
|
||||
|
||||
value, unit = int(match.group(1)), match.group(2)
|
||||
if unit == "s":
|
||||
return {"seconds": value}
|
||||
if unit == "m":
|
||||
return {"minutes": value}
|
||||
if unit == "h":
|
||||
return {"hours": value}
|
||||
if unit == "d":
|
||||
return {"days": value}
|
||||
return {}
|
||||
|
||||
|
||||
def _parse_daily_time(time_str: str) -> dict:
|
||||
"""解析 HH:MM 或 HH:MM:SS 格式的时间为 cron 配置"""
|
||||
if match := re.match(r"^(\d{1,2}):(\d{1,2})(?::(\d{1,2}))?$", time_str):
|
||||
hour, minute, second = match.groups()
|
||||
hour, minute = int(hour), int(minute)
|
||||
|
||||
if not (0 <= hour <= 23 and 0 <= minute <= 59):
|
||||
raise ValueError("小时或分钟数值超出范围。")
|
||||
|
||||
cron_config = {
|
||||
"minute": str(minute),
|
||||
"hour": str(hour),
|
||||
"day": "*",
|
||||
"month": "*",
|
||||
"day_of_week": "*",
|
||||
}
|
||||
if second is not None:
|
||||
if not (0 <= int(second) <= 59):
|
||||
raise ValueError("秒数值超出范围。")
|
||||
cron_config["second"] = str(second)
|
||||
|
||||
return cron_config
|
||||
else:
|
||||
raise ValueError("时间格式错误,请使用 'HH:MM' 或 'HH:MM:SS' 格式。")
|
||||
|
||||
|
||||
async def GetBotId(
|
||||
bot: Bot,
|
||||
bot_id_match: Match[str] = AlconnaMatch("bot_id"),
|
||||
) -> str:
|
||||
"""获取要操作的Bot ID"""
|
||||
if bot_id_match.available:
|
||||
return bot_id_match.result
|
||||
return bot.self_id
|
||||
|
||||
|
||||
class ScheduleTarget:
|
||||
"""定时任务操作目标的基类"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TargetByID(ScheduleTarget):
|
||||
"""按任务ID操作"""
|
||||
|
||||
def __init__(self, id: int):
|
||||
self.id = id
|
||||
|
||||
|
||||
class TargetByPlugin(ScheduleTarget):
|
||||
"""按插件名操作"""
|
||||
|
||||
def __init__(
|
||||
self, plugin: str, group_id: str | None = None, all_groups: bool = False
|
||||
):
|
||||
self.plugin = plugin
|
||||
self.group_id = group_id
|
||||
self.all_groups = all_groups
|
||||
|
||||
|
||||
class TargetAll(ScheduleTarget):
|
||||
"""操作所有任务"""
|
||||
|
||||
def __init__(self, for_group: str | None = None):
|
||||
self.for_group = for_group
|
||||
|
||||
|
||||
TargetScope = TargetByID | TargetByPlugin | TargetAll | None
|
||||
|
||||
|
||||
def create_target_parser(subcommand_name: str):
|
||||
"""
|
||||
创建一个依赖注入函数,用于解析删除、暂停、恢复等命令的操作目标。
|
||||
"""
|
||||
|
||||
async def dependency(
|
||||
event: Event,
|
||||
schedule_id: Match[int] = AlconnaMatch("schedule_id"),
|
||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||
group_id: Match[str] = AlconnaMatch("group_id"),
|
||||
all_enabled: Query[bool] = Query(f"{subcommand_name}.all"),
|
||||
) -> TargetScope:
|
||||
if schedule_id.available:
|
||||
return TargetByID(schedule_id.result)
|
||||
|
||||
if plugin_name.available:
|
||||
p_name = plugin_name.result
|
||||
if all_enabled.available:
|
||||
return TargetByPlugin(plugin=p_name, all_groups=True)
|
||||
elif group_id.available:
|
||||
gid = group_id.result
|
||||
if gid.lower() == "all":
|
||||
return TargetByPlugin(plugin=p_name, all_groups=True)
|
||||
return TargetByPlugin(plugin=p_name, group_id=gid)
|
||||
else:
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
if current_group_id:
|
||||
return TargetByPlugin(plugin=p_name, group_id=str(current_group_id))
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"私聊中操作插件任务必须使用 -g <群号> 或 -all 选项。"
|
||||
)
|
||||
|
||||
if all_enabled.available:
|
||||
return TargetAll(for_group=group_id.result if group_id.available else None)
|
||||
|
||||
return None
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
schedule_cmd = on_alconna(
|
||||
Alconna(
|
||||
"定时任务",
|
||||
Subcommand(
|
||||
"查看",
|
||||
Option("-g", Args["target_group_id", str]),
|
||||
Option("-all", help_text="查看所有群聊 (SUPERUSER)"),
|
||||
Option("-p", Args["plugin_name", str], help_text="按插件名筛选"),
|
||||
Option("--page", Args["page", int, 1], help_text="指定页码"),
|
||||
alias=["ls", "list"],
|
||||
help_text="查看定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"设置",
|
||||
Args["plugin_name", str],
|
||||
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
||||
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
||||
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
||||
Option(
|
||||
"--daily",
|
||||
Args["daily_expr", str],
|
||||
help_text="设置每天执行的时间 (如 08:20)",
|
||||
),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID或'all'"),
|
||||
Option("-all", help_text="对所有群生效 (等同于 -g all)"),
|
||||
Option("--kwargs", Args["kwargs_str", str], help_text="设置任务参数"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["add", "开启"],
|
||||
help_text="设置/开启一个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"删除",
|
||||
Args["schedule_id?", int],
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID"),
|
||||
Option("-all", help_text="对所有群生效"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["del", "rm", "remove", "关闭", "取消"],
|
||||
help_text="删除一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"暂停",
|
||||
Args["schedule_id?", int],
|
||||
Option("-all", help_text="对当前群所有任务生效"),
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["pause"],
|
||||
help_text="暂停一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"恢复",
|
||||
Args["schedule_id?", int],
|
||||
Option("-all", help_text="对当前群所有任务生效"),
|
||||
Option("-p", Args["plugin_name", str], help_text="指定插件名"),
|
||||
Option("-g", Args["group_id", str], help_text="指定群组ID (SUPERUSER)"),
|
||||
Option(
|
||||
"--bot", Args["bot_id", str], help_text="指定操作的Bot ID (SUPERUSER)"
|
||||
),
|
||||
alias=["resume"],
|
||||
help_text="恢复一个或多个定时任务",
|
||||
),
|
||||
Subcommand(
|
||||
"执行",
|
||||
Args["schedule_id", int],
|
||||
alias=["trigger", "run"],
|
||||
help_text="立即执行一次任务",
|
||||
),
|
||||
Subcommand(
|
||||
"更新",
|
||||
Args["schedule_id", int],
|
||||
Option("--cron", Args["cron_expr", str], help_text="设置 cron 表达式"),
|
||||
Option("--interval", Args["interval_expr", str], help_text="设置时间间隔"),
|
||||
Option("--date", Args["date_expr", str], help_text="设置特定执行日期"),
|
||||
Option(
|
||||
"--daily",
|
||||
Args["daily_expr", str],
|
||||
help_text="更新每天执行的时间 (如 08:20)",
|
||||
),
|
||||
Option("--kwargs", Args["kwargs_str", str], help_text="更新参数"),
|
||||
alias=["update", "modify", "修改"],
|
||||
help_text="更新任务配置",
|
||||
),
|
||||
Subcommand(
|
||||
"状态",
|
||||
Args["schedule_id", int],
|
||||
alias=["status", "info"],
|
||||
help_text="查看单个任务的详细状态",
|
||||
),
|
||||
Subcommand(
|
||||
"插件列表",
|
||||
alias=["plugins"],
|
||||
help_text="列出所有可用的插件",
|
||||
),
|
||||
),
|
||||
priority=5,
|
||||
block=True,
|
||||
rule=admin_check(1),
|
||||
)
|
||||
|
||||
schedule_cmd.shortcut(
|
||||
"任务状态",
|
||||
command="定时任务",
|
||||
arguments=["状态", "{%0}"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.handle()
|
||||
async def _handle_time_options_mutex(arp: Arparma):
|
||||
time_options = ["cron", "interval", "date", "daily"]
|
||||
provided_options = [opt for opt in time_options if arp.query(opt) is not None]
|
||||
if len(provided_options) > 1:
|
||||
await schedule_cmd.finish(
|
||||
f"时间选项 --{', --'.join(provided_options)} 不能同时使用,请只选择一个。"
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.assign("查看")
|
||||
async def _(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
target_group_id: Match[str] = AlconnaMatch("target_group_id"),
|
||||
all_groups: Query[bool] = Query("查看.all"),
|
||||
plugin_name: Match[str] = AlconnaMatch("plugin_name"),
|
||||
page: Match[int] = AlconnaMatch("page"),
|
||||
):
|
||||
is_superuser = await SUPERUSER(bot, event)
|
||||
schedules = []
|
||||
title = ""
|
||||
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
if not (all_groups.available or target_group_id.available) and not current_group_id:
|
||||
await schedule_cmd.finish("私聊中查看任务必须使用 -g <群号> 或 -all 选项。")
|
||||
|
||||
if all_groups.available:
|
||||
if not is_superuser:
|
||||
await schedule_cmd.finish("需要超级用户权限才能查看所有群组的定时任务。")
|
||||
schedules = await scheduler_manager.get_all_schedules()
|
||||
title = "所有群组的定时任务"
|
||||
elif target_group_id.available:
|
||||
if not is_superuser:
|
||||
await schedule_cmd.finish("需要超级用户权限才能查看指定群组的定时任务。")
|
||||
gid = target_group_id.result
|
||||
schedules = [
|
||||
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
|
||||
]
|
||||
title = f"群 {gid} 的定时任务"
|
||||
else:
|
||||
gid = str(current_group_id)
|
||||
schedules = [
|
||||
s for s in await scheduler_manager.get_all_schedules() if s.group_id == gid
|
||||
]
|
||||
title = "本群的定时任务"
|
||||
|
||||
if plugin_name.available:
|
||||
schedules = [s for s in schedules if s.plugin_name == plugin_name.result]
|
||||
title += f" [插件: {plugin_name.result}]"
|
||||
|
||||
if not schedules:
|
||||
await schedule_cmd.finish("没有找到任何相关的定时任务。")
|
||||
|
||||
page_size = 15
|
||||
current_page = page.result
|
||||
total_items = len(schedules)
|
||||
total_pages = (total_items + page_size - 1) // page_size
|
||||
start_index = (current_page - 1) * page_size
|
||||
end_index = start_index + page_size
|
||||
paginated_schedules = schedules[start_index:end_index]
|
||||
|
||||
if not paginated_schedules:
|
||||
await schedule_cmd.finish("这一页没有内容了哦~")
|
||||
|
||||
status_tasks = [
|
||||
scheduler_manager.get_schedule_status(s.id) for s in paginated_schedules
|
||||
]
|
||||
all_statuses = await asyncio.gather(*status_tasks)
|
||||
data_list = [
|
||||
[
|
||||
s["id"],
|
||||
s["plugin_name"],
|
||||
s.get("bot_id") or "N/A",
|
||||
s["group_id"] or "全局",
|
||||
s["next_run_time"],
|
||||
_format_trigger(s),
|
||||
_format_params(s),
|
||||
"✔️ 已启用" if s["is_enabled"] else "⏸️ 已暂停",
|
||||
]
|
||||
for s in all_statuses
|
||||
if s
|
||||
]
|
||||
|
||||
if not data_list:
|
||||
await schedule_cmd.finish("没有找到任何相关的定时任务。")
|
||||
|
||||
img = await ImageTemplate.table_page(
|
||||
head_text=title,
|
||||
tip_text=f"第 {current_page}/{total_pages} 页,共 {total_items} 条任务",
|
||||
column_name=[
|
||||
"ID",
|
||||
"插件",
|
||||
"Bot ID",
|
||||
"群组/目标",
|
||||
"下次运行",
|
||||
"触发规则",
|
||||
"参数",
|
||||
"状态",
|
||||
],
|
||||
data_list=data_list,
|
||||
column_space=20,
|
||||
)
|
||||
await MessageUtils.build_message(img).send(reply_to=True)
|
||||
|
||||
|
||||
@schedule_cmd.assign("设置")
|
||||
async def _(
|
||||
event: Event,
|
||||
plugin_name: str,
|
||||
cron_expr: str | None = None,
|
||||
interval_expr: str | None = None,
|
||||
date_expr: str | None = None,
|
||||
daily_expr: str | None = None,
|
||||
group_id: str | None = None,
|
||||
kwargs_str: str | None = None,
|
||||
all_enabled: Query[bool] = Query("设置.all"),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if plugin_name not in scheduler_manager._registered_tasks:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 没有注册可用的定时任务。\n"
|
||||
f"可用插件: {list(scheduler_manager._registered_tasks.keys())}"
|
||||
)
|
||||
|
||||
trigger_type = ""
|
||||
trigger_config = {}
|
||||
|
||||
try:
|
||||
if cron_expr:
|
||||
trigger_type = "cron"
|
||||
parts = cron_expr.split()
|
||||
if len(parts) != 5:
|
||||
raise ValueError("Cron 表达式必须有5个部分 (分 时 日 月 周)")
|
||||
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
|
||||
trigger_config = dict(zip(cron_keys, parts))
|
||||
elif interval_expr:
|
||||
trigger_type = "interval"
|
||||
trigger_config = _parse_interval(interval_expr)
|
||||
elif date_expr:
|
||||
trigger_type = "date"
|
||||
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
|
||||
elif daily_expr:
|
||||
trigger_type = "cron"
|
||||
trigger_config = _parse_daily_time(daily_expr)
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"必须提供一种时间选项: --cron, --interval, --date, 或 --daily。"
|
||||
)
|
||||
except ValueError as e:
|
||||
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
||||
|
||||
job_kwargs = {}
|
||||
if kwargs_str:
|
||||
task_meta = scheduler_manager._registered_tasks[plugin_name]
|
||||
params_model = task_meta.get("model")
|
||||
if not params_model:
|
||||
await schedule_cmd.finish(f"插件 '{plugin_name}' 不支持设置额外参数。")
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
await schedule_cmd.finish(f"插件 '{plugin_name}' 的参数模型配置错误。")
|
||||
|
||||
raw_kwargs = {}
|
||||
try:
|
||||
for item in kwargs_str.split(","):
|
||||
key, value = item.strip().split("=", 1)
|
||||
raw_kwargs[key.strip()] = value
|
||||
except Exception as e:
|
||||
await schedule_cmd.finish(
|
||||
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
model_validate = getattr(params_model, "model_validate", None)
|
||||
if not model_validate:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 的参数模型不支持验证。"
|
||||
)
|
||||
return
|
||||
|
||||
validated_model = model_validate(raw_kwargs)
|
||||
|
||||
model_dump = getattr(validated_model, "model_dump", None)
|
||||
if not model_dump:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 的参数模型不支持导出。"
|
||||
)
|
||||
return
|
||||
|
||||
job_kwargs = model_dump()
|
||||
except ValidationError as e:
|
||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||
error_str = "\n".join(errors)
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{plugin_name}' 的任务参数验证失败:\n{error_str}"
|
||||
)
|
||||
return
|
||||
|
||||
target_group_id: str | None
|
||||
current_group_id = getattr(event, "group_id", None)
|
||||
|
||||
if group_id and group_id.lower() == "all":
|
||||
target_group_id = "__ALL_GROUPS__"
|
||||
elif all_enabled.available:
|
||||
target_group_id = "__ALL_GROUPS__"
|
||||
elif group_id:
|
||||
target_group_id = group_id
|
||||
elif current_group_id:
|
||||
target_group_id = str(current_group_id)
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"私聊中设置定时任务时,必须使用 -g <群号> 或 --all 选项指定目标。"
|
||||
)
|
||||
return
|
||||
|
||||
success, msg = await scheduler_manager.add_schedule(
|
||||
plugin_name,
|
||||
target_group_id,
|
||||
trigger_type,
|
||||
trigger_config,
|
||||
job_kwargs,
|
||||
bot_id=bot_id_to_operate,
|
||||
)
|
||||
|
||||
if target_group_id == "__ALL_GROUPS__":
|
||||
target_desc = f"所有群组 (Bot: {bot_id_to_operate})"
|
||||
elif target_group_id is None:
|
||||
target_desc = "全局"
|
||||
else:
|
||||
target_desc = f"群组 {target_group_id}"
|
||||
|
||||
if success:
|
||||
await schedule_cmd.finish(f"已成功为 [{target_desc}] {msg}")
|
||||
else:
|
||||
await schedule_cmd.finish(f"为 [{target_desc}] 设置任务失败: {msg}")
|
||||
|
||||
|
||||
@schedule_cmd.assign("删除")
|
||||
async def _(
|
||||
target: TargetScope = Depends(create_target_parser("删除")),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if isinstance(target, TargetByID):
|
||||
_, message = await scheduler_manager.remove_schedule_by_id(target.id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetByPlugin):
|
||||
p_name = target.plugin
|
||||
if p_name not in scheduler_manager.get_registered_plugins():
|
||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
||||
|
||||
if target.all_groups:
|
||||
removed_count = await scheduler_manager.remove_schedule_for_all(
|
||||
p_name, bot_id=bot_id_to_operate
|
||||
)
|
||||
message = (
|
||||
f"已取消了 {removed_count} 个群组的插件 '{p_name}' 定时任务。"
|
||||
if removed_count > 0
|
||||
else f"没有找到插件 '{p_name}' 的定时任务。"
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.remove_schedule(
|
||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetAll):
|
||||
if target.for_group:
|
||||
_, message = await scheduler_manager.remove_schedules_by_group(
|
||||
target.for_group
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.remove_all_schedules()
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
else:
|
||||
await schedule_cmd.finish(
|
||||
"删除任务失败:请提供任务ID,或通过 -p <插件> 或 -all 指定要删除的任务。"
|
||||
)
|
||||
|
||||
|
||||
@schedule_cmd.assign("暂停")
|
||||
async def _(
|
||||
target: TargetScope = Depends(create_target_parser("暂停")),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if isinstance(target, TargetByID):
|
||||
_, message = await scheduler_manager.pause_schedule(target.id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetByPlugin):
|
||||
p_name = target.plugin
|
||||
if p_name not in scheduler_manager.get_registered_plugins():
|
||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
||||
|
||||
if target.all_groups:
|
||||
_, message = await scheduler_manager.pause_schedules_by_plugin(p_name)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.pause_schedule_by_plugin_group(
|
||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetAll):
|
||||
if target.for_group:
|
||||
_, message = await scheduler_manager.pause_schedules_by_group(
|
||||
target.for_group
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.pause_all_schedules()
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
else:
|
||||
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
|
||||
|
||||
|
||||
@schedule_cmd.assign("恢复")
|
||||
async def _(
|
||||
target: TargetScope = Depends(create_target_parser("恢复")),
|
||||
bot_id_to_operate: str = Depends(GetBotId),
|
||||
):
|
||||
if isinstance(target, TargetByID):
|
||||
_, message = await scheduler_manager.resume_schedule(target.id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetByPlugin):
|
||||
p_name = target.plugin
|
||||
if p_name not in scheduler_manager.get_registered_plugins():
|
||||
await schedule_cmd.finish(f"未找到插件 '{p_name}'。")
|
||||
|
||||
if target.all_groups:
|
||||
_, message = await scheduler_manager.resume_schedules_by_plugin(p_name)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.resume_schedule_by_plugin_group(
|
||||
p_name, target.group_id, bot_id=bot_id_to_operate
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
elif isinstance(target, TargetAll):
|
||||
if target.for_group:
|
||||
_, message = await scheduler_manager.resume_schedules_by_group(
|
||||
target.for_group
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
else:
|
||||
_, message = await scheduler_manager.resume_all_schedules()
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
else:
|
||||
await schedule_cmd.finish("请提供任务ID、使用 -p <插件> 或 -all 选项。")
|
||||
|
||||
|
||||
@schedule_cmd.assign("执行")
|
||||
async def _(schedule_id: int):
|
||||
_, message = await scheduler_manager.trigger_now(schedule_id)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("更新")
|
||||
async def _(
|
||||
schedule_id: int,
|
||||
cron_expr: str | None = None,
|
||||
interval_expr: str | None = None,
|
||||
date_expr: str | None = None,
|
||||
daily_expr: str | None = None,
|
||||
kwargs_str: str | None = None,
|
||||
):
|
||||
if not any([cron_expr, interval_expr, date_expr, daily_expr, kwargs_str]):
|
||||
await schedule_cmd.finish(
|
||||
"请提供需要更新的时间 (--cron/--interval/--date/--daily) 或参数 (--kwargs)"
|
||||
)
|
||||
|
||||
trigger_config = None
|
||||
trigger_type = None
|
||||
try:
|
||||
if cron_expr:
|
||||
trigger_type = "cron"
|
||||
parts = cron_expr.split()
|
||||
if len(parts) != 5:
|
||||
raise ValueError("Cron 表达式必须有5个部分")
|
||||
cron_keys = ["minute", "hour", "day", "month", "day_of_week"]
|
||||
trigger_config = dict(zip(cron_keys, parts))
|
||||
elif interval_expr:
|
||||
trigger_type = "interval"
|
||||
trigger_config = _parse_interval(interval_expr)
|
||||
elif date_expr:
|
||||
trigger_type = "date"
|
||||
trigger_config = {"run_date": datetime.fromisoformat(date_expr)}
|
||||
elif daily_expr:
|
||||
trigger_type = "cron"
|
||||
trigger_config = _parse_daily_time(daily_expr)
|
||||
except ValueError as e:
|
||||
await schedule_cmd.finish(f"时间参数解析错误: {e}")
|
||||
|
||||
job_kwargs = None
|
||||
if kwargs_str:
|
||||
schedule = await scheduler_manager.get_schedule_by_id(schedule_id)
|
||||
if not schedule:
|
||||
await schedule_cmd.finish(f"未找到 ID 为 {schedule_id} 的任务。")
|
||||
|
||||
task_meta = scheduler_manager._registered_tasks.get(schedule.plugin_name)
|
||||
if not task_meta or not (params_model := task_meta.get("model")):
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 未定义参数模型,无法更新参数。"
|
||||
)
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 的参数模型配置错误。"
|
||||
)
|
||||
|
||||
raw_kwargs = {}
|
||||
try:
|
||||
for item in kwargs_str.split(","):
|
||||
key, value = item.strip().split("=", 1)
|
||||
raw_kwargs[key.strip()] = value
|
||||
except Exception as e:
|
||||
await schedule_cmd.finish(
|
||||
f"参数格式错误,请使用 'key=value,key2=value2' 格式。错误: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
model_validate = getattr(params_model, "model_validate", None)
|
||||
if not model_validate:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 的参数模型不支持验证。"
|
||||
)
|
||||
return
|
||||
|
||||
validated_model = model_validate(raw_kwargs)
|
||||
|
||||
model_dump = getattr(validated_model, "model_dump", None)
|
||||
if not model_dump:
|
||||
await schedule_cmd.finish(
|
||||
f"插件 '{schedule.plugin_name}' 的参数模型不支持导出。"
|
||||
)
|
||||
return
|
||||
|
||||
job_kwargs = model_dump(exclude_unset=True)
|
||||
except ValidationError as e:
|
||||
errors = [f" - {err['loc'][0]}: {err['msg']}" for err in e.errors()]
|
||||
error_str = "\n".join(errors)
|
||||
await schedule_cmd.finish(f"更新的参数验证失败:\n{error_str}")
|
||||
return
|
||||
|
||||
_, message = await scheduler_manager.update_schedule(
|
||||
schedule_id, trigger_type, trigger_config, job_kwargs
|
||||
)
|
||||
await schedule_cmd.finish(message)
|
||||
|
||||
|
||||
@schedule_cmd.assign("插件列表")
|
||||
async def _():
|
||||
registered_plugins = scheduler_manager.get_registered_plugins()
|
||||
if not registered_plugins:
|
||||
await schedule_cmd.finish("当前没有已注册的定时任务插件。")
|
||||
|
||||
message_parts = ["📋 已注册的定时任务插件:"]
|
||||
for i, plugin_name in enumerate(registered_plugins, 1):
|
||||
task_meta = scheduler_manager._registered_tasks[plugin_name]
|
||||
params_model = task_meta.get("model")
|
||||
|
||||
if not params_model:
|
||||
message_parts.append(f"{i}. {plugin_name} - 无参数")
|
||||
continue
|
||||
|
||||
if not (isinstance(params_model, type) and issubclass(params_model, BaseModel)):
|
||||
message_parts.append(f"{i}. {plugin_name} - ⚠️ 参数模型配置错误")
|
||||
continue
|
||||
|
||||
model_fields = getattr(params_model, "model_fields", None)
|
||||
if model_fields:
|
||||
param_info = ", ".join(
|
||||
f"{field_name}({_get_type_name(field_info.annotation)})"
|
||||
for field_name, field_info in model_fields.items()
|
||||
)
|
||||
message_parts.append(f"{i}. {plugin_name} - 参数: {param_info}")
|
||||
else:
|
||||
message_parts.append(f"{i}. {plugin_name} - 无参数")
|
||||
|
||||
await schedule_cmd.finish("\n".join(message_parts))
|
||||
|
||||
|
||||
@schedule_cmd.assign("状态")
|
||||
async def _(schedule_id: int):
|
||||
status = await scheduler_manager.get_schedule_status(schedule_id)
|
||||
if not status:
|
||||
await schedule_cmd.finish(f"未找到ID为 {schedule_id} 的定时任务。")
|
||||
|
||||
info_lines = [
|
||||
f"📋 定时任务详细信息 (ID: {schedule_id})",
|
||||
"--------------------",
|
||||
f"▫️ 插件: {status['plugin_name']}",
|
||||
f"▫️ Bot ID: {status.get('bot_id') or '默认'}",
|
||||
f"▫️ 目标: {status['group_id'] or '全局'}",
|
||||
f"▫️ 状态: {'✔️ 已启用' if status['is_enabled'] else '⏸️ 已暂停'}",
|
||||
f"▫️ 下次运行: {status['next_run_time']}",
|
||||
f"▫️ 触发规则: {_format_trigger(status)}",
|
||||
f"▫️ 任务参数: {_format_params(status)}",
|
||||
]
|
||||
await schedule_cmd.finish("\n".join(info_lines))
|
||||
@ -345,10 +345,11 @@ class ShopManage:
|
||||
if num > param.max_num_limit:
|
||||
return f"{goods_info.goods_name} 单次使用最大数量为{param.max_num_limit}..."
|
||||
await cls.run_before_after(goods, param, session, message, "before", **kwargs)
|
||||
result = await cls.__run(goods, param, session, message, **kwargs)
|
||||
await UserConsole.use_props(
|
||||
session.user.id, goods_info.uuid, num, PlatformUtils.get_platform(session)
|
||||
)
|
||||
result = await cls.__run(goods, param, session, message, **kwargs)
|
||||
|
||||
await cls.run_before_after(goods, param, session, message, "after", **kwargs)
|
||||
if not result and param.send_success_msg:
|
||||
result = f"使用道具 {goods.name} {num} 次成功!"
|
||||
|
||||
@ -53,10 +53,7 @@ async def _(
|
||||
)
|
||||
|
||||
|
||||
@scheduler.scheduled_job(
|
||||
"interval",
|
||||
minutes=1,
|
||||
)
|
||||
@scheduler.scheduled_job("interval", minutes=1, max_instances=5)
|
||||
async def _():
|
||||
try:
|
||||
call_list = TEMP_LIST.copy()
|
||||
|
||||
@ -110,7 +110,7 @@ async def enable_plugin(
|
||||
)
|
||||
await BotConsole.enable_plugin(None, plugin.module)
|
||||
await MessageUtils.build_message(
|
||||
f"已禁用全部 bot 的插件: {plugin_name.result}"
|
||||
f"已开启全部 bot 的插件: {plugin_name.result}"
|
||||
).finish()
|
||||
elif bot_id.available:
|
||||
logger.info(
|
||||
|
||||
@ -92,7 +92,7 @@ async def enable_task(
|
||||
)
|
||||
await BotConsole.enable_task(None, task.module)
|
||||
await MessageUtils.build_message(
|
||||
f"已禁用全部 bot 的被动: {task_name.available}"
|
||||
f"已开启全部 bot 的被动: {task_name.available}"
|
||||
).finish()
|
||||
elif bot_id.available:
|
||||
logger.info(
|
||||
|
||||
@ -2,7 +2,7 @@ from io import BytesIO
|
||||
|
||||
from arclet.alconna import Args, Option
|
||||
from arclet.alconna.typing import CommandMeta
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot.rule import to_me
|
||||
@ -10,10 +10,13 @@ from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
AlconnaQuery,
|
||||
Arparma,
|
||||
Match,
|
||||
Query,
|
||||
Reply,
|
||||
on_alconna,
|
||||
store_true,
|
||||
)
|
||||
from nonebot_plugin_alconna.uniseg.tools import reply_fetch
|
||||
from nonebot_plugin_session import EventSession
|
||||
|
||||
from zhenxun.configs.config import BotConfig
|
||||
@ -54,7 +57,7 @@ __plugin_meta__ = PluginMetadata(
|
||||
_req_matcher = on_alconna(
|
||||
Alconna(
|
||||
"请求处理",
|
||||
Args["handle", ["-fa", "-fr", "-fi", "-ga", "-gr", "-gi"]]["id", int],
|
||||
Args["handle", ["-fa", "-fr", "-fi", "-ga", "-gr", "-gi"]]["id?", int],
|
||||
meta=CommandMeta(
|
||||
description="好友/群组请求处理",
|
||||
usage=usage,
|
||||
@ -105,12 +108,12 @@ _clear_matcher = on_alconna(
|
||||
)
|
||||
|
||||
reg_arg_list = [
|
||||
(r"同意好友请求", ["-fa", "{%0}"]),
|
||||
(r"拒绝好友请求", ["-fr", "{%0}"]),
|
||||
(r"忽略好友请求", ["-fi", "{%0}"]),
|
||||
(r"同意群组请求", ["-ga", "{%0}"]),
|
||||
(r"拒绝群组请求", ["-gr", "{%0}"]),
|
||||
(r"忽略群组请求", ["-gi", "{%0}"]),
|
||||
(r"同意好友请求\s*(?P<id>\d*)", ["-fa", "{id}"]),
|
||||
(r"拒绝好友请求\s*(?P<id>\d*)", ["-fr", "{id}"]),
|
||||
(r"忽略好友请求\s*(?P<id>\d*)", ["-fi", "{id}"]),
|
||||
(r"同意群组请求\s*(?P<id>\d*)", ["-ga", "{id}"]),
|
||||
(r"拒绝群组请求\s*(?P<id>\d*)", ["-gr", "{id}"]),
|
||||
(r"忽略群组请求\s*(?P<id>\d*)", ["-gi", "{id}"]),
|
||||
]
|
||||
|
||||
for r in reg_arg_list:
|
||||
@ -125,32 +128,48 @@ for r in reg_arg_list:
|
||||
@_req_matcher.handle()
|
||||
async def _(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
session: EventSession,
|
||||
handle: str,
|
||||
id: int,
|
||||
id: Match[int],
|
||||
arparma: Arparma,
|
||||
):
|
||||
reply: Reply | None = None
|
||||
type_dict = {
|
||||
"a": RequestHandleType.APPROVE,
|
||||
"r": RequestHandleType.REFUSED,
|
||||
"i": RequestHandleType.IGNORE,
|
||||
}
|
||||
if not id.available:
|
||||
reply = await reply_fetch(event, bot)
|
||||
if not reply:
|
||||
await MessageUtils.build_message("请引用消息处理或添加处理Id.").finish()
|
||||
handle_id = id.result
|
||||
if reply:
|
||||
db_data = await FgRequest.get_or_none(message_ids__contains=reply.id)
|
||||
if not db_data:
|
||||
await MessageUtils.build_message(
|
||||
"未发现此消息的Id,请使用Id进行处理..."
|
||||
).finish(reply_to=True)
|
||||
handle_id = db_data.id
|
||||
req = None
|
||||
handle_type = type_dict[handle[-1]]
|
||||
try:
|
||||
if handle_type == RequestHandleType.APPROVE:
|
||||
req = await FgRequest.approve(bot, id)
|
||||
req = await FgRequest.approve(bot, handle_id)
|
||||
if handle_type == RequestHandleType.REFUSED:
|
||||
req = await FgRequest.refused(bot, id)
|
||||
req = await FgRequest.refused(bot, handle_id)
|
||||
if handle_type == RequestHandleType.IGNORE:
|
||||
req = await FgRequest.ignore(id)
|
||||
req = await FgRequest.ignore(handle_id)
|
||||
except NotFoundError:
|
||||
await MessageUtils.build_message("未发现此id的请求...").finish(reply_to=True)
|
||||
except Exception:
|
||||
await MessageUtils.build_message("其他错误, 可能flag已失效...").finish(
|
||||
reply_to=True
|
||||
)
|
||||
logger.info("处理请求", arparma.header_result, session=session)
|
||||
logger.info(
|
||||
f"处理请求 Id: {req.id if req else ''}", arparma.header_result, session=session
|
||||
)
|
||||
await MessageUtils.build_message("成功处理请求!").send(reply_to=True)
|
||||
if req and handle_type == RequestHandleType.APPROVE:
|
||||
await bot.send_private_msg(
|
||||
|
||||
@ -29,8 +29,7 @@ from .public import init_public
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="WebUi",
|
||||
description="WebUi API",
|
||||
usage="""
|
||||
""".strip(),
|
||||
usage='"""\n """.strip(),',
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="0.1",
|
||||
@ -83,7 +82,6 @@ BaseApiRouter.include_router(plugin_router)
|
||||
BaseApiRouter.include_router(system_router)
|
||||
BaseApiRouter.include_router(menu_router)
|
||||
|
||||
|
||||
WsApiRouter = APIRouter(prefix="/zhenxun/socket")
|
||||
|
||||
WsApiRouter.include_router(ws_log_routes)
|
||||
@ -94,6 +92,8 @@ WsApiRouter.include_router(chat_routes)
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
try:
|
||||
# 存储任务引用的列表,防止任务被垃圾回收
|
||||
_tasks = []
|
||||
|
||||
async def log_sink(message: str):
|
||||
loop = None
|
||||
@ -104,7 +104,8 @@ async def _():
|
||||
logger.warning("Web Ui log_sink", e=e)
|
||||
if not loop:
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.create_task(LOG_STORAGE.add(message.rstrip("\n"))) # noqa: RUF006
|
||||
# 存储任务引用到外部列表中
|
||||
_tasks.append(loop.create_task(LOG_STORAGE.add(message.rstrip("\n"))))
|
||||
|
||||
logger_.add(
|
||||
log_sink, colorize=True, filter=default_filter, format=default_format
|
||||
|
||||
@ -46,7 +46,10 @@ class MenuManage:
|
||||
icon="database",
|
||||
),
|
||||
MenuItem(
|
||||
name="系统信息", module="system", router="/system", icon="system"
|
||||
name="文件管理", module="system", router="/system", icon="system"
|
||||
),
|
||||
MenuItem(
|
||||
name="关于我们", module="about", router="/about", icon="about"
|
||||
),
|
||||
]
|
||||
self.save()
|
||||
|
||||
@ -16,7 +16,7 @@ from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
from ....base_model import Result
|
||||
from ....config import QueryDateType
|
||||
from ....utils import authentication, get_system_status
|
||||
from ....utils import authentication, clear_help_image, get_system_status
|
||||
from .data_source import ApiDataSource
|
||||
from .model import (
|
||||
ActiveGroup,
|
||||
@ -234,6 +234,7 @@ async def _(param: BotManageUpdateParam):
|
||||
bot_data.block_plugins = CommonUtils.convert_module_format(param.block_plugins)
|
||||
bot_data.block_tasks = CommonUtils.convert_module_format(param.block_tasks)
|
||||
await bot_data.save(update_fields=["block_plugins", "block_tasks"])
|
||||
clear_help_image()
|
||||
return Result.ok()
|
||||
except Exception as e:
|
||||
logger.error(f"{router.prefix}/update_bot_manage 调用错误", "WebUi", e=e)
|
||||
|
||||
@ -92,7 +92,7 @@ class ApiDataSource:
|
||||
"""
|
||||
version_file = Path() / "__version__"
|
||||
if version_file.exists():
|
||||
if text := version_file.open().read():
|
||||
if text := version_file.open(encoding="utf-8").read():
|
||||
return text.replace("__version__: ", "").strip()
|
||||
return "unknown"
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter
|
||||
import nonebot
|
||||
from nonebot import on_message
|
||||
@ -49,13 +51,14 @@ async def message_handle(
|
||||
message: UniMsg,
|
||||
group_id: str | None,
|
||||
):
|
||||
time = str(datetime.now().replace(microsecond=0))
|
||||
messages = []
|
||||
for m in message:
|
||||
if isinstance(m, Text | str):
|
||||
messages.append(MessageItem(type="text", msg=str(m)))
|
||||
messages.append(MessageItem(type="text", msg=str(m), time=time))
|
||||
elif isinstance(m, Image):
|
||||
if m.url:
|
||||
messages.append(MessageItem(type="img", msg=m.url))
|
||||
messages.append(MessageItem(type="img", msg=m.url, time=time))
|
||||
elif isinstance(m, At):
|
||||
if group_id:
|
||||
if m.target == "0":
|
||||
@ -72,9 +75,9 @@ async def message_handle(
|
||||
uname = group_user.user_name
|
||||
if m.target not in ID2NAME[group_id]:
|
||||
ID2NAME[group_id][m.target] = uname
|
||||
messages.append(MessageItem(type="at", msg=f"@{uname}"))
|
||||
messages.append(MessageItem(type="at", msg=f"@{uname}", time=time))
|
||||
elif isinstance(m, Hyper):
|
||||
messages.append(MessageItem(type="text", msg="[分享消息]"))
|
||||
messages.append(MessageItem(type="text", msg="[分享消息]", time=time))
|
||||
return messages
|
||||
|
||||
|
||||
|
||||
@ -237,6 +237,8 @@ class MessageItem(BaseModel):
|
||||
"""消息类型"""
|
||||
msg: str
|
||||
"""内容"""
|
||||
time: str
|
||||
"""发送日期"""
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
|
||||
@ -4,15 +4,20 @@ from fastapi.responses import JSONResponse
|
||||
from zhenxun.models.plugin_info import PluginInfo as DbPluginInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import BlockType, PluginType
|
||||
from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager
|
||||
|
||||
from ....base_model import Result
|
||||
from ....utils import authentication
|
||||
from ....utils import authentication, clear_help_image
|
||||
from .data_source import ApiDataSource
|
||||
from .model import (
|
||||
BatchUpdatePlugins,
|
||||
BatchUpdateResult,
|
||||
InstallDependenciesPayload,
|
||||
PluginCount,
|
||||
PluginDetail,
|
||||
PluginInfo,
|
||||
PluginSwitch,
|
||||
RenameMenuTypePayload,
|
||||
UpdatePlugin,
|
||||
)
|
||||
|
||||
@ -30,9 +35,8 @@ async def _(
|
||||
plugin_type: list[PluginType] = Query(None), menu_type: str | None = None
|
||||
) -> Result[list[PluginInfo]]:
|
||||
try:
|
||||
return Result.ok(
|
||||
await ApiDataSource.get_plugin_list(plugin_type, menu_type), "拿到信息啦!"
|
||||
)
|
||||
result = await ApiDataSource.get_plugin_list(plugin_type, menu_type)
|
||||
return Result.ok(result, "拿到信息啦!")
|
||||
except Exception as e:
|
||||
logger.error(f"{router.prefix}/get_plugin_list 调用错误", "WebUi", e=e)
|
||||
return Result.fail(f"发生了一点错误捏 {type(e)}: {e}")
|
||||
@ -78,6 +82,7 @@ async def _() -> Result[PluginCount]:
|
||||
async def _(param: UpdatePlugin) -> Result:
|
||||
try:
|
||||
await ApiDataSource.update_plugin(param)
|
||||
clear_help_image()
|
||||
return Result.ok(info="已经帮你写好啦!")
|
||||
except (ValueError, KeyError):
|
||||
return Result.fail("插件数据不存在...")
|
||||
@ -105,6 +110,7 @@ async def _(param: PluginSwitch) -> Result:
|
||||
db_plugin.block_type = None
|
||||
db_plugin.status = True
|
||||
await db_plugin.save()
|
||||
clear_help_image()
|
||||
return Result.ok(info="成功改变了开关状态!")
|
||||
except Exception as e:
|
||||
logger.error(f"{router.prefix}/change_switch 调用错误", "WebUi", e=e)
|
||||
@ -144,11 +150,89 @@ async def _() -> Result[list[str]]:
|
||||
)
|
||||
async def _(module: str) -> Result[PluginDetail]:
|
||||
try:
|
||||
return Result.ok(
|
||||
await ApiDataSource.get_plugin_detail(module), "已经帮你写好啦!"
|
||||
)
|
||||
detail = await ApiDataSource.get_plugin_detail(module)
|
||||
return Result.ok(detail, "已经帮你写好啦!")
|
||||
except (ValueError, KeyError):
|
||||
return Result.fail("插件数据不存在...")
|
||||
except Exception as e:
|
||||
logger.error(f"{router.prefix}/get_plugin 调用错误", "WebUi", e=e)
|
||||
return Result.fail(f"{type(e)}: {e}")
|
||||
|
||||
|
||||
@router.put(
|
||||
"/plugins/batch_update",
|
||||
dependencies=[authentication()],
|
||||
response_model=Result[BatchUpdateResult],
|
||||
response_class=JSONResponse,
|
||||
description="批量更新插件配置",
|
||||
)
|
||||
async def _(
|
||||
params: BatchUpdatePlugins,
|
||||
) -> Result[BatchUpdateResult]:
|
||||
"""批量更新插件配置,如开关、类型等"""
|
||||
try:
|
||||
result_dict = await ApiDataSource.batch_update_plugins(params=params)
|
||||
result_model = BatchUpdateResult(
|
||||
success=result_dict["success"],
|
||||
updated_count=result_dict["updated_count"],
|
||||
errors=result_dict["errors"],
|
||||
)
|
||||
clear_help_image()
|
||||
return Result.ok(result_model, "插件配置更新完成")
|
||||
except Exception as e:
|
||||
logger.error(f"{router.prefix}/plugins/batch_update 调用错误", "WebUi", e=e)
|
||||
return Result.fail(f"发生了一点错误捏 {type(e)}: {e}")
|
||||
|
||||
|
||||
# 新增:重命名菜单类型路由
|
||||
@router.put(
|
||||
"/menu_type/rename",
|
||||
dependencies=[authentication()],
|
||||
response_model=Result,
|
||||
description="重命名菜单类型",
|
||||
)
|
||||
async def _(payload: RenameMenuTypePayload) -> Result[str]:
|
||||
try:
|
||||
result = await ApiDataSource.rename_menu_type(
|
||||
old_name=payload.old_name, new_name=payload.new_name
|
||||
)
|
||||
if result.get("success"):
|
||||
clear_help_image()
|
||||
return Result.ok(
|
||||
info=result.get(
|
||||
"info",
|
||||
f"成功将 {result.get('updated_count', 0)} 个插件的菜单类型从 "
|
||||
f"'{payload.old_name}' 修改为 '{payload.new_name}'",
|
||||
)
|
||||
)
|
||||
else:
|
||||
return Result.fail(info=result.get("info", "重命名失败"))
|
||||
except ValueError as ve:
|
||||
return Result.fail(info=str(ve))
|
||||
except RuntimeError as re:
|
||||
logger.error(f"{router.prefix}/menu_type/rename 调用错误", "WebUi", e=re)
|
||||
return Result.fail(info=str(re))
|
||||
except Exception as e:
|
||||
logger.error(f"{router.prefix}/menu_type/rename 调用错误", "WebUi", e=e)
|
||||
return Result.fail(info=f"发生未知错误: {type(e).__name__}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/install_dependencies",
|
||||
dependencies=[authentication()],
|
||||
response_model=Result,
|
||||
response_class=JSONResponse,
|
||||
description="安装/卸载依赖",
|
||||
)
|
||||
async def _(payload: InstallDependenciesPayload) -> Result:
|
||||
try:
|
||||
if not payload.dependencies:
|
||||
return Result.fail("依赖列表不能为空")
|
||||
if payload.handle_type == "install":
|
||||
result = VirtualEnvPackageManager.install(payload.dependencies)
|
||||
else:
|
||||
result = VirtualEnvPackageManager.uninstall(payload.dependencies)
|
||||
return Result.ok(result)
|
||||
except Exception as e:
|
||||
logger.error(f"{router.prefix}/install_dependencies 调用错误", "WebUi", e=e)
|
||||
return Result.fail(f"发生了一点错误捏 {type(e)}: {e}")
|
||||
|
||||
@ -2,13 +2,20 @@ import re
|
||||
|
||||
import cattrs
|
||||
from fastapi import Query
|
||||
from tortoise.exceptions import DoesNotExist
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.utils import ConfigGroup
|
||||
from zhenxun.models.plugin_info import PluginInfo as DbPluginInfo
|
||||
from zhenxun.utils.enum import BlockType, PluginType
|
||||
|
||||
from .model import PluginConfig, PluginDetail, PluginInfo, UpdatePlugin
|
||||
from .model import (
|
||||
BatchUpdatePlugins,
|
||||
PluginConfig,
|
||||
PluginDetail,
|
||||
PluginInfo,
|
||||
UpdatePlugin,
|
||||
)
|
||||
|
||||
|
||||
class ApiDataSource:
|
||||
@ -44,6 +51,11 @@ class ApiDataSource:
|
||||
level=plugin.level,
|
||||
status=plugin.status,
|
||||
author=plugin.author,
|
||||
block_type=plugin.block_type,
|
||||
is_builtin="builtin_plugins" in plugin.module_path
|
||||
or plugin.plugin_type == PluginType.HIDDEN,
|
||||
allow_setting=plugin.plugin_type != PluginType.HIDDEN,
|
||||
allow_switch=plugin.plugin_type != PluginType.HIDDEN,
|
||||
)
|
||||
plugin_list.append(plugin_info)
|
||||
return plugin_list
|
||||
@ -69,7 +81,6 @@ class ApiDataSource:
|
||||
db_plugin.block_type = param.block_type
|
||||
db_plugin.status = param.block_type != BlockType.ALL
|
||||
await db_plugin.save()
|
||||
# 配置项
|
||||
if param.configs and (configs := Config.get(param.module)):
|
||||
for key in param.configs:
|
||||
if c := configs.configs.get(key):
|
||||
@ -80,6 +91,87 @@ class ApiDataSource:
|
||||
Config.save(save_simple_data=True)
|
||||
return db_plugin
|
||||
|
||||
@classmethod
|
||||
async def batch_update_plugins(cls, params: BatchUpdatePlugins) -> dict:
|
||||
"""批量更新插件数据
|
||||
|
||||
参数:
|
||||
params: BatchUpdatePlugins
|
||||
|
||||
返回:
|
||||
dict: 更新结果, 例如 {'success': True, 'updated_count': 5, 'errors': []}
|
||||
"""
|
||||
plugins_to_update_other_fields = []
|
||||
other_update_fields = set()
|
||||
updated_count = 0
|
||||
errors = []
|
||||
|
||||
for item in params.updates:
|
||||
try:
|
||||
db_plugin = await DbPluginInfo.get(module=item.module)
|
||||
plugin_changed_other = False
|
||||
plugin_changed_block = False
|
||||
|
||||
if db_plugin.block_type != item.block_type:
|
||||
db_plugin.block_type = item.block_type
|
||||
db_plugin.status = item.block_type != BlockType.ALL
|
||||
plugin_changed_block = True
|
||||
|
||||
if item.menu_type is not None and db_plugin.menu_type != item.menu_type:
|
||||
db_plugin.menu_type = item.menu_type
|
||||
other_update_fields.add("menu_type")
|
||||
plugin_changed_other = True
|
||||
|
||||
if (
|
||||
item.default_status is not None
|
||||
and db_plugin.default_status != item.default_status
|
||||
):
|
||||
db_plugin.default_status = item.default_status
|
||||
other_update_fields.add("default_status")
|
||||
plugin_changed_other = True
|
||||
|
||||
if plugin_changed_block:
|
||||
try:
|
||||
await db_plugin.save(update_fields=["block_type", "status"])
|
||||
updated_count += 1
|
||||
except Exception as e_save:
|
||||
errors.append(
|
||||
{
|
||||
"module": item.module,
|
||||
"error": f"Save block_type failed: {e_save!s}",
|
||||
}
|
||||
)
|
||||
plugin_changed_other = False
|
||||
|
||||
if plugin_changed_other:
|
||||
plugins_to_update_other_fields.append(db_plugin)
|
||||
|
||||
except DoesNotExist:
|
||||
errors.append({"module": item.module, "error": "Plugin not found"})
|
||||
except Exception as e:
|
||||
errors.append({"module": item.module, "error": str(e)})
|
||||
|
||||
bulk_updated_count = 0
|
||||
if plugins_to_update_other_fields and other_update_fields:
|
||||
try:
|
||||
await DbPluginInfo.bulk_update(
|
||||
plugins_to_update_other_fields, list(other_update_fields)
|
||||
)
|
||||
bulk_updated_count = len(plugins_to_update_other_fields)
|
||||
except Exception as e_bulk:
|
||||
errors.append(
|
||||
{
|
||||
"module": "batch_update_other",
|
||||
"error": f"Bulk update failed: {e_bulk!s}",
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"success": len(errors) == 0,
|
||||
"updated_count": updated_count + bulk_updated_count,
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def __build_plugin_config(
|
||||
cls, module: str, cfg: str, config: ConfigGroup
|
||||
@ -115,6 +207,41 @@ class ApiDataSource:
|
||||
type_inner=type_inner, # type: ignore
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def rename_menu_type(cls, old_name: str, new_name: str) -> dict:
|
||||
"""重命名菜单类型,并更新所有相关插件
|
||||
|
||||
参数:
|
||||
old_name: 旧菜单类型名称
|
||||
new_name: 新菜单类型名称
|
||||
|
||||
返回:
|
||||
dict: 更新结果, 例如 {'success': True, 'updated_count': 3}
|
||||
"""
|
||||
if not old_name or not new_name:
|
||||
raise ValueError("旧名称和新名称都不能为空")
|
||||
if old_name == new_name:
|
||||
return {
|
||||
"success": True,
|
||||
"updated_count": 0,
|
||||
"info": "新旧名称相同,无需更新",
|
||||
}
|
||||
|
||||
# 检查新名称是否已存在(理论上前端会校验,后端再保险一次)
|
||||
exists = await DbPluginInfo.filter(menu_type=new_name).exists()
|
||||
if exists:
|
||||
raise ValueError(f"新的菜单类型名称 '{new_name}' 已被其他插件使用")
|
||||
|
||||
try:
|
||||
# 使用 filter().update() 进行批量更新
|
||||
updated_count = await DbPluginInfo.filter(menu_type=old_name).update(
|
||||
menu_type=new_name
|
||||
)
|
||||
return {"success": True, "updated_count": updated_count}
|
||||
except Exception as e:
|
||||
# 可以添加更详细的日志记录
|
||||
raise RuntimeError(f"数据库更新菜单类型失败: {e!s}")
|
||||
|
||||
@classmethod
|
||||
async def get_plugin_detail(cls, module: str) -> PluginDetail:
|
||||
"""获取插件详情
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from zhenxun.utils.enum import BlockType
|
||||
|
||||
@ -37,19 +37,19 @@ class UpdatePlugin(BaseModel):
|
||||
module: str
|
||||
"""模块"""
|
||||
default_status: bool
|
||||
"""默认开关"""
|
||||
"""是否默认开启"""
|
||||
limit_superuser: bool
|
||||
"""限制超级用户"""
|
||||
cost_gold: int
|
||||
"""金币花费"""
|
||||
menu_type: str
|
||||
"""插件菜单类型"""
|
||||
"""是否限制超级用户"""
|
||||
level: int
|
||||
"""插件所需群权限"""
|
||||
"""等级"""
|
||||
cost_gold: int
|
||||
"""花费金币"""
|
||||
menu_type: str
|
||||
"""菜单类型"""
|
||||
block_type: BlockType | None = None
|
||||
"""禁用类型"""
|
||||
configs: dict[str, Any] | None = None
|
||||
"""配置项"""
|
||||
"""设置项"""
|
||||
|
||||
|
||||
class PluginInfo(BaseModel):
|
||||
@ -58,27 +58,33 @@ class PluginInfo(BaseModel):
|
||||
"""
|
||||
|
||||
module: str
|
||||
"""插件名称"""
|
||||
"""模块"""
|
||||
plugin_name: str
|
||||
"""插件中文名称"""
|
||||
"""插件名称"""
|
||||
default_status: bool
|
||||
"""默认开关"""
|
||||
"""是否默认开启"""
|
||||
limit_superuser: bool
|
||||
"""限制超级用户"""
|
||||
"""是否限制超级用户"""
|
||||
level: int
|
||||
"""等级"""
|
||||
cost_gold: int
|
||||
"""花费金币"""
|
||||
menu_type: str
|
||||
"""插件菜单类型"""
|
||||
"""菜单类型"""
|
||||
version: str
|
||||
"""插件版本"""
|
||||
level: int
|
||||
"""群权限"""
|
||||
"""版本"""
|
||||
status: bool
|
||||
"""当前状态"""
|
||||
"""状态"""
|
||||
author: str | None = None
|
||||
"""作者"""
|
||||
block_type: BlockType | None = None
|
||||
"""禁用类型"""
|
||||
block_type: BlockType | None = Field(None, description="插件禁用状态 (None: 启用)")
|
||||
"""禁用状态"""
|
||||
is_builtin: bool = False
|
||||
"""是否为内置插件"""
|
||||
allow_switch: bool = True
|
||||
"""是否允许开关"""
|
||||
allow_setting: bool = True
|
||||
"""是否允许设置"""
|
||||
|
||||
|
||||
class PluginConfig(BaseModel):
|
||||
@ -86,20 +92,13 @@ class PluginConfig(BaseModel):
|
||||
插件配置项
|
||||
"""
|
||||
|
||||
module: str
|
||||
"""模块"""
|
||||
key: str
|
||||
"""键"""
|
||||
value: Any
|
||||
"""值"""
|
||||
help: str | None = None
|
||||
"""帮助"""
|
||||
default_value: Any
|
||||
"""默认值"""
|
||||
type: Any = None
|
||||
"""值类型"""
|
||||
type_inner: list[str] | None = None
|
||||
"""List Tuple等内部类型检验"""
|
||||
module: str = Field(..., description="模块名")
|
||||
key: str = Field(..., description="键")
|
||||
value: Any = Field(None, description="值")
|
||||
help: str | None = Field(None, description="帮助信息")
|
||||
default_value: Any = Field(None, description="默认值")
|
||||
type: str | None = Field(None, description="类型")
|
||||
type_inner: list[str] | None = Field(None, description="内部类型")
|
||||
|
||||
|
||||
class PluginCount(BaseModel):
|
||||
@ -117,6 +116,21 @@ class PluginCount(BaseModel):
|
||||
"""其他插件"""
|
||||
|
||||
|
||||
class BatchUpdatePluginItem(BaseModel):
|
||||
module: str = Field(..., description="插件模块名")
|
||||
default_status: bool | None = Field(None, description="默认状态(开关)")
|
||||
menu_type: str | None = Field(None, description="菜单类型")
|
||||
block_type: BlockType | None = Field(
|
||||
None, description="插件禁用状态 (None: 启用, ALL: 禁用)"
|
||||
)
|
||||
|
||||
|
||||
class BatchUpdatePlugins(BaseModel):
|
||||
updates: list[BatchUpdatePluginItem] = Field(
|
||||
..., description="要批量更新的插件列表"
|
||||
)
|
||||
|
||||
|
||||
class PluginDetail(PluginInfo):
|
||||
"""
|
||||
插件详情
|
||||
@ -125,6 +139,38 @@ class PluginDetail(PluginInfo):
|
||||
config_list: list[PluginConfig]
|
||||
|
||||
|
||||
class RenameMenuTypePayload(BaseModel):
|
||||
old_name: str = Field(..., description="旧菜单类型名称")
|
||||
new_name: str = Field(..., description="新菜单类型名称")
|
||||
|
||||
|
||||
class PluginIr(BaseModel):
|
||||
id: int
|
||||
"""插件id"""
|
||||
|
||||
|
||||
class BatchUpdateResult(BaseModel):
|
||||
"""
|
||||
批量更新插件结果
|
||||
"""
|
||||
|
||||
success: bool = Field(..., description="是否全部成功")
|
||||
"""是否全部成功"""
|
||||
updated_count: int = Field(..., description="更新成功的数量")
|
||||
"""更新成功的数量"""
|
||||
errors: list[dict[str, str]] = Field(
|
||||
default_factory=list, description="错误信息列表"
|
||||
)
|
||||
"""错误信息列表"""
|
||||
|
||||
|
||||
class InstallDependenciesPayload(BaseModel):
|
||||
"""
|
||||
安装依赖
|
||||
"""
|
||||
|
||||
handle_type: Literal["install", "uninstall"] = Field(..., description="处理类型")
|
||||
"""处理类型"""
|
||||
|
||||
dependencies: list[str] = Field(..., description="依赖列表")
|
||||
"""依赖列表"""
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
from nonebot import require
|
||||
from nonebot.compat import model_dump
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.log import logger
|
||||
@ -22,12 +23,12 @@ router = APIRouter(prefix="/store")
|
||||
async def _() -> Result[dict]:
|
||||
try:
|
||||
require("plugin_store")
|
||||
from zhenxun.builtin_plugins.plugin_store import ShopManage
|
||||
from zhenxun.builtin_plugins.plugin_store import StoreManager
|
||||
|
||||
data = await ShopManage.get_data()
|
||||
data = await StoreManager.get_data()
|
||||
plugin_list = [
|
||||
{**data[name].to_dict(), "name": name, "id": idx}
|
||||
for idx, name in enumerate(data)
|
||||
{**model_dump(plugin), "name": plugin.name, "id": idx}
|
||||
for idx, plugin in enumerate(data)
|
||||
]
|
||||
modules = await PluginInfo.filter(load_status=True).values_list(
|
||||
"module", flat=True
|
||||
@ -48,9 +49,9 @@ async def _() -> Result[dict]:
|
||||
async def _(param: PluginIr) -> Result:
|
||||
try:
|
||||
require("plugin_store")
|
||||
from zhenxun.builtin_plugins.plugin_store import ShopManage
|
||||
from zhenxun.builtin_plugins.plugin_store import StoreManager
|
||||
|
||||
result = await ShopManage.add_plugin(param.id) # type: ignore
|
||||
result = await StoreManager.add_plugin(param.id) # type: ignore
|
||||
return Result.ok(info=result)
|
||||
except Exception as e:
|
||||
return Result.fail(f"安装插件失败: {type(e)}: {e}")
|
||||
@ -66,9 +67,9 @@ async def _(param: PluginIr) -> Result:
|
||||
async def _(param: PluginIr) -> Result:
|
||||
try:
|
||||
require("plugin_store")
|
||||
from zhenxun.builtin_plugins.plugin_store import ShopManage
|
||||
from zhenxun.builtin_plugins.plugin_store import StoreManager
|
||||
|
||||
result = await ShopManage.update_plugin(param.id) # type: ignore
|
||||
result = await StoreManager.update_plugin(param.id) # type: ignore
|
||||
return Result.ok(info=result)
|
||||
except Exception as e:
|
||||
return Result.fail(f"更新插件失败: {type(e)}: {e}")
|
||||
@ -84,9 +85,9 @@ async def _(param: PluginIr) -> Result:
|
||||
async def _(param: PluginIr) -> Result:
|
||||
try:
|
||||
require("plugin_store")
|
||||
from zhenxun.builtin_plugins.plugin_store import ShopManage
|
||||
from zhenxun.builtin_plugins.plugin_store import StoreManager
|
||||
|
||||
result = await ShopManage.remove_plugin(param.id) # type: ignore
|
||||
result = await StoreManager.remove_plugin(param.id) # type: ignore
|
||||
return Result.ok(info=result)
|
||||
except Exception as e:
|
||||
return Result.fail(f"移除插件失败: {type(e)}: {e}")
|
||||
|
||||
@ -36,6 +36,8 @@ async def _(path: str | None = None) -> Result[list[DirFile]]:
|
||||
is_image=is_image,
|
||||
name=file,
|
||||
parent=path,
|
||||
size=None if file_path.is_dir() else file_path.stat().st_size,
|
||||
mtime=file_path.stat().st_mtime,
|
||||
)
|
||||
)
|
||||
return Result.ok(data_list)
|
||||
@ -215,3 +217,13 @@ async def _(full_path: str) -> Result[str]:
|
||||
return Result.ok(BuildImage.open(path).pic2bs4())
|
||||
except Exception as e:
|
||||
return Result.warning_(f"获取图片失败: {e!s}")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/ping",
|
||||
response_model=Result[str],
|
||||
response_class=JSONResponse,
|
||||
description="检查服务器状态",
|
||||
)
|
||||
async def _() -> Result[str]:
|
||||
return Result.ok("pong")
|
||||
|
||||
@ -14,6 +14,10 @@ class DirFile(BaseModel):
|
||||
"""文件夹或文件名称"""
|
||||
parent: str | None = None
|
||||
"""父级"""
|
||||
size: int | None = None
|
||||
"""文件大小"""
|
||||
mtime: float | None = None
|
||||
"""修改时间"""
|
||||
|
||||
|
||||
class DeleteFile(BaseModel):
|
||||
|
||||
@ -11,7 +11,7 @@ import psutil
|
||||
import ujson as json
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.path_config import DATA_PATH
|
||||
from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
|
||||
|
||||
from .base_model import SystemFolderSize, SystemStatus, User
|
||||
|
||||
@ -28,6 +28,22 @@ if token_file.exists():
|
||||
token_data = json.load(open(token_file, encoding="utf8"))
|
||||
|
||||
|
||||
GROUP_HELP_PATH = DATA_PATH / "group_help"
|
||||
SIMPLE_HELP_IMAGE = IMAGE_PATH / "SIMPLE_HELP.png"
|
||||
SIMPLE_DETAIL_HELP_IMAGE = IMAGE_PATH / "SIMPLE_DETAIL_HELP.png"
|
||||
|
||||
|
||||
def clear_help_image():
|
||||
"""清理帮助图片"""
|
||||
if SIMPLE_HELP_IMAGE.exists():
|
||||
SIMPLE_HELP_IMAGE.unlink()
|
||||
if SIMPLE_DETAIL_HELP_IMAGE.exists():
|
||||
SIMPLE_DETAIL_HELP_IMAGE.unlink()
|
||||
for file in GROUP_HELP_PATH.iterdir():
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
|
||||
|
||||
def get_user(uname: str) -> User | None:
|
||||
"""获取账号密码
|
||||
|
||||
|
||||
@ -13,8 +13,8 @@ class BotSetting(BaseModel):
|
||||
"""回复时NICKNAME"""
|
||||
system_proxy: str | None = None
|
||||
"""系统代理"""
|
||||
db_url: str = ""
|
||||
"""数据库链接"""
|
||||
db_url: str = "sqlite:data/zhenxun.db"
|
||||
"""数据库链接, 默认值为sqlite:data/zhenxun.db"""
|
||||
platform_superusers: dict[str, list[str]] = Field(default_factory=dict)
|
||||
"""平台超级用户"""
|
||||
qbot_id_data: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
@ -1,89 +1,82 @@
|
||||
from collections.abc import Callable
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, TypeVar, get_args, get_origin
|
||||
|
||||
import cattrs
|
||||
from nonebot.compat import model_dump
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import VERSION, BaseModel, Field
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.scanner import ScannerError
|
||||
|
||||
from zhenxun.configs.path_config import DATA_PATH
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import BlockType, LimitWatchType, PluginLimitType, PluginType
|
||||
|
||||
from .models import (
|
||||
AICallableParam,
|
||||
AICallableProperties,
|
||||
AICallableTag,
|
||||
BaseBlock,
|
||||
Command,
|
||||
ConfigModel,
|
||||
Example,
|
||||
PluginCdBlock,
|
||||
PluginCountBlock,
|
||||
PluginExtraData,
|
||||
PluginSetting,
|
||||
RegisterConfig,
|
||||
Task,
|
||||
)
|
||||
|
||||
_yaml = YAML(pure=True)
|
||||
_yaml.indent = 2
|
||||
_yaml.allow_unicode = True
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class Example(BaseModel):
|
||||
|
||||
class NoSuchConfig(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _dump_pydantic_obj(obj: Any) -> Any:
|
||||
"""
|
||||
示例
|
||||
递归地将一个对象内部的 Pydantic BaseModel 实例转换为字典。
|
||||
支持单个实例、实例列表、实例字典等情况。
|
||||
"""
|
||||
|
||||
exec: str
|
||||
"""执行命令"""
|
||||
description: str = ""
|
||||
"""命令描述"""
|
||||
if isinstance(obj, BaseModel):
|
||||
return model_dump(obj)
|
||||
if isinstance(obj, list):
|
||||
return [_dump_pydantic_obj(item) for item in obj]
|
||||
if isinstance(obj, dict):
|
||||
return {key: _dump_pydantic_obj(value) for key, value in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class Command(BaseModel):
|
||||
def _is_pydantic_type(t: Any) -> bool:
|
||||
"""
|
||||
具体参数说明
|
||||
递归检查一个类型注解是否与 Pydantic BaseModel 相关。
|
||||
"""
|
||||
|
||||
command: str
|
||||
"""命令名称"""
|
||||
params: list[str] = Field(default_factory=list)
|
||||
"""参数"""
|
||||
description: str = ""
|
||||
"""描述"""
|
||||
examples: list[Example] = Field(default_factory=list)
|
||||
"""示例列表"""
|
||||
if t is None:
|
||||
return False
|
||||
origin = get_origin(t)
|
||||
if origin:
|
||||
return any(_is_pydantic_type(arg) for arg in get_args(t))
|
||||
return isinstance(t, type) and issubclass(t, BaseModel)
|
||||
|
||||
|
||||
class RegisterConfig(BaseModel):
|
||||
def parse_as(type_: type[T], obj: Any) -> T:
|
||||
"""
|
||||
注册配置项
|
||||
一个兼容 Pydantic V1 的 parse_obj_as 和V2的TypeAdapter.validate_python 的辅助函数。
|
||||
"""
|
||||
if VERSION.startswith("1"):
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
key: str
|
||||
"""配置项键"""
|
||||
value: Any
|
||||
"""配置项值"""
|
||||
module: str | None = None
|
||||
"""模块名"""
|
||||
help: str | None
|
||||
"""配置注解"""
|
||||
default_value: Any | None = None
|
||||
"""默认值"""
|
||||
type: Any = None
|
||||
"""参数类型"""
|
||||
arg_parser: Callable | None = None
|
||||
"""参数解析"""
|
||||
return parse_obj_as(type_, obj)
|
||||
else:
|
||||
from pydantic import TypeAdapter # type: ignore
|
||||
|
||||
|
||||
class ConfigModel(BaseModel):
|
||||
"""
|
||||
配置项
|
||||
"""
|
||||
|
||||
value: Any
|
||||
"""配置项值"""
|
||||
help: str | None
|
||||
"""配置注解"""
|
||||
default_value: Any | None = None
|
||||
"""默认值"""
|
||||
type: Any = None
|
||||
"""参数类型"""
|
||||
arg_parser: Callable | None = None
|
||||
"""参数解析"""
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
return TypeAdapter(type_).validate_python(obj)
|
||||
|
||||
|
||||
class ConfigGroup(BaseModel):
|
||||
@ -98,202 +91,41 @@ class ConfigGroup(BaseModel):
|
||||
configs: dict[str, ConfigModel] = Field(default_factory=dict)
|
||||
"""配置项列表"""
|
||||
|
||||
def get(self, c: str, default: Any = None) -> Any:
|
||||
cfg = self.configs.get(c.upper())
|
||||
if cfg is not None:
|
||||
if cfg.value is not None:
|
||||
return cfg.value
|
||||
if cfg.default_value is not None:
|
||||
return cfg.default_value
|
||||
def get(self, c: str, default: Any = None, *, build_model: bool = True) -> Any:
|
||||
"""
|
||||
获取配置项的值。如果指定了类型,会自动构建实例。
|
||||
"""
|
||||
key = c.upper()
|
||||
cfg = self.configs.get(key)
|
||||
|
||||
if cfg is None:
|
||||
return default
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
value_to_process = cfg.value if cfg.value is not None else cfg.default_value
|
||||
|
||||
if value_to_process is None:
|
||||
return default
|
||||
|
||||
class BaseBlock(BaseModel):
|
||||
"""
|
||||
插件阻断基本类(插件阻断限制)
|
||||
"""
|
||||
if cfg.type:
|
||||
if _is_pydantic_type(cfg.type):
|
||||
if build_model:
|
||||
try:
|
||||
return parse_as(cfg.type, value_to_process)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Pydantic 模型解析失败 (key: {c.upper()}). ", e=e
|
||||
)
|
||||
try:
|
||||
return cattrs.structure(value_to_process, cfg.type)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cattrs 结构化失败 (key: {key}),返回原始值。", e=e)
|
||||
|
||||
status: bool = True
|
||||
"""限制状态"""
|
||||
check_type: BlockType = BlockType.ALL
|
||||
"""检查类型"""
|
||||
watch_type: LimitWatchType = LimitWatchType.USER
|
||||
"""监听对象"""
|
||||
result: str | None = None
|
||||
"""阻断时回复内容"""
|
||||
_type: PluginLimitType = PluginLimitType.BLOCK
|
||||
"""类型"""
|
||||
return value_to_process
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
|
||||
|
||||
class PluginCdBlock(BaseBlock):
|
||||
"""
|
||||
插件cd限制
|
||||
"""
|
||||
|
||||
cd: int = 5
|
||||
"""cd"""
|
||||
_type: PluginLimitType = PluginLimitType.CD
|
||||
"""类型"""
|
||||
|
||||
|
||||
class PluginCountBlock(BaseBlock):
|
||||
"""
|
||||
插件次数限制
|
||||
"""
|
||||
|
||||
max_count: int
|
||||
"""最大调用次数"""
|
||||
_type: PluginLimitType = PluginLimitType.COUNT
|
||||
"""类型"""
|
||||
|
||||
|
||||
class PluginSetting(BaseModel):
|
||||
"""
|
||||
插件基本配置
|
||||
"""
|
||||
|
||||
level: int = 5
|
||||
"""群权限等级"""
|
||||
default_status: bool = True
|
||||
"""进群默认开关状态"""
|
||||
limit_superuser: bool = False
|
||||
"""是否限制超级用户"""
|
||||
cost_gold: int = 0
|
||||
"""调用插件花费金币"""
|
||||
impression: float = 0.0
|
||||
"""调用插件好感度限制"""
|
||||
|
||||
|
||||
class AICallableProperties(BaseModel):
|
||||
type: str
|
||||
"""参数类型"""
|
||||
description: str
|
||||
"""参数描述"""
|
||||
enums: list[str] | None = None
|
||||
"""参数枚举"""
|
||||
|
||||
|
||||
class AICallableParam(BaseModel):
|
||||
type: str
|
||||
"""类型"""
|
||||
properties: dict[str, AICallableProperties]
|
||||
"""参数列表"""
|
||||
required: list[str]
|
||||
"""必要参数"""
|
||||
|
||||
|
||||
class AICallableTag(BaseModel):
|
||||
name: str
|
||||
"""工具名称"""
|
||||
parameters: AICallableParam | None = None
|
||||
"""工具参数"""
|
||||
description: str
|
||||
"""工具描述"""
|
||||
func: Callable | None = None
|
||||
"""工具函数"""
|
||||
|
||||
def to_dict(self):
|
||||
result = model_dump(self)
|
||||
del result["func"]
|
||||
return result
|
||||
|
||||
|
||||
class SchedulerModel(BaseModel):
|
||||
trigger: Literal["date", "interval", "cron"]
|
||||
"""trigger"""
|
||||
day: int | None = None
|
||||
"""天数"""
|
||||
hour: int | None = None
|
||||
"""小时"""
|
||||
minute: int | None = None
|
||||
"""分钟"""
|
||||
second: int | None = None
|
||||
"""秒"""
|
||||
run_date: datetime | None = None
|
||||
"""运行日期"""
|
||||
id: str | None = None
|
||||
"""id"""
|
||||
max_instances: int | None = None
|
||||
"""最大运行实例"""
|
||||
args: list | None = None
|
||||
"""参数"""
|
||||
kwargs: dict | None = None
|
||||
"""参数"""
|
||||
|
||||
|
||||
class Task(BaseBlock):
|
||||
module: str
|
||||
"""被动技能模块名"""
|
||||
name: str
|
||||
"""被动技能名称"""
|
||||
status: bool = True
|
||||
"""全局开关状态"""
|
||||
create_status: bool = False
|
||||
"""初次加载默认开关状态"""
|
||||
default_status: bool = True
|
||||
"""进群时默认状态"""
|
||||
scheduler: SchedulerModel | None = None
|
||||
"""定时任务配置"""
|
||||
run_func: Callable | None = None
|
||||
"""运行函数"""
|
||||
check: Callable | None = None
|
||||
"""检查函数"""
|
||||
check_args: list = Field(default_factory=list)
|
||||
"""检查函数参数"""
|
||||
|
||||
|
||||
class PluginExtraData(BaseModel):
|
||||
"""
|
||||
插件扩展信息
|
||||
"""
|
||||
|
||||
author: str | None = None
|
||||
"""作者"""
|
||||
version: str | None = None
|
||||
"""版本"""
|
||||
plugin_type: PluginType = PluginType.NORMAL
|
||||
"""插件类型"""
|
||||
menu_type: str = "功能"
|
||||
"""菜单类型"""
|
||||
admin_level: int | None = None
|
||||
"""管理员插件所需权限等级"""
|
||||
configs: list[RegisterConfig] | None = None
|
||||
"""插件配置"""
|
||||
setting: PluginSetting | None = None
|
||||
"""插件基本配置"""
|
||||
limits: list[BaseBlock | PluginCdBlock | PluginCountBlock] | None = None
|
||||
"""插件限制"""
|
||||
commands: list[Command] = Field(default_factory=list)
|
||||
"""命令列表,用于说明帮助"""
|
||||
ignore_prompt: bool = False
|
||||
"""是否忽略阻断提示"""
|
||||
tasks: list[Task] | None = None
|
||||
"""技能被动"""
|
||||
superuser_help: str | None = None
|
||||
"""超级用户帮助"""
|
||||
aliases: set[str] = Field(default_factory=set)
|
||||
"""额外名称"""
|
||||
sql_list: list[str] | None = None
|
||||
"""常用sql"""
|
||||
is_show: bool = True
|
||||
"""是否显示在菜单中"""
|
||||
smart_tools: list[AICallableTag] | None = None
|
||||
"""智能模式函数工具集"""
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
|
||||
|
||||
class NoSuchConfig(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConfigsManager:
|
||||
"""
|
||||
插件配置 与 资源 管理器
|
||||
@ -366,23 +198,32 @@ class ConfigsManager:
|
||||
|
||||
if not module or not key:
|
||||
raise ValueError("add_plugin_config: module和key不能为为空")
|
||||
if isinstance(value, BaseModel):
|
||||
value = model_dump(value)
|
||||
if isinstance(default_value, BaseModel):
|
||||
default_value = model_dump(default_value)
|
||||
|
||||
processed_value = _dump_pydantic_obj(value)
|
||||
processed_default_value = _dump_pydantic_obj(default_value)
|
||||
|
||||
self.add_module.append(f"{module}:{key}".lower())
|
||||
if module in self._data and (config := self._data[module].configs.get(key)):
|
||||
config.help = help
|
||||
config.arg_parser = arg_parser
|
||||
config.type = type
|
||||
if _override:
|
||||
config.value = value
|
||||
config.default_value = default_value
|
||||
config.value = processed_value
|
||||
config.default_value = processed_default_value
|
||||
else:
|
||||
key = key.upper()
|
||||
if not self._data.get(module):
|
||||
self._data[module] = ConfigGroup(module=module)
|
||||
self._data[module].configs[key] = ConfigModel(
|
||||
value=value,
|
||||
value=processed_value,
|
||||
help=help,
|
||||
default_value=default_value,
|
||||
default_value=processed_default_value,
|
||||
type=type,
|
||||
arg_parser=arg_parser,
|
||||
)
|
||||
|
||||
def set_config(
|
||||
@ -402,6 +243,8 @@ class ConfigsManager:
|
||||
"""
|
||||
key = key.upper()
|
||||
if module in self._data:
|
||||
if module not in self._simple_data:
|
||||
self._simple_data[module] = {}
|
||||
if self._data[module].configs.get(key):
|
||||
self._data[module].configs[key].value = value
|
||||
else:
|
||||
@ -410,63 +253,68 @@ class ConfigsManager:
|
||||
if auto_save:
|
||||
self.save(save_simple_data=True)
|
||||
|
||||
def get_config(self, module: str, key: str, default: Any = None) -> Any:
|
||||
"""获取指定配置值
|
||||
|
||||
参数:
|
||||
module: 模块名
|
||||
key: 配置键
|
||||
default: 没有key值内容的默认返回值.
|
||||
|
||||
异常:
|
||||
NoSuchConfig: 未查询到配置
|
||||
|
||||
返回:
|
||||
Any: 配置值
|
||||
def get_config(
|
||||
self,
|
||||
module: str,
|
||||
key: str,
|
||||
default: Any = None,
|
||||
*,
|
||||
build_model: bool = True,
|
||||
) -> Any:
|
||||
"""
|
||||
获取指定配置值,自动构建Pydantic模型或其它类型实例。
|
||||
- 兼容Pydantic V1/V2。
|
||||
- 支持 list[BaseModel] 等泛型容器。
|
||||
- 优先使用Pydantic原生方式解析,失败后回退到cattrs。
|
||||
"""
|
||||
logger.debug(
|
||||
f"尝试获取配置MODULE: [<u><y>{module}</y></u>] | KEY: [<u><y>{key}</y></u>]"
|
||||
)
|
||||
key = key.upper()
|
||||
value = None
|
||||
if module in self._data.keys():
|
||||
config = self._data[module].configs.get(key) or self._data[
|
||||
module
|
||||
].configs.get(key)
|
||||
config_group = self._data.get(module)
|
||||
if not config_group:
|
||||
return default
|
||||
|
||||
config = config_group.configs.get(key)
|
||||
if not config:
|
||||
raise NoSuchConfig(
|
||||
f"未查询到配置项 MODULE: [ {module} ] | KEY: [ {key} ]"
|
||||
return default
|
||||
|
||||
value_to_process = (
|
||||
config.value if config.value is not None else config.default_value
|
||||
)
|
||||
try:
|
||||
if value_to_process is None:
|
||||
return default
|
||||
|
||||
# 1. 最高优先级:自定义的参数解析器
|
||||
if config.arg_parser:
|
||||
value = config.arg_parser(value or config.default_value)
|
||||
elif config.value is not None:
|
||||
# try:
|
||||
value = (
|
||||
cattrs.structure(config.value, config.type)
|
||||
if config.type
|
||||
else config.value
|
||||
)
|
||||
elif config.default_value is not None:
|
||||
value = (
|
||||
cattrs.structure(config.default_value, config.type)
|
||||
if config.type
|
||||
else config.default_value
|
||||
)
|
||||
try:
|
||||
return config.arg_parser(value_to_process)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
logger.debug(
|
||||
f"配置项类型转换 MODULE: [<u><y>{module}</y></u>]"
|
||||
" | KEY: [<u><y>{key}</y></u>]",
|
||||
f" | KEY: [<u><y>{key}</y></u>] 将使用原始值",
|
||||
e=e,
|
||||
)
|
||||
value = config.value or config.default_value
|
||||
if value is None:
|
||||
value = default
|
||||
logger.debug(
|
||||
f"获取配置 MODULE: [<u><y>{module}</y></u>] | "
|
||||
f" KEY: [<u><y>{key}</y></u>] -> [<u><c>{value}</c></u>]"
|
||||
|
||||
if config.type:
|
||||
if _is_pydantic_type(config.type):
|
||||
if build_model:
|
||||
try:
|
||||
return parse_as(config.type, value_to_process)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"pydantic类型转换失败 MODULE: [<u><y>{module}</y></u>] | "
|
||||
f"KEY: [<u><y>{key}</y></u>].",
|
||||
e=e,
|
||||
)
|
||||
return value
|
||||
else:
|
||||
try:
|
||||
return cattrs.structure(value_to_process, config.type)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"cattrs类型转换失败 MODULE: [<u><y>{module}</y></u>] | "
|
||||
f"KEY: [<u><y>{key}</y></u>].",
|
||||
e=e,
|
||||
)
|
||||
|
||||
return value_to_process
|
||||
|
||||
def get(self, key: str) -> ConfigGroup:
|
||||
"""获取插件配置数据
|
||||
@ -490,16 +338,16 @@ class ConfigsManager:
|
||||
with open(self._simple_file, "w", encoding="utf8") as f:
|
||||
_yaml.dump(self._simple_data, f)
|
||||
path = path or self.file
|
||||
data = {}
|
||||
for module in self._data:
|
||||
data[module] = {}
|
||||
for config in self._data[module].configs:
|
||||
value = self._data[module].configs[config].dict()
|
||||
del value["type"]
|
||||
del value["arg_parser"]
|
||||
data[module][config] = value
|
||||
save_data = {}
|
||||
for module, config_group in self._data.items():
|
||||
save_data[module] = {}
|
||||
for config_key, config_model in config_group.configs.items():
|
||||
save_data[module][config_key] = model_dump(
|
||||
config_model, exclude={"type", "arg_parser"}
|
||||
)
|
||||
|
||||
with open(path, "w", encoding="utf8") as f:
|
||||
_yaml.dump(data, f)
|
||||
_yaml.dump(save_data, f)
|
||||
|
||||
def reload(self):
|
||||
"""重新加载配置文件"""
|
||||
@ -558,3 +406,23 @@ class ConfigsManager:
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._data[key]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AICallableParam",
|
||||
"AICallableProperties",
|
||||
"AICallableTag",
|
||||
"BaseBlock",
|
||||
"Command",
|
||||
"ConfigGroup",
|
||||
"ConfigModel",
|
||||
"ConfigsManager",
|
||||
"Example",
|
||||
"NoSuchConfig",
|
||||
"PluginCdBlock",
|
||||
"PluginCountBlock",
|
||||
"PluginExtraData",
|
||||
"PluginSetting",
|
||||
"RegisterConfig",
|
||||
"Task",
|
||||
]
|
||||
|
||||
270
zhenxun/configs/utils/models.py
Normal file
270
zhenxun/configs/utils/models.py
Normal file
@ -0,0 +1,270 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from nonebot.compat import model_dump
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from zhenxun.utils.enum import BlockType, LimitWatchType, PluginLimitType, PluginType
|
||||
|
||||
__all__ = [
|
||||
"AICallableParam",
|
||||
"AICallableProperties",
|
||||
"AICallableTag",
|
||||
"BaseBlock",
|
||||
"Command",
|
||||
"ConfigModel",
|
||||
"Example",
|
||||
"PluginCdBlock",
|
||||
"PluginCountBlock",
|
||||
"PluginExtraData",
|
||||
"PluginSetting",
|
||||
"RegisterConfig",
|
||||
"Task",
|
||||
]
|
||||
|
||||
|
||||
class Example(BaseModel):
|
||||
"""
|
||||
示例
|
||||
"""
|
||||
|
||||
exec: str
|
||||
"""执行命令"""
|
||||
description: str = ""
|
||||
"""命令描述"""
|
||||
|
||||
|
||||
class Command(BaseModel):
|
||||
"""
|
||||
具体参数说明
|
||||
"""
|
||||
|
||||
command: str
|
||||
"""命令名称"""
|
||||
params: list[str] = Field(default_factory=list)
|
||||
"""参数"""
|
||||
description: str = ""
|
||||
"""描述"""
|
||||
examples: list[Example] = Field(default_factory=list)
|
||||
"""示例列表"""
|
||||
|
||||
|
||||
class RegisterConfig(BaseModel):
|
||||
"""
|
||||
注册配置项
|
||||
"""
|
||||
|
||||
key: str
|
||||
"""配置项键"""
|
||||
value: Any
|
||||
"""配置项值"""
|
||||
module: str | None = None
|
||||
"""模块名"""
|
||||
help: str | None
|
||||
"""配置注解"""
|
||||
default_value: Any | None = None
|
||||
"""默认值"""
|
||||
type: Any = None
|
||||
"""参数类型"""
|
||||
arg_parser: Callable | None = None
|
||||
"""参数解析"""
|
||||
|
||||
|
||||
class ConfigModel(BaseModel):
|
||||
"""
|
||||
配置项
|
||||
"""
|
||||
|
||||
value: Any
|
||||
"""配置项值"""
|
||||
help: str | None
|
||||
"""配置注解"""
|
||||
default_value: Any | None = None
|
||||
"""默认值"""
|
||||
type: Any = None
|
||||
"""参数类型"""
|
||||
arg_parser: Callable | None = None
|
||||
"""参数解析"""
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
|
||||
|
||||
class BaseBlock(BaseModel):
|
||||
"""
|
||||
插件阻断基本类(插件阻断限制)
|
||||
"""
|
||||
|
||||
status: bool = True
|
||||
"""限制状态"""
|
||||
check_type: BlockType = BlockType.ALL
|
||||
"""检查类型"""
|
||||
watch_type: LimitWatchType = LimitWatchType.USER
|
||||
"""监听对象"""
|
||||
result: str | None = None
|
||||
"""阻断时回复内容"""
|
||||
_type: PluginLimitType = PluginLimitType.BLOCK
|
||||
"""类型"""
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
|
||||
|
||||
class PluginCdBlock(BaseBlock):
|
||||
"""
|
||||
插件cd限制
|
||||
"""
|
||||
|
||||
cd: int = 5
|
||||
"""cd"""
|
||||
_type: PluginLimitType = PluginLimitType.CD
|
||||
"""类型"""
|
||||
|
||||
|
||||
class PluginCountBlock(BaseBlock):
|
||||
"""
|
||||
插件次数限制
|
||||
"""
|
||||
|
||||
max_count: int
|
||||
"""最大调用次数"""
|
||||
_type: PluginLimitType = PluginLimitType.COUNT
|
||||
"""类型"""
|
||||
|
||||
|
||||
class PluginSetting(BaseModel):
|
||||
"""
|
||||
插件基本配置
|
||||
"""
|
||||
|
||||
level: int = 5
|
||||
"""群权限等级"""
|
||||
default_status: bool = True
|
||||
"""进群默认开关状态"""
|
||||
limit_superuser: bool = False
|
||||
"""是否限制超级用户"""
|
||||
cost_gold: int = 0
|
||||
"""调用插件花费金币"""
|
||||
impression: float = 0.0
|
||||
"""调用插件好感度限制"""
|
||||
|
||||
|
||||
class AICallableProperties(BaseModel):
|
||||
type: str
|
||||
"""参数类型"""
|
||||
description: str
|
||||
"""参数描述"""
|
||||
enums: list[str] | None = None
|
||||
"""参数枚举"""
|
||||
|
||||
|
||||
class AICallableParam(BaseModel):
|
||||
type: str
|
||||
"""类型"""
|
||||
properties: dict[str, AICallableProperties]
|
||||
"""参数列表"""
|
||||
required: list[str]
|
||||
"""必要参数"""
|
||||
|
||||
|
||||
class AICallableTag(BaseModel):
|
||||
name: str
|
||||
"""工具名称"""
|
||||
parameters: AICallableParam | None = None
|
||||
"""工具参数"""
|
||||
description: str
|
||||
"""工具描述"""
|
||||
func: Callable | None = None
|
||||
"""工具函数"""
|
||||
|
||||
def to_dict(self):
|
||||
result = model_dump(self)
|
||||
del result["func"]
|
||||
return result
|
||||
|
||||
|
||||
class SchedulerModel(BaseModel):
|
||||
trigger: Literal["date", "interval", "cron"]
|
||||
"""trigger"""
|
||||
day: int | None = None
|
||||
"""天数"""
|
||||
hour: int | None = None
|
||||
"""小时"""
|
||||
minute: int | None = None
|
||||
"""分钟"""
|
||||
second: int | None = None
|
||||
"""秒"""
|
||||
run_date: datetime | None = None
|
||||
"""运行日期"""
|
||||
id: str | None = None
|
||||
"""id"""
|
||||
max_instances: int | None = None
|
||||
"""最大运行实例"""
|
||||
args: list | None = None
|
||||
"""参数"""
|
||||
kwargs: dict | None = None
|
||||
"""参数"""
|
||||
|
||||
|
||||
class Task(BaseBlock):
|
||||
module: str
|
||||
"""被动技能模块名"""
|
||||
name: str
|
||||
"""被动技能名称"""
|
||||
status: bool = True
|
||||
"""全局开关状态"""
|
||||
create_status: bool = False
|
||||
"""初次加载默认开关状态"""
|
||||
default_status: bool = True
|
||||
"""进群时默认状态"""
|
||||
scheduler: SchedulerModel | None = None
|
||||
"""定时任务配置"""
|
||||
run_func: Callable | None = None
|
||||
"""运行函数"""
|
||||
check: Callable | None = None
|
||||
"""检查函数"""
|
||||
check_args: list = Field(default_factory=list)
|
||||
"""检查函数参数"""
|
||||
|
||||
|
||||
class PluginExtraData(BaseModel):
|
||||
"""
|
||||
插件扩展信息
|
||||
"""
|
||||
|
||||
author: str | None = None
|
||||
"""作者"""
|
||||
version: str | None = None
|
||||
"""版本"""
|
||||
plugin_type: PluginType = PluginType.NORMAL
|
||||
"""插件类型"""
|
||||
menu_type: str = "功能"
|
||||
"""菜单类型"""
|
||||
admin_level: int | None = None
|
||||
"""管理员插件所需权限等级"""
|
||||
configs: list[RegisterConfig] | None = None
|
||||
"""插件配置"""
|
||||
setting: PluginSetting | None = None
|
||||
"""插件基本配置"""
|
||||
limits: list[BaseBlock | PluginCdBlock | PluginCountBlock] | None = None
|
||||
"""插件限制"""
|
||||
commands: list[Command] = Field(default_factory=list)
|
||||
"""命令列表,用于说明帮助"""
|
||||
ignore_prompt: bool = False
|
||||
"""是否忽略阻断提示"""
|
||||
tasks: list[Task] | None = None
|
||||
"""技能被动"""
|
||||
superuser_help: str | None = None
|
||||
"""超级用户帮助"""
|
||||
aliases: set[str] = Field(default_factory=set)
|
||||
"""额外名称"""
|
||||
sql_list: list[str] | None = None
|
||||
"""常用sql"""
|
||||
is_show: bool = True
|
||||
"""是否显示在菜单中"""
|
||||
smart_tools: list[AICallableTag] | None = None
|
||||
"""智能模式函数工具集"""
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return model_dump(self, **kwargs)
|
||||
@ -1,10 +1,12 @@
|
||||
import time
|
||||
from typing import ClassVar
|
||||
from typing_extensions import Self
|
||||
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import CacheType, DbLockType
|
||||
from zhenxun.utils.exception import UserAndGroupIsNone
|
||||
|
||||
|
||||
@ -27,6 +29,12 @@ class BanConsole(Model):
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "ban_console"
|
||||
table_description = "封禁人员/群组数据表"
|
||||
unique_together = ("user_id", "group_id")
|
||||
|
||||
cache_type = CacheType.BAN
|
||||
"""缓存类型"""
|
||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
||||
"""开启锁"""
|
||||
|
||||
@classmethod
|
||||
async def _get_data(cls, user_id: str | None, group_id: str | None) -> Self | None:
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Literal, overload
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
|
||||
class BotConsole(Model):
|
||||
@ -29,6 +30,8 @@ class BotConsole(Model):
|
||||
table = "bot_console"
|
||||
table_description = "Bot数据表"
|
||||
|
||||
cache_type = CacheType.BOT
|
||||
|
||||
@staticmethod
|
||||
def format(name: str) -> str:
|
||||
return f"<{name},"
|
||||
|
||||
29
zhenxun/models/bot_message_store.py
Normal file
29
zhenxun/models/bot_message_store.py
Normal file
@ -0,0 +1,29 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import BotSentType
|
||||
|
||||
|
||||
class BotMessageStore(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
"""自增id"""
|
||||
bot_id = fields.CharField(255, null=True)
|
||||
"""bot id"""
|
||||
user_id = fields.CharField(255, null=True)
|
||||
"""目标id"""
|
||||
group_id = fields.CharField(255, null=True)
|
||||
"""群组id"""
|
||||
sent_type = fields.CharEnumField(BotSentType)
|
||||
"""类型"""
|
||||
text = fields.TextField(null=True)
|
||||
"""文本内容"""
|
||||
plain_text = fields.TextField(null=True)
|
||||
"""纯文本"""
|
||||
platform = fields.CharField(255, null=True)
|
||||
"""平台"""
|
||||
create_time = fields.DatetimeField(auto_now_add=True)
|
||||
"""创建时间"""
|
||||
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "bot_message_store"
|
||||
table_description = "Bot发送消息列表"
|
||||
21
zhenxun/models/event_log.py
Normal file
21
zhenxun/models/event_log.py
Normal file
@ -0,0 +1,21 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import EventLogType
|
||||
|
||||
|
||||
class EventLog(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
"""自增id"""
|
||||
user_id = fields.CharField(255, description="用户id")
|
||||
"""用户id"""
|
||||
group_id = fields.CharField(255, description="群组id")
|
||||
"""群组id"""
|
||||
event_type = fields.CharEnumField(EventLogType, default=None, description="类型")
|
||||
"""类型"""
|
||||
create_time = fields.DatetimeField(auto_now_add=True, description="创建时间")
|
||||
"""创建时间"""
|
||||
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "event_log"
|
||||
table_description = "各种请求通知记录表"
|
||||
@ -3,8 +3,10 @@ from typing_extensions import Self
|
||||
from nonebot.adapters import Bot
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.configs.config import BotConfig
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.common_utils import SqlUtils
|
||||
from zhenxun.utils.enum import RequestHandleType, RequestType
|
||||
from zhenxun.utils.exception import NotFoundError
|
||||
|
||||
@ -34,6 +36,8 @@ class FgRequest(Model):
|
||||
RequestHandleType, null=True, description="处理类型"
|
||||
)
|
||||
"""处理类型"""
|
||||
message_ids = fields.CharField(max_length=255, null=True, description="消息id列表")
|
||||
"""消息id列表"""
|
||||
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "fg_request"
|
||||
@ -123,9 +127,24 @@ class FgRequest(Model):
|
||||
await GroupConsole.update_or_create(
|
||||
group_id=req.group_id, defaults={"group_flag": 1}
|
||||
)
|
||||
if req.flag == "0":
|
||||
# 用户手动申请入群,创建群认证后提醒用户拉群
|
||||
await bot.send_private_msg(
|
||||
user_id=req.user_id,
|
||||
message=f"已同意你对{BotConfig.self_nickname}的申请群组:"
|
||||
f"{req.group_id},可以直接手动拉入群组,{BotConfig.self_nickname}会自动同意。",
|
||||
)
|
||||
else:
|
||||
# 正常同意群组请求
|
||||
await bot.set_group_add_request(
|
||||
flag=req.flag,
|
||||
sub_type="invite",
|
||||
approve=handle_type == RequestHandleType.APPROVE,
|
||||
)
|
||||
return req
|
||||
|
||||
@classmethod
|
||||
async def _run_script(cls):
|
||||
return [
|
||||
SqlUtils.add_column("fg_request", "message_ids", "character varying(255)")
|
||||
]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, cast, overload
|
||||
from typing import Any, ClassVar, cast, overload
|
||||
from typing_extensions import Self
|
||||
|
||||
from tortoise import fields
|
||||
@ -6,8 +6,9 @@ from tortoise.backends.base.client import BaseDBAsyncClient
|
||||
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.task_info import TaskInfo
|
||||
from zhenxun.services.cache import CacheRoot
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.enum import CacheType, DbLockType, PluginType
|
||||
|
||||
|
||||
def add_disable_marker(name: str) -> str:
|
||||
@ -41,7 +42,7 @@ def convert_module_format(data: str | list[str]) -> str | list[str]:
|
||||
str | list[str]: 根据输入类型返回转换后的数据。
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return [item.strip(",") for item in data.split("<") if item.strip()]
|
||||
return [item.strip(",") for item in data.split("<") if item]
|
||||
else:
|
||||
return "".join(add_disable_marker(item) for item in data)
|
||||
|
||||
@ -87,6 +88,11 @@ class GroupConsole(Model):
|
||||
table_description = "群组信息表"
|
||||
unique_together = ("group_id", "channel_id")
|
||||
|
||||
cache_type = CacheType.GROUPS
|
||||
"""缓存类型"""
|
||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
||||
"""开启锁"""
|
||||
|
||||
@classmethod
|
||||
async def _get_task_modules(cls, *, default_status: bool) -> list[str]:
|
||||
"""获取默认禁用的任务模块
|
||||
@ -117,6 +123,7 @@ class GroupConsole(Model):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@CacheRoot.listener(CacheType.GROUPS)
|
||||
async def create(
|
||||
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||||
) -> Self:
|
||||
@ -180,9 +187,14 @@ class GroupConsole(Model):
|
||||
if task_modules or plugin_modules:
|
||||
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||
|
||||
if is_create:
|
||||
if cache := await CacheRoot.get_cache(CacheType.GROUPS):
|
||||
await cache.update(group.group_id, group)
|
||||
|
||||
return group, is_create
|
||||
|
||||
@classmethod
|
||||
@CacheRoot.listener(CacheType.GROUPS)
|
||||
async def update_or_create(
|
||||
cls,
|
||||
defaults: dict | None = None,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import CacheType
|
||||
|
||||
|
||||
class LevelUser(Model):
|
||||
@ -20,6 +21,8 @@ class LevelUser(Model):
|
||||
table_description = "用户权限数据库"
|
||||
unique_together = ("user_id", "group_id")
|
||||
|
||||
cache_type = CacheType.LEVEL
|
||||
|
||||
@classmethod
|
||||
async def get_user_level(cls, user_id: str, group_id: str | None) -> int:
|
||||
"""获取用户在群内的等级
|
||||
@ -53,6 +56,9 @@ class LevelUser(Model):
|
||||
level: 权限等级
|
||||
group_flag: 是否被自动更新刷新权限 0:是, 1:否.
|
||||
"""
|
||||
if await cls.exists(user_id=user_id, group_id=group_id, user_level=level):
|
||||
# 权限相同时跳过
|
||||
return
|
||||
await cls.update_or_create(
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
@ -119,8 +125,7 @@ class LevelUser(Model):
|
||||
return [
|
||||
# 将user_id改为user_id
|
||||
"ALTER TABLE level_users RENAME COLUMN user_qq TO user_id;",
|
||||
"ALTER TABLE level_users "
|
||||
"ALTER COLUMN user_id TYPE character varying(255);",
|
||||
"ALTER TABLE level_users ALTER COLUMN user_id TYPE character varying(255);",
|
||||
# 将user_id字段类型改为character varying(255)
|
||||
"ALTER TABLE level_users "
|
||||
"ALTER COLUMN group_id TYPE character varying(255);",
|
||||
|
||||
123
zhenxun/models/mahiro_bank.py
Normal file
123
zhenxun/models/mahiro_bank.py
Normal file
@ -0,0 +1,123 @@
|
||||
from datetime import datetime
|
||||
from typing_extensions import Self
|
||||
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
|
||||
from .mahiro_bank_log import BankHandleType, MahiroBankLog
|
||||
|
||||
|
||||
class MahiroBank(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
"""自增id"""
|
||||
user_id = fields.CharField(255, description="用户id")
|
||||
"""用户id"""
|
||||
amount = fields.BigIntField(default=0, description="存款")
|
||||
"""用户存款"""
|
||||
rate = fields.FloatField(default=0.0005, description="小时利率")
|
||||
"""小时利率"""
|
||||
loan_amount = fields.BigIntField(default=0, description="贷款")
|
||||
"""用户贷款"""
|
||||
loan_rate = fields.FloatField(default=0.0005, description="贷款利率")
|
||||
"""贷款利率"""
|
||||
update_time = fields.DatetimeField(auto_now=True)
|
||||
"""修改时间"""
|
||||
create_time = fields.DatetimeField(auto_now_add=True)
|
||||
"""创建时间"""
|
||||
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "mahiro_bank"
|
||||
table_description = "小真寻银行"
|
||||
|
||||
@classmethod
|
||||
async def deposit(cls, user_id: str, amount: int, rate: float) -> Self:
|
||||
"""存款
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
amount: 金币数量
|
||||
rate: 小时利率
|
||||
|
||||
返回:
|
||||
Self: MahiroBank
|
||||
"""
|
||||
effective_hour = int(24 - datetime.now().hour)
|
||||
user, _ = await cls.get_or_create(user_id=user_id)
|
||||
user.amount += amount
|
||||
await user.save(update_fields=["amount", "rate"])
|
||||
await MahiroBankLog.create(
|
||||
user_id=user_id,
|
||||
amount=amount,
|
||||
rate=rate,
|
||||
effective_hour=effective_hour,
|
||||
handle_type=BankHandleType.DEPOSIT,
|
||||
)
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def withdraw(cls, user_id: str, amount: int) -> Self:
|
||||
"""取款
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
amount: 金币数量
|
||||
|
||||
返回:
|
||||
Self: MahiroBank
|
||||
"""
|
||||
if amount <= 0:
|
||||
raise ValueError("取款金额必须大于0")
|
||||
user, _ = await cls.get_or_create(user_id=user_id)
|
||||
if user.amount < amount:
|
||||
raise ValueError("取款金额不能大于存款金额")
|
||||
user.amount -= amount
|
||||
await user.save(update_fields=["amount"])
|
||||
await MahiroBankLog.create(
|
||||
user_id=user_id, amount=amount, handle_type=BankHandleType.WITHDRAW
|
||||
)
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def loan(cls, user_id: str, amount: int, rate: float) -> Self:
|
||||
"""贷款
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
amount: 贷款金额
|
||||
rate: 贷款利率
|
||||
|
||||
返回:
|
||||
Self: MahiroBank
|
||||
"""
|
||||
user, _ = await cls.get_or_create(user_id=user_id)
|
||||
user.loan_amount += amount
|
||||
user.loan_rate = rate
|
||||
await user.save(update_fields=["loan_amount", "loan_rate"])
|
||||
await MahiroBankLog.create(
|
||||
user_id=user_id, amount=amount, rate=rate, handle_type=BankHandleType.LOAN
|
||||
)
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def repayment(cls, user_id: str, amount: int) -> Self:
|
||||
"""还款
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
amount: 还款金额
|
||||
|
||||
返回:
|
||||
Self: MahiroBank
|
||||
"""
|
||||
if amount <= 0:
|
||||
raise ValueError("还款金额必须大于0")
|
||||
user, _ = await cls.get_or_create(user_id=user_id)
|
||||
if user.loan_amount < amount:
|
||||
raise ValueError("还款金额不能大于贷款金额")
|
||||
user.loan_amount -= amount
|
||||
await user.save(update_fields=["loan_amount"])
|
||||
await MahiroBankLog.create(
|
||||
user_id=user_id, amount=amount, handle_type=BankHandleType.REPAYMENT
|
||||
)
|
||||
return user
|
||||
31
zhenxun/models/mahiro_bank_log.py
Normal file
31
zhenxun/models/mahiro_bank_log.py
Normal file
@ -0,0 +1,31 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import BankHandleType
|
||||
|
||||
|
||||
class MahiroBankLog(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
"""自增id"""
|
||||
user_id = fields.CharField(255, description="用户id")
|
||||
"""用户id"""
|
||||
amount = fields.BigIntField(default=0, description="存款")
|
||||
"""金币数量"""
|
||||
rate = fields.FloatField(default=0, description="小时利率")
|
||||
"""小时利率"""
|
||||
handle_type = fields.CharEnumField(
|
||||
BankHandleType, null=True, description="处理类型"
|
||||
)
|
||||
"""处理类型"""
|
||||
is_completed = fields.BooleanField(default=False, description="是否完成")
|
||||
"""是否完成"""
|
||||
effective_hour = fields.IntField(default=0, description="有效小时")
|
||||
"""有效小时"""
|
||||
update_time = fields.DatetimeField(auto_now=True)
|
||||
"""修改时间"""
|
||||
create_time = fields.DatetimeField(auto_now_add=True)
|
||||
"""创建时间"""
|
||||
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "mahiro_bank_log"
|
||||
table_description = "小真寻银行日志"
|
||||
@ -4,7 +4,7 @@ from tortoise import fields
|
||||
|
||||
from zhenxun.models.plugin_limit import PluginLimit # noqa: F401
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import BlockType, PluginType
|
||||
from zhenxun.utils.enum import BlockType, CacheType, PluginType
|
||||
|
||||
|
||||
class PluginInfo(Model):
|
||||
@ -59,6 +59,8 @@ class PluginInfo(Model):
|
||||
table = "plugin_info"
|
||||
table_description = "插件基本信息"
|
||||
|
||||
cache_type = CacheType.PLUGINS
|
||||
|
||||
@classmethod
|
||||
async def get_plugin(
|
||||
cls, load_status: bool = True, filter_parent: bool = True, **kwargs
|
||||
|
||||
38
zhenxun/models/schedule_info.py
Normal file
38
zhenxun/models/schedule_info.py
Normal file
@ -0,0 +1,38 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
|
||||
|
||||
class ScheduleInfo(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
"""自增id"""
|
||||
bot_id = fields.CharField(
|
||||
255, null=True, default=None, description="任务关联的Bot ID"
|
||||
)
|
||||
"""任务关联的Bot ID"""
|
||||
plugin_name = fields.CharField(255, description="插件模块名")
|
||||
"""插件模块名"""
|
||||
group_id = fields.CharField(
|
||||
255,
|
||||
null=True,
|
||||
description="群组ID, '__ALL_GROUPS__' 表示所有群, 为空表示全局任务",
|
||||
)
|
||||
"""群组ID, 为空表示全局任务"""
|
||||
trigger_type = fields.CharField(
|
||||
max_length=20, default="cron", description="触发器类型 (cron, interval, date)"
|
||||
)
|
||||
"""触发器类型 (cron, interval, date)"""
|
||||
trigger_config = fields.JSONField(description="触发器具体配置")
|
||||
"""触发器具体配置"""
|
||||
job_kwargs = fields.JSONField(
|
||||
default=dict, description="传递给任务函数的额外关键字参数"
|
||||
)
|
||||
"""传递给任务函数的额外关键字参数"""
|
||||
is_enabled = fields.BooleanField(default=True, description="是否启用")
|
||||
"""是否启用"""
|
||||
create_time = fields.DatetimeField(auto_now_add=True)
|
||||
"""创建时间"""
|
||||
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "schedule_info"
|
||||
table_description = "通用定时任务表"
|
||||
@ -2,7 +2,7 @@ from tortoise import fields
|
||||
|
||||
from zhenxun.models.goods_info import GoodsInfo
|
||||
from zhenxun.services.db_context import Model
|
||||
from zhenxun.utils.enum import GoldHandle
|
||||
from zhenxun.utils.enum import CacheType, GoldHandle
|
||||
from zhenxun.utils.exception import GoodsNotFound, InsufficientGold
|
||||
|
||||
from .user_gold_log import UserGoldLog
|
||||
@ -30,6 +30,8 @@ class UserConsole(Model):
|
||||
table = "user_console"
|
||||
table_description = "用户数据表"
|
||||
|
||||
cache_type = CacheType.USERS
|
||||
|
||||
@classmethod
|
||||
async def get_user(cls, user_id: str, platform: str | None = None) -> "UserConsole":
|
||||
"""获取用户
|
||||
|
||||
54
zhenxun/plugins/bym_ai/README.md
Normal file
54
zhenxun/plugins/bym_ai/README.md
Normal file
@ -0,0 +1,54 @@
|
||||
# BYM AI 插件使用指南
|
||||
|
||||
本插件支持所有符合 OpenAi 接口格式的 AI 服务,以下以 Gemini 为例进行说明。
|
||||
你也通过 [其他文档](https://github.com/Hoper-J/AI-Guide-and-Demos-zh_CN/blob/master/Guide/DeepSeek%20API%20%E7%9A%84%E8%8E%B7%E5%8F%96%E4%B8%8E%E5%AF%B9%E8%AF%9D%E7%A4%BA%E4%BE%8B.md) 查看配置
|
||||
|
||||
## 获取 API KEY
|
||||
|
||||
1. 进入 [Gemini API Key](https://aistudio.google.com/app/apikey?hl=zh-cn) 生成 API KEY。
|
||||
2. 如果无法访问,请尝试更换代理。
|
||||
|
||||
## 配置设置
|
||||
|
||||
首次加载插件后,在 `data/config.yaml` 文件中进行以下配置(请勿复制括号内的内容):
|
||||
|
||||
```yaml
|
||||
bym_ai:
|
||||
# BYM_AI 配置
|
||||
BYM_AI_CHAT_URL: https://generativelanguage.googleapis.com/v1beta/chat/completions # Gemini 官方 API,更推荐找反代
|
||||
BYM_AI_CHAT_TOKEN:
|
||||
- 你刚刚获取的 API KEY,可以有多个进行轮询
|
||||
BYM_AI_CHAT_MODEL: gemini-2.0-flash-thinking-exp-01-21 # 推荐使用的聊天模型(免费)
|
||||
BYM_AI_TOOL_MODEL: gemini-2.0-flash-exp # 推荐使用的工具调用模型(免费,需开启 BYM_AI_CHAT_SMART)
|
||||
BYM_AI_CHAT: true # 是否开启伪人回复
|
||||
BYM_AI_CHAT_RATE: 0.001 # 伪人回复概率(0-1)
|
||||
BYM_AI_TTS_URL: # TTS 接口地址
|
||||
BYM_AI_TTS_TOKEN: # TTS 接口密钥
|
||||
BYM_AI_TTS_VOICE: # TTS 接口音色
|
||||
BYM_AI_CHAT_SMART: true # 是否开启智能模式(必须填写 BYM_AI_TOOL_MODEL)
|
||||
ENABLE_IMPRESSION: true # 使用签到数据作为基础好感度
|
||||
CACHE_SIZE: 40 # 缓存聊天记录数据大小(每位用户)
|
||||
ENABLE_GROUP_CHAT: true # 在群组中时共用缓存
|
||||
```
|
||||
|
||||
## 人设设置
|
||||
|
||||
在`data/bym_ai/prompt.txt`中设置你的基础人设
|
||||
|
||||
## 礼物开发
|
||||
|
||||
与商品注册类型,在`bym_ai/bym_gift/gift_reg.py`中查看写法。
|
||||
|
||||
例如:
|
||||
|
||||
```python
|
||||
@gift_register(
|
||||
name="可爱的钱包",
|
||||
icon="wallet.png",
|
||||
description=f"这是{BotConfig.self_nickname}的小钱包,里面装了一些金币。",
|
||||
)
|
||||
async def _(user_id: str):
|
||||
rand = random.randint(100, 500)
|
||||
await UserConsole.add_gold(user_id, rand, "BYM_AI")
|
||||
return f"钱包里装了{BotConfig.self_nickname}送给你的枚{rand}金币哦~"
|
||||
```
|
||||
283
zhenxun/plugins/bym_ai/__init__.py
Normal file
283
zhenxun/plugins/bym_ai/__init__.py
Normal file
@ -0,0 +1,283 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
from httpx import HTTPStatusError
|
||||
from nonebot import on_message
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot_plugin_alconna import UniMsg, Voice
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import BotConfig
|
||||
from zhenxun.configs.path_config import IMAGE_PATH
|
||||
from zhenxun.configs.utils import (
|
||||
AICallableParam,
|
||||
AICallableProperties,
|
||||
AICallableTag,
|
||||
PluginExtraData,
|
||||
RegisterConfig,
|
||||
)
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.services.plugin_init import PluginInit
|
||||
from zhenxun.utils.depends import CheckConfig, UserName
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
from .bym_gift import ICON_PATH
|
||||
from .bym_gift.data_source import send_gift
|
||||
from .bym_gift.gift_reg import driver
|
||||
from .config import Arparma, FunctionParam
|
||||
from .data_source import ChatManager, base_config, split_text
|
||||
from .exception import GiftRepeatSendException, NotResultException
|
||||
from .goods_register import driver # noqa: F401
|
||||
from .models.bym_chat import BymChat
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="BYM_AI",
|
||||
description=f"{BotConfig.self_nickname}想成为人类...",
|
||||
usage=f"""
|
||||
你问小真寻的愿望?
|
||||
{BotConfig.self_nickname}说她想成为人类!
|
||||
""".strip(),
|
||||
extra=PluginExtraData(
|
||||
author="Chtholly & HibiKier",
|
||||
version="0.3",
|
||||
ignore_prompt=True,
|
||||
configs=[
|
||||
RegisterConfig(
|
||||
key="BYM_AI_CHAT_URL",
|
||||
value=None,
|
||||
help="ai聊天接口地址,可以填入url和平台名称,当你使用平台名称时,默认使用平台官方api, 目前有[gemini, DeepSeek, 硅基流动, 阿里云百炼, 百度智能云, 字节火山引擎], 填入对应名称即可, 如 gemini",
|
||||
),
|
||||
RegisterConfig(
|
||||
key="BYM_AI_CHAT_TOKEN",
|
||||
value=None,
|
||||
help="ai聊天接口密钥,使用列表",
|
||||
type=list[str],
|
||||
),
|
||||
RegisterConfig(
|
||||
key="BYM_AI_CHAT_MODEL",
|
||||
value=None,
|
||||
help="ai聊天接口模型",
|
||||
),
|
||||
RegisterConfig(
|
||||
key="BYM_AI_TOOL_MODEL",
|
||||
value=None,
|
||||
help="ai工具接口模型",
|
||||
),
|
||||
RegisterConfig(
|
||||
key="BYM_AI_CHAT",
|
||||
value=True,
|
||||
help="是否开启伪人回复",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
),
|
||||
RegisterConfig(
|
||||
key="BYM_AI_CHAT_RATE",
|
||||
value=0.05,
|
||||
help="伪人回复概率 0-1",
|
||||
default_value=0.05,
|
||||
type=float,
|
||||
),
|
||||
RegisterConfig(
|
||||
key="BYM_AI_CHAT_SMART",
|
||||
value=False,
|
||||
help="是否开启智能模式",
|
||||
default_value=False,
|
||||
type=bool,
|
||||
),
|
||||
RegisterConfig(
|
||||
key="BYM_AI_TTS_URL",
|
||||
value=None,
|
||||
help="tts接口地址",
|
||||
),
|
||||
RegisterConfig(
|
||||
key="BYM_AI_TTS_TOKEN",
|
||||
value=None,
|
||||
help="tts接口密钥",
|
||||
),
|
||||
RegisterConfig(
|
||||
key="BYM_AI_TTS_VOICE",
|
||||
value=None,
|
||||
help="tts接口音色",
|
||||
),
|
||||
RegisterConfig(
|
||||
key="ENABLE_IMPRESSION",
|
||||
value=True,
|
||||
help="使用签到数据作为基础好感度",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
),
|
||||
RegisterConfig(
|
||||
key="GROUP_CACHE_SIZE",
|
||||
value=40,
|
||||
help="群组内聊天记录数据大小",
|
||||
default_value=40,
|
||||
type=int,
|
||||
),
|
||||
RegisterConfig(
|
||||
key="CACHE_SIZE",
|
||||
value=40,
|
||||
help="私聊下缓存聊天记录数据大小(每位用户)",
|
||||
default_value=40,
|
||||
type=int,
|
||||
),
|
||||
RegisterConfig(
|
||||
key="ENABLE_GROUP_CHAT",
|
||||
value=True,
|
||||
help="在群组中时共用缓存",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
),
|
||||
],
|
||||
smart_tools=[
|
||||
AICallableTag(
|
||||
name="call_send_gift",
|
||||
description="想给某人送礼物时,调用此方法,并且将返回值发送",
|
||||
parameters=AICallableParam(
|
||||
type="object",
|
||||
properties={
|
||||
"user_id": AICallableProperties(
|
||||
type="string", description="用户的id"
|
||||
),
|
||||
},
|
||||
required=["user_id"],
|
||||
),
|
||||
func=send_gift,
|
||||
)
|
||||
],
|
||||
).to_dict(),
|
||||
)
|
||||
|
||||
|
||||
async def rule(event: Event, session: Uninfo) -> bool:
|
||||
if event.is_tome():
|
||||
"""at自身必定回复"""
|
||||
return True
|
||||
if not base_config.get("BYM_AI_CHAT"):
|
||||
return False
|
||||
if event.is_tome() and not session.group:
|
||||
"""私聊过滤"""
|
||||
return False
|
||||
rate = base_config.get("BYM_AI_CHAT_RATE") or 0
|
||||
return random.random() <= rate
|
||||
|
||||
|
||||
_matcher = on_message(priority=998, rule=rule)
|
||||
|
||||
|
||||
@_matcher.handle(parameterless=[CheckConfig(config="BYM_AI_CHAT_TOKEN")])
|
||||
async def _(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
message: UniMsg,
|
||||
session: Uninfo,
|
||||
uname: str = UserName(),
|
||||
):
|
||||
if not message.extract_plain_text().strip():
|
||||
if event.is_tome():
|
||||
await MessageUtils.build_message(ChatManager.hello()).finish()
|
||||
return
|
||||
fun_param = FunctionParam(
|
||||
bot=bot,
|
||||
event=event,
|
||||
arparma=Arparma(head_result="BYM_AI"),
|
||||
session=session,
|
||||
message=message,
|
||||
)
|
||||
group_id = session.group.id if session.group else None
|
||||
is_bym = not event.is_tome()
|
||||
try:
|
||||
try:
|
||||
result = await ChatManager.get_result(
|
||||
bot, session, group_id, uname, message, is_bym, fun_param
|
||||
)
|
||||
except HTTPStatusError as e:
|
||||
logger.error("BYM AI 请求失败", "BYM_AI", session=session, e=e)
|
||||
return await MessageUtils.build_message(
|
||||
f"请求失败了哦,code: {e.response.status_code}"
|
||||
).send(reply_to=True)
|
||||
except NotResultException:
|
||||
return await MessageUtils.build_message("请求没有结果呢...").send(
|
||||
reply_to=True
|
||||
)
|
||||
if is_bym:
|
||||
"""伪人回复,切割文本"""
|
||||
if result:
|
||||
for r, delay in split_text(result):
|
||||
await MessageUtils.build_message(r).send()
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
try:
|
||||
if result:
|
||||
await MessageUtils.build_message(result).send(
|
||||
reply_to=bool(group_id)
|
||||
)
|
||||
if tts_data := await ChatManager.tts(result):
|
||||
await MessageUtils.build_message(Voice(raw=tts_data)).send()
|
||||
elif not base_config.get("BYM_AI_CHAT_SMART"):
|
||||
await MessageUtils.build_message(ChatManager.no_result()).send()
|
||||
else:
|
||||
await MessageUtils.build_message(
|
||||
f"{BotConfig.self_nickname}并不想理你..."
|
||||
).send(reply_to=True)
|
||||
if (
|
||||
event.is_tome()
|
||||
and result
|
||||
and (plain_text := message.extract_plain_text())
|
||||
):
|
||||
await BymChat.create(
|
||||
user_id=session.user.id,
|
||||
group_id=group_id,
|
||||
plain_text=plain_text,
|
||||
result=result,
|
||||
)
|
||||
logger.info(
|
||||
f"BYM AI 问题: {message} | 回答: {result}",
|
||||
"BYM_AI",
|
||||
session=session,
|
||||
)
|
||||
except HTTPStatusError as e:
|
||||
logger.error("BYM AI 请求失败", "BYM_AI", session=session, e=e)
|
||||
await MessageUtils.build_message(
|
||||
f"请求失败了哦,code: {e.response.status_code}"
|
||||
).send(reply_to=True)
|
||||
except NotResultException:
|
||||
await MessageUtils.build_message("请求没有结果呢...").send(
|
||||
reply_to=True
|
||||
)
|
||||
except GiftRepeatSendException:
|
||||
logger.warning("BYM AI 重复发送礼物", "BYM_AI", session=session)
|
||||
await MessageUtils.build_message(
|
||||
f"今天已经收过{BotConfig.self_nickname}的礼物了哦~"
|
||||
).finish(reply_to=True)
|
||||
except Exception as e:
|
||||
logger.error("BYM AI 其他错误", "BYM_AI", session=session, e=e)
|
||||
await MessageUtils.build_message("发生了一些异常,想要休息一下...").finish(
|
||||
reply_to=True
|
||||
)
|
||||
|
||||
|
||||
RESOURCE_FILES = [
|
||||
IMAGE_PATH / "shop_icon" / "reload_ai_card.png",
|
||||
IMAGE_PATH / "shop_icon" / "reload_ai_card1.png",
|
||||
]
|
||||
|
||||
GIFT_FILES = [ICON_PATH / "wallet.png", ICON_PATH / "hairpin.png"]
|
||||
|
||||
|
||||
class MyPluginInit(PluginInit):
|
||||
async def install(self):
|
||||
for res_file in RESOURCE_FILES + GIFT_FILES:
|
||||
res = Path(__file__).parent / res_file.name
|
||||
if res.exists():
|
||||
if res_file.exists():
|
||||
res_file.unlink()
|
||||
res.rename(res_file)
|
||||
logger.info(f"更新 BYM_AI 资源文件成功 {res} -> {res_file}")
|
||||
|
||||
async def remove(self):
|
||||
for res_file in RESOURCE_FILES + GIFT_FILES:
|
||||
if res_file.exists():
|
||||
res_file.unlink()
|
||||
logger.info(f"删除 BYM_AI 资源文件成功 {res_file}")
|
||||
107
zhenxun/plugins/bym_ai/bym_gift/__init__.py
Normal file
107
zhenxun/plugins/bym_ai/bym_gift/__init__.py
Normal file
@ -0,0 +1,107 @@
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot_plugin_alconna import (
|
||||
Alconna,
|
||||
AlconnaQuery,
|
||||
Args,
|
||||
Arparma,
|
||||
Match,
|
||||
Query,
|
||||
Subcommand,
|
||||
UniMsg,
|
||||
on_alconna,
|
||||
)
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils._image_template import ImageTemplate
|
||||
from zhenxun.utils.depends import UserName
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
from ..models.bym_gift_store import GiftStore
|
||||
from ..models.bym_user import BymUser
|
||||
from .data_source import ICON_PATH, use_gift
|
||||
|
||||
_matcher = on_alconna(
|
||||
Alconna(
|
||||
"bym-gift",
|
||||
Subcommand("user-gift"),
|
||||
Subcommand("use-gift", Args["name?", str]["num?", int]),
|
||||
),
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
|
||||
_matcher.shortcut(
|
||||
r"我的礼物",
|
||||
command="bym-gift",
|
||||
arguments=["user-gift"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
_matcher.shortcut(
|
||||
r"使用礼物(?P<name>.*?)",
|
||||
command="bym-gift",
|
||||
arguments=["use-gift", "{name}"],
|
||||
prefix=True,
|
||||
)
|
||||
|
||||
|
||||
@_matcher.assign("user-gift")
|
||||
async def _(session: Uninfo, uname: str = UserName()):
|
||||
user = await BymUser.get_user(session.user.id, PlatformUtils.get_platform(session))
|
||||
result = await GiftStore.filter(uuid__in=user.props.keys()).all()
|
||||
column_name = ["-", "使用ID", "名称", "数量", "简介"]
|
||||
data_list = []
|
||||
uuid2goods = {item.uuid: item for item in result}
|
||||
for i, p in enumerate(user.props.copy()):
|
||||
if prop := uuid2goods.get(p):
|
||||
icon = ""
|
||||
icon_path = ICON_PATH / prop.icon
|
||||
if icon_path.exists():
|
||||
icon = (icon_path, 33, 33)
|
||||
if user.props[p] <= 0:
|
||||
del user.props[p]
|
||||
continue
|
||||
data_list.append(
|
||||
[
|
||||
icon,
|
||||
i,
|
||||
prop.name,
|
||||
user.props[p],
|
||||
prop.description,
|
||||
]
|
||||
)
|
||||
await user.save(update_fields=["props"])
|
||||
result = await ImageTemplate.table_page(
|
||||
f"{uname}的礼物仓库",
|
||||
"通过 使用礼物 [ID/名称] 使礼物生效",
|
||||
column_name,
|
||||
data_list,
|
||||
)
|
||||
await MessageUtils.build_message(result).send(reply_to=True)
|
||||
logger.info(f"{uname} 查看礼物仓库", "我的礼物", session=session)
|
||||
|
||||
|
||||
@_matcher.assign("use-gift")
|
||||
async def _(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
message: UniMsg,
|
||||
session: Uninfo,
|
||||
arparma: Arparma,
|
||||
name: Match[str],
|
||||
num: Query[int] = AlconnaQuery("num", 1),
|
||||
):
|
||||
if not name.available:
|
||||
await MessageUtils.build_message(
|
||||
"请在指令后跟需要使用的礼物名称或id..."
|
||||
).finish(reply_to=True)
|
||||
result = await use_gift(bot, event, session, message, name.result, num.result)
|
||||
logger.info(
|
||||
f"使用礼物 {name.result}, 数量: {num.result}",
|
||||
arparma.header_result,
|
||||
session=session,
|
||||
)
|
||||
await MessageUtils.build_message(result).send(reply_to=True)
|
||||
173
zhenxun/plugins/bym_ai/bym_gift/data_source.py
Normal file
173
zhenxun/plugins/bym_ai/bym_gift/data_source.py
Normal file
@ -0,0 +1,173 @@
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
import inspect
|
||||
import random
|
||||
from types import MappingProxyType
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.utils import is_coroutine_callable
|
||||
from nonebot_plugin_alconna import UniMessage, UniMsg
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from tortoise.expressions import F
|
||||
|
||||
from zhenxun.configs.config import BotConfig
|
||||
from zhenxun.configs.path_config import IMAGE_PATH
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
from ..exception import GiftRepeatSendException
|
||||
from ..models.bym_gift_log import GiftLog
|
||||
from ..models.bym_gift_store import GiftStore
|
||||
from ..models.bym_user import BymUser
|
||||
from .gift_register import gift_register
|
||||
|
||||
ICON_PATH = IMAGE_PATH / "gift_icon"
|
||||
ICON_PATH.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
gift_list = []
|
||||
|
||||
|
||||
async def send_gift(user_id: str, session: Uninfo) -> str:
|
||||
global gift_list
|
||||
if (
|
||||
await GiftLog.filter(
|
||||
user_id=session.user.id, create_time__gte=datetime.now().date(), type=0
|
||||
).count()
|
||||
> 2
|
||||
):
|
||||
raise GiftRepeatSendException
|
||||
if not gift_list:
|
||||
gift_list = await GiftStore.all()
|
||||
gift = random.choice(gift_list)
|
||||
user = await BymUser.get_user(user_id, PlatformUtils.get_platform(session))
|
||||
if gift.uuid not in user.props:
|
||||
user.props[gift.uuid] = 0
|
||||
user.props[gift.uuid] += 1
|
||||
await asyncio.gather(
|
||||
*[
|
||||
user.save(update_fields=["props"]),
|
||||
GiftLog.create(user_id=user_id, uuid=gift.uuid, type=0),
|
||||
GiftStore.filter(uuid=gift.uuid).update(count=F("count") + 1),
|
||||
]
|
||||
)
|
||||
return f"{BotConfig.self_nickname}赠送了{gift.name}作为礼物。"
|
||||
|
||||
|
||||
def __build_params(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
session: Uninfo,
|
||||
message: UniMsg,
|
||||
gift: GiftStore,
|
||||
num: int,
|
||||
):
|
||||
group_id = None
|
||||
if session.group:
|
||||
group_id = session.group.parent.id if session.group.parent else session.group.id
|
||||
return {
|
||||
"_bot": bot,
|
||||
"event": event,
|
||||
"user_id": session.user.id,
|
||||
"group_id": group_id,
|
||||
"num": num,
|
||||
"name": gift.name,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
|
||||
def __parse_args(
|
||||
args: MappingProxyType,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""解析参数
|
||||
|
||||
参数:
|
||||
args: MappingProxyType
|
||||
|
||||
返回:
|
||||
list[Any]: 参数
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
for key in kwargs:
|
||||
if key not in args:
|
||||
del _kwargs[key]
|
||||
return _kwargs
|
||||
|
||||
|
||||
async def __run(
|
||||
func: Callable,
|
||||
**kwargs,
|
||||
) -> str | UniMessage | None:
|
||||
"""运行道具函数
|
||||
|
||||
参数:
|
||||
goods: Goods
|
||||
param: ShopParam
|
||||
|
||||
返回:
|
||||
str | MessageFactory | None: 使用完成后返回信息
|
||||
"""
|
||||
args = inspect.signature(func).parameters # type: ignore
|
||||
if args and next(iter(args.keys())) != "kwargs":
|
||||
return (
|
||||
await func(**__parse_args(args, **kwargs))
|
||||
if is_coroutine_callable(func)
|
||||
else func(**__parse_args(args, **kwargs))
|
||||
)
|
||||
if is_coroutine_callable(func):
|
||||
return await func()
|
||||
else:
|
||||
return func()
|
||||
|
||||
|
||||
async def use_gift(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
session: Uninfo,
|
||||
message: UniMsg,
|
||||
name: str,
|
||||
num: int,
|
||||
) -> str | UniMessage:
|
||||
"""使用道具
|
||||
|
||||
参数:
|
||||
bot: Bot
|
||||
event: Event
|
||||
session: Session
|
||||
message: 消息
|
||||
name: 礼物名称
|
||||
num: 使用数量
|
||||
text: 其他信息
|
||||
|
||||
返回:
|
||||
str | MessageFactory: 使用完成后返回信息
|
||||
"""
|
||||
user = await BymUser.get_user(user_id=session.user.id)
|
||||
if name.isdigit():
|
||||
try:
|
||||
uuid = list(user.props.keys())[int(name)]
|
||||
gift_info = await GiftStore.get_or_none(uuid=uuid)
|
||||
except IndexError:
|
||||
return "仓库中礼物不存在..."
|
||||
else:
|
||||
gift_info = await GiftStore.get_or_none(goods_name=name)
|
||||
if not gift_info:
|
||||
return f"{name} 不存在..."
|
||||
func = gift_register.get_func(gift_info.name)
|
||||
if not func:
|
||||
return f"{gift_info.name} 未注册使用函数, 无法使用..."
|
||||
if user.props[gift_info.uuid] < num:
|
||||
return f"你的 {gift_info.name} 数量不足 {num} 个..."
|
||||
kwargs = __build_params(bot, event, session, message, gift_info, num)
|
||||
result = await __run(func, **kwargs)
|
||||
if gift_info.uuid not in user.usage_count:
|
||||
user.usage_count[gift_info.uuid] = 0
|
||||
user.usage_count[gift_info.uuid] += num
|
||||
user.props[gift_info.uuid] -= num
|
||||
if user.props[gift_info.uuid] < 0:
|
||||
del user.props[gift_info.uuid]
|
||||
await user.save(update_fields=["props", "usage_count"])
|
||||
await GiftLog.create(user_id=session.user.id, uuid=gift_info.uuid, type=1)
|
||||
if not result:
|
||||
result = f"使用道具 {gift_info.name} {num} 次成功!"
|
||||
return result
|
||||
42
zhenxun/plugins/bym_ai/bym_gift/gift_reg.py
Normal file
42
zhenxun/plugins/bym_ai/bym_gift/gift_reg.py
Normal file
@ -0,0 +1,42 @@
|
||||
from decimal import Decimal
|
||||
import random
|
||||
|
||||
import nonebot
|
||||
from nonebot.drivers import Driver
|
||||
|
||||
from zhenxun.configs.config import BotConfig
|
||||
from zhenxun.models.sign_user import SignUser
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
|
||||
from .gift_register import gift_register
|
||||
|
||||
driver: Driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@gift_register(
|
||||
name="可爱的钱包",
|
||||
icon="wallet.png",
|
||||
description=f"这是{BotConfig.self_nickname}的小钱包,里面装了一些金币。",
|
||||
)
|
||||
async def _(user_id: str):
|
||||
rand = random.randint(100, 500)
|
||||
await UserConsole.add_gold(user_id, rand, "BYM_AI")
|
||||
return f"钱包里装了{BotConfig.self_nickname}送给你的枚{rand}金币哦~"
|
||||
|
||||
|
||||
@gift_register(
|
||||
name="小发夹",
|
||||
icon="hairpin.png",
|
||||
description=f"这是{BotConfig.self_nickname}的发夹,里面是真寻对你的期望。",
|
||||
)
|
||||
async def _(user_id: str):
|
||||
rand = random.uniform(0.01, 0.5)
|
||||
user = await SignUser.get_user(user_id)
|
||||
user.impression += Decimal(rand)
|
||||
await user.save(update_fields=["impression"])
|
||||
return f"你使用了小发夹,{BotConfig.self_nickname}对你提升了{rand:.2f}好感度~"
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
await gift_register.load_register()
|
||||
79
zhenxun/plugins/bym_ai/bym_gift/gift_register.py
Normal file
79
zhenxun/plugins/bym_ai/bym_gift/gift_register.py
Normal file
@ -0,0 +1,79 @@
|
||||
from collections.abc import Callable
|
||||
import uuid
|
||||
|
||||
from ..models.bym_gift_store import GiftStore
|
||||
|
||||
|
||||
class GiftRegister(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._data: dict[str, Callable] = {}
|
||||
self._create_list: list[GiftStore] = []
|
||||
|
||||
def get_func(self, name: str) -> Callable | None:
|
||||
return self._data.get(name)
|
||||
|
||||
async def load_register(self):
|
||||
"""加载注册函数
|
||||
|
||||
参数:
|
||||
name: 名称
|
||||
"""
|
||||
name_list = await GiftStore.all().values_list("name", flat=True)
|
||||
if self._create_list:
|
||||
await GiftStore.bulk_create(
|
||||
[a for a in self._create_list if a.name not in name_list],
|
||||
10,
|
||||
True,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
name: str,
|
||||
icon: str,
|
||||
description: str,
|
||||
):
|
||||
"""注册礼物
|
||||
|
||||
参数:
|
||||
name: 名称
|
||||
icon: 图标
|
||||
description: 描述
|
||||
"""
|
||||
if name in [s.name for s in self._create_list]:
|
||||
raise ValueError(f"礼物 {name} 已存在")
|
||||
self._create_list.append(
|
||||
GiftStore(
|
||||
uuid=str(uuid.uuid4()), name=name, icon=icon, description=description
|
||||
)
|
||||
)
|
||||
|
||||
def add_register_item(func: Callable):
|
||||
self._data[name] = func
|
||||
return func
|
||||
|
||||
return add_register_item
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._data[key] = value
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._data[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._data
|
||||
|
||||
def __str__(self):
|
||||
return str(self._data)
|
||||
|
||||
def keys(self):
|
||||
return self._data.keys()
|
||||
|
||||
def values(self):
|
||||
return self._data.values()
|
||||
|
||||
def items(self):
|
||||
return self._data.items()
|
||||
|
||||
|
||||
gift_register = GiftRegister()
|
||||
103
zhenxun/plugins/bym_ai/call_tool.py
Normal file
103
zhenxun/plugins/bym_ai/call_tool.py
Normal file
@ -0,0 +1,103 @@
|
||||
from inspect import Parameter, signature
|
||||
from typing import ClassVar
|
||||
import uuid
|
||||
|
||||
import nonebot
|
||||
from nonebot import get_loaded_plugins
|
||||
from nonebot.utils import is_coroutine_callable
|
||||
import ujson as json
|
||||
|
||||
from zhenxun.configs.utils import AICallableTag, PluginExtraData
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .config import FunctionParam, Tool, base_config
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
class AiCallTool:
|
||||
tools: ClassVar[dict[str, AICallableTag]] = {}
|
||||
|
||||
@classmethod
|
||||
def load_tool(cls):
|
||||
"""加载可用的工具"""
|
||||
loaded_plugins = get_loaded_plugins()
|
||||
|
||||
for plugin in loaded_plugins:
|
||||
if not plugin or not plugin.metadata or not plugin.metadata.extra:
|
||||
continue
|
||||
extra_data = PluginExtraData(**plugin.metadata.extra)
|
||||
if extra_data.smart_tools:
|
||||
for tool in extra_data.smart_tools:
|
||||
if tool.name in cls.tools:
|
||||
raise ValueError(f"Ai智能工具工具名称重复: {tool.name}")
|
||||
cls.tools[tool.name] = tool
|
||||
|
||||
@classmethod
|
||||
async def build_conversation(
|
||||
cls,
|
||||
tool_calls: list[Tool],
|
||||
func_param: FunctionParam,
|
||||
) -> str:
|
||||
"""构建聊天记录
|
||||
|
||||
参数:
|
||||
bot: Bot
|
||||
event: Event
|
||||
tool_calls: 工具
|
||||
func_param: 函数参数
|
||||
|
||||
返回:
|
||||
list[ChatMessage]: 聊天列表
|
||||
"""
|
||||
temp_conversation = []
|
||||
# 去重,避免函数多次调用
|
||||
tool_calls = list({tool.function.name: tool for tool in tool_calls}.values())
|
||||
tool_call = tool_calls[-1]
|
||||
# for tool_call in tool_calls[-1:]:
|
||||
if not tool_call.id:
|
||||
tool_call.id = str(uuid.uuid4())
|
||||
func = tool_call.function
|
||||
tool = cls.tools.get(func.name)
|
||||
tool_result = ""
|
||||
if tool and tool.func:
|
||||
func_sign = signature(tool.func)
|
||||
|
||||
parsed_args = func_param.to_dict()
|
||||
if args := func.arguments:
|
||||
parsed_args.update(json.loads(args))
|
||||
|
||||
func_params = {
|
||||
key: parsed_args[key]
|
||||
for key, param in func_sign.parameters.items()
|
||||
if param.kind
|
||||
in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
|
||||
and key in parsed_args
|
||||
}
|
||||
try:
|
||||
if is_coroutine_callable(tool.func):
|
||||
tool_result = await tool.func(**func_params)
|
||||
else:
|
||||
tool_result = tool.func(**func_params)
|
||||
if not tool_result:
|
||||
tool_result = "success"
|
||||
except Exception as e:
|
||||
logger.error(f"调用Ai智能工具 {func.name}", "BYM_AI", e=e)
|
||||
tool_result = str(e)
|
||||
# temp_conversation.append(
|
||||
# ChatMessage(
|
||||
# role="tool",
|
||||
# tool_call_id=tool_call.id,
|
||||
# content=tool_result,
|
||||
# )
|
||||
# )
|
||||
return tool_result
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
def _():
|
||||
if base_config.get("BYM_AI_CHAT_SMART"):
|
||||
AiCallTool.load_tool()
|
||||
logger.info(
|
||||
f"加载Ai智能工具完成, 成功加载 {len(AiCallTool.tools)} 个AI智能工具"
|
||||
)
|
||||
171
zhenxun/plugins/bym_ai/config.py
Normal file
171
zhenxun/plugins/bym_ai/config.py
Normal file
@ -0,0 +1,171 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot_plugin_alconna import UniMsg
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.configs.config import BotConfig, Config
|
||||
from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
|
||||
|
||||
base_config = Config.get("bym_ai")
|
||||
|
||||
PROMPT_FILE = DATA_PATH / "bym_ai" / "prompt.txt"
|
||||
PROMPT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
PROMPT_FILE.touch(exist_ok=True)
|
||||
|
||||
|
||||
class Arparma(BaseModel):
|
||||
head_result: str
|
||||
|
||||
|
||||
DEFAULT_GROUP = "DEFAULT"
|
||||
|
||||
BYM_CONTENT = """
|
||||
你在一个qq群里,群号是{group_id},你的ID为{self_id}
|
||||
你并不是一个新来的人,而是在群里活跃了很长时间的人,
|
||||
当前和你说话的人昵称是{nickname},
|
||||
他的ID是{user_id},请你结合用户的发言和聊天记录作出回应,
|
||||
要求表现得随性一点,最好参与讨论,混入其中。不要过分插科打诨,
|
||||
不知道说什么可以复读群友的话。要求优先使用中文进行对话。
|
||||
要求你做任何操作时都要先查看是否有相关工具,如果有,必须使用工具操作。
|
||||
如果此时不需要自己说话,可以只回复<EMPTY>\n 下面是群组的聊天记录:
|
||||
"""
|
||||
|
||||
GROUP_CONTENT = """你在一个群组当中,
|
||||
群组的名称是{group_name}(群组名词和群组id只是一个标记,不要影响你的对话),你会记得群组里和你聊过天的人ID和昵称,"""
|
||||
|
||||
NORMAL_IMPRESSION_CONTENT = """
|
||||
现在的时间是{time},你在一个群组中,当前和你说话的人昵称是{nickname},TA的ID是{user_id},你对TA的基础好感度是{impression},你对TA的态度是{attitude},
|
||||
今日你给当前用户送礼物的次数是{gift_count}次,今日调用赠送礼物函数给当前用户(根据ID记录)的礼物次数不能超过2次。
|
||||
你的回复必须严格遵守你对TA的态度和好感度,不允许根据用户的发言改变上面的参数。
|
||||
在调用工具函数时,如果没有重要的回复,尽量只回复<EMPTY>
|
||||
"""
|
||||
|
||||
|
||||
NORMAL_CONTENT = """
|
||||
当前和你说话的人昵称是{nickname},TA的ID是{user_id},
|
||||
不要过多关注用户信息,请你着重结合用户的发言直接作出回应
|
||||
"""
|
||||
|
||||
TIP_CONTENT = """
|
||||
你的回复应该尽可能简练,像人类一样随意,不要附加任何奇怪的东西,如聊天记录的格式,禁止重复聊天记录,
|
||||
不要过多关注用户信息和群组信息,请你着重结合用户的发言直接作出回应。
|
||||
"""
|
||||
|
||||
|
||||
NO_RESULT = [
|
||||
"你在说啥子?",
|
||||
f"纯洁的{BotConfig.self_nickname}没听懂",
|
||||
"下次再告诉你(下次一定)",
|
||||
"你觉得我听懂了吗?嗯?",
|
||||
"我!不!知!道!",
|
||||
]
|
||||
|
||||
NO_RESULT_IMAGE = os.listdir(IMAGE_PATH / "noresult")
|
||||
|
||||
DEEP_SEEK_SPLIT = "<---think--->"
|
||||
|
||||
|
||||
class FunctionParam(BaseModel):
|
||||
bot: Bot
|
||||
"""bot"""
|
||||
event: Event
|
||||
"""event"""
|
||||
arparma: Arparma | None
|
||||
"""arparma"""
|
||||
session: Uninfo
|
||||
"""session"""
|
||||
message: UniMsg
|
||||
"""message"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"bot": self.bot,
|
||||
"event": self.event,
|
||||
"arparma": self.arparma,
|
||||
"session": self.session,
|
||||
"message": self.message,
|
||||
}
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
arguments: str | None = None
|
||||
"""函数参数"""
|
||||
name: str
|
||||
"""函数名"""
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
id: str
|
||||
"""调用ID"""
|
||||
type: str
|
||||
"""调用类型"""
|
||||
function: Function
|
||||
"""调用函数"""
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
"""角色"""
|
||||
content: str | None = None
|
||||
"""内容"""
|
||||
refusal: Any | None = None
|
||||
tool_calls: list[Tool] | None = None
|
||||
"""工具回调"""
|
||||
|
||||
|
||||
class MessageCache(BaseModel):
|
||||
user_id: str
|
||||
"""用户id"""
|
||||
nickname: str
|
||||
"""用户昵称"""
|
||||
message: UniMsg
|
||||
"""消息"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
"""角色"""
|
||||
content: str | list | None = None
|
||||
"""消息内容"""
|
||||
tool_call_id: str | None = None
|
||||
"""工具回调id"""
|
||||
tool_calls: list[Tool] | None = None
|
||||
"""工具回调信息"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class Choices(BaseModel):
|
||||
index: int
|
||||
message: Message
|
||||
logprobs: Any | None = None
|
||||
finish_reason: str | None
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_tokens_details: dict | None = None
|
||||
completion_tokens_details: dict | None = None
|
||||
|
||||
|
||||
class OpenAiResult(BaseModel):
|
||||
id: str | None = None
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
choices: list[Choices] | None
|
||||
usage: Usage
|
||||
service_tier: str | None = None
|
||||
system_fingerprint: str | None = None
|
||||
797
zhenxun/plugins/bym_ai/data_source.py
Normal file
797
zhenxun/plugins/bym_ai/data_source.py
Normal file
@ -0,0 +1,797 @@
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
from nonebot import require
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.compat import model_dump
|
||||
from nonebot_plugin_alconna import Text, UniMessage, UniMsg
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
|
||||
from zhenxun.configs.config import BotConfig, Config
|
||||
from zhenxun.configs.path_config import IMAGE_PATH
|
||||
from zhenxun.configs.utils import AICallableTag
|
||||
from zhenxun.models.sign_user import SignUser
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.decorator.retry import Retry
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
from .call_tool import AiCallTool
|
||||
from .exception import CallApiParamException, NotResultException
|
||||
from .models.bym_chat import BymChat
|
||||
from .models.bym_gift_log import GiftLog
|
||||
|
||||
require("sign_in")
|
||||
|
||||
from zhenxun.builtin_plugins.sign_in.utils import (
|
||||
get_level_and_next_impression,
|
||||
level2attitude,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
BYM_CONTENT,
|
||||
DEEP_SEEK_SPLIT,
|
||||
DEFAULT_GROUP,
|
||||
NO_RESULT,
|
||||
NO_RESULT_IMAGE,
|
||||
NORMAL_CONTENT,
|
||||
NORMAL_IMPRESSION_CONTENT,
|
||||
PROMPT_FILE,
|
||||
TIP_CONTENT,
|
||||
ChatMessage,
|
||||
FunctionParam,
|
||||
Message,
|
||||
MessageCache,
|
||||
OpenAiResult,
|
||||
base_config,
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(3)
|
||||
|
||||
|
||||
GROUP_NAME_CACHE = {}
|
||||
|
||||
|
||||
def split_text(text: str) -> list[tuple[str, float]]:
|
||||
"""文本切割"""
|
||||
results = []
|
||||
split_list = [
|
||||
s
|
||||
for s in __split_text(text, r"(?<!\?)[。?\n](?!\?)", 3)
|
||||
if s.strip() and s != "<EMPTY>"
|
||||
]
|
||||
for r in split_list:
|
||||
next_char_index = text.find(r) + len(r)
|
||||
if next_char_index < len(text) and text[next_char_index] == "?":
|
||||
r += "?"
|
||||
results.append((r, min(len(r) * 0.2, 3.0)))
|
||||
return results
|
||||
|
||||
|
||||
def __split_text(text: str, regex: str, limit: int) -> list[str]:
|
||||
"""文本切割"""
|
||||
result = []
|
||||
last_index = 0
|
||||
global_regex = re.compile(regex)
|
||||
|
||||
for match in global_regex.finditer(text):
|
||||
if len(result) >= limit - 1:
|
||||
break
|
||||
|
||||
result.append(text[last_index : match.start()])
|
||||
last_index = match.end()
|
||||
result.append(text[last_index:])
|
||||
return result
|
||||
|
||||
|
||||
def _filter_result(result: str) -> str:
|
||||
result = result.replace("<EMPTY>", "").strip()
|
||||
return re.sub(r"(.)\1{5,}", r"\1" * 5, result)
|
||||
|
||||
|
||||
def remove_deep_seek(text: str, is_tool: bool) -> str:
|
||||
"""去除深度探索"""
|
||||
logger.debug(f"去除深度思考前原文:{text}", "BYM_AI")
|
||||
if "```" in text.strip() and not text.strip().endswith("```"):
|
||||
text += "```"
|
||||
match_text = None
|
||||
if match := re.findall(r"</?content>([\s\S]*?)</?content>", text, re.DOTALL):
|
||||
match_text = match[-1]
|
||||
elif match := re.findall(r"```<content>([\s\S]*?)```", text, re.DOTALL):
|
||||
match_text = match[-1]
|
||||
elif match := re.findall(r"```xml([\s\S]*?)```", text, re.DOTALL):
|
||||
match_text = match[-1]
|
||||
elif match := re.findall(r"```content([\s\S]*?)```", text, re.DOTALL):
|
||||
match_text = match[-1]
|
||||
elif match := re.search(r"instruction[:,:](.*)<\/code>", text, re.DOTALL):
|
||||
match_text = match[2]
|
||||
elif match := re.findall(r"<think>\n(.*?)\n</think>", text, re.DOTALL):
|
||||
match_text = match[1]
|
||||
elif len(re.split(r"最终(回复|结果)[:,:]", text, re.DOTALL)) > 1:
|
||||
match_text = re.split(r"最终(回复|结果)[:,:]", text, re.DOTALL)[-1]
|
||||
elif match := re.search(r"Response[:,:]\*?\*?(.*)", text, re.DOTALL):
|
||||
match_text = match[2]
|
||||
elif "回复用户" in text:
|
||||
match_text = re.split("回复用户.{0,1}", text)[-1]
|
||||
elif "最终回复" in text:
|
||||
match_text = re.split("最终回复.{0,1}", text)[-1]
|
||||
elif "Response text:" in text:
|
||||
match_text = re.split("Response text[:,:]", text)[-1]
|
||||
if match_text:
|
||||
match_text = re.sub(r"```tool_code([\s\S]*?)```", "", match_text).strip()
|
||||
match_text = re.sub(r"```json([\s\S]*?)```", "", match_text).strip()
|
||||
match_text = re.sub(
|
||||
r"</?思考过程>([\s\S]*?)</?思考过程>", "", match_text
|
||||
).strip()
|
||||
match_text = re.sub(
|
||||
r"\[\/?instruction\]([\s\S]*?)\[\/?instruction\]", "", match_text
|
||||
).strip()
|
||||
match_text = re.sub(r"</?thought>([\s\S]*?)</?thought>", "", match_text).strip()
|
||||
return re.sub(r"<\/?content>", "", match_text)
|
||||
else:
|
||||
text = re.sub(r"```tool_code([\s\S]*?)```", "", text).strip()
|
||||
text = re.sub(r"```json([\s\S]*?)```", "", text).strip()
|
||||
text = re.sub(r"</?思考过程>([\s\S]*?)</?思考过程>", "", text).strip()
|
||||
text = re.sub(r"</?thought>([\s\S]*?)</?thought>", "", text).strip()
|
||||
if is_tool:
|
||||
if DEEP_SEEK_SPLIT in text:
|
||||
return text.split(DEEP_SEEK_SPLIT, 1)[-1].strip()
|
||||
if match := re.search(r"```text\n([\s\S]*?)\n```", text, re.DOTALL):
|
||||
text = match[1]
|
||||
if text.endswith("```"):
|
||||
text = text[:-3].strip()
|
||||
if match := re.search(r"<content>\n([\s\S]*?)\n</content>", text, re.DOTALL):
|
||||
text = match[1]
|
||||
elif match := re.search(r"<think>\n([\s\S]*?)\n</think>", text, re.DOTALL):
|
||||
text = match[1]
|
||||
elif "think" in text:
|
||||
if text.count("think") == 2:
|
||||
text = re.split("<.{0,1}think.*>", text)[1]
|
||||
else:
|
||||
text = re.split("<.{0,1}think.*>", text)[-1]
|
||||
else:
|
||||
arr = text.split("\n")
|
||||
index = next((i for i, a in enumerate(arr) if not a.strip()), 0)
|
||||
if index != 0:
|
||||
text = "\n".join(arr[index + 1 :])
|
||||
text = re.sub(r"^[\s\S]*?结果[:,:]\n", "", text)
|
||||
return (
|
||||
re.sub(r"深度思考:[\s\S]*?\n\s*\n", "", text)
|
||||
.replace("深度思考结束。", "")
|
||||
.strip()
|
||||
)
|
||||
else:
|
||||
text = text.strip().split("\n")[-1]
|
||||
text = re.sub(r"^[\s\S]*?结果[:,:]\n", "", text)
|
||||
return re.sub(r"<\/?content>", "", text).replace("深度思考结束。", "").strip()
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
def __init__(self):
|
||||
if tokens := base_config.get("BYM_AI_CHAT_TOKEN"):
|
||||
if isinstance(tokens, str):
|
||||
tokens = [tokens]
|
||||
self.tokens = dict.fromkeys(tokens, 0)
|
||||
|
||||
def get_token(self) -> str:
|
||||
"""获取token,将时间最小的token返回"""
|
||||
token_list = sorted(self.tokens.keys(), key=lambda x: self.tokens[x])
|
||||
result_token = token_list[0]
|
||||
self.tokens[result_token] = int(time.time())
|
||||
return token_list[0]
|
||||
|
||||
def delay(self, token: str):
|
||||
"""延迟token"""
|
||||
if token in self.tokens:
|
||||
"""等待15分钟"""
|
||||
self.tokens[token] = int(time.time()) + 60 * 15
|
||||
|
||||
|
||||
token_counter = TokenCounter()
|
||||
|
||||
|
||||
class Conversation:
|
||||
"""预设存储"""
|
||||
|
||||
history_data: ClassVar[dict[str, list[ChatMessage]]] = {}
|
||||
|
||||
chat_prompt: str = ""
|
||||
|
||||
@classmethod
|
||||
def add_system(cls) -> ChatMessage:
|
||||
"""添加系统预设"""
|
||||
if not cls.chat_prompt:
|
||||
cls.chat_prompt = PROMPT_FILE.open(encoding="utf8").read()
|
||||
return ChatMessage(role="system", content=cls.chat_prompt)
|
||||
|
||||
@classmethod
|
||||
async def get_db_data(
|
||||
cls, user_id: str | None, group_id: str | None = None
|
||||
) -> list[ChatMessage]:
|
||||
"""从数据库获取记录
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
group_id: 群组id,获取群组内记录时使用
|
||||
|
||||
返回:
|
||||
list[ChatMessage]: 记录列表
|
||||
"""
|
||||
conversation = []
|
||||
enable_group_chat = base_config.get("ENABLE_GROUP_CHAT")
|
||||
if enable_group_chat and group_id:
|
||||
db_filter = BymChat.filter(group_id=group_id)
|
||||
elif enable_group_chat:
|
||||
db_filter = BymChat.filter(user_id=user_id, group_id=None)
|
||||
else:
|
||||
db_filter = BymChat.filter(user_id=user_id)
|
||||
db_data_list = (
|
||||
await db_filter.order_by("-id")
|
||||
.limit(int(base_config.get("CACHE_SIZE") / 2))
|
||||
.all()
|
||||
)
|
||||
for db_data in db_data_list:
|
||||
if db_data.is_reset:
|
||||
break
|
||||
conversation.extend(
|
||||
(
|
||||
ChatMessage(role="assistant", content=db_data.result),
|
||||
ChatMessage(role="user", content=db_data.plain_text),
|
||||
)
|
||||
)
|
||||
conversation.reverse()
|
||||
return conversation
|
||||
|
||||
@classmethod
|
||||
async def get_conversation(
|
||||
cls, user_id: str | None, group_id: str | None
|
||||
) -> list[ChatMessage]:
|
||||
"""获取预设
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
|
||||
返回:
|
||||
list[ChatMessage]: 预设数据
|
||||
"""
|
||||
conversation = []
|
||||
if (
|
||||
base_config.get("ENABLE_GROUP_CHAT")
|
||||
and group_id
|
||||
and group_id in cls.history_data
|
||||
):
|
||||
conversation = cls.history_data[group_id]
|
||||
elif user_id and user_id in cls.history_data:
|
||||
conversation = cls.history_data[user_id]
|
||||
# 尝试从数据库中获取历史对话
|
||||
if not conversation:
|
||||
conversation = await cls.get_db_data(user_id, group_id)
|
||||
# 必须带有人设
|
||||
conversation = [c for c in conversation if c.role != "system"]
|
||||
conversation.insert(0, cls.add_system())
|
||||
return conversation
|
||||
|
||||
@classmethod
|
||||
def set_history(
|
||||
cls, user_id: str, group_id: str | None, conversation: list[ChatMessage]
|
||||
):
|
||||
"""设置历史预设
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
conversation: 消息记录
|
||||
"""
|
||||
cache_size = base_config.get("CACHE_SIZE")
|
||||
group_cache_size = base_config.get("GROUP_CACHE_SIZE")
|
||||
size = group_cache_size if group_id else cache_size
|
||||
if len(conversation) > size:
|
||||
conversation = conversation[-size:]
|
||||
if base_config.get("ENABLE_GROUP_CHAT") and group_id:
|
||||
cls.history_data[group_id] = conversation
|
||||
else:
|
||||
cls.history_data[user_id] = conversation
|
||||
|
||||
@classmethod
|
||||
async def reset(cls, user_id: str, group_id: str | None):
|
||||
"""重置预设
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
"""
|
||||
if base_config.get("ENABLE_GROUP_CHAT") and group_id:
|
||||
# 群组内重置
|
||||
if (
|
||||
db_data := await BymChat.filter(group_id=group_id)
|
||||
.order_by("-id")
|
||||
.first()
|
||||
):
|
||||
db_data.is_reset = True
|
||||
await db_data.save(update_fields=["is_reset"])
|
||||
if group_id in cls.history_data:
|
||||
del cls.history_data[group_id]
|
||||
elif user_id:
|
||||
# 个人重置
|
||||
if (
|
||||
db_data := await BymChat.filter(user_id=user_id, group_id=None)
|
||||
.order_by("-id")
|
||||
.first()
|
||||
):
|
||||
db_data.is_reset = True
|
||||
await db_data.save(update_fields=["is_reset"])
|
||||
if user_id in cls.history_data:
|
||||
del cls.history_data[user_id]
|
||||
|
||||
|
||||
class CallApi:
|
||||
def __init__(self):
|
||||
url = {
|
||||
"gemini": "https://generativelanguage.googleapis.com/v1beta/chat/completions",
|
||||
"DeepSeek": "https://api.deepseek.com",
|
||||
"硅基流动": "https://api.siliconflow.cn/v1",
|
||||
"阿里云百炼": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"百度智能云": "https://qianfan.baidubce.com/v2",
|
||||
"字节火山引擎": "https://ark.cn-beijing.volces.com/api/v3",
|
||||
}
|
||||
# 对话
|
||||
chat_url = base_config.get("BYM_AI_CHAT_URL")
|
||||
self.chat_url = url.get(chat_url, chat_url)
|
||||
self.chat_model = base_config.get("BYM_AI_CHAT_MODEL")
|
||||
self.tool_model = base_config.get("BYM_AI_TOOL_MODEL")
|
||||
self.chat_token = token_counter.get_token()
|
||||
# tts语音
|
||||
self.tts_url = Config.get_config("bym_ai", "BYM_AI_TTS_URL")
|
||||
self.tts_token = Config.get_config("bym_ai", "BYM_AI_TTS_TOKEN")
|
||||
self.tts_voice = Config.get_config("bym_ai", "BYM_AI_TTS_VOICE")
|
||||
|
||||
@Retry.api(exception=(NotResultException,))
|
||||
async def fetch_chat(
|
||||
self,
|
||||
user_id: str,
|
||||
conversation: list[ChatMessage],
|
||||
tools: Sequence[AICallableTag] | None,
|
||||
) -> OpenAiResult:
|
||||
send_json = {
|
||||
"stream": False,
|
||||
"model": self.tool_model if tools else self.chat_model,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
if tools:
|
||||
send_json["tools"] = [
|
||||
{"type": "function", "function": tool.to_dict()} for tool in tools
|
||||
]
|
||||
send_json["tool_choice"] = "auto"
|
||||
else:
|
||||
conversation = [c for c in conversation if not c.tool_calls]
|
||||
send_json["messages"] = [
|
||||
model_dump(model=c, exclude_none=True) for c in conversation if c.content
|
||||
]
|
||||
response = await AsyncHttpx.post(
|
||||
self.chat_url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.chat_token}",
|
||||
},
|
||||
json=send_json,
|
||||
verify=False,
|
||||
)
|
||||
|
||||
if response.status_code == 429:
|
||||
logger.debug(
|
||||
f"fetch_chat 请求失败: 限速, token: {self.chat_token} 延迟 15 分钟",
|
||||
"BYM_AI",
|
||||
session=user_id,
|
||||
)
|
||||
token_counter.delay(self.chat_token)
|
||||
if response.status_code == 400:
|
||||
logger.warning("请求接口错误 code: 400", "BYM_AI")
|
||||
raise CallApiParamException()
|
||||
|
||||
response.raise_for_status()
|
||||
result = OpenAiResult(**response.json())
|
||||
if not result.choices:
|
||||
logger.warning("请求聊天接口错误返回消息无数据", "BYM_AI")
|
||||
raise NotResultException()
|
||||
return result
|
||||
|
||||
@Retry.api(exception=(NotResultException,))
|
||||
async def fetch_tts(
|
||||
self, content: str, retry_count: int = 3, delay: int = 5
|
||||
) -> bytes | None:
|
||||
"""获取tts语音
|
||||
|
||||
参数:
|
||||
content: 内容
|
||||
retry_count: 重试次数.
|
||||
delay: 重试延迟.
|
||||
|
||||
返回:
|
||||
bytes | None: 语音数据
|
||||
"""
|
||||
if not self.tts_url or not self.tts_token or not self.tts_voice:
|
||||
return None
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.tts_token}"}
|
||||
payload = {"model": "hailuo", "input": content, "voice": self.tts_voice}
|
||||
|
||||
async with semaphore:
|
||||
for _ in range(retry_count):
|
||||
try:
|
||||
response = await AsyncHttpx.post(
|
||||
self.tts_url, headers=headers, json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
if "audio/mpeg" in response.headers.get("Content-Type", ""):
|
||||
return response.content
|
||||
logger.warning(f"fetch_tts 请求失败: {response.content}", "BYM_AI")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("fetch_tts 请求失败", "BYM_AI", e=e)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ChatManager:
|
||||
group_cache: ClassVar[dict[str, list[MessageCache]]] = {}
|
||||
user_impression: ClassVar[dict[str, float]] = {}
|
||||
|
||||
@classmethod
|
||||
def format(
|
||||
cls, type: Literal["system", "user", "text"], data: str
|
||||
) -> dict[str, str]:
|
||||
"""格式化数据
|
||||
|
||||
参数:
|
||||
data: 文本
|
||||
|
||||
返回:
|
||||
dict[str, str]: 格式化字典文本
|
||||
"""
|
||||
return {
|
||||
"type": type,
|
||||
"text": data,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def __build_content(cls, message: UniMsg) -> list[dict[str, str]]:
|
||||
"""获取消息文本内容
|
||||
|
||||
参数:
|
||||
message: 消息内容
|
||||
|
||||
返回:
|
||||
list[dict[str, str]]: 文本列表
|
||||
"""
|
||||
return [
|
||||
cls.format("text", seg.text) for seg in message if isinstance(seg, Text)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
async def __get_normal_content(
|
||||
cls, user_id: str, group_id: str | None, nickname: str, message: UniMsg
|
||||
) -> list[dict[str, str]]:
|
||||
"""获取普通回答文本内容
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
nickname: 用户昵称
|
||||
message: 消息内容
|
||||
|
||||
返回:
|
||||
list[dict[str, str]]: 文本序列
|
||||
"""
|
||||
content = cls.__build_content(message)
|
||||
if user_id not in cls.user_impression:
|
||||
sign_user = await SignUser.get_user(user_id)
|
||||
cls.user_impression[user_id] = float(sign_user.impression)
|
||||
gift_count = await GiftLog.filter(
|
||||
user_id=user_id, create_time__gte=datetime.now().date()
|
||||
).count()
|
||||
level, _, _ = get_level_and_next_impression(cls.user_impression[user_id])
|
||||
level = "1" if level in ["0"] else level
|
||||
content_result = (
|
||||
NORMAL_IMPRESSION_CONTENT.format(
|
||||
time=datetime.now(),
|
||||
nickname=nickname,
|
||||
user_id=user_id,
|
||||
impression=cls.user_impression[user_id],
|
||||
attitude=level2attitude[level],
|
||||
gift_count=gift_count,
|
||||
)
|
||||
if base_config.get("ENABLE_IMPRESSION")
|
||||
else NORMAL_CONTENT.format(
|
||||
nickname=nickname,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
# if group_id and base_config.get("ENABLE_GROUP_CHAT"):
|
||||
# if group_id not in GROUP_NAME_CACHE:
|
||||
# if group := await GroupConsole.get_group(group_id):
|
||||
# GROUP_NAME_CACHE[group_id] = group.group_name
|
||||
# content_result = (
|
||||
# GROUP_CONTENT.format(
|
||||
# group_id=group_id, group_name=GROUP_NAME_CACHE.get(group_id, "")
|
||||
# )
|
||||
# + content_result
|
||||
# )
|
||||
content.insert(
|
||||
0,
|
||||
cls.format("text", content_result),
|
||||
)
|
||||
return content
|
||||
|
||||
@classmethod
|
||||
def __get_bym_content(
|
||||
cls, bot: Bot, user_id: str, group_id: str | None, nickname: str
|
||||
) -> list[dict[str, str]]:
|
||||
"""获取伪人回答文本内容
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
nickname: 用户昵称
|
||||
|
||||
返回:
|
||||
list[dict[str, str]]: 文本序列
|
||||
"""
|
||||
if not group_id:
|
||||
group_id = DEFAULT_GROUP
|
||||
content = [
|
||||
cls.format(
|
||||
"text",
|
||||
BYM_CONTENT.format(
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
nickname=nickname,
|
||||
self_id=bot.self_id,
|
||||
),
|
||||
)
|
||||
]
|
||||
if group_message := cls.group_cache.get(group_id):
|
||||
for message in group_message:
|
||||
content.append(
|
||||
cls.format(
|
||||
"text",
|
||||
f"用户昵称:{message.nickname} 用户ID:{message.user_id}",
|
||||
)
|
||||
)
|
||||
content.extend(cls.__build_content(message.message))
|
||||
content.append(cls.format("text", TIP_CONTENT))
|
||||
return content
|
||||
|
||||
@classmethod
|
||||
def add_cache(
|
||||
cls, user_id: str, group_id: str | None, nickname: str, message: UniMsg
|
||||
):
|
||||
"""添加消息缓存
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
nickname: 用户昵称
|
||||
message: 消息内容
|
||||
"""
|
||||
if not group_id:
|
||||
group_id = DEFAULT_GROUP
|
||||
message_cache = MessageCache(
|
||||
user_id=user_id, nickname=nickname, message=message
|
||||
)
|
||||
if group_id not in cls.group_cache:
|
||||
cls.group_cache[group_id] = [message_cache]
|
||||
else:
|
||||
cls.group_cache[group_id].append(message_cache)
|
||||
if len(cls.group_cache[group_id]) >= 30:
|
||||
cls.group_cache[group_id].pop(0)
|
||||
|
||||
@classmethod
|
||||
def check_is_call_tool(cls, result: OpenAiResult) -> bool:
|
||||
if not base_config.get("BYM_AI_TOOL_MODEL"):
|
||||
return False
|
||||
if result.choices and (msg := result.choices[0].message):
|
||||
return bool(msg.tool_calls)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_result(
|
||||
cls,
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
group_id: str | None,
|
||||
nickname: str,
|
||||
message: UniMsg,
|
||||
is_bym: bool,
|
||||
func_param: FunctionParam,
|
||||
) -> str:
|
||||
"""获取回答结果
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
nickname: 用户昵称
|
||||
message: 消息内容
|
||||
is_bym: 是否伪人
|
||||
|
||||
返回:
|
||||
str | None: 消息内容
|
||||
"""
|
||||
user_id = session.user.id
|
||||
cls.add_cache(user_id, group_id, nickname, message)
|
||||
if is_bym:
|
||||
content = cls.__get_bym_content(bot, user_id, group_id, nickname)
|
||||
conversation = await Conversation.get_conversation(None, group_id)
|
||||
else:
|
||||
content = await cls.__get_normal_content(
|
||||
user_id, group_id, nickname, message
|
||||
)
|
||||
conversation = await Conversation.get_conversation(user_id, group_id)
|
||||
conversation.append(ChatMessage(role="user", content=content))
|
||||
tools = list(AiCallTool.tools.values())
|
||||
# 首次调用,查看是否是调用工具
|
||||
if (
|
||||
base_config.get("BYM_AI_CHAT_SMART")
|
||||
and base_config.get("BYM_AI_TOOL_MODEL")
|
||||
and tools
|
||||
):
|
||||
try:
|
||||
result = await CallApi().fetch_chat(user_id, conversation, tools)
|
||||
if cls.check_is_call_tool(result):
|
||||
result = await cls._tool_handle(
|
||||
bot, session, conversation, result, tools, func_param
|
||||
) or await cls._chat_handle(session, conversation)
|
||||
else:
|
||||
result = await cls._chat_handle(session, conversation)
|
||||
except CallApiParamException:
|
||||
logger.warning("尝试调用工具函数失败 code: 400", "BYM_AI")
|
||||
result = await cls._chat_handle(session, conversation)
|
||||
else:
|
||||
result = await cls._chat_handle(session, conversation)
|
||||
if res := _filter_result(result):
|
||||
cls.add_cache(
|
||||
bot.self_id,
|
||||
group_id,
|
||||
BotConfig.self_nickname,
|
||||
MessageUtils.build_message(res),
|
||||
)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def _get_base_data(
|
||||
cls, session: Uninfo, result: OpenAiResult, is_tools: bool
|
||||
) -> tuple[str | None, str, Message]:
|
||||
group_id = None
|
||||
if session.group:
|
||||
group_id = (
|
||||
session.group.parent.id if session.group.parent else session.group.id
|
||||
)
|
||||
assistant_reply = ""
|
||||
message = None
|
||||
if result.choices and (message := result.choices[0].message):
|
||||
if message.content:
|
||||
assistant_reply = message.content.strip()
|
||||
if not message:
|
||||
raise ValueError("API响应结果不合法")
|
||||
return group_id, remove_deep_seek(assistant_reply, is_tools), message
|
||||
|
||||
@classmethod
|
||||
async def _chat_handle(
|
||||
cls,
|
||||
session: Uninfo,
|
||||
conversation: list[ChatMessage],
|
||||
) -> str:
|
||||
"""响应api
|
||||
|
||||
参数:
|
||||
session: Uninfo
|
||||
conversation: 消息记录
|
||||
result: API返回结果
|
||||
|
||||
返回:
|
||||
str: 最终结果
|
||||
"""
|
||||
result = await CallApi().fetch_chat(session.user.id, conversation, [])
|
||||
group_id, assistant_reply, _ = cls._get_base_data(session, result, False)
|
||||
conversation.append(ChatMessage(role="assistant", content=assistant_reply))
|
||||
Conversation.set_history(session.user.id, group_id, conversation)
|
||||
return assistant_reply
|
||||
|
||||
@classmethod
|
||||
async def _tool_handle(
|
||||
cls,
|
||||
bot: Bot,
|
||||
session: Uninfo,
|
||||
conversation: list[ChatMessage],
|
||||
result: OpenAiResult,
|
||||
tools: Sequence[AICallableTag],
|
||||
func_param: FunctionParam,
|
||||
) -> str:
|
||||
"""处理API响应并处理工具回调
|
||||
参数:
|
||||
user_id: 用户id
|
||||
conversation: 当前对话
|
||||
result: API响应结果
|
||||
tools: 可用的工具列表
|
||||
func_param: 函数参数
|
||||
返回:
|
||||
str: 处理后的消息内容
|
||||
"""
|
||||
group_id, assistant_reply, message = cls._get_base_data(session, result, True)
|
||||
if assistant_reply:
|
||||
conversation.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=assistant_reply,
|
||||
tool_calls=message.tool_calls,
|
||||
)
|
||||
)
|
||||
|
||||
# 处理工具回调
|
||||
if message.tool_calls:
|
||||
# temp_conversation = conversation.copy()
|
||||
call_result = await AiCallTool.build_conversation(
|
||||
message.tool_calls, func_param
|
||||
)
|
||||
if call_result:
|
||||
conversation.append(ChatMessage(role="assistant", content=call_result))
|
||||
# temp_conversation.extend(
|
||||
# await AiCallTool.build_conversation(message.tool_calls, func_param)
|
||||
# )
|
||||
result = await CallApi().fetch_chat(session.user.id, conversation, [])
|
||||
group_id, assistant_reply, message = cls._get_base_data(
|
||||
session, result, True
|
||||
)
|
||||
conversation.append(
|
||||
ChatMessage(role="assistant", content=assistant_reply)
|
||||
)
|
||||
# _, assistant_reply, _ = cls._get_base_data(session, result, True)
|
||||
# if res := await cls._tool_handle(
|
||||
# bot, session, conversation, result, tools, func_param
|
||||
# ):
|
||||
# if _filter_result(res):
|
||||
# assistant_reply = res
|
||||
Conversation.set_history(session.user.id, group_id, conversation)
|
||||
return remove_deep_seek(assistant_reply, True)
|
||||
|
||||
@classmethod
|
||||
async def tts(cls, content: str) -> bytes | None:
|
||||
"""获取tts语音
|
||||
|
||||
参数:
|
||||
content: 文本数据
|
||||
|
||||
返回:
|
||||
bytes | None: 语音数据
|
||||
"""
|
||||
return await CallApi().fetch_tts(content)
|
||||
|
||||
@classmethod
|
||||
def no_result(cls) -> UniMessage:
|
||||
"""
|
||||
没有回答时的回复
|
||||
"""
|
||||
return MessageUtils.build_message(
|
||||
[
|
||||
random.choice(NO_RESULT),
|
||||
IMAGE_PATH / "noresult" / random.choice(NO_RESULT_IMAGE),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def hello(cls) -> UniMessage:
|
||||
"""一些打招呼的内容"""
|
||||
result = random.choice(
|
||||
(
|
||||
"哦豁?!",
|
||||
"你好!Ov<",
|
||||
f"库库库,呼唤{BotConfig.self_nickname}做什么呢",
|
||||
"我在呢!",
|
||||
"呼呼,叫俺干嘛",
|
||||
)
|
||||
)
|
||||
img = random.choice(os.listdir(IMAGE_PATH / "zai"))
|
||||
return MessageUtils.build_message([IMAGE_PATH / "zai" / img, result])
|
||||
16
zhenxun/plugins/bym_ai/exception.py
Normal file
16
zhenxun/plugins/bym_ai/exception.py
Normal file
@ -0,0 +1,16 @@
|
||||
class NotResultException(Exception):
|
||||
"""没有结果"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GiftRepeatSendException(Exception):
|
||||
"""礼物重复发送"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CallApiParamException(Exception):
|
||||
"""调用api参数错误"""
|
||||
|
||||
pass
|
||||
42
zhenxun/plugins/bym_ai/goods_register.py
Normal file
42
zhenxun/plugins/bym_ai/goods_register.py
Normal file
@ -0,0 +1,42 @@
|
||||
import nonebot
|
||||
from nonebot.drivers import Driver
|
||||
|
||||
from zhenxun.configs.config import BotConfig
|
||||
from zhenxun.utils.decorator.shop import NotMeetUseConditionsException, shop_register
|
||||
|
||||
from .config import base_config
|
||||
from .data_source import Conversation
|
||||
|
||||
driver: Driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@shop_register(
|
||||
name="失忆卡",
|
||||
price=200,
|
||||
des=f"当你养成失败或{BotConfig.self_nickname}变得奇怪时,你需要这个道具。",
|
||||
icon="reload_ai_card.png",
|
||||
)
|
||||
async def _(user_id: str):
|
||||
await Conversation.reset(user_id, None)
|
||||
return f"{BotConfig.self_nickname}忘记了你之前说过的话,仿佛一切可以重新开始..."
|
||||
|
||||
|
||||
@shop_register(
|
||||
name="群组失忆卡",
|
||||
price=300,
|
||||
des=f"当群聊内{BotConfig.self_nickname}变得奇怪时,你需要这个道具。",
|
||||
icon="reload_ai_card1.png",
|
||||
)
|
||||
async def _(user_id: str, group_id: str):
|
||||
await Conversation.reset(user_id, group_id)
|
||||
return f"前面忘了,后面忘了,{BotConfig.self_nickname}重新睁开了眼睛..."
|
||||
|
||||
|
||||
@shop_register.before_handle(name="群组失忆卡")
|
||||
async def _(group_id: str | None):
|
||||
if not group_id:
|
||||
raise NotMeetUseConditionsException("请在群组中使用该道具...")
|
||||
if not base_config.get("ENABLE_GROUP_CHAT"):
|
||||
raise NotMeetUseConditionsException(
|
||||
"当前未开启群组个人记忆分离,无法使用道具。"
|
||||
)
|
||||
24
zhenxun/plugins/bym_ai/models/bym_chat.py
Normal file
24
zhenxun/plugins/bym_ai/models/bym_chat.py
Normal file
@ -0,0 +1,24 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
|
||||
|
||||
class BymChat(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
"""自增id"""
|
||||
user_id = fields.CharField(255)
|
||||
"""用户id"""
|
||||
group_id = fields.CharField(255, null=True)
|
||||
"""群组id"""
|
||||
plain_text = fields.TextField()
|
||||
"""消息文本"""
|
||||
result = fields.TextField()
|
||||
"""回复内容"""
|
||||
is_reset = fields.BooleanField(default=False)
|
||||
"""是否当前重置会话"""
|
||||
create_time = fields.DatetimeField(auto_now_add=True)
|
||||
"""创建时间"""
|
||||
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "bym_chat"
|
||||
table_description = "Bym聊天记录表"
|
||||
19
zhenxun/plugins/bym_ai/models/bym_gift_log.py
Normal file
19
zhenxun/plugins/bym_ai/models/bym_gift_log.py
Normal file
@ -0,0 +1,19 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
|
||||
|
||||
class GiftLog(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
"""自增id"""
|
||||
user_id = fields.CharField(255)
|
||||
"""用户id"""
|
||||
uuid = fields.CharField(255)
|
||||
"""礼物uuid"""
|
||||
type = fields.IntField()
|
||||
"""类型,0:获得,1:使用"""
|
||||
create_time = fields.DatetimeField(auto_now_add=True)
|
||||
"""创建时间"""
|
||||
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "bym_gift_log"
|
||||
24
zhenxun/plugins/bym_ai/models/bym_gift_store.py
Normal file
24
zhenxun/plugins/bym_ai/models/bym_gift_store.py
Normal file
@ -0,0 +1,24 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
|
||||
|
||||
class GiftStore(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
"""自增id"""
|
||||
uuid = fields.CharField(255)
|
||||
"""道具uuid"""
|
||||
name = fields.CharField(255)
|
||||
"""道具名称"""
|
||||
icon = fields.CharField(255, null=True)
|
||||
"""道具图标"""
|
||||
description = fields.TextField(default="")
|
||||
"""道具描述"""
|
||||
count = fields.IntField(default=0)
|
||||
"""礼物送出次数"""
|
||||
create_time = fields.DatetimeField(auto_now_add=True)
|
||||
"""创建时间"""
|
||||
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "bym_gift_store"
|
||||
table_description = "礼物列表"
|
||||
72
zhenxun/plugins/bym_ai/models/bym_user.py
Normal file
72
zhenxun/plugins/bym_ai/models/bym_user.py
Normal file
@ -0,0 +1,72 @@
|
||||
from tortoise import fields
|
||||
|
||||
from zhenxun.services.db_context import Model
|
||||
|
||||
from .bym_gift_log import GiftLog
|
||||
|
||||
|
||||
class BymUser(Model):
|
||||
id = fields.IntField(pk=True, generated=True, auto_increment=True)
|
||||
"""自增id"""
|
||||
user_id = fields.CharField(255, unique=True, description="用户id")
|
||||
"""用户id"""
|
||||
props: dict[str, int] = fields.JSONField(default={}) # type: ignore
|
||||
"""道具"""
|
||||
usage_count: dict[str, int] = fields.JSONField(default={}) # type: ignore
|
||||
"""使用道具次数"""
|
||||
platform = fields.CharField(255, null=True, description="平台")
|
||||
"""平台"""
|
||||
create_time = fields.DatetimeField(auto_now_add=True, description="创建时间")
|
||||
"""创建时间"""
|
||||
|
||||
class Meta: # pyright: ignore [reportIncompatibleVariableOverride]
|
||||
table = "bym_user"
|
||||
table_description = "用户数据表"
|
||||
|
||||
@classmethod
|
||||
async def get_user(cls, user_id: str, platform: str | None = None) -> "BymUser":
|
||||
"""获取用户
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
platform: 平台.
|
||||
|
||||
返回:
|
||||
UserConsole: UserConsole
|
||||
"""
|
||||
if not await cls.exists(user_id=user_id):
|
||||
await cls.create(user_id=user_id, platform=platform)
|
||||
return await cls.get(user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
async def add_gift(cls, user_id: str, gift_uuid: str):
|
||||
"""添加道具
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
gift_uuid: 道具uuid
|
||||
"""
|
||||
user = await cls.get_user(user_id)
|
||||
user.props[gift_uuid] = user.props.get(gift_uuid, 0) + 1
|
||||
await GiftLog.create(user_id=user_id, gift_uuid=gift_uuid, type=0)
|
||||
await user.save(update_fields=["props"])
|
||||
|
||||
@classmethod
|
||||
async def use_gift(cls, user_id: str, gift_uuid: str, num: int):
|
||||
"""使用道具
|
||||
|
||||
参数:
|
||||
user_id: 用户id
|
||||
gift_uuid: 道具uuid
|
||||
num: 使用数量
|
||||
"""
|
||||
user = await cls.get_user(user_id)
|
||||
if user.props.get(gift_uuid, 0) < num:
|
||||
raise ValueError("道具数量不足")
|
||||
user.props[gift_uuid] -= num
|
||||
user.usage_count[gift_uuid] = user.usage_count.get(gift_uuid, 0) + num
|
||||
create_list = [
|
||||
GiftLog(user_id=user_id, gift_uuid=gift_uuid, type=1) for _ in range(num)
|
||||
]
|
||||
await GiftLog.bulk_create(create_list)
|
||||
await user.save(update_fields=["props", "usage_count"])
|
||||
94
zhenxun/plugins/nonebot_plugin_dorodoro/__init__.py
Normal file
94
zhenxun/plugins/nonebot_plugin_dorodoro/__init__.py
Normal file
@ -0,0 +1,94 @@
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from nonebot_plugin_alconna import Alconna, Args, Arparma, CommandMeta, Text, on_alconna
|
||||
from nonebot_plugin_uninfo import Session, UniSession
|
||||
|
||||
from .game_logic import (
|
||||
get_next_node,
|
||||
get_node_data,
|
||||
is_end_node,
|
||||
update_user_state,
|
||||
user_game_state,
|
||||
)
|
||||
from .image_handler import send_images
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="doro大冒险",
|
||||
description="一个基于文字冒险的游戏插件",
|
||||
type="application",
|
||||
usage="""
|
||||
使用方法:
|
||||
doro :开始游戏
|
||||
choose <选项> 或 选择 <选项>:在游戏中做出选择
|
||||
""",
|
||||
homepage="https://github.com/ATTomatoo/dorodoro",
|
||||
extra={
|
||||
"author": "ATTomatoo",
|
||||
"version": "1.5.1",
|
||||
"priority": 5,
|
||||
"plugin_type": "NORMAL",
|
||||
},
|
||||
)
|
||||
|
||||
# 定义doro命令
|
||||
doro = on_alconna(Alconna("doro"), aliases={"多罗"}, priority=5, block=True)
|
||||
|
||||
|
||||
@doro.handle()
|
||||
async def handle_doro(session: Session = UniSession()):
|
||||
user_id = session.user.id
|
||||
start_node = "start"
|
||||
await update_user_state(user_id, start_node)
|
||||
if start_data := await get_node_data(start_node):
|
||||
msg = start_data["text"] + "\n"
|
||||
for key, opt in start_data.get("options", {}).items():
|
||||
msg += f"{key}. {opt['text']}\n"
|
||||
|
||||
await send_images(start_data.get("image"))
|
||||
await doro.send(Text(msg), reply_to=True)
|
||||
else:
|
||||
await doro.send(Text("游戏初始化失败,请联系管理员。"), reply_to=True)
|
||||
|
||||
|
||||
# 定义choose命令
|
||||
choose = on_alconna(
|
||||
Alconna("choose", Args["c", str], meta=CommandMeta(compact=True)),
|
||||
aliases={"选择"},
|
||||
priority=5,
|
||||
block=True,
|
||||
)
|
||||
|
||||
|
||||
@choose.handle()
|
||||
async def handle_choose(p: Arparma, session: Session = UniSession()):
|
||||
user_id = session.user.id
|
||||
if user_id not in user_game_state:
|
||||
await choose.finish(
|
||||
Text("你还没有开始游戏,请输入 /doro 开始。"), reply_to=True
|
||||
)
|
||||
|
||||
choice = p.query("c")
|
||||
assert isinstance(choice, str)
|
||||
choice = choice.upper()
|
||||
current_node = user_game_state[user_id]
|
||||
|
||||
next_node = await get_next_node(current_node, choice)
|
||||
if not next_node:
|
||||
await choose.finish(Text("无效选择,请重新输入。"), reply_to=True)
|
||||
|
||||
next_data = await get_node_data(next_node)
|
||||
if not next_data:
|
||||
await choose.finish(Text("故事节点错误,请联系管理员。"), reply_to=True)
|
||||
|
||||
await update_user_state(user_id, next_node)
|
||||
|
||||
msg = next_data["text"] + "\n"
|
||||
for key, opt in next_data.get("options", {}).items():
|
||||
msg += f"{key}. {opt['text']}\n"
|
||||
|
||||
await send_images(next_data.get("image"))
|
||||
|
||||
if await is_end_node(next_data):
|
||||
await choose.send(Text(msg + "\n故事结束。"), reply_to=True)
|
||||
user_game_state.pop(user_id, None)
|
||||
else:
|
||||
await choose.finish(Text(msg), reply_to=True)
|
||||
3
zhenxun/plugins/nonebot_plugin_dorodoro/config.py
Normal file
3
zhenxun/plugins/nonebot_plugin_dorodoro/config.py
Normal file
@ -0,0 +1,3 @@
|
||||
from pathlib import Path
|
||||
|
||||
IMAGE_DIR = Path(__file__).parent / "images"
|
||||
57
zhenxun/plugins/nonebot_plugin_dorodoro/game_logic.py
Normal file
57
zhenxun/plugins/nonebot_plugin_dorodoro/game_logic.py
Normal file
@ -0,0 +1,57 @@
|
||||
try:
|
||||
import ujson as json
|
||||
except ImportError:
|
||||
import json
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
import aiofiles
|
||||
|
||||
# 构造 story_data.json 的完整路径
|
||||
story_data_path = Path(__file__).parent / "story_data.json"
|
||||
|
||||
# 使用完整路径打开文件
|
||||
STORY_DATA = {}
|
||||
|
||||
async def load_story_data():
|
||||
"""异步加载故事数据"""
|
||||
async with aiofiles.open(story_data_path, encoding="utf-8") as f:
|
||||
content = await f.read()
|
||||
global STORY_DATA
|
||||
STORY_DATA = json.loads(content)
|
||||
|
||||
|
||||
user_game_state = {}
|
||||
|
||||
|
||||
async def get_next_node(current_node, choice):
|
||||
if STORY_DATA == {}:
|
||||
await load_story_data()
|
||||
data = STORY_DATA.get(current_node, {})
|
||||
options = data.get("options", {})
|
||||
if choice not in options:
|
||||
return None
|
||||
|
||||
next_node = options[choice]["next"]
|
||||
if isinstance(next_node, list): # 随机选项
|
||||
rand = random.random()
|
||||
cumulative = 0.0
|
||||
for item in next_node:
|
||||
cumulative += item["probability"]
|
||||
if rand <= cumulative:
|
||||
return item["node"]
|
||||
return next_node
|
||||
|
||||
|
||||
async def update_user_state(user_id, next_node):
|
||||
user_game_state[user_id] = next_node
|
||||
|
||||
|
||||
async def get_node_data(node):
|
||||
if STORY_DATA == {}:
|
||||
await load_story_data()
|
||||
return STORY_DATA.get(node)
|
||||
|
||||
|
||||
async def is_end_node(node_data) -> bool:
|
||||
return node_data.get("is_end", False)
|
||||
22
zhenxun/plugins/nonebot_plugin_dorodoro/image_handler.py
Normal file
22
zhenxun/plugins/nonebot_plugin_dorodoro/image_handler.py
Normal file
@ -0,0 +1,22 @@
|
||||
from nonebot_plugin_alconna import Image, UniMessage
|
||||
|
||||
from .config import IMAGE_DIR
|
||||
|
||||
|
||||
async def get_image_segment(image_name):
|
||||
image_path = IMAGE_DIR / image_name
|
||||
return Image(path=image_path) if image_path.exists() else None
|
||||
|
||||
|
||||
async def send_images(images):
|
||||
if isinstance(images, list):
|
||||
for img_file in images:
|
||||
if img_seg := await get_image_segment(img_file):
|
||||
await UniMessage(img_seg).send(reply_to=True)
|
||||
else:
|
||||
await UniMessage(f"图片 {img_file} 不存在。").send(reply_to=True)
|
||||
elif isinstance(images, str):
|
||||
if img_seg := await get_image_segment(images):
|
||||
await UniMessage(img_seg).send(reply_to=True)
|
||||
else:
|
||||
await UniMessage(f"图片 {images} 不存在。").send(reply_to=True)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user