Spaces:
Running
Running
File size: 12,822 Bytes
b72ab63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 |
from __future__ import annotations
import hashlib
import os
import typing
import urllib.parse
import warnings
from dataclasses import dataclass, field
import fastapi
from fastapi.responses import RedirectResponse
from huggingface_hub import HfFolder, whoami
from .utils import get_space
OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID")
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET")
OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES")
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL")
def attach_oauth(app: fastapi.FastAPI):
try:
from starlette.middleware.sessions import SessionMiddleware
except ImportError as e:
raise ImportError(
"Cannot initialize OAuth to due a missing library. Please run `pip install gradio[oauth]` or add "
"`gradio[oauth]` to your requirements.txt file in order to install the required dependencies."
) from e
# Add `/login/huggingface`, `/login/callback` and `/logout` routes to enable OAuth in the Gradio app.
# If the app is running in a Space, OAuth is enabled normally. Otherwise, we mock the "real" routes to make the
# user log in with a fake user profile - without any calls to hf.co.
if get_space() is not None:
_add_oauth_routes(app)
else:
_add_mocked_oauth_routes(app)
# Session Middleware requires a secret key to sign the cookies. Let's use a hash
# of the OAuth secret key to make it unique to the Space + updated in case OAuth
# config gets updated.
session_secret = (OAUTH_CLIENT_SECRET or "") + "-v4"
# ^ if we change the session cookie format in the future, we can bump the version of the session secret to make
# sure cookies are invalidated. Otherwise some users with an old cookie format might get a HTTP 500 error.
app.add_middleware(
SessionMiddleware,
secret_key=hashlib.sha256(session_secret.encode()).hexdigest(),
same_site="none",
https_only=True,
)
def _add_oauth_routes(app: fastapi.FastAPI) -> None:
"""Add OAuth routes to the FastAPI app (login, callback handler and logout)."""
try:
from authlib.integrations.base_client.errors import MismatchingStateError
from authlib.integrations.starlette_client import OAuth
except ImportError as e:
raise ImportError(
"Cannot initialize OAuth to due a missing library. Please run `pip install gradio[oauth]` or add "
"`gradio[oauth]` to your requirements.txt file in order to install the required dependencies."
) from e
# Check environment variables
msg = (
"OAuth is required but {} environment variable is not set. Make sure you've enabled OAuth in your Space by"
" setting `hf_oauth: true` in the Space metadata."
)
if OAUTH_CLIENT_ID is None:
raise ValueError(msg.format("OAUTH_CLIENT_ID"))
if OAUTH_CLIENT_SECRET is None:
raise ValueError(msg.format("OAUTH_CLIENT_SECRET"))
if OAUTH_SCOPES is None:
raise ValueError(msg.format("OAUTH_SCOPES"))
if OPENID_PROVIDER_URL is None:
raise ValueError(msg.format("OPENID_PROVIDER_URL"))
# Register OAuth server
oauth = OAuth()
oauth.register(
name="huggingface",
client_id=OAUTH_CLIENT_ID,
client_secret=OAUTH_CLIENT_SECRET,
client_kwargs={"scope": OAUTH_SCOPES},
server_metadata_url=OPENID_PROVIDER_URL + "/.well-known/openid-configuration",
)
# Define OAuth routes
@app.get("/login/huggingface")
async def oauth_login(request: fastapi.Request):
"""Endpoint that redirects to HF OAuth page."""
# Define target (where to redirect after login)
redirect_uri = _generate_redirect_uri(request)
return await oauth.huggingface.authorize_redirect(request, redirect_uri) # type: ignore
@app.get("/login/callback")
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
"""Endpoint that handles the OAuth callback."""
try:
oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore
except MismatchingStateError:
# If the state mismatch, it is very likely that the cookie is corrupted.
# There is a bug reported in authlib that causes the token to grow indefinitely if the user tries to login
# repeatedly. Since cookies cannot get bigger than 4kb, the token will be truncated at some point - hence
# losing the state. A workaround is to delete the cookie and redirect the user to the login page again.
# See https://github.com/lepture/authlib/issues/622 for more details.
login_uri = "/login/huggingface"
if "_target_url" in request.query_params:
login_uri += (
"?"
+ urllib.parse.urlencode( # Keep same _target_url as before
{"_target_url": request.query_params["_target_url"]}
)
)
for key in list(request.session.keys()):
# Delete all keys that are related to the OAuth state
if key.startswith("_state_huggingface"):
request.session.pop(key)
return RedirectResponse(login_uri)
# OAuth login worked => store the user info in the session and redirect
request.session["oauth_info"] = oauth_info
return _redirect_to_target(request)
@app.get("/logout")
async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
"""Endpoint that logs out the user (e.g. delete cookie session)."""
request.session.pop("oauth_info", None)
return _redirect_to_target(request)
def _add_mocked_oauth_routes(app: fastapi.FastAPI) -> None:
"""Add fake oauth routes if Gradio is run locally and OAuth is enabled.
Clicking on a gr.LoginButton will have the same behavior as in a Space (i.e. gets redirected in a new tab) but
instead of authenticating with HF, a mocked user profile is added to the session.
"""
warnings.warn(
"Gradio does not support OAuth features outside of a Space environment. To help"
" you debug your app locally, the login and logout buttons are mocked with your"
" profile. To make it work, your machine must be logged in to Huggingface."
)
mocked_oauth_info = _get_mocked_oauth_info()
# Define OAuth routes
@app.get("/login/huggingface")
async def oauth_login(request: fastapi.Request): # noqa: ARG001
"""Fake endpoint that redirects to HF OAuth page."""
# Define target (where to redirect after login)
redirect_uri = _generate_redirect_uri(request)
return RedirectResponse(
"/login/callback?" + urllib.parse.urlencode({"_target_url": redirect_uri})
)
@app.get("/login/callback")
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
"""Endpoint that handles the OAuth callback."""
request.session["oauth_info"] = mocked_oauth_info
return _redirect_to_target(request)
@app.get("/logout")
async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
"""Endpoint that logs out the user (e.g. delete cookie session)."""
request.session.pop("oauth_info", None)
logout_url = str(request.url).replace("/logout", "/") # preserve query params
return RedirectResponse(url=logout_url)
def _generate_redirect_uri(request: fastapi.Request) -> str:
if "_target_url" in request.query_params:
# if `_target_url` already in query params => respect it
target = request.query_params["_target_url"]
else:
# otherwise => keep query params
target = "/?" + urllib.parse.urlencode(request.query_params)
redirect_uri = request.url_for("oauth_redirect_callback").include_query_params(
_target_url=target
)
redirect_uri_as_str = str(redirect_uri)
if redirect_uri.netloc.endswith(".hf.space"):
# In Space, FastAPI redirect as http but we want https
redirect_uri_as_str = redirect_uri_as_str.replace("http://", "https://")
return redirect_uri_as_str
def _redirect_to_target(
request: fastapi.Request, default_target: str = "/"
) -> RedirectResponse:
target = request.query_params.get("_target_url", default_target)
return RedirectResponse(target)
@dataclass
class OAuthProfile(typing.Dict): # inherit from Dict for backward compatibility
"""
A Gradio OAuthProfile object that can be used to inject the profile of a user in a
function. If a function expects `OAuthProfile` or `Optional[OAuthProfile]` as input,
the value will be injected from the FastAPI session if the user is logged in. If the
user is not logged in and the function expects `OAuthProfile`, an error will be
raised.
Attributes:
name (str): The name of the user (e.g. 'Abubakar Abid').
username (str): The username of the user (e.g. 'abidlabs')
profile (str): The profile URL of the user (e.g. 'https://huggingface.co/abidlabs').
picture (str): The profile picture URL of the user.
Example:
import gradio as gr
from typing import Optional
def hello(profile: Optional[gr.OAuthProfile]) -> str:
if profile is None:
return "I don't know you."
return f"Hello {profile.name}"
with gr.Blocks() as demo:
gr.LoginButton()
gr.Markdown().attach_load_event(hello, None)
"""
name: str = field(init=False)
username: str = field(init=False)
profile: str = field(init=False)
picture: str = field(init=False)
def __init__(self, data: dict): # hack to make OAuthProfile backward compatible
self.update(data)
self.name = self["name"]
self.username = self["preferred_username"]
self.profile = self["profile"]
self.picture = self["picture"]
@dataclass
class OAuthToken:
"""
A Gradio OAuthToken object that can be used to inject the access token of a user in a
function. If a function expects `OAuthToken` or `Optional[OAuthToken]` as input,
the value will be injected from the FastAPI session if the user is logged in. If the
user is not logged in and the function expects `OAuthToken`, an error will be
raised.
Attributes:
token (str): The access token of the user.
scope (str): The scope of the access token.
expires_at (int): The expiration timestamp of the access token.
Example:
import gradio as gr
from typing import Optional
from huggingface_hub import whoami
def list_organizations(oauth_token: Optional[gr.OAuthToken]) -> str:
if oauth_token is None:
return "Please log in to list organizations."
org_names = [org["name"] for org in whoami(oauth_token.token)["orgs"]]
return f"You belong to {', '.join(org_names)}."
with gr.Blocks() as demo:
gr.LoginButton()
gr.Markdown().attach_load_event(list_organizations, None)
"""
token: str
scope: str
expires_at: int
def _get_mocked_oauth_info() -> typing.Dict:
token = HfFolder.get_token()
if token is None:
raise ValueError(
"Your machine must be logged in to HF to debug a Gradio app locally. Please"
" run `huggingface-cli login` or set `HF_TOKEN` as environment variable "
"with one of your access token. You can generate a new token in your "
"settings page (https://huggingface.co/settings/tokens)."
)
user = whoami()
if user["type"] != "user":
raise ValueError(
"Your machine is not logged in with a personal account. Please use a "
"personal access token. You can generate a new token in your settings page"
" (https://huggingface.co/settings/tokens)."
)
return {
"access_token": token,
"token_type": "bearer",
"expires_in": 3600,
"id_token": "AAAAAAAAAAAAAAAAAAAAAAAAAA",
"scope": "openid profile",
"expires_at": 1691676444,
"userinfo": {
"sub": "11111111111111111111111",
"name": user["fullname"],
"preferred_username": user["name"],
"profile": f"https://huggingface.co/{user['name']}",
"picture": user["avatarUrl"],
"website": "",
"aud": "00000000-0000-0000-0000-000000000000",
"auth_time": 1691672844,
"nonce": "aaaaaaaaaaaaaaaaaaa",
"iat": 1691672844,
"exp": 1691676444,
"iss": "https://huggingface.co",
},
}
|