perf👌: 实现优化webui群组/好友管理api

This commit is contained in:
HibiKier 2024-01-09 13:47:24 +08:00
parent 9a1510fe7e
commit bf55a20241
7 changed files with 443 additions and 95 deletions

View File

@ -18,6 +18,7 @@ from .api.tabs.database import router as database_router
from .api.tabs.main import router as main_router from .api.tabs.main import router as main_router
from .api.tabs.main import ws_router as status_routes from .api.tabs.main import ws_router as status_routes
from .api.tabs.manage import router as manage_router from .api.tabs.manage import router as manage_router
from .api.tabs.manage import ws_router as chat_routes
from .api.tabs.plugin_manage import router as plugin_router from .api.tabs.plugin_manage import router as plugin_router
from .auth import router as auth_router from .auth import router as auth_router
@ -42,6 +43,7 @@ WsApiRouter = APIRouter(prefix="/zhenxun/socket")
WsApiRouter.include_router(ws_log_routes) WsApiRouter.include_router(ws_log_routes)
WsApiRouter.include_router(status_routes) WsApiRouter.include_router(status_routes)
WsApiRouter.include_router(chat_routes)
@driver.on_startup @driver.on_startup

View File

@ -21,18 +21,6 @@ class LogStorage(Generic[_T]):
self.listeners: Set[LogListener[str]] = set() self.listeners: Set[LogListener[str]] = set()
async def add(self, log: str): async def add(self, log: str):
# log = re.sub(PATTERN, "", log)
# log_split = log.split()
# time = log_split[0] + " " + log_split[1]
# level = log_split[2]
# main = log_split[3]
# type_ = None
# log_ = " ".join(log_split[3:])
# if "Calling API" in log_:
# sp = log_.split("|")
# type_ = sp[1]
# log_ = "|".join(log_[1:])
# data = {"time": time, "level": level, "main": main, "type": type_, "log": log_}
seq = self.count = self.count + 1 seq = self.count = self.count + 1
self.logs[seq] = log self.logs[seq] = log
asyncio.get_running_loop().call_later(self.rotation, self.remove, seq) asyncio.get_running_loop().call_later(self.rotation, self.remove, seq)
@ -48,3 +36,4 @@ class LogStorage(Generic[_T]):
LOG_STORAGE: LogStorage[str] = LogStorage[str]() LOG_STORAGE: LogStorage[str] = LogStorage[str]()

View File

@ -5,7 +5,6 @@ from typing import List, Optional
import nonebot import nonebot
from fastapi import APIRouter, WebSocket from fastapi import APIRouter, WebSocket
from nonebot.utils import escape_tag
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
from tortoise.functions import Count from tortoise.functions import Count
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
@ -18,21 +17,18 @@ from services.log import logger
from utils.manager import plugin_data_manager, plugins2settings_manager, plugins_manager from utils.manager import plugin_data_manager, plugins2settings_manager, plugins_manager
from ....base_model import Result from ....base_model import Result
from ....config import QueryDateType from ....config import AVA_URL, GROUP_AVA_URL, QueryDateType
from ....utils import authentication, get_system_status from ....utils import authentication, get_system_status
from .data_source import bot_live from .data_source import bot_live
from .model import ActiveGroup, BaseInfo, ChatHistoryCount, HotPlugin from .model import ActiveGroup, BaseInfo, ChatHistoryCount, HotPlugin
AVA_URL = "http://q1.qlogo.cn/g?b=qq&nk={}&s=160"
GROUP_AVA_URL = "http://p.qlogo.cn/gh/{}/{}/640/"
run_time = time.time() run_time = time.time()
ws_router = APIRouter() ws_router = APIRouter()
router = APIRouter() router = APIRouter()
@router.get("/get_base_info", dependencies=[authentication()], description="基础信息") @router.get("/get_base_info", dependencies=[authentication()], description="基础信息")
async def _(bot_id: Optional[str] = None) -> Result: async def _(bot_id: Optional[str] = None) -> Result:
""" """

View File

