#import yaml | |
#from munch import DefaultMunch | |
import torch | |
import torch.nn as nn | |
def norm(img): | |
low=float(img.min()) | |
high=float(img.max()) | |
img.sub_(low).div_(max(high - low, 1e-5)) | |
def random_sample(batch_size, z_dim, device): | |
# input to the generator | |
# z_dim channels, 1x1 pixels | |
return torch.randn(batch_size,z_dim, 1, 1).to(device) | |
def init_weight(m): | |
if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): | |
# the weights are initialized according to | |
# https://arxiv.org/abs/1511.06434 | |
nn.init.normal_(m.weight, 0, 0.02) | |
if m.bias is not None: | |
if m.bias.data is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1) | |
if m.bias.data is not None: | |
m.bias.data.zero_() | |