Spaces:
Runtime error
Runtime error
File size: 5,312 Bytes
cd6f98e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import json
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from urllib.parse import urlencode
import aiohttp
from fastapi import Depends, Path
from reworkd_platform.db.crud.oauth import OAuthCrud
from reworkd_platform.db.models.auth import OauthCredentials
from reworkd_platform.schemas import UserBase
from reworkd_platform.services.security import encryption_service
from reworkd_platform.settings import Settings
from reworkd_platform.settings import settings as platform_settings
from reworkd_platform.web.api.http_responses import forbidden
class OAuthInstaller(ABC):
def __init__(self, crud: OAuthCrud, settings: Settings):
self.crud = crud
self.settings = settings
@abstractmethod
async def install(self, user: UserBase, redirect_uri: str) -> str:
raise NotImplementedError()
@abstractmethod
async def install_callback(self, code: str, state: str) -> OauthCredentials:
raise NotImplementedError()
@abstractmethod
async def uninstall(self, user: UserBase) -> bool:
raise NotImplementedError()
@staticmethod
def store_access_token(creds: OauthCredentials, access_token: str) -> None:
creds.access_token_enc = encryption_service.encrypt(access_token)
@staticmethod
def store_refresh_token(creds: OauthCredentials, refresh_token: str) -> None:
creds.refresh_token_enc = encryption_service.encrypt(refresh_token)
class SIDInstaller(OAuthInstaller):
PROVIDER = "sid"
async def install(self, user: UserBase, redirect_uri: str) -> str:
# gracefully handle the case where the installation already exists
# this can happen if the user starts the process from multiple tabs
installation = await self.crud.get_installation_by_user_id(
user.id, self.PROVIDER
)
if not installation:
installation = await self.crud.create_installation(
user,
self.PROVIDER,
redirect_uri,
)
scopes = ["data:query", "offline_access"]
params = {
"client_id": self.settings.sid_client_id,
"redirect_uri": self.settings.sid_redirect_uri,
"response_type": "code",
"scope": " ".join(scopes),
"state": installation.state,
"audience": "https://api.sid.ai/api/v1/",
}
auth_url = "https://me.sid.ai/api/oauth/authorize"
auth_url += "?" + urlencode(params)
return auth_url
async def install_callback(self, code: str, state: str) -> OauthCredentials:
creds = await self.crud.get_installation_by_state(state)
if not creds:
raise forbidden()
req = {
"grant_type": "authorization_code",
"client_id": self.settings.sid_client_id,
"client_secret": self.settings.sid_client_secret,
"redirect_uri": self.settings.sid_redirect_uri,
"code": code,
}
async with aiohttp.ClientSession() as session:
async with session.post(
"https://auth.sid.ai/oauth/token",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
},
data=json.dumps(req),
) as response:
res_data = await response.json()
OAuthInstaller.store_access_token(creds, res_data["access_token"])
OAuthInstaller.store_refresh_token(creds, res_data["refresh_token"])
creds.access_token_expiration = datetime.now() + timedelta(
seconds=res_data["expires_in"]
)
return await creds.save(self.crud.session)
async def uninstall(self, user: UserBase) -> bool:
creds = await self.crud.get_installation_by_user_id(user.id, self.PROVIDER)
# check if credentials exist and contain a refresh token
if not creds:
return False
# use refresh token to revoke access
delete_token = encryption_service.decrypt(creds.refresh_token_enc)
# delete credentials from database
await self.crud.session.delete(creds)
# revoke refresh token
async with aiohttp.ClientSession() as session:
await session.post(
"https://auth.sid.ai/oauth/revoke",
headers={
"Content-Type": "application/json",
},
data=json.dumps(
{
"client_id": self.settings.sid_client_id,
"client_secret": self.settings.sid_client_secret,
"token": delete_token,
}
),
)
return True
integrations = {
SIDInstaller.PROVIDER: SIDInstaller,
}
def installer_factory(
provider: str = Path(description="OAuth Provider"),
crud: OAuthCrud = Depends(OAuthCrud.inject),
) -> OAuthInstaller:
"""Factory for OAuth installers
Args:
provider (str): OAuth Provider (can be slack, github, etc.) (injected)
crud (OAuthCrud): OAuth Crud (injected)
"""
if provider in integrations:
return integrations[provider](crud, platform_settings)
raise NotImplementedError()
|