|
from __future__ import annotations |
|
|
|
import hashlib |
|
import os |
|
import typing |
|
import warnings |
|
|
|
import fastapi |
|
from fastapi.responses import RedirectResponse |
|
|
|
from .utils import get_space |
|
|
|
OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID") |
|
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET") |
|
OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES") |
|
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL") |
|
|
|
|
|
def attach_oauth(app: fastapi.FastAPI): |
|
try: |
|
from starlette.middleware.sessions import SessionMiddleware |
|
except ImportError as e: |
|
raise ImportError( |
|
"Cannot initialize OAuth to due a missing library. Please run `pip install gradio[oauth]` or add " |
|
"`gradio[oauth]` to your requirements.txt file in order to install the required dependencies." |
|
) from e |
|
|
|
|
|
|
|
|
|
if get_space() is not None: |
|
_add_oauth_routes(app) |
|
else: |
|
_add_mocked_oauth_routes(app) |
|
|
|
|
|
|
|
|
|
app.add_middleware( |
|
SessionMiddleware, |
|
secret_key=hashlib.sha256((OAUTH_CLIENT_SECRET or "").encode()).hexdigest(), |
|
same_site="none", |
|
https_only=True, |
|
) |
|
|
|
|
|
def _add_oauth_routes(app: fastapi.FastAPI) -> None: |
|
"""Add OAuth routes to the FastAPI app (login, callback handler and logout).""" |
|
try: |
|
from authlib.integrations.starlette_client import OAuth |
|
except ImportError as e: |
|
raise ImportError( |
|
"Cannot initialize OAuth to due a missing library. Please run `pip install gradio[oauth]` or add " |
|
"`gradio[oauth]` to your requirements.txt file in order to install the required dependencies." |
|
) from e |
|
|
|
|
|
msg = ( |
|
"OAuth is required but {} environment variable is not set. Make sure you've enabled OAuth in your Space by" |
|
" setting `hf_oauth: true` in the Space metadata." |
|
) |
|
if OAUTH_CLIENT_ID is None: |
|
raise ValueError(msg.format("OAUTH_CLIENT_ID")) |
|
if OAUTH_CLIENT_SECRET is None: |
|
raise ValueError(msg.format("OAUTH_CLIENT_SECRET")) |
|
if OAUTH_SCOPES is None: |
|
raise ValueError(msg.format("OAUTH_SCOPES")) |
|
if OPENID_PROVIDER_URL is None: |
|
raise ValueError(msg.format("OPENID_PROVIDER_URL")) |
|
|
|
|
|
oauth = OAuth() |
|
oauth.register( |
|
name="huggingface", |
|
client_id=OAUTH_CLIENT_ID, |
|
client_secret=OAUTH_CLIENT_SECRET, |
|
client_kwargs={"scope": OAUTH_SCOPES}, |
|
server_metadata_url=OPENID_PROVIDER_URL + "/.well-known/openid-configuration", |
|
) |
|
|
|
|
|
@app.get("/login/huggingface") |
|
async def oauth_login(request: fastapi.Request): |
|
"""Endpoint that redirects to HF OAuth page.""" |
|
redirect_uri = str(request.url_for("oauth_redirect_callback")) |
|
if ".hf.space" in redirect_uri: |
|
|
|
redirect_uri = redirect_uri.replace("http://", "https://") |
|
return await oauth.huggingface.authorize_redirect(request, redirect_uri) |
|
|
|
@app.get("/login/callback") |
|
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: |
|
"""Endpoint that handles the OAuth callback.""" |
|
token = await oauth.huggingface.authorize_access_token(request) |
|
request.session["oauth_profile"] = token["userinfo"] |
|
request.session["oauth_token"] = token |
|
return RedirectResponse("/") |
|
|
|
@app.get("/logout") |
|
async def oauth_logout(request: fastapi.Request) -> RedirectResponse: |
|
"""Endpoint that logs out the user (e.g. delete cookie session).""" |
|
request.session.pop("oauth_profile", None) |
|
request.session.pop("oauth_token", None) |
|
return RedirectResponse("/") |
|
|
|
|
|
def _add_mocked_oauth_routes(app: fastapi.FastAPI) -> None: |
|
"""Add fake oauth routes if Gradio is run locally and OAuth is enabled. |
|
|
|
Clicking on a gr.LoginButton will have the same behavior as in a Space (i.e. gets redirected in a new tab) but |
|
instead of authenticating with HF, a mocked user profile is added to the session. |
|
""" |
|
warnings.warn( |
|
"Gradio does not support OAuth features outside of a Space environment. " |
|
"To help you debug your app locally, the login and logout buttons are mocked with a fake user profile." |
|
) |
|
|
|
|
|
@app.get("/login/huggingface") |
|
async def oauth_login(request: fastapi.Request): |
|
"""Fake endpoint that redirects to HF OAuth page.""" |
|
return RedirectResponse("/login/callback") |
|
|
|
@app.get("/login/callback") |
|
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: |
|
"""Endpoint that handles the OAuth callback.""" |
|
request.session["oauth_profile"] = MOCKED_OAUTH_TOKEN["userinfo"] |
|
request.session["oauth_token"] = MOCKED_OAUTH_TOKEN |
|
return RedirectResponse("/") |
|
|
|
@app.get("/logout") |
|
async def oauth_logout(request: fastapi.Request) -> RedirectResponse: |
|
"""Endpoint that logs out the user (e.g. delete cookie session).""" |
|
request.session.pop("oauth_profile", None) |
|
request.session.pop("oauth_token", None) |
|
return RedirectResponse("/") |
|
|
|
|
|
class OAuthProfile(typing.Dict): |
|
""" |
|
A Gradio OAuthProfile object that can be used to inject the profile of a user in a |
|
function. If a function expects `OAuthProfile` or `Optional[OAuthProfile]` as input, |
|
the value will be injected from the FastAPI session if the user is logged in. If the |
|
user is not logged in and the function expects `OAuthProfile`, an error will be |
|
raised. |
|
|
|
Example: |
|
import gradio as gr |
|
from typing import Optional |
|
|
|
|
|
def hello(profile: Optional[gr.OAuthProfile]) -> str: |
|
if profile is None: |
|
return "I don't know you." |
|
return f"Hello {profile.name}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.LoginButton() |
|
gr.LogoutButton() |
|
gr.Markdown().attach_load_event(hello, None) |
|
""" |
|
|
|
|
|
MOCKED_OAUTH_TOKEN = { |
|
"access_token": "hf_oauth_AAAAAAAAAAAAAAAAAAAAAAAAAA", |
|
"token_type": "bearer", |
|
"expires_in": 3600, |
|
"id_token": "AAAAAAAAAAAAAAAAAAAAAAAAAA", |
|
"scope": "openid profile", |
|
"expires_at": 1691676444, |
|
"userinfo": { |
|
"sub": "11111111111111111111111", |
|
"name": "Fake Gradio User", |
|
"preferred_username": "FakeGradioUser", |
|
"profile": "https://huggingface.co/FakeGradioUser", |
|
"picture": "https://huggingface.co/front/assets/huggingface_logo-noborder.svg", |
|
"website": "", |
|
"aud": "00000000-0000-0000-0000-000000000000", |
|
"auth_time": 1691672844, |
|
"nonce": "aaaaaaaaaaaaaaaaaaa", |
|
"iat": 1691672844, |
|
"exp": 1691676444, |
|
"iss": "https://huggingface.co", |
|
}, |
|
} |
|
|