Flux9665's picture
initial commit
6faeba1
raw
history blame
6.38 kB
"""
taken and adapted from https://github.com/as-ideas/DeepForcedAligner
refined with insights from https://www.audiolabs-erlangen.de/resources/NLUI/2023-ICASSP-eval-alignment-tts
EVALUATING SPEECH–PHONEME ALIGNMENT AND ITS IMPACT ON NEURAL TEXT-TO-SPEECH SYNTHESIS
by Frank Zalkow, Prachi Govalkar, Meinard Muller, Emanuel A. P. Habets, Christian Dittmar
"""
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.multiprocessing
from torch.nn import CTCLoss
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
from Utility.utils import make_non_pad_mask
class BatchNormConv(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
super().__init__()
self.conv = torch.nn.Conv1d(
in_channels, out_channels, kernel_size,
stride=1, padding=kernel_size // 2, bias=False)
self.bnorm = torch.nn.SyncBatchNorm.convert_sync_batchnorm(torch.nn.BatchNorm1d(out_channels))
self.relu = torch.nn.ReLU()
def forward(self, x):
x = x.transpose(1, 2)
x = self.conv(x)
x = self.relu(x)
x = self.bnorm(x)
x = x.transpose(1, 2)
return x
class Aligner(torch.nn.Module):
def __init__(self,
n_features=128,
num_symbols=145,
conv_dim=512,
lstm_dim=512):
super().__init__()
self.convs = torch.nn.ModuleList([
BatchNormConv(n_features, conv_dim, 3),
torch.nn.Dropout(p=0.5),
BatchNormConv(conv_dim, conv_dim, 3),
torch.nn.Dropout(p=0.5),
BatchNormConv(conv_dim, conv_dim, 3),
torch.nn.Dropout(p=0.5),
BatchNormConv(conv_dim, conv_dim, 3),
torch.nn.Dropout(p=0.5),
BatchNormConv(conv_dim, conv_dim, 3),
torch.nn.Dropout(p=0.5),
])
self.rnn1 = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True)
self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols)
self.tf = ArticulatoryCombinedTextFrontend(language="eng")
self.ctc_loss = CTCLoss(blank=144, zero_infinity=True)
self.vector_to_id = dict()
def forward(self, x, lens=None):
for conv in self.convs:
x = conv(x)
if lens is not None:
x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
x, _ = self.rnn1(x)
x, _ = self.rnn2(x)
if lens is not None:
x, _ = pad_packed_sequence(x, batch_first=True)
x = self.proj(x)
if lens is not None:
out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(x.device)
x = x * out_masks.float()
return x
@torch.inference_mode()
def inference(self, features, tokens, save_img_for_debug=None, train=False, pathfinding="MAS", return_ctc=False):
if not train:
tokens_indexed = self.tf.text_vectors_to_id_sequence(text_vector=tokens) # first we need to convert the articulatory vectors to IDs, so we can apply dijkstra or viterbi
tokens = np.asarray(tokens_indexed)
else:
tokens = tokens.cpu().detach().numpy()
pred = self(features.unsqueeze(0))
if return_ctc:
ctc_loss = self.ctc_loss(pred.transpose(0, 1).log_softmax(2), torch.LongTensor(tokens), torch.LongTensor([len(pred[0])]),
torch.LongTensor([len(tokens)])).item()
pred = pred.squeeze().cpu().detach().numpy()
pred_max = pred[:, tokens]
# run monotonic alignment search
alignment_matrix = binarize_alignment(pred_max)
if save_img_for_debug is not None:
phones = list()
for index in tokens:
phones.append(self.tf.id_to_phone[index])
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis')
ax.set_ylabel("Mel-Frames")
ax.set_xticks(range(len(pred_max[0])))
ax.set_xticklabels(labels=phones)
ax.set_title("MAS Path")
plt.tight_layout()
fig.savefig(save_img_for_debug)
fig.clf()
plt.close()
if return_ctc:
return alignment_matrix, ctc_loss
return alignment_matrix
def binarize_alignment(alignment_prob):
"""
# Implementation by:
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/alignment.py
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attn_loss_function.py
Binarizes alignment with MAS.
"""
# assumes features x text
opt = np.zeros_like(alignment_prob)
alignment_prob = alignment_prob + (np.abs(alignment_prob).max() + 1.0) # make all numbers positive and add an offset to avoid log of 0 later
alignment_prob * alignment_prob * (1.0 / alignment_prob.max()) # normalize to (0, 1]
attn_map = np.log(alignment_prob)
attn_map[0, 1:] = -np.inf
log_p = np.zeros_like(attn_map)
log_p[0, :] = attn_map[0, :]
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
for i in range(1, attn_map.shape[0]):
for j in range(attn_map.shape[1]): # for each text dim
prev_log = log_p[i - 1, j]
prev_j = j
if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
prev_log = log_p[i - 1, j - 1]
prev_j = j - 1
log_p[i, j] = attn_map[i, j] + prev_log
prev_ind[i, j] = prev_j
# now backtrack
curr_text_idx = attn_map.shape[1] - 1
for i in range(attn_map.shape[0] - 1, -1, -1):
opt[i, curr_text_idx] = 1
curr_text_idx = prev_ind[i, curr_text_idx]
opt[0, curr_text_idx] = 1
return opt
if __name__ == '__main__':
print(sum(p.numel() for p in Aligner().parameters() if p.requires_grad))
print(Aligner()(x=torch.randn(size=[3, 30, 128]), lens=torch.LongTensor([20, 30, 10])).shape)