zhenxun_bot/zhenxun/plugins/word_bank/_model.py
2024-07-27 04:30:03 +08:00

567 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import re
import time
import uuid
from datetime import datetime
from typing import Any
from nonebot_plugin_alconna import At as alcAt
from nonebot_plugin_alconna import Image as alcImage
from nonebot_plugin_alconna import Text as alcText
from nonebot_plugin_saa import Image, Mention, MessageFactory, Text
from tortoise import Tortoise, fields
from tortoise.expressions import Q
from typing_extensions import Self
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.services.db_context import Model
from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.image_utils import get_img_hash
from ._config import int2type
path = DATA_PATH / "word_bank"
class WordBank(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True)
"""自增id"""
user_id = fields.CharField(255)
"""用户id"""
group_id = fields.CharField(255, null=True)
"""群聊id"""
word_scope = fields.IntField(default=0)
"""生效范围 0: 全局 1: 群聊 2: 私聊"""
word_type = fields.IntField(default=0)
"""词条类型 0: 完全匹配 1: 模糊 2: 正则 3: 图片"""
status = fields.BooleanField()
"""词条状态"""
problem = fields.TextField()
"""问题为图片时使用图片hash"""
answer = fields.TextField()
"""回答"""
placeholder = fields.TextField(null=True)
"""占位符"""
image_path = fields.TextField(null=True)
"""使用图片作为问题时图片存储的路径"""
to_me = fields.CharField(255, null=True)
"""昵称开头时存储的昵称"""
create_time = fields.DatetimeField(auto_now=True)
"""创建时间"""
update_time = fields.DatetimeField(auto_now_add=True)
"""更新时间"""
platform = fields.CharField(255, default="qq")
"""平台"""
author = fields.CharField(255, null=True, default="")
"""收录人"""
class Meta:
table = "word_bank2"
table_description = "词条数据库"
@classmethod
async def exists(
cls,
user_id: str | None,
group_id: str | None,
problem: str,
answer: str | None,
word_scope: int | None = None,
word_type: int | None = None,
) -> bool:
"""检测问题是否存在
参数:
user_id: 用户id
group_id: 群号
problem: 问题
answer: 回答
word_scope: 词条范围
word_type: 词条类型
"""
query = cls.filter(problem=problem)
if user_id:
query = query.filter(user_id=user_id)
if group_id:
query = query.filter(group_id=group_id)
if answer:
query = query.filter(answer=answer)
if word_type is not None:
query = query.filter(word_type=word_type)
if word_scope is not None:
query = query.filter(word_scope=word_scope)
return await query.exists()
@classmethod
async def add_problem_answer(
cls,
user_id: str,
group_id: str | None,
word_scope: int,
word_type: int,
problem: str,
answer: list[str | alcText | alcAt | alcImage],
to_me_nickname: str | None = None,
platform: str = "",
author: str = "",
):
"""添加或新增一个问答
参数:
user_id: 用户id
group_id: 群号
word_scope: 词条范围,
word_type: 词条类型,
problem: 问题, 为图片时是URl
answer: 回答
to_me_nickname: at真寻名称
platform: 所属平台
author: 收录人id
"""
# 对图片做额外处理
image_path = None
if word_type == 3:
_file = (
path / "problem" / f"{group_id}" / f"{user_id}_{int(time.time())}.jpg"
)
_file.parent.mkdir(exist_ok=True, parents=True)
await AsyncHttpx.download_file(problem, _file)
problem = get_img_hash(_file)
image_path = f"problem/{group_id}/{user_id}_{int(time.time())}.jpg"
new_answer, placeholder_list = await cls._answer2format(
answer, user_id, group_id
)
if not await cls.exists(
user_id, group_id, problem, new_answer, word_scope, word_type
):
await cls.create(
user_id=user_id,
group_id=group_id,
word_scope=word_scope,
word_type=word_type,
status=True,
problem=str(problem).strip(),
answer=new_answer,
image_path=image_path,
placeholder=",".join(placeholder_list),
create_time=datetime.now().replace(microsecond=0),
update_time=datetime.now().replace(microsecond=0),
to_me=to_me_nickname,
platform=platform,
author=author,
)
@classmethod
async def _answer2format(
cls,
answer: list[str | alcText | alcAt | alcImage],
user_id: str,
group_id: str | None,
) -> tuple[str, list[Any]]:
"""将特殊字段转化为占位符图片at等
参数:
answer: 回答内容
user_id: 用户id
group_id: 群号
返回:
tuple[str, list[Any]]: 替换后的文本回答内容,占位符
"""
placeholder_list = []
text = ""
index = 0
for seg in answer:
placeholder = uuid.uuid1()
if isinstance(seg, str):
text += seg
elif isinstance(seg, alcText):
text += seg.text
elif seg.type == "face": # TODO: face貌似无用...
text += f"[face:placeholder_{placeholder}]"
placeholder_list.append(seg.data["id"])
elif isinstance(seg, alcAt):
text += f"[at:placeholder_{placeholder}]"
placeholder_list.append(seg.target)
elif isinstance(seg, alcImage) and seg.url:
text += f"[image:placeholder_{placeholder}]"
index += 1
_file = (
path
/ "answer"
/ f"{group_id or user_id}"
/ f"{user_id}_{placeholder}.jpg"
)
_file.parent.mkdir(exist_ok=True, parents=True)
await AsyncHttpx.download_file(seg.url, _file)
placeholder_list.append(
f"answer/{group_id or user_id}/{user_id}_{placeholder}.jpg"
)
return text, placeholder_list
@classmethod
async def _format2answer(
cls,
problem: str,
answer: str,
user_id: int,
group_id: int,
query: Self | None = None,
) -> MessageFactory | Text:
"""将占位符转换为实际内容
参数:
problem: 问题内容
answer: 回答内容
user_id: 用户id
group_id: 群组id
"""
result_list = []
if not query:
query = await cls.get_or_none(
problem=problem,
user_id=user_id,
group_id=group_id,
answer=answer,
)
if not answer:
answer = str(query.answer) # type: ignore
if query and query.placeholder:
type_list = re.findall(rf"\[(.*?):placeholder_.*?]", answer)
answer_split = re.split(rf"\[.*:placeholder_.*?]", answer)
placeholder_split = query.placeholder.split(",")
for index, ans in enumerate(answer_split):
result_list.append(Text(ans))
if index < len(type_list):
t = type_list[index]
p = placeholder_split[index]
if t == "image":
result_list.append(Image(path / p))
elif t == "at":
result_list.append(Mention(p))
return MessageFactory(result_list)
return Text(answer)
@classmethod
async def check_problem(
cls,
group_id: str | None,
problem: str,
word_scope: int | None = None,
word_type: int | None = None,
) -> Any:
"""检测是否包含该问题并获取所有回答
参数:
group_id: 群组id
problem: 问题内容
word_scope: 词条范围
word_type: 词条类型
"""
query = cls
if group_id:
if word_scope:
query = query.filter(word_scope=word_scope)
else:
query = query.filter(Q(group_id=group_id) | Q(word_scope=0))
else:
query = query.filter(Q(word_scope=2) | Q(word_scope=0))
if word_type:
query = query.filter(word_scope=word_type)
# 完全匹配
if data_list := await query.filter(
Q(Q(word_type=0) | Q(word_type=3)), Q(problem=problem)
).all():
return data_list
db = Tortoise.get_connection("default")
# 模糊匹配
sql = query.filter(word_type=1).sql() + " and POSITION(problem in $1) > 0"
data_list = await db.execute_query_dict(sql, [problem])
if data_list:
return [cls(**data) for data in data_list]
# 正则
sql = (
query.filter(word_type=2, word_scope__not=999).sql() + " and $1 ~ problem;"
)
data_list = await db.execute_query_dict(sql, [problem])
if data_list:
return [cls(**data) for data in data_list]
return None
@classmethod
async def get_answer(
cls,
group_id: str | None,
problem: str,
word_scope: int | None = None,
word_type: int | None = None,
) -> Text | MessageFactory | None:
"""根据问题内容获取随机回答
参数:
user_id: 用户id
group_id: 群组id
problem: 问题内容
word_scope: 词条范围
word_type: 词条类型
"""
data_list = await cls.check_problem(group_id, problem, word_scope, word_type)
if data_list:
random_answer = random.choice(data_list)
if random_answer.word_type == 2:
r = re.search(random_answer.problem, problem)
has_placeholder = re.search(rf"\$(\d)", random_answer.answer)
if r and r.groups() and has_placeholder:
pats = re.sub(r"\$(\d)", r"\\\1", random_answer.answer)
random_answer.answer = re.sub(random_answer.problem, pats, problem)
return (
await cls._format2answer(
random_answer.problem,
random_answer.answer,
random_answer.user_id,
random_answer.group_id,
random_answer,
)
if random_answer.placeholder
else Text(random_answer.answer)
)
@classmethod
async def get_problem_all_answer(
cls,
problem: str,
index: int | None = None,
group_id: str | None = None,
word_scope: int | None = 0,
) -> tuple[str, list[Text | MessageFactory]]:
"""获取指定问题所有回答
参数:
problem: 问题
index: 下标
group_id: 群号
word_scope: 词条范围
返回:
tuple[str, list[Text | MessageFactory]]: 问题和所有回答
"""
if index is not None:
# TODO: group_by和order_by不能同时使用
if group_id:
_problem = (
await cls.filter(group_id=group_id).order_by("create_time")
# .group_by("problem")
.values_list("problem", flat=True)
)
else:
_problem = (
await cls.filter(word_scope=(word_scope or 0)).order_by(
"create_time"
)
# .group_by("problem")
.values_list("problem", flat=True)
)
# if index is None and problem not in _problem:
# return "词条不存在...", []
sort_problem = []
for p in _problem:
if p not in sort_problem:
sort_problem.append(p)
if index > len(sort_problem) - 1:
return "下标错误,必须小于问题数量...", []
problem = sort_problem[index] # type: ignore
f = cls.filter(problem=problem, word_scope=(word_scope or 0))
if group_id:
f = f.filter(group_id=group_id)
answer_list = await f.all()
if not answer_list:
return "词条不存在...", []
return problem, [await cls._format2answer("", "", 0, 0, a) for a in answer_list]
@classmethod
async def delete_group_problem(
cls,
problem: str,
group_id: str | None,
index: int | None = None,
word_scope: int = 1,
):
"""删除指定问题全部或指定回答
参数:
problem: 问题文本
group_id: 群号
index: 回答下标
word_scope: 词条范围
"""
if await cls.exists(None, group_id, problem, None, word_scope):
if index is not None:
if group_id:
query = await cls.filter(
group_id=group_id, problem=problem, word_scope=word_scope
).all()
else:
query = await cls.filter(
word_scope=word_scope, problem=problem
).all()
await query[index].delete()
else:
if group_id:
await WordBank.filter(
group_id=group_id, problem=problem, word_scope=word_scope
).delete()
else:
await WordBank.filter(
word_scope=word_scope, problem=problem
).delete()
return True
return False
@classmethod
async def update_group_problem(
cls,
problem: str,
replace_str: str,
group_id: str | None,
index: int | None = None,
word_scope: int = 1,
) -> str:
"""修改词条问题
参数:
problem: 问题
replace_str: 替换问题
group_id: 群号
index: 问题下标
word_scope: 词条范围
返回:
str: 修改前的问题
"""
if index is not None:
if group_id:
query = await cls.filter(group_id=group_id, problem=problem).all()
else:
query = await cls.filter(word_scope=word_scope, problem=problem).all()
tmp = query[index].problem
query[index].problem = replace_str
await query[index].save(update_fields=["problem"])
return tmp
else:
if group_id:
await cls.filter(group_id=group_id, problem=problem).update(
problem=replace_str
)
else:
await cls.filter(word_scope=word_scope, problem=problem).update(
problem=replace_str
)
return problem
@classmethod
async def get_group_all_problem(cls, group_id: str) -> list[tuple[Any | str]]:
"""获取群聊所有词条
参数:
group_id: 群号
"""
return cls._handle_problem(
await cls.filter(group_id=group_id).order_by("create_time").all() # type: ignore
)
@classmethod
async def get_problem_by_scope(cls, word_scope: int):
"""通过词条范围获取词条
参数:
word_scope: 词条范围
"""
return cls._handle_problem(
await cls.filter(word_scope=word_scope).order_by("create_time").all() # type: ignore
)
@classmethod
async def get_problem_by_type(cls, word_type: int):
"""通过词条类型获取词条
参数:
word_type: 词条类型
"""
return cls._handle_problem(
await cls.filter(word_type=word_type).order_by("create_time").all() # type: ignore
)
@classmethod
def _handle_problem(cls, problem_list: list["WordBank"]):
"""格式化处理问题
参数:
msg_list: 消息列表
"""
_tmp = []
result_list = []
for q in problem_list:
if q.problem not in _tmp:
# TODO: 获取收录人名称
problem = (
(path / q.image_path, 30, 30) if q.image_path else q.problem,
int2type[q.word_type],
# q.author,
"-",
)
result_list.append(problem)
_tmp.append(q.problem)
return result_list
@classmethod
async def _move(
cls,
user_id: str,
group_id: str | None,
problem: str,
answer: str,
placeholder: str,
):
"""旧词条图片移动方法
参数:
user_id: 用户id
group_id: 群号
problem: 问题
answer: 回答
placeholder: 占位符
"""
word_scope = 0
word_type = 0
# 对图片做额外处理
if not await cls.exists(
user_id, group_id, problem, answer, word_scope, word_type
):
await cls.create(
user_id=user_id,
group_id=group_id,
word_scope=word_scope,
word_type=word_type,
status=True,
problem=problem,
answer=answer,
image_path=None,
placeholder=placeholder,
create_time=datetime.now().replace(microsecond=0),
update_time=datetime.now().replace(microsecond=0),
)
@classmethod
async def _run_script(cls):
return [
"ALTER TABLE word_bank2 ADD to_me varchar(255);", # 添加 to_me 字段
"ALTER TABLE word_bank2 ALTER COLUMN create_time TYPE timestamp with time zone USING create_time::timestamp with time zone;",
"ALTER TABLE word_bank2 ALTER COLUMN update_time TYPE timestamp with time zone USING update_time::timestamp with time zone;",
"ALTER TABLE word_bank2 RENAME COLUMN user_qq TO user_id;", # 将user_qq改为user_id
"ALTER TABLE word_bank2 ALTER COLUMN user_id TYPE character varying(255);",
"ALTER TABLE word_bank2 ALTER COLUMN group_id TYPE character varying(255);",
"ALTER TABLE word_bank2 ADD platform varchar(255) DEFAULT 'qq';",
"ALTER TABLE word_bank2 ADD author varchar(255) DEFAULT '';",
]