# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # The gca code was heavily based on https://github.com/Yaoyi-Li/GCA-Matting # and https://github.com/open-mmlab/mmediting import paddle import paddle.nn as nn import paddle.nn.functional as F from paddleseg.cvlibs import param_init class GuidedCxtAtten(nn.Layer): def __init__(self, out_channels, guidance_channels, kernel_size=3, stride=1, rate=2): super().__init__() self.kernel_size = kernel_size self.rate = rate self.stride = stride self.guidance_conv = nn.Conv2D( in_channels=guidance_channels, out_channels=guidance_channels // 2, kernel_size=1) self.out_conv = nn.Sequential( nn.Conv2D( in_channels=out_channels, out_channels=out_channels, kernel_size=1, bias_attr=False), nn.BatchNorm(out_channels)) self.init_weight() def init_weight(self): param_init.xavier_uniform(self.guidance_conv.weight) param_init.constant_init(self.guidance_conv.bias, value=0.0) param_init.xavier_uniform(self.out_conv[0].weight) param_init.constant_init(self.out_conv[1].weight, value=1e-3) param_init.constant_init(self.out_conv[1].bias, value=0.0) def forward(self, img_feat, alpha_feat, unknown=None, softmax_scale=1.): img_feat = self.guidance_conv(img_feat) img_feat = F.interpolate( img_feat, scale_factor=1 / self.rate, mode='nearest') # process unknown mask unknown, softmax_scale = self.process_unknown_mask(unknown, img_feat, softmax_scale) img_ps, alpha_ps, unknown_ps = self.extract_feature_maps_patches( img_feat, alpha_feat, unknown) self_mask = self.get_self_correlation_mask(img_feat) # split tensors by batch dimension; tuple is returned img_groups = paddle.split(img_feat, 1, axis=0) img_ps_groups = paddle.split(img_ps, 1, axis=0) alpha_ps_groups = paddle.split(alpha_ps, 1, axis=0) unknown_ps_groups = paddle.split(unknown_ps, 1, axis=0) scale_groups = paddle.split(softmax_scale, 1, axis=0) groups = (img_groups, img_ps_groups, alpha_ps_groups, unknown_ps_groups, scale_groups) y = [] for img_i, img_ps_i, alpha_ps_i, unknown_ps_i, scale_i in zip(*groups): # conv for compare similarity_map = self.compute_similarity_map(img_i, img_ps_i) gca_score = self.compute_guided_attention_score( similarity_map, unknown_ps_i, scale_i, self_mask) yi = self.propagate_alpha_feature(gca_score, alpha_ps_i) y.append(yi) y = paddle.concat(y, axis=0) # back to the mini-batch y = paddle.reshape(y, alpha_feat.shape) y = self.out_conv(y) + alpha_feat return y def extract_feature_maps_patches(self, img_feat, alpha_feat, unknown): # extract image feature patches with shape: # (N, img_h*img_w, img_c, img_ks, img_ks) img_ks = self.kernel_size img_ps = self.extract_patches(img_feat, img_ks, self.stride) # extract alpha feature patches with shape: # (N, img_h*img_w, alpha_c, alpha_ks, alpha_ks) alpha_ps = self.extract_patches(alpha_feat, self.rate * 2, self.rate) # extract unknown mask patches with shape: (N, img_h*img_w, 1, 1) unknown_ps = self.extract_patches(unknown, img_ks, self.stride) unknown_ps = unknown_ps.squeeze(axis=2) # squeeze channel dimension unknown_ps = unknown_ps.mean(axis=[2, 3], keepdim=True) return img_ps, alpha_ps, unknown_ps def extract_patches(self, x, kernel_size, stride): n, c, _, _ = x.shape x = self.pad(x, kernel_size, stride) x = F.unfold(x, [kernel_size, kernel_size], strides=[stride, stride]) x = paddle.transpose(x, (0, 2, 1)) x = paddle.reshape(x, (n, -1, c, kernel_size, kernel_size)) return x def pad(self, x, kernel_size, stride): left = (kernel_size - stride + 1) // 2 right = (kernel_size - stride) // 2 pad = (left, right, left, right) return F.pad(x, pad, mode='reflect') def compute_guided_attention_score(self, similarity_map, unknown_ps, scale, self_mask): # scale the correlation with predicted scale factor for known and # unknown area unknown_scale, known_scale = scale[0] out = similarity_map * ( unknown_scale * paddle.greater_than(unknown_ps, paddle.to_tensor([0.])) + known_scale * paddle.less_equal(unknown_ps, paddle.to_tensor([0.]))) # mask itself, self-mask only applied to unknown area out = out + self_mask * unknown_ps gca_score = F.softmax(out, axis=1) return gca_score def propagate_alpha_feature(self, gca_score, alpha_ps): alpha_ps = alpha_ps[0] # squeeze dim 0 if self.rate == 1: gca_score = self.pad(gca_score, kernel_size=2, stride=1) alpha_ps = paddle.transpose(alpha_ps, (1, 0, 2, 3)) out = F.conv2d(gca_score, alpha_ps) / 4. else: out = F.conv2d_transpose( gca_score, alpha_ps, stride=self.rate, padding=1) / 4. return out def compute_similarity_map(self, img_feat, img_ps): img_ps = img_ps[0] # squeeze dim 0 # convolve the feature to get correlation (similarity) map img_ps_normed = img_ps / paddle.clip(self.l2_norm(img_ps), 1e-4) img_feat = F.pad(img_feat, (1, 1, 1, 1), mode='reflect') similarity_map = F.conv2d(img_feat, img_ps_normed) return similarity_map def get_self_correlation_mask(self, img_feat): _, _, h, w = img_feat.shape self_mask = F.one_hot( paddle.reshape(paddle.arange(h * w), (h, w)), num_classes=int(h * w)) self_mask = paddle.transpose(self_mask, (2, 0, 1)) self_mask = paddle.reshape(self_mask, (1, h * w, h, w)) return self_mask * (-1e4) def process_unknown_mask(self, unknown, img_feat, softmax_scale): n, _, h, w = img_feat.shape if unknown is not None: unknown = unknown.clone() unknown = F.interpolate( unknown, scale_factor=1 / self.rate, mode='nearest') unknown_mean = unknown.mean(axis=[2, 3]) known_mean = 1 - unknown_mean unknown_scale = paddle.clip( paddle.sqrt(unknown_mean / known_mean), 0.1, 10) known_scale = paddle.clip( paddle.sqrt(known_mean / unknown_mean), 0.1, 10) softmax_scale = paddle.concat([unknown_scale, known_scale], axis=1) else: unknown = paddle.ones([n, 1, h, w]) softmax_scale = paddle.reshape( paddle.to_tensor([softmax_scale, softmax_scale]), (1, 2)) softmax_scale = paddle.expand(softmax_scale, (n, 2)) return unknown, softmax_scale @staticmethod def l2_norm(x): x = x**2 x = x.sum(axis=[1, 2, 3], keepdim=True) return paddle.sqrt(x)