import torch class Decoder(torch.nn.Module): def __init__(self, cfg): super(Decoder, self).__init__() self.cfg = cfg # Layer Definition self.layer1 = torch.nn.Sequential( torch.nn.ConvTranspose3d(2048, 512, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1), torch.nn.BatchNorm3d(512), torch.nn.ReLU() ) self.layer2 = torch.nn.Sequential( torch.nn.ConvTranspose3d(512, 128, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1), torch.nn.BatchNorm3d(128), torch.nn.ReLU() ) self.layer3 = torch.nn.Sequential( torch.nn.ConvTranspose3d(128, 32, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1), torch.nn.BatchNorm3d(32), torch.nn.ReLU() ) self.layer4 = torch.nn.Sequential( torch.nn.ConvTranspose3d(32, 8, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1), torch.nn.BatchNorm3d(8), torch.nn.ReLU() ) self.layer5 = torch.nn.Sequential( torch.nn.ConvTranspose3d(8, 1, kernel_size=1, bias=cfg.NETWORK.TCONV_USE_BIAS), torch.nn.Sigmoid() ) def forward(self, image_features): image_features = image_features.permute(1, 0, 2, 3, 4).contiguous() image_features = torch.split(image_features, 1, dim=0) gen_volumes = [] raw_features = [] for features in image_features: gen_volume = features.view(-1, 2048, 2, 2, 2) # print(gen_volume.size()) # torch.Size([batch_size, 2048, 2, 2, 2]) gen_volume = self.layer1(gen_volume) # print(gen_volume.size()) # torch.Size([batch_size, 512, 4, 4, 4]) gen_volume = self.layer2(gen_volume) # print(gen_volume.size()) # torch.Size([batch_size, 128, 8, 8, 8]) gen_volume = self.layer3(gen_volume) # print(gen_volume.size()) # torch.Size([batch_size, 32, 16, 16, 16]) gen_volume = self.layer4(gen_volume) raw_feature = gen_volume # print(gen_volume.size()) # torch.Size([batch_size, 8, 32, 32, 32]) gen_volume = self.layer5(gen_volume) # print(gen_volume.size()) # torch.Size([batch_size, 1, 32, 32, 32]) raw_feature = torch.cat((raw_feature, gen_volume), dim=1) # print(raw_feature.size()) # torch.Size([batch_size, 9, 32, 32, 32]) gen_volumes.append(torch.squeeze(gen_volume, dim=1)) raw_features.append(raw_feature) gen_volumes = torch.stack(gen_volumes).permute(1, 0, 2, 3, 4).contiguous() raw_features = torch.stack(raw_features).permute(1, 0, 2, 3, 4, 5).contiguous() # print(gen_volumes.size()) # torch.Size([batch_size, n_views, 32, 32, 32]) # print(raw_features.size()) # torch.Size([batch_size, n_views, 9, 32, 32, 32]) return raw_features, gen_volumes class DummyCfg: class NETWORK: TCONV_USE_BIAS = False cfg = DummyCfg() # Instantiate the decoder decoder = Decoder(cfg) # Simulate input: shape [batch_size,n_views,img_c, img_h, img_w] n_views = 1 batch_size = 64 img_c, img_h, img_w = 256, 8, 8 dummy_input = torch.randn(batch_size,n_views,img_c, img_h, img_w) # Run the decoder print(dummy_input.shape) raw_features, gen_volumes = decoder(dummy_input) # Output shapes print("raw_features shape:", raw_features.shape) # Expected: [64, 5, 9, 32, 32, 32] print("gen_volumes shape:", gen_volumes.shape)