|
from typing import Any, Dict, List, Optional, Union, cast |
|
|
|
from fastapi.exceptions import HTTPException |
|
from fastapi.openapi.models import OAuth2 as OAuth2Model |
|
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel |
|
from fastapi.param_functions import Form |
|
from fastapi.security.base import SecurityBase |
|
from fastapi.security.utils import get_authorization_scheme_param |
|
from starlette.requests import Request |
|
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN |
|
|
|
|
|
from typing_extensions import Annotated |
|
|
|
|
|
class OAuth2PasswordRequestForm: |
|
""" |
|
This is a dependency class, use it like: |
|
|
|
@app.post("/login") |
|
def login(form_data: OAuth2PasswordRequestForm = Depends()): |
|
data = form_data.parse() |
|
print(data.username) |
|
print(data.password) |
|
for scope in data.scopes: |
|
print(scope) |
|
if data.client_id: |
|
print(data.client_id) |
|
if data.client_secret: |
|
print(data.client_secret) |
|
return data |
|
|
|
|
|
It creates the following Form request parameters in your endpoint: |
|
|
|
grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password". |
|
Nevertheless, this dependency class is permissive and allows not passing it. If you want to enforce it, |
|
use instead the OAuth2PasswordRequestFormStrict dependency. |
|
username: username string. The OAuth2 spec requires the exact field name "username". |
|
password: password string. The OAuth2 spec requires the exact field name "password". |
|
scope: Optional string. Several scopes (each one a string) separated by spaces. E.g. |
|
"items:read items:write users:read profile openid" |
|
client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any) |
|
using HTTP Basic auth, as: client_id:client_secret |
|
client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any) |
|
using HTTP Basic auth, as: client_id:client_secret |
|
""" |
|
|
|
def __init__( |
|
self, |
|
*, |
|
grant_type: Annotated[Union[str, None], Form(pattern="password")] = None, |
|
username: Annotated[str, Form()], |
|
password: Annotated[str, Form()], |
|
scope: Annotated[str, Form()] = "", |
|
client_id: Annotated[Union[str, None], Form()] = None, |
|
client_secret: Annotated[Union[str, None], Form()] = None, |
|
): |
|
self.grant_type = grant_type |
|
self.username = username |
|
self.password = password |
|
self.scopes = scope.split() |
|
self.client_id = client_id |
|
self.client_secret = client_secret |
|
|
|
|
|
class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm): |
|
""" |
|
This is a dependency class, use it like: |
|
|
|
@app.post("/login") |
|
def login(form_data: OAuth2PasswordRequestFormStrict = Depends()): |
|
data = form_data.parse() |
|
print(data.username) |
|
print(data.password) |
|
for scope in data.scopes: |
|
print(scope) |
|
if data.client_id: |
|
print(data.client_id) |
|
if data.client_secret: |
|
print(data.client_secret) |
|
return data |
|
|
|
|
|
It creates the following Form request parameters in your endpoint: |
|
|
|
grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password". |
|
This dependency is strict about it. If you want to be permissive, use instead the |
|
OAuth2PasswordRequestForm dependency class. |
|
username: username string. The OAuth2 spec requires the exact field name "username". |
|
password: password string. The OAuth2 spec requires the exact field name "password". |
|
scope: Optional string. Several scopes (each one a string) separated by spaces. E.g. |
|
"items:read items:write users:read profile openid" |
|
client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any) |
|
using HTTP Basic auth, as: client_id:client_secret |
|
client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any) |
|
using HTTP Basic auth, as: client_id:client_secret |
|
""" |
|
|
|
def __init__( |
|
self, |
|
grant_type: Annotated[str, Form(pattern="password")], |
|
username: Annotated[str, Form()], |
|
password: Annotated[str, Form()], |
|
scope: Annotated[str, Form()] = "", |
|
client_id: Annotated[Union[str, None], Form()] = None, |
|
client_secret: Annotated[Union[str, None], Form()] = None, |
|
): |
|
super().__init__( |
|
grant_type=grant_type, |
|
username=username, |
|
password=password, |
|
scope=scope, |
|
client_id=client_id, |
|
client_secret=client_secret, |
|
) |
|
|
|
|
|
class OAuth2(SecurityBase): |
|
def __init__( |
|
self, |
|
*, |
|
flows: Union[OAuthFlowsModel, Dict[str, Dict[str, Any]]] = OAuthFlowsModel(), |
|
scheme_name: Optional[str] = None, |
|
description: Optional[str] = None, |
|
auto_error: bool = True, |
|
): |
|
self.model = OAuth2Model( |
|
flows=cast(OAuthFlowsModel, flows), description=description |
|
) |
|
self.scheme_name = scheme_name or self.__class__.__name__ |
|
self.auto_error = auto_error |
|
|
|
async def __call__(self, request: Request) -> Optional[str]: |
|
authorization = request.headers.get("Authorization") |
|
if not authorization: |
|
if self.auto_error: |
|
raise HTTPException( |
|
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" |
|
) |
|
else: |
|
return None |
|
return authorization |
|
|
|
|
|
class OAuth2PasswordBearer(OAuth2): |
|
def __init__( |
|
self, |
|
tokenUrl: str, |
|
scheme_name: Optional[str] = None, |
|
scopes: Optional[Dict[str, str]] = None, |
|
description: Optional[str] = None, |
|
auto_error: bool = True, |
|
): |
|
if not scopes: |
|
scopes = {} |
|
flows = OAuthFlowsModel( |
|
password=cast(Any, {"tokenUrl": tokenUrl, "scopes": scopes}) |
|
) |
|
super().__init__( |
|
flows=flows, |
|
scheme_name=scheme_name, |
|
description=description, |
|
auto_error=auto_error, |
|
) |
|
|
|
async def __call__(self, request: Request) -> Optional[str]: |
|
authorization = request.headers.get("Authorization") |
|
scheme, param = get_authorization_scheme_param(authorization) |
|
if not authorization or scheme.lower() != "bearer": |
|
if self.auto_error: |
|
raise HTTPException( |
|
status_code=HTTP_401_UNAUTHORIZED, |
|
detail="Not authenticated", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
else: |
|
return None |
|
return param |
|
|
|
|
|
class OAuth2AuthorizationCodeBearer(OAuth2): |
|
def __init__( |
|
self, |
|
authorizationUrl: str, |
|
tokenUrl: str, |
|
refreshUrl: Optional[str] = None, |
|
scheme_name: Optional[str] = None, |
|
scopes: Optional[Dict[str, str]] = None, |
|
description: Optional[str] = None, |
|
auto_error: bool = True, |
|
): |
|
if not scopes: |
|
scopes = {} |
|
flows = OAuthFlowsModel( |
|
authorizationCode=cast( |
|
Any, |
|
{ |
|
"authorizationUrl": authorizationUrl, |
|
"tokenUrl": tokenUrl, |
|
"refreshUrl": refreshUrl, |
|
"scopes": scopes, |
|
}, |
|
) |
|
) |
|
super().__init__( |
|
flows=flows, |
|
scheme_name=scheme_name, |
|
description=description, |
|
auto_error=auto_error, |
|
) |
|
|
|
async def __call__(self, request: Request) -> Optional[str]: |
|
authorization = request.headers.get("Authorization") |
|
scheme, param = get_authorization_scheme_param(authorization) |
|
if not authorization or scheme.lower() != "bearer": |
|
if self.auto_error: |
|
raise HTTPException( |
|
status_code=HTTP_401_UNAUTHORIZED, |
|
detail="Not authenticated", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
else: |
|
return None |
|
return param |
|
|
|
|
|
class SecurityScopes: |
|
def __init__(self, scopes: Optional[List[str]] = None): |
|
self.scopes = scopes or [] |
|
self.scope_str = " ".join(self.scopes) |
|
|