Spaces:
Runtime error
Runtime error
import random | |
import re | |
import torch | |
import torch.nn as nn | |
import torchaudio | |
from torch.utils.data import Dataset | |
class DataCollator: | |
def __init__(self, processor, padding, device, augment): | |
self.processor = processor | |
self.padding = padding | |
self.device = device | |
self.sampling_rate = 16000 | |
self.augment = augment | |
atempos = (0.8, 1.0, 1.25) # audio tempo atempo=tempo | |
audio_effects = ( | |
("highpass=frequency=1500",), | |
( | |
"vibrato=f=5:d=0.4", | |
"volume=1.5", | |
), | |
( | |
"aecho=0.8:0.88:30:0.3", | |
"volume=1.5", | |
), | |
) | |
self.effectors = [None] | |
for atempo in atempos: | |
for audio_effect in audio_effects: | |
effect = f"atempo={atempo}," + ",".join(audio_effect) | |
self.effectors.append(torchaudio.io.AudioEffector(effect=effect)) | |
def __call__(self, data): | |
waveforms, lm_labels, accent_labels, gender_labels = zip(*data) | |
accent_labels = torch.tensor(accent_labels, device=self.device) | |
gender_labels = torch.tensor(gender_labels, device=self.device) | |
input_features = [ | |
{"input_values": self.random_augment(waveform).squeeze()} | |
for waveform in waveforms | |
] | |
label_features = [{"input_ids": lm_label} for lm_label in lm_labels] | |
padded_waveforms = self.processor.pad( | |
input_features, | |
padding=True, | |
return_tensors="pt", | |
)["input_values"] | |
padded_waveforms = padded_waveforms.to(self.device) | |
with self.processor.as_target_processor(): | |
padded_lm_labels = self.processor.pad( | |
label_features, | |
padding=True, | |
return_tensors="pt", | |
) | |
# replace padding with -100 to ignore loss correctly | |
padded_lm_labels = padded_lm_labels["input_ids"].masked_fill( | |
padded_lm_labels.attention_mask.ne(1), -100 | |
) | |
padded_lm_labels = padded_lm_labels.to(self.device) | |
return padded_waveforms, padded_lm_labels, accent_labels, gender_labels | |
def random_augment(self, waveform): | |
if not self.augment: | |
return waveform | |
waveform = torch.tensor(waveform) | |
waveform = torch.transpose(waveform, 0, 1) | |
effector = random.choice(self.effectors) | |
if effector is None: | |
return waveform | |
augmented_waveform = effector.apply(waveform, self.sampling_rate) | |
if augmented_waveform.isnan().any() | augmented_waveform.isinf().any(): | |
return waveform | |
return augmented_waveform | |
class L2ArcticDataset(Dataset): | |
def __init__(self, processor, audio_paths, lm_labels, accent_labels, gender_labels): | |
orig_sampling_rate = 44100 | |
new_sampling_rate = 16000 | |
resample_transform = torchaudio.transforms.Resample( | |
orig_sampling_rate, new_sampling_rate | |
) | |
self.waveforms = [] | |
self.lm_labels = [] | |
self.accent_labels = accent_labels | |
self.gender_labels = gender_labels | |
for audio_path in audio_paths: | |
waveform, _ = torchaudio.load(audio_path) | |
waveform = resample_transform(waveform) | |
self.waveforms.append( | |
processor(waveform, sampling_rate=new_sampling_rate).input_values[0] | |
) | |
with processor.as_target_processor(): | |
for lm_label in lm_labels: | |
self.lm_labels.append(processor(lm_label).input_ids) | |
def __getitem__(self, index): | |
return ( | |
self.waveforms[index], | |
self.lm_labels[index], | |
self.accent_labels[index], | |
self.gender_labels[index], | |
) | |
def __len__(self): | |
return len(self.waveforms) | |
class MultiTaskWav2Vec2(nn.Module): | |
def __init__( | |
self, | |
wav2vec2_backbone, | |
backbone_hidden_size, | |
projection_hidden_size, | |
num_accent_class, | |
): | |
super().__init__() | |
self.wav2vec2 = wav2vec2_backbone | |
self.accent_projector = nn.Linear(backbone_hidden_size, projection_hidden_size) | |
self.accent_classifier = nn.Linear(projection_hidden_size, num_accent_class) | |
self.gender_projector = nn.Linear(backbone_hidden_size, projection_hidden_size) | |
self.gender_classifier = nn.Linear(projection_hidden_size, 2) | |
def forward(self, waveform, lm_labels=None): | |
if lm_labels is not None: | |
# use hugging face wav2vecc2 | |
wav2vec2_output = self.wav2vec2(input_values=waveform, labels=lm_labels) | |
# get partial loss based (lm_head loss or the ctc loss) | |
ctc_loss = wav2vec2_output.loss | |
else: | |
# use hugging face wav2vecc2 | |
wav2vec2_output = self.wav2vec2(input_values=waveform) | |
ctc_loss = None | |
# get features from wav2vec2 | |
features = wav2vec2_output.hidden_states[-1] | |
# get output lm logits | |
lm_logits = wav2vec2_output.logits | |
# get output accent logits | |
accent_projected = self.accent_projector(features) | |
accent_projected = accent_projected.mean(dim=1) | |
accent_logits = self.accent_classifier(accent_projected) | |
# get output gender logits | |
gender_projected = self.gender_projector(features) | |
gender_projected = gender_projected.mean(dim=1) | |
gender_logits = self.gender_classifier(gender_projected) | |
return ctc_loss, lm_logits, accent_logits, gender_logits | |