Spaces:
Runtime error
Runtime error
Time-TravelRephotography
/
Time_TravelRephotography
/models
/encoder4editing
/criteria
/lpips
/networks.py
from typing import Sequence | |
from itertools import chain | |
import torch | |
import torch.nn as nn | |
from torchvision import models | |
from criteria.lpips.utils import normalize_activation | |
def get_network(net_type: str): | |
if net_type == 'alex': | |
return AlexNet() | |
elif net_type == 'squeeze': | |
return SqueezeNet() | |
elif net_type == 'vgg': | |
return VGG16() | |
else: | |
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') | |
class LinLayers(nn.ModuleList): | |
def __init__(self, n_channels_list: Sequence[int]): | |
super(LinLayers, self).__init__([ | |
nn.Sequential( | |
nn.Identity(), | |
nn.Conv2d(nc, 1, 1, 1, 0, bias=False) | |
) for nc in n_channels_list | |
]) | |
for param in self.parameters(): | |
param.requires_grad = False | |
class BaseNet(nn.Module): | |
def __init__(self): | |
super(BaseNet, self).__init__() | |
# register buffer | |
self.register_buffer( | |
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) | |
self.register_buffer( | |
'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) | |
def set_requires_grad(self, state: bool): | |
for param in chain(self.parameters(), self.buffers()): | |
param.requires_grad = state | |
def z_score(self, x: torch.Tensor): | |
return (x - self.mean) / self.std | |
def forward(self, x: torch.Tensor): | |
x = self.z_score(x) | |
output = [] | |
for i, (_, layer) in enumerate(self.layers._modules.items(), 1): | |
x = layer(x) | |
if i in self.target_layers: | |
output.append(normalize_activation(x)) | |
if len(output) == len(self.target_layers): | |
break | |
return output | |
class SqueezeNet(BaseNet): | |
def __init__(self): | |
super(SqueezeNet, self).__init__() | |
self.layers = models.squeezenet1_1(True).features | |
self.target_layers = [2, 5, 8, 10, 11, 12, 13] | |
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] | |
self.set_requires_grad(False) | |
class AlexNet(BaseNet): | |
def __init__(self): | |
super(AlexNet, self).__init__() | |
self.layers = models.alexnet(True).features | |
self.target_layers = [2, 5, 8, 10, 12] | |
self.n_channels_list = [64, 192, 384, 256, 256] | |
self.set_requires_grad(False) | |
class VGG16(BaseNet): | |
def __init__(self): | |
super(VGG16, self).__init__() | |
self.layers = models.vgg16(True).features | |
self.target_layers = [4, 9, 16, 23, 30] | |
self.n_channels_list = [64, 128, 256, 512, 512] | |
self.set_requires_grad(False) |