File size: 3,219 Bytes
9f4b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from torch import nn

from nota_wav2lip.models.base import Wav2LipBase
from nota_wav2lip.models.conv import Conv2d, Conv2dTranspose


class NotaWav2Lip(Wav2LipBase):
    def __init__(self, nef=4, naf=8, ndf=8, x_size=96):
        super().__init__()

        assert x_size in [96, 128]
        self.ker_sz_last = x_size // 32

        self.face_encoder_blocks = nn.ModuleList([
            nn.Sequential(Conv2d(6, nef, kernel_size=7, stride=1, padding=3)),  # 96,96

            nn.Sequential(Conv2d(nef, nef * 2, kernel_size=3, stride=2, padding=1),),  # 48,48

            nn.Sequential(Conv2d(nef * 2, nef * 4, kernel_size=3, stride=2, padding=1),),  # 24,24

            nn.Sequential(Conv2d(nef * 4, nef * 8, kernel_size=3, stride=2, padding=1),),  # 12,12

            nn.Sequential(Conv2d(nef * 8, nef * 16, kernel_size=3, stride=2, padding=1),),  # 6,6

            nn.Sequential(Conv2d(nef * 16, nef * 32, kernel_size=3, stride=2, padding=1),),  # 3,3

            nn.Sequential(Conv2d(nef * 32, nef * 32, kernel_size=self.ker_sz_last, stride=1, padding=0),  # 1, 1
                          Conv2d(nef * 32, nef * 32, kernel_size=1, stride=1, padding=0)), ])

        self.audio_encoder = nn.Sequential(
            Conv2d(1, naf, kernel_size=3, stride=1, padding=1),

            Conv2d(naf, naf * 2, kernel_size=3, stride=(3, 1), padding=1),

            Conv2d(naf * 2, naf * 4, kernel_size=3, stride=3, padding=1),

            Conv2d(naf * 4, naf * 8, kernel_size=3, stride=(3, 2), padding=1),

            Conv2d(naf * 8, naf * 16, kernel_size=3, stride=1, padding=0),
            Conv2d(naf * 16, naf * 16, kernel_size=1, stride=1, padding=0), )

        self.face_decoder_blocks = nn.ModuleList([
            nn.Sequential(Conv2d(naf * 16, naf * 16, kernel_size=1, stride=1, padding=0), ),

            nn.Sequential(Conv2dTranspose(nef * 32 + naf * 16, ndf * 16, kernel_size=self.ker_sz_last, stride=1, padding=0),),
                          # 3,3 # 512+512 = 1024

            nn.Sequential(
                Conv2dTranspose(nef * 32 + ndf * 16, ndf * 16, kernel_size=3, stride=2, padding=1, output_padding=1),),  # 6, 6
                # 512+512 = 1024

            nn.Sequential(
                Conv2dTranspose(nef * 16 + ndf * 16, ndf * 12, kernel_size=3, stride=2, padding=1, output_padding=1),),  # 12, 12
                # 256+512 = 768

            nn.Sequential(
                Conv2dTranspose(nef * 8 + ndf * 12, ndf * 8, kernel_size=3, stride=2, padding=1, output_padding=1),),  # 24, 24
                # 128+384 = 512

            nn.Sequential(
                Conv2dTranspose(nef * 4 + ndf * 8, ndf * 4, kernel_size=3, stride=2, padding=1, output_padding=1),),  # 48, 48
                # 64+256 = 320

            nn.Sequential(
                Conv2dTranspose(nef * 2 + ndf * 4, ndf * 2, kernel_size=3, stride=2, padding=1, output_padding=1),), # 96,96
                # 32+128 = 160
        ])

        self.output_block = nn.Sequential(Conv2d(nef + ndf * 2, ndf, kernel_size=3, stride=1, padding=1),  # 16+64 = 80
                                          nn.Conv2d(ndf, 3, kernel_size=1, stride=1, padding=0),
                                          nn.Sigmoid())