mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
149 lines
4.7 KiB
Python
149 lines
4.7 KiB
Python
from services.db_context import db
|
|
from typing import Optional, List
|
|
from datetime import datetime, timedelta
|
|
|
|
|
|
class BlackWord(db.Model):
|
|
__tablename__ = "black_word"
|
|
|
|
id = db.Column(db.Integer(), primary_key=True, autoincrement=True)
|
|
user_qq = db.Column(db.BigInteger(), nullable=False, primary_key=True)
|
|
group_id = db.Column(db.BigInteger())
|
|
plant_text = db.Column(db.String())
|
|
black_word = db.Column(db.String())
|
|
punish = db.Column(db.String(), default="")
|
|
punish_level = db.Column(db.Integer())
|
|
create_time = db.Column(db.DateTime(timezone=True), nullable=False)
|
|
|
|
@classmethod
|
|
async def add_user_black_word(
|
|
cls,
|
|
user_qq: int,
|
|
group_id: Optional[int],
|
|
black_word: str,
|
|
plant_text: str,
|
|
punish_level: int,
|
|
):
|
|
"""
|
|
说明:
|
|
添加用户发送的敏感词
|
|
参数:
|
|
:param user_qq: 用户id
|
|
:param group_id: 群号
|
|
:param black_word: 黑名单词汇
|
|
:param plant_text: 消息文本
|
|
:param punish_level: 惩罚等级
|
|
"""
|
|
await cls.create(
|
|
user_qq=user_qq,
|
|
group_id=group_id,
|
|
plant_text=plant_text,
|
|
black_word=black_word,
|
|
punish_level=punish_level,
|
|
create_time=datetime.now(),
|
|
)
|
|
|
|
@classmethod
|
|
async def set_user_punish(
|
|
cls,
|
|
user_qq: int,
|
|
punish: str,
|
|
black_word: Optional[str] = None,
|
|
id_: Optional[int] = None,
|
|
) -> bool:
|
|
"""
|
|
说明:
|
|
设置处罚
|
|
参数:
|
|
:param user_qq: 用户id
|
|
:param punish: 处罚
|
|
:param black_word: 黑名单词汇
|
|
:param id_: 记录下标
|
|
"""
|
|
user = None
|
|
if (not black_word and not id_) or not punish:
|
|
return False
|
|
query = cls.query.where(cls.user_qq == user_qq).with_for_update()
|
|
if black_word:
|
|
user = await query.where(cls.black_word == black_word).order_by(cls.id.desc()).gino.first()
|
|
elif id_:
|
|
user_list = await query.gino.all()
|
|
if len(user_list) == 0 or (id_ < 0 or id_ > len(user_list)):
|
|
return False
|
|
user = user_list[id_]
|
|
if not user:
|
|
return False
|
|
await user.update(punish=cls.punish + punish + " ").apply()
|
|
return True
|
|
|
|
@classmethod
|
|
async def get_user_count(
|
|
cls, user_qq: int, days: int = 7, punish_level: Optional[int] = None
|
|
) -> int:
|
|
"""
|
|
说明:
|
|
获取用户规定周期内的犯事次数
|
|
参数:
|
|
:param user_qq: 用户qq
|
|
:param days: 周期天数
|
|
:param punish_level: 惩罚等级
|
|
"""
|
|
setattr(BlackWord, "count", db.func.count(cls.id).label("count"))
|
|
query = cls.select("count").where(
|
|
(cls.user_qq == user_qq)
|
|
& (cls.punish_level != -1)
|
|
& (cls.create_time > datetime.now() - timedelta(days=days))
|
|
)
|
|
if punish_level is not None:
|
|
query = query.where(cls.punish_level == punish_level)
|
|
return (await query.gino.first())[0]
|
|
|
|
@classmethod
|
|
async def get_user_punish_level(cls, user_qq: int, days: int = 7) -> Optional[int]:
|
|
"""
|
|
说明:
|
|
获取用户最近一次的惩罚记录等级
|
|
参数:
|
|
:param user_qq: 用户qq
|
|
:param days: 周期天数
|
|
"""
|
|
if (
|
|
query := await cls.query.where(cls.user_qq == user_qq)
|
|
.where(cls.create_time > datetime.now() - timedelta(days=days))
|
|
.order_by(cls.id.desc())
|
|
.gino.first()
|
|
):
|
|
return query.punish_level
|
|
return None
|
|
|
|
@classmethod
|
|
async def get_black_data(
|
|
cls,
|
|
user_qq: Optional[int],
|
|
group_id: Optional[int],
|
|
date: Optional[datetime],
|
|
date_type: str = "=",
|
|
) -> List["BlackWord"]:
|
|
"""
|
|
说明:
|
|
通过指定条件查询数据
|
|
参数:
|
|
:param user_qq: 用户qq
|
|
:param group_id: 群号
|
|
:param date: 日期
|
|
:param date_type: 日期查询类型
|
|
"""
|
|
query = cls.query
|
|
if user_qq:
|
|
query = query.where(cls.user_qq == user_qq)
|
|
if group_id:
|
|
query = query.where(cls.group_id == group_id)
|
|
if date:
|
|
if date_type == "=":
|
|
query = query.where(cls.create_time == date)
|
|
elif date_type == ">":
|
|
query = query.where(cls.create_time > date)
|
|
elif date_type == "<":
|
|
query = query.where(cls.create_time < date)
|
|
return await query.gino.all()
|