zhenxun_bot/plugins/black_word/model.py
2022-05-01 15:53:52 +08:00

150 lines
4.7 KiB
Python

from services.db_context import db
from typing import Optional, List
from datetime import datetime, timedelta
class BlackWord(db.Model):
__tablename__ = "black_word"
id = db.Column(db.Integer(), primary_key=True, autoincrement=True)
user_qq = db.Column(db.BigInteger(), nullable=False, primary_key=True)
group_id = db.Column(db.BigInteger())
plant_text = db.Column(db.String())
black_word = db.Column(db.String())
punish = db.Column(db.String(), default="")
punish_level = db.Column(db.Integer())
create_time = db.Column(db.DateTime(timezone=True), nullable=False)
@classmethod
async def add_user_black_word(
cls,
user_qq: int,
group_id: Optional[int],
black_word: str,
plant_text: str,
punish_level: int,
):
"""
说明:
添加用户发送的敏感词
参数:
:param user_qq: 用户id
:param group_id: 群号
:param black_word: 黑名单词汇
:param plant_text: 消息文本
:param punish_level: 惩罚等级
"""
await cls.create(
user_qq=user_qq,
group_id=group_id,
plant_text=plant_text,
black_word=black_word,
punish_level=punish_level,
create_time=datetime.now(),
)
@classmethod
async def set_user_punish(
cls,
user_qq: int,
punish: str,
black_word: Optional[str] = None,
id_: Optional[int] = None,
) -> bool:
"""
说明:
设置处罚
参数:
:param user_qq: 用户id
:param punish: 处罚
:param black_word: 黑名单词汇
:param id_: 记录下标
"""
user = None
if (not black_word and not id_) or not punish:
return False
query = cls.query.where(cls.user_qq == user_qq).with_for_update()
if black_word:
user = await query.where(cls.black_word == black_word).order_by(cls.id.desc()).gino.first()
elif id_:
user_list = await query.gino.all()
print(len(user_list))
if len(user_list) == 0 or (id_ < 0 or id_ > len(user_list)):
return False
user = user_list[id_]
if not user:
return False
await user.update(punish=cls.punish + punish + " ").apply()
return True
@classmethod
async def get_user_count(
cls, user_qq: int, days: int = 7, punish_level: Optional[int] = None
) -> int:
"""
说明:
获取用户规定周期内的犯事次数
参数:
:param user_qq: 用户qq
:param days: 周期天数
:param punish_level: 惩罚等级
"""
setattr(BlackWord, "count", db.func.count(cls.id).label("count"))
query = cls.select("count").where(
(cls.user_qq == user_qq)
& (cls.punish_level != -1)
& (cls.create_time > datetime.now() - timedelta(days=days))
)
if punish_level is not None:
query = query.where(cls.punish_level == punish_level)
return (await query.gino.first())[0]
@classmethod
async def get_user_punish_level(cls, user_qq: int, days: int = 7) -> Optional[int]:
"""
说明:
获取用户最近一次的惩罚记录等级
参数:
:param user_qq: 用户qq
:param days: 周期天数
"""
if (
query := await cls.query.where(cls.user_qq == user_qq)
.where(cls.create_time > datetime.now() - timedelta(days=days))
.order_by(cls.id.desc())
.gino.first()
):
return query.punish_level
return None
@classmethod
async def get_black_data(
cls,
user_qq: Optional[int],
group_id: Optional[int],
date: Optional[datetime],
date_type: str = "=",
) -> List["BlackWord"]:
"""
说明:
通过指定条件查询数据
参数:
:param user_qq: 用户qq
:param group_id: 群号
:param date: 日期
:param date_type: 日期查询类型
"""
query = cls.query
if user_qq:
query = query.where(cls.user_qq == user_qq)
if group_id:
query = query.where(cls.group_id == group_id)
if date:
if date_type == "=":
query = query.where(cls.create_time == date)
elif date_type == ">":
query = query.where(cls.create_time > date)
elif date_type == "<":
query = query.where(cls.create_time < date)
return await query.gino.all()