Spaces:
Running
Running
Upload 50 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- checkpoints/30_net_gen.pth +3 -0
- checkpoints/BFM/.gitkeep +0 -0
- checkpoints/BFM/01_MorphableModel.mat +3 -0
- checkpoints/BFM/BFM_exp_idx.mat +0 -0
- checkpoints/BFM/BFM_front_idx.mat +0 -0
- checkpoints/BFM/BFM_model_front.mat +3 -0
- checkpoints/BFM/Exp_Pca.bin +3 -0
- checkpoints/BFM/facemodel_info.mat +0 -0
- checkpoints/BFM/select_vertex_id.mat +0 -0
- checkpoints/BFM/similarity_Lm3D_all.mat +0 -0
- checkpoints/BFM/std_exp.txt +1 -0
- checkpoints/DNet.pt +3 -0
- checkpoints/ENet.pth +3 -0
- checkpoints/GFPGANv1.3.pth +3 -0
- checkpoints/GPEN-BFR-512.pth +3 -0
- checkpoints/LNet.pth +3 -0
- checkpoints/ParseNet-latest.pth +3 -0
- checkpoints/RetinaFace-R50.pth +3 -0
- checkpoints/expression.mat +0 -0
- checkpoints/face3d_pretrain_epoch_20.pth +3 -0
- checkpoints/shape_predictor_68_face_landmarks.dat +3 -0
- models/DNet.py +118 -0
- models/ENet.py +139 -0
- models/LNet.py +139 -0
- models/__init__.py +37 -0
- models/__pycache__/DNet.cpython-37.pyc +0 -0
- models/__pycache__/DNet.cpython-38.pyc +0 -0
- models/__pycache__/DNet.cpython-39.pyc +0 -0
- models/__pycache__/ENet.cpython-37.pyc +0 -0
- models/__pycache__/ENet.cpython-38.pyc +0 -0
- models/__pycache__/ENet.cpython-39.pyc +0 -0
- models/__pycache__/LNet.cpython-37.pyc +0 -0
- models/__pycache__/LNet.cpython-38.pyc +0 -0
- models/__pycache__/LNet.cpython-39.pyc +0 -0
- models/__pycache__/__init__.cpython-37.pyc +0 -0
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/__init__.cpython-39.pyc +0 -0
- models/__pycache__/base_blocks.cpython-37.pyc +0 -0
- models/__pycache__/base_blocks.cpython-38.pyc +0 -0
- models/__pycache__/base_blocks.cpython-39.pyc +0 -0
- models/__pycache__/ffc.cpython-37.pyc +0 -0
- models/__pycache__/ffc.cpython-38.pyc +0 -0
- models/__pycache__/ffc.cpython-39.pyc +0 -0
- models/__pycache__/transformer.cpython-37.pyc +0 -0
- models/__pycache__/transformer.cpython-38.pyc +0 -0
- models/__pycache__/transformer.cpython-39.pyc +0 -0
- models/base_blocks.py +554 -0
- models/ffc.py +233 -0
- models/transformer.py +119 -0
.gitattributes
CHANGED
@@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
temp/temp/temp.wav filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
temp/temp/temp.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
+
checkpoints/BFM/01_MorphableModel.mat filter=lfs diff=lfs merge=lfs -text
|
38 |
+
checkpoints/BFM/BFM_model_front.mat filter=lfs diff=lfs merge=lfs -text
|
39 |
+
checkpoints/shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
|
checkpoints/30_net_gen.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4db83e1727128e2c5de27bc80d2929586535e04a709af45016a63e7cf7c46b0c
|
3 |
+
size 33877439
|
checkpoints/BFM/.gitkeep
ADDED
File without changes
|
checkpoints/BFM/01_MorphableModel.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37b1f0742db356a3b1568a8365a06f5b0fe0ab687ac1c3068c803666cbd4d8e2
|
3 |
+
size 240875364
|
checkpoints/BFM/BFM_exp_idx.mat
ADDED
Binary file (91.9 kB). View file
|
|
checkpoints/BFM/BFM_front_idx.mat
ADDED
Binary file (44.9 kB). View file
|
|
checkpoints/BFM/BFM_model_front.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ae3ff544aba3246c5f2c117f2be76fa44a7b76145326aae0bbfbfb564d4f82af
|
3 |
+
size 127170280
|
checkpoints/BFM/Exp_Pca.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7f31380e6cbdaf2aeec698db220bac4f221946e4d551d88c092d47ec49b1726
|
3 |
+
size 51086404
|
checkpoints/BFM/facemodel_info.mat
ADDED
Binary file (739 kB). View file
|
|
checkpoints/BFM/select_vertex_id.mat
ADDED
Binary file (62.3 kB). View file
|
|
checkpoints/BFM/similarity_Lm3D_all.mat
ADDED
Binary file (994 Bytes). View file
|
|
checkpoints/BFM/std_exp.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
453980 257264 263068 211890 135873 184721 47055.6 72732 62787.4 106226 56708.5 51439.8 34887.1 44378.7 51813.4 31030.7 23354.9 23128.1 19400 21827.6 22767.7 22057.4 19894.3 16172.8 17142.7 10035.3 14727.5 12972.5 10763.8 8953.93 8682.62 8941.81 6342.3 5205.3 7065.65 6083.35 6678.88 4666.63 5082.89 5134.76 4908.16 3964.93 3739.95 3180.09 2470.45 1866.62 1624.71 2423.74 1668.53 1471.65 1194.52 782.102 815.044 835.782 834.937 744.496 575.146 633.76 705.685 753.409 620.306 673.326 766.189 619.866 559.93 357.264 396.472 556.849 455.048 460.592 400.735 326.702 279.428 291.535 326.584 305.664 287.816 283.642 276.19
|
checkpoints/DNet.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:41220d2973c0ba2eab6e8f17ed00711aef5a0d76d19808f885dc0e3251df2e80
|
3 |
+
size 180424655
|
checkpoints/ENet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:967ee3ed857619cedd92b6407dc8a124cbfe763cc11cad58316fe21271a8928f
|
3 |
+
size 573261168
|
checkpoints/GFPGANv1.3.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70
|
3 |
+
size 348632874
|
checkpoints/GPEN-BFR-512.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f1002c41add95b0decad69604d80455576f7187dd99ca16bd611bcfd44c10b51
|
3 |
+
size 284085738
|
checkpoints/LNet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ae06fef0454c421b828cc53e8d4b9c92d990867a858ea7bb9661ab6cf6ab774
|
3 |
+
size 1534697728
|
checkpoints/ParseNet-latest.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
|
3 |
+
size 85331193
|
checkpoints/RetinaFace-R50.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
|
3 |
+
size 109497761
|
checkpoints/expression.mat
ADDED
Binary file (1.46 kB). View file
|
|
checkpoints/face3d_pretrain_epoch_20.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d17a6b23457b521801baae583cb6a58f7238fe6721fc3d65d76407460e9149b
|
3 |
+
size 288860037
|
checkpoints/shape_predictor_68_face_landmarks.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
|
3 |
+
size 99693937
|
models/DNet.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO
|
2 |
+
import functools
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from utils import flow_util
|
10 |
+
from models.base_blocks import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
|
11 |
+
|
12 |
+
# DNet
|
13 |
+
class DNet(nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super(DNet, self).__init__()
|
16 |
+
self.mapping_net = MappingNet()
|
17 |
+
self.warpping_net = WarpingNet()
|
18 |
+
self.editing_net = EditingNet()
|
19 |
+
|
20 |
+
def forward(self, input_image, driving_source, stage=None):
|
21 |
+
if stage == 'warp':
|
22 |
+
descriptor = self.mapping_net(driving_source)
|
23 |
+
output = self.warpping_net(input_image, descriptor)
|
24 |
+
else:
|
25 |
+
descriptor = self.mapping_net(driving_source)
|
26 |
+
output = self.warpping_net(input_image, descriptor)
|
27 |
+
output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
|
28 |
+
return output
|
29 |
+
|
30 |
+
class MappingNet(nn.Module):
|
31 |
+
def __init__(self, coeff_nc=73, descriptor_nc=256, layer=3):
|
32 |
+
super( MappingNet, self).__init__()
|
33 |
+
|
34 |
+
self.layer = layer
|
35 |
+
nonlinearity = nn.LeakyReLU(0.1)
|
36 |
+
|
37 |
+
self.first = nn.Sequential(
|
38 |
+
torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
|
39 |
+
|
40 |
+
for i in range(layer):
|
41 |
+
net = nn.Sequential(nonlinearity,
|
42 |
+
torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
|
43 |
+
setattr(self, 'encoder' + str(i), net)
|
44 |
+
|
45 |
+
self.pooling = nn.AdaptiveAvgPool1d(1)
|
46 |
+
self.output_nc = descriptor_nc
|
47 |
+
|
48 |
+
def forward(self, input_3dmm):
|
49 |
+
out = self.first(input_3dmm)
|
50 |
+
for i in range(self.layer):
|
51 |
+
model = getattr(self, 'encoder' + str(i))
|
52 |
+
out = model(out) + out[:,:,3:-3]
|
53 |
+
out = self.pooling(out)
|
54 |
+
return out
|
55 |
+
|
56 |
+
class WarpingNet(nn.Module):
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
image_nc=3,
|
60 |
+
descriptor_nc=256,
|
61 |
+
base_nc=32,
|
62 |
+
max_nc=256,
|
63 |
+
encoder_layer=5,
|
64 |
+
decoder_layer=3,
|
65 |
+
use_spect=False
|
66 |
+
):
|
67 |
+
super( WarpingNet, self).__init__()
|
68 |
+
|
69 |
+
nonlinearity = nn.LeakyReLU(0.1)
|
70 |
+
norm_layer = functools.partial(LayerNorm2d, affine=True)
|
71 |
+
kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
|
72 |
+
|
73 |
+
self.descriptor_nc = descriptor_nc
|
74 |
+
self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
|
75 |
+
max_nc, encoder_layer, decoder_layer, **kwargs)
|
76 |
+
|
77 |
+
self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
|
78 |
+
nonlinearity,
|
79 |
+
nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
|
80 |
+
|
81 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
82 |
+
|
83 |
+
def forward(self, input_image, descriptor):
|
84 |
+
final_output={}
|
85 |
+
output = self.hourglass(input_image, descriptor)
|
86 |
+
final_output['flow_field'] = self.flow_out(output)
|
87 |
+
|
88 |
+
deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
|
89 |
+
final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
|
90 |
+
return final_output
|
91 |
+
|
92 |
+
|
93 |
+
class EditingNet(nn.Module):
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
image_nc=3,
|
97 |
+
descriptor_nc=256,
|
98 |
+
layer=3,
|
99 |
+
base_nc=64,
|
100 |
+
max_nc=256,
|
101 |
+
num_res_blocks=2,
|
102 |
+
use_spect=False):
|
103 |
+
super(EditingNet, self).__init__()
|
104 |
+
|
105 |
+
nonlinearity = nn.LeakyReLU(0.1)
|
106 |
+
norm_layer = functools.partial(LayerNorm2d, affine=True)
|
107 |
+
kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
|
108 |
+
self.descriptor_nc = descriptor_nc
|
109 |
+
|
110 |
+
# encoder part
|
111 |
+
self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
|
112 |
+
self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
|
113 |
+
|
114 |
+
def forward(self, input_image, warp_image, descriptor):
|
115 |
+
x = torch.cat([input_image, warp_image], 1)
|
116 |
+
x = self.encoder(x)
|
117 |
+
gen_image = self.decoder(x, descriptor)
|
118 |
+
return gen_image
|
models/ENet.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from models.base_blocks import ResBlock, StyleConv, ToRGB
|
6 |
+
|
7 |
+
|
8 |
+
class ENet(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
num_style_feat=512,
|
12 |
+
lnet=None,
|
13 |
+
concat=False
|
14 |
+
):
|
15 |
+
super(ENet, self).__init__()
|
16 |
+
|
17 |
+
self.low_res = lnet
|
18 |
+
for param in self.low_res.parameters():
|
19 |
+
param.requires_grad = False
|
20 |
+
|
21 |
+
channel_multiplier, narrow = 2, 1
|
22 |
+
channels = {
|
23 |
+
'4': int(512 * narrow),
|
24 |
+
'8': int(512 * narrow),
|
25 |
+
'16': int(512 * narrow),
|
26 |
+
'32': int(512 * narrow),
|
27 |
+
'64': int(256 * channel_multiplier * narrow),
|
28 |
+
'128': int(128 * channel_multiplier * narrow),
|
29 |
+
'256': int(64 * channel_multiplier * narrow),
|
30 |
+
'512': int(32 * channel_multiplier * narrow),
|
31 |
+
'1024': int(16 * channel_multiplier * narrow)
|
32 |
+
}
|
33 |
+
|
34 |
+
self.log_size = 8
|
35 |
+
first_out_size = 128
|
36 |
+
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1) # 256 -> 128
|
37 |
+
|
38 |
+
# downsample
|
39 |
+
in_channels = channels[f'{first_out_size}']
|
40 |
+
self.conv_body_down = nn.ModuleList()
|
41 |
+
for i in range(8, 2, -1):
|
42 |
+
out_channels = channels[f'{2**(i - 1)}']
|
43 |
+
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
|
44 |
+
in_channels = out_channels
|
45 |
+
|
46 |
+
self.num_style_feat = num_style_feat
|
47 |
+
linear_out_channel = num_style_feat
|
48 |
+
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
49 |
+
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
|
50 |
+
|
51 |
+
self.style_convs = nn.ModuleList()
|
52 |
+
self.to_rgbs = nn.ModuleList()
|
53 |
+
self.noises = nn.Module()
|
54 |
+
|
55 |
+
self.concat = concat
|
56 |
+
if concat:
|
57 |
+
in_channels = 3 + 32 # channels['64']
|
58 |
+
else:
|
59 |
+
in_channels = 3
|
60 |
+
|
61 |
+
for i in range(7, 9): # 128, 256
|
62 |
+
out_channels = channels[f'{2**i}'] #
|
63 |
+
self.style_convs.append(
|
64 |
+
StyleConv(
|
65 |
+
in_channels,
|
66 |
+
out_channels,
|
67 |
+
kernel_size=3,
|
68 |
+
num_style_feat=num_style_feat,
|
69 |
+
demodulate=True,
|
70 |
+
sample_mode='upsample'))
|
71 |
+
self.style_convs.append(
|
72 |
+
StyleConv(
|
73 |
+
out_channels,
|
74 |
+
out_channels,
|
75 |
+
kernel_size=3,
|
76 |
+
num_style_feat=num_style_feat,
|
77 |
+
demodulate=True,
|
78 |
+
sample_mode=None))
|
79 |
+
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
|
80 |
+
in_channels = out_channels
|
81 |
+
|
82 |
+
def forward(self, audio_sequences, face_sequences, gt_sequences):
|
83 |
+
B = audio_sequences.size(0)
|
84 |
+
input_dim_size = len(face_sequences.size())
|
85 |
+
inp, ref = torch.split(face_sequences,3,dim=1)
|
86 |
+
|
87 |
+
if input_dim_size > 4:
|
88 |
+
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
|
89 |
+
inp = torch.cat([inp[:, :, i] for i in range(inp.size(2))], dim=0)
|
90 |
+
ref = torch.cat([ref[:, :, i] for i in range(ref.size(2))], dim=0)
|
91 |
+
gt_sequences = torch.cat([gt_sequences[:, :, i] for i in range(gt_sequences.size(2))], dim=0)
|
92 |
+
|
93 |
+
# get the global style
|
94 |
+
feat = F.leaky_relu_(self.conv_body_first(F.interpolate(ref, size=(256,256), mode='bilinear')), negative_slope=0.2)
|
95 |
+
for i in range(self.log_size - 2):
|
96 |
+
feat = self.conv_body_down[i](feat)
|
97 |
+
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
|
98 |
+
|
99 |
+
# style code
|
100 |
+
style_code = self.final_linear(feat.reshape(feat.size(0), -1))
|
101 |
+
style_code = style_code.reshape(style_code.size(0), -1, self.num_style_feat)
|
102 |
+
|
103 |
+
LNet_input = torch.cat([inp, gt_sequences], dim=1)
|
104 |
+
LNet_input = F.interpolate(LNet_input, size=(96,96), mode='bilinear')
|
105 |
+
|
106 |
+
if self.concat:
|
107 |
+
low_res_img, low_res_feat = self.low_res(audio_sequences, LNet_input)
|
108 |
+
low_res_img.detach()
|
109 |
+
low_res_feat.detach()
|
110 |
+
out = torch.cat([low_res_img, low_res_feat], dim=1)
|
111 |
+
|
112 |
+
else:
|
113 |
+
low_res_img = self.low_res(audio_sequences, LNet_input)
|
114 |
+
low_res_img.detach()
|
115 |
+
# 96 x 96
|
116 |
+
out = low_res_img
|
117 |
+
|
118 |
+
p2d = (2,2,2,2)
|
119 |
+
out = F.pad(out, p2d, "reflect", 0)
|
120 |
+
skip = out
|
121 |
+
|
122 |
+
for conv1, conv2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], self.to_rgbs):
|
123 |
+
out = conv1(out, style_code) # 96, 192, 384
|
124 |
+
out = conv2(out, style_code)
|
125 |
+
skip = to_rgb(out, style_code, skip)
|
126 |
+
_outputs = skip
|
127 |
+
|
128 |
+
# remove padding
|
129 |
+
_outputs = _outputs[:,:,8:-8,8:-8]
|
130 |
+
|
131 |
+
if input_dim_size > 4:
|
132 |
+
_outputs = torch.split(_outputs, B, dim=0)
|
133 |
+
outputs = torch.stack(_outputs, dim=2)
|
134 |
+
low_res_img = F.interpolate(low_res_img, outputs.size()[3:])
|
135 |
+
low_res_img = torch.split(low_res_img, B, dim=0)
|
136 |
+
low_res_img = torch.stack(low_res_img, dim=2)
|
137 |
+
else:
|
138 |
+
outputs = _outputs
|
139 |
+
return outputs, low_res_img
|
models/LNet.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from models.transformer import RETURNX, Transformer
|
6 |
+
from models.base_blocks import Conv2d, LayerNorm2d, FirstBlock2d, DownBlock2d, UpBlock2d, \
|
7 |
+
FFCADAINResBlocks, Jump, FinalBlock2d
|
8 |
+
|
9 |
+
|
10 |
+
class Visual_Encoder(nn.Module):
|
11 |
+
def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
12 |
+
super(Visual_Encoder, self).__init__()
|
13 |
+
self.layers = layers
|
14 |
+
self.first_inp = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
|
15 |
+
self.first_ref = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
|
16 |
+
for i in range(layers):
|
17 |
+
in_channels = min(ngf*(2**i), img_f)
|
18 |
+
out_channels = min(ngf*(2**(i+1)), img_f)
|
19 |
+
model_ref = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
|
20 |
+
model_inp = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
|
21 |
+
if i < 2:
|
22 |
+
ca_layer = RETURNX()
|
23 |
+
else:
|
24 |
+
ca_layer = Transformer(2**(i+1) * ngf,2,4,ngf,ngf*4)
|
25 |
+
setattr(self, 'ca' + str(i), ca_layer)
|
26 |
+
setattr(self, 'ref_down' + str(i), model_ref)
|
27 |
+
setattr(self, 'inp_down' + str(i), model_inp)
|
28 |
+
self.output_nc = out_channels * 2
|
29 |
+
|
30 |
+
def forward(self, maskGT, ref):
|
31 |
+
x_maskGT, x_ref = self.first_inp(maskGT), self.first_ref(ref)
|
32 |
+
out=[x_maskGT]
|
33 |
+
for i in range(self.layers):
|
34 |
+
model_ref = getattr(self, 'ref_down'+str(i))
|
35 |
+
model_inp = getattr(self, 'inp_down'+str(i))
|
36 |
+
ca_layer = getattr(self, 'ca'+str(i))
|
37 |
+
x_maskGT, x_ref = model_inp(x_maskGT), model_ref(x_ref)
|
38 |
+
x_maskGT = ca_layer(x_maskGT, x_ref)
|
39 |
+
if i < self.layers - 1:
|
40 |
+
out.append(x_maskGT)
|
41 |
+
else:
|
42 |
+
out.append(torch.cat([x_maskGT, x_ref], dim=1)) # concat ref features !
|
43 |
+
return out
|
44 |
+
|
45 |
+
|
46 |
+
class Decoder(nn.Module):
|
47 |
+
def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
48 |
+
super(Decoder, self).__init__()
|
49 |
+
self.layers = layers
|
50 |
+
for i in range(layers)[::-1]:
|
51 |
+
if i == layers-1:
|
52 |
+
in_channels = ngf*(2**(i+1)) * 2
|
53 |
+
else:
|
54 |
+
in_channels = min(ngf*(2**(i+1)), img_f)
|
55 |
+
out_channels = min(ngf*(2**i), img_f)
|
56 |
+
up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
|
57 |
+
res = FFCADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
|
58 |
+
jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
|
59 |
+
|
60 |
+
setattr(self, 'up' + str(i), up)
|
61 |
+
setattr(self, 'res' + str(i), res)
|
62 |
+
setattr(self, 'jump' + str(i), jump)
|
63 |
+
|
64 |
+
self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'sigmoid')
|
65 |
+
self.output_nc = out_channels
|
66 |
+
|
67 |
+
def forward(self, x, z):
|
68 |
+
out = x.pop()
|
69 |
+
for i in range(self.layers)[::-1]:
|
70 |
+
res_model = getattr(self, 'res' + str(i))
|
71 |
+
up_model = getattr(self, 'up' + str(i))
|
72 |
+
jump_model = getattr(self, 'jump' + str(i))
|
73 |
+
out = res_model(out, z)
|
74 |
+
out = up_model(out)
|
75 |
+
out = jump_model(x.pop()) + out
|
76 |
+
out_image = self.final(out)
|
77 |
+
return out_image
|
78 |
+
|
79 |
+
|
80 |
+
class LNet(nn.Module):
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
image_nc=3,
|
84 |
+
descriptor_nc=512,
|
85 |
+
layer=3,
|
86 |
+
base_nc=64,
|
87 |
+
max_nc=512,
|
88 |
+
num_res_blocks=9,
|
89 |
+
use_spect=True,
|
90 |
+
encoder=Visual_Encoder,
|
91 |
+
decoder=Decoder
|
92 |
+
):
|
93 |
+
super(LNet, self).__init__()
|
94 |
+
|
95 |
+
nonlinearity = nn.LeakyReLU(0.1)
|
96 |
+
norm_layer = functools.partial(LayerNorm2d, affine=True)
|
97 |
+
kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
|
98 |
+
self.descriptor_nc = descriptor_nc
|
99 |
+
|
100 |
+
self.encoder = encoder(image_nc, base_nc, max_nc, layer, **kwargs)
|
101 |
+
self.decoder = decoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
|
102 |
+
self.audio_encoder = nn.Sequential(
|
103 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
104 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
105 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
106 |
+
|
107 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
108 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
109 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
110 |
+
|
111 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
112 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
113 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
114 |
+
|
115 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
116 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
117 |
+
|
118 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
119 |
+
Conv2d(512, descriptor_nc, kernel_size=1, stride=1, padding=0),
|
120 |
+
)
|
121 |
+
|
122 |
+
def forward(self, audio_sequences, face_sequences):
|
123 |
+
B = audio_sequences.size(0)
|
124 |
+
input_dim_size = len(face_sequences.size())
|
125 |
+
if input_dim_size > 4:
|
126 |
+
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
|
127 |
+
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
|
128 |
+
cropped, ref = torch.split(face_sequences, 3, dim=1)
|
129 |
+
|
130 |
+
vis_feat = self.encoder(cropped, ref)
|
131 |
+
audio_feat = self.audio_encoder(audio_sequences)
|
132 |
+
_outputs = self.decoder(vis_feat, audio_feat)
|
133 |
+
|
134 |
+
if input_dim_size > 4:
|
135 |
+
_outputs = torch.split(_outputs, B, dim=0)
|
136 |
+
outputs = torch.stack(_outputs, dim=2)
|
137 |
+
else:
|
138 |
+
outputs = _outputs
|
139 |
+
return outputs
|
models/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from models.DNet import DNet
|
3 |
+
from models.LNet import LNet
|
4 |
+
from models.ENet import ENet
|
5 |
+
|
6 |
+
|
7 |
+
def _load(checkpoint_path):
|
8 |
+
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
9 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
10 |
+
return checkpoint
|
11 |
+
|
12 |
+
def load_checkpoint(path, model):
|
13 |
+
print("Load checkpoint from: {}".format(path))
|
14 |
+
checkpoint = _load(path)
|
15 |
+
s = checkpoint["state_dict"] if 'arcface' not in path else checkpoint
|
16 |
+
new_s = {}
|
17 |
+
for k, v in s.items():
|
18 |
+
if 'low_res' in k:
|
19 |
+
continue
|
20 |
+
else:
|
21 |
+
new_s[k.replace('module.', '')] = v
|
22 |
+
model.load_state_dict(new_s, strict=False)
|
23 |
+
return model
|
24 |
+
|
25 |
+
def load_network(args):
|
26 |
+
L_net = LNet()
|
27 |
+
L_net = load_checkpoint(args.LNet_path, L_net)
|
28 |
+
E_net = ENet(lnet=L_net)
|
29 |
+
model = load_checkpoint(args.ENet_path, E_net)
|
30 |
+
return model.eval()
|
31 |
+
|
32 |
+
def load_DNet(args):
|
33 |
+
D_Net = DNet()
|
34 |
+
print("Load checkpoint from: {}".format(args.DNet_path))
|
35 |
+
checkpoint = torch.load(args.DNet_path, map_location=lambda storage, loc: storage)
|
36 |
+
D_Net.load_state_dict(checkpoint['net_G_ema'], strict=False)
|
37 |
+
return D_Net.eval()
|
models/__pycache__/DNet.cpython-37.pyc
ADDED
Binary file (4.08 kB). View file
|
|
models/__pycache__/DNet.cpython-38.pyc
ADDED
Binary file (4.04 kB). View file
|
|
models/__pycache__/DNet.cpython-39.pyc
ADDED
Binary file (4.05 kB). View file
|
|
models/__pycache__/ENet.cpython-37.pyc
ADDED
Binary file (3.76 kB). View file
|
|
models/__pycache__/ENet.cpython-38.pyc
ADDED
Binary file (3.78 kB). View file
|
|
models/__pycache__/ENet.cpython-39.pyc
ADDED
Binary file (3.76 kB). View file
|
|
models/__pycache__/LNet.cpython-37.pyc
ADDED
Binary file (4.85 kB). View file
|
|
models/__pycache__/LNet.cpython-38.pyc
ADDED
Binary file (4.83 kB). View file
|
|
models/__pycache__/LNet.cpython-39.pyc
ADDED
Binary file (4.83 kB). View file
|
|
models/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (1.55 kB). View file
|
|
models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (1.54 kB). View file
|
|
models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (1.54 kB). View file
|
|
models/__pycache__/base_blocks.cpython-37.pyc
ADDED
Binary file (20.9 kB). View file
|
|
models/__pycache__/base_blocks.cpython-38.pyc
ADDED
Binary file (20.2 kB). View file
|
|
models/__pycache__/base_blocks.cpython-39.pyc
ADDED
Binary file (20.2 kB). View file
|
|
models/__pycache__/ffc.cpython-37.pyc
ADDED
Binary file (6.98 kB). View file
|
|
models/__pycache__/ffc.cpython-38.pyc
ADDED
Binary file (7.07 kB). View file
|
|
models/__pycache__/ffc.cpython-39.pyc
ADDED
Binary file (6.96 kB). View file
|
|
models/__pycache__/transformer.cpython-37.pyc
ADDED
Binary file (4.91 kB). View file
|
|
models/__pycache__/transformer.cpython-38.pyc
ADDED
Binary file (4.81 kB). View file
|
|
models/__pycache__/transformer.cpython-39.pyc
ADDED
Binary file (4.82 kB). View file
|
|
models/base_blocks.py
ADDED
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn.modules.batchnorm import BatchNorm2d
|
6 |
+
from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
|
7 |
+
|
8 |
+
from models.ffc import FFC
|
9 |
+
from basicsr.archs.arch_util import default_init_weights
|
10 |
+
|
11 |
+
|
12 |
+
class Conv2d(nn.Module):
|
13 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
14 |
+
super().__init__(*args, **kwargs)
|
15 |
+
self.conv_block = nn.Sequential(
|
16 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
17 |
+
nn.BatchNorm2d(cout)
|
18 |
+
)
|
19 |
+
self.act = nn.ReLU()
|
20 |
+
self.residual = residual
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
out = self.conv_block(x)
|
24 |
+
if self.residual:
|
25 |
+
out += x
|
26 |
+
return self.act(out)
|
27 |
+
|
28 |
+
|
29 |
+
class ResBlock(nn.Module):
|
30 |
+
def __init__(self, in_channels, out_channels, mode='down'):
|
31 |
+
super(ResBlock, self).__init__()
|
32 |
+
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
33 |
+
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
34 |
+
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
35 |
+
if mode == 'down':
|
36 |
+
self.scale_factor = 0.5
|
37 |
+
elif mode == 'up':
|
38 |
+
self.scale_factor = 2
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
|
42 |
+
# upsample/downsample
|
43 |
+
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
44 |
+
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
|
45 |
+
# skip
|
46 |
+
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
47 |
+
skip = self.skip(x)
|
48 |
+
out = out + skip
|
49 |
+
return out
|
50 |
+
|
51 |
+
|
52 |
+
class LayerNorm2d(nn.Module):
|
53 |
+
def __init__(self, n_out, affine=True):
|
54 |
+
super(LayerNorm2d, self).__init__()
|
55 |
+
self.n_out = n_out
|
56 |
+
self.affine = affine
|
57 |
+
|
58 |
+
if self.affine:
|
59 |
+
self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
|
60 |
+
self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
normalized_shape = x.size()[1:]
|
64 |
+
if self.affine:
|
65 |
+
return F.layer_norm(x, normalized_shape, \
|
66 |
+
self.weight.expand(normalized_shape),
|
67 |
+
self.bias.expand(normalized_shape))
|
68 |
+
else:
|
69 |
+
return F.layer_norm(x, normalized_shape)
|
70 |
+
|
71 |
+
|
72 |
+
def spectral_norm(module, use_spect=True):
|
73 |
+
if use_spect:
|
74 |
+
return SpectralNorm(module)
|
75 |
+
else:
|
76 |
+
return module
|
77 |
+
|
78 |
+
|
79 |
+
class FirstBlock2d(nn.Module):
|
80 |
+
def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
81 |
+
super(FirstBlock2d, self).__init__()
|
82 |
+
kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
|
83 |
+
conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
|
84 |
+
|
85 |
+
if type(norm_layer) == type(None):
|
86 |
+
self.model = nn.Sequential(conv, nonlinearity)
|
87 |
+
else:
|
88 |
+
self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
out = self.model(x)
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
class DownBlock2d(nn.Module):
|
96 |
+
def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
97 |
+
super(DownBlock2d, self).__init__()
|
98 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
99 |
+
conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
|
100 |
+
pool = nn.AvgPool2d(kernel_size=(2, 2))
|
101 |
+
|
102 |
+
if type(norm_layer) == type(None):
|
103 |
+
self.model = nn.Sequential(conv, nonlinearity, pool)
|
104 |
+
else:
|
105 |
+
self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
out = self.model(x)
|
109 |
+
return out
|
110 |
+
|
111 |
+
|
112 |
+
class UpBlock2d(nn.Module):
|
113 |
+
def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
114 |
+
super(UpBlock2d, self).__init__()
|
115 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
116 |
+
conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
|
117 |
+
if type(norm_layer) == type(None):
|
118 |
+
self.model = nn.Sequential(conv, nonlinearity)
|
119 |
+
else:
|
120 |
+
self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
out = self.model(F.interpolate(x, scale_factor=2))
|
124 |
+
return out
|
125 |
+
|
126 |
+
|
127 |
+
class ADAIN(nn.Module):
|
128 |
+
def __init__(self, norm_nc, feature_nc):
|
129 |
+
super().__init__()
|
130 |
+
|
131 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
132 |
+
|
133 |
+
nhidden = 128
|
134 |
+
use_bias=True
|
135 |
+
|
136 |
+
self.mlp_shared = nn.Sequential(
|
137 |
+
nn.Linear(feature_nc, nhidden, bias=use_bias),
|
138 |
+
nn.ReLU()
|
139 |
+
)
|
140 |
+
self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)
|
141 |
+
self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)
|
142 |
+
|
143 |
+
def forward(self, x, feature):
|
144 |
+
|
145 |
+
# Part 1. generate parameter-free normalized activations
|
146 |
+
normalized = self.param_free_norm(x)
|
147 |
+
# Part 2. produce scaling and bias conditioned on feature
|
148 |
+
feature = feature.view(feature.size(0), -1)
|
149 |
+
actv = self.mlp_shared(feature)
|
150 |
+
gamma = self.mlp_gamma(actv)
|
151 |
+
beta = self.mlp_beta(actv)
|
152 |
+
|
153 |
+
# apply scale and bias
|
154 |
+
gamma = gamma.view(*gamma.size()[:2], 1,1)
|
155 |
+
beta = beta.view(*beta.size()[:2], 1,1)
|
156 |
+
out = normalized * (1 + gamma) + beta
|
157 |
+
return out
|
158 |
+
|
159 |
+
|
160 |
+
class FineADAINResBlock2d(nn.Module):
|
161 |
+
"""
|
162 |
+
Define an Residual block for different types
|
163 |
+
"""
|
164 |
+
def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
165 |
+
super(FineADAINResBlock2d, self).__init__()
|
166 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
167 |
+
self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
|
168 |
+
self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
|
169 |
+
self.norm1 = ADAIN(input_nc, feature_nc)
|
170 |
+
self.norm2 = ADAIN(input_nc, feature_nc)
|
171 |
+
self.actvn = nonlinearity
|
172 |
+
|
173 |
+
def forward(self, x, z):
|
174 |
+
dx = self.actvn(self.norm1(self.conv1(x), z))
|
175 |
+
dx = self.norm2(self.conv2(x), z)
|
176 |
+
out = dx + x
|
177 |
+
return out
|
178 |
+
|
179 |
+
|
180 |
+
class FineADAINResBlocks(nn.Module):
|
181 |
+
def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
182 |
+
super(FineADAINResBlocks, self).__init__()
|
183 |
+
self.num_block = num_block
|
184 |
+
for i in range(num_block):
|
185 |
+
model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
|
186 |
+
setattr(self, 'res'+str(i), model)
|
187 |
+
|
188 |
+
def forward(self, x, z):
|
189 |
+
for i in range(self.num_block):
|
190 |
+
model = getattr(self, 'res'+str(i))
|
191 |
+
x = model(x, z)
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class ADAINEncoderBlock(nn.Module):
|
196 |
+
def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
197 |
+
super(ADAINEncoderBlock, self).__init__()
|
198 |
+
kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
|
199 |
+
kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
200 |
+
|
201 |
+
self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect)
|
202 |
+
self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)
|
203 |
+
|
204 |
+
|
205 |
+
self.norm_0 = ADAIN(input_nc, feature_nc)
|
206 |
+
self.norm_1 = ADAIN(output_nc, feature_nc)
|
207 |
+
self.actvn = nonlinearity
|
208 |
+
|
209 |
+
def forward(self, x, z):
|
210 |
+
x = self.conv_0(self.actvn(self.norm_0(x, z)))
|
211 |
+
x = self.conv_1(self.actvn(self.norm_1(x, z)))
|
212 |
+
return x
|
213 |
+
|
214 |
+
|
215 |
+
class ADAINDecoderBlock(nn.Module):
|
216 |
+
def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
217 |
+
super(ADAINDecoderBlock, self).__init__()
|
218 |
+
# Attributes
|
219 |
+
self.actvn = nonlinearity
|
220 |
+
hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc
|
221 |
+
|
222 |
+
kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
|
223 |
+
if use_transpose:
|
224 |
+
kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
|
225 |
+
else:
|
226 |
+
kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}
|
227 |
+
|
228 |
+
# create conv layers
|
229 |
+
self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
|
230 |
+
if use_transpose:
|
231 |
+
self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
|
232 |
+
self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
|
233 |
+
else:
|
234 |
+
self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
|
235 |
+
nn.Upsample(scale_factor=2))
|
236 |
+
self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
|
237 |
+
nn.Upsample(scale_factor=2))
|
238 |
+
# define normalization layers
|
239 |
+
self.norm_0 = ADAIN(input_nc, feature_nc)
|
240 |
+
self.norm_1 = ADAIN(hidden_nc, feature_nc)
|
241 |
+
self.norm_s = ADAIN(input_nc, feature_nc)
|
242 |
+
|
243 |
+
def forward(self, x, z):
|
244 |
+
x_s = self.shortcut(x, z)
|
245 |
+
dx = self.conv_0(self.actvn(self.norm_0(x, z)))
|
246 |
+
dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
|
247 |
+
out = x_s + dx
|
248 |
+
return out
|
249 |
+
|
250 |
+
def shortcut(self, x, z):
|
251 |
+
x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
|
252 |
+
return x_s
|
253 |
+
|
254 |
+
|
255 |
+
class FineEncoder(nn.Module):
|
256 |
+
"""docstring for Encoder"""
|
257 |
+
def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
258 |
+
super(FineEncoder, self).__init__()
|
259 |
+
self.layers = layers
|
260 |
+
self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
|
261 |
+
for i in range(layers):
|
262 |
+
in_channels = min(ngf*(2**i), img_f)
|
263 |
+
out_channels = min(ngf*(2**(i+1)), img_f)
|
264 |
+
model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
|
265 |
+
setattr(self, 'down' + str(i), model)
|
266 |
+
self.output_nc = out_channels
|
267 |
+
|
268 |
+
def forward(self, x):
|
269 |
+
x = self.first(x)
|
270 |
+
out=[x]
|
271 |
+
for i in range(self.layers):
|
272 |
+
model = getattr(self, 'down'+str(i))
|
273 |
+
x = model(x)
|
274 |
+
out.append(x)
|
275 |
+
return out
|
276 |
+
|
277 |
+
|
278 |
+
class FineDecoder(nn.Module):
|
279 |
+
"""docstring for FineDecoder"""
|
280 |
+
def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
281 |
+
super(FineDecoder, self).__init__()
|
282 |
+
self.layers = layers
|
283 |
+
for i in range(layers)[::-1]:
|
284 |
+
in_channels = min(ngf*(2**(i+1)), img_f)
|
285 |
+
out_channels = min(ngf*(2**i), img_f)
|
286 |
+
up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
|
287 |
+
res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
|
288 |
+
jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
|
289 |
+
setattr(self, 'up' + str(i), up)
|
290 |
+
setattr(self, 'res' + str(i), res)
|
291 |
+
setattr(self, 'jump' + str(i), jump)
|
292 |
+
self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')
|
293 |
+
self.output_nc = out_channels
|
294 |
+
|
295 |
+
def forward(self, x, z):
|
296 |
+
out = x.pop()
|
297 |
+
for i in range(self.layers)[::-1]:
|
298 |
+
res_model = getattr(self, 'res' + str(i))
|
299 |
+
up_model = getattr(self, 'up' + str(i))
|
300 |
+
jump_model = getattr(self, 'jump' + str(i))
|
301 |
+
out = res_model(out, z)
|
302 |
+
out = up_model(out)
|
303 |
+
out = jump_model(x.pop()) + out
|
304 |
+
out_image = self.final(out)
|
305 |
+
return out_image
|
306 |
+
|
307 |
+
|
308 |
+
class ADAINEncoder(nn.Module):
|
309 |
+
def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
310 |
+
super(ADAINEncoder, self).__init__()
|
311 |
+
self.layers = layers
|
312 |
+
self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
|
313 |
+
for i in range(layers):
|
314 |
+
in_channels = min(ngf * (2**i), img_f)
|
315 |
+
out_channels = min(ngf *(2**(i+1)), img_f)
|
316 |
+
model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
|
317 |
+
setattr(self, 'encoder' + str(i), model)
|
318 |
+
self.output_nc = out_channels
|
319 |
+
|
320 |
+
def forward(self, x, z):
|
321 |
+
out = self.input_layer(x)
|
322 |
+
out_list = [out]
|
323 |
+
for i in range(self.layers):
|
324 |
+
model = getattr(self, 'encoder' + str(i))
|
325 |
+
out = model(out, z)
|
326 |
+
out_list.append(out)
|
327 |
+
return out_list
|
328 |
+
|
329 |
+
|
330 |
+
class ADAINDecoder(nn.Module):
|
331 |
+
"""docstring for ADAINDecoder"""
|
332 |
+
def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True,
|
333 |
+
nonlinearity=nn.LeakyReLU(), use_spect=False):
|
334 |
+
|
335 |
+
super(ADAINDecoder, self).__init__()
|
336 |
+
self.encoder_layers = encoder_layers
|
337 |
+
self.decoder_layers = decoder_layers
|
338 |
+
self.skip_connect = skip_connect
|
339 |
+
use_transpose = True
|
340 |
+
for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
|
341 |
+
in_channels = min(ngf * (2**(i+1)), img_f)
|
342 |
+
in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
|
343 |
+
out_channels = min(ngf * (2**i), img_f)
|
344 |
+
model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
|
345 |
+
setattr(self, 'decoder' + str(i), model)
|
346 |
+
self.output_nc = out_channels*2 if self.skip_connect else out_channels
|
347 |
+
|
348 |
+
def forward(self, x, z):
|
349 |
+
out = x.pop() if self.skip_connect else x
|
350 |
+
for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
|
351 |
+
model = getattr(self, 'decoder' + str(i))
|
352 |
+
out = model(out, z)
|
353 |
+
out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
|
354 |
+
return out
|
355 |
+
|
356 |
+
|
357 |
+
class ADAINHourglass(nn.Module):
|
358 |
+
def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
|
359 |
+
super(ADAINHourglass, self).__init__()
|
360 |
+
self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
|
361 |
+
self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
|
362 |
+
self.output_nc = self.decoder.output_nc
|
363 |
+
|
364 |
+
def forward(self, x, z):
|
365 |
+
return self.decoder(self.encoder(x, z), z)
|
366 |
+
|
367 |
+
|
368 |
+
class FineADAINLama(nn.Module):
|
369 |
+
def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
370 |
+
super(FineADAINLama, self).__init__()
|
371 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
372 |
+
self.actvn = nonlinearity
|
373 |
+
ratio_gin = 0.75
|
374 |
+
ratio_gout = 0.75
|
375 |
+
self.ffc = FFC(input_nc, input_nc, 3,
|
376 |
+
ratio_gin, ratio_gout, 1, 1, 1,
|
377 |
+
1, False, False, padding_type='reflect')
|
378 |
+
global_channels = int(input_nc * ratio_gout)
|
379 |
+
self.bn_l = ADAIN(input_nc - global_channels, feature_nc)
|
380 |
+
self.bn_g = ADAIN(global_channels, feature_nc)
|
381 |
+
|
382 |
+
def forward(self, x, z):
|
383 |
+
x_l, x_g = self.ffc(x)
|
384 |
+
x_l = self.actvn(self.bn_l(x_l,z))
|
385 |
+
x_g = self.actvn(self.bn_g(x_g,z))
|
386 |
+
return x_l, x_g
|
387 |
+
|
388 |
+
|
389 |
+
class FFCResnetBlock(nn.Module):
|
390 |
+
def __init__(self, dim, feature_dim, padding_type='reflect', norm_layer=BatchNorm2d, activation_layer=nn.ReLU, dilation=1,
|
391 |
+
spatial_transform_kwargs=None, inline=False, **conv_kwargs):
|
392 |
+
super().__init__()
|
393 |
+
self.conv1 = FineADAINLama(dim, feature_dim, **conv_kwargs)
|
394 |
+
self.conv2 = FineADAINLama(dim, feature_dim, **conv_kwargs)
|
395 |
+
self.inline = True
|
396 |
+
|
397 |
+
def forward(self, x, z):
|
398 |
+
if self.inline:
|
399 |
+
x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
|
400 |
+
else:
|
401 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
402 |
+
|
403 |
+
id_l, id_g = x_l, x_g
|
404 |
+
x_l, x_g = self.conv1((x_l, x_g), z)
|
405 |
+
x_l, x_g = self.conv2((x_l, x_g), z)
|
406 |
+
|
407 |
+
x_l, x_g = id_l + x_l, id_g + x_g
|
408 |
+
out = x_l, x_g
|
409 |
+
if self.inline:
|
410 |
+
out = torch.cat(out, dim=1)
|
411 |
+
return out
|
412 |
+
|
413 |
+
|
414 |
+
class FFCADAINResBlocks(nn.Module):
|
415 |
+
def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
416 |
+
super(FFCADAINResBlocks, self).__init__()
|
417 |
+
self.num_block = num_block
|
418 |
+
for i in range(num_block):
|
419 |
+
model = FFCResnetBlock(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
|
420 |
+
setattr(self, 'res'+str(i), model)
|
421 |
+
|
422 |
+
def forward(self, x, z):
|
423 |
+
for i in range(self.num_block):
|
424 |
+
model = getattr(self, 'res'+str(i))
|
425 |
+
x = model(x, z)
|
426 |
+
return x
|
427 |
+
|
428 |
+
|
429 |
+
class Jump(nn.Module):
|
430 |
+
def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
431 |
+
super(Jump, self).__init__()
|
432 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
433 |
+
conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
|
434 |
+
if type(norm_layer) == type(None):
|
435 |
+
self.model = nn.Sequential(conv, nonlinearity)
|
436 |
+
else:
|
437 |
+
self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)
|
438 |
+
|
439 |
+
def forward(self, x):
|
440 |
+
out = self.model(x)
|
441 |
+
return out
|
442 |
+
|
443 |
+
|
444 |
+
class FinalBlock2d(nn.Module):
|
445 |
+
def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
|
446 |
+
super(FinalBlock2d, self).__init__()
|
447 |
+
kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
|
448 |
+
conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
|
449 |
+
if tanh_or_sigmoid == 'sigmoid':
|
450 |
+
out_nonlinearity = nn.Sigmoid()
|
451 |
+
else:
|
452 |
+
out_nonlinearity = nn.Tanh()
|
453 |
+
self.model = nn.Sequential(conv, out_nonlinearity)
|
454 |
+
|
455 |
+
def forward(self, x):
|
456 |
+
out = self.model(x)
|
457 |
+
return out
|
458 |
+
|
459 |
+
|
460 |
+
class ModulatedConv2d(nn.Module):
|
461 |
+
def __init__(self,
|
462 |
+
in_channels,
|
463 |
+
out_channels,
|
464 |
+
kernel_size,
|
465 |
+
num_style_feat,
|
466 |
+
demodulate=True,
|
467 |
+
sample_mode=None,
|
468 |
+
eps=1e-8):
|
469 |
+
super(ModulatedConv2d, self).__init__()
|
470 |
+
self.in_channels = in_channels
|
471 |
+
self.out_channels = out_channels
|
472 |
+
self.kernel_size = kernel_size
|
473 |
+
self.demodulate = demodulate
|
474 |
+
self.sample_mode = sample_mode
|
475 |
+
self.eps = eps
|
476 |
+
|
477 |
+
# modulation inside each modulated conv
|
478 |
+
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
|
479 |
+
# initialization
|
480 |
+
default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
|
481 |
+
|
482 |
+
self.weight = nn.Parameter(
|
483 |
+
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
|
484 |
+
math.sqrt(in_channels * kernel_size**2))
|
485 |
+
self.padding = kernel_size // 2
|
486 |
+
|
487 |
+
def forward(self, x, style):
|
488 |
+
b, c, h, w = x.shape
|
489 |
+
style = self.modulation(style).view(b, 1, c, 1, 1)
|
490 |
+
weight = self.weight * style
|
491 |
+
|
492 |
+
if self.demodulate:
|
493 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
494 |
+
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
495 |
+
|
496 |
+
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
497 |
+
|
498 |
+
# upsample or downsample if necessary
|
499 |
+
if self.sample_mode == 'upsample':
|
500 |
+
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
501 |
+
elif self.sample_mode == 'downsample':
|
502 |
+
x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
503 |
+
|
504 |
+
b, c, h, w = x.shape
|
505 |
+
x = x.view(1, b * c, h, w)
|
506 |
+
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
507 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
508 |
+
return out
|
509 |
+
|
510 |
+
def __repr__(self):
|
511 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
|
512 |
+
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
513 |
+
|
514 |
+
|
515 |
+
class StyleConv(nn.Module):
|
516 |
+
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
517 |
+
super(StyleConv, self).__init__()
|
518 |
+
self.modulated_conv = ModulatedConv2d(
|
519 |
+
in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
|
520 |
+
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
521 |
+
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
|
522 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
523 |
+
|
524 |
+
def forward(self, x, style, noise=None):
|
525 |
+
# modulate
|
526 |
+
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
|
527 |
+
# noise injection
|
528 |
+
if noise is None:
|
529 |
+
b, _, h, w = out.shape
|
530 |
+
noise = out.new_empty(b, 1, h, w).normal_()
|
531 |
+
out = out + self.weight * noise
|
532 |
+
# add bias
|
533 |
+
out = out + self.bias
|
534 |
+
# activation
|
535 |
+
out = self.activate(out)
|
536 |
+
return out
|
537 |
+
|
538 |
+
|
539 |
+
class ToRGB(nn.Module):
|
540 |
+
def __init__(self, in_channels, num_style_feat, upsample=True):
|
541 |
+
super(ToRGB, self).__init__()
|
542 |
+
self.upsample = upsample
|
543 |
+
self.modulated_conv = ModulatedConv2d(
|
544 |
+
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
|
545 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
546 |
+
|
547 |
+
def forward(self, x, style, skip=None):
|
548 |
+
out = self.modulated_conv(x, style)
|
549 |
+
out = out + self.bias
|
550 |
+
if skip is not None:
|
551 |
+
if self.upsample:
|
552 |
+
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
|
553 |
+
out = out + skip
|
554 |
+
return out
|
models/ffc.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Fast Fourier Convolution NeurIPS 2020
|
2 |
+
# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
|
3 |
+
# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
# from models.modules.squeeze_excitation import SELayer
|
9 |
+
import torch.fft
|
10 |
+
|
11 |
+
class SELayer(nn.Module):
|
12 |
+
def __init__(self, channel, reduction=16):
|
13 |
+
super(SELayer, self).__init__()
|
14 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
15 |
+
self.fc = nn.Sequential(
|
16 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
17 |
+
nn.ReLU(inplace=True),
|
18 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
19 |
+
nn.Sigmoid()
|
20 |
+
)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
b, c, _, _ = x.size()
|
24 |
+
y = self.avg_pool(x).view(b, c)
|
25 |
+
y = self.fc(y).view(b, c, 1, 1)
|
26 |
+
res = x * y.expand_as(x)
|
27 |
+
return res
|
28 |
+
|
29 |
+
|
30 |
+
class FFCSE_block(nn.Module):
|
31 |
+
def __init__(self, channels, ratio_g):
|
32 |
+
super(FFCSE_block, self).__init__()
|
33 |
+
in_cg = int(channels * ratio_g)
|
34 |
+
in_cl = channels - in_cg
|
35 |
+
r = 16
|
36 |
+
|
37 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
38 |
+
self.conv1 = nn.Conv2d(channels, channels // r,
|
39 |
+
kernel_size=1, bias=True)
|
40 |
+
self.relu1 = nn.ReLU(inplace=True)
|
41 |
+
self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
|
42 |
+
channels // r, in_cl, kernel_size=1, bias=True)
|
43 |
+
self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
|
44 |
+
channels // r, in_cg, kernel_size=1, bias=True)
|
45 |
+
self.sigmoid = nn.Sigmoid()
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
x = x if type(x) is tuple else (x, 0)
|
49 |
+
id_l, id_g = x
|
50 |
+
|
51 |
+
x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
|
52 |
+
x = self.avgpool(x)
|
53 |
+
x = self.relu1(self.conv1(x))
|
54 |
+
|
55 |
+
x_l = 0 if self.conv_a2l is None else id_l * \
|
56 |
+
self.sigmoid(self.conv_a2l(x))
|
57 |
+
x_g = 0 if self.conv_a2g is None else id_g * \
|
58 |
+
self.sigmoid(self.conv_a2g(x))
|
59 |
+
return x_l, x_g
|
60 |
+
|
61 |
+
|
62 |
+
class FourierUnit(nn.Module):
|
63 |
+
|
64 |
+
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
|
65 |
+
spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
|
66 |
+
# bn_layer not used
|
67 |
+
super(FourierUnit, self).__init__()
|
68 |
+
self.groups = groups
|
69 |
+
|
70 |
+
self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
|
71 |
+
out_channels=out_channels * 2,
|
72 |
+
kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
|
73 |
+
self.bn = torch.nn.BatchNorm2d(out_channels * 2)
|
74 |
+
self.relu = torch.nn.ReLU(inplace=True)
|
75 |
+
|
76 |
+
# squeeze and excitation block
|
77 |
+
self.use_se = use_se
|
78 |
+
if use_se:
|
79 |
+
if se_kwargs is None:
|
80 |
+
se_kwargs = {}
|
81 |
+
self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
|
82 |
+
|
83 |
+
self.spatial_scale_factor = spatial_scale_factor
|
84 |
+
self.spatial_scale_mode = spatial_scale_mode
|
85 |
+
self.spectral_pos_encoding = spectral_pos_encoding
|
86 |
+
self.ffc3d = ffc3d
|
87 |
+
self.fft_norm = fft_norm
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
batch = x.shape[0]
|
91 |
+
|
92 |
+
if self.spatial_scale_factor is not None:
|
93 |
+
orig_size = x.shape[-2:]
|
94 |
+
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)
|
95 |
+
|
96 |
+
r_size = x.size()
|
97 |
+
# (batch, c, h, w/2+1, 2)
|
98 |
+
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
99 |
+
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
100 |
+
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
101 |
+
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
102 |
+
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
|
103 |
+
|
104 |
+
if self.spectral_pos_encoding:
|
105 |
+
height, width = ffted.shape[-2:]
|
106 |
+
coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
|
107 |
+
coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
|
108 |
+
ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
|
109 |
+
|
110 |
+
if self.use_se:
|
111 |
+
ffted = self.se(ffted)
|
112 |
+
|
113 |
+
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
|
114 |
+
ffted = self.relu(self.bn(ffted))
|
115 |
+
|
116 |
+
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
|
117 |
+
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
|
118 |
+
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
119 |
+
|
120 |
+
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
121 |
+
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
|
122 |
+
|
123 |
+
if self.spatial_scale_factor is not None:
|
124 |
+
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
|
125 |
+
|
126 |
+
return output
|
127 |
+
|
128 |
+
|
129 |
+
class SpectralTransform(nn.Module):
|
130 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
|
131 |
+
# bn_layer not used
|
132 |
+
super(SpectralTransform, self).__init__()
|
133 |
+
self.enable_lfu = enable_lfu
|
134 |
+
if stride == 2:
|
135 |
+
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
|
136 |
+
else:
|
137 |
+
self.downsample = nn.Identity()
|
138 |
+
|
139 |
+
self.stride = stride
|
140 |
+
self.conv1 = nn.Sequential(
|
141 |
+
nn.Conv2d(in_channels, out_channels //
|
142 |
+
2, kernel_size=1, groups=groups, bias=False),
|
143 |
+
nn.BatchNorm2d(out_channels // 2),
|
144 |
+
nn.ReLU(inplace=True)
|
145 |
+
)
|
146 |
+
self.fu = FourierUnit(
|
147 |
+
out_channels // 2, out_channels // 2, groups, **fu_kwargs)
|
148 |
+
if self.enable_lfu:
|
149 |
+
self.lfu = FourierUnit(
|
150 |
+
out_channels // 2, out_channels // 2, groups)
|
151 |
+
self.conv2 = torch.nn.Conv2d(
|
152 |
+
out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
x = self.downsample(x)
|
156 |
+
x = self.conv1(x)
|
157 |
+
output = self.fu(x)
|
158 |
+
|
159 |
+
if self.enable_lfu:
|
160 |
+
n, c, h, w = x.shape
|
161 |
+
split_no = 2
|
162 |
+
split_s = h // split_no
|
163 |
+
xs = torch.cat(torch.split(
|
164 |
+
x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
|
165 |
+
xs = torch.cat(torch.split(xs, split_s, dim=-1),
|
166 |
+
dim=1).contiguous()
|
167 |
+
xs = self.lfu(xs)
|
168 |
+
xs = xs.repeat(1, 1, split_no, split_no).contiguous()
|
169 |
+
else:
|
170 |
+
xs = 0
|
171 |
+
|
172 |
+
output = self.conv2(x + output + xs)
|
173 |
+
return output
|
174 |
+
|
175 |
+
|
176 |
+
class FFC(nn.Module):
|
177 |
+
|
178 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
179 |
+
ratio_gin, ratio_gout, stride=1, padding=0,
|
180 |
+
dilation=1, groups=1, bias=False, enable_lfu=True,
|
181 |
+
padding_type='reflect', gated=False, **spectral_kwargs):
|
182 |
+
super(FFC, self).__init__()
|
183 |
+
|
184 |
+
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
|
185 |
+
self.stride = stride
|
186 |
+
|
187 |
+
in_cg = int(in_channels * ratio_gin)
|
188 |
+
in_cl = in_channels - in_cg
|
189 |
+
out_cg = int(out_channels * ratio_gout)
|
190 |
+
out_cl = out_channels - out_cg
|
191 |
+
|
192 |
+
self.ratio_gin = ratio_gin
|
193 |
+
self.ratio_gout = ratio_gout
|
194 |
+
self.global_in_num = in_cg
|
195 |
+
|
196 |
+
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
|
197 |
+
self.convl2l = module(in_cl, out_cl, kernel_size,
|
198 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
199 |
+
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
|
200 |
+
self.convl2g = module(in_cl, out_cg, kernel_size,
|
201 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
202 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
|
203 |
+
self.convg2l = module(in_cg, out_cl, kernel_size,
|
204 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
205 |
+
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
|
206 |
+
self.convg2g = module(
|
207 |
+
in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
|
208 |
+
|
209 |
+
self.gated = gated
|
210 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
|
211 |
+
self.gate = module(in_channels, 2, 1)
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
215 |
+
out_xl, out_xg = 0, 0
|
216 |
+
|
217 |
+
if self.gated:
|
218 |
+
total_input_parts = [x_l]
|
219 |
+
if torch.is_tensor(x_g):
|
220 |
+
total_input_parts.append(x_g)
|
221 |
+
total_input = torch.cat(total_input_parts, dim=1)
|
222 |
+
|
223 |
+
gates = torch.sigmoid(self.gate(total_input))
|
224 |
+
g2l_gate, l2g_gate = gates.chunk(2, dim=1)
|
225 |
+
else:
|
226 |
+
g2l_gate, l2g_gate = 1, 1
|
227 |
+
|
228 |
+
if self.ratio_gout != 1:
|
229 |
+
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
|
230 |
+
if self.ratio_gout != 0:
|
231 |
+
out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
|
232 |
+
|
233 |
+
return out_xl, out_xg
|
models/transformer.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
class GELU(nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super(GELU, self).__init__()
|
14 |
+
def forward(self, x):
|
15 |
+
return 0.5*x*(1+F.tanh(np.sqrt(2/np.pi)*(x+0.044715*torch.pow(x,3))))
|
16 |
+
|
17 |
+
# helpers
|
18 |
+
|
19 |
+
def pair(t):
|
20 |
+
return t if isinstance(t, tuple) else (t, t)
|
21 |
+
|
22 |
+
# classes
|
23 |
+
|
24 |
+
class PreNorm(nn.Module):
|
25 |
+
def __init__(self, dim, fn):
|
26 |
+
super().__init__()
|
27 |
+
self.norm = nn.LayerNorm(dim)
|
28 |
+
self.fn = fn
|
29 |
+
def forward(self, x, **kwargs):
|
30 |
+
return self.fn(self.norm(x), **kwargs)
|
31 |
+
|
32 |
+
class DualPreNorm(nn.Module):
|
33 |
+
def __init__(self, dim, fn):
|
34 |
+
super().__init__()
|
35 |
+
self.normx = nn.LayerNorm(dim)
|
36 |
+
self.normy = nn.LayerNorm(dim)
|
37 |
+
self.fn = fn
|
38 |
+
def forward(self, x, y, **kwargs):
|
39 |
+
return self.fn(self.normx(x), self.normy(y), **kwargs)
|
40 |
+
|
41 |
+
class FeedForward(nn.Module):
|
42 |
+
def __init__(self, dim, hidden_dim, dropout = 0.):
|
43 |
+
super().__init__()
|
44 |
+
self.net = nn.Sequential(
|
45 |
+
nn.Linear(dim, hidden_dim),
|
46 |
+
GELU(),
|
47 |
+
nn.Dropout(dropout),
|
48 |
+
nn.Linear(hidden_dim, dim),
|
49 |
+
nn.Dropout(dropout)
|
50 |
+
)
|
51 |
+
def forward(self, x):
|
52 |
+
return self.net(x)
|
53 |
+
|
54 |
+
class Attention(nn.Module):
|
55 |
+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
56 |
+
super().__init__()
|
57 |
+
inner_dim = dim_head * heads
|
58 |
+
project_out = not (heads == 1 and dim_head == dim)
|
59 |
+
|
60 |
+
self.heads = heads
|
61 |
+
self.scale = dim_head ** -0.5
|
62 |
+
|
63 |
+
self.attend = nn.Softmax(dim = -1)
|
64 |
+
|
65 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
66 |
+
self.to_k = nn.Linear(dim, inner_dim, bias = False)
|
67 |
+
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
68 |
+
|
69 |
+
|
70 |
+
self.to_out = nn.Sequential(
|
71 |
+
nn.Linear(inner_dim, dim),
|
72 |
+
nn.Dropout(dropout)
|
73 |
+
) if project_out else nn.Identity()
|
74 |
+
|
75 |
+
def forward(self, x, y):
|
76 |
+
# qk = self.to_qk(x).chunk(2, dim = -1) #
|
77 |
+
q = rearrange(self.to_q(x), 'b n (h d) -> b h n d', h = self.heads) # q,k from the zero feature
|
78 |
+
k = rearrange(self.to_k(x), 'b n (h d) -> b h n d', h = self.heads) # v from the reference features
|
79 |
+
v = rearrange(self.to_v(y), 'b n (h d) -> b h n d', h = self.heads)
|
80 |
+
|
81 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
82 |
+
|
83 |
+
attn = self.attend(dots)
|
84 |
+
|
85 |
+
out = torch.matmul(attn, v)
|
86 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
87 |
+
return self.to_out(out)
|
88 |
+
|
89 |
+
class Transformer(nn.Module):
|
90 |
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
91 |
+
super().__init__()
|
92 |
+
self.layers = nn.ModuleList([])
|
93 |
+
for _ in range(depth):
|
94 |
+
self.layers.append(nn.ModuleList([
|
95 |
+
DualPreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
96 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
97 |
+
]))
|
98 |
+
|
99 |
+
|
100 |
+
def forward(self, x, y): # x is the cropped, y is the foreign reference
|
101 |
+
bs,c,h,w = x.size()
|
102 |
+
|
103 |
+
# img to embedding
|
104 |
+
x = x.view(bs,c,-1).permute(0,2,1)
|
105 |
+
y = y.view(bs,c,-1).permute(0,2,1)
|
106 |
+
|
107 |
+
for attn, ff in self.layers:
|
108 |
+
x = attn(x, y) + x
|
109 |
+
x = ff(x) + x
|
110 |
+
|
111 |
+
x = x.view(bs,h,w,c).permute(0,3,1,2)
|
112 |
+
return x
|
113 |
+
|
114 |
+
class RETURNX(nn.Module):
|
115 |
+
def __init__(self,):
|
116 |
+
super().__init__()
|
117 |
+
|
118 |
+
def forward(self, x, y): # x is the cropped, y is the foreign reference
|
119 |
+
return x
|