File size: 5,612 Bytes
32dba0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import random
import re
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import Dataset


class DataCollator:
    def __init__(self, processor, padding, device, augment):
        self.processor = processor
        self.padding = padding
        self.device = device
        self.sampling_rate = 16000
        self.augment = augment

        atempos = (0.8, 1.0, 1.25)  # audio tempo atempo=tempo
        audio_effects = (
            ("highpass=frequency=1500",),
            (
                "vibrato=f=5:d=0.4",
                "volume=1.5",
            ),
            (
                "aecho=0.8:0.88:30:0.3",
                "volume=1.5",
            ),
        )

        self.effectors = [None]
        for atempo in atempos:
            for audio_effect in audio_effects:
                effect = f"atempo={atempo}," + ",".join(audio_effect)
                self.effectors.append(torchaudio.io.AudioEffector(effect=effect))

    def __call__(self, data):
        waveforms, lm_labels, accent_labels, gender_labels = zip(*data)
        accent_labels = torch.tensor(accent_labels, device=self.device)
        gender_labels = torch.tensor(gender_labels, device=self.device)

        input_features = [
            {"input_values": self.random_augment(waveform).squeeze()}
            for waveform in waveforms
        ]
        label_features = [{"input_ids": lm_label} for lm_label in lm_labels]

        padded_waveforms = self.processor.pad(
            input_features,
            padding=True,
            return_tensors="pt",
        )["input_values"]
        padded_waveforms = padded_waveforms.to(self.device)

        with self.processor.as_target_processor():
            padded_lm_labels = self.processor.pad(
                label_features,
                padding=True,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        padded_lm_labels = padded_lm_labels["input_ids"].masked_fill(
            padded_lm_labels.attention_mask.ne(1), -100
        )
        padded_lm_labels = padded_lm_labels.to(self.device)

        return padded_waveforms, padded_lm_labels, accent_labels, gender_labels

    def random_augment(self, waveform):
        if not self.augment:
            return waveform

        waveform = torch.tensor(waveform)
        waveform = torch.transpose(waveform, 0, 1)
        effector = random.choice(self.effectors)
        if effector is None:
            return waveform

        augmented_waveform = effector.apply(waveform, self.sampling_rate)
        if augmented_waveform.isnan().any() | augmented_waveform.isinf().any():
            return waveform

        return augmented_waveform


class L2ArcticDataset(Dataset):
    def __init__(self, processor, audio_paths, lm_labels, accent_labels, gender_labels):
        orig_sampling_rate = 44100
        new_sampling_rate = 16000
        resample_transform = torchaudio.transforms.Resample(
            orig_sampling_rate, new_sampling_rate
        )

        self.waveforms = []
        self.lm_labels = []
        self.accent_labels = accent_labels
        self.gender_labels = gender_labels

        for audio_path in audio_paths:
            waveform, _ = torchaudio.load(audio_path)
            waveform = resample_transform(waveform)
            self.waveforms.append(
                processor(waveform, sampling_rate=new_sampling_rate).input_values[0]
            )

        with processor.as_target_processor():
            for lm_label in lm_labels:
                self.lm_labels.append(processor(lm_label).input_ids)

    def __getitem__(self, index):
        return (
            self.waveforms[index],
            self.lm_labels[index],
            self.accent_labels[index],
            self.gender_labels[index],
        )

    def __len__(self):
        return len(self.waveforms)


class MultiTaskWav2Vec2(nn.Module):
    def __init__(
        self,
        wav2vec2_backbone,
        backbone_hidden_size,
        projection_hidden_size,
        num_accent_class,
    ):
        super().__init__()
        self.wav2vec2 = wav2vec2_backbone
        self.accent_projector = nn.Linear(backbone_hidden_size, projection_hidden_size)
        self.accent_classifier = nn.Linear(projection_hidden_size, num_accent_class)
        self.gender_projector = nn.Linear(backbone_hidden_size, projection_hidden_size)
        self.gender_classifier = nn.Linear(projection_hidden_size, 2)

    def forward(self, waveform, lm_labels=None):
        if lm_labels is not None:
            # use hugging face wav2vecc2
            wav2vec2_output = self.wav2vec2(input_values=waveform, labels=lm_labels)

            # get partial loss based (lm_head loss or the ctc loss)
            ctc_loss = wav2vec2_output.loss

        else:
            # use hugging face wav2vecc2
            wav2vec2_output = self.wav2vec2(input_values=waveform)
            ctc_loss = None

        # get features from wav2vec2
        features = wav2vec2_output.hidden_states[-1]

        # get output lm logits
        lm_logits = wav2vec2_output.logits

        # get output accent logits
        accent_projected = self.accent_projector(features)
        accent_projected = accent_projected.mean(dim=1)
        accent_logits = self.accent_classifier(accent_projected)

        # get output gender logits
        gender_projected = self.gender_projector(features)
        gender_projected = gender_projected.mean(dim=1)
        gender_logits = self.gender_classifier(gender_projected)

        return ctc_loss, lm_logits, accent_logits, gender_logits