File size: 3,306 Bytes
69ab315
 
 
 
 
 
 
 
 
 
 
 
 
 
ad9237d
69ab315
 
 
 
 
 
 
 
ad9237d
69ab315
 
 
 
 
ad9237d
69ab315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from datetime import timedelta, datetime
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from starlette import status
from passlib.context import CryptContext
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
from jose import jwt, JWTError
from pydantic import BaseModel

router = APIRouter(
    prefix='/v1/auth',
    tags=['auth']
)

SECRET_KEY = PASSWORD
ALGORITHM = 'HS256'

bcrypt_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
oauth2_bearer = OAuth2PasswordBearer(tokenUrl='/v1/auth/token')

hardcoded_users = [
    {   "id": 1, 
        "username": "test",
        "password_hash": bcrypt_context.hash(USER_HASH),
        "role": "user"
    },

    {   "id": 2, 
        "username": "admin",
        "password_hash": bcrypt_context.hash(ADMIN_HASH),
        "role": "admin"
    },
  
  
]


class CreateUserRequest(BaseModel):
    username: str
    password: str

class Token(BaseModel):
    access_token: str
    token_type: str

def authenticate_user(username: str, password: str, role: str):
    for user in hardcoded_users:
        if user["username"] == username:
            stored_password_hash = user.get("password_hash")
            stored_role = user.get("role")
            if (
                stored_password_hash
                and stored_role
                and bcrypt_context.verify(password, stored_password_hash)
            ):
                return {"username": username, "id": user["id"], "role": stored_role}
    return None



@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_user(create_user_request: CreateUserRequest):
    # This function is not necessary for hardcoded users, as users are predefined
    pass

@router.post("/token", response_model=Token)
async def login_for_access_token(
    form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
):
    user = authenticate_user(form_data.username, form_data.password, role="user")
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Could not validate user.",
        )

    token = create_access_token(user["username"], user["id"], user["role"], timedelta(minutes=10080))

    return Token(access_token=token, token_type="bearer")



def create_access_token(username: str, user_id: int, role: str, expires_delta: timedelta):
    encode = {'sub': username, 'id': user_id, 'role': role}  
    expires = datetime.utcnow() + expires_delta
    encode.update({'exp': expires})
    return jwt.encode(encode, SECRET_KEY, algorithm=ALGORITHM)


async def get_current_user(token: Annotated[str, Depends(oauth2_bearer)]):
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=ALGORITHM)
        username: str = payload.get('sub')
        user_id: int = payload.get('id')
        role: str = payload.get('role')  # Add this line to get the role
        if username is None or user_id is None:
            raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Could not validate user.')
        return {'username': username, 'id': user_id, 'role': role}  # Include the role in the returned dictionary
    except JWTError:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate user.")