File size: 3,104 Bytes
a3d6c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import secrets
from typing import List
from typing_extensions import Literal

import torch.cuda
from pydantic import BaseModel, validator


class AuthConfig(BaseModel):
    """Config for web api token authentication"""

    auth: bool = True
    """Enables Token Authentication for API"""
    admin_token: str = secrets.token_hex(32)
    """Admin Token"""
    allowed_tokens: List[str] = [secrets.token_hex(32)]
    """All allowed tokens"""


class MLConfig(BaseModel):
    """Config for ml part of framework"""

    segmentation_network: Literal[
        "u2net", "deeplabv3", "basnet", "tracer_b7"
    ] = "tracer_b7"
    """Segmentation Network"""
    preprocessing_method: Literal["none", "stub"] = "none"
    """Pre-processing Method"""
    postprocessing_method: Literal["fba", "none"] = "fba"
    """Post-Processing Network"""
    device: str = "cpu"
    """Processing device"""
    batch_size_seg: int = 5
    """Batch size for segmentation network"""
    batch_size_matting: int = 1
    """Batch size for matting network"""
    seg_mask_size: int = 640
    """The size of the input image for the segmentation neural network."""
    matting_mask_size: int = 2048
    """The size of the input image for the matting neural network."""
    fp16: bool = False
    """Use half precision for inference"""
    trimap_dilation: int = 30
    """Dilation size for trimap"""
    trimap_erosion: int = 5
    """Erosion levels for trimap"""
    trimap_prob_threshold: int = 231
    """Probability threshold for trimap generation"""

    @validator("seg_mask_size")
    def seg_mask_size_validator(cls, value: int, values):
        if value > 0:
            return value
        else:
            raise ValueError("Incorrect seg_mask_size!")

    @validator("matting_mask_size")
    def matting_mask_size_validator(cls, value: int, values):
        if value > 0:
            return value
        else:
            raise ValueError("Incorrect matting_mask_size!")

    @validator("batch_size_seg")
    def batch_size_seg_validator(cls, value: int, values):
        if value > 0:
            return value
        else:
            raise ValueError("Incorrect batch size!")

    @validator("batch_size_matting")
    def batch_size_matting_validator(cls, value: int, values):
        if value > 0:
            return value
        else:
            raise ValueError("Incorrect batch size!")

    @validator("device")
    def device_validator(cls, value):
        if torch.cuda.is_available() is False and "cuda" in value:
            raise ValueError(
                "GPU is not available, but specified as processing device!"
            )
        if "cuda" not in value and "cpu" != value:
            raise ValueError("Unknown processing device! It should be cpu or cuda!")
        return value


class WebAPIConfig(BaseModel):
    """FastAPI app config"""

    port: int = 5000
    """Web API port"""
    host: str = "0.0.0.0"
    """Web API host"""
    ml: MLConfig = MLConfig()
    """Config for ml part of framework"""
    auth: AuthConfig = AuthConfig()
    """Config for web api token authentication """