from typing import Any, List, Tuple import numpy as np import torch from torch.utils.data import Dataset from training.tools import pad_1D, pad_2D, pad_3D class LibriTTSMMDatasetAcoustic(Dataset): def __init__(self, file_path: str): r"""A PyTorch dataset for loading preprocessed acoustic data stored in memory-mapped files. Args: file_path (str): Path to the memory-mapped file. """ self.data = torch.load(file_path) def __getitem__(self, idx: int): r"""Returns a sample from the dataset at the given index. Args: idx (int): Index of the sample to return. Returns: Dict[str, Any]: A dictionary containing the sample data. """ return self.data[idx] def __len__(self): r"""Returns the number of samples in the dataset. Returns int: Number of samples in the dataset. """ return len(self.data) def collate_fn(self, data: List) -> List: r"""Collates a batch of data samples. Args: data (List): A list of data samples. Returns: List: A list of reprocessed data batches. """ data_size = len(data) idxs = list(range(data_size)) # Initialize empty lists to store extracted values empty_lists: List[List] = [[] for _ in range(11)] ( ids, speakers, texts, raw_texts, mels, pitches, attn_priors, langs, src_lens, mel_lens, wavs, ) = empty_lists # Extract fields from data dictionary and populate the lists for idx in idxs: data_entry = data[idx] ids.append(data_entry["id"]) speakers.append(data_entry["speaker"]) texts.append(data_entry["text"]) raw_texts.append(data_entry["raw_text"]) mels.append(data_entry["mel"]) pitches.append(data_entry["pitch"]) attn_priors.append(data_entry["attn_prior"].numpy()) langs.append(data_entry["lang"]) src_lens.append(data_entry["text"].shape[0]) mel_lens.append(data_entry["mel"].shape[1]) wavs.append(data_entry["wav"].numpy()) # Convert langs, src_lens, and mel_lens to numpy arrays langs = np.array(langs) src_lens = np.array(src_lens) mel_lens = np.array(mel_lens) # NOTE: Instead of the pitches for the whole dataset, used stat for the batch # Take only min and max values for pitch pitches_stat = list(self.normalize_pitch(pitches)[:2]) texts = pad_1D(texts) mels = pad_2D(mels) pitches = pad_1D(pitches) attn_priors = pad_3D(attn_priors, len(idxs), max(src_lens), max(mel_lens)) speakers = np.repeat( np.expand_dims(np.array(speakers), axis=1), texts.shape[1], axis=1, ) langs = np.repeat( np.expand_dims(np.array(langs), axis=1), texts.shape[1], axis=1, ) wavs = pad_2D(wavs) return [ ids, raw_texts, torch.from_numpy(speakers), torch.from_numpy(texts).int(), torch.from_numpy(src_lens), torch.from_numpy(mels), torch.from_numpy(pitches), pitches_stat, torch.from_numpy(mel_lens), torch.from_numpy(langs), torch.from_numpy(attn_priors), torch.from_numpy(wavs), ] def normalize_pitch( self, pitches: List[torch.Tensor], ) -> Tuple[float, float, float, float]: r"""Normalizes the pitch values. Args: pitches (List[torch.Tensor]): A list of pitch values. Returns: Tuple: A tuple containing the normalized pitch values. """ pitches_t = torch.concatenate(pitches) min_value = torch.min(pitches_t).item() max_value = torch.max(pitches_t).item() mean = torch.mean(pitches_t).item() std = torch.std(pitches_t).item() return min_value, max_value, mean, std