dia-gov's picture
Upload 569 files
cd6f98e verified
raw
history blame
2.71 kB
import secrets
from typing import Dict, Optional
from fastapi import Depends
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from reworkd_platform.db.crud.base import BaseCrud
from reworkd_platform.db.dependencies import get_db_session
from reworkd_platform.db.models.auth import OauthCredentials
from reworkd_platform.schemas import UserBase
class OAuthCrud(BaseCrud):
@classmethod
async def inject(
cls,
session: AsyncSession = Depends(get_db_session),
) -> "OAuthCrud":
return cls(session)
async def create_installation(
self, user: UserBase, provider: str, redirect_uri: Optional[str]
) -> OauthCredentials:
return await OauthCredentials(
user_id=user.id,
organization_id=user.organization_id,
provider=provider,
state=secrets.token_hex(16),
redirect_uri=redirect_uri,
).save(self.session)
async def get_installation_by_state(self, state: str) -> Optional[OauthCredentials]:
query = select(OauthCredentials).filter(OauthCredentials.state == state)
return (await self.session.execute(query)).scalar_one_or_none()
async def get_installation_by_user_id(
self, user_id: str, provider: str
) -> Optional[OauthCredentials]:
query = select(OauthCredentials).filter(
OauthCredentials.user_id == user_id,
OauthCredentials.provider == provider,
OauthCredentials.access_token_enc.isnot(None),
)
return (await self.session.execute(query)).scalars().first()
async def get_installation_by_organization_id(
self, organization_id: str, provider: str
) -> Optional[OauthCredentials]:
query = select(OauthCredentials).filter(
OauthCredentials.organization_id == organization_id,
OauthCredentials.provider == provider,
OauthCredentials.access_token_enc.isnot(None),
OauthCredentials.organization_id.isnot(None),
)
return (await self.session.execute(query)).scalars().first()
async def get_all(self, user: UserBase) -> Dict[str, str]:
query = (
select(
OauthCredentials.provider,
func.any_value(OauthCredentials.access_token_enc),
)
.filter(
OauthCredentials.access_token_enc.isnot(None),
OauthCredentials.organization_id == user.organization_id,
)
.group_by(OauthCredentials.provider)
)
return {
provider: token
for provider, token in (await self.session.execute(query)).all()
}