zhenxun_bot/zhenxun/plugins/draw_card/handles/base_handle.py
2024-08-10 02:25:04 +08:00

296 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import random
from asyncio.exceptions import TimeoutError
from datetime import datetime
from typing import Generic, TypeVar
import aiohttp
import anyio
import ujson as json
from nonebot_plugin_saa import Image as SaaImage
from nonebot_plugin_saa import MessageFactory, Text
from PIL import Image
from pydantic import BaseModel, Extra
from zhenxun.configs.path_config import DATA_PATH
from zhenxun.services.log import logger
from zhenxun.utils.image_utils import BuildImage
from ..config import DRAW_PATH, draw_config
from ..util import circled_number, cn2py
class BaseData(BaseModel, extra=Extra.ignore):
name: str # 名字
star: int # 星级
limited: bool # 限定
def __eq__(self, other: "BaseData"):
return self.name == other.name
def __hash__(self):
return hash(self.name)
@property
def star_str(self) -> str:
return "".join(["" for _ in range(self.star)])
class UpChar(BaseData):
zoom: float # up提升倍率
class UpEvent(BaseModel):
title: str # up池标题
pool_img: str # up池封面
start_time: datetime | None # 开始时间
end_time: datetime | None # 结束时间
up_char: list[UpChar] # up对象
TC = TypeVar("TC", bound="BaseData")
class BaseHandle(Generic[TC]):
def __init__(self, game_name: str, game_name_cn: str):
self.game_name = game_name
self.game_name_cn = game_name_cn
self.max_star = 1 # 最大星级
self.game_card_color: str = "#ffffff"
self.data_path = DATA_PATH / "draw_card"
self.img_path = DRAW_PATH / f"{self.game_name}"
self.up_path = DATA_PATH / "draw_card" / "draw_card_up"
self.img_path.mkdir(parents=True, exist_ok=True)
self.up_path.mkdir(parents=True, exist_ok=True)
self.data_files: list[str] = [f"{self.game_name}.json"]
async def draw(self, count: int, **kwargs) -> MessageFactory:
index2card = self.get_cards(count, **kwargs)
cards = [card[0] for card in index2card]
result = self.format_result(index2card)
gen_img = await self.generate_img(cards)
return MessageFactory([SaaImage(gen_img.pic2bytes()), Text(result)])
# 抽取卡池
def get_card(self, **kwargs) -> TC:
raise NotImplementedError
def get_cards(self, count: int, **kwargs) -> list[tuple[TC, int]]:
return [(self.get_card(**kwargs), i) for i in range(count)]
# 获取星级
@staticmethod
def get_star(star_list: list[int], probability_list: list[float]) -> int:
return random.choices(star_list, weights=probability_list, k=1)[0]
def format_result(self, index2card: list[tuple[TC, int]], **kwargs) -> str:
card_list = [card[0] for card in index2card]
results = [
self.format_star_result(card_list, **kwargs),
self.format_max_star(index2card, **kwargs),
self.format_max_card(card_list, **kwargs),
]
results = [rst for rst in results if rst]
return "\n".join(results)
def format_star_result(self, card_list: list[TC], **kwargs) -> str:
star_dict: dict[str, int] = {} # 记录星级及其次数
card_list_sorted = sorted(card_list, key=lambda c: c.star, reverse=True)
for card in card_list_sorted:
try:
star_dict[card.star_str] += 1
except KeyError:
star_dict[card.star_str] = 1
rst = ""
for star_str, count in star_dict.items():
rst += f"[{star_str}×{count}] "
return rst.strip()
def format_max_star(
self, card_list: list[tuple[TC, int]], up_list: list[str] = [], **kwargs
) -> str:
up_list = up_list or kwargs.get("up_list", [])
rst = ""
for card, index in card_list:
if card.star == self.max_star:
if card.name in up_list:
rst += f"{index} 抽获取UP {card.name}\n"
else:
rst += f"{index} 抽获取 {card.name}\n"
return rst.strip()
def format_max_card(self, card_list: list[TC], **kwargs) -> str:
card_dict: dict[TC, int] = {} # 记录卡牌抽取次数
for card in card_list:
try:
card_dict[card] += 1
except KeyError:
card_dict[card] = 1
max_count = max(card_dict.values())
max_card = list(card_dict.keys())[list(card_dict.values()).index(max_count)]
if max_count <= 1:
return ""
return f"抽取到最多的是{max_card.name},共抽取了{max_count}"
async def generate_img(
self,
cards: list[TC],
num_per_line: int = 5,
max_per_line: tuple[int, int] = (40, 10),
) -> BuildImage:
"""
生成统计图片
cards: 卡牌列表
num_per_line: 单行角色显示数量
max_per_line: 当card_list超过一定数值时更改单行数量
"""
if len(cards) > max_per_line[0]:
num_per_line = max_per_line[1]
if len(cards) > 90:
card_dict: dict[TC, int] = {} # 记录卡牌抽取次数
for card in cards:
try:
card_dict[card] += 1
except KeyError:
card_dict[card] = 1
card_list = list(card_dict)
num_list = list(card_dict.values())
else:
card_list = cards
num_list = [1] * len(cards)
card_imgs: list[BuildImage] = []
for card, num in zip(card_list, num_list):
card_img = await self.generate_card_img(card)
# 数量 > 1 时加数字上标
if num > 1:
label = circled_number(num)
label_w = int(min(card_img.width, card_img.height) / 7)
label = label.resize(
(
int(label_w * label.width / label.height),
label_w,
),
Image.ANTIALIAS, # type: ignore
)
await card_img.paste(label)
card_imgs.append(card_img)
# img_w = card_imgs[0].width
# img_h = card_imgs[0].height
# if len(card_imgs) < num_per_line:
# w = img_w * len(card_imgs)
# else:
# w = img_w * num_per_line
# h = img_h * math.ceil(len(card_imgs) / num_per_line)
# img = BuildImage(w, h, img_w, img_h, color=self.game_card_color)
# for card_img in card_imgs:
# await img.paste(card_img)
return await BuildImage.auto_paste(card_imgs, 10, color=self.game_card_color) # type: ignore
async def generate_card_img(self, card: TC) -> BuildImage:
img = str(self.img_path / f"{cn2py(card.name)}.png")
return BuildImage(100, 100, background=img)
def load_data(self, filename: str = "") -> dict:
if not filename:
filename = f"{self.game_name}.json"
filepath = self.data_path / filename
if not filepath.exists():
return {}
with filepath.open("r", encoding="utf8") as f:
return json.load(f)
def dump_data(self, data: dict, filename: str = ""):
if not filename:
filename = f"{self.game_name}.json"
filepath = self.data_path / filename
with filepath.open("w", encoding="utf8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def data_exists(self) -> bool:
for file in self.data_files:
if not (self.data_path / file).exists():
return False
return True
def _init_data(self):
raise NotImplementedError
def init_data(self):
try:
self._init_data()
except Exception as e:
logger.warning(f"{self.game_name_cn} 导入角色数据错误:{type(e)}{e}")
async def _update_info(self):
raise NotImplementedError
def client(self) -> aiohttp.ClientSession:
headers = {
"User-Agent": '"Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; TencentTraveler 4.0)"'
}
return aiohttp.ClientSession(headers=headers)
async def update_info(self):
try:
async with asyncio.Semaphore(draw_config.SEMAPHORE):
async with self.client() as session:
self.session = session
await self._update_info()
except Exception as e:
logger.warning(f"{self.game_name_cn} 更新数据错误:{type(e)}{e}")
self.init_data()
async def get_url(self, url: str) -> str:
result = ""
retry = 5
for i in range(retry):
try:
async with self.session.get(url, timeout=10) as response:
result = await response.text()
break
except TimeoutError:
logger.warning(f"访问 {url} 超时, 重试 {i + 1}/{retry}")
await asyncio.sleep(1)
return result
async def download_img(self, url: str, name: str) -> bool:
img_path = self.img_path / f"{cn2py(name)}.png"
if img_path.exists():
return True
try:
async with self.session.get(url, timeout=10) as response:
async with await anyio.open_file(img_path, "wb") as f:
await f.write(await response.read())
return True
except TimeoutError:
logger.warning(
f"下载 {self.game_name_cn} 图片超时,名称:{name}url{url}"
)
return False
except:
logger.warning(
f"下载 {self.game_name_cn} 链接错误,名称:{name}url{url}"
)
return False
async def _reload_pool(self) -> MessageFactory | None:
return None
async def reload_pool(self) -> MessageFactory | None:
try:
async with self.client() as session:
self.session = session
return await self._reload_pool()
except Exception as e:
logger.warning(f"{self.game_name_cn} 重载UP池错误", e=e)
def reset_count(self, user_id: str) -> bool:
return False