@ -1,32 +1,54 @@
from typing import Literal import re
from typing import Literal, Optional
import nonebot import nonebot
from fastapi import APIRouter from fastapi import APIRouter
from pydantic.error_wrappers import ValidationError from nonebot.adapters.onebot.v11.exception import ActionFailed
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
from tortoise.functions import Count
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
from configs.config import NICKNAME from configs.config import NICKNAME
from models.ban_user import BanUser
from models.chat_history import ChatHistory
from models.friend_user import FriendUser
from models.group_info import GroupInfo from models.group_info import GroupInfo
from models.group_member_info import GroupInfoUser
from models.statistics import Statistics
from services.log import logger from services.log import logger
from utils.manager import group_manager, requests_manager from utils.manager import group_manager, plugin_data_manager, requests_manager
from utils.utils import get_bot from utils.utils import get_bot
from ....base_model import Result from ....base_model import Result
from ....config import AVA_URL, GROUP_AVA_URL
from ....utils import authentication from ....utils import authentication
from ...logs.log_manager import LOG_STORAGE
from .model import ( from .model import (
DeleteFriend, DeleteFriend,
Friend, Friend,
FriendRequestResult, FriendRequestResult,
Group, GroupDetail,
GroupRequestResult, GroupRequestResult,
GroupResult, GroupResult,
HandleRequest, HandleRequest,
LeaveGroup, LeaveGroup,
Message,
Plugin,
ReqResult,
SendMessage,
Task, Task,
UpdateGroup, UpdateGroup,
UserDetail,
) )
ws_router = APIRouter()
router = APIRouter() router = APIRouter()
SUB_PATTERN = r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))"
GROUP_PATTERN = r'.*?Message (-?\d*) from (\d*)@\[群:(\d*)] "(.*)"'
PRIVATE_PATTERN = r'.*?Message (-?\d*) from (\d*) "(.*)"'
@router.get("/get_group_list", dependencies=[authentication()], description="获取群组列表") @router.get("/get_group_list", dependencies=[authentication()], description="获取群组列表")
async def _(bot_id: str) -> Result: async def _(bot_id: str) -> Result:
@ -41,27 +63,9 @@ async def _(bot_id: str) -> Result:
group_info = {} group_info = {}
group_list = await bots[bot_id].get_group_list() group_list = await bots[bot_id].get_group_list()
for g in group_list: for g in group_list:
group_info[g["group_id"]] = Group(**g) gid = g['group_id']
group_data = group_manager.get_data() g['ava_url'] = GROUP_AVA_URL.format(gid, gid)
for group_id in group_data.group_manager: group_list_result.append(GroupResult(**g))
task_list = []
data = group_manager[group_id].dict()
for tn, status in data["group_task_status"].items():
task_list.append(
Task(
**{
"name": tn,
"nameZh": group_manager.get_task_data().get(tn) or tn,
"status": status,
}
)
)
data["task"] = task_list
if x := group_info.get(int(group_id)):
data["group"] = x
else:
continue
group_list_result.append(GroupResult(**data))
except Exception as e: except Exception as e:
logger.error("调用API错误", "/get_group_list", e=e) logger.error("调用API错误", "/get_group_list", e=e)
return Result.fail(f"{type(e)}: {e}") return Result.fail(f"{type(e)}: {e}")
@ -78,12 +82,14 @@ async def _(group: UpdateGroup) -> Result:
group_manager.turn_on_group_bot_status(group_id) group_manager.turn_on_group_bot_status(group_id)
else: else:
group_manager.shutdown_group_bot_status(group_id) group_manager.shutdown_group_bot_status(group_id)
if group.task_status: all_task = group_manager.get_task_data().keys()
for task in group.task_status: if group.task:
if group.task_status[task]: for task in all_task:
if task in group.task:
group_manager.open_group_task(group_id, task) group_manager.open_group_task(group_id, task)
else: else:
group_manager.close_group_task(group_id, task) group_manager.close_group_task(group_id, task)
group_manager[group_id].close_plugins = group.close_plugins
group_manager.save() group_manager.save()
except Exception as e: except Exception as e:
logger.error("调用API错误", "/get_group", e=e) logger.error("调用API错误", "/get_group", e=e)
@ -101,7 +107,9 @@ async def _(bot_id: str) -> Result:
return Result.warning_("指定Bot未连接...") return Result.warning_("指定Bot未连接...")
try: try:
friend_list = await bots[bot_id].get_friend_list() friend_list = await bots[bot_id].get_friend_list()
return Result.ok([Friend(**f) for f in friend_list], "拿到了新鲜出炉的数据!") for f in friend_list:
f['ava_url'] = AVA_URL.format(f['user_id'])
return Result.ok([Friend(**f) for f in friend_list if str(f['user_id']) != bot_id], "拿到了新鲜出炉的数据!")
except Exception as e: except Exception as e:
logger.error("调用API错误", "/get_group_list", e=e) logger.error("调用API错误", "/get_group_list", e=e)
return Result.fail(f"{type(e)}: {e}") return Result.fail(f"{type(e)}: {e}")
@ -118,21 +126,27 @@ def _() -> Result:
@router.get("/get_request_list", dependencies=[authentication()], description="获取请求列表") @router.get("/get_request_list", dependencies=[authentication()], description="获取请求列表")
def _(request_type: Literal["private", "group"]) -> Result: def _() -> Result:
try: try:
req_data = requests_manager.get_data().get(request_type) or [] req_result = ReqResult()
req_list = [] data = requests_manager.get_data()
for x in req_data: for type_ in requests_manager.get_data():
req_data[x]["oid"] = x for x in data[type_]:
if request_type == "private": data[type_][x]["oid"] = x
req_list.append(FriendRequestResult(**req_data[x])) data[type_][x]['type'] = type_
if type_ == "private":
data[type_][x]['ava_url'] = AVA_URL.format(data[type_][x]['id'])
req_result.friend.append(FriendRequestResult(**data[type_][x]))
else: else:
req_list.append(GroupRequestResult(**req_data[x])) gid = data[type_][x]['id']
req_list.reverse() data[type_][x]['ava_url'] = GROUP_AVA_URL.format(gid, gid)
req_result.group.append(GroupRequestResult(**data[type_][x]))
req_result.friend.reverse()
req_result.group.reverse()
except Exception as e: except Exception as e:
logger.error("调用API错误", "/get_request", e=e) logger.error("调用API错误", "/get_request", e=e)
return Result.fail(f"{type(e)}: {e}") return Result.fail(f"{type(e)}: {e}")
return Result.ok(req_list, f"{NICKNAME}带来了最新的数据!") return Result.ok(req_result, f"{NICKNAME}带来了最新的数据!")
@router.delete("/clear_request", dependencies=[authentication()], description="清空请求列表") @router.delete("/clear_request", dependencies=[authentication()], description="清空请求列表")
@ -156,14 +170,18 @@ async def _(parma: HandleRequest) -> Result:
bot_id = parma.bot_id bot_id = parma.bot_id
if bot_id not in nonebot.get_bots(): if bot_id not in nonebot.get_bots():
return Result.warning_("指定Bot未连接...") return Result.warning_("指定Bot未连接...")
flag = await requests_manager.refused(bots[bot_id], parma.id, parma.request_type) # type: ignore try:
flag = await requests_manager.refused(bots[bot_id], parma.flag, parma.request_type) # type: ignore
except ActionFailed as e:
requests_manager.delete_request(parma.flag, parma.request_type)
return Result.warning_("请求失败,可能该请求已失效或请求数据错误...")
if flag == 1: if flag == 1:
requests_manager.delete_request(parma.id, parma.request_type) requests_manager.delete_request(parma.flag, parma.request_type)
return Result.warning_("该请求已失效...") return Result.warning_("该请求已失效...")
elif flag == 2: elif flag == 2:
return Result.warning_("未找到此Id请求...") return Result.warning_("未找到此Id请求...")
return Result.ok(info="成功处理了请求!") return Result.ok(info="成功处理了请求!")
return Result.warning_("Bot连接...") return Result.warning_("Bot连接...")
except Exception as e: except Exception as e:
logger.error("调用API错误", "/refuse_request", e=e) logger.error("调用API错误", "/refuse_request", e=e)
return Result.fail(f"{type(e)}: {e}") return Result.fail(f"{type(e)}: {e}")
@ -175,7 +193,7 @@ async def _(parma: HandleRequest) -> Result:
操作请求 操作请求
:param parma: 参数 :param parma: 参数
""" """
requests_manager.delete_request(parma.id, parma.request_type) requests_manager.delete_request(parma.flag, parma.request_type)
return Result.ok(info="成功处理了请求!") return Result.ok(info="成功处理了请求!")
@ -191,7 +209,7 @@ async def _(parma: HandleRequest) -> Result:
if bot_id not in nonebot.get_bots(): if bot_id not in nonebot.get_bots():
return Result.warning_("指定Bot未连接...") return Result.warning_("指定Bot未连接...")
if parma.request_type == "group": if parma.request_type == "group":
if rid := requests_manager.get_group_id(parma.id): if rid := requests_manager.get_group_id(parma.flag):
if group := await GroupInfo.get_or_none(group_id=str(rid)): if group := await GroupInfo.get_or_none(group_id=str(rid)):
await group.update_or_create(group_flag=1) await group.update_or_create(group_flag=1)
else: else:
@ -205,9 +223,13 @@ async def _(parma: HandleRequest) -> Result:
"group_flag": 1, "group_flag": 1,
}, },
) )
await requests_manager.approve(bots[bot_id], parma.id, parma.request_type) # type: ignore try:
await requests_manager.approve(bots[bot_id], parma.flag, parma.request_type) # type: ignore
return Result.ok(info="成功处理了请求!") return Result.ok(info="成功处理了请求!")
return Result.warning_("Bot未连接...") except ActionFailed as e:
requests_manager.delete_request(parma.flag, parma.request_type)
return Result.warning_("请求失败,可能该请求已失效或请求数据错误...")
return Result.warning_("无Bot连接...")
except Exception as e: except Exception as e:
logger.error("调用API错误", "/approve_request", e=e) logger.error("调用API错误", "/approve_request", e=e)
return Result.fail(f"{type(e)}: {e}") return Result.fail(f"{type(e)}: {e}")
@ -223,7 +245,7 @@ async def _(param: LeaveGroup) -> Result:
return Result.warning_("Bot未在该群聊中...") return Result.warning_("Bot未在该群聊中...")
await bots[bot_id].set_group_leave(group_id=param.group_id) await bots[bot_id].set_group_leave(group_id=param.group_id)
return Result.ok(info="成功处理了请求!") return Result.ok(info="成功处理了请求!")
return Result.warning_("Bot连接...") return Result.warning_("Bot连接...")
except Exception as e: except Exception as e:
logger.error("调用API错误", "/leave_group", e=e) logger.error("调用API错误", "/leave_group", e=e)
return Result.fail(f"{type(e)}: {e}") return Result.fail(f"{type(e)}: {e}")
@ -243,3 +265,163 @@ async def _(param: DeleteFriend) -> Result:
except Exception as e: except Exception as e:
logger.error("调用API错误", "/delete_friend", e=e) logger.error("调用API错误", "/delete_friend", e=e)
return Result.fail(f"{type(e)}: {e}") return Result.fail(f"{type(e)}: {e}")
@router.get("/get_friend_detail", dependencies=[authentication()], description="获取好友详情")
async def _(bot_id: str, user_id: str) -> Result:
if bots := nonebot.get_bots():
if bot_id in bots:
if fd := [x for x in await bots[bot_id].get_friend_list() if str(x['user_id']) == user_id]:
like_plugin_list = (
await Statistics.filter(user_id=user_id).annotate(count=Count("id"))
.group_by("plugin_name").order_by("-count").limit(5)
.values_list("plugin_name", "count")
)
like_plugin = {}
for data in like_plugin_list:
name = data[0]
if plugin_data := plugin_data_manager.get(data[0]):
name = plugin_data.name
like_plugin[name] = data[1]
user = fd[0]
user_detail = UserDetail(
user_id=user_id,
ava_url=AVA_URL.format(user_id),
nickname=user['nickname'],
remark=user['remark'],
is_ban=await BanUser.is_ban(user_id),
chat_count=await ChatHistory.filter(user_id=user_id).count(),
call_count=await Statistics.filter(user_id=user_id).count(),
like_plugin=like_plugin,
)
return Result.ok(user_detail)
else:
return Result.warning_("未添加指定好友...")
return Result.warning_("无Bot连接...")
@router.get("/get_group_detail", dependencies=[authentication()], description="获取群组详情")
async def _(bot_id: str, group_id: str) -> Result:
if bots := nonebot.get_bots():
if bot_id in bots:
group_info = await bots[bot_id].get_group_info(group_id=int(group_id))
g = group_manager[group_id]
if not g:
return Result.warning_("指定群组未被收录...")
if group_info:
like_plugin_list = (
await Statistics.filter(group_id=group_id).annotate(count=Count("id"))
.group_by("plugin_name").order_by("-count").limit(5)
.values_list("plugin_name", "count")
)
like_plugin = {}
for data in like_plugin_list:
name = data[0]
if plugin_data := plugin_data_manager.get(data[0]):
name = plugin_data.name
like_plugin[name] = data[1]
close_plugins = []
for module in g.close_plugins:
plugin = Plugin(module=module, plugin_name=module)
if plugin_data := plugin_data_manager.get(module):
plugin.plugin_name = plugin_data.name
close_plugins.append(plugin)
task_list = []
task_data = group_manager.get_task_data()
for tn, status in g.group_task_status.items():
task_list.append(
Task(
name=tn,
zh_name=task_data.get(tn) or tn,
status=status
)
)
group_detail = GroupDetail(
group_id=group_id,
ava_url=GROUP_AVA_URL.format(group_id, group_id),
name=group_info['group_name'],
member_count=group_info['member_count'],
max_member_count=group_info['max_member_count'],
chat_count=await ChatHistory.filter(group_id=group_id).count(),
call_count=await Statistics.filter(group_id=group_id).count(),
like_plugin=like_plugin,
level=g.level,
status=g.status,
close_plugins=close_plugins,
task=task_list
)
return Result.ok(group_detail)
else:
return Result.warning_("未添加指定群组...")
return Result.warning_("无Bot连接...")
@router.post("/send_message", dependencies=[authentication()], description="获取群组详情")
async def _(param: SendMessage) -> Result:
if bots := nonebot.get_bots():
if param.bot_id in bots:
try:
if param.user_id:
await bots[param.bot_id].send_private_msg(user_id=str(param.user_id), message=param.message)
else:
await bots[param.bot_id].send_group_msg(group_id=str(param.group_id), message=param.message)
except Exception as e:
return Result.fail(str(e))
return Result.ok("发送成功!")
return Result.warning_("指定Bot未连接...")
return Result.warning_("无Bot连接...")
MSG_LIST = []
@ws_router.websocket("/chat")
async def _(websocket: WebSocket, group_id: Optional[str] = None, user_id: Optional[str] = None):
global MSG_LIST
await websocket.accept()
async def log_listener(log: str):
sub_log = re.sub(SUB_PATTERN, "", log)
if "message.private.friend" in log:
if r := re.search(PRIVATE_PATTERN, sub_log):
msg_id = r.group(1)
uid = r.group(2)
msg = r.group(3)
user = await FriendUser.filter(user_id=user_id).first()
name = user.user_name
if uid and uid == user_id and msg_id not in MSG_LIST:
MSG_LIST.append(msg_id)
message = Message(
user_id=uid,
message=msg,
name=name,
ava_url=AVA_URL.format(uid)
)
await websocket.send_json(message.dict())
else:
if r := re.search(GROUP_PATTERN, sub_log):
msg_id = r.group(1)
uid = r.group(2)
gid = r.group(3)
msg = r.group(4)
user = await GroupInfoUser.filter(user_id=uid, group_id=gid).first()
name = user.user_name or user.nickname
if gid and gid == group_id and msg_id not in MSG_LIST:
MSG_LIST.append(msg_id)
message = Message(
user_id=uid,
group_id=gid,
message=msg,
name=name,
ava_url=AVA_URL.format(uid)
)
await websocket.send_json(message.dict())
LOG_STORAGE.listeners.add(log_listener)
try:
while websocket.client_state == WebSocketState.CONNECTED:
recv = await websocket.receive()
except WebSocketDisconnect:
pass
finally:
LOG_STORAGE.listeners.remove(log_listener)
return

