Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import DCGAN | |
import SRGAN | |
from Utils import color_histogram_mapping, denormalize_images | |
import torch.nn as nn | |
import random | |
device = torch.device("cpu") | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
latent_size = 100 | |
display_width = 450 | |
checkpoint_path = "Checkpoints/150epochs.chkpt" | |
st.title("Generating Abstract Art") | |
st.text("start generating (left side bar)") | |
st.text("Made by Xingyu B.") | |
st.sidebar.subheader("Configurations") | |
seed = st.sidebar.slider('Seed', -10000, 10000, 0) | |
num_images = st.sidebar.slider('Number of Images', 1, 10, 1) | |
use_srgan = st.sidebar.selectbox( | |
'Apply image enhancement', | |
('Yes', 'No') | |
) | |
generate = st.sidebar.button("Generate") | |
# caching the expensive model loading | |
def load_dcgan(): | |
model = torch.jit.load('Checkpoints/dcgan.pt', map_location=device) | |
return model | |
def load_esrgan(): | |
model_state_dict = torch.load("Checkpoints/esrgan.pt", map_location=device) | |
return model_state_dict | |
# if the user wants to generate something new | |
if generate: | |
torch.manual_seed(seed) | |
random.seed(seed) | |
sampled_noise = torch.randn(num_images, latent_size, 1, 1, device=device) | |
generator = load_dcgan() | |
generator.eval() | |
with torch.no_grad(): | |
fakes = generator(sampled_noise).detach() | |
# use srgan for super resolution | |
if use_srgan == "Yes": | |
# restore to the checkpoint | |
st.write("Using DCGAN then ESRGAN upscale...") | |
esrgan_generator = SRGAN.GeneratorRRDB(channels=3, filters=64, num_res_blocks=23).to(device) | |
esrgan_checkpoint = load_esrgan() | |
esrgan_generator.load_state_dict(esrgan_checkpoint) | |
esrgan_generator.eval() | |
with torch.no_grad(): | |
enhanced_fakes = esrgan_generator(fakes).detach().cpu() | |
color_match = color_histogram_mapping(enhanced_fakes, fakes.cpu()) | |
for i in range(len(color_match)): | |
# denormalize and permute to correct color channel | |
st.image(denormalize_images(color_match[i]).permute(1, 2, 0).numpy(), width=display_width) | |
# default setting -> vanilla dcgan generation | |
if use_srgan == "No": | |
fakes = fakes.cpu() | |
st.write("Using DCGAN Model...") | |
for i in range(len(fakes)): | |
st.image(denormalize_images(fakes[i]).permute(1, 2, 0).numpy(), width=display_width) | |