File size: 1,882 Bytes
6da6215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11ae501
 
6da6215
 
 
 
 
11ae501
6da6215
 
 
 
 
11ae501
6da6215
 
 
11ae501
 
6da6215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import requests
import gradio as gr

import paddle
from paddleseg.cvlibs import Config

from matting.core import predict
from matting.model import *
from matting.dataset import MattingDataset


def download_file(http_address, file_name):
    r = requests.get(http_address, allow_redirects=True)
    open(file_name, 'wb').write(r.content)

cfg_paths = ['configs/modnet/modnet_mobilenetv2.yml', 'configs/modnet/modnet_resnet50_vd.yml', 'configs/modnet/modnet_hrnet_w18.yml']
cfgs = [Config(cfg) for cfg in cfg_paths]

download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-mobilenetv2.pdparams', 'modnet-mobilenetv2.pdparams')
download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-resnet50_vd.pdparams', 'modnet-resnet50_vd.pdparams')
download_file('https://paddleseg.bj.bcebos.com/matting/models/modnet-hrnet_w18.pdparams', 'modnet-hrnet_w18.pdparams')
models_paths = ['modnet-mobilenetv2.pdparams', 'modnet-resnet50_vd.pdparams', 'modnet-hrnet_w18.pdparams']
models = [cfg.model for cfg in cfgs]


def inference(image, chosen_model):
    paddle.set_device('cpu')

    cfg = cfgs[chosen_model]
    val_dataset = cfg.val_dataset
    img_transforms = val_dataset.transforms

    model = models[chosen_model]

    alpha_pred = predict(model,
                         model_path=models_paths[chosen_model],
                         transforms=img_transforms,
                         image_list=[image])

    return alpha_pred


inputs = [gr.inputs.Image(label='Input Image'),
          gr.inputs.Radio(['MobileNetV2', 'ResNet50_vd', 'HRNet_W18'], label='Model', type='index')]

gr.Interface(
    inference, 
    inputs,
    gr.outputs.Image(label='Output'),
    title='PaddleSeg - Matting',
    examples=[['images/armchair.jpg', 'MobileNetV2'],
              ['images/cat.jpg', 'ResNet50_vd'],
              ['images/plant.jpg', 'HRNet_W18']]
    ).launch()