Shadhil commited on
Commit
261df59
·
verified ·
1 Parent(s): 1a79cb6

Upload 50 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. checkpoints/30_net_gen.pth +3 -0
  3. checkpoints/BFM/.gitkeep +0 -0
  4. checkpoints/BFM/01_MorphableModel.mat +3 -0
  5. checkpoints/BFM/BFM_exp_idx.mat +0 -0
  6. checkpoints/BFM/BFM_front_idx.mat +0 -0
  7. checkpoints/BFM/BFM_model_front.mat +3 -0
  8. checkpoints/BFM/Exp_Pca.bin +3 -0
  9. checkpoints/BFM/facemodel_info.mat +0 -0
  10. checkpoints/BFM/select_vertex_id.mat +0 -0
  11. checkpoints/BFM/similarity_Lm3D_all.mat +0 -0
  12. checkpoints/BFM/std_exp.txt +1 -0
  13. checkpoints/DNet.pt +3 -0
  14. checkpoints/ENet.pth +3 -0
  15. checkpoints/GFPGANv1.3.pth +3 -0
  16. checkpoints/GPEN-BFR-512.pth +3 -0
  17. checkpoints/LNet.pth +3 -0
  18. checkpoints/ParseNet-latest.pth +3 -0
  19. checkpoints/RetinaFace-R50.pth +3 -0
  20. checkpoints/expression.mat +0 -0
  21. checkpoints/face3d_pretrain_epoch_20.pth +3 -0
  22. checkpoints/shape_predictor_68_face_landmarks.dat +3 -0
  23. models/DNet.py +118 -0
  24. models/ENet.py +139 -0
  25. models/LNet.py +139 -0
  26. models/__init__.py +37 -0
  27. models/__pycache__/DNet.cpython-37.pyc +0 -0
  28. models/__pycache__/DNet.cpython-38.pyc +0 -0
  29. models/__pycache__/DNet.cpython-39.pyc +0 -0
  30. models/__pycache__/ENet.cpython-37.pyc +0 -0
  31. models/__pycache__/ENet.cpython-38.pyc +0 -0
  32. models/__pycache__/ENet.cpython-39.pyc +0 -0
  33. models/__pycache__/LNet.cpython-37.pyc +0 -0
  34. models/__pycache__/LNet.cpython-38.pyc +0 -0
  35. models/__pycache__/LNet.cpython-39.pyc +0 -0
  36. models/__pycache__/__init__.cpython-37.pyc +0 -0
  37. models/__pycache__/__init__.cpython-38.pyc +0 -0
  38. models/__pycache__/__init__.cpython-39.pyc +0 -0
  39. models/__pycache__/base_blocks.cpython-37.pyc +0 -0
  40. models/__pycache__/base_blocks.cpython-38.pyc +0 -0
  41. models/__pycache__/base_blocks.cpython-39.pyc +0 -0
  42. models/__pycache__/ffc.cpython-37.pyc +0 -0
  43. models/__pycache__/ffc.cpython-38.pyc +0 -0
  44. models/__pycache__/ffc.cpython-39.pyc +0 -0
  45. models/__pycache__/transformer.cpython-37.pyc +0 -0
  46. models/__pycache__/transformer.cpython-38.pyc +0 -0
  47. models/__pycache__/transformer.cpython-39.pyc +0 -0
  48. models/base_blocks.py +554 -0
  49. models/ffc.py +233 -0
  50. models/transformer.py +119 -0
.gitattributes CHANGED
@@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  temp/temp/temp.wav filter=lfs diff=lfs merge=lfs -text
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  temp/temp/temp.wav filter=lfs diff=lfs merge=lfs -text
37
+ checkpoints/BFM/01_MorphableModel.mat filter=lfs diff=lfs merge=lfs -text
38
+ checkpoints/BFM/BFM_model_front.mat filter=lfs diff=lfs merge=lfs -text
39
+ checkpoints/shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
checkpoints/30_net_gen.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4db83e1727128e2c5de27bc80d2929586535e04a709af45016a63e7cf7c46b0c
3
+ size 33877439
checkpoints/BFM/.gitkeep ADDED
File without changes
checkpoints/BFM/01_MorphableModel.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37b1f0742db356a3b1568a8365a06f5b0fe0ab687ac1c3068c803666cbd4d8e2
3
+ size 240875364
checkpoints/BFM/BFM_exp_idx.mat ADDED
Binary file (91.9 kB). View file
 
checkpoints/BFM/BFM_front_idx.mat ADDED
Binary file (44.9 kB). View file
 
checkpoints/BFM/BFM_model_front.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae3ff544aba3246c5f2c117f2be76fa44a7b76145326aae0bbfbfb564d4f82af
3
+ size 127170280
checkpoints/BFM/Exp_Pca.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7f31380e6cbdaf2aeec698db220bac4f221946e4d551d88c092d47ec49b1726
3
+ size 51086404
checkpoints/BFM/facemodel_info.mat ADDED
Binary file (739 kB). View file
 
checkpoints/BFM/select_vertex_id.mat ADDED
Binary file (62.3 kB). View file
 
checkpoints/BFM/similarity_Lm3D_all.mat ADDED
Binary file (994 Bytes). View file
 
checkpoints/BFM/std_exp.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 453980 257264 263068 211890 135873 184721 47055.6 72732 62787.4 106226 56708.5 51439.8 34887.1 44378.7 51813.4 31030.7 23354.9 23128.1 19400 21827.6 22767.7 22057.4 19894.3 16172.8 17142.7 10035.3 14727.5 12972.5 10763.8 8953.93 8682.62 8941.81 6342.3 5205.3 7065.65 6083.35 6678.88 4666.63 5082.89 5134.76 4908.16 3964.93 3739.95 3180.09 2470.45 1866.62 1624.71 2423.74 1668.53 1471.65 1194.52 782.102 815.044 835.782 834.937 744.496 575.146 633.76 705.685 753.409 620.306 673.326 766.189 619.866 559.93 357.264 396.472 556.849 455.048 460.592 400.735 326.702 279.428 291.535 326.584 305.664 287.816 283.642 276.19
checkpoints/DNet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41220d2973c0ba2eab6e8f17ed00711aef5a0d76d19808f885dc0e3251df2e80
3
+ size 180424655
checkpoints/ENet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:967ee3ed857619cedd92b6407dc8a124cbfe763cc11cad58316fe21271a8928f
3
+ size 573261168
checkpoints/GFPGANv1.3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70
3
+ size 348632874
checkpoints/GPEN-BFR-512.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1002c41add95b0decad69604d80455576f7187dd99ca16bd611bcfd44c10b51
3
+ size 284085738
checkpoints/LNet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ae06fef0454c421b828cc53e8d4b9c92d990867a858ea7bb9661ab6cf6ab774
3
+ size 1534697728
checkpoints/ParseNet-latest.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
3
+ size 85331193
checkpoints/RetinaFace-R50.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
3
+ size 109497761
checkpoints/expression.mat ADDED
Binary file (1.46 kB). View file
 
