konverner's picture
Initial commit
899cf32
raw
history blame
No virus
1.69 kB
import torch
from typing import Any, Dict, List, Union
class TTSDataCollatorWithPadding:
def __init__(self, model, processor):
self.model = model
self.processor = processor
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
label_features = [{"input_values": feature["labels"]} for feature in features]
speaker_features = [feature["speaker_embeddings"] for feature in features]
# collate the inputs and targets into a batch
batch = self.processor.pad(
input_ids=input_ids,
labels=label_features,
return_tensors="pt",
)
# replace padding with -100 to ignore loss correctly
batch["labels"] = batch["labels"].masked_fill(
batch.decoder_attention_mask.unsqueeze(-1).ne(1), -100
)
# not used during fine-tuning
del batch["decoder_attention_mask"]
# round down target lengths to multiple of reduction factor
if self.model.config.reduction_factor > 1:
target_lengths = torch.tensor([
len(feature["input_values"]) for feature in label_features
])
target_lengths = target_lengths.new([
length - length % self.model.config.reduction_factor for length in target_lengths
])
max_length = max(target_lengths)
batch["labels"] = batch["labels"][:, :max_length]
# add the speaker embeddings
batch["speaker_embeddings"] = torch.tensor(speaker_features)
return batch