File size: 1,244 Bytes
bd282c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667dd63
bd282c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn

from torch.nn.functional import pad
from collections import OrderedDict


class Encoder(nn.Module):
    def __init__(self, config):
        super(Encoder, self).__init__()
        self.config = config

        self.encoder = None
        self.succeeding_layers = None

        # AUDIO
        if self.config.model.task == "audio":
            if self.config.model.encoder.name.lower() == "wavlm":
                from manipulate_model.encoder.wavlm.WavLM import WavLM, WavLMConfig

                ckpt = torch.load(
                    config.model.encoder.pretrained_path, map_location="cpu"
                )
                cfg = WavLMConfig(ckpt)
                self.encoder = WavLM(cfg)


    def forward(self, x):
        if self.config.model.encoder.name.lower() == "wavlm":
            return self.encoder(x, output_layer=self.config.model.encoder.output_layer)
        elif self.config.model.encoder.name.lower() == "videomamba":
            return self.encoder(x)

        return self.encoder(x)

    def get_encoding_dim(self):
        return self.encoder.get_encoding_dim()

    def get_temporal_dim(self):
        return self.encoder.get_temporal_dim(window_size=self.config.data.window_size)