Spaces:
Sleeping
Sleeping
import utils | |
import torch | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from network import ImageTransformNet_dpws | |
from torch.autograd import Variable | |
from torchvision import transforms | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.HTML('<h1 style="text-align: center;">์คํ์ผ ๋ณํ๊ธฐ</h1>') | |
with gr.Row(): | |
with gr.Column(): | |
style_radio = gr.Radio(['La muse', 'Mosaic', 'Starry Night Crop', 'Wave Crop'], label='์ํ๋ ์คํ์ผ ์ ํ!') | |
image_input = gr.Image(label='์ฝํ ์ธ ์ด๋ฏธ์ง') | |
convert_button = gr.Button('๋ณํ!') | |
with gr.Column(): | |
result_image = gr.Image(label='๊ฒฐ๊ณผ ์ด๋ฏธ์ง') | |
def transform_image(style, img): | |
dtype = torch.FloatTensor | |
# content image | |
img_transform_512 = transforms.Compose([ | |
# transforms.Scale(512), # scale shortest side to image_size | |
transforms.Resize(512), # scale shortest side to image_size | |
# transforms.CenterCrop(512), # crop center image_size out | |
transforms.ToTensor(), # turn image from [0-255] to [0-1] | |
utils.normalize_tensor_transform() # normalize with ImageNet values | |
]) | |
content = Image.fromarray(img) | |
content = img_transform_512(content) | |
content = content.unsqueeze(0) | |
# content = Variable(content).type(dtype) | |
content = Variable(content.repeat(1, 1, 1, 1), requires_grad=False).type(dtype) | |
# load style model | |
model_folder_name = '_'.join(style.lower().split()) | |
model_path = 'models/' + model_folder_name + '/compressed.model' | |
checkpoint_lw = torch.load(model_path) | |
style_model = ImageTransformNet_dpws().type(dtype) | |
style_model.load_state_dict((checkpoint_lw)) | |
# process input image | |
stylized = style_model(content).cpu() | |
utils.save_image('results.jpg', stylized.data[0]) | |
return 'results.jpg' | |
convert_button.click( | |
transform_image, | |
inputs=[style_radio, image_input], | |
outputs=[result_image], | |
) | |
demo.launch() |