|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
import time |
|
|
|
import paddle |
|
import paddle.nn as nn |
|
import paddle.nn.functional as F |
|
import paddleseg |
|
from paddleseg.models import layers |
|
from paddleseg import utils |
|
from paddleseg.cvlibs import manager |
|
|
|
from ppmatting.models.losses import MRSD, GradientLoss |
|
from ppmatting.models.backbone import resnet_vd |
|
|
|
|
|
@manager.MODELS.add_component |
|
class PPMatting(nn.Layer): |
|
""" |
|
The PPMattinh implementation based on PaddlePaddle. |
|
|
|
The original article refers to |
|
Guowei Chen, et, al. "PP-Matting: High-Accuracy Natural Image Matting" |
|
(https://arxiv.org/pdf/2204.09433.pdf). |
|
|
|
Args: |
|
backbone: backbone model. |
|
pretrained(str, optional): The path of pretrianed model. Defautl: None. |
|
|
|
""" |
|
|
|
def __init__(self, backbone, pretrained=None): |
|
super().__init__() |
|
self.backbone = backbone |
|
self.pretrained = pretrained |
|
self.loss_func_dict = self.get_loss_func_dict() |
|
|
|
self.backbone_channels = backbone.feat_channels |
|
|
|
self.scb = SCB(self.backbone_channels[-1]) |
|
|
|
self.hrdb = HRDB( |
|
self.backbone_channels[0] + self.backbone_channels[1], |
|
scb_channels=self.scb.out_channels, |
|
gf_index=[0, 2, 4]) |
|
|
|
self.init_weight() |
|
|
|
def forward(self, inputs): |
|
x = inputs['img'] |
|
input_shape = paddle.shape(x) |
|
fea_list = self.backbone(x) |
|
|
|
scb_logits = self.scb(fea_list[-1]) |
|
semantic_map = F.softmax(scb_logits[-1], axis=1) |
|
|
|
fea0 = F.interpolate( |
|
fea_list[0], input_shape[2:], mode='bilinear', align_corners=False) |
|
fea1 = F.interpolate( |
|
fea_list[1], input_shape[2:], mode='bilinear', align_corners=False) |
|
hrdb_input = paddle.concat([fea0, fea1], 1) |
|
hrdb_logit = self.hrdb(hrdb_input, scb_logits) |
|
detail_map = F.sigmoid(hrdb_logit) |
|
fusion = self.fusion(semantic_map, detail_map) |
|
|
|
if self.training: |
|
logit_dict = { |
|
'semantic': semantic_map, |
|
'detail': detail_map, |
|
'fusion': fusion |
|
} |
|
loss_dict = self.loss(logit_dict, inputs) |
|
return logit_dict, loss_dict |
|
else: |
|
return fusion |
|
|
|
def get_loss_func_dict(self): |
|
loss_func_dict = defaultdict(list) |
|
loss_func_dict['semantic'].append(nn.NLLLoss()) |
|
loss_func_dict['detail'].append(MRSD()) |
|
loss_func_dict['detail'].append(GradientLoss()) |
|
loss_func_dict['fusion'].append(MRSD()) |
|
loss_func_dict['fusion'].append(MRSD()) |
|
loss_func_dict['fusion'].append(GradientLoss()) |
|
return loss_func_dict |
|
|
|
def loss(self, logit_dict, label_dict): |
|
loss = {} |
|
|
|
|
|
|
|
semantic_label = label_dict['trimap'] |
|
semantic_label_trans = (semantic_label == 128).astype('int64') |
|
semantic_label_bg = (semantic_label == 0).astype('int64') |
|
semantic_label = semantic_label_trans + semantic_label_bg * 2 |
|
loss_semantic = self.loss_func_dict['semantic'][0]( |
|
paddle.log(logit_dict['semantic'] + 1e-6), |
|
semantic_label.squeeze(1)) |
|
loss['semantic'] = loss_semantic |
|
|
|
|
|
transparent = label_dict['trimap'] == 128 |
|
detail_alpha_loss = self.loss_func_dict['detail'][0]( |
|
logit_dict['detail'], label_dict['alpha'], transparent) |
|
|
|
detail_gradient_loss = self.loss_func_dict['detail'][1]( |
|
logit_dict['detail'], label_dict['alpha'], transparent) |
|
loss_detail = detail_alpha_loss + detail_gradient_loss |
|
loss['detail'] = loss_detail |
|
loss['detail_alpha'] = detail_alpha_loss |
|
loss['detail_gradient'] = detail_gradient_loss |
|
|
|
|
|
loss_fusion_func = self.loss_func_dict['fusion'] |
|
|
|
fusion_alpha_loss = loss_fusion_func[0](logit_dict['fusion'], |
|
label_dict['alpha']) |
|
|
|
comp_pred = logit_dict['fusion'] * label_dict['fg'] + ( |
|
1 - logit_dict['fusion']) * label_dict['bg'] |
|
comp_gt = label_dict['alpha'] * label_dict['fg'] + ( |
|
1 - label_dict['alpha']) * label_dict['bg'] |
|
fusion_composition_loss = loss_fusion_func[1](comp_pred, comp_gt) |
|
|
|
fusion_grad_loss = loss_fusion_func[2](logit_dict['fusion'], |
|
label_dict['alpha']) |
|
|
|
loss_fusion = fusion_alpha_loss + fusion_composition_loss + fusion_grad_loss |
|
loss['fusion'] = loss_fusion |
|
loss['fusion_alpha'] = fusion_alpha_loss |
|
loss['fusion_composition'] = fusion_composition_loss |
|
loss['fusion_gradient'] = fusion_grad_loss |
|
|
|
loss[ |
|
'all'] = 0.25 * loss_semantic + 0.25 * loss_detail + 0.25 * loss_fusion |
|
|
|
return loss |
|
|
|
def fusion(self, semantic_map, detail_map): |
|
|
|
|
|
|
|
index = paddle.argmax(semantic_map, axis=1, keepdim=True) |
|
transition_mask = (index == 1).astype('float32') |
|
fg = (index == 0).astype('float32') |
|
alpha = detail_map * transition_mask + fg |
|
return alpha |
|
|
|
def init_weight(self): |
|
if self.pretrained is not None: |
|
utils.load_entire_model(self, self.pretrained) |
|
|
|
|
|
class SCB(nn.Layer): |
|
def __init__(self, in_channels): |
|
super().__init__() |
|
self.in_channels = [512 + in_channels, 512, 256, 128, 128, 64] |
|
self.mid_channels = [512, 256, 128, 128, 64, 64] |
|
self.out_channels = [256, 128, 64, 64, 64, 3] |
|
|
|
self.psp_module = layers.PPModule( |
|
in_channels, |
|
512, |
|
bin_sizes=(1, 3, 5), |
|
dim_reduction=False, |
|
align_corners=False) |
|
|
|
psp_upsamples = [2, 4, 8, 16] |
|
self.psps = nn.LayerList([ |
|
self.conv_up_psp(512, self.out_channels[i], psp_upsamples[i]) |
|
for i in range(4) |
|
]) |
|
|
|
scb_list = [ |
|
self._make_stage( |
|
self.in_channels[i], |
|
self.mid_channels[i], |
|
self.out_channels[i], |
|
padding=int(i == 0) + 1, |
|
dilation=int(i == 0) + 1) |
|
for i in range(len(self.in_channels) - 1) |
|
] |
|
scb_list += [ |
|
nn.Sequential( |
|
layers.ConvBNReLU( |
|
self.in_channels[-1], self.mid_channels[-1], 3, padding=1), |
|
layers.ConvBNReLU( |
|
self.mid_channels[-1], self.mid_channels[-1], 3, padding=1), |
|
nn.Conv2D( |
|
self.mid_channels[-1], self.out_channels[-1], 3, padding=1)) |
|
] |
|
self.scb_stages = nn.LayerList(scb_list) |
|
|
|
def forward(self, x): |
|
psp_x = self.psp_module(x) |
|
psps = [psp(psp_x) for psp in self.psps] |
|
|
|
scb_logits = [] |
|
for i, scb_stage in enumerate(self.scb_stages): |
|
if i == 0: |
|
x = scb_stage(paddle.concat((psp_x, x), 1)) |
|
elif i <= len(psps): |
|
x = scb_stage(paddle.concat((psps[i - 1], x), 1)) |
|
else: |
|
x = scb_stage(x) |
|
scb_logits.append(x) |
|
return scb_logits |
|
|
|
def conv_up_psp(self, in_channels, out_channels, up_sample): |
|
return nn.Sequential( |
|
layers.ConvBNReLU( |
|
in_channels, out_channels, 3, padding=1), |
|
nn.Upsample( |
|
scale_factor=up_sample, mode='bilinear', align_corners=False)) |
|
|
|
def _make_stage(self, |
|
in_channels, |
|
mid_channels, |
|
out_channels, |
|
padding=1, |
|
dilation=1): |
|
layer_list = [ |
|
layers.ConvBNReLU( |
|
in_channels, mid_channels, 3, padding=1), layers.ConvBNReLU( |
|
mid_channels, |
|
mid_channels, |
|
3, |
|
padding=padding, |
|
dilation=dilation), layers.ConvBNReLU( |
|
mid_channels, |
|
out_channels, |
|
3, |
|
padding=padding, |
|
dilation=dilation), nn.Upsample( |
|
scale_factor=2, |
|
mode='bilinear', |
|
align_corners=False) |
|
] |
|
return nn.Sequential(*layer_list) |
|
|
|
|
|
class HRDB(nn.Layer): |
|
""" |
|
The High-Resolution Detail Branch |
|
|
|
Args: |
|
in_channels(int): The number of input channels. |
|
scb_channels(list|tuple): The channels of scb logits |
|
gf_index(list|tuple, optional): Which logit is selected as guidance flow from scb logits. Default: (0, 2, 4) |
|
""" |
|
|
|
def __init__(self, in_channels, scb_channels, gf_index=(0, 2, 4)): |
|
super().__init__() |
|
self.gf_index = gf_index |
|
self.gf_list = nn.LayerList( |
|
[nn.Conv2D(scb_channels[i], 1, 1) for i in gf_index]) |
|
|
|
channels = [64, 32, 16, 8] |
|
self.res_list = [ |
|
resnet_vd.BasicBlock( |
|
in_channels, channels[0], stride=1, shortcut=False) |
|
] |
|
self.res_list += [ |
|
resnet_vd.BasicBlock( |
|
i, i, stride=1) for i in channels[1:-1] |
|
] |
|
self.res_list = nn.LayerList(self.res_list) |
|
|
|
self.convs = nn.LayerList([ |
|
nn.Conv2D( |
|
channels[i], channels[i + 1], kernel_size=1) |
|
for i in range(len(channels) - 1) |
|
]) |
|
self.gates = nn.LayerList( |
|
[GatedSpatailConv2d(i, i) for i in channels[1:]]) |
|
|
|
self.detail_conv = nn.Conv2D(channels[-1], 1, 1, bias_attr=False) |
|
|
|
def forward(self, x, scb_logits): |
|
for i in range(len(self.res_list)): |
|
x = self.res_list[i](x) |
|
x = self.convs[i](x) |
|
gf = self.gf_list[i](scb_logits[self.gf_index[i]]) |
|
gf = F.interpolate( |
|
gf, paddle.shape(x)[-2:], mode='bilinear', align_corners=False) |
|
x = self.gates[i](x, gf) |
|
return self.detail_conv(x) |
|
|
|
|
|
class GatedSpatailConv2d(nn.Layer): |
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
dilation=1, |
|
groups=1, |
|
bias_attr=False): |
|
super().__init__() |
|
self._gate_conv = nn.Sequential( |
|
layers.SyncBatchNorm(in_channels + 1), |
|
nn.Conv2D( |
|
in_channels + 1, in_channels + 1, kernel_size=1), |
|
nn.ReLU(), |
|
nn.Conv2D( |
|
in_channels + 1, 1, kernel_size=1), |
|
layers.SyncBatchNorm(1), |
|
nn.Sigmoid()) |
|
self.conv = nn.Conv2D( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=groups, |
|
bias_attr=bias_attr) |
|
|
|
def forward(self, input_features, gating_features): |
|
cat = paddle.concat([input_features, gating_features], axis=1) |
|
alphas = self._gate_conv(cat) |
|
x = input_features * (alphas + 1) |
|
x = self.conv(x) |
|
return x |
|
|