import torch from torch import nn from torchvision.models import resnet18, ResNet18_Weights from config import InferenceConfig class EyeResNet(nn.Module): """ Modified Resnet18 pretrained model with frozen layers except for the final block and the last fully connected layer. """ def __init__(self): super().__init__() resnet = resnet18(weights=ResNet18_Weights.DEFAULT) model = nn.Sequential(*list(resnet.children())[:-3]) model[:-3][-1].requires_grad_(False) self.resnet_cut = model self.avg_pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1)) self.fc = nn.Linear(in_features=256, out_features=1, bias=True) self.flatten = nn.Flatten() def forward(self, x): x = self.resnet_cut(x) x = self.avg_pooling(x) x = self.flatten(x) x = self.fc(x) return x if __name__ == "__main__": cfg = InferenceConfig() model = EyeResNet() model.load_state_dict(torch.load(cfg.eyes_model_path)['state_dict'])