# This script is modified from https://github.com/EricGuo5513/TM2T # Licensed under: https://github.com/EricGuo5513/TM2T/blob/main/LICENSE import torch.nn as nn class VQDecoderV3(nn.Module): def __init__(self, args): super(VQDecoderV3, self).__init__() n_up = args.vae_layer channels = [] for i in range(n_up - 1): channels.append(args.vae_length) channels.append(args.vae_length) channels.append(args.vae_test_dim) input_size = args.vae_length n_resblk = 2 assert len(channels) == n_up + 1 if input_size == channels[0]: layers = [] else: layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)] for i in range(n_resblk): layers += [ResBlock(channels[0])] # channels = channels for i in range(n_up): layers += [ nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2, inplace=True), ] layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)] self.main = nn.Sequential(*layers) # self.main.apply(init_weight) def forward(self, inputs): inputs = inputs.permute(0, 2, 1) outputs = self.main(inputs).permute(0, 2, 1) return outputs class ResBlock(nn.Module): def __init__(self, channel): super(ResBlock, self).__init__() self.model = nn.Sequential( nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1), ) def forward(self, x): residual = x out = self.model(x) out += residual return out