Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import uvicorn | |
from fastapi import FastAPI, Depends | |
from starlette.responses import RedirectResponse | |
from starlette.middleware.sessions import SessionMiddleware | |
from authlib.integrations.starlette_client import OAuth, OAuthError | |
from fastapi import Request | |
import os | |
from starlette.config import Config | |
import gradio as gr | |
app = FastAPI() | |
# OAuth settings | |
GOOGLE_CLIENT_ID = os.environ.get("GOOGLE_CLIENT_ID") | |
GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET") | |
SECRET_KEY = os.environ.get("SECRET_KEY") | |
# Set up OAuth | |
config_data = {'GOOGLE_CLIENT_ID': GOOGLE_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_CLIENT_SECRET} | |
starlette_config = Config(environ=config_data) | |
oauth = OAuth(starlette_config) | |
oauth.register( | |
name='google', | |
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', | |
client_kwargs={'scope': 'openid email profile'}, | |
) | |
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) | |
# Dependency to get the current user | |
def get_user(request: Request): | |
user = request.session.get('user') | |
if user: | |
return user['name'] | |
return None | |
def public(request: Request, user = Depends(get_user)): | |
root_url = gr.route_utils.get_root_url(request, "/", None) | |
if user: | |
return RedirectResponse(url=f'{root_url}/gradio/') | |
else: | |
return RedirectResponse(url=f'{root_url}/main/') | |
async def logout(request: Request): | |
request.session.pop('user', None) | |
return RedirectResponse(url='/') | |
async def login(request: Request): | |
root_url = gr.route_utils.get_root_url(request, "/login", None) | |
redirect_uri = f"{root_url}/auth" | |
print("Redirecting to", redirect_uri) | |
return await oauth.google.authorize_redirect(request, redirect_uri) | |
async def auth(request: Request): | |
try: | |
access_token = await oauth.google.authorize_access_token(request) | |
except OAuthError: | |
print("Error getting access token", str(OAuthError)) | |
return RedirectResponse(url='/') | |
request.session['user'] = dict(access_token)["userinfo"] | |
print("Redirecting to /gradio") | |
return RedirectResponse(url='/gradio') | |
with gr.Blocks() as login_demo: | |
btn = gr.Button("Login") | |
_js_redirect = """ | |
() => { | |
url = '/login' + window.location.search; | |
window.open(url, '_blank'); | |
} | |
""" | |
btn.click(None, js=_js_redirect) | |
app = gr.mount_gradio_app(app, login_demo, path="/main") | |
def greet(request: gr.Request): | |
return f"Welcome to Gradio, {request.username}" | |
with gr.Blocks() as main_demo: | |
m = gr.Markdown("Welcome to Gradio!") | |
gr.Button("Logout", link="/logout") | |
main_demo.load(greet, None, m) | |
app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user) | |
if __name__ == '__main__': | |
uvicorn.run(app) | |