File size: 490 Bytes
bd282c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn


class Decoder(nn.Module):
    def __init__(self, config):
        super(Decoder, self).__init__()
        self.config = config

        self.decoder = None

        if config.model.decoder.name.lower() == "aasist":
            from manipulate_model.decoder.aasist.aasist import AASIST

            self.decoder = AASIST(config)
        else:
            raise ValueError("Invalid decoder name")

    def forward(self, x):
        return self.decoder(x)