import torch import torch.nn as nn class Decoder(nn.Module): def __init__(self, config): super(Decoder, self).__init__() self.config = config self.decoder = None if config.model.decoder.name.lower() == "aasist": from manipulate_model.decoder.aasist.aasist import AASIST self.decoder = AASIST(config) else: raise ValueError("Invalid decoder name") def forward(self, x): return self.decoder(x)