dawn17's picture
Upload 35 files
bcc0f94
raw
history blame
848 Bytes
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
for param in resnet.parameters():
param.requires_grad_(False)
self.stages = nn.ModuleDict(
{
"block1": nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu),
"block2": nn.Sequential(resnet.maxpool, resnet.layer1),
"block3": resnet.layer2,
"block4": resnet.layer3,
"block5": resnet.layer4,
}
)
def forward(self, x):
stages = {}
for name, stage in self.stages.items():
x = stage(x)
stages[name] = x
return x, stages