File size: 7,288 Bytes
870ab6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from __future__ import annotations
import hashlib
import os
import typing
import warnings
import fastapi
from fastapi.responses import RedirectResponse
from .utils import get_space
OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID")
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET")
OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES")
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL")
def attach_oauth(app: fastapi.FastAPI):
try:
from starlette.middleware.sessions import SessionMiddleware
except ImportError as e:
raise ImportError(
"Cannot initialize OAuth to due a missing library. Please run `pip install gradio[oauth]` or add "
"`gradio[oauth]` to your requirements.txt file in order to install the required dependencies."
) from e
# Add `/login/huggingface`, `/login/callback` and `/logout` routes to enable OAuth in the Gradio app.
# If the app is running in a Space, OAuth is enabled normally. Otherwise, we mock the "real" routes to make the
# user log in with a fake user profile - without any calls to hf.co.
if get_space() is not None:
_add_oauth_routes(app)
else:
_add_mocked_oauth_routes(app)
# Session Middleware requires a secret key to sign the cookies. Let's use a hash
# of the OAuth secret key to make it unique to the Space + updated in case OAuth
# config gets updated.
app.add_middleware(
SessionMiddleware,
secret_key=hashlib.sha256((OAUTH_CLIENT_SECRET or "").encode()).hexdigest(),
same_site="none",
https_only=True,
)
def _add_oauth_routes(app: fastapi.FastAPI) -> None:
"""Add OAuth routes to the FastAPI app (login, callback handler and logout)."""
try:
from authlib.integrations.starlette_client import OAuth
except ImportError as e:
raise ImportError(
"Cannot initialize OAuth to due a missing library. Please run `pip install gradio[oauth]` or add "
"`gradio[oauth]` to your requirements.txt file in order to install the required dependencies."
) from e
# Check environment variables
msg = (
"OAuth is required but {} environment variable is not set. Make sure you've enabled OAuth in your Space by"
" setting `hf_oauth: true` in the Space metadata."
)
if OAUTH_CLIENT_ID is None:
raise ValueError(msg.format("OAUTH_CLIENT_ID"))
if OAUTH_CLIENT_SECRET is None:
raise ValueError(msg.format("OAUTH_CLIENT_SECRET"))
if OAUTH_SCOPES is None:
raise ValueError(msg.format("OAUTH_SCOPES"))
if OPENID_PROVIDER_URL is None:
raise ValueError(msg.format("OPENID_PROVIDER_URL"))
# Register OAuth server
oauth = OAuth()
oauth.register(
name="huggingface",
client_id=OAUTH_CLIENT_ID,
client_secret=OAUTH_CLIENT_SECRET,
client_kwargs={"scope": OAUTH_SCOPES},
server_metadata_url=OPENID_PROVIDER_URL + "/.well-known/openid-configuration",
)
# Define OAuth routes
@app.get("/login/huggingface")
async def oauth_login(request: fastapi.Request):
"""Endpoint that redirects to HF OAuth page."""
redirect_uri = str(request.url_for("oauth_redirect_callback"))
if ".hf.space" in redirect_uri:
# In Space, FastAPI redirect as http but we want https
redirect_uri = redirect_uri.replace("http://", "https://")
return await oauth.huggingface.authorize_redirect(request, redirect_uri)
@app.get("/login/callback")
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
"""Endpoint that handles the OAuth callback."""
token = await oauth.huggingface.authorize_access_token(request)
request.session["oauth_profile"] = token["userinfo"]
request.session["oauth_token"] = token
return RedirectResponse("/")
@app.get("/logout")
async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
"""Endpoint that logs out the user (e.g. delete cookie session)."""
request.session.pop("oauth_profile", None)
request.session.pop("oauth_token", None)
return RedirectResponse("/")
def _add_mocked_oauth_routes(app: fastapi.FastAPI) -> None:
"""Add fake oauth routes if Gradio is run locally and OAuth is enabled.
Clicking on a gr.LoginButton will have the same behavior as in a Space (i.e. gets redirected in a new tab) but
instead of authenticating with HF, a mocked user profile is added to the session.
"""
warnings.warn(
"Gradio does not support OAuth features outside of a Space environment. "
"To help you debug your app locally, the login and logout buttons are mocked with a fake user profile."
)
# Define OAuth routes
@app.get("/login/huggingface")
async def oauth_login(request: fastapi.Request):
"""Fake endpoint that redirects to HF OAuth page."""
return RedirectResponse("/login/callback")
@app.get("/login/callback")
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
"""Endpoint that handles the OAuth callback."""
request.session["oauth_profile"] = MOCKED_OAUTH_TOKEN["userinfo"]
request.session["oauth_token"] = MOCKED_OAUTH_TOKEN
return RedirectResponse("/")
@app.get("/logout")
async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
"""Endpoint that logs out the user (e.g. delete cookie session)."""
request.session.pop("oauth_profile", None)
request.session.pop("oauth_token", None)
return RedirectResponse("/")
class OAuthProfile(typing.Dict):
"""
A Gradio OAuthProfile object that can be used to inject the profile of a user in a
function. If a function expects `OAuthProfile` or `Optional[OAuthProfile]` as input,
the value will be injected from the FastAPI session if the user is logged in. If the
user is not logged in and the function expects `OAuthProfile`, an error will be
raised.
Example:
import gradio as gr
from typing import Optional
def hello(profile: Optional[gr.OAuthProfile]) -> str:
if profile is None:
return "I don't know you."
return f"Hello {profile.name}"
with gr.Blocks() as demo:
gr.LoginButton()
gr.LogoutButton()
gr.Markdown().attach_load_event(hello, None)
"""
MOCKED_OAUTH_TOKEN = {
"access_token": "hf_oauth_AAAAAAAAAAAAAAAAAAAAAAAAAA",
"token_type": "bearer",
"expires_in": 3600,
"id_token": "AAAAAAAAAAAAAAAAAAAAAAAAAA",
"scope": "openid profile",
"expires_at": 1691676444,
"userinfo": {
"sub": "11111111111111111111111",
"name": "Fake Gradio User",
"preferred_username": "FakeGradioUser",
"profile": "https://huggingface.co/FakeGradioUser",
"picture": "https://huggingface.co/front/assets/huggingface_logo-noborder.svg",
"website": "",
"aud": "00000000-0000-0000-0000-000000000000",
"auth_time": 1691672844,
"nonce": "aaaaaaaaaaaaaaaaaaa",
"iat": 1691672844,
"exp": 1691676444,
"iss": "https://huggingface.co",
},
}
|