osanseviero commited on
Commit
fc89401
·
1 Parent(s): c5f0da9

Add base code

Browse files
Files changed (1) hide show
  1. expert.py +56 -0
expert.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from packaging import version
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn.utils.rnn import pad_sequence
7
+
8
+ import fairseq
9
+ from s3prl.interfaces import UpstreamBase
10
+
11
+
12
+ SAMPLE_RATE = 16000
13
+ EXAMPLE_SEC = 5
14
+
15
+ class UpstreamExpert(UpstreamBase):
16
+ def __init__(self, ckpt, **kwargs):
17
+ super().__init__(**kwargs)
18
+ assert version.parse(fairseq.__version__) > version.parse(
19
+ "0.10.2"
20
+ ), "Please install the fairseq master branch."
21
+
22
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
23
+ [ckpt]
24
+ )
25
+ self.model = model[0]
26
+ self.task = task
27
+
28
+ if len(self.hooks) == 0:
29
+ module_name = "self.model.encoder.layers"
30
+ for module_id in range(len(eval(module_name))):
31
+ self.add_hook(
32
+ f"{module_name}[{module_id}]",
33
+ lambda input, output: input[0].transpose(0, 1),
34
+ )
35
+ self.add_hook("self.model.encoder", lambda input, output: output[0])
36
+
37
+ def forward(self, wavs):
38
+ if self.task.cfg.normalize:
39
+ wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
40
+
41
+ device = wavs[0].device
42
+ wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
43
+ wav_padding_mask = ~torch.lt(
44
+ torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
45
+ wav_lengths.unsqueeze(1),
46
+ )
47
+ padded_wav = pad_sequence(wavs, batch_first=True)
48
+
49
+ features, feat_padding_mask = self.model.extract_features(
50
+ padded_wav,
51
+ padding_mask=wav_padding_mask,
52
+ mask=None,
53
+ )
54
+ return {
55
+ "default": features,
56
+ }