CBN_cGAN / app.py
KhadgaA's picture
Update generator_model_40.pth loading in app.py
d5da7cb
import gradio as gr
import os
import PIL.Image as Image
import torchvision.transforms as T
import torch
from Discriminator import *
from Generator import *
from torchvision import models, transforms
from torch.autograd import Variable
from segmenatation_model import *
# from inference_kathbadh import inference_kathbadh
def Generate_Fakes(sketches,classof):
# noisy_sketchs = add_gaussian_noise(sketches)
noisy_sketchs = sketches
noisy_sketchs_ = []
fake_labels = torch.ones(sketches.size(0) , device=sketches.device,dtype=torch.long) * classof
for noisy_sketch, fake_label in zip(noisy_sketchs, fake_labels):
channels = torch.zeros(
size=(num_classes, *noisy_sketch.shape), device=noisy_sketch.device
)
channels[fake_label] = 1.0
noisy_sketch = torch.cat((noisy_sketch.unsqueeze(0), channels), dim=0)
noisy_sketchs_.append(noisy_sketch)
noisy_sketchs = torch.stack(noisy_sketchs_)
# convert fake_labels to one-hot encoding
fake_labels = F.one_hot(fake_labels, num_classes=7).squeeze(1).float().to(device)
return noisy_sketchs, fake_labels
image_size = 256
batch_size = 1
stats_image = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
stats_sketch = (0,), (1)
num_classes = 7
ngpu = torch.cuda.device_count()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
discriminator = Discriminator(num_classes=num_classes, ngpu=ngpu).to(device)
generator = Generator(ngpu=ngpu, num_classes=num_classes).to(device)
model_unfreeze = models.mobilenet_v2(pretrained=True)
model_unfreeze.classifier = nn.Identity()
model_unfrozen = model_unfreeze.features
decoder = Decoder(num_encoder_features=1280, num_classes=1)
seg_model = SegmentationModel(encoder=model_unfrozen, decoder=decoder)
seg_model_saved = 'segmentation_model.pth'
seg_model.load_state_dict(torch.load(seg_model_saved))
seg_model.to(device)
if (device.type == 'cuda') and (ngpu > 1):
generator = nn.DataParallel(generator, list(range(ngpu)))
discriminator = nn.DataParallel(discriminator, list(range(ngpu)))
seg_model = nn.DataParallel(seg_model, list(range(ngpu)))
generator.load_state_dict(torch.load('generator_model_40.pth'))
def inference(sketch_path,label):
transform_sketch = T.Compose(
[
T.Resize(image_size),
T.CenterCrop(image_size),
# T.ToTensor(),
# T.Normalize(*stats_sketch)
]
)
sketch = Image.open(sketch_path)
sketch = transform_sketch(sketch)
latent_input, gen_labels = Generate_Fakes(sketches=sketch)
latent_input = Variable(latent_input.to(device))
fake_images = generator(latent_input)
return fake_images.cpu().detach().numpy()
audio_1 = gr.Image(sources="upload", type="filepath", label="img 1")
audio_2 = gr.Image(sources="upload", type="filepath", label="img 2")
# text_output = gr.Textbox(label="Similarity Score")
image_out = gr.Image(label="Generated Image")
gr.Interface(
fn=inference,
inputs=[audio_1, audio_2],
outputs=image_out,
title="GAN",
description = "CGAN generation."
).launch()