Commit
·
af5b380
1
Parent(s):
8f8af36
updated: pipeline
Browse files- app.py +7 -4
- explain.py +12 -11
- requirements.txt +2 -1
app.py
CHANGED
@@ -4,10 +4,10 @@ from PIL import Image
|
|
4 |
from explain import get_results, reproduce
|
5 |
|
6 |
|
7 |
-
def classify_and_explain(image, object_detection=False):
|
8 |
reproduce()
|
9 |
# This function will classify the image and return a list of image paths
|
10 |
-
list_of_images = get_results(img_for_testing=image, od=object_detection)
|
11 |
return list_of_images
|
12 |
|
13 |
|
@@ -23,8 +23,8 @@ def get_examples():
|
|
23 |
]
|
24 |
|
25 |
return [
|
26 |
-
[Image.open(i), True] for i in glob("samples/*") if i not in od_off_examples
|
27 |
-
] + [[Image.open(i), False] for i in glob("samples/*") if i in od_off_examples]
|
28 |
|
29 |
|
30 |
demo = gr.Interface(
|
@@ -34,6 +34,9 @@ demo = gr.Interface(
|
|
34 |
gr.Checkbox(
|
35 |
label="Extract Leaves", info="What to extract leafs before classification"
|
36 |
),
|
|
|
|
|
|
|
37 |
],
|
38 |
outputs="gallery",
|
39 |
examples=get_examples(),
|
|
|
4 |
from explain import get_results, reproduce
|
5 |
|
6 |
|
7 |
+
def classify_and_explain(image, object_detection=False, remove_bg=False):
|
8 |
reproduce()
|
9 |
# This function will classify the image and return a list of image paths
|
10 |
+
list_of_images = get_results(img_for_testing=image, od=object_detection, remove_bg=remove_bg)
|
11 |
return list_of_images
|
12 |
|
13 |
|
|
|
23 |
]
|
24 |
|
25 |
return [
|
26 |
+
[Image.open(i), True, False] for i in glob("samples/*") if i not in od_off_examples
|
27 |
+
] + [[Image.open(i), False, False] for i in glob("samples/*") if i in od_off_examples]
|
28 |
|
29 |
|
30 |
demo = gr.Interface(
|
|
|
34 |
gr.Checkbox(
|
35 |
label="Extract Leaves", info="What to extract leafs before classification"
|
36 |
),
|
37 |
+
gr.Checkbox(
|
38 |
+
label="Remove Background", info="Remove the background and make it white"
|
39 |
+
),
|
40 |
],
|
41 |
outputs="gallery",
|
42 |
examples=get_examples(),
|
explain.py
CHANGED
@@ -16,7 +16,7 @@ from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image
|
|
16 |
|
17 |
from ultralytics import YOLO
|
18 |
|
19 |
-
|
20 |
import uuid
|
21 |
|
22 |
|
@@ -25,7 +25,7 @@ model_path = "efficientnet-b2.pth"
|
|
25 |
model_name = "efficientnet_b2"
|
26 |
YOLO_MODEL_WEIGHTS = "yolo-v11-best.pt"
|
27 |
classes = ["Healthy", "Resistant", "Susceptible"]
|
28 |
-
resizing_transforms = transforms.Compose([transforms.CenterCrop(
|
29 |
|
30 |
|
31 |
# Function definitions
|
@@ -141,7 +141,7 @@ gbp_model = GuidedBackpropReLUModel(model=model, device="cpu")
|
|
141 |
yolo_model = YOLO(YOLO_MODEL_WEIGHTS)
|
142 |
|
143 |
|
144 |
-
def get_results(img_path=None, img_for_testing=None, od=False):
|
145 |
if img_path is None and img_for_testing is None:
|
146 |
raise ValueError("Either img_path or img_for_testing should be provided.")
|
147 |
|
@@ -161,14 +161,15 @@ def get_results(img_path=None, img_for_testing=None, od=False):
|
|
161 |
bbox = result.boxes.xyxy[0].cpu().numpy().astype(int)
|
162 |
bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
172 |
|
173 |
res = make_prediction_and_explain(bbox_image)
|
174 |
save_explanation_results(res, save_path)
|
|
|
16 |
|
17 |
from ultralytics import YOLO
|
18 |
|
19 |
+
from rembg import remove
|
20 |
import uuid
|
21 |
|
22 |
|
|
|
25 |
model_name = "efficientnet_b2"
|
26 |
YOLO_MODEL_WEIGHTS = "yolo-v11-best.pt"
|
27 |
classes = ["Healthy", "Resistant", "Susceptible"]
|
28 |
+
resizing_transforms = transforms.Compose([transforms.CenterCrop(256)])
|
29 |
|
30 |
|
31 |
# Function definitions
|
|
|
141 |
yolo_model = YOLO(YOLO_MODEL_WEIGHTS)
|
142 |
|
143 |
|
144 |
+
def get_results(img_path=None, img_for_testing=None, od=False, remove_bg=False):
|
145 |
if img_path is None and img_for_testing is None:
|
146 |
raise ValueError("Either img_path or img_for_testing should be provided.")
|
147 |
|
|
|
161 |
bbox = result.boxes.xyxy[0].cpu().numpy().astype(int)
|
162 |
bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
|
163 |
|
164 |
+
if remove_bg:
|
165 |
+
bbox_image = remove(bbox_image).convert("RGB")
|
166 |
+
bbox_image = Image.fromarray(
|
167 |
+
np.where(
|
168 |
+
np.array(bbox_image) == [0, 0, 0],
|
169 |
+
[255, 255, 255],
|
170 |
+
np.array(bbox_image),
|
171 |
+
).astype(np.uint8)
|
172 |
+
)
|
173 |
|
174 |
res = make_prediction_and_explain(bbox_image)
|
175 |
save_explanation_results(res, save_path)
|
requirements.txt
CHANGED
@@ -8,4 +8,5 @@ wandb
|
|
8 |
seaborn
|
9 |
matplotlib
|
10 |
ultralytics
|
11 |
-
grad-cam
|
|
|
|
8 |
seaborn
|
9 |
matplotlib
|
10 |
ultralytics
|
11 |
+
grad-cam
|
12 |
+
rembg[cpu]
|