Spaces:
Running
Running
File size: 822 Bytes
bd282c4 948bfd2 bd282c4 948bfd2 bd282c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from manipulate_model.encoder.encoder import Encoder
from manipulate_model.decoder.decoder import Decoder
class Model(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super(Model, self).__init__()
self.config = config
self.encoder = Encoder(self.config)
self.config.model.decoder.temporal_dim = self.encoder.get_temporal_dim()
self.config.model.decoder.encoding_dim = self.encoder.get_encoding_dim()
self.decoder = Decoder(self.config)
def forward(self, x):
if self.config.model.encoder_freeze:
with torch.no_grad():
x = self.encoder(x)
else:
x = self.encoder(x)
x = self.decoder(x)
return x
|