U2net-with-rgba / app.py
Akbartus's picture
Update app.py
69ca83d
raw
history blame
1.85 kB
import os
import copy
import time
import cv2 as cv
import numpy as np
import onnxruntime
from PIL import Image
import gradio
def run_inference(onnx_session, input_size, image):
# リサイズ
temp_image = copy.deepcopy(image)
resize_image = cv.resize(temp_image, dsize=(input_size, input_size))
x = cv.cvtColor(resize_image, cv.COLOR_BGR2RGB)
# 前処理
x = np.array(x, dtype=np.float32)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
x = (x / 255 - mean) / std
x = x.transpose(2, 0, 1).astype('float32')
x = x.reshape(-1, 3, input_size, input_size)
# 推論
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
onnx_result = onnx_session.run([output_name], {input_name: x})
# 後処理
onnx_result = np.array(onnx_result).squeeze()
min_value = np.min(onnx_result)
max_value = np.max(onnx_result)
onnx_result = (onnx_result - min_value) / (max_value - min_value)
onnx_result *= 255
onnx_result = onnx_result.astype('uint8')
return onnx_result
# Load model
onnx_session = onnxruntime.InferenceSession("u2net.onnx")
def create_rgba(mode, image):
out = run_inference(
onnx_session,
320,
image,
)
resize_image = cv.resize(out, dsize=(image.shape[1], image.shape[0]))
if mode == "binary":
resize_image[resize_image > 255] = 255
resize_image[resize_image < 125] = 0
mask = Image.fromarray(resize_image)
rgba_image = Image.fromarray(image).convert('RGBA')
rgba_image.putalpha(mask)
return rgba_image
inputs = [gradio.inputs.Radio(["binary", "smooth"]), gradio.inputs.Image()]
outputs = gradio.outputs.Image(type="pil")
iface = gradio.Interface(fn=create_rgba, inputs=inputs, outputs=outputs, api_name="add")
iface.launch()