File size: 4,147 Bytes
46fdf2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9629bd9
 
 
 
 
46fdf2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9629bd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46fdf2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import torch
import gradio as gr
import argparse
import time
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import transforms

from pipeline.resnet_csra import ResNet_CSRA
from pipeline.vit_csra import VIT_B16_224_CSRA, VIT_L16_224_CSRA, VIT_CSRA
from pipeline.dataset import DataSet
from torchvision.transforms import transforms
from utils.evaluation.eval import voc_classes, wider_classes, coco_classes, class_dict

torch.manual_seed(0)

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

# Device
# Use GPU if available
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

# Make directories
os.system("mkdir ./models")

# Get model weights
if not os.path.exists("./models/msl_c_voc.pth"):
    os.system(
        "wget -O ./models/msl_c_voc.pth https://github.com/hasibzunair/msl-recognition/releases/download/v1.0-models/msl_c_voc.pth"
    )

# Load model
model = ResNet_CSRA(num_heads=1, lam=0.1, num_classes=20)
normalize = transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
model.to(DEVICE)
print("Loading weights from {}".format("./models/msl_c_voc.pth"))
model.load_state_dict(torch.load("./models/msl_c_voc.pth"))

# Inference!
def inference(img_path):
    # read image
    image = Image.open(img_path).convert("RGB")

    # image pre-process
    transforms_image = transforms.Compose([
        transforms.Resize((448, 448)),
        transforms.ToTensor(),
        normalize
    ])

    image = transforms_image(image)
    image = image.unsqueeze(0)

    # Predict
    result = []
    model.eval()
    with torch.no_grad():
        image = image.to(DEVICE)
        logit = model(image).squeeze(0)
        logit = nn.Sigmoid()(logit)

        pos = torch.where(logit > 0.5)[0].cpu().numpy()
        for k in pos:
            result.append(str(class_dict["voc07"][k]))
    return result


# Define ins outs placeholders
inputs = gr.inputs.Image(type="filepath", label="Input Image")

# Define style
title = "Learning to Recognize Occluded and Small Objects with Partial Inputs"
description = """
Try this demo for <a href="https://github.com/hasibzunair/msl-recognition">MSL</a>,
introduced in <a href="ADD_PAPER_LINK">Learning to Recognize Occluded and Small Objects with Partial Inputs</a>.
\n\n MSL aims to explicitly focus on context from neighbouring regions around 
objects. Further, this also enables to learn a distribution of association across classes. Ideally to handle situations in-the-wild where only part of some object class is visible, but where us humans might readily use context to infer the classes presence.

You can use this demo to get the a list of objects present in your images.

To use it, simply upload an image of your choice and hit submit. You will get one or more names of objects present
in your images from this list: ("aeroplane", "bicycle", "bird", "boat", "bottle",
           "bus", "car", "cat", "chair", "cow", "diningtable",
           "dog", "horse", "motorbike", "person", "pottedplant",
           "sheep", "sofa", "train", "tvmonitor")

\n\n<a href="https://hasibzunair.github.io/msl-recognition/">Project Page</a>
"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1512.03385' target='_blank'>Learning to Recognize Occluded and Small Objects with Partial Inputs</a> | <a href='https://github.com/hasibzunair/msl-recognition' target='_blank'>Github Repo</a></p>"

voc_classes = ("aeroplane", "bicycle", "bird", "boat", "bottle",
           "bus", "car", "cat", "chair", "cow", "diningtable",
           "dog", "horse", "motorbike", "person", "pottedplant",
           "sheep", "sofa", "train", "tvmonitor")

# Run inference
gr.Interface(inference, 
            inputs, 
            outputs="text", 
            examples=["demo_images/000001.jpg", "demo_images/000006.jpg", "demo_images/000009.jpg"], 
            title=title, 
            description=description, 
            article=article,
            analytics_enabled=False).launch()