|
import gradio as gr |
|
import os |
|
import PIL.Image as Image |
|
import torchvision.transforms as T |
|
import torch |
|
from Discriminator import * |
|
from Generator import * |
|
from torchvision import models, transforms |
|
from torch.autograd import Variable |
|
from segmenatation_model import * |
|
|
|
|
|
def Generate_Fakes(sketches,classof): |
|
|
|
noisy_sketchs = sketches |
|
noisy_sketchs_ = [] |
|
fake_labels = torch.ones(sketches.size(0) , device=sketches.device,dtype=torch.long) * classof |
|
for noisy_sketch, fake_label in zip(noisy_sketchs, fake_labels): |
|
channels = torch.zeros( |
|
size=(num_classes, *noisy_sketch.shape), device=noisy_sketch.device |
|
) |
|
channels[fake_label] = 1.0 |
|
noisy_sketch = torch.cat((noisy_sketch.unsqueeze(0), channels), dim=0) |
|
noisy_sketchs_.append(noisy_sketch) |
|
|
|
noisy_sketchs = torch.stack(noisy_sketchs_) |
|
|
|
|
|
fake_labels = F.one_hot(fake_labels, num_classes=7).squeeze(1).float().to(device) |
|
|
|
return noisy_sketchs, fake_labels |
|
|
|
image_size = 256 |
|
batch_size = 1 |
|
stats_image = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) |
|
stats_sketch = (0,), (1) |
|
num_classes = 7 |
|
ngpu = torch.cuda.device_count() |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
discriminator = Discriminator(num_classes=num_classes, ngpu=ngpu).to(device) |
|
generator = Generator(ngpu=ngpu, num_classes=num_classes).to(device) |
|
|
|
|
|
model_unfreeze = models.mobilenet_v2(pretrained=True) |
|
model_unfreeze.classifier = nn.Identity() |
|
model_unfrozen = model_unfreeze.features |
|
decoder = Decoder(num_encoder_features=1280, num_classes=1) |
|
seg_model = SegmentationModel(encoder=model_unfrozen, decoder=decoder) |
|
|
|
seg_model_saved = 'segmentation_model.pth' |
|
seg_model.load_state_dict(torch.load(seg_model_saved)) |
|
seg_model.to(device) |
|
|
|
if (device.type == 'cuda') and (ngpu > 1): |
|
generator = nn.DataParallel(generator, list(range(ngpu))) |
|
discriminator = nn.DataParallel(discriminator, list(range(ngpu))) |
|
seg_model = nn.DataParallel(seg_model, list(range(ngpu))) |
|
generator.load_state_dict(torch.load('generator_model_40.pth')) |
|
|
|
def inference(sketch_path,label): |
|
transform_sketch = T.Compose( |
|
[ |
|
T.Resize(image_size), |
|
T.CenterCrop(image_size), |
|
|
|
|
|
] |
|
) |
|
sketch = Image.open(sketch_path) |
|
sketch = transform_sketch(sketch) |
|
|
|
latent_input, gen_labels = Generate_Fakes(sketches=sketch) |
|
latent_input = Variable(latent_input.to(device)) |
|
fake_images = generator(latent_input) |
|
|
|
|
|
return fake_images.cpu().detach().numpy() |
|
|
|
|
|
audio_1 = gr.Image(sources="upload", type="filepath", label="img 1") |
|
audio_2 = gr.Image(sources="upload", type="filepath", label="img 2") |
|
|
|
image_out = gr.Image(label="Generated Image") |
|
gr.Interface( |
|
fn=inference, |
|
inputs=[audio_1, audio_2], |
|
outputs=image_out, |
|
title="GAN", |
|
description = "CGAN generation." |
|
).launch() |