Spaces:
Runtime error
Runtime error
File size: 5,612 Bytes
32dba0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import random
import re
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import Dataset
class DataCollator:
def __init__(self, processor, padding, device, augment):
self.processor = processor
self.padding = padding
self.device = device
self.sampling_rate = 16000
self.augment = augment
atempos = (0.8, 1.0, 1.25) # audio tempo atempo=tempo
audio_effects = (
("highpass=frequency=1500",),
(
"vibrato=f=5:d=0.4",
"volume=1.5",
),
(
"aecho=0.8:0.88:30:0.3",
"volume=1.5",
),
)
self.effectors = [None]
for atempo in atempos:
for audio_effect in audio_effects:
effect = f"atempo={atempo}," + ",".join(audio_effect)
self.effectors.append(torchaudio.io.AudioEffector(effect=effect))
def __call__(self, data):
waveforms, lm_labels, accent_labels, gender_labels = zip(*data)
accent_labels = torch.tensor(accent_labels, device=self.device)
gender_labels = torch.tensor(gender_labels, device=self.device)
input_features = [
{"input_values": self.random_augment(waveform).squeeze()}
for waveform in waveforms
]
label_features = [{"input_ids": lm_label} for lm_label in lm_labels]
padded_waveforms = self.processor.pad(
input_features,
padding=True,
return_tensors="pt",
)["input_values"]
padded_waveforms = padded_waveforms.to(self.device)
with self.processor.as_target_processor():
padded_lm_labels = self.processor.pad(
label_features,
padding=True,
return_tensors="pt",
)
# replace padding with -100 to ignore loss correctly
padded_lm_labels = padded_lm_labels["input_ids"].masked_fill(
padded_lm_labels.attention_mask.ne(1), -100
)
padded_lm_labels = padded_lm_labels.to(self.device)
return padded_waveforms, padded_lm_labels, accent_labels, gender_labels
def random_augment(self, waveform):
if not self.augment:
return waveform
waveform = torch.tensor(waveform)
waveform = torch.transpose(waveform, 0, 1)
effector = random.choice(self.effectors)
if effector is None:
return waveform
augmented_waveform = effector.apply(waveform, self.sampling_rate)
if augmented_waveform.isnan().any() | augmented_waveform.isinf().any():
return waveform
return augmented_waveform
class L2ArcticDataset(Dataset):
def __init__(self, processor, audio_paths, lm_labels, accent_labels, gender_labels):
orig_sampling_rate = 44100
new_sampling_rate = 16000
resample_transform = torchaudio.transforms.Resample(
orig_sampling_rate, new_sampling_rate
)
self.waveforms = []
self.lm_labels = []
self.accent_labels = accent_labels
self.gender_labels = gender_labels
for audio_path in audio_paths:
waveform, _ = torchaudio.load(audio_path)
waveform = resample_transform(waveform)
self.waveforms.append(
processor(waveform, sampling_rate=new_sampling_rate).input_values[0]
)
with processor.as_target_processor():
for lm_label in lm_labels:
self.lm_labels.append(processor(lm_label).input_ids)
def __getitem__(self, index):
return (
self.waveforms[index],
self.lm_labels[index],
self.accent_labels[index],
self.gender_labels[index],
)
def __len__(self):
return len(self.waveforms)
class MultiTaskWav2Vec2(nn.Module):
def __init__(
self,
wav2vec2_backbone,
backbone_hidden_size,
projection_hidden_size,
num_accent_class,
):
super().__init__()
self.wav2vec2 = wav2vec2_backbone
self.accent_projector = nn.Linear(backbone_hidden_size, projection_hidden_size)
self.accent_classifier = nn.Linear(projection_hidden_size, num_accent_class)
self.gender_projector = nn.Linear(backbone_hidden_size, projection_hidden_size)
self.gender_classifier = nn.Linear(projection_hidden_size, 2)
def forward(self, waveform, lm_labels=None):
if lm_labels is not None:
# use hugging face wav2vecc2
wav2vec2_output = self.wav2vec2(input_values=waveform, labels=lm_labels)
# get partial loss based (lm_head loss or the ctc loss)
ctc_loss = wav2vec2_output.loss
else:
# use hugging face wav2vecc2
wav2vec2_output = self.wav2vec2(input_values=waveform)
ctc_loss = None
# get features from wav2vec2
features = wav2vec2_output.hidden_states[-1]
# get output lm logits
lm_logits = wav2vec2_output.logits
# get output accent logits
accent_projected = self.accent_projector(features)
accent_projected = accent_projected.mean(dim=1)
accent_logits = self.accent_classifier(accent_projected)
# get output gender logits
gender_projected = self.gender_projector(features)
gender_projected = gender_projected.mean(dim=1)
gender_logits = self.gender_classifier(gender_projected)
return ctc_loss, lm_logits, accent_logits, gender_logits
|