zhenxun_bot/utils/utils.py

364 lines
8.3 KiB
Python
Raw Normal View History

2021-11-04 16:11:50 +08:00
from datetime import datetime
2021-06-30 19:50:55 +08:00
from collections import defaultdict
from nonebot import require
2021-07-30 21:21:51 +08:00
from configs.config import SYSTEM_PROXY
2021-11-04 16:11:50 +08:00
from typing import List, Union, Optional, Type, Any
2021-10-03 14:24:07 +08:00
from nonebot.adapters.cqhttp import Bot
from nonebot.matcher import matchers, Matcher
2021-11-04 16:11:50 +08:00
import httpx
2021-06-30 19:50:55 +08:00
import nonebot
import pytz
import pypinyin
2021-07-30 21:21:51 +08:00
import time
try:
import ujson as json
except ModuleNotFoundError:
import json
2021-06-30 19:50:55 +08:00
2021-07-30 21:21:51 +08:00
scheduler = require("nonebot_plugin_apscheduler").scheduler
2021-06-30 19:50:55 +08:00
class CountLimiter:
2021-07-30 21:21:51 +08:00
"""
次数检测工具检测调用次数是否超过设定值
"""
def __init__(self, max_count: int):
2021-06-30 19:50:55 +08:00
self.count = defaultdict(int)
2021-07-30 21:21:51 +08:00
self.max_count = max_count
2021-06-30 19:50:55 +08:00
2021-11-04 16:11:50 +08:00
def add(self, key: Any):
2021-06-30 19:50:55 +08:00
self.count[key] += 1
2021-11-04 16:11:50 +08:00
def check(self, key: Any) -> bool:
2021-07-30 21:21:51 +08:00
if self.count[key] >= self.max_count:
2021-06-30 19:50:55 +08:00
self.count[key] = 0
return True
return False
2021-10-03 14:24:07 +08:00
class UserBlockLimiter:
2021-07-30 21:21:51 +08:00
"""
检测用户是否正在调用命令
"""
2021-06-30 19:50:55 +08:00
def __init__(self):
2021-07-30 21:21:51 +08:00
self.flag_data = defaultdict(bool)
2021-06-30 19:50:55 +08:00
self.time = time.time()
2021-11-04 16:11:50 +08:00
def set_true(self, key: Any):
2021-06-30 19:50:55 +08:00
self.time = time.time()
2021-07-30 21:21:51 +08:00
self.flag_data[key] = True
2021-06-30 19:50:55 +08:00
2021-11-04 16:11:50 +08:00
def set_false(self, key: Any):
2021-07-30 21:21:51 +08:00
self.flag_data[key] = False
2021-06-30 19:50:55 +08:00
2021-11-04 16:11:50 +08:00
def check(self, key: Any) -> bool:
2021-06-30 19:50:55 +08:00
if time.time() - self.time > 30:
2021-09-05 02:21:38 +08:00
self.set_false(key)
2021-06-30 19:50:55 +08:00
return False
2021-07-30 21:21:51 +08:00
return self.flag_data[key]
2021-06-30 19:50:55 +08:00
class FreqLimiter:
2021-07-30 21:21:51 +08:00
"""
命令冷却检测用户是否处于冷却状态
"""
def __init__(self, default_cd_seconds: int):
2021-06-30 19:50:55 +08:00
self.next_time = defaultdict(float)
self.default_cd = default_cd_seconds
2021-11-04 16:11:50 +08:00
def check(self, key: Any) -> bool:
2021-06-30 19:50:55 +08:00
return time.time() >= self.next_time[key]
2021-11-04 16:11:50 +08:00
def start_cd(self, key: Any, cd_time: int = 0):
2021-07-30 21:21:51 +08:00
self.next_time[key] = time.time() + (
cd_time if cd_time > 0 else self.default_cd
)
2021-06-30 19:50:55 +08:00
2021-11-04 16:11:50 +08:00
def left_time(self, key: Any) -> float:
2021-06-30 19:50:55 +08:00
return self.next_time[key] - time.time()
static_flmt = FreqLimiter(15)
class BanCheckLimiter:
2021-07-30 21:21:51 +08:00
"""
恶意命令触发检测
"""
2021-06-30 19:50:55 +08:00
def __init__(self, default_check_time: float = 5, default_count: int = 4):
self.mint = defaultdict(int)
self.mtime = defaultdict(float)
self.default_check_time = default_check_time
self.default_count = default_count
2021-07-30 21:21:51 +08:00
def add(self, key: Union[str, int, float]):
2021-06-30 19:50:55 +08:00
if self.mint[key] == 1:
self.mtime[key] = time.time()
self.mint[key] += 1
2021-07-30 21:21:51 +08:00
def check(self, key: Union[str, int, float]) -> bool:
2021-06-30 19:50:55 +08:00
if time.time() - self.mtime[key] > self.default_check_time:
self.mtime[key] = time.time()
self.mint[key] = 0
return False
2021-07-30 21:21:51 +08:00
if (
self.mint[key] >= self.default_count
and time.time() - self.mtime[key] < self.default_check_time
):
2021-06-30 19:50:55 +08:00
self.mtime[key] = time.time()
self.mint[key] = 0
return True
return False
class DailyNumberLimiter:
2021-07-30 21:21:51 +08:00
"""
每日调用命令次数限制
"""
tz = pytz.timezone("Asia/Shanghai")
2021-06-30 19:50:55 +08:00
def __init__(self, max_num):
self.today = -1
self.count = defaultdict(int)
self.max = max_num
def check(self, key) -> bool:
2021-11-04 16:11:50 +08:00
day = datetime.now(self.tz).day
2021-06-30 19:50:55 +08:00
if day != self.today:
self.today = day
self.count.clear()
return bool(self.count[key] < self.max)
def get_num(self, key):
return self.count[key]
def increase(self, key, num=1):
self.count[key] += num
def reset(self, key):
self.count[key] = 0
2021-07-30 21:21:51 +08:00
def is_number(s: str) -> bool:
"""
说明
检测 s 是否为数字
参数
:param s: 文本
"""
2021-06-30 19:50:55 +08:00
try:
float(s)
return True
except ValueError:
pass
try:
import unicodedata
2021-07-30 21:21:51 +08:00
2021-06-30 19:50:55 +08:00
unicodedata.numeric(s)
return True
except (TypeError, ValueError):
pass
return False
2021-09-05 02:21:38 +08:00
def get_bot() -> Optional[Bot]:
2021-07-30 21:21:51 +08:00
"""
说明
获取 bot 对象
"""
2021-09-05 02:21:38 +08:00
try:
return list(nonebot.get_bots().values())[0]
except IndexError:
return None
2021-06-30 19:50:55 +08:00
2021-10-03 14:24:07 +08:00
def get_matchers() -> List[Type[Matcher]]:
"""
获取所有插件
"""
_matchers = []
for i in matchers.keys():
for matcher in matchers[i]:
_matchers.append(matcher)
return _matchers
2021-07-30 21:21:51 +08:00
def get_message_at(data: str) -> List[int]:
"""
说明
获取消息中所有的 at 对象的 qq
参数
:param data: event.json()
"""
2021-10-03 14:24:07 +08:00
try:
qq_list = []
data = json.loads(data)
for msg in data["message"]:
if msg["type"] == "at":
qq_list.append(int(msg["data"]["qq"]))
return qq_list
except KeyError:
return []
2021-07-30 21:21:51 +08:00
def get_message_imgs(data: str) -> List[str]:
"""
说明
获取消息中所有的 图片 的链接
参数
:param data: event.json()
"""
2021-10-03 14:24:07 +08:00
try:
img_list = []
data = json.loads(data)
for msg in data["message"]:
if msg["type"] == "image":
img_list.append(msg["data"]["url"])
return img_list
except KeyError:
return []
2021-06-30 19:50:55 +08:00
def get_message_text(data: str) -> str:
2021-07-30 21:21:51 +08:00
"""
说明
获取消息中 纯文本 的信息
参数
:param data: event.json()
"""
2021-10-03 14:24:07 +08:00
try:
data = json.loads(data)
result = ""
for msg in data["message"]:
if msg["type"] == "text":
result += msg["data"]["text"].strip() + " "
return result.strip()
except KeyError:
return ""
2021-07-30 21:21:51 +08:00
def get_message_record(data: str) -> List[str]:
"""
说明
获取消息中所有 语音 的链接
参数
:param data: event.json()
"""
2021-10-03 14:24:07 +08:00
try:
record_list = []
data = json.loads(data)
for msg in data["message"]:
if msg["type"] == "record":
record_list.append(msg["data"]["url"])
return record_list
except KeyError:
return []
2021-07-30 21:21:51 +08:00
def get_message_json(data: str) -> List[dict]:
"""
说明
获取消息中所有 json
参数
:param data: event.json()
"""
2021-10-03 14:24:07 +08:00
try:
json_list = []
data = json.loads(data)
for msg in data["message"]:
if msg["type"] == "json":
json_list.append(msg["data"])
return json_list
except KeyError:
return []
2021-06-30 19:50:55 +08:00
def get_local_proxy():
2021-07-30 21:21:51 +08:00
"""
说明
获取 config.py 中设置的代理
"""
return SYSTEM_PROXY if SYSTEM_PROXY else None
2021-11-04 16:11:50 +08:00
def is_chinese(word: str) -> bool:
2021-07-30 21:21:51 +08:00
"""
说明
判断字符串是否为纯中文
参数
:param word: 文本
"""
2021-06-30 19:50:55 +08:00
for ch in word:
2021-07-30 21:21:51 +08:00
if not "\u4e00" <= ch <= "\u9fff":
return False
return True
2021-06-30 19:50:55 +08:00
2021-11-04 16:11:50 +08:00
async def get_user_avatar(qq: int) -> bytes:
2021-07-30 21:21:51 +08:00
"""
说明
快捷获取用户头像
参数
:param qq: qq号
"""
url = f"http://q1.qlogo.cn/g?b=qq&nk={qq}&s=160"
2021-11-04 16:11:50 +08:00
async with httpx.AsyncClient() as client:
for _ in range(3):
try:
return (await client.get(url)).content
except TimeoutError:
pass
2021-06-30 19:50:55 +08:00
2021-11-04 16:11:50 +08:00
async def get_group_avatar(group_id: int) -> bytes:
2021-07-30 21:21:51 +08:00
"""
说明
快捷获取用群头像
参数
:param group_id: 群号
"""
url = f"http://p.qlogo.cn/gh/{group_id}/{group_id}/640/"
2021-11-04 16:11:50 +08:00
async with httpx.AsyncClient() as client:
for _ in range(3):
try:
return (await client.get(url)).content
except TimeoutError:
pass
2021-06-30 19:50:55 +08:00
2021-07-30 21:21:51 +08:00
def cn2py(word: str) -> str:
"""
说明
将字符串转化为拼音
参数
:param word: 文本
"""
2021-06-30 19:50:55 +08:00
temp = ""
for i in pypinyin.pinyin(word, style=pypinyin.NORMAL):
2021-07-30 21:21:51 +08:00
temp += "".join(i)
2021-06-30 19:50:55 +08:00
return temp
2021-09-05 02:21:38 +08:00
def change_picture_links(url: str, mode: str):
"""
说明
根据配置改变图片大小
参数
:param url: 图片原图链接
:param mode: 模式
"""
2021-10-03 14:24:07 +08:00
if mode == "master":
img_sp = url.rsplit(".", maxsplit=1)
2021-09-05 02:21:38 +08:00
url = img_sp[0]
img_type = img_sp[1]
2021-10-03 14:24:07 +08:00
url = url.replace("original", "master") + f"_master1200.{img_type}"
2021-09-05 02:21:38 +08:00
return url
2021-11-04 16:11:50 +08:00