|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
import time |
|
|
|
import paddle |
|
import paddle.nn as nn |
|
import paddle.nn.functional as F |
|
import paddleseg |
|
from paddleseg.models import layers |
|
from paddleseg import utils |
|
from paddleseg.cvlibs import manager |
|
|
|
from ppmatting.models.losses import MRSD |
|
|
|
|
|
def conv_up_psp(in_channels, out_channels, up_sample): |
|
return nn.Sequential( |
|
layers.ConvBNReLU( |
|
in_channels, out_channels, 3, padding=1), |
|
nn.Upsample( |
|
scale_factor=up_sample, mode='bilinear', align_corners=False)) |
|
|
|
|
|
@manager.MODELS.add_component |
|
class HumanMatting(nn.Layer): |
|
"""A model for """ |
|
|
|
def __init__(self, |
|
backbone, |
|
pretrained=None, |
|
backbone_scale=0.25, |
|
refine_kernel_size=3, |
|
if_refine=True): |
|
super().__init__() |
|
if if_refine: |
|
if backbone_scale > 0.5: |
|
raise ValueError( |
|
'Backbone_scale should not be greater than 1/2, but it is {}' |
|
.format(backbone_scale)) |
|
else: |
|
backbone_scale = 1 |
|
|
|
self.backbone = backbone |
|
self.backbone_scale = backbone_scale |
|
self.pretrained = pretrained |
|
self.if_refine = if_refine |
|
if if_refine: |
|
self.refiner = Refiner(kernel_size=refine_kernel_size) |
|
self.loss_func_dict = None |
|
|
|
self.backbone_channels = backbone.feat_channels |
|
|
|
|
|
|
|
self.psp_module = layers.PPModule( |
|
self.backbone_channels[-1], |
|
512, |
|
bin_sizes=(1, 3, 5), |
|
dim_reduction=False, |
|
align_corners=False) |
|
self.psp4 = conv_up_psp(512, 256, 2) |
|
self.psp3 = conv_up_psp(512, 128, 4) |
|
self.psp2 = conv_up_psp(512, 64, 8) |
|
self.psp1 = conv_up_psp(512, 64, 16) |
|
|
|
self.decoder5_g = nn.Sequential( |
|
layers.ConvBNReLU( |
|
512 + self.backbone_channels[-1], 512, 3, padding=1), |
|
layers.ConvBNReLU( |
|
512, 512, 3, padding=2, dilation=2), |
|
layers.ConvBNReLU( |
|
512, 256, 3, padding=2, dilation=2), |
|
nn.Upsample( |
|
scale_factor=2, mode='bilinear', align_corners=False)) |
|
|
|
self.decoder4_g = nn.Sequential( |
|
layers.ConvBNReLU( |
|
512, 256, 3, padding=1), |
|
layers.ConvBNReLU( |
|
256, 256, 3, padding=1), |
|
layers.ConvBNReLU( |
|
256, 128, 3, padding=1), |
|
nn.Upsample( |
|
scale_factor=2, mode='bilinear', align_corners=False)) |
|
|
|
self.decoder3_g = nn.Sequential( |
|
layers.ConvBNReLU( |
|
256, 128, 3, padding=1), |
|
layers.ConvBNReLU( |
|
128, 128, 3, padding=1), |
|
layers.ConvBNReLU( |
|
128, 64, 3, padding=1), |
|
nn.Upsample( |
|
scale_factor=2, mode='bilinear', align_corners=False)) |
|
|
|
self.decoder2_g = nn.Sequential( |
|
layers.ConvBNReLU( |
|
128, 128, 3, padding=1), |
|
layers.ConvBNReLU( |
|
128, 128, 3, padding=1), |
|
layers.ConvBNReLU( |
|
128, 64, 3, padding=1), |
|
nn.Upsample( |
|
scale_factor=2, mode='bilinear', align_corners=False)) |
|
|
|
self.decoder1_g = nn.Sequential( |
|
layers.ConvBNReLU( |
|
128, 64, 3, padding=1), |
|
layers.ConvBNReLU( |
|
64, 64, 3, padding=1), |
|
layers.ConvBNReLU( |
|
64, 64, 3, padding=1), |
|
nn.Upsample( |
|
scale_factor=2, mode='bilinear', align_corners=False)) |
|
|
|
self.decoder0_g = nn.Sequential( |
|
layers.ConvBNReLU( |
|
64, 64, 3, padding=1), |
|
layers.ConvBNReLU( |
|
64, 64, 3, padding=1), |
|
nn.Conv2D( |
|
64, 3, 3, padding=1)) |
|
|
|
|
|
|
|
|
|
self.bridge_block = nn.Sequential( |
|
layers.ConvBNReLU( |
|
self.backbone_channels[-1], 512, 3, dilation=2, padding=2), |
|
layers.ConvBNReLU( |
|
512, 512, 3, dilation=2, padding=2), |
|
layers.ConvBNReLU( |
|
512, 512, 3, dilation=2, padding=2)) |
|
|
|
self.decoder5_f = nn.Sequential( |
|
layers.ConvBNReLU( |
|
512 + self.backbone_channels[-1], 512, 3, padding=1), |
|
layers.ConvBNReLU( |
|
512, 512, 3, padding=2, dilation=2), |
|
layers.ConvBNReLU( |
|
512, 256, 3, padding=2, dilation=2), |
|
nn.Upsample( |
|
scale_factor=2, mode='bilinear', align_corners=False)) |
|
|
|
self.decoder4_f = nn.Sequential( |
|
layers.ConvBNReLU( |
|
256 + self.backbone_channels[-2], 256, 3, padding=1), |
|
layers.ConvBNReLU( |
|
256, 256, 3, padding=1), |
|
layers.ConvBNReLU( |
|
256, 128, 3, padding=1), |
|
nn.Upsample( |
|
scale_factor=2, mode='bilinear', align_corners=False)) |
|
|
|
self.decoder3_f = nn.Sequential( |
|
layers.ConvBNReLU( |
|
128 + self.backbone_channels[-3], 128, 3, padding=1), |
|
layers.ConvBNReLU( |
|
128, 128, 3, padding=1), |
|
layers.ConvBNReLU( |
|
128, 64, 3, padding=1), |
|
nn.Upsample( |
|
scale_factor=2, mode='bilinear', align_corners=False)) |
|
|
|
self.decoder2_f = nn.Sequential( |
|
layers.ConvBNReLU( |
|
64 + self.backbone_channels[-4], 128, 3, padding=1), |
|
layers.ConvBNReLU( |
|
128, 128, 3, padding=1), |
|
layers.ConvBNReLU( |
|
128, 64, 3, padding=1), |
|
nn.Upsample( |
|
scale_factor=2, mode='bilinear', align_corners=False)) |
|
|
|
self.decoder1_f = nn.Sequential( |
|
layers.ConvBNReLU( |
|
64 + self.backbone_channels[-5], 64, 3, padding=1), |
|
layers.ConvBNReLU( |
|
64, 64, 3, padding=1), |
|
layers.ConvBNReLU( |
|
64, 64, 3, padding=1), |
|
nn.Upsample( |
|
scale_factor=2, mode='bilinear', align_corners=False)) |
|
|
|
self.decoder0_f = nn.Sequential( |
|
layers.ConvBNReLU( |
|
64, 64, 3, padding=1), |
|
layers.ConvBNReLU( |
|
64, 64, 3, padding=1), |
|
nn.Conv2D( |
|
64, 1 + 1 + 32, 3, padding=1)) |
|
self.init_weight() |
|
|
|
def forward(self, data): |
|
src = data['img'] |
|
src_h, src_w = paddle.shape(src)[2:] |
|
if self.if_refine: |
|
|
|
if isinstance(src_h, paddle.Tensor): |
|
if (src_h % 4 != 0) or (src_w % 4) != 0: |
|
raise ValueError( |
|
'The input image must have width and height that are divisible by 4' |
|
) |
|
|
|
|
|
src_sm = F.interpolate( |
|
src, |
|
scale_factor=self.backbone_scale, |
|
mode='bilinear', |
|
align_corners=False) |
|
|
|
|
|
fea_list = self.backbone(src_sm) |
|
|
|
|
|
|
|
|
|
psp = self.psp_module(fea_list[-1]) |
|
|
|
d5_g = self.decoder5_g(paddle.concat((psp, fea_list[-1]), 1)) |
|
|
|
d4_g = self.decoder4_g(paddle.concat((self.psp4(psp), d5_g), 1)) |
|
|
|
d3_g = self.decoder3_g(paddle.concat((self.psp3(psp), d4_g), 1)) |
|
|
|
d2_g = self.decoder2_g(paddle.concat((self.psp2(psp), d3_g), 1)) |
|
|
|
d1_g = self.decoder1_g(paddle.concat((self.psp1(psp), d2_g), 1)) |
|
|
|
d0_g = self.decoder0_g(d1_g) |
|
|
|
|
|
glance_sigmoid = F.softmax(d0_g, axis=1) |
|
|
|
|
|
|
|
|
|
bb = self.bridge_block(fea_list[-1]) |
|
|
|
d5_f = self.decoder5_f(paddle.concat((bb, fea_list[-1]), 1)) |
|
|
|
d4_f = self.decoder4_f(paddle.concat((d5_f, fea_list[-2]), 1)) |
|
|
|
d3_f = self.decoder3_f(paddle.concat((d4_f, fea_list[-3]), 1)) |
|
|
|
d2_f = self.decoder2_f(paddle.concat((d3_f, fea_list[-4]), 1)) |
|
|
|
d1_f = self.decoder1_f(paddle.concat((d2_f, fea_list[-5]), 1)) |
|
|
|
d0_f = self.decoder0_f(d1_f) |
|
|
|
focus_sigmoid = F.sigmoid(d0_f[:, 0:1, :, :]) |
|
pha_sm = self.fusion(glance_sigmoid, focus_sigmoid) |
|
err_sm = d0_f[:, 1:2, :, :] |
|
err_sm = paddle.clip(err_sm, 0., 1.) |
|
hid_sm = F.relu(d0_f[:, 2:, :, :]) |
|
|
|
|
|
if self.if_refine: |
|
pha = self.refiner( |
|
src=src, pha=pha_sm, err=err_sm, hid=hid_sm, tri=glance_sigmoid) |
|
|
|
pha = paddle.clip(pha, 0., 1.) |
|
|
|
if self.training: |
|
logit_dict = { |
|
'glance': glance_sigmoid, |
|
'focus': focus_sigmoid, |
|
'fusion': pha_sm, |
|
'error': err_sm |
|
} |
|
if self.if_refine: |
|
logit_dict['refine'] = pha |
|
loss_dict = self.loss(logit_dict, data) |
|
return logit_dict, loss_dict |
|
else: |
|
return pha if self.if_refine else pha_sm |
|
|
|
def loss(self, logit_dict, label_dict, loss_func_dict=None): |
|
if loss_func_dict is None: |
|
if self.loss_func_dict is None: |
|
self.loss_func_dict = defaultdict(list) |
|
self.loss_func_dict['glance'].append(nn.NLLLoss()) |
|
self.loss_func_dict['focus'].append(MRSD()) |
|
self.loss_func_dict['cm'].append(MRSD()) |
|
self.loss_func_dict['err'].append(paddleseg.models.MSELoss()) |
|
self.loss_func_dict['refine'].append(paddleseg.models.L1Loss()) |
|
else: |
|
self.loss_func_dict = loss_func_dict |
|
|
|
loss = {} |
|
|
|
|
|
|
|
glance_label = F.interpolate( |
|
label_dict['trimap'], |
|
logit_dict['glance'].shape[2:], |
|
mode='nearest', |
|
align_corners=False) |
|
glance_label_trans = (glance_label == 128).astype('int64') |
|
glance_label_bg = (glance_label == 0).astype('int64') |
|
glance_label = glance_label_trans + glance_label_bg * 2 |
|
loss_glance = self.loss_func_dict['glance'][0]( |
|
paddle.log(logit_dict['glance'] + 1e-6), glance_label.squeeze(1)) |
|
loss['glance'] = loss_glance |
|
|
|
|
|
focus_label = F.interpolate( |
|
label_dict['alpha'], |
|
logit_dict['focus'].shape[2:], |
|
mode='bilinear', |
|
align_corners=False) |
|
loss_focus = self.loss_func_dict['focus'][0]( |
|
logit_dict['focus'], focus_label, glance_label_trans) |
|
loss['focus'] = loss_focus |
|
|
|
|
|
loss_cm_func = self.loss_func_dict['cm'] |
|
|
|
loss_cm = loss_cm_func[0](logit_dict['fusion'], focus_label) |
|
loss['cm'] = loss_cm |
|
|
|
|
|
err = F.interpolate( |
|
logit_dict['error'], |
|
label_dict['alpha'].shape[2:], |
|
mode='bilinear', |
|
align_corners=False) |
|
err_label = (F.interpolate( |
|
logit_dict['fusion'], |
|
label_dict['alpha'].shape[2:], |
|
mode='bilinear', |
|
align_corners=False) - label_dict['alpha']).abs() |
|
loss_err = self.loss_func_dict['err'][0](err, err_label) |
|
loss['err'] = loss_err |
|
|
|
loss_all = 0.25 * loss_glance + 0.25 * loss_focus + 0.25 * loss_cm + loss_err |
|
|
|
|
|
if self.if_refine: |
|
loss_refine = self.loss_func_dict['refine'][0](logit_dict['refine'], |
|
label_dict['alpha']) |
|
loss['refine'] = loss_refine |
|
loss_all = loss_all + loss_refine |
|
|
|
loss['all'] = loss_all |
|
return loss |
|
|
|
def fusion(self, glance_sigmoid, focus_sigmoid): |
|
|
|
|
|
|
|
index = paddle.argmax(glance_sigmoid, axis=1, keepdim=True) |
|
transition_mask = (index == 1).astype('float32') |
|
fg = (index == 0).astype('float32') |
|
fusion_sigmoid = focus_sigmoid * transition_mask + fg |
|
return fusion_sigmoid |
|
|
|
def init_weight(self): |
|
if self.pretrained is not None: |
|
utils.load_entire_model(self, self.pretrained) |
|
|
|
|
|
class Refiner(nn.Layer): |
|
''' |
|
Refiner refines the coarse output to full resolution. |
|
|
|
Args: |
|
kernel_size: The convolution kernel_size. Options: [1, 3]. Default: 3. |
|
''' |
|
|
|
def __init__(self, kernel_size=3): |
|
super().__init__() |
|
if kernel_size not in [1, 3]: |
|
raise ValueError("kernel_size must be in [1, 3]") |
|
|
|
self.kernel_size = kernel_size |
|
|
|
channels = [32, 24, 16, 12, 1] |
|
self.conv1 = layers.ConvBNReLU( |
|
channels[0] + 4 + 3, |
|
channels[1], |
|
kernel_size, |
|
padding=0, |
|
bias_attr=False) |
|
self.conv2 = layers.ConvBNReLU( |
|
channels[1], channels[2], kernel_size, padding=0, bias_attr=False) |
|
self.conv3 = layers.ConvBNReLU( |
|
channels[2] + 3, |
|
channels[3], |
|
kernel_size, |
|
padding=0, |
|
bias_attr=False) |
|
self.conv4 = nn.Conv2D( |
|
channels[3], channels[4], kernel_size, padding=0, bias_attr=True) |
|
|
|
def forward(self, src, pha, err, hid, tri): |
|
''' |
|
Args: |
|
src: (B, 3, H, W) full resolution source image. |
|
pha: (B, 1, Hc, Wc) coarse alpha prediction. |
|
err: (B, 1, Hc, Hc) coarse error prediction. |
|
hid: (B, 32, Hc, Hc) coarse hidden encoding. |
|
tri: (B, 1, Hc, Hc) trimap prediction. |
|
''' |
|
h_full, w_full = paddle.shape(src)[2:] |
|
h_half, w_half = h_full // 2, w_full // 2 |
|
h_quat, w_quat = h_full // 4, w_full // 4 |
|
|
|
x = paddle.concat([hid, pha, tri], axis=1) |
|
x = F.interpolate( |
|
x, |
|
paddle.concat((h_half, w_half)), |
|
mode='bilinear', |
|
align_corners=False) |
|
y = F.interpolate( |
|
src, |
|
paddle.concat((h_half, w_half)), |
|
mode='bilinear', |
|
align_corners=False) |
|
|
|
if self.kernel_size == 3: |
|
x = F.pad(x, [3, 3, 3, 3]) |
|
y = F.pad(y, [3, 3, 3, 3]) |
|
|
|
x = self.conv1(paddle.concat([x, y], axis=1)) |
|
x = self.conv2(x) |
|
|
|
if self.kernel_size == 3: |
|
x = F.interpolate(x, paddle.concat((h_full + 4, w_full + 4))) |
|
y = F.pad(src, [2, 2, 2, 2]) |
|
else: |
|
x = F.interpolate( |
|
x, paddle.concat((h_full, w_full)), mode='nearest') |
|
y = src |
|
|
|
x = self.conv3(paddle.concat([x, y], axis=1)) |
|
x = self.conv4(x) |
|
|
|
pha = x |
|
return pha |
|
|