LITnet / app.py
JungminChung's picture
first commit
3617b5f
import utils
import torch
import gradio as gr
import numpy as np
from PIL import Image
from network import ImageTransformNet_dpws
from torch.autograd import Variable
from torchvision import transforms
with gr.Blocks() as demo:
with gr.Row():
gr.HTML('<h1 style="text-align: center;">์Šคํƒ€์ผ ๋ณ€ํ™˜๊ธฐ</h1>')
with gr.Row():
with gr.Column():
style_radio = gr.Radio(['La muse', 'Mosaic', 'Starry Night Crop', 'Wave Crop'], label='์›ํ•˜๋Š” ์Šคํƒ€์ผ ์„ ํƒ!')
image_input = gr.Image(label='์ฝ˜ํ…์ธ  ์ด๋ฏธ์ง€')
convert_button = gr.Button('๋ณ€ํ™˜!')
with gr.Column():
result_image = gr.Image(label='๊ฒฐ๊ณผ ์ด๋ฏธ์ง€')
def transform_image(style, img):
dtype = torch.FloatTensor
# content image
img_transform_512 = transforms.Compose([
# transforms.Scale(512), # scale shortest side to image_size
transforms.Resize(512), # scale shortest side to image_size
# transforms.CenterCrop(512), # crop center image_size out
transforms.ToTensor(), # turn image from [0-255] to [0-1]
utils.normalize_tensor_transform() # normalize with ImageNet values
])
content = Image.fromarray(img)
content = img_transform_512(content)
content = content.unsqueeze(0)
# content = Variable(content).type(dtype)
content = Variable(content.repeat(1, 1, 1, 1), requires_grad=False).type(dtype)
# load style model
model_folder_name = '_'.join(style.lower().split())
model_path = 'models/' + model_folder_name + '/compressed.model'
checkpoint_lw = torch.load(model_path)
style_model = ImageTransformNet_dpws().type(dtype)
style_model.load_state_dict((checkpoint_lw))
# process input image
stylized = style_model(content).cpu()
utils.save_image('results.jpg', stylized.data[0])
return 'results.jpg'
convert_button.click(
transform_image,
inputs=[style_radio, image_input],
outputs=[result_image],
)
demo.launch()