Spaces:
Runtime error
Runtime error
File size: 2,517 Bytes
76f797b cbe6be9 76f797b cbe6be9 76f797b cbe6be9 76f797b cbe6be9 76f797b cbe6be9 76f797b cbe6be9 76f797b 92bdcd3 d7d76c4 cbe6be9 76f797b cbe6be9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
import numpy as np
import codecs
import torch
import torchvision.transforms as transforms
import gradio as gr
from PIL import Image
from unetplusplus import NestedUNet
torch.manual_seed(0)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
# Device
DEVICE = "cpu"
print(DEVICE)
# Load color map
cmap = np.load("cmap.npy")
# Make directories
os.system("mkdir ./models")
# Get model weights
if not os.path.exists("./models/masksupnyu39.31d.pth"):
os.system(
"wget -O ./models/masksupnyu39.31d.pth https://github.com/hasibzunair/masksup-segmentation/releases/download/v0.1/masksupnyu39.31iou.pth"
)
# Load model
model = NestedUNet(num_classes=40)
checkpoint = torch.load(
"./models/masksupnyu39.31d.pth", map_location=torch.device("cpu")
)
model.load_state_dict(checkpoint)
model = model.to(DEVICE)
model.eval()
# Main inference function
def inference(img_path):
image = Image.open(img_path).convert("RGB")
transforms_image = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
image = transforms_image(image)
image = image[None, :] # batch dimension
# Predict
with torch.no_grad():
output = torch.sigmoid(model(image.to(DEVICE).float()))
output = (
torch.softmax(output, dim=1)
.argmax(dim=1)[0]
.float()
.cpu()
.numpy()
.astype(np.uint8)
)
pred = cmap[output]
return pred
# App
title = "Masked Supervised Learning for Semantic Segmentation"
description = codecs.open("description.html", "r", "utf-8").read()
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2210.00923' target='_blank'>Masked Supervised Learning for Semantic Segmentation</a> | <a href='https://github.com/hasibzunair/masksup-segmentation' target='_blank'>Github</a></p>"
gr.Interface(
inference,
gr.inputs.Image(type="filepath", label="Input Image"),
gr.outputs.Image(type="numpy", label="Predicted Output"),
examples=[
"./sample_images/a.png",
"./sample_images/b.png",
"./sample_images/c.png",
"./sample_images/d.png",
],
title=title,
description=description,
article=article,
allow_flagging=False,
analytics_enabled=False,
).launch(debug=True, enable_queue=True)
|