zhenxun_bot/plugins/web_ui/auth/__init__.py

108 lines
3.1 KiB
Python
Raw Normal View History

2022-06-05 19:51:23 +08:00
import json
2022-04-04 20:33:37 +08:00
from datetime import datetime, timedelta
from typing import Optional
2023-04-01 01:50:34 +08:00
import nonebot
2022-04-04 20:33:37 +08:00
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
2023-04-01 01:50:34 +08:00
from jose import JWTError, jwt
2022-04-04 20:33:37 +08:00
from pydantic import BaseModel
2023-04-01 01:50:34 +08:00
from starlette import status
2022-04-04 20:33:37 +08:00
from configs.config import Config
2023-04-01 01:50:34 +08:00
from configs.path_config import DATA_PATH
2022-04-04 20:33:37 +08:00
2023-04-01 01:50:34 +08:00
from ..config import Result, router
2022-06-05 19:51:23 +08:00
2022-04-04 20:33:37 +08:00
app = nonebot.get_app()
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
2023-04-01 01:50:34 +08:00
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/login")
2022-04-04 20:33:37 +08:00
2022-06-05 19:51:23 +08:00
token_file = DATA_PATH / "web_ui" / "token.json"
token_file.parent.mkdir(parents=True, exist_ok=True)
token_data = {"token": []}
if token_file.exists():
2023-04-01 01:50:34 +08:00
token_data = json.load(open(token_file, "r", encoding="utf8"))
2022-06-05 19:51:23 +08:00
2022-04-04 20:33:37 +08:00
class User(BaseModel):
username: str
password: str
class Token(BaseModel):
access_token: str
token_type: str
def get_user(uname: str) -> Optional[User]:
username = Config.get_config("web-ui", "username")
password = Config.get_config("web-ui", "password")
if username and password and uname == username:
return User(username=username, password=password)
form_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
def create_token(user: User, expires_delta: Optional[timedelta] = None):
2023-04-01 01:50:34 +08:00
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
2022-04-04 20:33:37 +08:00
return jwt.encode(
claims={"sub": user.username, "exp": expire},
key=SECRET_KEY,
2023-04-01 01:50:34 +08:00
algorithm=ALGORITHM,
2022-04-04 20:33:37 +08:00
)
2023-04-01 01:50:34 +08:00
@router.post("/login")
2022-04-04 20:33:37 +08:00
async def login_get_token(form_data: OAuth2PasswordRequestForm = Depends()):
2023-04-01 01:50:34 +08:00
user = get_user(form_data.username)
2022-04-04 20:33:37 +08:00
if not user or user.password != form_data.password:
raise form_exception
2023-04-01 01:50:34 +08:00
access_token = create_token(
user=user, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
2022-06-05 19:51:23 +08:00
token_data["token"].append(access_token)
if len(token_data["token"]) > 3:
token_data["token"] = token_data["token"][1:]
2023-04-01 01:50:34 +08:00
with open(token_file, "w", encoding="utf8") as f:
2022-06-12 23:16:07 +08:00
json.dump(token_data, f, ensure_ascii=False, indent=4)
2022-04-04 20:33:37 +08:00
return {"access_token": access_token, "token_type": "bearer"}
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
2023-04-01 01:50:34 +08:00
@app.post("/auth")
2022-04-04 20:33:37 +08:00
def token_to_user(token: str = Depends(oauth2_scheme)):
2022-06-05 19:51:23 +08:00
if token not in token_data["token"]:
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username, expire = payload.get("sub"), payload.get("exp")
2023-04-01 01:50:34 +08:00
user = get_user(username) # type: ignore
2022-06-05 19:51:23 +08:00
if user is None:
raise JWTError
except JWTError:
return Result(code=401)
2023-04-01 01:50:34 +08:00
return Result(code=200, info="登录成功")
2022-04-04 20:33:37 +08:00
2023-04-01 01:50:34 +08:00
if __name__ == "__main__":
2022-04-04 20:33:37 +08:00
import uvicorn
2023-04-01 01:50:34 +08:00
2022-04-04 20:33:37 +08:00
uvicorn.run(app, host="127.0.0.1", port=8080)