zhenxun_bot/models/chat_history.py

124 lines
3.7 KiB
Python
Raw Normal View History

2022-04-10 22:19:50 +08:00
from datetime import datetime, timedelta
from typing import List, Literal, Optional
from services.db_context import db
class ChatHistory(db.Model):
__tablename__ = "chat_history"
id = db.Column(db.Integer(), primary_key=True)
user_qq = db.Column(db.BigInteger(), nullable=False)
group_id = db.Column(db.BigInteger())
text = db.Column(db.Text())
create_time = db.Column(db.DateTime(timezone=True), nullable=False)
@classmethod
async def add_chat_msg(cls, user_qq: int, group_id: Optional[int], text: str):
await cls.create(
user_qq=user_qq, group_id=group_id, text=text, create_time=datetime.now()
)
@classmethod
async def get_user_msg(
cls,
uid: int,
msg_type: Optional[Literal["private", "group"]],
days: Optional[int] = None,
) -> List["ChatHistory"]:
"""
说明
获取用户消息
参数
:param uid: 用户qq
:param msg_type: 消息类型私聊或群聊
:param days: 限制日期
"""
return await cls._get_msg(uid, None, "user", msg_type, days).gino.all()
@classmethod
async def get_user_msg_count(
cls,
uid: int,
msg_type: Optional[Literal["private", "group"]],
days: Optional[int] = None,
) -> int:
"""
说明
获取用户消息数量
参数
:param uid: 用户qq
:param msg_type: 消息类型私聊或群聊
:param days: 限制日期
"""
return (await cls._get_msg(uid, None, "user", msg_type, days, True).gino.first())[0]
@classmethod
async def get_group_msg(
cls,
gid: int,
days: Optional[int] = None,
) -> List["ChatHistory"]:
"""
说明
获取群聊消息
参数
:param gid: 用户qq
:param days: 限制日期
"""
return await cls._get_msg(None, gid, "group", None, days).gino.all()
@classmethod
async def get_group_msg_count(
cls,
gid: int,
days: Optional[int] = None,
) -> List["ChatHistory"]:
"""
说明
获取群聊消息数量
参数
:param gid: 用户qq
:param days: 限制日期
"""
return (await cls._get_msg(None, gid, "group", None, days, True).gino.first())[0]
@classmethod
def _get_msg(
cls,
uid: Optional[int],
gid: Optional[int],
type_: Literal["user", "group"],
msg_type: Optional[Literal["private", "group"]],
days: Optional[int],
is_select_count: bool = False
):
"""
说明
获取消息查询query
参数
:param uid: 用户qq
:param gid: 群号
:param type_: 类型私聊或群聊
:param msg_type: 消息类型用户或群聊
:param days: 限制日期
"""
if is_select_count:
setattr(ChatHistory, 'count', db.func.count(cls.id).label('count'))
query = cls.select('count')
else:
query = cls.query
if type_ == "user":
query = query.where(cls.user_qq == uid)
if msg_type == "private":
query = query.where(cls.group_id == None)
elif msg_type == "group":
query = query.where(cls.group_id != None)
else:
query = query.where(cls.group_id == gid)
if days:
query = query.where(
cls.create_time >= datetime.now() - timedelta(days=days)
)
return query