zhenxun_bot/plugins/word_bank/_model.py
2023-02-19 12:29:51 +08:00

504 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import re
import time
from datetime import datetime
from typing import Any, List, Optional, Tuple, Union
from nonebot.adapters.onebot.v11 import (
GroupMessageEvent,
Message,
MessageEvent,
MessageSegment,
)
from nonebot.internal.adapter.template import MessageTemplate
from tortoise import Tortoise, fields
from tortoise.expressions import Q
from configs.path_config import DATA_PATH
from services.db_context import Model
from utils.http_utils import AsyncHttpx
from utils.image_utils import get_img_hash
from utils.message_builder import at, face, image
from utils.utils import get_message_img
from ._config import int2type
path = DATA_PATH / "word_bank"
class WordBank(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True)
"""自增id"""
user_qq = fields.BigIntField()
"""用户id"""
group_id = fields.BigIntField(null=True)
"""群聊id"""
word_scope = fields.IntField(default=0)
"""生效范围 0: 全局 1: 群聊 2: 私聊"""
word_type = fields.IntField(default=0)
"""词条类型 0: 完全匹配 1: 模糊 2: 正则 3: 图片"""
status = fields.BooleanField()
"""词条状态"""
problem = fields.TextField()
"""问题为图片时使用图片hash"""
answer = fields.TextField()
"""回答"""
placeholder = fields.TextField(null=True)
"""占位符"""
image_path = fields.TextField(null=True)
"""使用图片作为问题时图片存储的路径"""
to_me = fields.CharField(255, null=True)
"""昵称开头时存储的昵称"""
create_time = fields.DatetimeField(auto_now=True)
"""创建时间"""
update_time = fields.DatetimeField(auto_now_add=True)
"""更新时间"""
class Meta:
table = "word_bank2"
table_description = "词条数据库"
@classmethod
async def exists(
cls,
user_id: Optional[int],
group_id: Optional[int],
problem: str,
answer: Optional[str],
word_scope: Optional[int] = None,
word_type: Optional[int] = None,
) -> bool:
"""
说明:
检测问题是否存在
参数:
:param user_id: 用户id
:param group_id: 群号
:param problem: 问题
:param answer: 回答
:param word_scope: 词条范围
:param word_type: 词条类型
"""
query = cls.filter(problem=problem)
if user_id:
query = query.filter(user_qq=user_id)
if group_id:
query = query.filter(group_id=group_id)
if answer:
query = query.filter(answer=answer)
if word_type is not None:
query = query.filter(word_type=word_type)
if word_scope is not None:
query = query.filter(word_scope=word_scope)
return bool(await query.first())
@classmethod
async def add_problem_answer(
cls,
user_id: int,
group_id: Optional[int],
word_scope: int,
word_type: int,
problem: Union[str, Message],
answer: Union[str, Message],
to_me_nickname: Optional[str] = None,
):
"""
说明:
添加或新增一个问答
参数:
:param user_id: 用户id
:param group_id: 群号
:param word_scope: 词条范围,
:param word_type: 词条类型,
:param problem: 问题
:param answer: 回答
:param to_me_nickname: at真寻名称
"""
# 对图片做额外处理
image_path = None
if word_type == 3:
url = get_message_img(problem)[0]
_file = (
path / "problem" / f"{group_id}" / f"{user_id}_{int(time.time())}.jpg"
)
_file.parent.mkdir(exist_ok=True, parents=True)
await AsyncHttpx.download_file(url, _file)
problem = str(get_img_hash(_file))
image_path = f"problem/{group_id}/{user_id}_{int(time.time())}.jpg"
answer, _list = await cls._answer2format(answer, user_id, group_id)
if not await cls.exists(
user_id, group_id, problem, answer, word_scope, word_type
):
await cls.create(
user_qq=user_id,
group_id=group_id,
word_scope=word_scope,
word_type=word_type,
status=True,
problem=problem,
answer=answer,
image_path=image_path,
placeholder=",".join(_list),
create_time=datetime.now().replace(microsecond=0),
update_time=datetime.now().replace(microsecond=0),
to_me=to_me_nickname,
)
@classmethod
async def _answer2format(
cls, answer: Union[str, Message], user_id: int, group_id: int
) -> Tuple[str, List[Any]]:
"""
说明:
将CQ码转化为占位符
参数:
:param answer: 回答内容
:param user_id: 用户id
:param group_id: 群号
"""
if isinstance(answer, str):
return answer, []
_list = []
text = ""
index = 0
for seg in answer:
if isinstance(seg, str):
text += seg
elif seg.type == "text":
text += seg.data["text"]
elif seg.type == "face":
text += f"[face:placeholder_{index}]"
_list.append(seg.data["id"])
elif seg.type == "at":
text += f"[at:placeholder_{index}]"
_list.append(seg.data["qq"])
else:
text += f"[image:placeholder_{index}]"
index += 1
t = int(time.time())
_file = path / "answer" / f"{group_id}" / f"{user_id}_{t}.jpg"
_file.parent.mkdir(exist_ok=True, parents=True)
await AsyncHttpx.download_file(seg.data["url"], _file)
_list.append(f"answer/{group_id}/{user_id}_{t}.jpg")
return text, _list
@classmethod
async def _format2answer(
cls,
problem: str,
answer: Union[str, Message],
user_id: int,
group_id: int,
query: Optional["WordBank"] = None,
) -> Union[str, Message]:
"""
说明:
将占位符转换为CQ码
参数:
:param problem: 问题内容
:param answer: 回答内容
:param user_id: 用户id
:param group_id: 群号
"""
if query:
answer = query.answer
else:
query = await cls.get_or_none(
problem=problem,
user_qq=user_id,
group_id=group_id,
answer=answer,
)
if query and query.placeholder:
type_list = re.findall(rf"\[(.*?):placeholder_.*?]", answer)
temp_answer = re.sub(rf"\[(.*?):placeholder_.*?]", "{}", answer)
seg_list = []
for t, p in zip(type_list, query.placeholder.split(",")):
if t == "image":
seg_list.append(image(path / p))
elif t == "face":
seg_list.append(face(p))
elif t == "at":
seg_list.append(at(p))
return MessageTemplate(temp_answer, Message).format(*seg_list)
return answer
@classmethod
async def check_problem(
cls,
event: MessageEvent,
problem: str,
word_scope: Optional[int] = None,
word_type: Optional[int] = None,
) -> Optional[Any]:
"""
说明:
检测是否包含该问题并获取所有回答
参数:
:param event: event
:param problem: 问题内容
:param word_scope: 词条范围
:param word_type: 词条类型
"""
query = cls
if isinstance(event, GroupMessageEvent):
if word_scope:
query = query.filter(word_scope=word_scope)
else:
query = query.filter(Q(group_id=event.group_id) | Q(word_scope=0))
else:
query = query.filter(Q(cword_scope=2) | Q(word_scope=0))
if word_type:
query = query.filter(word_scope=word_type)
# 完全匹配
if data_list := await query.filter(
Q(Q(word_type=0) | Q(word_type=3)), Q(problem=problem)
).all():
return data_list
db = Tortoise.get_connection("default")
# 模糊匹配
sql = query.filter(word_type=1).sql() + " and POSITION(problem in $1) > 0"
data_list = await db.execute_query_dict(sql, [problem])
if data_list:
return [cls(**data) for data in data_list]
# 正则
sql = (
query.filter(word_type=2, word_scope__not=999).sql() + " and $1 ~ problem;"
)
data_list = await db.execute_query_dict(sql, [problem])
if data_list:
return [cls(**data) for data in data_list]
return None
@classmethod
async def get_answer(
cls,
event: MessageEvent,
problem: str,
word_scope: Optional[int] = None,
word_type: Optional[int] = None,
) -> Optional[Union[str, Message]]:
"""
说明:
根据问题内容获取随机回答
参数:
:param event: event
:param problem: 问题内容
:param word_scope: 词条范围
:param word_type: 词条类型
"""
data_list = await cls.check_problem(event, problem, word_scope, word_type)
if data_list:
answer = random.choice(data_list)
return (
await cls._format2answer(
problem, answer.answer, answer.user_qq, answer.group_id
)
if answer.placeholder
else answer.answer
)
@classmethod
async def get_problem_all_answer(
cls,
problem: str,
index: Optional[int] = None,
group_id: Optional[int] = None,
word_scope: Optional[int] = 0,
) -> List[Union[str, Message]]:
"""
说明:
获取指定问题所有回答
参数:
:param problem: 问题
:param index: 下标
:param group_id: 群号
:param word_scope: 词条范围
"""
if index is not None:
if group_id:
problem_ = (await cls.filter(group_id=group_id).all())[index]
else:
problem_ = (await cls.filter(word_scope=(word_scope or 0)).all())[index]
problem = problem_.problem
answer = cls.filter(problem=problem)
if group_id:
answer = answer.filter(group_id=group_id)
return [await cls._format2answer("", "", 0, 0, x) for x in (await answer.all())]
@classmethod
async def delete_group_problem(
cls,
problem: str,
group_id: int,
index: Optional[int] = None,
word_scope: int = 1,
):
"""
说明:
删除指定问题全部或指定回答
参数:
:param problem: 问题文本
:param group_id: 群号
:param index: 回答下标
:param word_scope: 词条范围
"""
if await cls.exists(None, group_id, problem, None, word_scope):
if index is not None:
if group_id:
query = await cls.filter(group_id=group_id, problem=problem).all()
else:
query = await cls.filter(word_scope=0, problem=problem).all()
await query[index].delete()
else:
if group_id:
await WordBank.filter(group_id=group_id, problem=problem).delete()
else:
await WordBank.filter(
word_scope=word_scope, problem=problem
).delete()
return True
return False
@classmethod
async def update_group_problem(
cls,
problem: str,
replace_str: str,
group_id: int,
index: Optional[int] = None,
word_scope: int = 1,
):
"""
说明:
修改词条问题
参数:
:param problem: 问题
:param replace_str: 替换问题
:param group_id: 群号
:param index: 下标
:param word_scope: 词条范围
"""
if index is not None:
if group_id:
query = await cls.filter(group_id=group_id, problem=problem).all()
else:
query = await cls.filter(word_scope=word_scope, problem=problem).all()
query[index].problem = replace_str
await query[index].save(update_fields=["problem"])
else:
if group_id:
await cls.filter(group_id=group_id, problem=problem).update(
problem=replace_str
)
else:
await cls.filter(word_scope=word_scope, problem=problem).update(
problem=replace_str
)
@classmethod
async def get_group_all_problem(
cls, group_id: int
) -> List[Tuple[Any, Union[MessageSegment, str]]]:
"""
说明:
获取群聊所有词条
参数:
:param group_id: 群号
"""
return cls._handle_problem(
await cls.filter(group_id=group_id).all() # type: ignore
)
@classmethod
async def get_problem_by_scope(cls, word_scope: int):
"""
说明:
通过词条范围获取词条
参数:
:param word_scope: 词条范围
"""
return cls._handle_problem(
await cls.filter(word_scope=word_scope).all() # type: ignore
)
@classmethod
async def get_problem_by_type(cls, word_type: int):
"""
说明:
通过词条类型获取词条
参数:
:param word_type: 词条类型
"""
return cls._handle_problem(
await cls.filter(word_type=word_type).all() # type: ignore
)
@classmethod
def _handle_problem(cls, msg_list: List["WordBank"]):
"""
说明:
格式化处理问题
参数:
:param msg_list: 消息列表
"""
_tmp = []
problem_list = []
for q in msg_list:
if q.problem not in _tmp:
problem = (
q.problem,
image(path / q.image_path)
if q.image_path
else f"[{int2type[q.word_type]}] " + q.problem,
)
problem_list.append(problem)
_tmp.append(q.problem)
return problem_list
@classmethod
async def _move(
cls,
user_id: int,
group_id: Optional[int],
problem: Union[str, Message],
answer: Union[str, Message],
placeholder: str,
):
"""
说明:
旧词条图片移动方法
参数:
:param user_id: 用户id
:param group_id: 群号
:param problem: 问题
:param answer: 回答
:param placeholder: 占位符
"""
word_scope = 0
word_type = 0
# 对图片做额外处理
if not await cls.exists(
user_id, group_id, problem, answer, word_scope, word_type
):
await cls.create(
user_qq=user_id,
group_id=group_id,
word_scope=word_scope,
word_type=word_type,
status=True,
problem=problem,
answer=answer,
image_path=None,
placeholder=placeholder,
create_time=datetime.now().replace(microsecond=0),
update_time=datetime.now().replace(microsecond=0),
)
@classmethod
async def _run_script(cls):
await cls.raw("ALTER TABLE word_bank2 ADD to_me varchar(255);")
"""添加 to_me 字段"""