zhenxun_bot/models/omega_pixiv_illusts.py
2021-09-09 16:01:50 +08:00

140 lines
4.4 KiB
Python

from typing import Optional, List
from datetime import datetime
from services.db_context import db
class OmegaPixivIllusts(db.Model):
__tablename__ = "omega_pixiv_illusts"
id = db.Column(db.Integer(), primary_key=True)
pid = db.Column(db.BigInteger(), nullable=False)
uid = db.Column(db.BigInteger(), nullable=False)
title = db.Column(db.String(), nullable=False)
uname = db.Column(db.String(), nullable=False)
nsfw_tag = db.Column(db.Integer(), nullable=False)
width = db.Column(db.Integer(), nullable=False)
height = db.Column(db.Integer(), nullable=False)
tags = db.Column(db.String(), nullable=False)
url = db.Column(db.String(), nullable=False)
created_at = db.Column(db.DateTime(timezone=True))
updated_at = db.Column(db.DateTime(timezone=True))
_idx1 = db.Index("omega_pixiv_illusts_idx1", "pid", "url", unique=True)
@classmethod
async def add_image_data(
cls,
pid: int,
title: str,
width: int,
height: int,
url: str,
uid: int,
uname: str,
nsfw_tag: int,
tags: str,
created_at: datetime,
updated_at: datetime,
):
"""
说明:
添加图片信息
参数:
:param pid: pid
:param title: 标题
:param width: 宽度
:param height: 长度
:param url: url链接
:param uid: 作者uid
:param uname: 作者名称
:param nsfw_tag: nsfw标签, 0=safe, 1=setu. 2=r18
:param tags: 相关tag
:param created_at: 创建日期
:param updated_at: 更新日期
"""
if not await cls.check_exists(pid):
await cls.create(
pid=pid,
title=title,
width=width,
height=height,
url=url,
uid=uid,
uname=uname,
nsfw_tag=nsfw_tag,
tags=tags,
)
return True
return False
@classmethod
async def query_images(
cls,
keywords: Optional[List[str]] = None,
uid: Optional[int] = None,
pid: Optional[int] = None,
nsfw_tag: Optional[int] = 0,
num: int = 100
) -> List[Optional["OmegaPixivIllusts"]]:
"""
说明:
查找符合条件的图片
参数:
:param keywords: 关键词
:param uid: 画师uid
:param pid: 图片pid
:param nsfw_tag: nsfw标签, 0=safe, 1=setu. 2=r18
:param num: 获取图片数量
"""
if nsfw_tag is not None:
query = cls.query.where(cls.nsfw_tag == nsfw_tag)
else:
query = cls.query
if keywords:
for keyword in keywords:
query = query.where(cls.tags.contains(keyword))
elif uid:
query = query.where(cls.uid == uid)
elif pid:
query = query.where(cls.uid == pid)
query = query.order_by(db.func.random()).limit(num)
return await query.gino.all()
@classmethod
async def check_exists(cls, pid: int) -> bool:
"""
说明:
检测pid是否已存在
参数:
:param pid: 图片PID
"""
query = await cls.query.where(cls.pid == pid).gino.all()
return bool(query)
@classmethod
async def get_keyword_num(cls, tags: List[str] = None) -> "int, int, int":
"""
说明:
获取相关关键词(keyword, tag)在图库中的数量
参数:
:param tags: 关键词/Tag
"""
setattr(OmegaPixivIllusts, 'count', db.func.count(cls.pid).label('count'))
query = cls.select('count')
if tags:
for tag in tags:
query = query.where(cls.tags.contains(tag))
count = await query.where(cls.nsfw_tag == 0).gino.first()
setu_count = await query.where(cls.nsfw_tag == 1).gino.first()
r18_count = await query.where(cls.nsfw_tag == 2).gino.first()
return count[0], setu_count[0], r18_count[0]
@classmethod
async def get_all_pid(cls) -> List[int]:
"""
说明:
获取所有图片PID
"""
data = await cls.select('pid').gino.all()
return [x[0] for x in data]