Spaces:
Runtime error
Runtime error
File size: 1,983 Bytes
7999e5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import torch.nn as nn
vgg19_cfg = [3, 64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"]
decoder_cfg = [512, 256, "U", 256, 256, 256, 128, "U", 128, 64, 'U', 64, 3]
def vgg19(weights=None):
"""
Build vgg19 network. Load weights if weights are given.
Args:
weights (dict): vgg19 pretrained weights
Return:
layers (nn.Sequential): vgg19 layers
"""
modules = make_block(vgg19_cfg)
modules = [nn.Conv2d(3, 3, kernel_size=1)] + list(modules.children())
layers = nn.Sequential(*modules)
if weights:
layers.load_state_dict(weights)
return layers
def decoder(weights=None):
"""
Build decoder network. Load weights if weights are given.
Args:
weights (dict): decoder pretrained weights
Return:
layers (nn.Sequential): decoder layers
"""
modules = make_block(decoder_cfg)
layers = nn.Sequential(*list(modules.children())[:-1]) # no relu at the last layer
if weights:
layers.load_state_dict(weights)
return layers
def make_block(config):
"""
Helper function for building blocks of convolutional layers.
Args:
config (list): List of layer configs. "M"
"M" - Max pooling layer.
"U" - Upsampling layer.
i (int) - Convolutional layer (i filters) plus ReLU activation.
Return:
layers (nn.Sequential): block layers
"""
layers = []
in_channels = config[0]
for c in config[1:]:
if c == "M":
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
elif c == "U":
layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
else:
assert(isinstance(c, int))
layers.append(nn.Conv2d(in_channels, c, kernel_size=3, padding=1))
layers.append(nn.ReLU(inplace=True))
in_channels = c
return nn.Sequential(*layers)
|