Spaces:
Sleeping
Sleeping
Create authentication.py
Browse files
private_gpt/server/utils/authentication.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import timedelta, datetime
|
2 |
+
from typing import Annotated
|
3 |
+
from fastapi import APIRouter, Depends, HTTPException
|
4 |
+
from starlette import status
|
5 |
+
from passlib.context import CryptContext
|
6 |
+
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
|
7 |
+
from jose import jwt, JWTError
|
8 |
+
from pydantic import BaseModel
|
9 |
+
|
10 |
+
router = APIRouter(
|
11 |
+
prefix='/v1/auth',
|
12 |
+
tags=['auth']
|
13 |
+
)
|
14 |
+
|
15 |
+
SECRET_KEY = '1b510738bd0cd2d5757a27935aa4355b04efcd268a180fafd5fb7d8de60cc73a'
|
16 |
+
ALGORITHM = 'HS256'
|
17 |
+
|
18 |
+
bcrypt_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
|
19 |
+
oauth2_bearer = OAuth2PasswordBearer(tokenUrl='/v1/auth/token')
|
20 |
+
|
21 |
+
hardcoded_users = [
|
22 |
+
{ "id": 1,
|
23 |
+
"username": "test",
|
24 |
+
"password_hash": bcrypt_context.hash("secret"),
|
25 |
+
"role": "user"
|
26 |
+
},
|
27 |
+
|
28 |
+
{ "id": 2,
|
29 |
+
"username": "admin",
|
30 |
+
"password_hash": bcrypt_context.hash("admin"),
|
31 |
+
"role": "admin"
|
32 |
+
},
|
33 |
+
|
34 |
+
|
35 |
+
]
|
36 |
+
|
37 |
+
|
38 |
+
class CreateUserRequest(BaseModel):
|
39 |
+
username: str
|
40 |
+
password: str
|
41 |
+
|
42 |
+
class Token(BaseModel):
|
43 |
+
access_token: str
|
44 |
+
token_type: str
|
45 |
+
|
46 |
+
def authenticate_user(username: str, password: str, role: str):
|
47 |
+
for user in hardcoded_users:
|
48 |
+
if user["username"] == username:
|
49 |
+
stored_password_hash = user.get("password_hash")
|
50 |
+
stored_role = user.get("role")
|
51 |
+
if (
|
52 |
+
stored_password_hash
|
53 |
+
and stored_role
|
54 |
+
and bcrypt_context.verify(password, stored_password_hash)
|
55 |
+
):
|
56 |
+
return {"username": username, "id": user["id"], "role": stored_role}
|
57 |
+
return None
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
@router.post("/", status_code=status.HTTP_201_CREATED)
|
62 |
+
async def create_user(create_user_request: CreateUserRequest):
|
63 |
+
# This function is not necessary for hardcoded users, as users are predefined
|
64 |
+
pass
|
65 |
+
|
66 |
+
@router.post("/token", response_model=Token)
|
67 |
+
async def login_for_access_token(
|
68 |
+
form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
|
69 |
+
):
|
70 |
+
user = authenticate_user(form_data.username, form_data.password, role="user")
|
71 |
+
if not user:
|
72 |
+
raise HTTPException(
|
73 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
74 |
+
detail="Could not validate user.",
|
75 |
+
)
|
76 |
+
|
77 |
+
token = create_access_token(user["username"], user["id"], user["role"], timedelta(minutes=10080))
|
78 |
+
|
79 |
+
return Token(access_token=token, token_type="bearer")
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
def create_access_token(username: str, user_id: int, role: str, expires_delta: timedelta):
|
84 |
+
encode = {'sub': username, 'id': user_id, 'role': role}
|
85 |
+
expires = datetime.utcnow() + expires_delta
|
86 |
+
encode.update({'exp': expires})
|
87 |
+
return jwt.encode(encode, SECRET_KEY, algorithm=ALGORITHM)
|
88 |
+
|
89 |
+
|
90 |
+
async def get_current_user(token: Annotated[str, Depends(oauth2_bearer)]):
|
91 |
+
try:
|
92 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=ALGORITHM)
|
93 |
+
username: str = payload.get('sub')
|
94 |
+
user_id: int = payload.get('id')
|
95 |
+
role: str = payload.get('role') # Add this line to get the role
|
96 |
+
if username is None or user_id is None:
|
97 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Could not validate user.')
|
98 |
+
return {'username': username, 'id': user_id, 'role': role} # Include the role in the returned dictionary
|
99 |
+
except JWTError:
|
100 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate user.")
|
101 |
+
|