ECS7022P-WGAN-GP / utils.py
HuseynG's picture
Update utils.py
dc1b525
import time
import random
import threading
from gradio_client import Client
import torch
import torchvision
import torch.nn as nn
# import torch.optim as optim
from torchvision.utils import save_image
from torchvision.transforms import transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
noise_dim = 100
img_channels = 3
gen_features = 64
critic_features = 64
class Generator(nn.Module):
def __init__(self, noise_dim, img_channels, gen_features):
super(Generator, self).__init__()
self.gen = nn.Sequential(
# Input: N x noise_dim x 1 x 1
self._block(noise_dim, gen_features * 16, 4, 1, 0),
self._block(gen_features * 16, gen_features * 8, 4, 2, 1),
self._block(gen_features * 8, gen_features * 4, 4, 2, 1),
self._block(gen_features * 4, gen_features * 2, 4, 2, 1),
self._block(gen_features * 2, gen_features, 4, 2, 1),
self._block(gen_features, gen_features // 2, 4, 2, 1),
nn.ConvTranspose2d(gen_features // 2, img_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh()
# Output: N x channels_img x 256 x 256
)
def _block(self, in_c, out_c, k_size, s_size, p_size): # This is a nice practice that I learned from: # https://github.com/aladdinpersson
return nn.Sequential(
nn.ConvTranspose2d(
in_c, out_c,
k_size, s_size, p_size
),
nn.BatchNorm2d(out_c),
nn.ReLU(),
)
def forward(self, x):
return self.gen(x)
class Critic(nn.Module): # aka discirminator (called critic in )
def __init__(self, img_channels, critic_features):
super(Critic, self).__init__()
self.critic = nn.Sequential(
# Input: N x channels_img x 256 x 256
nn.Conv2d(img_channels, critic_features, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
self._block(critic_features, critic_features * 2, 4, 2, 1),
self._block(critic_features * 2, critic_features * 4, 4, 2, 1),
self._block(critic_features * 4, critic_features * 8 , 4, 2, 1),
self._block(critic_features * 8, critic_features * 16 , 4, 2, 1),
self._block(critic_features * 16, critic_features * 32 , 4, 2, 1),
nn.Conv2d(critic_features * 32, 1, kernel_size=4, stride=1, padding=0)
# Output: N x 1 x 1 x 1
)
def _block(self, in_c, out_c, k_size, s_size, p_size): # this is a nice practice that I learned from: # https://github.com/aladdinpersson
return nn.Sequential(
nn.Conv2d(
in_c, out_c,
k_size, s_size, p_size
),
nn.BatchNorm2d(out_c),
nn.LeakyReLU(0.2),
)
def forward(self, x):
return self.critic(x)
def weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear, nn.BatchNorm2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# The following gradianet penalty funciton is take from:
# https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/GANs/4.%20WGAN-GP/utils.py
def gradient_penalty(critic, real, fake, device):
BATCH_SIZE, C, H, W = real.shape
alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * alpha + fake * (1 - alpha)
# calculating the critic scores
mixed_scores = critic(interpolated_images)
# taking the gradient of the scores w.r.t the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
def load_model(model_type, model_path):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Loding the model based on the model_type
if model_type == 'generator':
model = Generator(noise_dim, img_channels, gen_features)
elif model_type == 'critic':
model = Critic(img_channels, critic_features)
else:
raise ValueError(f"Invalid model_type: {model_type}")
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
return model
import torchvision.transforms as transforms
from PIL import Image
def generate_random_img(model):
# Creating a random noise tensor
noise = torch.randn(1, noise_dim, 1, 1).to(device) # 1 is the number of images you want to generate
# Generating an image using the trained generator
with torch.no_grad():
generated_image = model(noise)
# Converting the generated tensor to a PIL image
generated_image = generated_image.cpu().detach().squeeze(0)
generated_image = transforms.ToPILImage()(generated_image)
return generated_image
def schedule_function(): # for dummy space so that this spcae and other space will call each other to avoid sleep time, again this project is for academic purpose.
while True:
wait_time = random.uniform(3 * 60 * 60, 5 * 60 * 60) # Get a random wait time between 3 and 5 hours in seconds
# wait_time = random.uniform(3, 5)
time.sleep(wait_time)
# call dummyscape
client = Client("https://huseyng-dummyspace.hf.space/")
result = client.predict(
f"Howdy! {wait_time}", # str representing string value in 'name' Textbox component
api_name="/predict"
)
print(result)
if __name__ == "__main__":
model = load_model('generator','generator_model_epoch_94.pth')
generate_random_img(model)