| from model import common | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| class VGG(nn.Module): | |
| def __init__(self, conv_index, rgb_range=1): | |
| super(VGG, self).__init__() | |
| vgg_features = models.vgg19(pretrained=True).features | |
| modules = [m for m in vgg_features] | |
| if conv_index.find('22') >= 0: | |
| self.vgg = nn.Sequential(*modules[:8]) | |
| elif conv_index.find('54') >= 0: | |
| self.vgg = nn.Sequential(*modules[:35]) | |
| vgg_mean = (0.485, 0.456, 0.406) | |
| vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) | |
| self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| def forward(self, sr, hr): | |
| def _forward(x): | |
| x = self.sub_mean(x) | |
| x = self.vgg(x) | |
| return x | |
| vgg_sr = _forward(sr) | |
| with torch.no_grad(): | |
| vgg_hr = _forward(hr.detach()) | |
| loss = F.mse_loss(vgg_sr, vgg_hr) | |
| return loss | |