Spaces:
Runtime error
Runtime error
from argparse import Namespace, ArgumentParser | |
from functools import partial | |
from torch import nn | |
from .resnet import ResNetBasicBlock, activation_func, norm_module, Conv2dAuto | |
def add_arguments(parser: ArgumentParser) -> ArgumentParser: | |
parser.add_argument("--latent_size", type=int, default=512, help="latent size") | |
return parser | |
def create_model(args) -> nn.Module: | |
in_channels = 3 if "rgb" in args and args.rgb else 1 | |
return Encoder(in_channels, args.encoder_size, latent_size=args.latent_size) | |
class Flatten(nn.Module): | |
def forward(self, input_): | |
return input_.view(input_.size(0), -1) | |
class Encoder(nn.Module): | |
def __init__( | |
self, in_channels: int, size: int, latent_size: int = 512, | |
activation: str = 'leaky_relu', norm: str = "instance" | |
): | |
super().__init__() | |
out_channels0 = 64 | |
norm_m = norm_module(norm) | |
self.conv0 = nn.Sequential( | |
Conv2dAuto(in_channels, out_channels0, kernel_size=5), | |
norm_m(out_channels0), | |
activation_func(activation), | |
) | |
pool_kernel = 2 | |
self.pool = nn.AvgPool2d(pool_kernel) | |
num_channels = [128, 256, 512, 512] | |
# FIXME: this is a hack | |
if size >= 256: | |
num_channels.append(512) | |
residual = partial(ResNetBasicBlock, activation=activation, norm=norm, bias=True) | |
residual_blocks = nn.ModuleList() | |
for in_channel, out_channel in zip([out_channels0] + num_channels[:-1], num_channels): | |
residual_blocks.append(residual(in_channel, out_channel)) | |
residual_blocks.append(nn.AvgPool2d(pool_kernel)) | |
self.residual_blocks = nn.Sequential(*residual_blocks) | |
self.last = nn.Sequential( | |
nn.ReLU(), | |
nn.AvgPool2d(4), # TODO: not sure whehter this would cause problem | |
Flatten(), | |
nn.Linear(num_channels[-1], latent_size, bias=True) | |
) | |
def forward(self, input_): | |
out = self.conv0(input_) | |
out = self.pool(out) | |
out = self.residual_blocks(out) | |
out = self.last(out) | |
return out | |