zhenxun_bot/plugins/word_clouds/data_source.py

130 lines
4.1 KiB
Python
Raw Normal View History

2022-05-18 20:31:14 +08:00
import asyncio
import os
import random
import jieba.analyse
import re
from typing import List
from PIL import Image as IMG
import jieba
from emoji import replace_emoji # type: ignore
from wordcloud import WordCloud, ImageColorGenerator
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO
2022-05-19 15:14:20 +08:00
from configs.path_config import IMAGE_PATH, FONT_PATH
2022-05-23 22:01:56 +08:00
from services import logger
2022-05-18 20:31:14 +08:00
from utils.http_utils import AsyncHttpx
from models.chat_history import ChatHistory
from configs.config import Config
2022-05-19 15:14:20 +08:00
async def pre_precess(msg: List[str], config) -> str:
2022-05-18 20:31:14 +08:00
return await asyncio.get_event_loop().run_in_executor(
2022-05-23 22:01:56 +08:00
None, _pre_precess, msg, config
)
2022-05-18 20:31:14 +08:00
2022-05-23 22:01:56 +08:00
def _pre_precess(msg: List[str], config) -> str:
2022-05-18 20:31:14 +08:00
"""对消息进行预处理"""
# 过滤掉命令
command_start = tuple([i for i in config.command_start if i])
msg = " ".join([m for m in msg if not m.startswith(command_start)])
# 去除网址
msg = re.sub(r"https?://[\w/:%#\$&\?\(\)~\.=\+\-]+", "", msg)
# 去除 \u200b
2022-05-19 15:14:20 +08:00
msg = re.sub(r"[\u200b]", "", msg)
2022-05-18 20:31:14 +08:00
# 去除cq码
msg = re.sub(r"\[CQ:.*?]", "", msg)
2022-05-19 15:14:20 +08:00
# 去除&#91&#93
msg = re.sub("[&#9(1|3);]", "", msg)
2022-05-18 20:31:14 +08:00
# 去除 emoji
# https://github.com/carpedm20/emoji
msg = replace_emoji(msg)
return msg
async def draw_word_cloud(messages, config):
wordcloud_dir = IMAGE_PATH / "wordcloud"
wordcloud_dir.mkdir(exist_ok=True, parents=True)
# 默认用真寻图片
zx_logo_path = wordcloud_dir / "default.png"
wordcloud_ttf = FONT_PATH / "STKAITI.TTF"
if not os.listdir(wordcloud_dir):
url = "https://ghproxy.com/https://raw.githubusercontent.com/HibiKier/zhenxun_bot/main/resources/image/wordcloud/default.png"
try:
await AsyncHttpx.download_file(url, zx_logo_path)
2022-05-23 22:01:56 +08:00
except Exception as e:
logger.error(f"词云图片资源下载发生错误 {type(e)}{e}")
2022-05-18 20:31:14 +08:00
return False
if not wordcloud_ttf.exists():
2022-05-23 22:01:56 +08:00
ttf_url = "https://ghproxy.com/https://raw.githubusercontent.com/HibiKier/zhenxun_bot/main/resources/font/STKAITI.TTF"
2022-05-18 20:31:14 +08:00
try:
await AsyncHttpx.download_file(ttf_url, wordcloud_ttf)
2022-05-23 22:01:56 +08:00
except Exception as e:
logger.error(f"词云字体资源下载发生错误 {type(e)}{e}")
2022-05-18 20:31:14 +08:00
return False
topK = min(int(len(messages)), 100000)
2022-05-23 22:01:56 +08:00
read_name = jieba.analyse.extract_tags(
await pre_precess(messages, config), topK=topK, withWeight=True, allowPOS=()
)
2022-05-18 20:31:14 +08:00
name = []
value = []
for t in read_name:
name.append(t[0])
value.append(t[1])
for i in range(len(name)):
name[i] = str(name[i])
dic = dict(zip(name, value))
if Config.get_config("word_clouds", "WORD_CLOUDS_TEMPLATE") == 1:
2022-05-23 22:01:56 +08:00
2022-05-18 20:31:14 +08:00
def random_pic(base_path: str) -> str:
path_dir = os.listdir(base_path)
path = random.sample(path_dir, 1)[0]
2022-05-23 22:01:56 +08:00
return str(base_path) + "/" + str(path)
2022-05-18 20:31:14 +08:00
mask = np.array(IMG.open(random_pic(wordcloud_dir)))
wc = WordCloud(
font_path=f"{wordcloud_ttf}",
background_color="white",
max_font_size=100,
width=1920,
height=1080,
mask=mask,
)
wc.generate_from_frequencies(dic)
image_colors = ImageColorGenerator(mask, default_color=(255, 255, 255))
wc.recolor(color_func=image_colors)
plt.imshow(wc.recolor(color_func=image_colors), interpolation="bilinear")
plt.axis("off")
else:
wc = WordCloud(
font_path=str(wordcloud_ttf),
width=1920,
height=1200,
background_color="black",
)
wc.generate_from_frequencies(dic)
bytes_io = BytesIO()
img = wc.to_image()
img.save(bytes_io, format="PNG")
return bytes_io.getvalue()
async def get_list_msg(user_id, group_id, days):
2022-05-23 22:01:56 +08:00
messages_list = (
await ChatHistory()
._get_msg(uid=user_id, gid=group_id, type_="group", days=days)
.gino.all()
)
2022-05-18 20:31:14 +08:00
if messages_list:
messages = [i.text for i in messages_list]
return messages
else:
return False