zhenxun_bot/zhenxun/plugins/word_bank/_model.py

567 lines
19 KiB
Python
Raw Normal View History

2024-07-27 04:30:03 +08:00
import random
import re
import time
import uuid
from datetime import datetime
from typing import Any
from nonebot_plugin_alconna import At as alcAt
from nonebot_plugin_alconna import Image as alcImage
from nonebot_plugin_alconna import Text as alcText
2024-08-10 02:25:04 +08:00
from nonebot_plugin_alconna import UniMessage
2024-07-27 04:30:03 +08:00
from tortoise import Tortoise, fields
from tortoise.expressions import Q
from typing_extensions import Self
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.services.db_context import Model
from zhenxun.utils.http_utils import AsyncHttpx
from zhenxun.utils.image_utils import get_img_hash
2024-08-10 02:25:04 +08:00
from zhenxun.utils.message import MessageUtils
2024-07-27 04:30:03 +08:00
from ._config import int2type
path = DATA_PATH / "word_bank"
class WordBank(Model):
id = fields.IntField(pk=True, generated=True, auto_increment=True)
"""自增id"""
user_id = fields.CharField(255)
"""用户id"""
group_id = fields.CharField(255, null=True)
"""群聊id"""
word_scope = fields.IntField(default=0)
"""生效范围 0: 全局 1: 群聊 2: 私聊"""
word_type = fields.IntField(default=0)
"""词条类型 0: 完全匹配 1: 模糊 2: 正则 3: 图片"""
status = fields.BooleanField()
"""词条状态"""
problem = fields.TextField()
"""问题为图片时使用图片hash"""
answer = fields.TextField()
"""回答"""
placeholder = fields.TextField(null=True)
"""占位符"""
image_path = fields.TextField(null=True)
"""使用图片作为问题时图片存储的路径"""
to_me = fields.CharField(255, null=True)
"""昵称开头时存储的昵称"""
create_time = fields.DatetimeField(auto_now=True)
"""创建时间"""
update_time = fields.DatetimeField(auto_now_add=True)
"""更新时间"""
platform = fields.CharField(255, default="qq")
"""平台"""
author = fields.CharField(255, null=True, default="")
"""收录人"""
class Meta:
table = "word_bank2"
table_description = "词条数据库"
@classmethod
async def exists(
cls,
user_id: str | None,
group_id: str | None,
problem: str,
answer: str | None,
word_scope: int | None = None,
word_type: int | None = None,
) -> bool:
"""检测问题是否存在
参数:
user_id: 用户id
group_id: 群号
problem: 问题
answer: 回答
word_scope: 词条范围
word_type: 词条类型
"""
query = cls.filter(problem=problem)
if user_id:
query = query.filter(user_id=user_id)
if group_id:
query = query.filter(group_id=group_id)
if answer:
query = query.filter(answer=answer)
if word_type is not None:
query = query.filter(word_type=word_type)
if word_scope is not None:
query = query.filter(word_scope=word_scope)
return await query.exists()
@classmethod
async def add_problem_answer(
cls,
user_id: str,
group_id: str | None,
word_scope: int,
word_type: int,
problem: str,
answer: list[str | alcText | alcAt | alcImage],
to_me_nickname: str | None = None,
platform: str = "",
author: str = "",
):
"""添加或新增一个问答
参数:
user_id: 用户id
group_id: 群号
word_scope: 词条范围,
word_type: 词条类型,
problem: 问题, 为图片时是URl
answer: 回答
to_me_nickname: at真寻名称
platform: 所属平台
author: 收录人id
"""
# 对图片做额外处理
image_path = None
if word_type == 3:
_uuid = uuid.uuid1()
_file = path / "problem" / f"{group_id}" / f"{user_id}_{_uuid}.jpg"
2024-07-27 04:30:03 +08:00
_file.parent.mkdir(exist_ok=True, parents=True)
await AsyncHttpx.download_file(problem, _file)
problem = get_img_hash(_file)
image_path = f"problem/{group_id}/{user_id}_{_uuid}.jpg"
2024-07-27 04:30:03 +08:00
new_answer, placeholder_list = await cls._answer2format(
answer, user_id, group_id
)
if not await cls.exists(
user_id, group_id, problem, new_answer, word_scope, word_type
):
await cls.create(
user_id=user_id,
group_id=group_id,
word_scope=word_scope,
word_type=word_type,
status=True,
problem=str(problem).strip(),
answer=new_answer,
image_path=image_path,
placeholder=",".join(placeholder_list),
create_time=datetime.now().replace(microsecond=0),
update_time=datetime.now().replace(microsecond=0),
to_me=to_me_nickname,
platform=platform,
author=author,
)
@classmethod
async def _answer2format(
cls,
answer: list[str | alcText | alcAt | alcImage],
user_id: str,
group_id: str | None,
) -> tuple[str, list[Any]]:
"""将特殊字段转化为占位符图片at等
参数:
answer: 回答内容
user_id: 用户id
group_id: 群号
返回:
tuple[str, list[Any]]: 替换后的文本回答内容占位符
"""
placeholder_list = []
text = ""
index = 0
for seg in answer:
placeholder = uuid.uuid1()
if isinstance(seg, str):
text += seg
elif isinstance(seg, alcText):
text += seg.text
elif seg.type == "face": # TODO: face貌似无用...
text += f"[face:placeholder_{placeholder}]"
placeholder_list.append(seg.data["id"])
elif isinstance(seg, alcAt):
text += f"[at:placeholder_{placeholder}]"
placeholder_list.append(seg.target)
elif isinstance(seg, alcImage) and seg.url:
text += f"[image:placeholder_{placeholder}]"
index += 1
_file = (
path
/ "answer"
/ f"{group_id or user_id}"
/ f"{user_id}_{placeholder}.jpg"
)
_file.parent.mkdir(exist_ok=True, parents=True)
await AsyncHttpx.download_file(seg.url, _file)
placeholder_list.append(
f"answer/{group_id or user_id}/{user_id}_{placeholder}.jpg"
)
return text, placeholder_list
@classmethod
async def _format2answer(
cls,
problem: str,
answer: str,
user_id: int,
group_id: int,
query: Self | None = None,
2024-08-10 02:25:04 +08:00
) -> UniMessage:
2024-07-27 04:30:03 +08:00
"""将占位符转换为实际内容
参数:
problem: 问题内容
answer: 回答内容
user_id: 用户id
group_id: 群组id
"""
result_list = []
if not query:
query = await cls.get_or_none(
problem=problem,
user_id=user_id,
group_id=group_id,
answer=answer,
)
if not answer:
answer = str(query.answer) # type: ignore
if query and query.placeholder:
type_list = re.findall(rf"\[(.*?):placeholder_.*?]", answer)
answer_split = re.split(rf"\[.*:placeholder_.*?]", answer)
placeholder_split = query.placeholder.split(",")
for index, ans in enumerate(answer_split):
2024-08-10 02:25:04 +08:00
result_list.append(ans)
2024-07-27 04:30:03 +08:00
if index < len(type_list):
t = type_list[index]
p = placeholder_split[index]
if t == "image":
2024-08-10 02:25:04 +08:00
result_list.append(path / p)
2024-07-27 04:30:03 +08:00
elif t == "at":
2024-08-10 02:25:04 +08:00
result_list.append(alcAt(flag="user", target=p))
return MessageUtils.build_message(result_list)
return MessageUtils.build_message(answer)
2024-07-27 04:30:03 +08:00
@classmethod
async def check_problem(
cls,
group_id: str | None,
problem: str,
word_scope: int | None = None,
word_type: int | None = None,
) -> Any:
"""检测是否包含该问题并获取所有回答
参数:
group_id: 群组id
problem: 问题内容
word_scope: 词条范围
word_type: 词条类型
"""
query = cls
if group_id:
if word_scope:
query = query.filter(word_scope=word_scope)
else:
query = query.filter(Q(group_id=group_id) | Q(word_scope=0))
else:
query = query.filter(Q(word_scope=2) | Q(word_scope=0))
if word_type:
query = query.filter(word_scope=word_type)
# 完全匹配
if data_list := await query.filter(
Q(Q(word_type=0) | Q(word_type=3)), Q(problem=problem)
).all():
return data_list
db = Tortoise.get_connection("default")
# 模糊匹配
sql = query.filter(word_type=1).sql() + " and POSITION(problem in $1) > 0"
data_list = await db.execute_query_dict(sql, [problem])
if data_list:
return [cls(**data) for data in data_list]
# 正则
sql = (
query.filter(word_type=2, word_scope__not=999).sql() + " and $1 ~ problem;"
)
data_list = await db.execute_query_dict(sql, [problem])
if data_list:
return [cls(**data) for data in data_list]
return None
@classmethod
async def get_answer(
cls,
group_id: str | None,
problem: str,
word_scope: int | None = None,
word_type: int | None = None,
2024-08-10 02:25:04 +08:00
) -> UniMessage | None:
2024-07-27 04:30:03 +08:00
"""根据问题内容获取随机回答
参数:
user_id: 用户id
group_id: 群组id
problem: 问题内容
word_scope: 词条范围
word_type: 词条类型
"""
data_list = await cls.check_problem(group_id, problem, word_scope, word_type)
if data_list:
random_answer = random.choice(data_list)
if random_answer.word_type == 2:
r = re.search(random_answer.problem, problem)
has_placeholder = re.search(rf"\$(\d)", random_answer.answer)
if r and r.groups() and has_placeholder:
pats = re.sub(r"\$(\d)", r"\\\1", random_answer.answer)
random_answer.answer = re.sub(random_answer.problem, pats, problem)
return (
await cls._format2answer(
random_answer.problem,
random_answer.answer,
random_answer.user_id,
random_answer.group_id,
random_answer,
)
if random_answer.placeholder
2024-08-10 02:25:04 +08:00
else MessageUtils.build_message(random_answer.answer)
2024-07-27 04:30:03 +08:00
)
@classmethod
async def get_problem_all_answer(
cls,
problem: str,
index: int | None = None,
group_id: str | None = None,
word_scope: int | None = 0,
2024-08-10 02:25:04 +08:00
) -> tuple[str, list[UniMessage]]:
2024-07-27 04:30:03 +08:00
"""获取指定问题所有回答
参数:
problem: 问题
index: 下标
group_id: 群号
word_scope: 词条范围
返回:
2024-08-10 02:25:04 +08:00
tuple[str, list[UniMessage]]: 问题和所有回答
2024-07-27 04:30:03 +08:00
"""
if index is not None:
# TODO: group_by和order_by不能同时使用
if group_id:
_problem = (
await cls.filter(group_id=group_id).order_by("create_time")
# .group_by("problem")
.values_list("problem", flat=True)
)
else:
_problem = (
await cls.filter(word_scope=(word_scope or 0)).order_by(
"create_time"
)
# .group_by("problem")
.values_list("problem", flat=True)
)
# if index is None and problem not in _problem:
# return "词条不存在...", []
sort_problem = []
for p in _problem:
if p not in sort_problem:
sort_problem.append(p)
if index > len(sort_problem) - 1:
return "下标错误,必须小于问题数量...", []
problem = sort_problem[index] # type: ignore
f = cls.filter(problem=problem, word_scope=(word_scope or 0))
if group_id:
f = f.filter(group_id=group_id)
answer_list = await f.all()
if not answer_list:
return "词条不存在...", []
return problem, [await cls._format2answer("", "", 0, 0, a) for a in answer_list]
@classmethod
async def delete_group_problem(
cls,
problem: str,
group_id: str | None,
index: int | None = None,
word_scope: int = 1,
):
"""删除指定问题全部或指定回答
参数:
problem: 问题文本
group_id: 群号
index: 回答下标
word_scope: 词条范围
"""
if await cls.exists(None, group_id, problem, None, word_scope):
if index is not None:
if group_id:
query = await cls.filter(
group_id=group_id, problem=problem, word_scope=word_scope
).all()
else:
query = await cls.filter(
word_scope=word_scope, problem=problem
).all()
await query[index].delete()
else:
if group_id:
await WordBank.filter(
group_id=group_id, problem=problem, word_scope=word_scope
).delete()
else:
await WordBank.filter(
word_scope=word_scope, problem=problem
).delete()
return True
return False
@classmethod
async def update_group_problem(
cls,
problem: str,
replace_str: str,
group_id: str | None,
index: int | None = None,
word_scope: int = 1,
) -> str:
"""修改词条问题
参数:
problem: 问题
replace_str: 替换问题
group_id: 群号
index: 问题下标
word_scope: 词条范围
返回:
str: 修改前的问题
"""
if index is not None:
if group_id:
query = await cls.filter(group_id=group_id, problem=problem).all()
else:
query = await cls.filter(word_scope=word_scope, problem=problem).all()
tmp = query[index].problem
query[index].problem = replace_str
await query[index].save(update_fields=["problem"])
return tmp
else:
if group_id:
await cls.filter(group_id=group_id, problem=problem).update(
problem=replace_str
)
else:
await cls.filter(word_scope=word_scope, problem=problem).update(
problem=replace_str
)
return problem
@classmethod
async def get_group_all_problem(cls, group_id: str) -> list[tuple[Any | str]]:
"""获取群聊所有词条
参数:
group_id: 群号
"""
return cls._handle_problem(
await cls.filter(group_id=group_id).order_by("create_time").all() # type: ignore
)
@classmethod
async def get_problem_by_scope(cls, word_scope: int):
"""通过词条范围获取词条
参数:
word_scope: 词条范围
"""
return cls._handle_problem(
await cls.filter(word_scope=word_scope).order_by("create_time").all() # type: ignore
)
@classmethod
async def get_problem_by_type(cls, word_type: int):
"""通过词条类型获取词条
参数:
word_type: 词条类型
"""
return cls._handle_problem(
await cls.filter(word_type=word_type).order_by("create_time").all() # type: ignore
)
@classmethod
def _handle_problem(cls, problem_list: list["WordBank"]):
"""格式化处理问题
参数:
msg_list: 消息列表
"""
_tmp = []
result_list = []
for q in problem_list:
if q.problem not in _tmp:
# TODO: 获取收录人名称
problem = (
(path / q.image_path, 30, 30) if q.image_path else q.problem,
int2type[q.word_type],
# q.author,
"-",
)
result_list.append(problem)
_tmp.append(q.problem)
return result_list
@classmethod
async def _move(
cls,
user_id: str,
group_id: str | None,
problem: str,
answer: str,
placeholder: str,
):
"""旧词条图片移动方法
参数:
user_id: 用户id
group_id: 群号
problem: 问题
answer: 回答
placeholder: 占位符
"""
word_scope = 0
word_type = 0
# 对图片做额外处理
if not await cls.exists(
user_id, group_id, problem, answer, word_scope, word_type
):
await cls.create(
user_id=user_id,
group_id=group_id,
word_scope=word_scope,
word_type=word_type,
status=True,
problem=problem,
answer=answer,
image_path=None,
placeholder=placeholder,
create_time=datetime.now().replace(microsecond=0),
update_time=datetime.now().replace(microsecond=0),
)
@classmethod
async def _run_script(cls):
return [
"ALTER TABLE word_bank2 ADD to_me varchar(255);", # 添加 to_me 字段
"ALTER TABLE word_bank2 ALTER COLUMN create_time TYPE timestamp with time zone USING create_time::timestamp with time zone;",
"ALTER TABLE word_bank2 ALTER COLUMN update_time TYPE timestamp with time zone USING update_time::timestamp with time zone;",
"ALTER TABLE word_bank2 RENAME COLUMN user_qq TO user_id;", # 将user_qq改为user_id
"ALTER TABLE word_bank2 ALTER COLUMN user_id TYPE character varying(255);",
"ALTER TABLE word_bank2 ALTER COLUMN group_id TYPE character varying(255);",
"ALTER TABLE word_bank2 ADD platform varchar(255) DEFAULT 'qq';",
"ALTER TABLE word_bank2 ADD author varchar(255) DEFAULT '';",
]