venkyyuvy commited on
Commit
2efd69c
·
1 Parent(s): 61e9251

init commit

Browse files
README.md CHANGED
@@ -10,4 +10,6 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
10
  license: mit
11
  ---
12
 
13
+ This application classifies the given images into one of the ten classes in the cifar 10 dataset. It provides the sample misclasification images done by the model in the test dataset.
14
+
15
+ The app also provides the option of visualizing the GradCAM (Gradient based Class activation mapping) output for model explainability. The user has the option to choose the layer
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+ import torch
3
+ from torchvision import transforms
4
+ import gradio as gr
5
+ from pytorch_grad_cam import GradCAM
6
+ from pytorch_grad_cam.utils.image import show_cam_on_image
7
+ import gradio as gr
8
+ import model
9
+ from data_loader import CIFAR_CLASS_LABELS, TEST_TRANSFORM
10
+ import matplotlib
11
+ matplotlib.use('agg')
12
+ from matplotlib import pyplot as plt
13
+
14
+ resnet_18 = model.LitResnet()
15
+ state_dict = torch.load("saved_model.pth", map_location=torch.device('cpu'))
16
+ resnet_18.load_state_dict(state_dict)
17
+ resnet_18_model = resnet_18.model
18
+
19
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
20
+ 'dog', 'frog', 'horse', 'ship', 'truck')
21
+
22
+
23
+ def inference(input_img, n_top_classes,
24
+ apply_gradcam, transparency=0.5,
25
+ target_layer_number = -1):
26
+ org_img = input_img
27
+ input_img = TEST_TRANSFORM(image=input_img)['image']
28
+ input_img = input_img.unsqueeze(0)
29
+ outputs = resnet_18_model(input_img)
30
+ softmax = torch.nn.Softmax(dim=0)
31
+ o = softmax(outputs.flatten())
32
+ y = {classes[i]: float(o[i]) for i in range(10)}
33
+ sorted_pred = sorted(y.items(), key=operator.itemgetter(1), reverse=True)
34
+ sorted_pred = sorted_pred[: n_top_classes]
35
+ confidences = {klass: prob for klass, prob in sorted_pred}
36
+ if apply_gradcam:
37
+ target_layers = [resnet_18_model.layer3[target_layer_number]]
38
+ cam = GradCAM(model=resnet_18_model, target_layers=target_layers, use_cuda=False)
39
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
40
+ grayscale_cam = grayscale_cam[0, :]
41
+ visualization = show_cam_on_image(
42
+ org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
43
+ return (gr.update(value= confidences),
44
+ gr.update(value=visualization, visible=True))
45
+ return (gr.update(value=confidences),
46
+ gr.update(visible=False))
47
+
48
+ def show_misclasif(see_misclassif, n_images):
49
+ if see_misclassif:
50
+ subset = torch.load('misclassified_images.pt')
51
+ images, actuals, preds = torch.tensor(subset[0])[:20], subset[1], subset[2]
52
+ figsize=(n_images, 4)
53
+ nrows=2
54
+ ncols=n_images//2
55
+ fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
56
+ fig.suptitle('misclassified images', weight='bold', size=10)
57
+ axes = axes.ravel()
58
+ for img, actual, pred, ax in zip(images, actuals, preds, axes):
59
+ ax.imshow(img)
60
+ ax.set_title(
61
+ f'Prediction={CIFAR_CLASS_LABELS[pred]}\n Actual={CIFAR_CLASS_LABELS[actual]}',
62
+ fontsize=8)
63
+ ax.set(xticks=[], yticks=[], xticklabels=[], yticklabels=[])
64
+ ax.axis('off')
65
+ image_path = "plot.png"
66
+ fig.savefig(image_path)
67
+ plt.close()
68
+ return gr.update(value=image_path, visible=True)
69
+
70
+
71
+ with gr.Blocks() as demo:
72
+ with gr.Row():
73
+ with gr.Column():
74
+ input_image = gr.Image(shape=(32, 32), label="Input Image")
75
+ n_top_classes = gr.Slider(maximum=10, minimum=1, value=3, step=1,
76
+ label="Top n classes to show", interactive=True)
77
+ require_gradcam = gr.Checkbox(label="Apply GradCAM",
78
+ info="Do you want see the GRAD-CAM visualization")
79
+ opacity_gradcam = gr.Slider(0, 1, value=0.5,
80
+ label="Opacity of GradCAM")
81
+ layer_gradcam = gr.Slider(-2, -1, value=-2, step=1,
82
+ label="Which Layer?")
83
+ submit = gr.Button("Submit")
84
+ with gr.Column():
85
+ pred_classes = gr.Label()
86
+ grad_cam = gr.Image(shape=(32, 32),
87
+ label="Output",visible=False)\
88
+ .style(width=128, height=128)
89
+ with gr.Row():
90
+ with gr.Column():
91
+ see_misclassif = gr.Checkbox(label="View misclassified images",
92
+ info="Do you want see the miscassified images in the test dataset")
93
+ n_misclasif = gr.Slider(maximum=20, minimum=2, value=10, step=2,
94
+ label="Number of misclassified images to show",
95
+ interactive=True, visible=False)
96
+ render = gr.Button("Render", visible=False)
97
+ misclasif_display = gr.Image(visible=False)
98
+
99
+ n_top_classes.postprocess(n_top_classes.value)
100
+ submit.click(inference,
101
+ inputs=[input_image, n_top_classes, require_gradcam,
102
+ opacity_gradcam, layer_gradcam],
103
+ outputs=[pred_classes, grad_cam]
104
+ )
105
+ def turn_on_misclasif(see_misclassif):
106
+ if see_misclassif:
107
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
108
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
109
+
110
+ see_misclassif.change(turn_on_misclasif, see_misclassif, [n_misclasif, render, misclasif_display])
111
+ render.click(show_misclasif, [see_misclassif, n_misclasif], misclasif_display)
112
+
113
+ gr.Examples(
114
+ examples=[
115
+ ["examples/truck.jpg", 3, True],
116
+ ["examples/ship.jpg", 3, True],
117
+ ["examples/dog.jpg", 3, True],
118
+ ["examples/cat.jpg", 3, True],
119
+ ["examples/horse.jpg", 3, True],
120
+ ["examples/airplane.jpg", 3, True],
121
+ ["examples/parrot.jpg", 3, True],
122
+ ["examples/automobile.jpg", 3, True],
123
+ ["examples/deer.jpg", 3, True],
124
+ ["examples/frog.jpg", 3, True],
125
+ ],
126
+ inputs=[input_image, n_top_classes, require_gradcam],
127
+ outputs=[pred_classes, grad_cam],
128
+ fn=inference,
129
+ )
130
+ demo.launch()
data_loader.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import datasets
2
+ import albumentations as A
3
+ from albumentations.pytorch import ToTensorV2
4
+
5
+
6
+ NORM_DATA_MEAN = (0.49139968, 0.48215841, 0.44653091)
7
+ NORM_DATA_STD = (0.24703223, 0.24348513, 0.26158784)
8
+
9
+ CIFAR_CLASS_LABELS = [
10
+ 'airplane', 'automobile', 'bird', 'cat', 'deer',
11
+ 'dog', 'frog', 'horse', 'ship', 'truck'
12
+ ]
13
+
14
+ TRAIN_TRANSFORM = A.Compose([
15
+ A.Normalize(
16
+ mean=NORM_DATA_MEAN,
17
+ std=NORM_DATA_STD,
18
+ ),
19
+ A.HorizontalFlip(),
20
+ A.Compose([
21
+ A.PadIfNeeded(min_height=40, min_width=40, p=1.0),
22
+ A.CoarseDropout(max_holes=1, max_height=16, max_width=16,
23
+ min_holes=1, min_height=16, min_width=16,
24
+ fill_value=NORM_DATA_MEAN, mask_fill_value=None, p=1.0),
25
+ A.RandomCrop(p=1.0, height=32, width=32)
26
+ ]),
27
+ ToTensorV2(),
28
+ ])
29
+
30
+ TEST_TRANSFORM = A.Compose([
31
+ A.Normalize(
32
+ mean=NORM_DATA_MEAN,
33
+ std=NORM_DATA_STD,
34
+ ),
35
+ ToTensorV2(),
36
+ ])
37
+
38
+ class CifarAlbumentationsDataset(datasets.CIFAR10):
39
+ def __init__(self, *args, **kwargs):
40
+ super().__init__(*args, **kwargs)
41
+ def __getitem__(self, idx):
42
+ img, target = self.data[idx], self.targets[idx]
43
+ if self.transform:
44
+ augmented = self.transform(image=img)
45
+ image = augmented['image']
46
+ return image, target
47
+
48
+
examples/airplane.jpg ADDED
examples/automobile.jpg ADDED
examples/cat.jpg ADDED
examples/deer.jpg ADDED
examples/dog.jpg ADDED
examples/frog.jpg ADDED
examples/horse.jpg ADDED
examples/ship.jpg ADDED
examples/truck.jpg ADDED
misclassified_images.pt ADDED
Binary file (450 kB). View file
 
model.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ https://github.com/kuangliu/pytorch-cifar
3
+
4
+ ResNet in PyTorch.
5
+
6
+ For Pre-activation ResNet, see 'preact_resnet.py'.
7
+
8
+ Reference:
9
+ [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
10
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
11
+ '''
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+ from torch_lr_finder import LRFinder
16
+
17
+
18
+ class BasicBlock(nn.Module):
19
+ expansion = 1
20
+
21
+ def __init__(self, in_planes, planes, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = nn.Conv2d(
24
+ in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
25
+ self.bn1 = nn.BatchNorm2d(planes)
26
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
27
+ stride=1, padding=1, bias=False)
28
+ self.bn2 = nn.BatchNorm2d(planes)
29
+
30
+ self.shortcut = nn.Sequential()
31
+ if stride != 1 or in_planes != self.expansion*planes:
32
+ self.shortcut = nn.Sequential(
33
+ nn.Conv2d(in_planes, self.expansion*planes,
34
+ kernel_size=1, stride=stride, bias=False),
35
+ nn.BatchNorm2d(self.expansion*planes)
36
+ )
37
+
38
+ def forward(self, x):
39
+ out = F.relu(self.bn1(self.conv1(x)))
40
+ out = self.bn2(self.conv2(out))
41
+ out += self.shortcut(x)
42
+ out = F.relu(out)
43
+ return out
44
+
45
+ class ResNet(nn.Module):
46
+ def __init__(self, block, num_blocks, num_classes=10):
47
+ super(ResNet, self).__init__()
48
+ self.in_planes = 64
49
+
50
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
51
+ stride=1, padding=1, bias=False)
52
+ self.bn1 = nn.BatchNorm2d(64)
53
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
54
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
55
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
56
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
57
+ self.linear = nn.Linear(512*block.expansion, num_classes)
58
+
59
+ def _make_layer(self, block, planes, num_blocks, stride):
60
+ strides = [stride] + [1]*(num_blocks-1)
61
+ layers = []
62
+ for stride in strides:
63
+ layers.append(block(self.in_planes, planes, stride))
64
+ self.in_planes = planes * block.expansion
65
+ return nn.Sequential(*layers)
66
+
67
+ def forward(self, x):
68
+ out = F.relu(self.bn1(self.conv1(x)))
69
+ out = self.layer1(out)
70
+ out = self.layer2(out)
71
+ out = self.layer3(out)
72
+ out = self.layer4(out)
73
+ out = F.avg_pool2d(out, 4)
74
+ out = out.view(out.size(0), -1)
75
+ out = self.linear(out)
76
+ return out
77
+
78
+
79
+ def ResNet18():
80
+ return ResNet(BasicBlock, [2, 2, 2, 2])
81
+
82
+ import torch.nn as nn
83
+ from torch.optim.lr_scheduler import OneCycleLR
84
+ from torch.utils.data import DataLoader
85
+ import matplotlib.pyplot as plt
86
+
87
+ from data_loader import CifarAlbumentationsDataset,\
88
+ CIFAR_CLASS_LABELS, TRAIN_TRANSFORM, TEST_TRANSFORM
89
+ import model
90
+ from torch_lr_finder import LRFinder
91
+
92
+ import torch
93
+ import torch.nn as nn
94
+ import torch.nn.functional as F
95
+ from pytorch_lightning import LightningModule
96
+ from torch.optim.lr_scheduler import OneCycleLR
97
+ from torchmetrics.functional import accuracy
98
+
99
+ class LitResnet(LightningModule):
100
+ def __init__(self, lr=0.03, batch_size=512):
101
+ super().__init__()
102
+
103
+ self.save_hyperparameters()
104
+ self.criterion = nn.CrossEntropyLoss()
105
+ self.model = ResNet18()
106
+
107
+ def forward(self, x):
108
+ return self.model(x)
109
+
110
+ def training_step(self, batch, batch_idx):
111
+ x, y = batch
112
+ output = self.forward(x)
113
+ loss = self.criterion(output, y)
114
+ self.log("train_loss", loss)
115
+ acc = accuracy(torch.argmax(output, dim=1),
116
+ y, 'multiclass', num_classes=10)
117
+ self.log(f"train_acc", acc, prog_bar=True)
118
+ return loss
119
+
120
+ def evaluate(self, batch, stage=None):
121
+ x, y = batch
122
+ output = self.forward(x)
123
+ loss = self.criterion(output, y)
124
+ preds = torch.argmax(output, dim=1)
125
+ acc = accuracy(preds, y, 'multiclass', num_classes=10)
126
+
127
+ if stage:
128
+ self.log(f"{stage}_loss", loss, prog_bar=True)
129
+ self.log(f"{stage}_acc", acc, prog_bar=True)
130
+
131
+ def validation_step(self, batch, batch_idx):
132
+ self.evaluate(batch, "val")
133
+
134
+ def test_step(self, batch, batch_idx):
135
+ self.evaluate(batch, "test")
136
+
137
+ # todo
138
+ # change the default for num_iter
139
+ def lr_finder(self, optimizer, num_iter=200,):
140
+ lr_finder = LRFinder(self, optimizer, self.criterion,
141
+ device=self.device)
142
+ lr_finder.range_test(
143
+ self.train_dataloader(), end_lr=1,
144
+ num_iter=num_iter, step_mode='exp',
145
+ )
146
+ ax, suggested_lr = lr_finder.plot(suggest_lr=True)
147
+ # todo
148
+ # how to log maplotlib images
149
+ # self.logger.experiment.add_image('lr_finder', plt.gcf(), 0)
150
+ lr_finder.reset()
151
+ return suggested_lr
152
+ def configure_optimizers(self):
153
+ optimizer = torch.optim.SGD(
154
+ self.parameters(),
155
+ lr=self.hparams.lr,
156
+ momentum=0.9,
157
+ weight_decay=5e-4,
158
+ )
159
+ suggested_lr = self.lr_finder(optimizer)
160
+ steps_per_epoch = len(self.train_dataloader())
161
+ scheduler_dict = {
162
+ "scheduler": OneCycleLR(
163
+ optimizer, max_lr=suggested_lr,
164
+ steps_per_epoch=steps_per_epoch,
165
+ epochs=self.trainer.max_epochs,
166
+ pct_start=5/self.trainer.max_epochs,
167
+ three_phase=False,
168
+ div_factor=100,
169
+ final_div_factor=100,
170
+ anneal_strategy='linear',
171
+ ),
172
+ "interval": "step",
173
+ }
174
+ return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
175
+ ####################
176
+ # DATA RELATED HOOKS
177
+ ####################
178
+
179
+ def prepare_data(self, data_path='../data'):
180
+ CifarAlbumentationsDataset(
181
+ data_path, train=True, download=True)
182
+ CifarAlbumentationsDataset(
183
+ data_path, train=False, download=True)
184
+
185
+ def setup(self, stage=None, data_dir='../data'):
186
+
187
+ if stage == "fit" or stage is None:
188
+ self.train_dataset = CifarAlbumentationsDataset(data_dir, train=True, transform=TRAIN_TRANSFORM)
189
+ self.test_dataset = CifarAlbumentationsDataset(data_dir, train=False, transform=TEST_TRANSFORM)
190
+
191
+
192
+ def train_dataloader(self):
193
+ return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size,
194
+ shuffle=True, pin_memory=True) #num_workers=4,
195
+
196
+ def val_dataloader(self):
197
+ return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size,
198
+ shuffle=False, pin_memory=True)
199
+
requirements.txt ADDED
File without changes