Spaces:
Runtime error
Runtime error
import secrets | |
from typing import Dict, Optional | |
from fastapi import Depends | |
from sqlalchemy import func, select | |
from sqlalchemy.ext.asyncio import AsyncSession | |
from reworkd_platform.db.crud.base import BaseCrud | |
from reworkd_platform.db.dependencies import get_db_session | |
from reworkd_platform.db.models.auth import OauthCredentials | |
from reworkd_platform.schemas import UserBase | |
class OAuthCrud(BaseCrud): | |
async def inject( | |
cls, | |
session: AsyncSession = Depends(get_db_session), | |
) -> "OAuthCrud": | |
return cls(session) | |
async def create_installation( | |
self, user: UserBase, provider: str, redirect_uri: Optional[str] | |
) -> OauthCredentials: | |
return await OauthCredentials( | |
user_id=user.id, | |
organization_id=user.organization_id, | |
provider=provider, | |
state=secrets.token_hex(16), | |
redirect_uri=redirect_uri, | |
).save(self.session) | |
async def get_installation_by_state(self, state: str) -> Optional[OauthCredentials]: | |
query = select(OauthCredentials).filter(OauthCredentials.state == state) | |
return (await self.session.execute(query)).scalar_one_or_none() | |
async def get_installation_by_user_id( | |
self, user_id: str, provider: str | |
) -> Optional[OauthCredentials]: | |
query = select(OauthCredentials).filter( | |
OauthCredentials.user_id == user_id, | |
OauthCredentials.provider == provider, | |
OauthCredentials.access_token_enc.isnot(None), | |
) | |
return (await self.session.execute(query)).scalars().first() | |
async def get_installation_by_organization_id( | |
self, organization_id: str, provider: str | |
) -> Optional[OauthCredentials]: | |
query = select(OauthCredentials).filter( | |
OauthCredentials.organization_id == organization_id, | |
OauthCredentials.provider == provider, | |
OauthCredentials.access_token_enc.isnot(None), | |
OauthCredentials.organization_id.isnot(None), | |
) | |
return (await self.session.execute(query)).scalars().first() | |
async def get_all(self, user: UserBase) -> Dict[str, str]: | |
query = ( | |
select( | |
OauthCredentials.provider, | |
func.any_value(OauthCredentials.access_token_enc), | |
) | |
.filter( | |
OauthCredentials.access_token_enc.isnot(None), | |
OauthCredentials.organization_id == user.organization_id, | |
) | |
.group_by(OauthCredentials.provider) | |
) | |
return { | |
provider: token | |
for provider, token in (await self.session.execute(query)).all() | |
} | |