#import yaml #from munch import DefaultMunch import torch import torch.nn as nn def norm(img): low=float(img.min()) high=float(img.max()) img.sub_(low).div_(max(high - low, 1e-5)) def random_sample(batch_size, z_dim, device): # input to the generator # z_dim channels, 1x1 pixels return torch.randn(batch_size,z_dim, 1, 1).to(device) def init_weight(m): if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): # the weights are initialized according to # https://arxiv.org/abs/1511.06434 nn.init.normal_(m.weight, 0, 0.02) if m.bias is not None: if m.bias.data is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) if m.bias.data is not None: m.bias.data.zero_()