checkpoints/face3d_pretrain_epoch_20.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d17a6b23457b521801baae583cb6a58f7238fe6721fc3d65d76407460e9149b
3
+ size 288860037
checkpoints/shape_predictor_68_face_landmarks.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
3
+ size 99693937
models/DNet.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO
2
+ import functools
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from utils import flow_util
10
+ from models.base_blocks import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
11
+
12
+ # DNet
13
+ class DNet(nn.Module):
14
+ def __init__(self):
15
+ super(DNet, self).__init__()
16
+ self.mapping_net = MappingNet()
17
+ self.warpping_net = WarpingNet()
18
+ self.editing_net = EditingNet()
19
+
20
+ def forward(self, input_image, driving_source, stage=None):
21
+ if stage == 'warp':
22
+ descriptor = self.mapping_net(driving_source)
23
+ output = self.warpping_net(input_image, descriptor)
24
+ else:
25
+ descriptor = self.mapping_net(driving_source)
26
+ output = self.warpping_net(input_image, descriptor)
27
+ output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
28
+ return output
29
+
30
+ class MappingNet(nn.Module):
31
+ def __init__(self, coeff_nc=73, descriptor_nc=256, layer=3):
32
+ super( MappingNet, self).__init__()
33
+
34
+ self.layer = layer
35
+ nonlinearity = nn.LeakyReLU(0.1)
36
+
37
+ self.first = nn.Sequential(
38
+ torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
39
+
40
+ for i in range(layer):
41
+ net = nn.Sequential(nonlinearity,
42
+ torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
43
+ setattr(self, 'encoder' + str(i), net)
44
+
45
+ self.pooling = nn.AdaptiveAvgPool1d(1)
46
+ self.output_nc = descriptor_nc
47
+
48
+ def forward(self, input_3dmm):
49
+ out = self.first(input_3dmm)
50
+ for i in range(self.layer):
51
+ model = getattr(self, 'encoder' + str(i))
52
+ out = model(out) + out[:,:,3:-3]
53
+ out = self.pooling(out)
54
+ return out
55
+
56
+ class WarpingNet(nn.Module):
57
+ def __init__(
58
+ self,
59
+ image_nc=3,
60
+ descriptor_nc=256,
61
+ base_nc=32,
62
+ max_nc=256,
63
+ encoder_layer=5,
64
+ decoder_layer=3,
65
+ use_spect=False
66
+ ):
67
+ super( WarpingNet, self).__init__()
68
+
69
+ nonlinearity = nn.LeakyReLU(0.1)
70
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
71
+ kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
72
+
73
+ self.descriptor_nc = descriptor_nc
74
+ self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
75
+ max_nc, encoder_layer, decoder_layer, **kwargs)
76
+
77
+ self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
78
+ nonlinearity,
79
+ nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
80
+
81
+ self.pool = nn.AdaptiveAvgPool2d(1)
82
+
83
+ def forward(self, input_image, descriptor):
84
+ final_output={}
85
+ output = self.hourglass(input_image, descriptor)
86
+ final_output['flow_field'] = self.flow_out(output)
87
+
88
+ deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
89
+ final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
90
+ return final_output
91
+
92
+
93
+ class EditingNet(nn.Module):
94
+ def __init__(
95
+ self,
96
+ image_nc=3,
97
+ descriptor_nc=256,
98
+ layer=3,
99
+ base_nc=64,
100
+ max_nc=256,
101
+ num_res_blocks=2,
102
+ use_spect=False):
103
+ super(EditingNet, self).__init__()
104
+
105
+ nonlinearity = nn.LeakyReLU(0.1)
106
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
107
+ kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
108
+ self.descriptor_nc = descriptor_nc
109
+
110
+ # encoder part
111
+ self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
112
+ self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
113
+
114
+ def forward(self, input_image, warp_image, descriptor):
115
+ x = torch.cat([input_image, warp_image], 1)
116
+ x = self.encoder(x)
117
+ gen_image = self.decoder(x, descriptor)
118
+ return gen_image
models/ENet.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from models.base_blocks import ResBlock, StyleConv, ToRGB
6
+
7
+
8
+ class ENet(nn.Module):
9
+ def __init__(
10
+ self,
11
+ num_style_feat=512,
12
+ lnet=None,
13
+ concat=False
14
+ ):
15
+ super(ENet, self).__init__()
16
+
17
+ self.low_res = lnet
18
+ for param in self.low_res.parameters():
19
+ param.requires_grad = False
20
+
21
+ channel_multiplier, narrow = 2, 1
22
+ channels = {
23
+ '4': int(512 * narrow),
24
+ '8': int(512 * narrow),
25
+ '16': int(512 * narrow),
26
+ '32': int(512 * narrow),
27
+ '64': int(256 * channel_multiplier * narrow),
28
+ '128': int(128 * channel_multiplier * narrow),
29
+ '256': int(64 * channel_multiplier * narrow),
30
+ '512': int(32 * channel_multiplier * narrow),
31
+ '1024': int(16 * channel_multiplier * narrow)
32
+ }
33
+
34
+ self.log_size = 8
35
+ first_out_size = 128
36
+ self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1) # 256 -> 128
37
+
38
+ # downsample
39
+ in_channels = channels[f'{first_out_size}']
40
+ self.conv_body_down = nn.ModuleList()
41
+ for i in range(8, 2, -1):
42
+ out_channels = channels[f'{2**(i - 1)}']
43
+ self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
44
+ in_channels = out_channels
45
+
46
+ self.num_style_feat = num_style_feat
47
+ linear_out_channel = num_style_feat
48
+ self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
49
+ self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
50
+
51
+ self.style_convs = nn.ModuleList()
52
+ self.to_rgbs = nn.ModuleList()
53
+ self.noises = nn.Module()
54
+
55
+ self.concat = concat
56
+ if concat:
57
+ in_channels = 3 + 32 # channels['64']
58
+ else:
59
+ in_channels = 3
60
+
61
+ for i in range(7, 9): # 128, 256
62
+ out_channels = channels[f'{2**i}'] #
63
+ self.style_convs.append(
64
+ StyleConv(
65
+ in_channels,
66
+ out_channels,
67
+ kernel_size=3,
68
+ num_style_feat=num_style_feat,
69
+ demodulate=True,
70
+ sample_mode='upsample'))
71
+ self.style_convs.append(
72
+ StyleConv(
73
+ out_channels,
74
+ out_channels,
75
+ kernel_size=3,
76
+ num_style_feat=num_style_feat,
77
+ demodulate=True,
78
+ sample_mode=None))
79
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
80
+ in_channels = out_channels
81
+
82
+ def forward(self, audio_sequences, face_sequences, gt_sequences):
83
+ B = audio_sequences.size(0)
84
+ input_dim_size = len(face_sequences.size())
85
+ inp, ref = torch.split(face_sequences,3,dim=1)
86
+
87
+ if input_dim_size > 4:
88
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
89
+ inp = torch.cat([inp[:, :, i] for i in range(inp.size(2))], dim=0)
90
+ ref = torch.cat([ref[:, :, i] for i in range(ref.size(2))], dim=0)
91
+ gt_sequences = torch.cat([gt_sequences[:, :, i] for i in range(gt_sequences.size(2))], dim=0)
92
+
93
+ # get the global style
94
+ feat = F.leaky_relu_(self.conv_body_first(F.interpolate(ref, size=(256,256), mode='bilinear')), negative_slope=0.2)
95
+ for i in range(self.log_size - 2):
96
+ feat = self.conv_body_down[i](feat)
97
+ feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
98
+
99
+ # style code
100
+ style_code = self.final_linear(feat.reshape(feat.size(0), -1))
101
+ style_code = style_code.reshape(style_code.size(0), -1, self.num_style_feat)
102
+
103
+ LNet_input = torch.cat([inp, gt_sequences], dim=1)
104
+ LNet_input = F.interpolate(LNet_input, size=(96,96), mode='bilinear')
105
+
106
+ if self.concat:
107
+ low_res_img, low_res_feat = self.low_res(audio_sequences, LNet_input)
108
+ low_res_img.detach()
109
+ low_res_feat.detach()
110
+ out = torch.cat([low_res_img, low_res_feat], dim=1)
111
+
112
+ else:
113
+ low_res_img = self.low_res(audio_sequences, LNet_input)
114
+ low_res_img.detach()
115
+ # 96 x 96
116
+ out = low_res_img
117
+
118
+ p2d = (2,2,2,2)
119
+ out = F.pad(out, p2d, "reflect", 0)
120
+ skip = out
121
+
122
+ for conv1, conv2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], self.to_rgbs):
123
+ out = conv1(out, style_code) # 96, 192, 384
124
+ out = conv2(out, style_code)
125
+ skip = to_rgb(out, style_code, skip)
126
+ _outputs = skip
127
+
128
+ # remove padding
129
+ _outputs = _outputs[:,:,8:-8,8:-8]
130
+
131
+ if input_dim_size > 4:
132
+ _outputs = torch.split(_outputs, B, dim=0)
133
+ outputs = torch.stack(_outputs, dim=2)
134
+ low_res_img = F.interpolate(low_res_img, outputs.size()[3:])
135
+ low_res_img = torch.split(low_res_img, B, dim=0)
136
+ low_res_img = torch.stack(low_res_img, dim=2)
137
+ else:
138
+ outputs = _outputs
139
+ return outputs, low_res_img
models/LNet.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from models.transformer import RETURNX, Transformer
6
+ from models.base_blocks import Conv2d, LayerNorm2d, FirstBlock2d, DownBlock2d, UpBlock2d, \
7
+ FFCADAINResBlocks, Jump, FinalBlock2d
8
+
9
+
10
+ class Visual_Encoder(nn.Module):
11
+ def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
12
+ super(Visual_Encoder, self).__init__()
13
+ self.layers = layers
14
+ self.first_inp = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
15
+ self.first_ref = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
16
+ for i in range(layers):
17
+ in_channels = min(ngf*(2**i), img_f)
18
+ out_channels = min(ngf*(2**(i+1)), img_f)
19
+ model_ref = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
20
+ model_inp = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
21
+ if i < 2:
22
+ ca_layer = RETURNX()
23
+ else:
24
+ ca_layer = Transformer(2**(i+1) * ngf,2,4,ngf,ngf*4)
25
+ setattr(self, 'ca' + str(i), ca_layer)
26
+ setattr(self, 'ref_down' + str(i), model_ref)
27
+ setattr(self, 'inp_down' + str(i), model_inp)
28
+ self.output_nc = out_channels * 2
29
+
30
+ def forward(self, maskGT, ref):
31
+ x_maskGT, x_ref = self.first_inp(maskGT), self.first_ref(ref)
32
+ out=[x_maskGT]
33
+ for i in range(self.layers):
34
+ model_ref = getattr(self, 'ref_down'+str(i))
35
+ model_inp = getattr(self, 'inp_down'+str(i))
36
+ ca_layer = getattr(self, 'ca'+str(i))
37
+ x_maskGT, x_ref = model_inp(x_maskGT), model_ref(x_ref)
38
+ x_maskGT = ca_layer(x_maskGT, x_ref)
39
+ if i < self.layers - 1:
40
+ out.append(x_maskGT)
41
+ else:
42
+ out.append(torch.cat([x_maskGT, x_ref], dim=1)) # concat ref features !
43
+ return out
44
+
45
+
46
+ class Decoder(nn.Module):
47
+ def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
48
+ super(Decoder, self).__init__()
49
+ self.layers = layers
50
+ for i in range(layers)[::-1]:
51
+ if i == layers-1:
52
+ in_channels = ngf*(2**(i+1)) * 2
53
+ else:
54
+ in_channels = min(ngf*(2**(i+1)), img_f)
55
+ out_channels = min(ngf*(2**i), img_f)
56
+ up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
57
+ res = FFCADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
58
+ jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
59
+
60
+ setattr(self, 'up' + str(i), up)
61
+ setattr(self, 'res' + str(i), res)
62
+ setattr(self, 'jump' + str(i), jump)
63
+
64
+ self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'sigmoid')
65
+ self.output_nc = out_channels
66
+
67
+ def forward(self, x, z):
68
+ out = x.pop()
69
+ for i in range(self.layers)[::-1]:
70
+ res_model = getattr(self, 'res' + str(i))
71
+ up_model = getattr(self, 'up' + str(i))
72
+ jump_model = getattr(self, 'jump' + str(i))
73
+ out = res_model(out, z)
74
+ out = up_model(out)
75
+ out = jump_model(x.pop()) + out
76
+ out_image = self.final(out)
77
+ return out_image
78
+
79
+
80
+ class LNet(nn.Module):
81
+ def __init__(
82
+ self,
83
+ image_nc=3,
84
+ descriptor_nc=512,
85
+ layer=3,
86
+ base_nc=64,
87
+ max_nc=512,
88
+ num_res_blocks=9,
89
+ use_spect=True,
90
+ encoder=Visual_Encoder,
91
+ decoder=Decoder
92
+ ):
93
+ super(LNet, self).__init__()
94
+
95
+ nonlinearity = nn.LeakyReLU(0.1)
96
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
97
+ kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
98
+ self.descriptor_nc = descriptor_nc
99
+
100
+ self.encoder = encoder(image_nc, base_nc, max_nc, layer, **kwargs)
101
+ self.decoder = decoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
102
+ self.audio_encoder = nn.Sequential(
103
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
104
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
105
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
106
+
107
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
108
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
109
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
110
+
111
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
112
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
113
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
114
+
115
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
116
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
117
+
118
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
119
+ Conv2d(512, descriptor_nc, kernel_size=1, stride=1, padding=0),
120
+ )
121
+
122
+ def forward(self, audio_sequences, face_sequences):
123
+ B = audio_sequences.size(0)
124
+ input_dim_size = len(face_sequences.size())
125
+ if input_dim_size > 4:
126
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
127
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
128
+ cropped, ref = torch.split(face_sequences, 3, dim=1)
129
+
130
+ vis_feat = self.encoder(cropped, ref)
131
+ audio_feat = self.audio_encoder(audio_sequences)
132
+ _outputs = self.decoder(vis_feat, audio_feat)
133
+
134
+ if input_dim_size > 4:
135
+ _outputs = torch.split(_outputs, B, dim=0)
136
+ outputs = torch.stack(_outputs, dim=2)
137
+ else:
138
+ outputs = _outputs
139
+ return outputs
models/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models.DNet import DNet
3
+ from models.LNet import LNet
4
+ from models.ENet import ENet
5
+
6
+
7
+ def _load(checkpoint_path):
8
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
9
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
10
+ return checkpoint
11
+
12
+ def load_checkpoint(path, model):
13
+ print("Load checkpoint from: {}".format(path))
14
+ checkpoint = _load(path)
15
+ s = checkpoint["state_dict"] if 'arcface' not in path else checkpoint
16
+ new_s = {}
17
+ for k, v in s.items():
18
+ if 'low_res' in k:
19
+ continue
20
+ else:
21
+ new_s[k.replace('module.', '')] = v
22
+ model.load_state_dict(new_s, strict=False)
23
+ return model
24
+
25
+ def load_network(args):
26
+ L_net = LNet()
27
+ L_net = load_checkpoint(args.LNet_path, L_net)
28
+ E_net = ENet(lnet=L_net)
29
+ model = load_checkpoint(args.ENet_path, E_net)
30
+ return model.eval()
31
+
32
+ def load_DNet(args):
33
+ D_Net = DNet()
34
+ print("Load checkpoint from: {}".format(args.DNet_path))
35
+ checkpoint = torch.load(args.DNet_path, map_location=lambda storage, loc: storage)
36
+ D_Net.load_state_dict(checkpoint['net_G_ema'], strict=False)
37
+ return D_Net.eval()
models/__pycache__/DNet.cpython-37.pyc ADDED
Binary file (4.08 kB). View file
 
