Spaces:
Runtime error
Runtime error
Orpheous1
commited on
Commit
·
5dc90b6
1
Parent(s):
9d9aad0
probs
Browse files- app.py +21 -2
- code/datamodules/__pycache__/__init__.cpython-38.pyc +0 -0
- code/datamodules/__pycache__/base.cpython-38.pyc +0 -0
- code/datamodules/__pycache__/transformations.cpython-38.pyc +0 -0
- code/models/__pycache__/interpretation.cpython-39.pyc +0 -0
- code/models/interpretation.py +2 -1
- requirements.txt +2 -1
app.py
CHANGED
@@ -9,6 +9,8 @@ from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image
|
|
9 |
import gradio as gr
|
10 |
import numpy as np
|
11 |
import torch
|
|
|
|
|
12 |
|
13 |
# Load Vision Transformer
|
14 |
hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
|
@@ -59,19 +61,36 @@ def get_mask(image, model_name: str):
|
|
59 |
dm_image = feature_extractor(image).unsqueeze(0)
|
60 |
dm_out = diffmask_model.get_mask(dm_image)
|
61 |
mask = dm_out["mask"][0].detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
pred = dm_out["pred_class"][0].detach()
|
63 |
pred = diffmask_model.model.config.id2label[pred.item()]
|
64 |
|
65 |
masked_img = draw_mask(image, mask)
|
66 |
heatmap = draw_heatmap(image, mask)
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
# Launch demo interface
|
70 |
gr.Interface(
|
71 |
get_mask,
|
72 |
inputs=[gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
|
73 |
gr.inputs.Dropdown(["DiffMask-CiFAR-10", "DiffMask-ImageNet"])],
|
74 |
-
outputs=[gr.outputs.Image(label="Output"), gr.outputs.Label(label="Prediction")
|
|
|
75 |
title="Vision DiffMask Demo",
|
76 |
live=True,
|
77 |
).launch()
|
|
|
9 |
import gradio as gr
|
10 |
import numpy as np
|
11 |
import torch
|
12 |
+
import seaborn as sns
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
|
15 |
# Load Vision Transformer
|
16 |
hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
|
|
|
61 |
dm_image = feature_extractor(image).unsqueeze(0)
|
62 |
dm_out = diffmask_model.get_mask(dm_image)
|
63 |
mask = dm_out["mask"][0].detach()
|
64 |
+
logits = dm_out["logits"][0].detach().softmax(dim=-1)
|
65 |
+
logits_orig = dm_out["logits_orig"][0].detach().softmax(dim=-1)
|
66 |
+
# fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 10))
|
67 |
+
# sns.displot(logits_orig.cpu().numpy().flatten(), kind="kde", label="Original", ax=ax)
|
68 |
+
top5logits_orig = logits_orig.topk(5, dim=-1)
|
69 |
+
idx = top5logits_orig.indices
|
70 |
+
# keep the top 5 classes from the indices of the top 5 logits
|
71 |
+
top5logits_orig = top5logits_orig.values
|
72 |
+
top5logits = logits[idx]
|
73 |
+
|
74 |
pred = dm_out["pred_class"][0].detach()
|
75 |
pred = diffmask_model.model.config.id2label[pred.item()]
|
76 |
|
77 |
masked_img = draw_mask(image, mask)
|
78 |
heatmap = draw_heatmap(image, mask)
|
79 |
+
orig_probs = {diffmask_model.model.config.id2label[i]: top5logits_orig[i].item() for i in range(5)}
|
80 |
+
pred_probs = {diffmask_model.model.config.id2label[i]: top5logits[i].item() for i in range(5)}
|
81 |
+
|
82 |
+
return np.hstack((masked_img, heatmap)), pred, orig_probs, pred_probs
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
|
87 |
# Launch demo interface
|
88 |
gr.Interface(
|
89 |
get_mask,
|
90 |
inputs=[gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
|
91 |
gr.inputs.Dropdown(["DiffMask-CiFAR-10", "DiffMask-ImageNet"])],
|
92 |
+
outputs=[gr.outputs.Image(label="Output"), gr.outputs.Label(label="Prediction"),
|
93 |
+
gr.Label(label="Original Probabilities"), gr.Label(label="Predicted Probabilities")],
|
94 |
title="Vision DiffMask Demo",
|
95 |
live=True,
|
96 |
).launch()
|
code/datamodules/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (388 Bytes). View file
|
|
code/datamodules/__pycache__/base.cpython-38.pyc
ADDED
Binary file (4.57 kB). View file
|
|
code/datamodules/__pycache__/transformations.cpython-38.pyc
ADDED
Binary file (1.88 kB). View file
|
|
code/models/__pycache__/interpretation.cpython-39.pyc
CHANGED
Binary files a/code/models/__pycache__/interpretation.cpython-39.pyc and b/code/models/__pycache__/interpretation.cpython-39.pyc differ
|
|
code/models/interpretation.py
CHANGED
@@ -277,7 +277,8 @@ class ImageInterpretationNet(pl.LightningModule):
|
|
277 |
mask = F.interpolate(mask, scale_factor=S)
|
278 |
mask = mask.reshape(B, H, W)
|
279 |
|
280 |
-
return {"mask": mask, "kl_div": kl_div, "pred_class": pred_class
|
|
|
281 |
|
282 |
def forward(self, x: Tensor) -> Tensor:
|
283 |
return self.model(x).logits
|
|
|
277 |
mask = F.interpolate(mask, scale_factor=S)
|
278 |
mask = mask.reshape(B, H, W)
|
279 |
|
280 |
+
return {"mask": mask, "kl_div": kl_div, "pred_class": pred_class,
|
281 |
+
"logits": logits, "logits_orig": logits_orig}
|
282 |
|
283 |
def forward(self, x: Tensor) -> Tensor:
|
284 |
return self.model(x).logits
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ pytorch_lightning
|
|
4 |
torch
|
5 |
torchvision
|
6 |
transformers
|
7 |
-
|
|
|
|
4 |
torch
|
5 |
torchvision
|
6 |
transformers
|
7 |
+
seaborn
|
8 |
+
matplotlib
|