Arnab Das
bug fix
948bfd2
raw
history blame contribute delete
822 Bytes
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from manipulate_model.encoder.encoder import Encoder
from manipulate_model.decoder.decoder import Decoder
class Model(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super(Model, self).__init__()
self.config = config
self.encoder = Encoder(self.config)
self.config.model.decoder.temporal_dim = self.encoder.get_temporal_dim()
self.config.model.decoder.encoding_dim = self.encoder.get_encoding_dim()
self.decoder = Decoder(self.config)
def forward(self, x):
if self.config.model.encoder_freeze:
with torch.no_grad():
x = self.encoder(x)
else:
x = self.encoder(x)
x = self.decoder(x)
return x