models/__pycache__/DNet.cpython-38.pyc ADDED
Binary file (4.04 kB). View file
 
models/__pycache__/DNet.cpython-39.pyc ADDED
Binary file (4.05 kB). View file
 
models/__pycache__/ENet.cpython-37.pyc ADDED
Binary file (3.76 kB). View file
 
models/__pycache__/ENet.cpython-38.pyc ADDED
Binary file (3.78 kB). View file
 
models/__pycache__/ENet.cpython-39.pyc ADDED
Binary file (3.76 kB). View file
 
models/__pycache__/LNet.cpython-37.pyc ADDED
Binary file (4.85 kB). View file
 
models/__pycache__/LNet.cpython-38.pyc ADDED
Binary file (4.83 kB). View file
 
models/__pycache__/LNet.cpython-39.pyc ADDED
Binary file (4.83 kB). View file
 
models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.55 kB). View file
 
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.54 kB). View file
 
models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.54 kB). View file
 
models/__pycache__/base_blocks.cpython-37.pyc ADDED
Binary file (20.9 kB). View file
 
models/__pycache__/base_blocks.cpython-38.pyc ADDED
Binary file (20.2 kB). View file
 
models/__pycache__/base_blocks.cpython-39.pyc ADDED
Binary file (20.2 kB). View file
 
