import torch import torch.nn as nn from torch.nn.functional import pad from collections import OrderedDict class Encoder(nn.Module): def __init__(self, config): super(Encoder, self).__init__() self.config = config self.encoder = None self.succeeding_layers = None # AUDIO if self.config.model.task == "audio": if self.config.model.encoder.name.lower() == "wavlm": from manipulate_model.encoder.wavlm.WavLM import WavLM, WavLMConfig ckpt = torch.load( config.model.encoder.pretrained_path, map_location="cpu" ) cfg = WavLMConfig(ckpt) self.encoder = WavLM(cfg) def forward(self, x): if self.config.model.encoder.name.lower() == "wavlm": return self.encoder(x, output_layer=self.config.model.encoder.output_layer) elif self.config.model.encoder.name.lower() == "videomamba": return self.encoder(x) return self.encoder(x) def get_encoding_dim(self): return self.encoder.get_encoding_dim() def get_temporal_dim(self): return self.encoder.get_temporal_dim(window_size=self.config.data.window_size)