from os import getenv from typing import Union from loguru import logger from carvekit.web.schemas.config import WebAPIConfig, MLConfig, AuthConfig from carvekit.api.interface import Interface from carvekit.ml.wrap.fba_matting import FBAMatting from carvekit.ml.wrap.u2net import U2NET from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 from carvekit.ml.wrap.basnet import BASNET from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 from carvekit.pipelines.postprocessing import MattingMethod from carvekit.pipelines.preprocessing import PreprocessingStub from carvekit.trimap.generator import TrimapGenerator def init_config() -> WebAPIConfig: default_config = WebAPIConfig() config = WebAPIConfig( **dict( port=int(getenv("CARVEKIT_PORT", default_config.port)), host=getenv("CARVEKIT_HOST", default_config.host), ml=MLConfig( segmentation_network=getenv( "CARVEKIT_SEGMENTATION_NETWORK", default_config.ml.segmentation_network, ), preprocessing_method=getenv( "CARVEKIT_PREPROCESSING_METHOD", default_config.ml.preprocessing_method, ), postprocessing_method=getenv( "CARVEKIT_POSTPROCESSING_METHOD", default_config.ml.postprocessing_method, ), device=getenv("CARVEKIT_DEVICE", default_config.ml.device), batch_size_seg=int( getenv("CARVEKIT_BATCH_SIZE_SEG", default_config.ml.batch_size_seg) ), batch_size_matting=int( getenv( "CARVEKIT_BATCH_SIZE_MATTING", default_config.ml.batch_size_matting, ) ), seg_mask_size=int( getenv("CARVEKIT_SEG_MASK_SIZE", default_config.ml.seg_mask_size) ), matting_mask_size=int( getenv( "CARVEKIT_MATTING_MASK_SIZE", default_config.ml.matting_mask_size, ) ), fp16=bool(int(getenv("CARVEKIT_FP16", default_config.ml.fp16))), trimap_prob_threshold=int( getenv( "CARVEKIT_TRIMAP_PROB_THRESHOLD", default_config.ml.trimap_prob_threshold, ) ), trimap_dilation=int( getenv( "CARVEKIT_TRIMAP_DILATION", default_config.ml.trimap_dilation ) ), trimap_erosion=int( getenv("CARVEKIT_TRIMAP_EROSION", default_config.ml.trimap_erosion) ), ), auth=AuthConfig( auth=bool( int(getenv("CARVEKIT_AUTH_ENABLE", default_config.auth.auth)) ), admin_token=getenv( "CARVEKIT_ADMIN_TOKEN", default_config.auth.admin_token ), allowed_tokens=default_config.auth.allowed_tokens if getenv("CARVEKIT_ALLOWED_TOKENS") is None else getenv("CARVEKIT_ALLOWED_TOKENS").split(","), ), ) ) logger.info(f"Admin token for Web API is {config.auth.admin_token}") logger.debug(f"Running Web API with this config: {config.json()}") return config def init_interface(config: Union[WebAPIConfig, MLConfig]) -> Interface: if isinstance(config, WebAPIConfig): config = config.ml if config.segmentation_network == "u2net": seg_net = U2NET( device=config.device, batch_size=config.batch_size_seg, input_image_size=config.seg_mask_size, fp16=config.fp16, ) elif config.segmentation_network == "deeplabv3": seg_net = DeepLabV3( device=config.device, batch_size=config.batch_size_seg, input_image_size=config.seg_mask_size, fp16=config.fp16, ) elif config.segmentation_network == "basnet": seg_net = BASNET( device=config.device, batch_size=config.batch_size_seg, input_image_size=config.seg_mask_size, fp16=config.fp16, ) elif config.segmentation_network == "tracer_b7": seg_net = TracerUniversalB7( device=config.device, batch_size=config.batch_size_seg, input_image_size=config.seg_mask_size, fp16=config.fp16, ) else: seg_net = TracerUniversalB7( device=config.device, batch_size=config.batch_size_seg, input_image_size=config.seg_mask_size, fp16=config.fp16, ) if config.preprocessing_method == "stub": preprocessing = PreprocessingStub() elif config.preprocessing_method == "none": preprocessing = None else: preprocessing = None if config.postprocessing_method == "fba": fba = FBAMatting( device=config.device, batch_size=config.batch_size_matting, input_tensor_size=config.matting_mask_size, fp16=config.fp16, ) trimap_generator = TrimapGenerator( prob_threshold=config.trimap_prob_threshold, kernel_size=config.trimap_dilation, erosion_iters=config.trimap_erosion, ) postprocessing = MattingMethod( device=config.device, matting_module=fba, trimap_generator=trimap_generator ) elif config.postprocessing_method == "none": postprocessing = None else: postprocessing = None interface = Interface( pre_pipe=preprocessing, post_pipe=postprocessing, seg_pipe=seg_net, device=config.device, ) return interface