models/__pycache__/ffc.cpython-37.pyc ADDED
Binary file (6.98 kB). View file
 
models/__pycache__/ffc.cpython-38.pyc ADDED
Binary file (7.07 kB). View file
 
models/__pycache__/ffc.cpython-39.pyc ADDED
Binary file (6.96 kB). View file
 
models/__pycache__/transformer.cpython-37.pyc ADDED
Binary file (4.91 kB). View file
 
models/__pycache__/transformer.cpython-38.pyc ADDED
Binary file (4.81 kB). View file
 
models/__pycache__/transformer.cpython-39.pyc ADDED
Binary file (4.82 kB). View file
 
models/base_blocks.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.modules.batchnorm import BatchNorm2d
6
+ from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
7
+
8
+ from models.ffc import FFC
9
+ from basicsr.archs.arch_util import default_init_weights
10
+
11
+
12
+ class Conv2d(nn.Module):
13
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+ self.conv_block = nn.Sequential(
16
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
17
+ nn.BatchNorm2d(cout)
18
+ )
19
+ self.act = nn.ReLU()
20
+ self.residual = residual
21
+
22
+ def forward(self, x):
23
+ out = self.conv_block(x)
24
+ if self.residual:
25
+ out += x
26
+ return self.act(out)
27
+
28
+
29
+ class ResBlock(nn.Module):
30
+ def __init__(self, in_channels, out_channels, mode='down'):
31
+ super(ResBlock, self).__init__()
32
+ self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
33
+ self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
34
+ self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
35
+ if mode == 'down':
36
+ self.scale_factor = 0.5
37
+ elif mode == 'up':
38
+ self.scale_factor = 2
39
+
40
+ def forward(self, x):
41
+ out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
42
+ # upsample/downsample
43
+ out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
44
+ out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
45
+ # skip
46
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
47
+ skip = self.skip(x)
48
+ out = out + skip
49
+ return out
50
+
51
+
52
+ class LayerNorm2d(nn.Module):
53
+ def __init__(self, n_out, affine=True):
54
+ super(LayerNorm2d, self).__init__()
55
+ self.n_out = n_out
56
+ self.affine = affine
57
+
58
+ if self.affine:
59
+ self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
60
+ self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
61
+
62
+ def forward(self, x):
63
+ normalized_shape = x.size()[1:]
64
+ if self.affine:
65
+ return F.layer_norm(x, normalized_shape, \
66
+ self.weight.expand(normalized_shape),
67
+ self.bias.expand(normalized_shape))
68
+ else:
69
+ return F.layer_norm(x, normalized_shape)
70
+
71
+
72
+ def spectral_norm(module, use_spect=True):
73
+ if use_spect:
74
+ return SpectralNorm(module)
75
+ else:
76
+ return module
77
+
78
+
79
+ class FirstBlock2d(nn.Module):
80
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
81
+ super(FirstBlock2d, self).__init__()
82
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
83
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
84
+
85
+ if type(norm_layer) == type(None):
86
+ self.model = nn.Sequential(conv, nonlinearity)
87
+ else:
88
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
89
+
90
+ def forward(self, x):
91
+ out = self.model(x)
92
+ return out
93
+
94
+
95
+ class DownBlock2d(nn.Module):
96
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
97
+ super(DownBlock2d, self).__init__()
98
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
99
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
100
+ pool = nn.AvgPool2d(kernel_size=(2, 2))
101
+
102
+ if type(norm_layer) == type(None):
103
+ self.model = nn.Sequential(conv, nonlinearity, pool)
104
+ else:
105
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)
106
+
107
+ def forward(self, x):
108
+ out = self.model(x)
109
+ return out
110
+
111
+
112
+ class UpBlock2d(nn.Module):
113
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
114
+ super(UpBlock2d, self).__init__()
115
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
116
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
117
+ if type(norm_layer) == type(None):
118
+ self.model = nn.Sequential(conv, nonlinearity)
119
+ else:
120
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
121
+
122
+ def forward(self, x):
123
+ out = self.model(F.interpolate(x, scale_factor=2))
124
+ return out
125
+
126
+
127
+ class ADAIN(nn.Module):
128
+ def __init__(self, norm_nc, feature_nc):
129
+ super().__init__()
130
+
131
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
132
+
133
+ nhidden = 128
134
+ use_bias=True
135
+
136
+ self.mlp_shared = nn.Sequential(
137
+ nn.Linear(feature_nc, nhidden, bias=use_bias),
138
+ nn.ReLU()
139
+ )
140
+ self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)
141
+ self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)
142
+
143
+ def forward(self, x, feature):
144
+
145
+ # Part 1. generate parameter-free normalized activations
146
+ normalized = self.param_free_norm(x)
147
+ # Part 2. produce scaling and bias conditioned on feature
148
+ feature = feature.view(feature.size(0), -1)
149
+ actv = self.mlp_shared(feature)
150
+ gamma = self.mlp_gamma(actv)
151
+ beta = self.mlp_beta(actv)
152
+
153
+ # apply scale and bias
154
+ gamma = gamma.view(*gamma.size()[:2], 1,1)
155
+ beta = beta.view(*beta.size()[:2], 1,1)
156
+ out = normalized * (1 + gamma) + beta
157
+ return out
158
+
159
+
160
+ class FineADAINResBlock2d(nn.Module):
161
+ """
162
+ Define an Residual block for different types
163
+ """
164
+ def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
165
+ super(FineADAINResBlock2d, self).__init__()
166
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
167
+ self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
168
+ self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
169
+ self.norm1 = ADAIN(input_nc, feature_nc)
170
+ self.norm2 = ADAIN(input_nc, feature_nc)
171
+ self.actvn = nonlinearity
172
+
173
+ def forward(self, x, z):
174
+ dx = self.actvn(self.norm1(self.conv1(x), z))
175
+ dx = self.norm2(self.conv2(x), z)
176
+ out = dx + x
177
+ return out
178
+
179
+
180
+ class FineADAINResBlocks(nn.Module):
181
+ def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
182
+ super(FineADAINResBlocks, self).__init__()
183
+ self.num_block = num_block
184
+ for i in range(num_block):
185
+ model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
186
+ setattr(self, 'res'+str(i), model)
187
+
188
+ def forward(self, x, z):
189
+ for i in range(self.num_block):
190
+ model = getattr(self, 'res'+str(i))
191
+ x = model(x, z)
192
+ return x
193
+
194
+
195
+ class ADAINEncoderBlock(nn.Module):
196
+ def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
197
+ super(ADAINEncoderBlock, self).__init__()
198
+ kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
199
+ kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
200
+
201
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect)
202
+ self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)
203
+
204
+
205
+ self.norm_0 = ADAIN(input_nc, feature_nc)
206
+ self.norm_1 = ADAIN(output_nc, feature_nc)
207
+ self.actvn = nonlinearity
208
+
209
+ def forward(self, x, z):
210
+ x = self.conv_0(self.actvn(self.norm_0(x, z)))
211
+ x = self.conv_1(self.actvn(self.norm_1(x, z)))
212
+ return x
213
+
214
+
215
+ class ADAINDecoderBlock(nn.Module):
216
+ def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
217
+ super(ADAINDecoderBlock, self).__init__()
218
+ # Attributes
219
+ self.actvn = nonlinearity
220
+ hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc
221
+
222
+ kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
223
+ if use_transpose:
224
+ kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
225
+ else:
226
+ kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}
227
+
228
+ # create conv layers
229
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
230
+ if use_transpose:
231
+ self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
232
+ self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
233
+ else:
234
+ self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
235
+ nn.Upsample(scale_factor=2))
236
+ self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
237
+ nn.Upsample(scale_factor=2))
238
+ # define normalization layers
239
+ self.norm_0 = ADAIN(input_nc, feature_nc)
240
+ self.norm_1 = ADAIN(hidden_nc, feature_nc)
241
+ self.norm_s = ADAIN(input_nc, feature_nc)
242
+
243
+ def forward(self, x, z):
244
+ x_s = self.shortcut(x, z)
245
+ dx = self.conv_0(self.actvn(self.norm_0(x, z)))
246
+ dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
247
+ out = x_s + dx
248
+ return out
249
+
250
+ def shortcut(self, x, z):
251
+ x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
252
+ return x_s
253
+
254
+
255
+ class FineEncoder(nn.Module):
256
+ """docstring for Encoder"""
257
+ def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
258
+ super(FineEncoder, self).__init__()
259
+ self.layers = layers
260
+ self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
261
+ for i in range(layers):
262
+ in_channels = min(ngf*(2**i), img_f)
263
+ out_channels = min(ngf*(2**(i+1)), img_f)
264
+ model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
265
+ setattr(self, 'down' + str(i), model)
266
+ self.output_nc = out_channels
267
+
268
+ def forward(self, x):
269
+ x = self.first(x)
270
+ out=[x]
271
+ for i in range(self.layers):
272
+ model = getattr(self, 'down'+str(i))
273
+ x = model(x)
274
+ out.append(x)
275
+ return out
276
+
277
+
278
+ class FineDecoder(nn.Module):
279
+ """docstring for FineDecoder"""
280
+ def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
281
+ super(FineDecoder, self).__init__()
282
+ self.layers = layers
283
+ for i in range(layers)[::-1]:
284
+ in_channels = min(ngf*(2**(i+1)), img_f)
285
+ out_channels = min(ngf*(2**i), img_f)
286
+ up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
287
+ res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
288
+ jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
289
+ setattr(self, 'up' + str(i), up)
290
+ setattr(self, 'res' + str(i), res)
291
+ setattr(self, 'jump' + str(i), jump)
292
+ self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')
293
+ self.output_nc = out_channels
294
+
295
+ def forward(self, x, z):
296
+ out = x.pop()
297
+ for i in range(self.layers)[::-1]:
298
+ res_model = getattr(self, 'res' + str(i))
299
+ up_model = getattr(self, 'up' + str(i))
300
+ jump_model = getattr(self, 'jump' + str(i))
301
+ out = res_model(out, z)
302
+ out = up_model(out)
303
+ out = jump_model(x.pop()) + out
304
+ out_image = self.final(out)
305
+ return out_image
306
+
307
+
308
+ class ADAINEncoder(nn.Module):
309
+ def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
310
+ super(ADAINEncoder, self).__init__()
311
+ self.layers = layers
312
+ self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
313
+ for i in range(layers):
314
+ in_channels = min(ngf * (2**i), img_f)
315
+ out_channels = min(ngf *(2**(i+1)), img_f)
316
+ model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
317
+ setattr(self, 'encoder' + str(i), model)
318
+ self.output_nc = out_channels
319
+
320
+ def forward(self, x, z):
321
+ out = self.input_layer(x)
322
+ out_list = [out]
323
+ for i in range(self.layers):
324
+ model = getattr(self, 'encoder' + str(i))
325
+ out = model(out, z)
326
+ out_list.append(out)
327
+ return out_list
328
+
329
+
330
+ class ADAINDecoder(nn.Module):
331
+ """docstring for ADAINDecoder"""
332
+ def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True,
333
+ nonlinearity=nn.LeakyReLU(), use_spect=False):
334
+
335
+ super(ADAINDecoder, self).__init__()
336
+ self.encoder_layers = encoder_layers
337
+ self.decoder_layers = decoder_layers
338
+ self.skip_connect = skip_connect
339
+ use_transpose = True
340
+ for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
341
+ in_channels = min(ngf * (2**(i+1)), img_f)
342
+ in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
343
+ out_channels = min(ngf * (2**i), img_f)
344
+ model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
345
+ setattr(self, 'decoder' + str(i), model)
346
+ self.output_nc = out_channels*2 if self.skip_connect else out_channels
347
+
348
+ def forward(self, x, z):
349
+ out = x.pop() if self.skip_connect else x
350
+ for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
351
+ model = getattr(self, 'decoder' + str(i))
352
+ out = model(out, z)
353
+ out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
354
+ return out
355
+
356
+
357
+ class ADAINHourglass(nn.Module):
358
+ def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
359
+ super(ADAINHourglass, self).__init__()
360
+ self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
361
+ self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
362
+ self.output_nc = self.decoder.output_nc
363
+
364
+ def forward(self, x, z):
365
+ return self.decoder(self.encoder(x, z), z)
366
+
367
+
368
+ class FineADAINLama(nn.Module):
369
+ def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
370
+ super(FineADAINLama, self).__init__()
371
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
372
+ self.actvn = nonlinearity
373
+ ratio_gin = 0.75
374
+ ratio_gout = 0.75
375
+ self.ffc = FFC(input_nc, input_nc, 3,
376
+ ratio_gin, ratio_gout, 1, 1, 1,
377
+ 1, False, False, padding_type='reflect')
378
+ global_channels = int(input_nc * ratio_gout)
379
+ self.bn_l = ADAIN(input_nc - global_channels, feature_nc)
380
+ self.bn_g = ADAIN(global_channels, feature_nc)
381
+
382
+ def forward(self, x, z):
383
+ x_l, x_g = self.ffc(x)
384
+ x_l = self.actvn(self.bn_l(x_l,z))
385
+ x_g = self.actvn(self.bn_g(x_g,z))
386
+ return x_l, x_g
387
+
388
+
389
+ class FFCResnetBlock(nn.Module):
390
+ def __init__(self, dim, feature_dim, padding_type='reflect', norm_layer=BatchNorm2d, activation_layer=nn.ReLU, dilation=1,
391
+ spatial_transform_kwargs=None, inline=False, **conv_kwargs):
392
+ super().__init__()
393
+ self.conv1 = FineADAINLama(dim, feature_dim, **conv_kwargs)
394
+ self.conv2 = FineADAINLama(dim, feature_dim, **conv_kwargs)
395
+ self.inline = True
396
+
397
+ def forward(self, x, z):
398
+ if self.inline:
399
+ x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
400
+ else:
401
+ x_l, x_g = x if type(x) is tuple else (x, 0)
402
+
403
+ id_l, id_g = x_l, x_g
404
+ x_l, x_g = self.conv1((x_l, x_g), z)
405
+ x_l, x_g = self.conv2((x_l, x_g), z)
406
+
407
+ x_l, x_g = id_l + x_l, id_g + x_g
408
+ out = x_l, x_g
409
+ if self.inline:
410
+ out = torch.cat(out, dim=1)
411
+ return out
412
+
413
+
414
+ class FFCADAINResBlocks(nn.Module):
415
+ def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
416
+ super(FFCADAINResBlocks, self).__init__()
417
+ self.num_block = num_block
418
+ for i in range(num_block):
419
+ model = FFCResnetBlock(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
420
+ setattr(self, 'res'+str(i), model)
421
+
422
+ def forward(self, x, z):
423
+ for i in range(self.num_block):
424
+ model = getattr(self, 'res'+str(i))
425
+ x = model(x, z)
426
+ return x
427
+
428
+
429
+ class Jump(nn.Module):
430
+ def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
431
+ super(Jump, self).__init__()
432
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
433
+ conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
434
+ if type(norm_layer) == type(None):
435
+ self.model = nn.Sequential(conv, nonlinearity)
436
+ else:
437
+ self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)
438
+
439
+ def forward(self, x):
440
+ out = self.model(x)
441
+ return out
442
+
443
+
444
+ class FinalBlock2d(nn.Module):
445
+ def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
446
+ super(FinalBlock2d, self).__init__()
447
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
448
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
449
+ if tanh_or_sigmoid == 'sigmoid':
450
+ out_nonlinearity = nn.Sigmoid()
451
+ else:
452
+ out_nonlinearity = nn.Tanh()
453
+ self.model = nn.Sequential(conv, out_nonlinearity)
454
+
455
+ def forward(self, x):
456
+ out = self.model(x)
457
+ return out
458
+
459
+
460
+ class ModulatedConv2d(nn.Module):
461
+ def __init__(self,
462
+ in_channels,
463
+ out_channels,
464
+ kernel_size,
465
+ num_style_feat,
466
+ demodulate=True,
467
+ sample_mode=None,
468
+ eps=1e-8):
469
+ super(ModulatedConv2d, self).__init__()
470
+ self.in_channels = in_channels
471
+ self.out_channels = out_channels
472
+ self.kernel_size = kernel_size
473
+ self.demodulate = demodulate
474
+ self.sample_mode = sample_mode
475
+ self.eps = eps
476
+
477
+ # modulation inside each modulated conv
478
+ self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
479
+ # initialization
480
+ default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
481
+
482
+ self.weight = nn.Parameter(
483
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
484
+ math.sqrt(in_channels * kernel_size**2))
485
+ self.padding = kernel_size // 2
486
+
487
+ def forward(self, x, style):
488
+ b, c, h, w = x.shape
489
+ style = self.modulation(style).view(b, 1, c, 1, 1)
490
+ weight = self.weight * style
491
+
492
+ if self.demodulate:
493
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
494
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
495
+
496
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
497
+
498
+ # upsample or downsample if necessary
499
+ if self.sample_mode == 'upsample':
500
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
501
+ elif self.sample_mode == 'downsample':
502
+ x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
503
+
504
+ b, c, h, w = x.shape
505
+ x = x.view(1, b * c, h, w)
506
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
507
+ out = out.view(b, self.out_channels, *out.shape[2:4])
508
+ return out
509
+
510
+ def __repr__(self):
511
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
512
+ f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
513
+
514
+
515
+ class StyleConv(nn.Module):
516
+ def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
517
+ super(StyleConv, self).__init__()
518
+ self.modulated_conv = ModulatedConv2d(
519
+ in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
520
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
521
+ self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
522
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
523
+
524
+ def forward(self, x, style, noise=None):
525
+ # modulate
526
+ out = self.modulated_conv(x, style) * 2**0.5 # for conversion
527
+ # noise injection
528
+ if noise is None:
529
+ b, _, h, w = out.shape
530
+ noise = out.new_empty(b, 1, h, w).normal_()
531
+ out = out + self.weight * noise
532
+ # add bias
533
+ out = out + self.bias
534
+ # activation
535
+ out = self.activate(out)
536
+ return out
537
+
538
+
539
+ class ToRGB(nn.Module):
540
+ def __init__(self, in_channels, num_style_feat, upsample=True):
541
+ super(ToRGB, self).__init__()
542
+ self.upsample = upsample
543
+ self.modulated_conv = ModulatedConv2d(
544
+ in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
545
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
546
+
547
+ def forward(self, x, style, skip=None):
548
+ out = self.modulated_conv(x, style)
549
+ out = out + self.bias
550
+ if skip is not None:
551
+ if self.upsample:
552
+ skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
553
+ out = out + skip
554
+ return out
models/ffc.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fast Fourier Convolution NeurIPS 2020
2
+ # original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
3
+ # paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ # from models.modules.squeeze_excitation import SELayer
9
+ import torch.fft
10
+
11
+ class SELayer(nn.Module):
12
+ def __init__(self, channel, reduction=16):
13
+ super(SELayer, self).__init__()
14
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
15
+ self.fc = nn.Sequential(
16
+ nn.Linear(channel, channel // reduction, bias=False),
17
+ nn.ReLU(inplace=True),
18
+ nn.Linear(channel // reduction, channel, bias=False),
19
+ nn.Sigmoid()
20
+ )
21
+
22
+ def forward(self, x):
23
+ b, c, _, _ = x.size()
24
+ y = self.avg_pool(x).view(b, c)
25
+ y = self.fc(y).view(b, c, 1, 1)
26
+ res = x * y.expand_as(x)
27
+ return res
28
+
29
+
30
+ class FFCSE_block(nn.Module):
31
+ def __init__(self, channels, ratio_g):
32
+ super(FFCSE_block, self).__init__()
33
+ in_cg = int(channels * ratio_g)
34
+ in_cl = channels - in_cg
35
+ r = 16
36
+
37
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
38
+ self.conv1 = nn.Conv2d(channels, channels // r,
39
+ kernel_size=1, bias=True)
40
+ self.relu1 = nn.ReLU(inplace=True)
41
+ self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
42
+ channels // r, in_cl, kernel_size=1, bias=True)
43
+ self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
44
+ channels // r, in_cg, kernel_size=1, bias=True)
45
+ self.sigmoid = nn.Sigmoid()
46
+
47
+ def forward(self, x):
48
+ x = x if type(x) is tuple else (x, 0)
49
+ id_l, id_g = x
50
+
51
+ x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
52
+ x = self.avgpool(x)
53
+ x = self.relu1(self.conv1(x))
54
+
55
+ x_l = 0 if self.conv_a2l is None else id_l * \
56
+ self.sigmoid(self.conv_a2l(x))
57
+ x_g = 0 if self.conv_a2g is None else id_g * \
58
+ self.sigmoid(self.conv_a2g(x))
59
+ return x_l, x_g
60
+
61
+
62
+ class FourierUnit(nn.Module):
63
+
64
+ def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
65
+ spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
66
+ # bn_layer not used
67
+ super(FourierUnit, self).__init__()
68
+ self.groups = groups
69
+
70
+ self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
71
+ out_channels=out_channels * 2,
72
+ kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
73
+ self.bn = torch.nn.BatchNorm2d(out_channels * 2)
74
+ self.relu = torch.nn.ReLU(inplace=True)
75
+
76
+ # squeeze and excitation block
77
+ self.use_se = use_se
78
+ if use_se:
79
+ if se_kwargs is None:
80
+ se_kwargs = {}
81
+ self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
82
+
83
+ self.spatial_scale_factor = spatial_scale_factor
84
+ self.spatial_scale_mode = spatial_scale_mode
85
+ self.spectral_pos_encoding = spectral_pos_encoding
86
+ self.ffc3d = ffc3d
87
+ self.fft_norm = fft_norm
88
+
89
+ def forward(self, x):
90
+ batch = x.shape[0]
91
+
92
+ if self.spatial_scale_factor is not None:
93
+ orig_size = x.shape[-2:]
94
+ x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)
95
+
96
+ r_size = x.size()
97
+ # (batch, c, h, w/2+1, 2)
98
+ fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
99
+ ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
100
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
101
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
102
+ ffted = ffted.view((batch, -1,) + ffted.size()[3:])
103
+
104
+ if self.spectral_pos_encoding:
105
+ height, width = ffted.shape[-2:]
106
+ coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
107
+ coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
108
+ ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
109
+
110
+ if self.use_se:
111
+ ffted = self.se(ffted)
112
+
113
+ ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
114
+ ffted = self.relu(self.bn(ffted))
115
+
116
+ ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
117
+ 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
118
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
119
+
120
+ ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
121
+ output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
122
+
123
+ if self.spatial_scale_factor is not None:
124
+ output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
125
+
126
+ return output
127
+
128
+
129
+ class SpectralTransform(nn.Module):
130
+ def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
131
+ # bn_layer not used
132
+ super(SpectralTransform, self).__init__()
133
+ self.enable_lfu = enable_lfu
134
+ if stride == 2:
135
+ self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
136
+ else:
137
+ self.downsample = nn.Identity()
138
+
139
+ self.stride = stride
140
+ self.conv1 = nn.Sequential(
141
+ nn.Conv2d(in_channels, out_channels //
142
+ 2, kernel_size=1, groups=groups, bias=False),
143
+ nn.BatchNorm2d(out_channels // 2),
144
+ nn.ReLU(inplace=True)
145
+ )
146
+ self.fu = FourierUnit(
147
+ out_channels // 2, out_channels // 2, groups, **fu_kwargs)
148
+ if self.enable_lfu:
149
+ self.lfu = FourierUnit(
150
+ out_channels // 2, out_channels // 2, groups)
151
+ self.conv2 = torch.nn.Conv2d(
152
+ out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
153
+
154
+ def forward(self, x):
155
+ x = self.downsample(x)
156
+ x = self.conv1(x)
157
+ output = self.fu(x)
158
+
159
+ if self.enable_lfu:
160
+ n, c, h, w = x.shape
161
+ split_no = 2
162
+ split_s = h // split_no
163
+ xs = torch.cat(torch.split(
164
+ x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
165
+ xs = torch.cat(torch.split(xs, split_s, dim=-1),
166
+ dim=1).contiguous()
167
+ xs = self.lfu(xs)
168
+ xs = xs.repeat(1, 1, split_no, split_no).contiguous()
169
+ else:
170
+ xs = 0
171
+
172
+ output = self.conv2(x + output + xs)
173
+ return output
174
+
175
+
176
+ class FFC(nn.Module):
177
+
178
+ def __init__(self, in_channels, out_channels, kernel_size,
179
+ ratio_gin, ratio_gout, stride=1, padding=0,
180
+ dilation=1, groups=1, bias=False, enable_lfu=True,
181
+ padding_type='reflect', gated=False, **spectral_kwargs):
182
+ super(FFC, self).__init__()
183
+
184
+ assert stride == 1 or stride == 2, "Stride should be 1 or 2."
185
+ self.stride = stride
186
+
187
+ in_cg = int(in_channels * ratio_gin)
188
+ in_cl = in_channels - in_cg
189
+ out_cg = int(out_channels * ratio_gout)
190
+ out_cl = out_channels - out_cg
191
+
192
+ self.ratio_gin = ratio_gin
193
+ self.ratio_gout = ratio_gout
194
+ self.global_in_num = in_cg
195
+
196
+ module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
197
+ self.convl2l = module(in_cl, out_cl, kernel_size,
198
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
199
+ module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
200
+ self.convl2g = module(in_cl, out_cg, kernel_size,
201
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
202
+ module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
203
+ self.convg2l = module(in_cg, out_cl, kernel_size,
204
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
205
+ module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
206
+ self.convg2g = module(
207
+ in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
208
+
209
+ self.gated = gated
210
+ module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
211
+ self.gate = module(in_channels, 2, 1)
212
+
213
+ def forward(self, x):
214
+ x_l, x_g = x if type(x) is tuple else (x, 0)
215
+ out_xl, out_xg = 0, 0
216
+
217
+ if self.gated:
218
+ total_input_parts = [x_l]
219
+ if torch.is_tensor(x_g):
220
+ total_input_parts.append(x_g)
221
+ total_input = torch.cat(total_input_parts, dim=1)
222
+
223
+ gates = torch.sigmoid(self.gate(total_input))
224
+ g2l_gate, l2g_gate = gates.chunk(2, dim=1)
225
+ else:
226
+ g2l_gate, l2g_gate = 1, 1
227
+
228
+ if self.ratio_gout != 1:
229
+ out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
230
+ if self.ratio_gout != 0:
231
+ out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
232
+
233
+ return out_xl, out_xg
models/transformer.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from einops import rearrange
5
+
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+
10
+
11
+ class GELU(nn.Module):
12
+ def __init__(self):
13
+ super(GELU, self).__init__()
14
+ def forward(self, x):
15
+ return 0.5*x*(1+F.tanh(np.sqrt(2/np.pi)*(x+0.044715*torch.pow(x,3))))
16
+
17
+ # helpers
18
+
19
+ def pair(t):
20
+ return t if isinstance(t, tuple) else (t, t)
21
+
22
+ # classes
23
+
24
+ class PreNorm(nn.Module):
25
+ def __init__(self, dim, fn):
26
+ super().__init__()
27
+ self.norm = nn.LayerNorm(dim)
28
+ self.fn = fn
29
+ def forward(self, x, **kwargs):
30
+ return self.fn(self.norm(x), **kwargs)
31
+
32
+ class DualPreNorm(nn.Module):
33
+ def __init__(self, dim, fn):
34
+ super().__init__()
35
+ self.normx = nn.LayerNorm(dim)
36
+ self.normy = nn.LayerNorm(dim)
37
+ self.fn = fn
38
+ def forward(self, x, y, **kwargs):
39
+ return self.fn(self.normx(x), self.normy(y), **kwargs)
40
+
41
+ class FeedForward(nn.Module):
42
+ def __init__(self, dim, hidden_dim, dropout = 0.):
43
+ super().__init__()
44
+ self.net = nn.Sequential(
45
+ nn.Linear(dim, hidden_dim),
46
+ GELU(),
47
+ nn.Dropout(dropout),
48
+ nn.Linear(hidden_dim, dim),
49
+ nn.Dropout(dropout)
50
+ )
51
+ def forward(self, x):
52
+ return self.net(x)
53
+
54
+ class Attention(nn.Module):
55
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
56
+ super().__init__()
57
+ inner_dim = dim_head * heads
58
+ project_out = not (heads == 1 and dim_head == dim)
59
+
60
+ self.heads = heads
61
+ self.scale = dim_head ** -0.5
62
+
63
+ self.attend = nn.Softmax(dim = -1)
64
+
65
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
66
+ self.to_k = nn.Linear(dim, inner_dim, bias = False)
67
+ self.to_v = nn.Linear(dim, inner_dim, bias = False)
68
+
69
+
70
+ self.to_out = nn.Sequential(
71
+ nn.Linear(inner_dim, dim),
72
+ nn.Dropout(dropout)
73
+ ) if project_out else nn.Identity()
74
+
75
+ def forward(self, x, y):
76
+ # qk = self.to_qk(x).chunk(2, dim = -1) #
77
+ q = rearrange(self.to_q(x), 'b n (h d) -> b h n d', h = self.heads) # q,k from the zero feature
78
+ k = rearrange(self.to_k(x), 'b n (h d) -> b h n d', h = self.heads) # v from the reference features
79
+ v = rearrange(self.to_v(y), 'b n (h d) -> b h n d', h = self.heads)
80
+
81
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
82
+
83
+ attn = self.attend(dots)
84
+
85
+ out = torch.matmul(attn, v)
86
+ out = rearrange(out, 'b h n d -> b n (h d)')
87
+ return self.to_out(out)
88
+
89
+ class Transformer(nn.Module):
90
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
91
+ super().__init__()
92
+ self.layers = nn.ModuleList([])
93
+ for _ in range(depth):
94
+ self.layers.append(nn.ModuleList([
95
+ DualPreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
96
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
97
+ ]))
98
+
99
+
100
+ def forward(self, x, y): # x is the cropped, y is the foreign reference
101
+ bs,c,h,w = x.size()
102
+
103
+ # img to embedding
104
+ x = x.view(bs,c,-1).permute(0,2,1)
105
+ y = y.view(bs,c,-1).permute(0,2,1)
106
+
107
+ for attn, ff in self.layers:
108
+ x = attn(x, y) + x
109
+ x = ff(x) + x
110
+
111
+ x = x.view(bs,h,w,c).permute(0,3,1,2)
112
+ return x
113
+
114
+ class RETURNX(nn.Module):
115
+ def __init__(self,):
116
+ super().__init__()
117
+
118
+ def forward(self, x, y): # x is the cropped, y is the foreign reference
119
+ return x