dia-gov's picture
Upload 569 files
cd6f98e verified
raw
history blame
5.31 kB
import json
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from urllib.parse import urlencode
import aiohttp
from fastapi import Depends, Path
from reworkd_platform.db.crud.oauth import OAuthCrud
from reworkd_platform.db.models.auth import OauthCredentials
from reworkd_platform.schemas import UserBase
from reworkd_platform.services.security import encryption_service
from reworkd_platform.settings import Settings
from reworkd_platform.settings import settings as platform_settings
from reworkd_platform.web.api.http_responses import forbidden
class OAuthInstaller(ABC):
def __init__(self, crud: OAuthCrud, settings: Settings):
self.crud = crud
self.settings = settings
@abstractmethod
async def install(self, user: UserBase, redirect_uri: str) -> str:
raise NotImplementedError()
@abstractmethod
async def install_callback(self, code: str, state: str) -> OauthCredentials:
raise NotImplementedError()
@abstractmethod
async def uninstall(self, user: UserBase) -> bool:
raise NotImplementedError()
@staticmethod
def store_access_token(creds: OauthCredentials, access_token: str) -> None:
creds.access_token_enc = encryption_service.encrypt(access_token)
@staticmethod
def store_refresh_token(creds: OauthCredentials, refresh_token: str) -> None:
creds.refresh_token_enc = encryption_service.encrypt(refresh_token)
class SIDInstaller(OAuthInstaller):
PROVIDER = "sid"
async def install(self, user: UserBase, redirect_uri: str) -> str:
# gracefully handle the case where the installation already exists
# this can happen if the user starts the process from multiple tabs
installation = await self.crud.get_installation_by_user_id(
user.id, self.PROVIDER
)
if not installation:
installation = await self.crud.create_installation(
user,
self.PROVIDER,
redirect_uri,
)
scopes = ["data:query", "offline_access"]
params = {
"client_id": self.settings.sid_client_id,
"redirect_uri": self.settings.sid_redirect_uri,
"response_type": "code",
"scope": " ".join(scopes),
"state": installation.state,
"audience": "https://api.sid.ai/api/v1/",
}
auth_url = "https://me.sid.ai/api/oauth/authorize"
auth_url += "?" + urlencode(params)
return auth_url
async def install_callback(self, code: str, state: str) -> OauthCredentials:
creds = await self.crud.get_installation_by_state(state)
if not creds:
raise forbidden()
req = {
"grant_type": "authorization_code",
"client_id": self.settings.sid_client_id,
"client_secret": self.settings.sid_client_secret,
"redirect_uri": self.settings.sid_redirect_uri,
"code": code,
}
async with aiohttp.ClientSession() as session:
async with session.post(
"https://auth.sid.ai/oauth/token",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
},
data=json.dumps(req),
) as response:
res_data = await response.json()
OAuthInstaller.store_access_token(creds, res_data["access_token"])
OAuthInstaller.store_refresh_token(creds, res_data["refresh_token"])
creds.access_token_expiration = datetime.now() + timedelta(
seconds=res_data["expires_in"]
)
return await creds.save(self.crud.session)
async def uninstall(self, user: UserBase) -> bool:
creds = await self.crud.get_installation_by_user_id(user.id, self.PROVIDER)
# check if credentials exist and contain a refresh token
if not creds:
return False
# use refresh token to revoke access
delete_token = encryption_service.decrypt(creds.refresh_token_enc)
# delete credentials from database
await self.crud.session.delete(creds)
# revoke refresh token
async with aiohttp.ClientSession() as session:
await session.post(
"https://auth.sid.ai/oauth/revoke",
headers={
"Content-Type": "application/json",
},
data=json.dumps(
{
"client_id": self.settings.sid_client_id,
"client_secret": self.settings.sid_client_secret,
"token": delete_token,
}
),
)
return True
integrations = {
SIDInstaller.PROVIDER: SIDInstaller,
}
def installer_factory(
provider: str = Path(description="OAuth Provider"),
crud: OAuthCrud = Depends(OAuthCrud.inject),
) -> OAuthInstaller:
"""Factory for OAuth installers
Args:
provider (str): OAuth Provider (can be slack, github, etc.) (injected)
crud (OAuthCrud): OAuth Crud (injected)
"""
if provider in integrations:
return integrations[provider](crud, platform_settings)
raise NotImplementedError()