Spaces:
Sleeping
Sleeping
from datetime import timedelta, datetime | |
from typing import Annotated | |
from fastapi import APIRouter, Depends, HTTPException | |
from starlette import status | |
from passlib.context import CryptContext | |
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer | |
from jose import jwt, JWTError | |
from pydantic import BaseModel | |
import uuid | |
import os | |
router = APIRouter( | |
prefix='/v1/auth', | |
tags=['auth'] | |
) | |
SECRET_KEY = os.environ.get("PASSWORD") | |
ALGORITHM = 'HS256' | |
bcrypt_context = CryptContext(schemes=['bcrypt'], deprecated='auto') | |
oauth2_bearer = OAuth2PasswordBearer(tokenUrl='/v1/auth/token') | |
hardcoded_users = [ | |
{ "id": str(uuid.uuid4()), | |
"first_name": "Ibraaheem", | |
"last_name": "Akbar", | |
"username": "test", | |
"password_hash": bcrypt_context.hash(os.environ.get("USER_HASH")), | |
"role": "user" | |
}, | |
{ "id": str(uuid.uuid4()), | |
"username": "admin", | |
"first_name": "John", | |
"last_name": "Doe", | |
"password_hash": bcrypt_context.hash(os.environ.get("ADMIN_HASH")), | |
"role": "admin" | |
}, | |
] | |
class CreateUserRequest(BaseModel): | |
username: str | |
password: str | |
first_name: str | |
last_name: str | |
class Token(BaseModel): | |
access_token: str | |
token_type: str | |
def authenticate_user(username: str, password: str, role: str): | |
for user in hardcoded_users: | |
if user["username"] == username: | |
stored_password_hash = user.get("password_hash") | |
stored_role = user.get("role") | |
if ( | |
stored_password_hash | |
and stored_role | |
and bcrypt_context.verify(password, stored_password_hash) | |
): | |
# Include 'first_name' and 'last_name' in the user dictionary | |
user_data = { | |
"username": username, | |
"id": user["id"], | |
"role": stored_role, | |
"first_name": user.get("first_name", ""), | |
"last_name": user.get("last_name", ""), | |
} | |
return user_data | |
return None | |
async def create_user(create_user_request: CreateUserRequest): | |
user_id = str(uuid.uuid4()) | |
user_data = { | |
"id": user_id, | |
"first_name": create_user_request.first_name, | |
"last_name": create_user_request.last_name, | |
"username": create_user_request.username, | |
"password_hash": bcrypt_context.hash(create_user_request.password), | |
"role": "user" | |
} | |
hardcoded_users.append(user_data) | |
return {"message": "User created successfully"} | |
async def login_for_access_token( | |
form_data: Annotated[OAuth2PasswordRequestForm, Depends()] | |
): | |
user = authenticate_user(form_data.username, form_data.password, role="user") | |
if not user: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Could not validate user.", | |
) | |
token = create_access_token(user["username"], user["id"], user["role"], user["first_name"], user["last_name"], timedelta(minutes=10080)) | |
return Token(access_token=token, token_type="bearer") | |
def create_access_token(username: str, user_id: int, role: str, first_name: str, last_name: str, expires_delta: timedelta): | |
encode = {'sub': username, 'id': user_id, 'role': role, 'first_name': first_name, 'last_name': last_name} | |
expires = datetime.utcnow() + expires_delta | |
encode.update({'exp': expires}) | |
return jwt.encode(encode, SECRET_KEY, algorithm=ALGORITHM) | |
async def get_current_user(token: Annotated[str, Depends(oauth2_bearer)]): | |
try: | |
payload = jwt.decode(token, SECRET_KEY, algorithms=ALGORITHM) | |
username: str = payload.get('sub') | |
user_id: int = payload.get('id') | |
role: str = payload.get('role') # Add this line to get the role | |
first_name: str = payload.get('first_name') | |
last_name: str = payload.get('last_name') | |
if username is None or user_id is None: | |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Could not validate user.') | |
return {'username': username, 'id': user_id, 'role': role, 'first_name': first_name, 'last_name': last_name} | |
except JWTError: | |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate user.") | |