from torchvision.models import resnet50, ResNet50_Weights import torch.nn as nn class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) for param in resnet.parameters(): param.requires_grad_(False) self.stages = nn.ModuleDict( { "block1": nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu), "block2": nn.Sequential(resnet.maxpool, resnet.layer1), "block3": resnet.layer2, "block4": resnet.layer3, "block5": resnet.layer4, } ) def forward(self, x): stages = {} for name, stage in self.stages.items(): x = stage(x) stages[name] = x return x, stages