vivym's picture
init
4a582ec
raw
history blame
No virus
8.01 kB
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The gca code was heavily based on https://github.com/Yaoyi-Li/GCA-Matting
# and https://github.com/open-mmlab/mmediting
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleseg.cvlibs import param_init
class GuidedCxtAtten(nn.Layer):
def __init__(self,
out_channels,
guidance_channels,
kernel_size=3,
stride=1,
rate=2):
super().__init__()
self.kernel_size = kernel_size
self.rate = rate
self.stride = stride
self.guidance_conv = nn.Conv2D(
in_channels=guidance_channels,
out_channels=guidance_channels // 2,
kernel_size=1)
self.out_conv = nn.Sequential(
nn.Conv2D(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=1,
bias_attr=False),
nn.BatchNorm(out_channels))
self.init_weight()
def init_weight(self):
param_init.xavier_uniform(self.guidance_conv.weight)
param_init.constant_init(self.guidance_conv.bias, value=0.0)
param_init.xavier_uniform(self.out_conv[0].weight)
param_init.constant_init(self.out_conv[1].weight, value=1e-3)
param_init.constant_init(self.out_conv[1].bias, value=0.0)
def forward(self, img_feat, alpha_feat, unknown=None, softmax_scale=1.):
img_feat = self.guidance_conv(img_feat)
img_feat = F.interpolate(
img_feat, scale_factor=1 / self.rate, mode='nearest')
# process unknown mask
unknown, softmax_scale = self.process_unknown_mask(unknown, img_feat,
softmax_scale)
img_ps, alpha_ps, unknown_ps = self.extract_feature_maps_patches(
img_feat, alpha_feat, unknown)
self_mask = self.get_self_correlation_mask(img_feat)
# split tensors by batch dimension; tuple is returned
img_groups = paddle.split(img_feat, 1, axis=0)
img_ps_groups = paddle.split(img_ps, 1, axis=0)
alpha_ps_groups = paddle.split(alpha_ps, 1, axis=0)
unknown_ps_groups = paddle.split(unknown_ps, 1, axis=0)
scale_groups = paddle.split(softmax_scale, 1, axis=0)
groups = (img_groups, img_ps_groups, alpha_ps_groups, unknown_ps_groups,
scale_groups)
y = []
for img_i, img_ps_i, alpha_ps_i, unknown_ps_i, scale_i in zip(*groups):
# conv for compare
similarity_map = self.compute_similarity_map(img_i, img_ps_i)
gca_score = self.compute_guided_attention_score(
similarity_map, unknown_ps_i, scale_i, self_mask)
yi = self.propagate_alpha_feature(gca_score, alpha_ps_i)
y.append(yi)
y = paddle.concat(y, axis=0) # back to the mini-batch
y = paddle.reshape(y, alpha_feat.shape)
y = self.out_conv(y) + alpha_feat
return y
def extract_feature_maps_patches(self, img_feat, alpha_feat, unknown):
# extract image feature patches with shape:
# (N, img_h*img_w, img_c, img_ks, img_ks)
img_ks = self.kernel_size
img_ps = self.extract_patches(img_feat, img_ks, self.stride)
# extract alpha feature patches with shape:
# (N, img_h*img_w, alpha_c, alpha_ks, alpha_ks)
alpha_ps = self.extract_patches(alpha_feat, self.rate * 2, self.rate)
# extract unknown mask patches with shape: (N, img_h*img_w, 1, 1)
unknown_ps = self.extract_patches(unknown, img_ks, self.stride)
unknown_ps = unknown_ps.squeeze(axis=2) # squeeze channel dimension
unknown_ps = unknown_ps.mean(axis=[2, 3], keepdim=True)
return img_ps, alpha_ps, unknown_ps
def extract_patches(self, x, kernel_size, stride):
n, c, _, _ = x.shape
x = self.pad(x, kernel_size, stride)
x = F.unfold(x, [kernel_size, kernel_size], strides=[stride, stride])
x = paddle.transpose(x, (0, 2, 1))
x = paddle.reshape(x, (n, -1, c, kernel_size, kernel_size))
return x
def pad(self, x, kernel_size, stride):
left = (kernel_size - stride + 1) // 2
right = (kernel_size - stride) // 2
pad = (left, right, left, right)
return F.pad(x, pad, mode='reflect')
def compute_guided_attention_score(self, similarity_map, unknown_ps, scale,
self_mask):
# scale the correlation with predicted scale factor for known and
# unknown area
unknown_scale, known_scale = scale[0]
out = similarity_map * (
unknown_scale * paddle.greater_than(unknown_ps,
paddle.to_tensor([0.])) +
known_scale * paddle.less_equal(unknown_ps, paddle.to_tensor([0.])))
# mask itself, self-mask only applied to unknown area
out = out + self_mask * unknown_ps
gca_score = F.softmax(out, axis=1)
return gca_score
def propagate_alpha_feature(self, gca_score, alpha_ps):
alpha_ps = alpha_ps[0] # squeeze dim 0
if self.rate == 1:
gca_score = self.pad(gca_score, kernel_size=2, stride=1)
alpha_ps = paddle.transpose(alpha_ps, (1, 0, 2, 3))
out = F.conv2d(gca_score, alpha_ps) / 4.
else:
out = F.conv2d_transpose(
gca_score, alpha_ps, stride=self.rate, padding=1) / 4.
return out
def compute_similarity_map(self, img_feat, img_ps):
img_ps = img_ps[0] # squeeze dim 0
# convolve the feature to get correlation (similarity) map
img_ps_normed = img_ps / paddle.clip(self.l2_norm(img_ps), 1e-4)
img_feat = F.pad(img_feat, (1, 1, 1, 1), mode='reflect')
similarity_map = F.conv2d(img_feat, img_ps_normed)
return similarity_map
def get_self_correlation_mask(self, img_feat):
_, _, h, w = img_feat.shape
self_mask = F.one_hot(
paddle.reshape(paddle.arange(h * w), (h, w)),
num_classes=int(h * w))
self_mask = paddle.transpose(self_mask, (2, 0, 1))
self_mask = paddle.reshape(self_mask, (1, h * w, h, w))
return self_mask * (-1e4)
def process_unknown_mask(self, unknown, img_feat, softmax_scale):
n, _, h, w = img_feat.shape
if unknown is not None:
unknown = unknown.clone()
unknown = F.interpolate(
unknown, scale_factor=1 / self.rate, mode='nearest')
unknown_mean = unknown.mean(axis=[2, 3])
known_mean = 1 - unknown_mean
unknown_scale = paddle.clip(
paddle.sqrt(unknown_mean / known_mean), 0.1, 10)
known_scale = paddle.clip(
paddle.sqrt(known_mean / unknown_mean), 0.1, 10)
softmax_scale = paddle.concat([unknown_scale, known_scale], axis=1)
else:
unknown = paddle.ones([n, 1, h, w])
softmax_scale = paddle.reshape(
paddle.to_tensor([softmax_scale, softmax_scale]), (1, 2))
softmax_scale = paddle.expand(softmax_scale, (n, 2))
return unknown, softmax_scale
@staticmethod
def l2_norm(x):
x = x**2
x = x.sum(axis=[1, 2, 3], keepdim=True)
return paddle.sqrt(x)