nabeelraza commited on
Commit
019483f
·
1 Parent(s): 35bf553

Add: initial code

Browse files
Files changed (3) hide show
  1. explain.py +214 -0
  2. model/load_model.py +14 -0
  3. requirements.txt +10 -0
explain.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ import torch.backends.cudnn as cudnn
6
+ import matplotlib.pyplot as plt
7
+
8
+ from glob import glob
9
+ from PIL import Image
10
+ from model.load_model import get_model
11
+ from torchvision import transforms
12
+
13
+ from pytorch_grad_cam import GradCAM, GuidedBackpropReLUModel
14
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
15
+ from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image
16
+
17
+ from ultralytics import YOLO
18
+ from rembg import remove
19
+ import uuid
20
+
21
+
22
+ # Static variables
23
+ model_path = "efficientnet-b0-best.pth"
24
+ model_name = "efficientnet_b0"
25
+ YOLO_MODEL_WEIGHTS = "yolo-v11-best.pt"
26
+ classes = ["Healthy", "Resistant", "Susceptible"]
27
+ resizing_transforms = transforms.Compose([transforms.CenterCrop(224)])
28
+
29
+
30
+ # Function definitions
31
+ def reproduce(seed=42):
32
+ random.seed(seed)
33
+ np.random.seed(seed)
34
+ torch.manual_seed(seed)
35
+ torch.cuda.manual_seed_all(seed)
36
+ cudnn.deterministic = True
37
+ cudnn.benchmark = False
38
+
39
+
40
+ def get_grad_cam_results(image, transformed_image, class_index=0):
41
+ with GradCAM(model=model, target_layers=target_layers) as cam:
42
+ targets = [ClassifierOutputTarget(class_index)]
43
+ grayscale_cam = cam(
44
+ input_tensor=transformed_image.unsqueeze(0), targets=targets
45
+ )
46
+ grayscale_cam = grayscale_cam[0, :]
47
+
48
+ visualization = show_cam_on_image(
49
+ np.array(image) / 255.0, grayscale_cam, use_rgb=True
50
+ )
51
+ return visualization, grayscale_cam
52
+
53
+
54
+ def get_backpropagation_results(transformed_image, class_index=0):
55
+ transformed_image = transformed_image.unsqueeze(0)
56
+ backpropagation = gbp_model(transformed_image, target_category=class_index)
57
+ bp_deprocessed = deprocess_image(backpropagation)
58
+ return backpropagation, bp_deprocessed
59
+
60
+
61
+ def get_guided_gradcam(image, cam_grayscale, bp):
62
+ cam_mask = np.expand_dims(cam_grayscale, axis=-1)
63
+ cam_mask = np.repeat(cam_mask, 3, axis=-1)
64
+ img = show_cam_on_image(
65
+ np.array(image) / 255.0, deprocess_image(cam_mask * bp), use_rgb=False
66
+ )
67
+ return img
68
+
69
+
70
+ def explain_results(image, class_index=0):
71
+ transformed_image = image_transform(image)
72
+ image = resizing_transforms(image)
73
+
74
+ visualization, cam_mask = get_grad_cam_results(
75
+ image, transformed_image, class_index
76
+ )
77
+ backpropagation, bp_deprocessed = get_backpropagation_results(
78
+ transformed_image, class_index
79
+ )
80
+ guided_gradcam = get_guided_gradcam(image, cam_mask, backpropagation)
81
+
82
+ return visualization, bp_deprocessed, guided_gradcam
83
+
84
+
85
+ def make_prediction_and_explain(image):
86
+ transformed_image = image_transform(image)
87
+ transformed_image = transformed_image.unsqueeze(0)
88
+ model.eval()
89
+ with torch.no_grad():
90
+ output = model(transformed_image)
91
+ output = torch.nn.functional.softmax(output, dim=1)
92
+
93
+ predictions = [round(x, 4) * 100 for x in output[0].tolist()]
94
+ results = {}
95
+
96
+ for i, k in enumerate(classes):
97
+ gradcam, bp_deprocessed, guided_gradcam = explain_results(image, class_index=i)
98
+
99
+ results[k] = {
100
+ "original_image": image,
101
+ "prediction": f"{k} ({predictions[i]}%)",
102
+ "gradcam": gradcam,
103
+ "backpropagation": bp_deprocessed,
104
+ "guided_gradcam": guided_gradcam,
105
+ }
106
+
107
+ return results
108
+
109
+
110
+ def save_explanation_results(res, path):
111
+ fig, ax = plt.subplots(3, 4, figsize=(15, 15))
112
+ for i, (k, v) in enumerate(res.items()):
113
+ ax[i, 0].imshow(v["original_image"])
114
+ ax[i, 0].set_title(f"Original Image (class: {v['prediction']}")
115
+ ax[i, 0].axis("off")
116
+
117
+ ax[i, 1].imshow(v["gradcam"])
118
+ ax[i, 1].set_title("GradCAM")
119
+ ax[i, 1].axis("off")
120
+
121
+ ax[i, 2].imshow(v["backpropagation"])
122
+ ax[i, 2].set_title("Backpropagation")
123
+ ax[i, 2].axis("off")
124
+
125
+ ax[i, 3].imshow(v["guided_gradcam"])
126
+ ax[i, 3].set_title("Guided GradCAM")
127
+ ax[i, 3].axis("off")
128
+
129
+ plt.tight_layout()
130
+ plt.savefig(path, bbox_inches="tight")
131
+ plt.close(fig)
132
+
133
+
134
+ # load stuff
135
+ reproduce()
136
+
137
+ model, image_transform = get_model(model_name)
138
+ model.load_state_dict(torch.load(model_path))
139
+ model.train()
140
+ target_layers = [model.conv_head]
141
+ gbp_model = GuidedBackpropReLUModel(model=model, device="cpu")
142
+
143
+ yolo_model = YOLO(YOLO_MODEL_WEIGHTS)
144
+
145
+
146
+ def get_results(img_path=None, img_for_testing=None):
147
+ if img_path is None and img_for_testing is None:
148
+ raise ValueError("Either img_path or img_for_testing should be provided.")
149
+
150
+ if img_path is not None:
151
+ results = yolo_model(img_path)
152
+ image = Image.open(img_path)
153
+
154
+ if img_for_testing is not None:
155
+ results = yolo_model(img_for_testing)
156
+ image = Image.fromarray(img_for_testing)
157
+
158
+ result_paths = []
159
+
160
+ for i, result in enumerate(results):
161
+ unique_id = uuid.uuid4().hex
162
+ save_path = f"/tmp/with-white-bg-result-{unique_id}.png"
163
+ bbox = result.boxes.xyxy[0].cpu().numpy().astype(int)
164
+ bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
165
+
166
+ bbox_image = remove(bbox_image).convert("RGB")
167
+
168
+ res = make_prediction_and_explain(bbox_image)
169
+ save_explanation_results(res, save_path)
170
+
171
+ result_paths.append(save_path)
172
+
173
+ return result_paths
174
+
175
+
176
+ if __name__ == "__main__":
177
+ # Actual logic
178
+ reproduce()
179
+
180
+ model, image_transform = get_model(model_name)
181
+ model.load_state_dict(torch.load(model_path))
182
+ model.train()
183
+ target_layers = [model.conv_head]
184
+ gbp_model = GuidedBackpropReLUModel(model=model, device="cpu")
185
+
186
+ yolo_model = YOLO(YOLO_MODEL_WEIGHTS)
187
+
188
+ for IMAGE_PATH in glob("samples/*"):
189
+ start = time.perf_counter()
190
+
191
+ results = yolo_model(IMAGE_PATH)
192
+ image = Image.open(IMAGE_PATH)
193
+
194
+ for i, result in enumerate(results):
195
+ save_path = IMAGE_PATH.replace(
196
+ "samples/", f"sample-results/with-white-bg-result-{i:02d}-"
197
+ )
198
+ bbox = result.boxes.xyxy[0].cpu().numpy().astype(int)
199
+ bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
200
+
201
+ bbox_image = remove(bbox_image).convert("RGB")
202
+ bbox_image = Image.fromarray(
203
+ np.where(
204
+ np.array(bbox_image) == [0, 0, 0],
205
+ [255, 255, 255],
206
+ np.array(bbox_image),
207
+ ).astype(np.uint8)
208
+ )
209
+
210
+ res = make_prediction_and_explain(bbox_image)
211
+ save_explanation_results(res, save_path)
212
+
213
+ end = time.perf_counter() - start
214
+ print(f"Completed in {end}s")
model/load_model.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ from timm.data import resolve_data_config
3
+ from timm.data.transforms_factory import create_transform
4
+
5
+
6
+ CLASSES = ["Healthy", "Resistant", "Susceptible"]
7
+
8
+
9
+ def get_model(model_name):
10
+ model = timm.create_model(model_name, pretrained=True, num_classes=len(CLASSES))
11
+ config = resolve_data_config({}, model=model)
12
+ image_transform = create_transform(**config)
13
+
14
+ return model, image_transform
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ huggingface-hub
5
+ transformers
6
+ accelerate
7
+ wandb
8
+ seaborn
9
+ matplotlib
10
+ ultralytics