View File

@ -27,27 +27,33 @@ class Task(BaseModel):
name: str name: str
"""被动名称""" """被动名称"""
nameZh: str zh_name: str
"""被动中文名称""" """被动中文名称"""
status: bool status: bool
"""状态""" """状态"""
class Plugin(BaseModel):
"""
插件
"""
module: str
"""模块名"""
plugin_name: str
"""中文名"""
class GroupResult(BaseModel): class GroupResult(BaseModel):
""" """
群组返回数据 群组返回数据
""" """
group: Group group_id: Union[str, int]
"""Group""" """群组id"""
level: int group_name: str
"""群等级""" """群组名称"""
status: bool ava_url: str
"""状态""" """群组头像"""
close_plugins: List[str]
"""关闭的插件"""
task: List[Task]
"""被动列表"""
class Friend(BaseModel): class Friend(BaseModel):
@ -61,7 +67,8 @@ class Friend(BaseModel):
"""昵称""" """昵称"""
remark: str = "" remark: str = ""
"""备注""" """备注"""
ava_url: str = ""
"""头像url"""
class UpdateGroup(BaseModel): class UpdateGroup(BaseModel):
""" """
@ -74,8 +81,10 @@ class UpdateGroup(BaseModel):
"""状态""" """状态"""
level: int level: int
"""群权限""" """群权限"""
task_status: Dict[str, bool] task: List[str]
"""被动状态""" """被动状态"""
close_plugins: List[str]
"""关闭插件"""
class FriendRequestResult(BaseModel): class FriendRequestResult(BaseModel):
@ -103,6 +112,10 @@ class FriendRequestResult(BaseModel):
"""来自""" """来自"""
comment: Optional[str] comment: Optional[str]
"""备注信息""" """备注信息"""
ava_url: str
"""头像"""
type: str
"""类型 private group"""
class GroupRequestResult(FriendRequestResult): class GroupRequestResult(FriendRequestResult):
@ -121,10 +134,10 @@ class HandleRequest(BaseModel):
操作请求接收数据 操作请求接收数据
""" """
bot_id: str bot_id: Optional[str] = None
"""bot_id""" """bot_id"""
id: int flag: str
"""id""" """flag"""
request_type: Literal["private", "group"] request_type: Literal["private", "group"]
"""类型""" """类型"""
@ -149,3 +162,97 @@ class DeleteFriend(BaseModel):
"""bot_id""" """bot_id"""
user_id: str user_id: str
"""用户id""" """用户id"""
class ReqResult(BaseModel):
"""
好友/群组请求列表
"""
friend: List[FriendRequestResult] = []
"""好友请求列表"""
group: List[GroupRequestResult] = []
"""群组请求列表"""
class UserDetail(BaseModel):
"""
用户详情
"""
user_id: str
"""用户id"""
ava_url: str
"""头像url"""
nickname: str
"""昵称"""
remark: str
"""备注"""
is_ban: bool
"""是否被ban"""
chat_count: int
"""发言次数"""
call_count: int
"""功能调用次数"""
like_plugin: Dict[str, int]
"""最喜爱的功能"""
class GroupDetail(BaseModel):
"""
用户详情
"""
group_id: str
"""群组id"""
ava_url: str
"""头像url"""
name: str
"""名称"""
member_count: int
"""成员数"""
max_member_count: int
"""最大成员数"""
chat_count: int
"""发言次数"""
call_count: int
"""功能调用次数"""
like_plugin: Dict[str, int]
"""最喜爱的功能"""
level: int
"""群权限"""
status: bool
"""状态(睡眠)"""
close_plugins: List[Plugin]
"""关闭的插件"""
task: List[Task]
"""被动列表"""
class Message(BaseModel):
"""
消息
"""
user_id: str
"""用户id"""
group_id: Optional[str] = None
"""群组id"""
message: str
"""消息"""
name: str
"""用户名称"""
ava_url: str
"""用户头像"""
class SendMessage(BaseModel):
"""
发送消息
"""
bot_id: str
"""bot id"""
user_id: Optional[str] = None
"""用户id"""
group_id: Optional[str] = None
"""群组id"""
message: str
"""消息"""

