jayparmr's picture
Upload folder using huggingface_hub
a3d6c18
raw
history blame
3.1 kB
import secrets
from typing import List
from typing_extensions import Literal
import torch.cuda
from pydantic import BaseModel, validator
class AuthConfig(BaseModel):
"""Config for web api token authentication"""
auth: bool = True
"""Enables Token Authentication for API"""
admin_token: str = secrets.token_hex(32)
"""Admin Token"""
allowed_tokens: List[str] = [secrets.token_hex(32)]
"""All allowed tokens"""
class MLConfig(BaseModel):
"""Config for ml part of framework"""
segmentation_network: Literal[
"u2net", "deeplabv3", "basnet", "tracer_b7"
] = "tracer_b7"
"""Segmentation Network"""
preprocessing_method: Literal["none", "stub"] = "none"
"""Pre-processing Method"""
postprocessing_method: Literal["fba", "none"] = "fba"
"""Post-Processing Network"""
device: str = "cpu"
"""Processing device"""
batch_size_seg: int = 5
"""Batch size for segmentation network"""
batch_size_matting: int = 1
"""Batch size for matting network"""
seg_mask_size: int = 640
"""The size of the input image for the segmentation neural network."""
matting_mask_size: int = 2048
"""The size of the input image for the matting neural network."""
fp16: bool = False
"""Use half precision for inference"""
trimap_dilation: int = 30
"""Dilation size for trimap"""
trimap_erosion: int = 5
"""Erosion levels for trimap"""
trimap_prob_threshold: int = 231
"""Probability threshold for trimap generation"""
@validator("seg_mask_size")
def seg_mask_size_validator(cls, value: int, values):
if value > 0:
return value
else:
raise ValueError("Incorrect seg_mask_size!")
@validator("matting_mask_size")
def matting_mask_size_validator(cls, value: int, values):
if value > 0:
return value
else:
raise ValueError("Incorrect matting_mask_size!")
@validator("batch_size_seg")
def batch_size_seg_validator(cls, value: int, values):
if value > 0:
return value
else:
raise ValueError("Incorrect batch size!")
@validator("batch_size_matting")
def batch_size_matting_validator(cls, value: int, values):
if value > 0:
return value
else:
raise ValueError("Incorrect batch size!")
@validator("device")
def device_validator(cls, value):
if torch.cuda.is_available() is False and "cuda" in value:
raise ValueError(
"GPU is not available, but specified as processing device!"
)
if "cuda" not in value and "cpu" != value:
raise ValueError("Unknown processing device! It should be cpu or cuda!")
return value
class WebAPIConfig(BaseModel):
"""FastAPI app config"""
port: int = 5000
"""Web API port"""
host: str = "0.0.0.0"
"""Web API host"""
ml: MLConfig = MLConfig()
"""Config for ml part of framework"""
auth: AuthConfig = AuthConfig()
"""Config for web api token authentication """