import torch from typing import Any, Dict, List, Union class TTSDataCollatorWithPadding: def __init__(self, model, processor): self.model = model self.processor = processor def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_ids = [{"input_ids": feature["input_ids"]} for feature in features] label_features = [{"input_values": feature["labels"]} for feature in features] speaker_features = [feature["speaker_embeddings"] for feature in features] # collate the inputs and targets into a batch batch = self.processor.pad( input_ids=input_ids, labels=label_features, return_tensors="pt", ) # replace padding with -100 to ignore loss correctly batch["labels"] = batch["labels"].masked_fill( batch.decoder_attention_mask.unsqueeze(-1).ne(1), -100 ) # not used during fine-tuning del batch["decoder_attention_mask"] # round down target lengths to multiple of reduction factor if self.model.config.reduction_factor > 1: target_lengths = torch.tensor([ len(feature["input_values"]) for feature in label_features ]) target_lengths = target_lengths.new([ length - length % self.model.config.reduction_factor for length in target_lengths ]) max_length = max(target_lengths) batch["labels"] = batch["labels"][:, :max_length] # add the speaker embeddings batch["speaker_embeddings"] = torch.tensor(speaker_features) return batch