Spaces:
Sleeping
Sleeping
File size: 5,093 Bytes
fb8e3a6 b20b64c fb8e3a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# Generator model
# Critic Model
import torch
import torchvision
import torch.nn as nn
# import torch.optim as optim
from torchvision.utils import save_image
from torchvision.transforms import transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
noise_dim = 100
img_channels = 3
gen_features = 64
critic_features = 64
class Generator(nn.Module):
def __init__(self, noise_dim, img_channels, gen_features):
super(Generator, self).__init__()
self.gen = nn.Sequential(
# Input: N x noise_dim x 1 x 1
self._block(noise_dim, gen_features * 16, 4, 1, 0),
self._block(gen_features * 16, gen_features * 8, 4, 2, 1),
self._block(gen_features * 8, gen_features * 4, 4, 2, 1),
self._block(gen_features * 4, gen_features * 2, 4, 2, 1),
self._block(gen_features * 2, gen_features, 4, 2, 1),
self._block(gen_features, gen_features // 2, 4, 2, 1),
nn.ConvTranspose2d(gen_features // 2, img_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh()
# Output: N x channels_img x 256 x 256
)
def _block(self, in_c, out_c, k_size, s_size, p_size): # This is a nice practice that I learned from: # https://github.com/aladdinpersson
return nn.Sequential(
nn.ConvTranspose2d(
in_c, out_c,
k_size, s_size, p_size
),
nn.BatchNorm2d(out_c),
nn.ReLU(),
)
def forward(self, x):
return self.gen(x)
class Critic(nn.Module): # aka discirminator (called critic in )
def __init__(self, img_channels, critic_features):
super(Critic, self).__init__()
self.critic = nn.Sequential(
# Input: N x channels_img x 256 x 256
nn.Conv2d(img_channels, critic_features, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
self._block(critic_features, critic_features * 2, 4, 2, 1),
self._block(critic_features * 2, critic_features * 4, 4, 2, 1),
self._block(critic_features * 4, critic_features * 8 , 4, 2, 1),
self._block(critic_features * 8, critic_features * 16 , 4, 2, 1),
self._block(critic_features * 16, critic_features * 32 , 4, 2, 1),
nn.Conv2d(critic_features * 32, 1, kernel_size=4, stride=1, padding=0)
# Output: N x 1 x 1 x 1
)
def _block(self, in_c, out_c, k_size, s_size, p_size): # this is a nice practice that I learned from: # https://github.com/aladdinpersson
return nn.Sequential(
nn.Conv2d(
in_c, out_c,
k_size, s_size, p_size
),
nn.BatchNorm2d(out_c),
nn.LeakyReLU(0.2),
)
def forward(self, x):
return self.critic(x)
def weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear, nn.BatchNorm2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# The following gradianet penalty funciton is take from:
# https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/GANs/4.%20WGAN-GP/utils.py
def gradient_penalty(critic, real, fake, device):
BATCH_SIZE, C, H, W = real.shape
alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * alpha + fake * (1 - alpha)
# calculating the critic scores
mixed_scores = critic(interpolated_images)
# taking the gradient of the scores w.r.t the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
def load_model(model_type, model_path):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Loding the model based on the model_type
if model_type == 'generator':
model = Generator(noise_dim, img_channels, gen_features)
elif model_type == 'critic':
model = Critic(img_channels, critic_features)
else:
raise ValueError(f"Invalid model_type: {model_type}")
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
return model
import torchvision.transforms as transforms
from PIL import Image
def generate_random_img(model):
# Creating a random noise tensor
noise = torch.randn(1, noise_dim, 1, 1).to(device) # 1 is the number of images you want to generate
# Generating an image using the trained generator
with torch.no_grad():
generated_image = model(noise)
# Converting the generated tensor to a PIL image
generated_image = generated_image.cpu().detach().squeeze(0)
generated_image = transforms.ToPILImage()(generated_image)
return generated_image
if __name__ == "__main__":
model = load_model('generator','generator_model_epoch_94.pth')
generate_random_img(model) |