|
import os |
|
import torch |
|
import gradio as gr |
|
import argparse |
|
import time |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
|
|
from tqdm import tqdm |
|
from PIL import Image |
|
from torch.utils.data import DataLoader |
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
from pipeline.resnet_csra import ResNet_CSRA |
|
from pipeline.vit_csra import VIT_B16_224_CSRA, VIT_L16_224_CSRA, VIT_CSRA |
|
from pipeline.dataset import DataSet |
|
from torchvision.transforms import transforms |
|
from utils.evaluation.eval import voc_classes, wider_classes, coco_classes, class_dict |
|
|
|
torch.manual_seed(0) |
|
|
|
if torch.cuda.is_available(): |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
DEVICE = torch.device("cuda") |
|
else: |
|
DEVICE = torch.device("cpu") |
|
|
|
|
|
os.system("mkdir ./models") |
|
|
|
|
|
if not os.path.exists("./models/msl_c_voc.pth"): |
|
os.system( |
|
"wget -O ./models/msl_c_voc.pth https://github.com/hasibzunair/msl-recognition/releases/download/v1.0-models/msl_c_voc.pth" |
|
) |
|
|
|
|
|
model = ResNet_CSRA(num_heads=1, lam=0.1, num_classes=20) |
|
normalize = transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]) |
|
model.to(DEVICE) |
|
print("Loading weights from {}".format("./models/msl_c_voc.pth")) |
|
model.load_state_dict(torch.load("./models/msl_c_voc.pth")) |
|
|
|
|
|
def inference(img_path): |
|
|
|
image = Image.open(img_path).convert("RGB") |
|
|
|
|
|
transforms_image = transforms.Compose([ |
|
transforms.Resize((448, 448)), |
|
transforms.ToTensor(), |
|
normalize |
|
]) |
|
|
|
image = transforms_image(image) |
|
image = image.unsqueeze(0) |
|
|
|
|
|
result = [] |
|
model.eval() |
|
with torch.no_grad(): |
|
image = image.to(DEVICE) |
|
logit = model(image).squeeze(0) |
|
logit = nn.Sigmoid()(logit) |
|
|
|
pos = torch.where(logit > 0.5)[0].cpu().numpy() |
|
for k in pos: |
|
result.append(str(class_dict["voc07"][k])) |
|
return result |
|
|
|
|
|
|
|
inputs = gr.inputs.Image(type="filepath", label="Input Image") |
|
|
|
|
|
title = "Learning to Recognize Occluded and Small Objects with Partial Inputs" |
|
description = """ |
|
Try this demo for <a href="https://github.com/hasibzunair/msl-recognition">MSL</a>, |
|
introduced in <a href="ADD_PAPER_LINK">Learning to Recognize Occluded and Small Objects with Partial Inputs</a>. |
|
\n\n MSL aims to explicitly focus on context from neighbouring regions around |
|
objects. Further, this also enables to learn a distribution of association across classes. Ideally to handle situations in-the-wild where only part of some object class is visible, but where us humans might readily use context to infer the classes presence. |
|
|
|
You can use this demo to get the a list of objects present in your images. |
|
|
|
To use it, simply upload an image of your choice and hit submit. You will get one or more names of objects present |
|
in your images from this list: ("aeroplane", "bicycle", "bird", "boat", "bottle", |
|
"bus", "car", "cat", "chair", "cow", "diningtable", |
|
"dog", "horse", "motorbike", "person", "pottedplant", |
|
"sheep", "sofa", "train", "tvmonitor") |
|
|
|
\n\n<a href="https://hasibzunair.github.io/msl-recognition/">Project Page</a> |
|
""" |
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1512.03385' target='_blank'>Learning to Recognize Occluded and Small Objects with Partial Inputs</a> | <a href='https://github.com/hasibzunair/msl-recognition' target='_blank'>Github Repo</a></p>" |
|
|
|
voc_classes = ("aeroplane", "bicycle", "bird", "boat", "bottle", |
|
"bus", "car", "cat", "chair", "cow", "diningtable", |
|
"dog", "horse", "motorbike", "person", "pottedplant", |
|
"sheep", "sofa", "train", "tvmonitor") |
|
|
|
|
|
gr.Interface(inference, |
|
inputs, |
|
outputs="text", |
|
examples=["demo_images/000001.jpg", "demo_images/000006.jpg", "demo_images/000009.jpg"], |
|
title=title, |
|
description=description, |
|
article=article, |
|
analytics_enabled=False).launch() |
|
|