Spaces:
Running
on
T4
Running
on
T4
File size: 6,381 Bytes
9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
"""
taken and adapted from https://github.com/as-ideas/DeepForcedAligner
refined with insights from https://www.audiolabs-erlangen.de/resources/NLUI/2023-ICASSP-eval-alignment-tts
EVALUATING SPEECH–PHONEME ALIGNMENT AND ITS IMPACT ON NEURAL TEXT-TO-SPEECH SYNTHESIS
by Frank Zalkow, Prachi Govalkar, Meinard Muller, Emanuel A. P. Habets, Christian Dittmar
"""
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.multiprocessing
from torch.nn import CTCLoss
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
from Utility.utils import make_non_pad_mask
class BatchNormConv(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
super().__init__()
self.conv = torch.nn.Conv1d(
in_channels, out_channels, kernel_size,
stride=1, padding=kernel_size // 2, bias=False)
self.bnorm = torch.nn.SyncBatchNorm.convert_sync_batchnorm(torch.nn.BatchNorm1d(out_channels))
self.relu = torch.nn.ReLU()
def forward(self, x):
x = x.transpose(1, 2)
x = self.conv(x)
x = self.relu(x)
x = self.bnorm(x)
x = x.transpose(1, 2)
return x
class Aligner(torch.nn.Module):
def __init__(self,
n_features=128,
num_symbols=145,
conv_dim=512,
lstm_dim=512):
super().__init__()
self.convs = torch.nn.ModuleList([
BatchNormConv(n_features, conv_dim, 3),
torch.nn.Dropout(p=0.5),
BatchNormConv(conv_dim, conv_dim, 3),
torch.nn.Dropout(p=0.5),
BatchNormConv(conv_dim, conv_dim, 3),
torch.nn.Dropout(p=0.5),
BatchNormConv(conv_dim, conv_dim, 3),
torch.nn.Dropout(p=0.5),
BatchNormConv(conv_dim, conv_dim, 3),
torch.nn.Dropout(p=0.5),
])
self.rnn1 = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True)
self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols)
self.tf = ArticulatoryCombinedTextFrontend(language="eng")
self.ctc_loss = CTCLoss(blank=144, zero_infinity=True)
self.vector_to_id = dict()
def forward(self, x, lens=None):
for conv in self.convs:
x = conv(x)
if lens is not None:
x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
x, _ = self.rnn1(x)
x, _ = self.rnn2(x)
if lens is not None:
x, _ = pad_packed_sequence(x, batch_first=True)
x = self.proj(x)
if lens is not None:
out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(x.device)
x = x * out_masks.float()
return x
@torch.inference_mode()
def inference(self, features, tokens, save_img_for_debug=None, train=False, pathfinding="MAS", return_ctc=False):
if not train:
tokens_indexed = self.tf.text_vectors_to_id_sequence(text_vector=tokens) # first we need to convert the articulatory vectors to IDs, so we can apply dijkstra or viterbi
tokens = np.asarray(tokens_indexed)
else:
tokens = tokens.cpu().detach().numpy()
pred = self(features.unsqueeze(0))
if return_ctc:
ctc_loss = self.ctc_loss(pred.transpose(0, 1).log_softmax(2), torch.LongTensor(tokens), torch.LongTensor([len(pred[0])]),
torch.LongTensor([len(tokens)])).item()
pred = pred.squeeze().cpu().detach().numpy()
pred_max = pred[:, tokens]
# run monotonic alignment search
alignment_matrix = binarize_alignment(pred_max)
if save_img_for_debug is not None:
phones = list()
for index in tokens:
phones.append(self.tf.id_to_phone[index])
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis')
ax.set_ylabel("Mel-Frames")
ax.set_xticks(range(len(pred_max[0])))
ax.set_xticklabels(labels=phones)
ax.set_title("MAS Path")
plt.tight_layout()
fig.savefig(save_img_for_debug)
fig.clf()
plt.close()
if return_ctc:
return alignment_matrix, ctc_loss
return alignment_matrix
def binarize_alignment(alignment_prob):
"""
# Implementation by:
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/alignment.py
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attn_loss_function.py
Binarizes alignment with MAS.
"""
# assumes features x text
opt = np.zeros_like(alignment_prob)
alignment_prob = alignment_prob + (np.abs(alignment_prob).max() + 1.0) # make all numbers positive and add an offset to avoid log of 0 later
alignment_prob * alignment_prob * (1.0 / alignment_prob.max()) # normalize to (0, 1]
attn_map = np.log(alignment_prob)
attn_map[0, 1:] = -np.inf
log_p = np.zeros_like(attn_map)
log_p[0, :] = attn_map[0, :]
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
for i in range(1, attn_map.shape[0]):
for j in range(attn_map.shape[1]): # for each text dim
prev_log = log_p[i - 1, j]
prev_j = j
if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
prev_log = log_p[i - 1, j - 1]
prev_j = j - 1
log_p[i, j] = attn_map[i, j] + prev_log
prev_ind[i, j] = prev_j
# now backtrack
curr_text_idx = attn_map.shape[1] - 1
for i in range(attn_map.shape[0] - 1, -1, -1):
opt[i, curr_text_idx] = 1
curr_text_idx = prev_ind[i, curr_text_idx]
opt[0, curr_text_idx] = 1
return opt
if __name__ == '__main__':
print(sum(p.numel() for p in Aligner().parameters() if p.requires_grad))
print(Aligner()(x=torch.randn(size=[3, 30, 128]), lens=torch.LongTensor([20, 30, 10])).shape)
|