Spaces:
Runtime error
Runtime error
import json | |
from abc import ABC, abstractmethod | |
from datetime import datetime, timedelta | |
from urllib.parse import urlencode | |
import aiohttp | |
from fastapi import Depends, Path | |
from reworkd_platform.db.crud.oauth import OAuthCrud | |
from reworkd_platform.db.models.auth import OauthCredentials | |
from reworkd_platform.schemas import UserBase | |
from reworkd_platform.services.security import encryption_service | |
from reworkd_platform.settings import Settings | |
from reworkd_platform.settings import settings as platform_settings | |
from reworkd_platform.web.api.http_responses import forbidden | |
class OAuthInstaller(ABC): | |
def __init__(self, crud: OAuthCrud, settings: Settings): | |
self.crud = crud | |
self.settings = settings | |
async def install(self, user: UserBase, redirect_uri: str) -> str: | |
raise NotImplementedError() | |
async def install_callback(self, code: str, state: str) -> OauthCredentials: | |
raise NotImplementedError() | |
async def uninstall(self, user: UserBase) -> bool: | |
raise NotImplementedError() | |
def store_access_token(creds: OauthCredentials, access_token: str) -> None: | |
creds.access_token_enc = encryption_service.encrypt(access_token) | |
def store_refresh_token(creds: OauthCredentials, refresh_token: str) -> None: | |
creds.refresh_token_enc = encryption_service.encrypt(refresh_token) | |
class SIDInstaller(OAuthInstaller): | |
PROVIDER = "sid" | |
async def install(self, user: UserBase, redirect_uri: str) -> str: | |
# gracefully handle the case where the installation already exists | |
# this can happen if the user starts the process from multiple tabs | |
installation = await self.crud.get_installation_by_user_id( | |
user.id, self.PROVIDER | |
) | |
if not installation: | |
installation = await self.crud.create_installation( | |
user, | |
self.PROVIDER, | |
redirect_uri, | |
) | |
scopes = ["data:query", "offline_access"] | |
params = { | |
"client_id": self.settings.sid_client_id, | |
"redirect_uri": self.settings.sid_redirect_uri, | |
"response_type": "code", | |
"scope": " ".join(scopes), | |
"state": installation.state, | |
"audience": "https://api.sid.ai/api/v1/", | |
} | |
auth_url = "https://me.sid.ai/api/oauth/authorize" | |
auth_url += "?" + urlencode(params) | |
return auth_url | |
async def install_callback(self, code: str, state: str) -> OauthCredentials: | |
creds = await self.crud.get_installation_by_state(state) | |
if not creds: | |
raise forbidden() | |
req = { | |
"grant_type": "authorization_code", | |
"client_id": self.settings.sid_client_id, | |
"client_secret": self.settings.sid_client_secret, | |
"redirect_uri": self.settings.sid_redirect_uri, | |
"code": code, | |
} | |
async with aiohttp.ClientSession() as session: | |
async with session.post( | |
"https://auth.sid.ai/oauth/token", | |
headers={ | |
"Content-Type": "application/json", | |
"Accept": "application/json", | |
}, | |
data=json.dumps(req), | |
) as response: | |
res_data = await response.json() | |
OAuthInstaller.store_access_token(creds, res_data["access_token"]) | |
OAuthInstaller.store_refresh_token(creds, res_data["refresh_token"]) | |
creds.access_token_expiration = datetime.now() + timedelta( | |
seconds=res_data["expires_in"] | |
) | |
return await creds.save(self.crud.session) | |
async def uninstall(self, user: UserBase) -> bool: | |
creds = await self.crud.get_installation_by_user_id(user.id, self.PROVIDER) | |
# check if credentials exist and contain a refresh token | |
if not creds: | |
return False | |
# use refresh token to revoke access | |
delete_token = encryption_service.decrypt(creds.refresh_token_enc) | |
# delete credentials from database | |
await self.crud.session.delete(creds) | |
# revoke refresh token | |
async with aiohttp.ClientSession() as session: | |
await session.post( | |
"https://auth.sid.ai/oauth/revoke", | |
headers={ | |
"Content-Type": "application/json", | |
}, | |
data=json.dumps( | |
{ | |
"client_id": self.settings.sid_client_id, | |
"client_secret": self.settings.sid_client_secret, | |
"token": delete_token, | |
} | |
), | |
) | |
return True | |
integrations = { | |
SIDInstaller.PROVIDER: SIDInstaller, | |
} | |
def installer_factory( | |
provider: str = Path(description="OAuth Provider"), | |
crud: OAuthCrud = Depends(OAuthCrud.inject), | |
) -> OAuthInstaller: | |
"""Factory for OAuth installers | |
Args: | |
provider (str): OAuth Provider (can be slack, github, etc.) (injected) | |
crud (OAuthCrud): OAuth Crud (injected) | |
""" | |
if provider in integrations: | |
return integrations[provider](crud, platform_settings) | |
raise NotImplementedError() | |