Spaces:
Runtime error
Runtime error
import io | |
import numpy as np | |
import onnxruntime | |
from torch import nn | |
import torch.utils.model_zoo as model_zoo | |
import torch.onnx | |
import torch.nn as nn | |
import torch.nn.init as init | |
import matplotlib.pyplot as plt | |
import json | |
from PIL import Image, ImageDraw, ImageFont | |
from resizeimage import resizeimage | |
import numpy as np | |
import pdb | |
import onnx | |
import gradio as gr | |
import os | |
class SuperResolutionNet(nn.Module): | |
def __init__(self, upscale_factor, inplace=False): | |
super(SuperResolutionNet, self).__init__() | |
self.relu = nn.ReLU(inplace=inplace) | |
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) | |
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) | |
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) | |
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) | |
self.pixel_shuffle = nn.PixelShuffle(upscale_factor) | |
self._initialize_weights() | |
def forward(self, x): | |
x = self.relu(self.conv1(x)) | |
x = self.relu(self.conv2(x)) | |
x = self.relu(self.conv3(x)) | |
x = self.pixel_shuffle(self.conv4(x)) | |
return x | |
def _initialize_weights(self): | |
init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) | |
init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) | |
init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) | |
init.orthogonal_(self.conv4.weight) | |
# Create the super-resolution model by using the above model definition. | |
torch_model = SuperResolutionNet(upscale_factor=3) | |
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth' | |
batch_size = 1 # just a random number | |
# Initialize model with the pretrained weights | |
map_location = lambda storage, loc: storage | |
if torch.cuda.is_available(): | |
map_location = None | |
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location)) | |
x = torch.randn(1, 1, 224, 224, requires_grad=True) | |
torch_model.eval() | |
os.system("wget https://github.com/AK391/models/raw/main/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx") | |
# Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers | |
# other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default | |
# based on the build flags) when instantiating InferenceSession. | |
# For example, if NVIDIA GPU is available and ORT Python package is built with CUDA, then call API as following: | |
# onnxruntime.InferenceSession(path/to/model, providers=['CUDAExecutionProvider']) | |
ort_session = onnxruntime.InferenceSession("super-resolution-10.onnx") | |
def inference(img): | |
orig_img = Image.open(img) | |
img = resizeimage.resize_cover(orig_img, [224,224], validate=False) | |
img_ycbcr = img.convert('YCbCr') | |
img_y_0, img_cb, img_cr = img_ycbcr.split() | |
img_ndarray = np.asarray(img_y_0) | |
img_4 = np.expand_dims(np.expand_dims(img_ndarray, axis=0), axis=0) | |
img_5 = img_4.astype(np.float32) / 255.0 | |
ort_inputs = {ort_session.get_inputs()[0].name: img_5} | |
ort_outs = ort_session.run(None, ort_inputs) | |
img_out_y = ort_outs[0] | |
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L') | |
final_img = Image.merge( | |
"YCbCr", [ | |
img_out_y, | |
img_cb.resize(img_out_y.size, Image.BICUBIC), | |
img_cr.resize(img_out_y.size, Image.BICUBIC), | |
]).convert("RGB") | |
return final_img | |
title="sub_pixel_cnn_2016" | |
description="The Super Resolution machine learning model sharpens and upscales the input image to refine the details and improve quality." | |
gr.Interface(inference,gr.inputs.Image(type="filepath"),gr.outputs.Image(type="pil"),title=title,description=description).launch() |