Spaces:
Running
on
Zero
Running
on
Zero
""" | |
taken and adapted from https://github.com/as-ideas/DeepForcedAligner | |
refined with insights from https://www.audiolabs-erlangen.de/resources/NLUI/2023-ICASSP-eval-alignment-tts | |
EVALUATING SPEECH–PHONEME ALIGNMENT AND ITS IMPACT ON NEURAL TEXT-TO-SPEECH SYNTHESIS | |
by Frank Zalkow, Prachi Govalkar, Meinard Muller, Emanuel A. P. Habets, Christian Dittmar | |
""" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.multiprocessing | |
from torch.nn import CTCLoss | |
from torch.nn.utils.rnn import pack_padded_sequence | |
from torch.nn.utils.rnn import pad_packed_sequence | |
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend | |
from Utility.utils import make_non_pad_mask | |
class BatchNormConv(torch.nn.Module): | |
def __init__(self, in_channels: int, out_channels: int, kernel_size: int): | |
super().__init__() | |
self.conv = torch.nn.Conv1d( | |
in_channels, out_channels, kernel_size, | |
stride=1, padding=kernel_size // 2, bias=False) | |
self.bnorm = torch.nn.SyncBatchNorm.convert_sync_batchnorm(torch.nn.BatchNorm1d(out_channels)) | |
self.relu = torch.nn.ReLU() | |
def forward(self, x): | |
x = x.transpose(1, 2) | |
x = self.conv(x) | |
x = self.relu(x) | |
x = self.bnorm(x) | |
x = x.transpose(1, 2) | |
return x | |
class Aligner(torch.nn.Module): | |
def __init__(self, | |
n_features=128, | |
num_symbols=145, | |
conv_dim=512, | |
lstm_dim=512): | |
super().__init__() | |
self.convs = torch.nn.ModuleList([ | |
BatchNormConv(n_features, conv_dim, 3), | |
torch.nn.Dropout(p=0.5), | |
BatchNormConv(conv_dim, conv_dim, 3), | |
torch.nn.Dropout(p=0.5), | |
BatchNormConv(conv_dim, conv_dim, 3), | |
torch.nn.Dropout(p=0.5), | |
BatchNormConv(conv_dim, conv_dim, 3), | |
torch.nn.Dropout(p=0.5), | |
BatchNormConv(conv_dim, conv_dim, 3), | |
torch.nn.Dropout(p=0.5), | |
]) | |
self.rnn1 = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True) | |
self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True) | |
self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols) | |
self.tf = ArticulatoryCombinedTextFrontend(language="eng") | |
self.ctc_loss = CTCLoss(blank=144, zero_infinity=True) | |
self.vector_to_id = dict() | |
def forward(self, x, lens=None): | |
for conv in self.convs: | |
x = conv(x) | |
if lens is not None: | |
x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False) | |
x, _ = self.rnn1(x) | |
x, _ = self.rnn2(x) | |
if lens is not None: | |
x, _ = pad_packed_sequence(x, batch_first=True) | |
x = self.proj(x) | |
if lens is not None: | |
out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(x.device) | |
x = x * out_masks.float() | |
return x | |
def inference(self, features, tokens, save_img_for_debug=None, train=False, pathfinding="MAS", return_ctc=False): | |
if not train: | |
tokens_indexed = self.tf.text_vectors_to_id_sequence(text_vector=tokens) # first we need to convert the articulatory vectors to IDs, so we can apply dijkstra or viterbi | |
tokens = np.asarray(tokens_indexed) | |
else: | |
tokens = tokens.cpu().detach().numpy() | |
pred = self(features.unsqueeze(0)) | |
if return_ctc: | |
ctc_loss = self.ctc_loss(pred.transpose(0, 1).log_softmax(2), torch.LongTensor(tokens), torch.LongTensor([len(pred[0])]), | |
torch.LongTensor([len(tokens)])).item() | |
pred = pred.squeeze().cpu().detach().numpy() | |
pred_max = pred[:, tokens] | |
# run monotonic alignment search | |
alignment_matrix = binarize_alignment(pred_max) | |
if save_img_for_debug is not None: | |
phones = list() | |
for index in tokens: | |
phones.append(self.tf.id_to_phone[index]) | |
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5)) | |
ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis') | |
ax.set_ylabel("Mel-Frames") | |
ax.set_xticks(range(len(pred_max[0]))) | |
ax.set_xticklabels(labels=phones) | |
ax.set_title("MAS Path") | |
plt.tight_layout() | |
fig.savefig(save_img_for_debug) | |
fig.clf() | |
plt.close() | |
if return_ctc: | |
return alignment_matrix, ctc_loss | |
return alignment_matrix | |
def binarize_alignment(alignment_prob): | |
""" | |
# Implementation by: | |
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/alignment.py | |
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attn_loss_function.py | |
Binarizes alignment with MAS. | |
""" | |
# assumes features x text | |
opt = np.zeros_like(alignment_prob) | |
alignment_prob = alignment_prob + (np.abs(alignment_prob).max() + 1.0) # make all numbers positive and add an offset to avoid log of 0 later | |
alignment_prob * alignment_prob * (1.0 / alignment_prob.max()) # normalize to (0, 1] | |
attn_map = np.log(alignment_prob) | |
attn_map[0, 1:] = -np.inf | |
log_p = np.zeros_like(attn_map) | |
log_p[0, :] = attn_map[0, :] | |
prev_ind = np.zeros_like(attn_map, dtype=np.int64) | |
for i in range(1, attn_map.shape[0]): | |
for j in range(attn_map.shape[1]): # for each text dim | |
prev_log = log_p[i - 1, j] | |
prev_j = j | |
if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]: | |
prev_log = log_p[i - 1, j - 1] | |
prev_j = j - 1 | |
log_p[i, j] = attn_map[i, j] + prev_log | |
prev_ind[i, j] = prev_j | |
# now backtrack | |
curr_text_idx = attn_map.shape[1] - 1 | |
for i in range(attn_map.shape[0] - 1, -1, -1): | |
opt[i, curr_text_idx] = 1 | |
curr_text_idx = prev_ind[i, curr_text_idx] | |
opt[0, curr_text_idx] = 1 | |
return opt | |
if __name__ == '__main__': | |
print(sum(p.numel() for p in Aligner().parameters() if p.requires_grad)) | |
print(Aligner()(x=torch.randn(size=[3, 30, 128]), lens=torch.LongTensor([20, 30, 10])).shape) | |