View File

@ -20,6 +20,11 @@ app.add_middleware(
) )
AVA_URL = "http://q1.qlogo.cn/g?b=qq&nk={}&s=160"
GROUP_AVA_URL = "http://p.qlogo.cn/gh/{}/{}/640/"
class QueryDateType(StrEnum): class QueryDateType(StrEnum):
""" """

View File

@ -1,6 +1,6 @@
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Union, overload
from nonebot.adapters.onebot.v11 import ActionFailed, Bot from nonebot.adapters.onebot.v11 import ActionFailed, Bot
@ -67,14 +67,24 @@ class RequestManager(StaticData):
} }
self.save() self.save()
@overload
def remove_request(self, type_: str, flag: str):
...
@overload
def remove_request(self, type_: str, id_: int): def remove_request(self, type_: str, id_: int):
...
def remove_request(self, type_: str, id_: Union[int, str]):
""" """
删除一个请求数据 删除一个请求数据
:param type_: 类型 :param type_: 类型
:param id_: iduser_id group_id :param id_: iduser_id group_id
""" """
for x in self._data[type_].keys(): for x in self._data[type_].keys():
if self._data[type_][x].get("id") == id_: a_id = self._data[type_][x].get("id")
a_flag = self._data[type_][x].get("flag")
if a_id == id_ or a_flag == id_:
del self._data[type_][x] del self._data[type_][x]
break break
self.save() self.save()
@ -89,7 +99,16 @@ class RequestManager(StaticData):
return data["invite_group"] return data["invite_group"]
return None return None
@overload
async def approve(self, bot: Bot, id_: int, type_: str) -> int: async def approve(self, bot: Bot, id_: int, type_: str) -> int:
...
@overload
async def approve(self, bot: Bot, flag: str, type_: str) -> int:
...
async def approve(self, bot: Bot, id_: Union[int, str], type_: str) -> int:
""" """
同意请求 同意请求
:param bot: Bot :param bot: Bot
@ -98,7 +117,15 @@ class RequestManager(StaticData):
""" """
return await self._set_add_request(bot, id_, type_, True) return await self._set_add_request(bot, id_, type_, True)
async def refused(self, bot: Bot, id_: int, type_: str) -> Optional[int]: @overload
async def refused(self, bot: Bot, id_: int, type_: str) -> int:
...
@overload
async def refused(self, bot: Bot, flag: str, type_: str) -> int:
...
async def refused(self, bot: Bot, id_: Union[int, str], type_: str) -> Optional[int]:
""" """
拒绝请求 拒绝请求
:param bot: Bot :param bot: Bot
@ -120,18 +147,32 @@ class RequestManager(StaticData):
self._data = {"private": {}, "group": {}} self._data = {"private": {}, "group": {}}
self.save() self.save()
@overload
async def delete_request(self, id_: int, type_: str) -> int:
...
@overload
async def delete_request(self, flag: str, type_: str) -> int:
...
def delete_request( def delete_request(
self, id_: int, type_: str self, id_: Union[str, int], type_: str
): # type_: Literal["group", "private"] ): # type_: Literal["group", "private"]
""" """
删除请求 删除请求
:param id_: id :param id_: id
:param type_: 类型 :param type_: 类型
""" """
id_ = str(id_) if type(id_) == int:
if self._data[type_].get(id_): if self._data[type_].get(id_):
del self._data[type_][id_] del self._data[type_][id_]
self.save() self.save()
else:
for k, item in self._data[type_].items():
if item['flag'] == id_:
del self._data[type_][k]
self.save()
break
def set_group_name(self, group_name: str, group_id: int): def set_group_name(self, group_name: str, group_id: int):
""" """
@ -239,7 +280,7 @@ class RequestManager(StaticData):
return bk.pic2bs4() return bk.pic2bs4()
async def _set_add_request( async def _set_add_request(
self, bot: Bot, idx: int, type_: str, approve: bool self, bot: Bot, idx: Union[str, int], type_: str, approve: bool
) -> int: ) -> int:
""" """
处理请求 处理请求
@ -248,8 +289,13 @@ class RequestManager(StaticData):
:param type_: 类型private group :param type_: 类型private group
:param approve: 是否同意 :param approve: 是否同意
""" """
flag = None
id_ = None
if type(idx) == str:
flag = idx
else:
id_ = str(idx) id_ = str(idx)
if id_ in self._data[type_].keys(): if id_ and id_ in self._data[type_].keys():
try: try:
if type_ == "private": if type_ == "private":
await bot.set_friend_add_request( await bot.set_friend_add_request(
@ -277,4 +323,25 @@ class RequestManager(StaticData):
del self._data[type_][id_] del self._data[type_][id_]
self.save() self.save()
return rid return rid
if flag:
rm_id = None
for k, item in self._data[type_].items():
if item['flag'] == flag:
rm_id = k
if type_ == 'private':
await bot.set_friend_add_request(
flag=item['flag'], approve=approve
)
rid = item["id"]
else:
await bot.set_group_add_request(
flag=item['flag'],
sub_type="invite",
approve=approve,
)
rid = item["invite_group"]
if rm_id is not None:
del self._data[type_][rm_id]
self.save()
return rid
return 2 # 未找到id return 2 # 未找到id