bel32123 commited on
Commit
32dba0a
·
1 Parent(s): b615647

Upload model code for multitask model

Browse files
Files changed (1) hide show
  1. wav2vecasr/models.py +167 -0
wav2vecasr/models.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import re
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchaudio
6
+ from torch.utils.data import Dataset
7
+
8
+
9
+ class DataCollator:
10
+ def __init__(self, processor, padding, device, augment):
11
+ self.processor = processor
12
+ self.padding = padding
13
+ self.device = device
14
+ self.sampling_rate = 16000
15
+ self.augment = augment
16
+
17
+ atempos = (0.8, 1.0, 1.25) # audio tempo atempo=tempo
18
+ audio_effects = (
19
+ ("highpass=frequency=1500",),
20
+ (
21
+ "vibrato=f=5:d=0.4",
22
+ "volume=1.5",
23
+ ),
24
+ (
25
+ "aecho=0.8:0.88:30:0.3",
26
+ "volume=1.5",
27
+ ),
28
+ )
29
+
30
+ self.effectors = [None]
31
+ for atempo in atempos:
32
+ for audio_effect in audio_effects:
33
+ effect = f"atempo={atempo}," + ",".join(audio_effect)
34
+ self.effectors.append(torchaudio.io.AudioEffector(effect=effect))
35
+
36
+ def __call__(self, data):
37
+ waveforms, lm_labels, accent_labels, gender_labels = zip(*data)
38
+ accent_labels = torch.tensor(accent_labels, device=self.device)
39
+ gender_labels = torch.tensor(gender_labels, device=self.device)
40
+
41
+ input_features = [
42
+ {"input_values": self.random_augment(waveform).squeeze()}
43
+ for waveform in waveforms
44
+ ]
45
+ label_features = [{"input_ids": lm_label} for lm_label in lm_labels]
46
+
47
+ padded_waveforms = self.processor.pad(
48
+ input_features,
49
+ padding=True,
50
+ return_tensors="pt",
51
+ )["input_values"]
52
+ padded_waveforms = padded_waveforms.to(self.device)
53
+
54
+ with self.processor.as_target_processor():
55
+ padded_lm_labels = self.processor.pad(
56
+ label_features,
57
+ padding=True,
58
+ return_tensors="pt",
59
+ )
60
+
61
+ # replace padding with -100 to ignore loss correctly
62
+ padded_lm_labels = padded_lm_labels["input_ids"].masked_fill(
63
+ padded_lm_labels.attention_mask.ne(1), -100
64
+ )
65
+ padded_lm_labels = padded_lm_labels.to(self.device)
66
+
67
+ return padded_waveforms, padded_lm_labels, accent_labels, gender_labels
68
+
69
+ def random_augment(self, waveform):
70
+ if not self.augment:
71
+ return waveform
72
+
73
+ waveform = torch.tensor(waveform)
74
+ waveform = torch.transpose(waveform, 0, 1)
75
+ effector = random.choice(self.effectors)
76
+ if effector is None:
77
+ return waveform
78
+
79
+ augmented_waveform = effector.apply(waveform, self.sampling_rate)
80
+ if augmented_waveform.isnan().any() | augmented_waveform.isinf().any():
81
+ return waveform
82
+
83
+ return augmented_waveform
84
+
85
+
86
+ class L2ArcticDataset(Dataset):
87
+ def __init__(self, processor, audio_paths, lm_labels, accent_labels, gender_labels):
88
+ orig_sampling_rate = 44100
89
+ new_sampling_rate = 16000
90
+ resample_transform = torchaudio.transforms.Resample(
91
+ orig_sampling_rate, new_sampling_rate
92
+ )
93
+
94
+ self.waveforms = []
95
+ self.lm_labels = []
96
+ self.accent_labels = accent_labels
97
+ self.gender_labels = gender_labels
98
+
99
+ for audio_path in audio_paths:
100
+ waveform, _ = torchaudio.load(audio_path)
101
+ waveform = resample_transform(waveform)
102
+ self.waveforms.append(
103
+ processor(waveform, sampling_rate=new_sampling_rate).input_values[0]
104
+ )
105
+
106
+ with processor.as_target_processor():
107
+ for lm_label in lm_labels:
108
+ self.lm_labels.append(processor(lm_label).input_ids)
109
+
110
+ def __getitem__(self, index):
111
+ return (
112
+ self.waveforms[index],
113
+ self.lm_labels[index],
114
+ self.accent_labels[index],
115
+ self.gender_labels[index],
116
+ )
117
+
118
+ def __len__(self):
119
+ return len(self.waveforms)
120
+
121
+
122
+ class MultiTaskWav2Vec2(nn.Module):
123
+ def __init__(
124
+ self,
125
+ wav2vec2_backbone,
126
+ backbone_hidden_size,
127
+ projection_hidden_size,
128
+ num_accent_class,
129
+ ):
130
+ super().__init__()
131
+ self.wav2vec2 = wav2vec2_backbone
132
+ self.accent_projector = nn.Linear(backbone_hidden_size, projection_hidden_size)
133
+ self.accent_classifier = nn.Linear(projection_hidden_size, num_accent_class)
134
+ self.gender_projector = nn.Linear(backbone_hidden_size, projection_hidden_size)
135
+ self.gender_classifier = nn.Linear(projection_hidden_size, 2)
136
+
137
+ def forward(self, waveform, lm_labels=None):
138
+ if lm_labels is not None:
139
+ # use hugging face wav2vecc2
140
+ wav2vec2_output = self.wav2vec2(input_values=waveform, labels=lm_labels)
141
+
142
+ # get partial loss based (lm_head loss or the ctc loss)
143
+ ctc_loss = wav2vec2_output.loss
144
+
145
+ else:
146
+ # use hugging face wav2vecc2
147
+ wav2vec2_output = self.wav2vec2(input_values=waveform)
148
+ ctc_loss = None
149
+
150
+ # get features from wav2vec2
151
+ features = wav2vec2_output.hidden_states[-1]
152
+
153
+ # get output lm logits
154
+ lm_logits = wav2vec2_output.logits
155
+
156
+ # get output accent logits
157
+ accent_projected = self.accent_projector(features)
158
+ accent_projected = accent_projected.mean(dim=1)
159
+ accent_logits = self.accent_classifier(accent_projected)
160
+
161
+ # get output gender logits
162
+ gender_projected = self.gender_projector(features)
163
+ gender_projected = gender_projected.mean(dim=1)
164
+ gender_logits = self.gender_classifier(gender_projected)
165
+
166
+ return ctc_loss, lm_logits, accent_logits, gender_logits
167
+