nabeelraza commited on
Commit
af5b380
·
1 Parent(s): 8f8af36

updated: pipeline

Browse files
Files changed (3) hide show
  1. app.py +7 -4
  2. explain.py +12 -11
  3. requirements.txt +2 -1
app.py CHANGED
@@ -4,10 +4,10 @@ from PIL import Image
4
  from explain import get_results, reproduce
5
 
6
 
7
- def classify_and_explain(image, object_detection=False):
8
  reproduce()
9
  # This function will classify the image and return a list of image paths
10
- list_of_images = get_results(img_for_testing=image, od=object_detection)
11
  return list_of_images
12
 
13
 
@@ -23,8 +23,8 @@ def get_examples():
23
  ]
24
 
25
  return [
26
- [Image.open(i), True] for i in glob("samples/*") if i not in od_off_examples
27
- ] + [[Image.open(i), False] for i in glob("samples/*") if i in od_off_examples]
28
 
29
 
30
  demo = gr.Interface(
@@ -34,6 +34,9 @@ demo = gr.Interface(
34
  gr.Checkbox(
35
  label="Extract Leaves", info="What to extract leafs before classification"
36
  ),
 
 
 
37
  ],
38
  outputs="gallery",
39
  examples=get_examples(),
 
4
  from explain import get_results, reproduce
5
 
6
 
7
+ def classify_and_explain(image, object_detection=False, remove_bg=False):
8
  reproduce()
9
  # This function will classify the image and return a list of image paths
10
+ list_of_images = get_results(img_for_testing=image, od=object_detection, remove_bg=remove_bg)
11
  return list_of_images
12
 
13
 
 
23
  ]
24
 
25
  return [
26
+ [Image.open(i), True, False] for i in glob("samples/*") if i not in od_off_examples
27
+ ] + [[Image.open(i), False, False] for i in glob("samples/*") if i in od_off_examples]
28
 
29
 
30
  demo = gr.Interface(
 
34
  gr.Checkbox(
35
  label="Extract Leaves", info="What to extract leafs before classification"
36
  ),
37
+ gr.Checkbox(
38
+ label="Remove Background", info="Remove the background and make it white"
39
+ ),
40
  ],
41
  outputs="gallery",
42
  examples=get_examples(),
explain.py CHANGED
@@ -16,7 +16,7 @@ from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image
16
 
17
  from ultralytics import YOLO
18
 
19
- # from rembg import remove
20
  import uuid
21
 
22
 
@@ -25,7 +25,7 @@ model_path = "efficientnet-b2.pth"
25
  model_name = "efficientnet_b2"
26
  YOLO_MODEL_WEIGHTS = "yolo-v11-best.pt"
27
  classes = ["Healthy", "Resistant", "Susceptible"]
28
- resizing_transforms = transforms.Compose([transforms.CenterCrop(224)])
29
 
30
 
31
  # Function definitions
@@ -141,7 +141,7 @@ gbp_model = GuidedBackpropReLUModel(model=model, device="cpu")
141
  yolo_model = YOLO(YOLO_MODEL_WEIGHTS)
142
 
143
 
144
- def get_results(img_path=None, img_for_testing=None, od=False):
145
  if img_path is None and img_for_testing is None:
146
  raise ValueError("Either img_path or img_for_testing should be provided.")
147
 
@@ -161,14 +161,15 @@ def get_results(img_path=None, img_for_testing=None, od=False):
161
  bbox = result.boxes.xyxy[0].cpu().numpy().astype(int)
162
  bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
163
 
164
- # bbox_image = remove(bbox_image).convert("RGB")
165
- # bbox_image = Image.fromarray(
166
- # np.where(
167
- # np.array(bbox_image) == [0, 0, 0],
168
- # [255, 255, 255],
169
- # np.array(bbox_image),
170
- # ).astype(np.uint8)
171
- # )
 
172
 
173
  res = make_prediction_and_explain(bbox_image)
174
  save_explanation_results(res, save_path)
 
16
 
17
  from ultralytics import YOLO
18
 
19
+ from rembg import remove
20
  import uuid
21
 
22
 
 
25
  model_name = "efficientnet_b2"
26
  YOLO_MODEL_WEIGHTS = "yolo-v11-best.pt"
27
  classes = ["Healthy", "Resistant", "Susceptible"]
28
+ resizing_transforms = transforms.Compose([transforms.CenterCrop(256)])
29
 
30
 
31
  # Function definitions
 
141
  yolo_model = YOLO(YOLO_MODEL_WEIGHTS)
142
 
143
 
144
+ def get_results(img_path=None, img_for_testing=None, od=False, remove_bg=False):
145
  if img_path is None and img_for_testing is None:
146
  raise ValueError("Either img_path or img_for_testing should be provided.")
147
 
 
161
  bbox = result.boxes.xyxy[0].cpu().numpy().astype(int)
162
  bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
163
 
164
+ if remove_bg:
165
+ bbox_image = remove(bbox_image).convert("RGB")
166
+ bbox_image = Image.fromarray(
167
+ np.where(
168
+ np.array(bbox_image) == [0, 0, 0],
169
+ [255, 255, 255],
170
+ np.array(bbox_image),
171
+ ).astype(np.uint8)
172
+ )
173
 
174
  res = make_prediction_and_explain(bbox_image)
175
  save_explanation_results(res, save_path)
requirements.txt CHANGED
@@ -8,4 +8,5 @@ wandb
8
  seaborn
9
  matplotlib
10
  ultralytics
11
- grad-cam
 
 
8
  seaborn
9
  matplotlib
10
  ultralytics
11
+ grad-cam
12
+ rembg[cpu]