|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
import paddle |
|
import paddle.nn as nn |
|
import paddle.nn.functional as F |
|
from paddleseg.models import layers |
|
from paddleseg import utils |
|
from paddleseg.cvlibs import manager |
|
|
|
from ppmatting.models.losses import MRSD |
|
|
|
|
|
@manager.MODELS.add_component |
|
class DIM(nn.Layer): |
|
""" |
|
The DIM implementation based on PaddlePaddle. |
|
|
|
The original article refers to |
|
Ning Xu, et, al. "Deep Image Matting" |
|
(https://arxiv.org/pdf/1908.07919.pdf). |
|
|
|
Args: |
|
backbone: backbone model. |
|
stage (int, optional): The stage of model. Defautl: 3. |
|
decoder_input_channels(int, optional): The channel of decoder input. Default: 512. |
|
pretrained(str, optional): The path of pretrianed model. Defautl: None. |
|
|
|
""" |
|
|
|
def __init__(self, |
|
backbone, |
|
stage=3, |
|
decoder_input_channels=512, |
|
pretrained=None): |
|
super().__init__() |
|
self.backbone = backbone |
|
self.pretrained = pretrained |
|
self.stage = stage |
|
self.loss_func_dict = None |
|
|
|
decoder_output_channels = [64, 128, 256, 512] |
|
self.decoder = Decoder( |
|
input_channels=decoder_input_channels, |
|
output_channels=decoder_output_channels) |
|
if self.stage == 2: |
|
for param in self.backbone.parameters(): |
|
param.stop_gradient = True |
|
for param in self.decoder.parameters(): |
|
param.stop_gradient = True |
|
if self.stage >= 2: |
|
self.refine = Refine() |
|
self.init_weight() |
|
|
|
def forward(self, inputs): |
|
input_shape = paddle.shape(inputs['img'])[-2:] |
|
x = paddle.concat([inputs['img'], inputs['trimap'] / 255], axis=1) |
|
fea_list = self.backbone(x) |
|
|
|
|
|
up_shape = [] |
|
for i in range(5): |
|
up_shape.append(paddle.shape(fea_list[i])[-2:]) |
|
alpha_raw = self.decoder(fea_list, up_shape) |
|
alpha_raw = F.interpolate( |
|
alpha_raw, input_shape, mode='bilinear', align_corners=False) |
|
logit_dict = {'alpha_raw': alpha_raw} |
|
if self.stage < 2: |
|
return logit_dict |
|
|
|
if self.stage >= 2: |
|
|
|
refine_input = paddle.concat([inputs['img'], alpha_raw], axis=1) |
|
alpha_refine = self.refine(refine_input) |
|
|
|
|
|
alpha_pred = alpha_refine + alpha_raw |
|
alpha_pred = F.interpolate( |
|
alpha_pred, input_shape, mode='bilinear', align_corners=False) |
|
if not self.training: |
|
alpha_pred = paddle.clip(alpha_pred, min=0, max=1) |
|
logit_dict['alpha_pred'] = alpha_pred |
|
if self.training: |
|
loss_dict = self.loss(logit_dict, inputs) |
|
return logit_dict, loss_dict |
|
else: |
|
return alpha_pred |
|
|
|
def loss(self, logit_dict, label_dict, loss_func_dict=None): |
|
if loss_func_dict is None: |
|
if self.loss_func_dict is None: |
|
self.loss_func_dict = defaultdict(list) |
|
self.loss_func_dict['alpha_raw'].append(MRSD()) |
|
self.loss_func_dict['comp'].append(MRSD()) |
|
self.loss_func_dict['alpha_pred'].append(MRSD()) |
|
else: |
|
self.loss_func_dict = loss_func_dict |
|
|
|
loss = {} |
|
mask = label_dict['trimap'] == 128 |
|
loss['all'] = 0 |
|
|
|
if self.stage != 2: |
|
loss['alpha_raw'] = self.loss_func_dict['alpha_raw'][0]( |
|
logit_dict['alpha_raw'], label_dict['alpha'], mask) |
|
loss['alpha_raw'] = 0.5 * loss['alpha_raw'] |
|
loss['all'] = loss['all'] + loss['alpha_raw'] |
|
|
|
if self.stage == 1 or self.stage == 3: |
|
comp_pred = logit_dict['alpha_raw'] * label_dict['fg'] + \ |
|
(1 - logit_dict['alpha_raw']) * label_dict['bg'] |
|
loss['comp'] = self.loss_func_dict['comp'][0]( |
|
comp_pred, label_dict['img'], mask) |
|
loss['comp'] = 0.5 * loss['comp'] |
|
loss['all'] = loss['all'] + loss['comp'] |
|
|
|
if self.stage == 2 or self.stage == 3: |
|
loss['alpha_pred'] = self.loss_func_dict['alpha_pred'][0]( |
|
logit_dict['alpha_pred'], label_dict['alpha'], mask) |
|
loss['all'] = loss['all'] + loss['alpha_pred'] |
|
|
|
return loss |
|
|
|
def init_weight(self): |
|
if self.pretrained is not None: |
|
utils.load_entire_model(self, self.pretrained) |
|
|
|
|
|
|
|
class Up(nn.Layer): |
|
def __init__(self, input_channels, output_channels): |
|
super().__init__() |
|
self.conv = layers.ConvBNReLU( |
|
input_channels, |
|
output_channels, |
|
kernel_size=5, |
|
padding=2, |
|
bias_attr=False) |
|
|
|
def forward(self, x, skip, output_shape): |
|
x = F.interpolate( |
|
x, size=output_shape, mode='bilinear', align_corners=False) |
|
x = x + skip |
|
x = self.conv(x) |
|
x = F.relu(x) |
|
|
|
return x |
|
|
|
|
|
class Decoder(nn.Layer): |
|
def __init__(self, input_channels, output_channels=(64, 128, 256, 512)): |
|
super().__init__() |
|
self.deconv6 = nn.Conv2D( |
|
input_channels, input_channels, kernel_size=1, bias_attr=False) |
|
self.deconv5 = Up(input_channels, output_channels[-1]) |
|
self.deconv4 = Up(output_channels[-1], output_channels[-2]) |
|
self.deconv3 = Up(output_channels[-2], output_channels[-3]) |
|
self.deconv2 = Up(output_channels[-3], output_channels[-4]) |
|
self.deconv1 = Up(output_channels[-4], 64) |
|
|
|
self.alpha_conv = nn.Conv2D( |
|
64, 1, kernel_size=5, padding=2, bias_attr=False) |
|
|
|
def forward(self, fea_list, shape_list): |
|
x = fea_list[-1] |
|
x = self.deconv6(x) |
|
x = self.deconv5(x, fea_list[4], shape_list[4]) |
|
x = self.deconv4(x, fea_list[3], shape_list[3]) |
|
x = self.deconv3(x, fea_list[2], shape_list[2]) |
|
x = self.deconv2(x, fea_list[1], shape_list[1]) |
|
x = self.deconv1(x, fea_list[0], shape_list[0]) |
|
alpha = self.alpha_conv(x) |
|
alpha = F.sigmoid(alpha) |
|
|
|
return alpha |
|
|
|
|
|
class Refine(nn.Layer): |
|
def __init__(self): |
|
super().__init__() |
|
self.conv1 = layers.ConvBNReLU( |
|
4, 64, kernel_size=3, padding=1, bias_attr=False) |
|
self.conv2 = layers.ConvBNReLU( |
|
64, 64, kernel_size=3, padding=1, bias_attr=False) |
|
self.conv3 = layers.ConvBNReLU( |
|
64, 64, kernel_size=3, padding=1, bias_attr=False) |
|
self.alpha_pred = layers.ConvBNReLU( |
|
64, 1, kernel_size=3, padding=1, bias_attr=False) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = self.conv2(x) |
|
x = self.conv3(x) |
|
alpha = self.alpha_pred(x) |
|
|
|
return alpha |
|
|