from typing import Any, Union,List,Dict import numpy as np import torch from dataclasses import dataclass from transformers.feature_extraction_utils import BatchFeature from .vits_output import VitsTextEncoderOutput #............................................. @dataclass class DataCollatorTTSWithPadding: """ Data collator that will dynamically pad the inputs received. Args: tokenizer ([`VitsTokenizer`]) The tokenizer used for processing the data. feature_extractor ([`VitsFeatureExtractor`]) The tokenizer used for processing the data. forward_attention_mask (`bool`) Whether to return attention_mask. """ tokenizer: Any feature_extractor: Any forward_attention_mask: bool def pad_waveform(self, raw_speech): is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 if is_batched_numpy and len(raw_speech.shape) > 2: raise ValueError(f"Only mono-channel audio is supported for input to {self}") is_batched = is_batched_numpy or ( isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) ) if is_batched: raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech] elif not is_batched and not isinstance(raw_speech, np.ndarray): raw_speech = np.asarray(raw_speech, dtype=np.float32) elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): raw_speech = raw_speech.astype(np.float32) # always return batch if not is_batched: raw_speech = [np.asarray([raw_speech]).T] batched_speech = BatchFeature({"input_features": raw_speech}) # convert into correct format for padding padded_inputs = self.feature_extractor.pad( batched_speech, padding=True, return_attention_mask=False, return_tensors="pt", )["input_features"] return padded_inputs def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # split inputs and labels since they have to be of different lengths and need # different padding methods model_input_name = "input_ids" input_ids = [{model_input_name: feature[model_input_name][0]} for feature in features] # pad input tokens batch = self.tokenizer.pad(input_ids, return_tensors="pt", return_attention_mask=self.forward_attention_mask) # pad waveform waveforms = [np.array(feature["waveform"]) for feature in features] batch["waveform"] = self.pad_waveform(waveforms) # pad spectrogram label_features = [np.array(feature["labels"]) for feature in features] labels_batch = self.feature_extractor.pad( {"input_features": [i.T for i in label_features]}, return_tensors="pt", return_attention_mask=True ) labels = labels_batch["input_features"].transpose(1, 2) batch["labels"] = labels batch["labels_attention_mask"] = labels_batch["attention_mask"] # pad mel spectrogram mel_scaled_input_features = { "input_features": [np.array(feature["mel_scaled_input_features"]).squeeze().T for feature in features] } mel_scaled_input_features = self.feature_extractor.pad( mel_scaled_input_features, return_tensors="pt", return_attention_mask=True )["input_features"].transpose(1, 2) batch["mel_scaled_input_features"] = mel_scaled_input_features batch["speaker_id"] = ( torch.tensor([feature["speaker_id"] for feature in features]) if "speaker_id" in features[0] else None ) # text_encoder_output = [{ # 'last_hidden_state':torch.tensor(features["text_encoder_output"]['last_hidden_state']), # 'prior_log_variances':torch.tensor(feature["text_encoder_output"]['prior_log_variances']), # 'prior_means':torch.tensor(feature["text_encoder_output"]['prior_means']), # } for feature in features] batch['text_encoder_output'] = VitsTextEncoderOutput( last_hidden_state=torch.tensor(features[0]["text_encoder_output"]['last_hidden_state']), prior_means=torch.tensor(features[0]["text_encoder_output"]['prior_means']), prior_log_variances=torch.tensor(features[0]["text_encoder_output"]['prior_log_variances']), ) # print("DataColl ",batch.keys()) return batch #.............................................................................................