zhenxun_bot/models/chat_history.py
2022-04-26 14:45:04 +08:00

215 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime, timedelta
from typing import List, Literal, Optional, Tuple, Union
from services.db_context import db
class ChatHistory(db.Model):
__tablename__ = "chat_history"
id = db.Column(db.Integer(), primary_key=True)
user_qq = db.Column(db.BigInteger(), nullable=False)
group_id = db.Column(db.BigInteger())
text = db.Column(db.Text())
create_time = db.Column(db.DateTime(timezone=True), nullable=False)
@classmethod
async def add_chat_msg(cls, user_qq: int, group_id: Optional[int], text: str):
await cls.create(
user_qq=user_qq, group_id=group_id, text=text, create_time=datetime.now()
)
@classmethod
async def get_user_msg(
cls,
uid: int,
msg_type: Optional[Literal["private", "group"]],
days: Optional[int] = None,
) -> List["ChatHistory"]:
"""
说明:
获取用户消息
参数:
:param uid: 用户qq
:param msg_type: 消息类型,私聊或群聊
:param days: 限制日期
"""
return await cls._get_msg(uid, None, "user", msg_type, days).gino.all()
@classmethod
async def get_group_user_msg(
cls,
uid: int,
gid: int,
limit: int = 10,
date_scope: Tuple[datetime, datetime] = None,
) -> List["ChatHistory"]:
"""
说明:
获取群聊指定用户聊天记录
参数:
:param uid: qq
:param gid: 群号
:param limit: 获取数量
:param date_scope: 日期范围默认None为全搜索
"""
return (
await cls._get_msg(uid, gid, "group", days=date_scope)
.limit(limit)
.gino.all()
)
@classmethod
async def get_group_user_msg_count(cls, uid: int, gid: int) -> Optional[int]:
"""
说明:
查询群聊指定用户的聊天记录数量
参数:
:param uid: qq
:param gid: 群号
"""
if x := await db.first(
db.text(
f"SELECT COUNT(id) as sum FROM public.chat_history WHERE user_qq = {uid} AND group_id = {gid}"
)
):
return x[0]
return None
@classmethod
async def get_group_msg_rank(
cls,
gid: int,
limit: int = 10,
order: str = "DESC",
date_scope: Optional[Tuple[datetime, datetime]] = None,
) -> Optional[Tuple[int, int]]:
"""
说明:
获取排行数据
参数:
:param gid: 群号
:param limit: 获取数量
:param order: 排序类型descdes
:param date_scope: 日期范围
"""
sql = f"SELECT user_qq, COUNT(id) as sum FROM public.chat_history WHERE group_id = {gid} "
if date_scope:
sql += f"AND create_time BETWEEN '{date_scope[0]}' AND '{date_scope[1]}' "
sql += f"GROUP BY user_qq ORDER BY sum {order if order and order.upper() != 'DES' else ''} LIMIT {limit}"
print(sql)
return await db.all(db.text(sql))
@classmethod
async def get_group_first_msg_datetime(cls, gid: int) -> Optional[datetime]:
"""
说明:
获取群第一条记录消息时间
参数:
:param gid:
"""
if (
msg := await cls.query.where(cls.group_id == gid)
.order_by(cls.create_time)
.gino.first()
):
return msg.create_time
return None
@classmethod
async def get_user_msg_count(
cls,
uid: int,
msg_type: Optional[Literal["private", "group"]],
days: Optional[int] = None,
) -> int:
"""
说明:
获取用户消息数量
参数:
:param uid: 用户qq
:param msg_type: 消息类型,私聊或群聊
:param days: 限制日期
"""
return (
await cls._get_msg(uid, None, "user", msg_type, days, True).gino.first()
)[0]
@classmethod
async def get_group_msg(
cls,
gid: int,
days: Optional[int] = None,
) -> List["ChatHistory"]:
"""
说明:
获取群聊消息
参数:
:param gid: 用户qq
:param days: 限制日期
"""
return await cls._get_msg(None, gid, "group", None, days).gino.all()
@classmethod
async def get_group_msg_count(
cls,
gid: int,
days: Optional[int] = None,
) -> List["ChatHistory"]:
"""
说明:
获取群聊消息数量
参数:
:param gid: 用户qq
:param days: 限制日期
"""
return (await cls._get_msg(None, gid, "group", None, days, True).gino.first())[
0
]
@classmethod
def _get_msg(
cls,
uid: Optional[int],
gid: Optional[int],
type_: Literal["user", "group"],
msg_type: Optional[Literal["private", "group"]] = None,
days: Optional[Union[int, Tuple[datetime, datetime]]] = None,
is_select_count: bool = False,
):
"""
说明:
获取消息查询query
参数:
:param uid: 用户qq
:param gid: 群号
:param type_: 类型,私聊或群聊
:param msg_type: 消息类型,用户或群聊
:param days: 限制日期
"""
if is_select_count:
setattr(ChatHistory, "count", db.func.count(cls.id).label("count"))
query = cls.select("count")
else:
query = cls.query
if type_ == "user":
query = query.where(cls.user_qq == uid)
if msg_type == "private":
query = query.where(cls.group_id == None)
elif msg_type == "group":
query = query.where(cls.group_id != None)
else:
query = query.where(cls.group_id == gid)
if uid:
query = query.where(cls.user_qq == uid)
if days:
if isinstance(days, int):
query = query.where(
cls.create_time >= datetime.now() - timedelta(days=days)
)
elif isinstance(days, tuple):
query = query.where(cls.create_time >= days[0]).where(
cls.create_time <= days[1]
)
return query