|
import secrets |
|
from typing import List |
|
from typing_extensions import Literal |
|
|
|
import torch.cuda |
|
from pydantic import BaseModel, validator |
|
|
|
|
|
class AuthConfig(BaseModel): |
|
"""Config for web api token authentication""" |
|
|
|
auth: bool = True |
|
"""Enables Token Authentication for API""" |
|
admin_token: str = secrets.token_hex(32) |
|
"""Admin Token""" |
|
allowed_tokens: List[str] = [secrets.token_hex(32)] |
|
"""All allowed tokens""" |
|
|
|
|
|
class MLConfig(BaseModel): |
|
"""Config for ml part of framework""" |
|
|
|
segmentation_network: Literal[ |
|
"u2net", "deeplabv3", "basnet", "tracer_b7" |
|
] = "tracer_b7" |
|
"""Segmentation Network""" |
|
preprocessing_method: Literal["none", "stub"] = "none" |
|
"""Pre-processing Method""" |
|
postprocessing_method: Literal["fba", "none"] = "fba" |
|
"""Post-Processing Network""" |
|
device: str = "cpu" |
|
"""Processing device""" |
|
batch_size_seg: int = 5 |
|
"""Batch size for segmentation network""" |
|
batch_size_matting: int = 1 |
|
"""Batch size for matting network""" |
|
seg_mask_size: int = 640 |
|
"""The size of the input image for the segmentation neural network.""" |
|
matting_mask_size: int = 2048 |
|
"""The size of the input image for the matting neural network.""" |
|
fp16: bool = False |
|
"""Use half precision for inference""" |
|
trimap_dilation: int = 30 |
|
"""Dilation size for trimap""" |
|
trimap_erosion: int = 5 |
|
"""Erosion levels for trimap""" |
|
trimap_prob_threshold: int = 231 |
|
"""Probability threshold for trimap generation""" |
|
|
|
@validator("seg_mask_size") |
|
def seg_mask_size_validator(cls, value: int, values): |
|
if value > 0: |
|
return value |
|
else: |
|
raise ValueError("Incorrect seg_mask_size!") |
|
|
|
@validator("matting_mask_size") |
|
def matting_mask_size_validator(cls, value: int, values): |
|
if value > 0: |
|
return value |
|
else: |
|
raise ValueError("Incorrect matting_mask_size!") |
|
|
|
@validator("batch_size_seg") |
|
def batch_size_seg_validator(cls, value: int, values): |
|
if value > 0: |
|
return value |
|
else: |
|
raise ValueError("Incorrect batch size!") |
|
|
|
@validator("batch_size_matting") |
|
def batch_size_matting_validator(cls, value: int, values): |
|
if value > 0: |
|
return value |
|
else: |
|
raise ValueError("Incorrect batch size!") |
|
|
|
@validator("device") |
|
def device_validator(cls, value): |
|
if torch.cuda.is_available() is False and "cuda" in value: |
|
raise ValueError( |
|
"GPU is not available, but specified as processing device!" |
|
) |
|
if "cuda" not in value and "cpu" != value: |
|
raise ValueError("Unknown processing device! It should be cpu or cuda!") |
|
return value |
|
|
|
|
|
class WebAPIConfig(BaseModel): |
|
"""FastAPI app config""" |
|
|
|
port: int = 5000 |
|
"""Web API port""" |
|
host: str = "0.0.0.0" |
|
"""Web API host""" |
|
ml: MLConfig = MLConfig() |
|
"""Config for ml part of framework""" |
|
auth: AuthConfig = AuthConfig() |
|
"""Config for web api token authentication """ |
|
|