|
from os import getenv |
|
from typing import Union |
|
|
|
from loguru import logger |
|
|
|
from carvekit.web.schemas.config import WebAPIConfig, MLConfig, AuthConfig |
|
from carvekit.api.interface import Interface |
|
from carvekit.ml.wrap.fba_matting import FBAMatting |
|
from carvekit.ml.wrap.u2net import U2NET |
|
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 |
|
from carvekit.ml.wrap.basnet import BASNET |
|
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 |
|
|
|
from carvekit.pipelines.postprocessing import MattingMethod |
|
from carvekit.pipelines.preprocessing import PreprocessingStub |
|
from carvekit.trimap.generator import TrimapGenerator |
|
|
|
|
|
def init_config() -> WebAPIConfig: |
|
default_config = WebAPIConfig() |
|
config = WebAPIConfig( |
|
**dict( |
|
port=int(getenv("CARVEKIT_PORT", default_config.port)), |
|
host=getenv("CARVEKIT_HOST", default_config.host), |
|
ml=MLConfig( |
|
segmentation_network=getenv( |
|
"CARVEKIT_SEGMENTATION_NETWORK", |
|
default_config.ml.segmentation_network, |
|
), |
|
preprocessing_method=getenv( |
|
"CARVEKIT_PREPROCESSING_METHOD", |
|
default_config.ml.preprocessing_method, |
|
), |
|
postprocessing_method=getenv( |
|
"CARVEKIT_POSTPROCESSING_METHOD", |
|
default_config.ml.postprocessing_method, |
|
), |
|
device=getenv("CARVEKIT_DEVICE", default_config.ml.device), |
|
batch_size_seg=int( |
|
getenv("CARVEKIT_BATCH_SIZE_SEG", default_config.ml.batch_size_seg) |
|
), |
|
batch_size_matting=int( |
|
getenv( |
|
"CARVEKIT_BATCH_SIZE_MATTING", |
|
default_config.ml.batch_size_matting, |
|
) |
|
), |
|
seg_mask_size=int( |
|
getenv("CARVEKIT_SEG_MASK_SIZE", default_config.ml.seg_mask_size) |
|
), |
|
matting_mask_size=int( |
|
getenv( |
|
"CARVEKIT_MATTING_MASK_SIZE", |
|
default_config.ml.matting_mask_size, |
|
) |
|
), |
|
fp16=bool(int(getenv("CARVEKIT_FP16", default_config.ml.fp16))), |
|
trimap_prob_threshold=int( |
|
getenv( |
|
"CARVEKIT_TRIMAP_PROB_THRESHOLD", |
|
default_config.ml.trimap_prob_threshold, |
|
) |
|
), |
|
trimap_dilation=int( |
|
getenv( |
|
"CARVEKIT_TRIMAP_DILATION", default_config.ml.trimap_dilation |
|
) |
|
), |
|
trimap_erosion=int( |
|
getenv("CARVEKIT_TRIMAP_EROSION", default_config.ml.trimap_erosion) |
|
), |
|
), |
|
auth=AuthConfig( |
|
auth=bool( |
|
int(getenv("CARVEKIT_AUTH_ENABLE", default_config.auth.auth)) |
|
), |
|
admin_token=getenv( |
|
"CARVEKIT_ADMIN_TOKEN", default_config.auth.admin_token |
|
), |
|
allowed_tokens=default_config.auth.allowed_tokens |
|
if getenv("CARVEKIT_ALLOWED_TOKENS") is None |
|
else getenv("CARVEKIT_ALLOWED_TOKENS").split(","), |
|
), |
|
) |
|
) |
|
|
|
logger.info(f"Admin token for Web API is {config.auth.admin_token}") |
|
logger.debug(f"Running Web API with this config: {config.json()}") |
|
return config |
|
|
|
|
|
def init_interface(config: Union[WebAPIConfig, MLConfig]) -> Interface: |
|
if isinstance(config, WebAPIConfig): |
|
config = config.ml |
|
if config.segmentation_network == "u2net": |
|
seg_net = U2NET( |
|
device=config.device, |
|
batch_size=config.batch_size_seg, |
|
input_image_size=config.seg_mask_size, |
|
fp16=config.fp16, |
|
) |
|
elif config.segmentation_network == "deeplabv3": |
|
seg_net = DeepLabV3( |
|
device=config.device, |
|
batch_size=config.batch_size_seg, |
|
input_image_size=config.seg_mask_size, |
|
fp16=config.fp16, |
|
) |
|
elif config.segmentation_network == "basnet": |
|
seg_net = BASNET( |
|
device=config.device, |
|
batch_size=config.batch_size_seg, |
|
input_image_size=config.seg_mask_size, |
|
fp16=config.fp16, |
|
) |
|
elif config.segmentation_network == "tracer_b7": |
|
seg_net = TracerUniversalB7( |
|
device=config.device, |
|
batch_size=config.batch_size_seg, |
|
input_image_size=config.seg_mask_size, |
|
fp16=config.fp16, |
|
) |
|
else: |
|
seg_net = TracerUniversalB7( |
|
device=config.device, |
|
batch_size=config.batch_size_seg, |
|
input_image_size=config.seg_mask_size, |
|
fp16=config.fp16, |
|
) |
|
|
|
if config.preprocessing_method == "stub": |
|
preprocessing = PreprocessingStub() |
|
elif config.preprocessing_method == "none": |
|
preprocessing = None |
|
else: |
|
preprocessing = None |
|
|
|
if config.postprocessing_method == "fba": |
|
fba = FBAMatting( |
|
device=config.device, |
|
batch_size=config.batch_size_matting, |
|
input_tensor_size=config.matting_mask_size, |
|
fp16=config.fp16, |
|
) |
|
trimap_generator = TrimapGenerator( |
|
prob_threshold=config.trimap_prob_threshold, |
|
kernel_size=config.trimap_dilation, |
|
erosion_iters=config.trimap_erosion, |
|
) |
|
postprocessing = MattingMethod( |
|
device=config.device, matting_module=fba, trimap_generator=trimap_generator |
|
) |
|
|
|
elif config.postprocessing_method == "none": |
|
postprocessing = None |
|
else: |
|
postprocessing = None |
|
|
|
interface = Interface( |
|
pre_pipe=preprocessing, |
|
post_pipe=postprocessing, |
|
seg_pipe=seg_net, |
|
device=config.device, |
|
) |
|
return interface |
|
|