File size: 6,381 Bytes
9e275b8
 
70399da
 
 
 
9e275b8
 
 
 
 
 
 
 
 
 
70399da
9e275b8
 
70399da
9e275b8
 
 
70399da
9e275b8
 
70399da
 
9e275b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70399da
 
9e275b8
70399da
9e275b8
70399da
9e275b8
70399da
9e275b8
70399da
9e275b8
70399da
9e275b8
70399da
9e275b8
70399da
 
9e275b8
 
 
 
 
 
 
 
 
 
70399da
 
9e275b8
 
 
 
70399da
 
 
9e275b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70399da
9e275b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70399da
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
taken and adapted from https://github.com/as-ideas/DeepForcedAligner

refined with insights from https://www.audiolabs-erlangen.de/resources/NLUI/2023-ICASSP-eval-alignment-tts
EVALUATING SPEECH–PHONEME ALIGNMENT AND ITS IMPACT ON NEURAL TEXT-TO-SPEECH SYNTHESIS
by Frank Zalkow, Prachi Govalkar, Meinard Muller, Emanuel A. P. Habets, Christian Dittmar
"""
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.multiprocessing
from torch.nn import CTCLoss
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence

from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
from Utility.utils import make_non_pad_mask


class BatchNormConv(torch.nn.Module):

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
        super().__init__()
        self.conv = torch.nn.Conv1d(
            in_channels, out_channels, kernel_size,
            stride=1, padding=kernel_size // 2, bias=False)
        self.bnorm = torch.nn.SyncBatchNorm.convert_sync_batchnorm(torch.nn.BatchNorm1d(out_channels))
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.conv(x)
        x = self.relu(x)
        x = self.bnorm(x)
        x = x.transpose(1, 2)
        return x


class Aligner(torch.nn.Module):

    def __init__(self,
                 n_features=128,
                 num_symbols=145,
                 conv_dim=512,
                 lstm_dim=512):
        super().__init__()
        self.convs = torch.nn.ModuleList([
            BatchNormConv(n_features, conv_dim, 3),
            torch.nn.Dropout(p=0.5),
            BatchNormConv(conv_dim, conv_dim, 3),
            torch.nn.Dropout(p=0.5),
            BatchNormConv(conv_dim, conv_dim, 3),
            torch.nn.Dropout(p=0.5),
            BatchNormConv(conv_dim, conv_dim, 3),
            torch.nn.Dropout(p=0.5),
            BatchNormConv(conv_dim, conv_dim, 3),
            torch.nn.Dropout(p=0.5),
        ])
        self.rnn1 = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True)
        self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
        self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols)
        self.tf = ArticulatoryCombinedTextFrontend(language="eng")
        self.ctc_loss = CTCLoss(blank=144, zero_infinity=True)
        self.vector_to_id = dict()

    def forward(self, x, lens=None):
        for conv in self.convs:
            x = conv(x)
        if lens is not None:
            x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
        x, _ = self.rnn1(x)
        x, _ = self.rnn2(x)
        if lens is not None:
            x, _ = pad_packed_sequence(x, batch_first=True)

        x = self.proj(x)
        if lens is not None:
            out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(x.device)
            x = x * out_masks.float()

        return x

    @torch.inference_mode()
    def inference(self, features, tokens, save_img_for_debug=None, train=False, pathfinding="MAS", return_ctc=False):
        if not train:
            tokens_indexed = self.tf.text_vectors_to_id_sequence(text_vector=tokens)  # first we need to convert the articulatory vectors to IDs, so we can apply dijkstra or viterbi
            tokens = np.asarray(tokens_indexed)
        else:
            tokens = tokens.cpu().detach().numpy()

        pred = self(features.unsqueeze(0))
        if return_ctc:
            ctc_loss = self.ctc_loss(pred.transpose(0, 1).log_softmax(2), torch.LongTensor(tokens), torch.LongTensor([len(pred[0])]),
                                     torch.LongTensor([len(tokens)])).item()
        pred = pred.squeeze().cpu().detach().numpy()
        pred_max = pred[:, tokens]

        # run monotonic alignment search
        alignment_matrix = binarize_alignment(pred_max)

        if save_img_for_debug is not None:
            phones = list()
            for index in tokens:
                phones.append(self.tf.id_to_phone[index])
            fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))

            ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis')
            ax.set_ylabel("Mel-Frames")
            ax.set_xticks(range(len(pred_max[0])))
            ax.set_xticklabels(labels=phones)
            ax.set_title("MAS Path")

            plt.tight_layout()
            fig.savefig(save_img_for_debug)
            fig.clf()
            plt.close()

        if return_ctc:
            return alignment_matrix, ctc_loss
        return alignment_matrix


def binarize_alignment(alignment_prob):
    """
    # Implementation by:
    # https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/alignment.py
    # https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attn_loss_function.py

    Binarizes alignment with MAS.
    """
    # assumes features x text
    opt = np.zeros_like(alignment_prob)
    alignment_prob = alignment_prob + (np.abs(alignment_prob).max() + 1.0)  # make all numbers positive and add an offset to avoid log of 0 later
    alignment_prob * alignment_prob * (1.0 / alignment_prob.max())  # normalize to (0,  1]
    attn_map = np.log(alignment_prob)
    attn_map[0, 1:] = -np.inf
    log_p = np.zeros_like(attn_map)
    log_p[0, :] = attn_map[0, :]
    prev_ind = np.zeros_like(attn_map, dtype=np.int64)
    for i in range(1, attn_map.shape[0]):
        for j in range(attn_map.shape[1]):  # for each text dim
            prev_log = log_p[i - 1, j]
            prev_j = j
            if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
                prev_log = log_p[i - 1, j - 1]
                prev_j = j - 1
            log_p[i, j] = attn_map[i, j] + prev_log
            prev_ind[i, j] = prev_j
    # now backtrack
    curr_text_idx = attn_map.shape[1] - 1
    for i in range(attn_map.shape[0] - 1, -1, -1):
        opt[i, curr_text_idx] = 1
        curr_text_idx = prev_ind[i, curr_text_idx]
    opt[0, curr_text_idx] = 1
    return opt


if __name__ == '__main__':
    print(sum(p.numel() for p in Aligner().parameters() if p.requires_grad))
    print(Aligner()(x=torch.randn(size=[3, 30, 128]), lens=torch.LongTensor([20, 30, 10])).shape)