From bf55a20241fb23d6076af25b230bfcf93715e508 Mon Sep 17 00:00:00 2001 From: HibiKier <775757368@qq.com> Date: Tue, 9 Jan 2024 13:47:24 +0800 Subject: [PATCH] =?UTF-8?q?perf=F0=9F=91=8C:=20=E5=AE=9E=E7=8E=B0=E4=BC=98?= =?UTF-8?q?=E5=8C=96webui=E7=BE=A4=E7=BB=84/=E5=A5=BD=E5=8F=8B=E7=AE=A1?= =?UTF-8?q?=E7=90=86api?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- plugins/web_ui/__init__.py | 2 + plugins/web_ui/api/logs/log_manager.py | 13 +- plugins/web_ui/api/tabs/main/__init__.py | 8 +- plugins/web_ui/api/tabs/manage/__init__.py | 280 +++++++++++++++++---- plugins/web_ui/api/tabs/manage/model.py | 139 ++++++++-- plugins/web_ui/config.py | 5 + utils/manager/requests_manager.py | 91 ++++++- 7 files changed, 443 insertions(+), 95 deletions(-) diff --git a/plugins/web_ui/__init__.py b/plugins/web_ui/__init__.py index cb05b979..3de51bfe 100644 --- a/plugins/web_ui/__init__.py +++ b/plugins/web_ui/__init__.py @@ -18,6 +18,7 @@ from .api.tabs.database import router as database_router from .api.tabs.main import router as main_router from .api.tabs.main import ws_router as status_routes from .api.tabs.manage import router as manage_router +from .api.tabs.manage import ws_router as chat_routes from .api.tabs.plugin_manage import router as plugin_router from .auth import router as auth_router @@ -42,6 +43,7 @@ WsApiRouter = APIRouter(prefix="/zhenxun/socket") WsApiRouter.include_router(ws_log_routes) WsApiRouter.include_router(status_routes) +WsApiRouter.include_router(chat_routes) @driver.on_startup diff --git a/plugins/web_ui/api/logs/log_manager.py b/plugins/web_ui/api/logs/log_manager.py index c7c8140e..f375313d 100644 --- a/plugins/web_ui/api/logs/log_manager.py +++ b/plugins/web_ui/api/logs/log_manager.py @@ -21,18 +21,6 @@ class LogStorage(Generic[_T]): self.listeners: Set[LogListener[str]] = set() async def add(self, log: str): - # log = re.sub(PATTERN, "", log) - # log_split = log.split() - # time = log_split[0] + " " + log_split[1] - # level = log_split[2] - # main = log_split[3] - # type_ = None - # log_ = " ".join(log_split[3:]) - # if "Calling API" in log_: - # sp = log_.split("|") - # type_ = sp[1] - # log_ = "|".join(log_[1:]) - # data = {"time": time, "level": level, "main": main, "type": type_, "log": log_} seq = self.count = self.count + 1 self.logs[seq] = log asyncio.get_running_loop().call_later(self.rotation, self.remove, seq) @@ -48,3 +36,4 @@ class LogStorage(Generic[_T]): LOG_STORAGE: LogStorage[str] = LogStorage[str]() + diff --git a/plugins/web_ui/api/tabs/main/__init__.py b/plugins/web_ui/api/tabs/main/__init__.py index 125cb587..6fec031c 100644 --- a/plugins/web_ui/api/tabs/main/__init__.py +++ b/plugins/web_ui/api/tabs/main/__init__.py @@ -5,7 +5,6 @@ from typing import List, Optional import nonebot from fastapi import APIRouter, WebSocket -from nonebot.utils import escape_tag from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState from tortoise.functions import Count from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK @@ -18,21 +17,18 @@ from services.log import logger from utils.manager import plugin_data_manager, plugins2settings_manager, plugins_manager from ....base_model import Result -from ....config import QueryDateType +from ....config import AVA_URL, GROUP_AVA_URL, QueryDateType from ....utils import authentication, get_system_status from .data_source import bot_live from .model import ActiveGroup, BaseInfo, ChatHistoryCount, HotPlugin -AVA_URL = "http://q1.qlogo.cn/g?b=qq&nk={}&s=160" - -GROUP_AVA_URL = "http://p.qlogo.cn/gh/{}/{}/640/" - run_time = time.time() ws_router = APIRouter() router = APIRouter() + @router.get("/get_base_info", dependencies=[authentication()], description="基础信息") async def _(bot_id: Optional[str] = None) -> Result: """ diff --git a/plugins/web_ui/api/tabs/manage/__init__.py b/plugins/web_ui/api/tabs/manage/__init__.py index a764f1bd..96e70ccb 100644 --- a/plugins/web_ui/api/tabs/manage/__init__.py +++ b/plugins/web_ui/api/tabs/manage/__init__.py @@ -1,32 +1,54 @@ -from typing import Literal +import re +from typing import Literal, Optional import nonebot from fastapi import APIRouter -from pydantic.error_wrappers import ValidationError +from nonebot.adapters.onebot.v11.exception import ActionFailed +from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState +from tortoise.functions import Count +from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK from configs.config import NICKNAME +from models.ban_user import BanUser +from models.chat_history import ChatHistory +from models.friend_user import FriendUser from models.group_info import GroupInfo +from models.group_member_info import GroupInfoUser +from models.statistics import Statistics from services.log import logger -from utils.manager import group_manager, requests_manager +from utils.manager import group_manager, plugin_data_manager, requests_manager from utils.utils import get_bot from ....base_model import Result +from ....config import AVA_URL, GROUP_AVA_URL from ....utils import authentication +from ...logs.log_manager import LOG_STORAGE from .model import ( DeleteFriend, Friend, FriendRequestResult, - Group, + GroupDetail, GroupRequestResult, GroupResult, HandleRequest, LeaveGroup, + Message, + Plugin, + ReqResult, + SendMessage, Task, UpdateGroup, + UserDetail, ) +ws_router = APIRouter() router = APIRouter() +SUB_PATTERN = r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))" + +GROUP_PATTERN = r'.*?Message (-?\d*) from (\d*)@\[群:(\d*)] "(.*)"' + +PRIVATE_PATTERN = r'.*?Message (-?\d*) from (\d*) "(.*)"' @router.get("/get_group_list", dependencies=[authentication()], description="获取群组列表") async def _(bot_id: str) -> Result: @@ -41,27 +63,9 @@ async def _(bot_id: str) -> Result: group_info = {} group_list = await bots[bot_id].get_group_list() for g in group_list: - group_info[g["group_id"]] = Group(**g) - group_data = group_manager.get_data() - for group_id in group_data.group_manager: - task_list = [] - data = group_manager[group_id].dict() - for tn, status in data["group_task_status"].items(): - task_list.append( - Task( - **{ - "name": tn, - "nameZh": group_manager.get_task_data().get(tn) or tn, - "status": status, - } - ) - ) - data["task"] = task_list - if x := group_info.get(int(group_id)): - data["group"] = x - else: - continue - group_list_result.append(GroupResult(**data)) + gid = g['group_id'] + g['ava_url'] = GROUP_AVA_URL.format(gid, gid) + group_list_result.append(GroupResult(**g)) except Exception as e: logger.error("调用API错误", "/get_group_list", e=e) return Result.fail(f"{type(e)}: {e}") @@ -78,12 +82,14 @@ async def _(group: UpdateGroup) -> Result: group_manager.turn_on_group_bot_status(group_id) else: group_manager.shutdown_group_bot_status(group_id) - if group.task_status: - for task in group.task_status: - if group.task_status[task]: + all_task = group_manager.get_task_data().keys() + if group.task: + for task in all_task: + if task in group.task: group_manager.open_group_task(group_id, task) else: group_manager.close_group_task(group_id, task) + group_manager[group_id].close_plugins = group.close_plugins group_manager.save() except Exception as e: logger.error("调用API错误", "/get_group", e=e) @@ -101,7 +107,9 @@ async def _(bot_id: str) -> Result: return Result.warning_("指定Bot未连接...") try: friend_list = await bots[bot_id].get_friend_list() - return Result.ok([Friend(**f) for f in friend_list], "拿到了新鲜出炉的数据!") + for f in friend_list: + f['ava_url'] = AVA_URL.format(f['user_id']) + return Result.ok([Friend(**f) for f in friend_list if str(f['user_id']) != bot_id], "拿到了新鲜出炉的数据!") except Exception as e: logger.error("调用API错误", "/get_group_list", e=e) return Result.fail(f"{type(e)}: {e}") @@ -118,21 +126,27 @@ def _() -> Result: @router.get("/get_request_list", dependencies=[authentication()], description="获取请求列表") -def _(request_type: Literal["private", "group"]) -> Result: +def _() -> Result: try: - req_data = requests_manager.get_data().get(request_type) or [] - req_list = [] - for x in req_data: - req_data[x]["oid"] = x - if request_type == "private": - req_list.append(FriendRequestResult(**req_data[x])) - else: - req_list.append(GroupRequestResult(**req_data[x])) - req_list.reverse() + req_result = ReqResult() + data = requests_manager.get_data() + for type_ in requests_manager.get_data(): + for x in data[type_]: + data[type_][x]["oid"] = x + data[type_][x]['type'] = type_ + if type_ == "private": + data[type_][x]['ava_url'] = AVA_URL.format(data[type_][x]['id']) + req_result.friend.append(FriendRequestResult(**data[type_][x])) + else: + gid = data[type_][x]['id'] + data[type_][x]['ava_url'] = GROUP_AVA_URL.format(gid, gid) + req_result.group.append(GroupRequestResult(**data[type_][x])) + req_result.friend.reverse() + req_result.group.reverse() except Exception as e: logger.error("调用API错误", "/get_request", e=e) return Result.fail(f"{type(e)}: {e}") - return Result.ok(req_list, f"{NICKNAME}带来了最新的数据!") + return Result.ok(req_result, f"{NICKNAME}带来了最新的数据!") @router.delete("/clear_request", dependencies=[authentication()], description="清空请求列表") @@ -156,14 +170,18 @@ async def _(parma: HandleRequest) -> Result: bot_id = parma.bot_id if bot_id not in nonebot.get_bots(): return Result.warning_("指定Bot未连接...") - flag = await requests_manager.refused(bots[bot_id], parma.id, parma.request_type) # type: ignore + try: + flag = await requests_manager.refused(bots[bot_id], parma.flag, parma.request_type) # type: ignore + except ActionFailed as e: + requests_manager.delete_request(parma.flag, parma.request_type) + return Result.warning_("请求失败,可能该请求已失效或请求数据错误...") if flag == 1: - requests_manager.delete_request(parma.id, parma.request_type) + requests_manager.delete_request(parma.flag, parma.request_type) return Result.warning_("该请求已失效...") elif flag == 2: return Result.warning_("未找到此Id请求...") return Result.ok(info="成功处理了请求!") - return Result.warning_("Bot未连接...") + return Result.warning_("无Bot连接...") except Exception as e: logger.error("调用API错误", "/refuse_request", e=e) return Result.fail(f"{type(e)}: {e}") @@ -175,7 +193,7 @@ async def _(parma: HandleRequest) -> Result: 操作请求 :param parma: 参数 """ - requests_manager.delete_request(parma.id, parma.request_type) + requests_manager.delete_request(parma.flag, parma.request_type) return Result.ok(info="成功处理了请求!") @@ -191,7 +209,7 @@ async def _(parma: HandleRequest) -> Result: if bot_id not in nonebot.get_bots(): return Result.warning_("指定Bot未连接...") if parma.request_type == "group": - if rid := requests_manager.get_group_id(parma.id): + if rid := requests_manager.get_group_id(parma.flag): if group := await GroupInfo.get_or_none(group_id=str(rid)): await group.update_or_create(group_flag=1) else: @@ -205,9 +223,13 @@ async def _(parma: HandleRequest) -> Result: "group_flag": 1, }, ) - await requests_manager.approve(bots[bot_id], parma.id, parma.request_type) # type: ignore - return Result.ok(info="成功处理了请求!") - return Result.warning_("Bot未连接...") + try: + await requests_manager.approve(bots[bot_id], parma.flag, parma.request_type) # type: ignore + return Result.ok(info="成功处理了请求!") + except ActionFailed as e: + requests_manager.delete_request(parma.flag, parma.request_type) + return Result.warning_("请求失败,可能该请求已失效或请求数据错误...") + return Result.warning_("无Bot连接...") except Exception as e: logger.error("调用API错误", "/approve_request", e=e) return Result.fail(f"{type(e)}: {e}") @@ -223,7 +245,7 @@ async def _(param: LeaveGroup) -> Result: return Result.warning_("Bot未在该群聊中...") await bots[bot_id].set_group_leave(group_id=param.group_id) return Result.ok(info="成功处理了请求!") - return Result.warning_("Bot未连接...") + return Result.warning_("无Bot连接...") except Exception as e: logger.error("调用API错误", "/leave_group", e=e) return Result.fail(f"{type(e)}: {e}") @@ -243,3 +265,163 @@ async def _(param: DeleteFriend) -> Result: except Exception as e: logger.error("调用API错误", "/delete_friend", e=e) return Result.fail(f"{type(e)}: {e}") + + + +@router.get("/get_friend_detail", dependencies=[authentication()], description="获取好友详情") +async def _(bot_id: str, user_id: str) -> Result: + if bots := nonebot.get_bots(): + if bot_id in bots: + if fd := [x for x in await bots[bot_id].get_friend_list() if str(x['user_id']) == user_id]: + like_plugin_list = ( + await Statistics.filter(user_id=user_id).annotate(count=Count("id")) + .group_by("plugin_name").order_by("-count").limit(5) + .values_list("plugin_name", "count") + ) + like_plugin = {} + for data in like_plugin_list: + name = data[0] + if plugin_data := plugin_data_manager.get(data[0]): + name = plugin_data.name + like_plugin[name] = data[1] + user = fd[0] + user_detail = UserDetail( + user_id=user_id, + ava_url=AVA_URL.format(user_id), + nickname=user['nickname'], + remark=user['remark'], + is_ban=await BanUser.is_ban(user_id), + chat_count=await ChatHistory.filter(user_id=user_id).count(), + call_count=await Statistics.filter(user_id=user_id).count(), + like_plugin=like_plugin, + ) + return Result.ok(user_detail) + else: + return Result.warning_("未添加指定好友...") + return Result.warning_("无Bot连接...") + + +@router.get("/get_group_detail", dependencies=[authentication()], description="获取群组详情") +async def _(bot_id: str, group_id: str) -> Result: + if bots := nonebot.get_bots(): + if bot_id in bots: + group_info = await bots[bot_id].get_group_info(group_id=int(group_id)) + g = group_manager[group_id] + if not g: + return Result.warning_("指定群组未被收录...") + if group_info: + like_plugin_list = ( + await Statistics.filter(group_id=group_id).annotate(count=Count("id")) + .group_by("plugin_name").order_by("-count").limit(5) + .values_list("plugin_name", "count") + ) + like_plugin = {} + for data in like_plugin_list: + name = data[0] + if plugin_data := plugin_data_manager.get(data[0]): + name = plugin_data.name + like_plugin[name] = data[1] + close_plugins = [] + for module in g.close_plugins: + plugin = Plugin(module=module, plugin_name=module) + if plugin_data := plugin_data_manager.get(module): + plugin.plugin_name = plugin_data.name + close_plugins.append(plugin) + task_list = [] + task_data = group_manager.get_task_data() + for tn, status in g.group_task_status.items(): + task_list.append( + Task( + name=tn, + zh_name=task_data.get(tn) or tn, + status=status + ) + ) + group_detail = GroupDetail( + group_id=group_id, + ava_url=GROUP_AVA_URL.format(group_id, group_id), + name=group_info['group_name'], + member_count=group_info['member_count'], + max_member_count=group_info['max_member_count'], + chat_count=await ChatHistory.filter(group_id=group_id).count(), + call_count=await Statistics.filter(group_id=group_id).count(), + like_plugin=like_plugin, + level=g.level, + status=g.status, + close_plugins=close_plugins, + task=task_list + ) + return Result.ok(group_detail) + else: + return Result.warning_("未添加指定群组...") + return Result.warning_("无Bot连接...") + + +@router.post("/send_message", dependencies=[authentication()], description="获取群组详情") +async def _(param: SendMessage) -> Result: + if bots := nonebot.get_bots(): + if param.bot_id in bots: + try: + if param.user_id: + await bots[param.bot_id].send_private_msg(user_id=str(param.user_id), message=param.message) + else: + await bots[param.bot_id].send_group_msg(group_id=str(param.group_id), message=param.message) + except Exception as e: + return Result.fail(str(e)) + return Result.ok("发送成功!") + return Result.warning_("指定Bot未连接...") + return Result.warning_("无Bot连接...") + +MSG_LIST = [] + +@ws_router.websocket("/chat") +async def _(websocket: WebSocket, group_id: Optional[str] = None, user_id: Optional[str] = None): + global MSG_LIST + await websocket.accept() + + async def log_listener(log: str): + sub_log = re.sub(SUB_PATTERN, "", log) + if "message.private.friend" in log: + if r := re.search(PRIVATE_PATTERN, sub_log): + msg_id = r.group(1) + uid = r.group(2) + msg = r.group(3) + user = await FriendUser.filter(user_id=user_id).first() + name = user.user_name + if uid and uid == user_id and msg_id not in MSG_LIST: + MSG_LIST.append(msg_id) + message = Message( + user_id=uid, + message=msg, + name=name, + ava_url=AVA_URL.format(uid) + ) + await websocket.send_json(message.dict()) + else: + if r := re.search(GROUP_PATTERN, sub_log): + msg_id = r.group(1) + uid = r.group(2) + gid = r.group(3) + msg = r.group(4) + user = await GroupInfoUser.filter(user_id=uid, group_id=gid).first() + name = user.user_name or user.nickname + if gid and gid == group_id and msg_id not in MSG_LIST: + MSG_LIST.append(msg_id) + message = Message( + user_id=uid, + group_id=gid, + message=msg, + name=name, + ava_url=AVA_URL.format(uid) + ) + await websocket.send_json(message.dict()) + + LOG_STORAGE.listeners.add(log_listener) + try: + while websocket.client_state == WebSocketState.CONNECTED: + recv = await websocket.receive() + except WebSocketDisconnect: + pass + finally: + LOG_STORAGE.listeners.remove(log_listener) + return \ No newline at end of file diff --git a/plugins/web_ui/api/tabs/manage/model.py b/plugins/web_ui/api/tabs/manage/model.py index 089da80d..bc8df37f 100644 --- a/plugins/web_ui/api/tabs/manage/model.py +++ b/plugins/web_ui/api/tabs/manage/model.py @@ -27,27 +27,33 @@ class Task(BaseModel): name: str """被动名称""" - nameZh: str + zh_name: str """被动中文名称""" status: bool """状态""" +class Plugin(BaseModel): + """ + 插件 + """ + + module: str + """模块名""" + plugin_name: str + """中文名""" + class GroupResult(BaseModel): """ 群组返回数据 """ - group: Group - """Group""" - level: int - """群等级""" - status: bool - """状态""" - close_plugins: List[str] - """关闭的插件""" - task: List[Task] - """被动列表""" + group_id: Union[str, int] + """群组id""" + group_name: str + """群组名称""" + ava_url: str + """群组头像""" class Friend(BaseModel): @@ -61,7 +67,8 @@ class Friend(BaseModel): """昵称""" remark: str = "" """备注""" - + ava_url: str = "" + """头像url""" class UpdateGroup(BaseModel): """ @@ -74,8 +81,10 @@ class UpdateGroup(BaseModel): """状态""" level: int """群权限""" - task_status: Dict[str, bool] + task: List[str] """被动状态""" + close_plugins: List[str] + """关闭插件""" class FriendRequestResult(BaseModel): @@ -103,6 +112,10 @@ class FriendRequestResult(BaseModel): """来自""" comment: Optional[str] """备注信息""" + ava_url: str + """头像""" + type: str + """类型 private group""" class GroupRequestResult(FriendRequestResult): @@ -121,10 +134,10 @@ class HandleRequest(BaseModel): 操作请求接收数据 """ - bot_id: str + bot_id: Optional[str] = None """bot_id""" - id: int - """id""" + flag: str + """flag""" request_type: Literal["private", "group"] """类型""" @@ -149,3 +162,97 @@ class DeleteFriend(BaseModel): """bot_id""" user_id: str """用户id""" + +class ReqResult(BaseModel): + """ + 好友/群组请求列表 + """ + + friend: List[FriendRequestResult] = [] + """好友请求列表""" + group: List[GroupRequestResult] = [] + """群组请求列表""" + + +class UserDetail(BaseModel): + """ + 用户详情 + """ + + user_id: str + """用户id""" + ava_url: str + """头像url""" + nickname: str + """昵称""" + remark: str + """备注""" + is_ban: bool + """是否被ban""" + chat_count: int + """发言次数""" + call_count: int + """功能调用次数""" + like_plugin: Dict[str, int] + """最喜爱的功能""" + + +class GroupDetail(BaseModel): + """ + 用户详情 + """ + + group_id: str + """群组id""" + ava_url: str + """头像url""" + name: str + """名称""" + member_count: int + """成员数""" + max_member_count: int + """最大成员数""" + chat_count: int + """发言次数""" + call_count: int + """功能调用次数""" + like_plugin: Dict[str, int] + """最喜爱的功能""" + level: int + """群权限""" + status: bool + """状态(睡眠)""" + close_plugins: List[Plugin] + """关闭的插件""" + task: List[Task] + """被动列表""" + + +class Message(BaseModel): + """ + 消息 + """ + + user_id: str + """用户id""" + group_id: Optional[str] = None + """群组id""" + message: str + """消息""" + name: str + """用户名称""" + ava_url: str + """用户头像""" + +class SendMessage(BaseModel): + """ + 发送消息 + """ + bot_id: str + """bot id""" + user_id: Optional[str] = None + """用户id""" + group_id: Optional[str] = None + """群组id""" + message: str + """消息""" diff --git a/plugins/web_ui/config.py b/plugins/web_ui/config.py index a85d9126..b6e6fde0 100644 --- a/plugins/web_ui/config.py +++ b/plugins/web_ui/config.py @@ -20,6 +20,11 @@ app.add_middleware( ) +AVA_URL = "http://q1.qlogo.cn/g?b=qq&nk={}&s=160" + +GROUP_AVA_URL = "http://p.qlogo.cn/gh/{}/{}/640/" + + class QueryDateType(StrEnum): """ diff --git a/utils/manager/requests_manager.py b/utils/manager/requests_manager.py index c97f2000..d9b5175c 100644 --- a/utils/manager/requests_manager.py +++ b/utils/manager/requests_manager.py @@ -1,6 +1,6 @@ from io import BytesIO from pathlib import Path -from typing import Optional +from typing import Optional, Union, overload from nonebot.adapters.onebot.v11 import ActionFailed, Bot @@ -67,14 +67,24 @@ class RequestManager(StaticData): } self.save() + @overload + def remove_request(self, type_: str, flag: str): + ... + + @overload def remove_request(self, type_: str, id_: int): + ... + + def remove_request(self, type_: str, id_: Union[int, str]): """ 删除一个请求数据 :param type_: 类型 :param id_: id,user_id 或 group_id """ for x in self._data[type_].keys(): - if self._data[type_][x].get("id") == id_: + a_id = self._data[type_][x].get("id") + a_flag = self._data[type_][x].get("flag") + if a_id == id_ or a_flag == id_: del self._data[type_][x] break self.save() @@ -88,8 +98,17 @@ class RequestManager(StaticData): if data: return data["invite_group"] return None - + + @overload async def approve(self, bot: Bot, id_: int, type_: str) -> int: + ... + + @overload + async def approve(self, bot: Bot, flag: str, type_: str) -> int: + ... + + + async def approve(self, bot: Bot, id_: Union[int, str], type_: str) -> int: """ 同意请求 :param bot: Bot @@ -97,8 +116,16 @@ class RequestManager(StaticData): :param type_: 类型,private 或 group """ return await self._set_add_request(bot, id_, type_, True) + + @overload + async def refused(self, bot: Bot, id_: int, type_: str) -> int: + ... - async def refused(self, bot: Bot, id_: int, type_: str) -> Optional[int]: + @overload + async def refused(self, bot: Bot, flag: str, type_: str) -> int: + ... + + async def refused(self, bot: Bot, id_: Union[int, str], type_: str) -> Optional[int]: """ 拒绝请求 :param bot: Bot @@ -120,18 +147,32 @@ class RequestManager(StaticData): self._data = {"private": {}, "group": {}} self.save() + @overload + async def delete_request(self, id_: int, type_: str) -> int: + ... + + @overload + async def delete_request(self, flag: str, type_: str) -> int: + ... + def delete_request( - self, id_: int, type_: str + self, id_: Union[str, int], type_: str ): # type_: Literal["group", "private"] """ 删除请求 :param id_: id :param type_: 类型 """ - id_ = str(id_) - if self._data[type_].get(id_): - del self._data[type_][id_] - self.save() + if type(id_) == int: + if self._data[type_].get(id_): + del self._data[type_][id_] + self.save() + else: + for k, item in self._data[type_].items(): + if item['flag'] == id_: + del self._data[type_][k] + self.save() + break def set_group_name(self, group_name: str, group_id: int): """ @@ -239,7 +280,7 @@ class RequestManager(StaticData): return bk.pic2bs4() async def _set_add_request( - self, bot: Bot, idx: int, type_: str, approve: bool + self, bot: Bot, idx: Union[str, int], type_: str, approve: bool ) -> int: """ 处理请求 @@ -248,8 +289,13 @@ class RequestManager(StaticData): :param type_: 类型,private 或 group :param approve: 是否同意 """ - id_ = str(idx) - if id_ in self._data[type_].keys(): + flag = None + id_ = None + if type(idx) == str: + flag = idx + else: + id_ = str(idx) + if id_ and id_ in self._data[type_].keys(): try: if type_ == "private": await bot.set_friend_add_request( @@ -277,4 +323,25 @@ class RequestManager(StaticData): del self._data[type_][id_] self.save() return rid + if flag: + rm_id = None + for k, item in self._data[type_].items(): + if item['flag'] == flag: + rm_id = k + if type_ == 'private': + await bot.set_friend_add_request( + flag=item['flag'], approve=approve + ) + rid = item["id"] + else: + await bot.set_group_add_request( + flag=item['flag'], + sub_type="invite", + approve=approve, + ) + rid = item["invite_group"] + if rm_id is not None: + del self._data[type_][rm_id] + self.save() + return rid return 2 # 未找到id