bel32123's picture
Upload model code for multitask model
32dba0a
import random
import re
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import Dataset
class DataCollator:
def __init__(self, processor, padding, device, augment):
self.processor = processor
self.padding = padding
self.device = device
self.sampling_rate = 16000
self.augment = augment
atempos = (0.8, 1.0, 1.25) # audio tempo atempo=tempo
audio_effects = (
("highpass=frequency=1500",),
(
"vibrato=f=5:d=0.4",
"volume=1.5",
),
(
"aecho=0.8:0.88:30:0.3",
"volume=1.5",
),
)
self.effectors = [None]
for atempo in atempos:
for audio_effect in audio_effects:
effect = f"atempo={atempo}," + ",".join(audio_effect)
self.effectors.append(torchaudio.io.AudioEffector(effect=effect))
def __call__(self, data):
waveforms, lm_labels, accent_labels, gender_labels = zip(*data)
accent_labels = torch.tensor(accent_labels, device=self.device)
gender_labels = torch.tensor(gender_labels, device=self.device)
input_features = [
{"input_values": self.random_augment(waveform).squeeze()}
for waveform in waveforms
]
label_features = [{"input_ids": lm_label} for lm_label in lm_labels]
padded_waveforms = self.processor.pad(
input_features,
padding=True,
return_tensors="pt",
)["input_values"]
padded_waveforms = padded_waveforms.to(self.device)
with self.processor.as_target_processor():
padded_lm_labels = self.processor.pad(
label_features,
padding=True,
return_tensors="pt",
)
# replace padding with -100 to ignore loss correctly
padded_lm_labels = padded_lm_labels["input_ids"].masked_fill(
padded_lm_labels.attention_mask.ne(1), -100
)
padded_lm_labels = padded_lm_labels.to(self.device)
return padded_waveforms, padded_lm_labels, accent_labels, gender_labels
def random_augment(self, waveform):
if not self.augment:
return waveform
waveform = torch.tensor(waveform)
waveform = torch.transpose(waveform, 0, 1)
effector = random.choice(self.effectors)
if effector is None:
return waveform
augmented_waveform = effector.apply(waveform, self.sampling_rate)
if augmented_waveform.isnan().any() | augmented_waveform.isinf().any():
return waveform
return augmented_waveform
class L2ArcticDataset(Dataset):
def __init__(self, processor, audio_paths, lm_labels, accent_labels, gender_labels):
orig_sampling_rate = 44100
new_sampling_rate = 16000
resample_transform = torchaudio.transforms.Resample(
orig_sampling_rate, new_sampling_rate
)
self.waveforms = []
self.lm_labels = []
self.accent_labels = accent_labels
self.gender_labels = gender_labels
for audio_path in audio_paths:
waveform, _ = torchaudio.load(audio_path)
waveform = resample_transform(waveform)
self.waveforms.append(
processor(waveform, sampling_rate=new_sampling_rate).input_values[0]
)
with processor.as_target_processor():
for lm_label in lm_labels:
self.lm_labels.append(processor(lm_label).input_ids)
def __getitem__(self, index):
return (
self.waveforms[index],
self.lm_labels[index],
self.accent_labels[index],
self.gender_labels[index],
)
def __len__(self):
return len(self.waveforms)
class MultiTaskWav2Vec2(nn.Module):
def __init__(
self,
wav2vec2_backbone,
backbone_hidden_size,
projection_hidden_size,
num_accent_class,
):
super().__init__()
self.wav2vec2 = wav2vec2_backbone
self.accent_projector = nn.Linear(backbone_hidden_size, projection_hidden_size)
self.accent_classifier = nn.Linear(projection_hidden_size, num_accent_class)
self.gender_projector = nn.Linear(backbone_hidden_size, projection_hidden_size)
self.gender_classifier = nn.Linear(projection_hidden_size, 2)
def forward(self, waveform, lm_labels=None):
if lm_labels is not None:
# use hugging face wav2vecc2
wav2vec2_output = self.wav2vec2(input_values=waveform, labels=lm_labels)
# get partial loss based (lm_head loss or the ctc loss)
ctc_loss = wav2vec2_output.loss
else:
# use hugging face wav2vecc2
wav2vec2_output = self.wav2vec2(input_values=waveform)
ctc_loss = None
# get features from wav2vec2
features = wav2vec2_output.hidden_states[-1]
# get output lm logits
lm_logits = wav2vec2_output.logits
# get output accent logits
accent_projected = self.accent_projector(features)
accent_projected = accent_projected.mean(dim=1)
accent_logits = self.accent_classifier(accent_projected)
# get output gender logits
gender_projected = self.gender_projector(features)
gender_projected = gender_projected.mean(dim=1)
gender_logits = self.gender_classifier(gender_projected)
return ctc_loss, lm_logits, accent_logits, gender_logits