artst-demo-asr / artst /data /speech_to_speech_dataset.py
amupd's picture
initial commit
8b33290
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import logging
import os
from typing import Any, List, Optional
import librosa
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data.fairseq_dataset import FairseqDataset
logger = logging.getLogger(__name__)
def _collate_frames(
frames: List[torch.Tensor], is_audio_input: bool = False
):
"""
Convert a list of 2D frames into a padded 3D tensor
Args:
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
max_len = max(frame.size(0) for frame in frames)
if is_audio_input:
out = frames[0].new_zeros((len(frames), max_len))
else:
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
for i, v in enumerate(frames):
out[i, : v.size(0)] = v
return out
def load_audio(manifest_path, max_keep, min_keep):
"""manifest tsv: src_wav, src_nframe, tgt_wav, tgt_nframe, tgt_spkemb"""
n_long, n_short = 0, 0
src_names, tgt_names, inds, sizes, tgt_sizes, spk_embeds = [], [], [], [], [], []
with open(manifest_path) as f:
root = f.readline().strip()
for ind, line in enumerate(f):
items = line.strip().split("\t")
assert len(items) >= 2, line
sz = int(items[1])
if min_keep is not None and sz < min_keep:
n_short += 1
elif max_keep is not None and sz > max_keep:
n_long += 1
else:
src_names.append(items[0])
tgt_names.append(items[2])
tgt_sizes.append(items[3])
spk_embeds.append(items[4])
inds.append(ind)
sizes.append(sz)
tot = ind + 1
logger.info(
(
f"max_keep={max_keep}, min_keep={min_keep}, "
f"loaded {len(src_names)}, skipped {n_short} short and {n_long} long, "
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
)
)
return root, src_names, inds, tot, sizes, tgt_names, tgt_sizes, spk_embeds
def logmelfilterbank(
audio,
sampling_rate,
fft_size=1024,
hop_size=256,
win_length=None,
window="hann",
num_mels=80,
fmin=80,
fmax=7600,
eps=1e-10,
):
"""Compute log-Mel filterbank feature.
(https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py)
Args:
audio (ndarray): Audio signal (T,).
sampling_rate (int): Sampling rate.
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length. If set to None, it will be the same as fft_size.
window (str): Window function type.
num_mels (int): Number of mel basis.
fmin (int): Minimum frequency in mel basis calculation.
fmax (int): Maximum frequency in mel basis calculation.
eps (float): Epsilon value to avoid inf in log calculation.
Returns:
ndarray: Log Mel filterbank feature (#frames, num_mels).
"""
# get amplitude spectrogram
x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
win_length=win_length, window=window, pad_mode="reflect")
spc = np.abs(x_stft).T # (#frames, #bins)
# get mel basis
fmin = 0 if fmin is None else fmin
fmax = sampling_rate / 2 if fmax is None else fmax
mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
class SpeechToSpeechDataset(FairseqDataset):
def __init__(
self,
manifest_path: str,
sample_rate: float,
max_keep_sample_size: Optional[int] = None,
min_keep_sample_size: Optional[int] = None,
shuffle: bool = True,
normalize: bool = False,
reduction_factor: int = 1,
):
self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.tgt_audios, self.tgt_sizes, self.tgt_spkembs = load_audio(
manifest_path, max_keep_sample_size, min_keep_sample_size
)
self.sample_rate = sample_rate
self.shuffle = shuffle
self.normalize = normalize
self.reduction_factor = reduction_factor
logger.info(
f"reduction_factor={reduction_factor}, normalize={normalize}"
)
def get_audio(self, index):
import soundfile as sf
wav_fbank = []
for name in [self.audio_names[index], self.tgt_audios[index]]:
wav_path = os.path.join(self.audio_root, name)
wav, cur_sample_rate = sf.read(wav_path)
wav = torch.from_numpy(wav).float()
fbank = logmelfilterbank(
wav.view(-1).cpu().numpy(), 16000
)
fbank = torch.from_numpy(fbank).float()
wav = self.postprocess(wav, cur_sample_rate)
wav_fbank.append(wav)
wav_fbank.append(fbank)
src_wav, src_fbank, tgt_wav, tgt_fbank = wav_fbank
return src_wav, src_fbank, tgt_wav, tgt_fbank
def __getitem__(self, index):
src_wav, src_fbank, tgt_wav, tgt_fbank = self.get_audio(index)
spkembs = np.load(os.path.join(self.audio_root, self.tgt_spkembs[index]))
spkembs = torch.from_numpy(spkembs).float()
name = self.audio_names[index].replace("/", ".").replace(".wav", "") + "-" + self.tgt_audios[index].replace("/", ".").replace(".wav", "") + ".wav"
return {"id": index, "source": src_wav, "target": tgt_fbank, "spkembs": spkembs, "audio_name": name, "tgt_name": self.tgt_audios[index]}
def __len__(self):
return len(self.wav_sizes)
def collater(self, samples):
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
audios = [s["source"] for s in samples]
audio_sizes = [len(s) for s in audios]
audio_size = max(audio_sizes)
collated_audios, padding_mask = self.collater_audio(
audios, audio_size
)
fbanks = [s["target"] for s in samples]
fbank_sizes = [len(s) for s in fbanks]
collated_fbanks = _collate_frames(fbanks)
collated_fbanks_size = torch.tensor(fbank_sizes, dtype=torch.long)
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
if self.reduction_factor > 1:
collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor]
collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size])
else:
collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size
prev_output_tokens = torch.cat(
[collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1
)
# make labels for stop prediction
labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1))
for i, l in enumerate(fbank_sizes):
labels[i, l - 1 :] = 1.0
spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True)
net_input = {
"source": collated_audios,
"padding_mask": padding_mask,
"prev_output_tokens": prev_output_tokens,
"tgt_lengths": collated_fbanks_size_in,
"spkembs": spkembs,
"task_name": "s2s",
}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"name": [s["audio_name"] for s in samples],
"tgt_name": [s["tgt_name"] for s in samples],
"net_input": net_input,
"labels": labels,
"dec_target": collated_fbanks,
"dec_target_lengths": collated_fbanks_size,
"src_lengths": torch.LongTensor(audio_sizes),
"task_name": "s2s",
"ntokens": sum(audio_sizes),
"target": collated_fbanks,
}
return batch
def collater_audio(self, audios, audio_size):
collated_audios = audios[0].new_zeros(len(audios), audio_size)
padding_mask = (
torch.BoolTensor(collated_audios.shape).fill_(False)
)
for i, audio in enumerate(audios):
diff = len(audio) - audio_size
if diff == 0:
collated_audios[i] = audio
elif diff < 0:
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
padding_mask[i, diff:] = True
else:
raise Exception("Diff should not be larger than 0")
return collated_audios, padding_mask
def num_tokens(self, index):
return self.wav_sizes[index]
def size(self, index):
return self.wav_sizes[index], self.tgt_sizes[index]
@property
def sizes(self):
return np.array(self.wav_sizes)
@property
def can_reuse_epoch_itr_across_epochs(self):
"""No cache dataset if dataset is large-scale. Cache dataset for small dataset."""
return True
def ordered_indices(self):
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.wav_sizes)
return np.lexsort(order)[::-1]
def postprocess(self, wav, cur_sample_rate):
if wav.dim() == 2:
wav = wav.mean(-1)
assert wav.dim() == 1, wav.dim()
if cur_sample_rate != self.sample_rate:
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
if self.normalize:
with torch.no_grad():
wav = F.layer_norm(wav, wav.shape)
return wav