Spaces:
Sleeping
Sleeping
import functools | |
import torch | |
import torch.nn as nn | |
from .base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder | |
def convert_flow_to_deformation(flow): | |
r"""convert flow fields to deformations. | |
Args: | |
flow (tensor): Flow field obtained by the model | |
Returns: | |
deformation (tensor): The deformation used for warpping | |
""" | |
b,c,h,w = flow.shape | |
flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1) | |
grid = make_coordinate_grid(flow) | |
deformation = grid + flow_norm.permute(0,2,3,1) | |
return deformation | |
def make_coordinate_grid(flow): | |
r"""obtain coordinate grid with the same size as the flow filed. | |
Args: | |
flow (tensor): Flow field obtained by the model | |
Returns: | |
grid (tensor): The grid with the same size as the input flow | |
""" | |
b,c,h,w = flow.shape | |
x = torch.arange(w).to(flow) | |
y = torch.arange(h).to(flow) | |
x = (2 * (x / (w - 1)) - 1) | |
y = (2 * (y / (h - 1)) - 1) | |
yy = y.view(-1, 1).repeat(1, w) | |
xx = x.view(1, -1).repeat(h, 1) | |
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) | |
meshed = meshed.expand(b, -1, -1, -1) | |
return meshed | |
def warp_image(source_image, deformation): | |
r"""warp the input image according to the deformation | |
Args: | |
source_image (tensor): source images to be warpped | |
deformation (tensor): deformations used to warp the images; value in range (-1, 1) | |
Returns: | |
output (tensor): the warpped images | |
""" | |
_, h_old, w_old, _ = deformation.shape | |
_, _, h, w = source_image.shape | |
if h_old != h or w_old != w: | |
deformation = deformation.permute(0, 3, 1, 2) | |
deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear') | |
deformation = deformation.permute(0, 2, 3, 1) | |
return torch.nn.functional.grid_sample(source_image, deformation) | |
class FaceGenerator(nn.Module): | |
def __init__( | |
self, | |
mapping_net, | |
warpping_net, | |
editing_net, | |
common | |
): | |
super(FaceGenerator, self).__init__() | |
self.mapping_net = MappingNet(**mapping_net) | |
self.warpping_net = WarpingNet(**warpping_net, **common) | |
self.editing_net = EditingNet(**editing_net, **common) | |
def forward( | |
self, | |
input_image, | |
driving_source, | |
stage=None | |
): | |
if stage == 'warp': | |
descriptor = self.mapping_net(driving_source) | |
output = self.warpping_net(input_image, descriptor) | |
else: | |
descriptor = self.mapping_net(driving_source) | |
output = self.warpping_net(input_image, descriptor) | |
output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor) | |
return output | |
class MappingNet(nn.Module): | |
def __init__(self, coeff_nc, descriptor_nc, layer): | |
super( MappingNet, self).__init__() | |
self.layer = layer | |
nonlinearity = nn.LeakyReLU(0.1) | |
self.first = nn.Sequential( | |
torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) | |
for i in range(layer): | |
net = nn.Sequential(nonlinearity, | |
torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) | |
setattr(self, 'encoder' + str(i), net) | |
self.pooling = nn.AdaptiveAvgPool1d(1) | |
self.output_nc = descriptor_nc | |
def forward(self, input_3dmm): | |
out = self.first(input_3dmm) | |
for i in range(self.layer): | |
model = getattr(self, 'encoder' + str(i)) | |
out = model(out) + out[:,:,3:-3] | |
out = self.pooling(out) | |
return out | |
class WarpingNet(nn.Module): | |
def __init__( | |
self, | |
image_nc, | |
descriptor_nc, | |
base_nc, | |
max_nc, | |
encoder_layer, | |
decoder_layer, | |
use_spect | |
): | |
super( WarpingNet, self).__init__() | |
nonlinearity = nn.LeakyReLU(0.1) | |
norm_layer = functools.partial(LayerNorm2d, affine=True) | |
kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect} | |
self.descriptor_nc = descriptor_nc | |
self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc, | |
max_nc, encoder_layer, decoder_layer, **kwargs) | |
self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc), | |
nonlinearity, | |
nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3)) | |
self.pool = nn.AdaptiveAvgPool2d(1) | |
def forward(self, input_image, descriptor): | |
final_output={} | |
output = self.hourglass(input_image, descriptor) | |
final_output['flow_field'] = self.flow_out(output) | |
deformation = convert_flow_to_deformation(final_output['flow_field']) | |
final_output['warp_image'] = warp_image(input_image, deformation) | |
return final_output | |
class EditingNet(nn.Module): | |
def __init__( | |
self, | |
image_nc, | |
descriptor_nc, | |
layer, | |
base_nc, | |
max_nc, | |
num_res_blocks, | |
use_spect): | |
super(EditingNet, self).__init__() | |
nonlinearity = nn.LeakyReLU(0.1) | |
norm_layer = functools.partial(LayerNorm2d, affine=True) | |
kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect} | |
self.descriptor_nc = descriptor_nc | |
# encoder part | |
self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs) | |
self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs) | |
def forward(self, input_image, warp_image, descriptor): | |
x = torch.cat([input_image, warp_image], 1) | |
x = self.encoder(x) | |
gen_image = self.decoder(x, descriptor) | |
return gen_image | |