Flux9665's picture
initial commit
6faeba1
raw
history blame
5.53 kB
import numpy as np
import torch
import torch.utils.data
import torch.utils.data.distributed
from torch import nn
class ResNet_G(nn.Module):
def __init__(self, data_dim, z_dim, size, nfilter=64, nfilter_max=512, bn=True, res_ratio=0.1, **kwargs):
super().__init__()
self.input_dim = z_dim
self.output_dim = z_dim
self.dropout_rate = 0
s0 = self.s0 = 4
nf = self.nf = nfilter
nf_max = self.nf_max = nfilter_max
self.bn = bn
self.z_dim = z_dim
# Submodules
nlayers = int(np.log2(size / s0))
self.nf0 = min(nf_max, nf * 2 ** (nlayers + 1))
self.fc = nn.Linear(z_dim, self.nf0 * s0 * s0)
if self.bn:
self.bn1d = nn.BatchNorm1d(self.nf0 * s0 * s0)
self.relu = nn.LeakyReLU(0.2, inplace=True)
blocks = []
for i in range(nlayers, 0, -1):
nf0 = min(nf * 2 ** (i + 1), nf_max)
nf1 = min(nf * 2 ** i, nf_max)
blocks += [
ResNetBlock(nf0, nf1, bn=self.bn, res_ratio=res_ratio),
nn.Upsample(scale_factor=2)
]
nf0 = min(nf * 2, nf_max)
nf1 = min(nf, nf_max)
blocks += [
ResNetBlock(nf0, nf1, bn=self.bn, res_ratio=res_ratio),
ResNetBlock(nf1, nf1, bn=self.bn, res_ratio=res_ratio)
]
self.resnet = nn.Sequential(*blocks)
self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)
self.fc_out = nn.Linear(3 * size * size, data_dim)
def forward(self, z, return_intermediate=False):
# print(z.shape)
batch_size = z.size(0)
# z = z.view(batch_size, -1)
out = self.fc(z)
if self.bn:
out = self.bn1d(out)
out = self.relu(out)
if return_intermediate:
l_1 = out.detach().clone()
out = out.view(batch_size, self.nf0, self.s0, self.s0)
# print(out.shape)
out = self.resnet(out)
# print(out.shape)
# out = out.view(batch_size, self.nf0*self.s0*self.s0*2)
out = self.conv_img(out)
out = self.relu(out)
out.flatten(1)
out = self.fc_out(out.flatten(1))
if return_intermediate:
return out, l_1
return out
def sample_latent(self, n_samples, z_size, temperature=0.7):
return torch.randn((n_samples, z_size)) * temperature
class ResNet_D(nn.Module):
def __init__(self, data_dim, size, nfilter=64, nfilter_max=512, res_ratio=0.1):
super().__init__()
s0 = self.s0 = 4
nf = self.nf = nfilter
nf_max = self.nf_max = nfilter_max
self.size = size
# Submodules
nlayers = int(np.log2(size / s0))
self.nf0 = min(nf_max, nf * 2 ** nlayers)
nf0 = min(nf, nf_max)
nf1 = min(nf * 2, nf_max)
blocks = [
ResNetBlock(nf0, nf0, bn=False, res_ratio=res_ratio),
ResNetBlock(nf0, nf1, bn=False, res_ratio=res_ratio)
]
self.fc_input = nn.Linear(data_dim, 3 * size * size)
for i in range(1, nlayers + 1):
nf0 = min(nf * 2 ** i, nf_max)
nf1 = min(nf * 2 ** (i + 1), nf_max)
blocks += [
nn.AvgPool2d(3, stride=2, padding=1),
ResNetBlock(nf0, nf1, bn=False, res_ratio=res_ratio),
]
self.conv_img = nn.Conv2d(3, 1 * nf, 3, padding=1)
self.relu = nn.LeakyReLU(0.2, inplace=True)
self.resnet = nn.Sequential(*blocks)
self.fc = nn.Linear(self.nf0 * s0 * s0, 1)
def forward(self, x):
batch_size = x.size(0)
out = self.fc_input(x)
out = self.relu(out).view(batch_size, 3, self.size, self.size)
out = self.relu((self.conv_img(out)))
out = self.resnet(out)
out = out.view(batch_size, self.nf0 * self.s0 * self.s0)
out = self.fc(out)
return out
class ResNetBlock(nn.Module):
def __init__(self, fin, fout, fhidden=None, bn=True, res_ratio=0.1):
super().__init__()
# Attributes
self.bn = bn
self.is_bias = not bn
self.learned_shortcut = (fin != fout)
self.fin = fin
self.fout = fout
if fhidden is None:
self.fhidden = min(fin, fout)
else:
self.fhidden = fhidden
self.res_ratio = res_ratio
# Submodules
self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1, bias=self.is_bias)
if self.bn:
self.bn2d_0 = nn.BatchNorm2d(self.fhidden)
self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=self.is_bias)
if self.bn:
self.bn2d_1 = nn.BatchNorm2d(self.fout)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)
if self.bn:
self.bn2d_s = nn.BatchNorm2d(self.fout)
self.relu = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
x_s = self._shortcut(x)
dx = self.conv_0(x)
if self.bn:
dx = self.bn2d_0(dx)
dx = self.relu(dx)
dx = self.conv_1(dx)
if self.bn:
dx = self.bn2d_1(dx)
out = self.relu(x_s + self.res_ratio * dx)
return out
def _shortcut(self, x):
if self.learned_shortcut:
x_s = self.conv_s(x)
if self.bn:
x_s = self.bn2d_s(x_s)
else:
x_s = x
return x_s