curt-park commited on
Commit
4ddb621
1 Parent(s): acb3eab

Add threshold

Browse files
Files changed (1) hide show
  1. app.py +75 -18
app.py CHANGED
@@ -1,18 +1,19 @@
1
  import os
2
- import PIL
3
  from functools import lru_cache
4
-
5
  from random import randint
6
- import gradio as gr
 
7
  import cv2
8
- import torch
9
  import numpy as np
10
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
11
- from typing import List
 
12
 
13
  CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
14
  MODEL_TYPE = "default"
15
  MAX_WIDTH = MAX_HEIGHT = 800
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
 
@@ -32,37 +33,63 @@ def adjust_image_size(image: np.ndarray) -> np.ndarray:
32
  if width > MAX_WIDTH:
33
  height, width = int(MAX_WIDTH / width * height), MAX_WIDTH
34
  image = cv2.resize(image, (width, height))
35
- print(image.shape)
36
  return image
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def draw_masks(
40
  image: np.ndarray, masks: List[np.ndarray], alpha: float = 0.7
41
  ) -> np.ndarray:
42
  for mask in masks:
43
  color = [randint(127, 255) for _ in range(3)]
44
- segmentation = mask["segmentation"]
45
 
46
  # draw mask overlay
47
- colored_seg = np.expand_dims(segmentation, 0).repeat(3, axis=0)
48
- colored_seg = np.moveaxis(colored_seg, 0, -1)
49
- masked = np.ma.MaskedArray(image, mask=colored_seg, fill_value=color)
50
  image_overlay = masked.filled()
51
  image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
52
 
53
  # draw contour
54
  contours, _ = cv2.findContours(
55
- np.uint8(segmentation), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
56
  )
57
  cv2.drawContours(image, contours, -1, (255, 0, 0), 2)
58
  return image
59
 
60
 
61
- def segment(image_path: str, query: str) -> PIL.ImageFile.ImageFile:
 
 
 
 
 
 
62
  mask_generator = load_mask_generator()
63
  # reduce the size to save gpu memory
64
  image = adjust_image_size(cv2.imread(image_path))
65
  masks = mask_generator.generate(image)
 
 
 
66
  image = draw_masks(image, masks)
67
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
68
  image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")
@@ -71,15 +98,45 @@ def segment(image_path: str, query: str) -> PIL.ImageFile.ImageFile:
71
 
72
  demo = gr.Interface(
73
  fn=segment,
74
- inputs=[gr.Image(type="filepath"), "text"],
 
 
 
 
 
 
75
  outputs="image",
76
  allow_flagging="never",
77
  title="Segment Anything with CLIP",
78
  examples=[
79
- [os.path.join(os.path.dirname(__file__), "examples/dog.jpg"), ""],
80
- [os.path.join(os.path.dirname(__file__), "examples/city.jpg"), ""],
81
- [os.path.join(os.path.dirname(__file__), "examples/food.jpg"), ""],
82
- [os.path.join(os.path.dirname(__file__), "examples/horse.jpg"), ""],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  ],
84
  )
85
 
 
1
  import os
 
2
  from functools import lru_cache
 
3
  from random import randint
4
+ from typing import Dict, List
5
+
6
  import cv2
7
+ import gradio as gr
8
  import numpy as np
9
+ import PIL
10
+ import torch
11
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
12
 
13
  CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
14
  MODEL_TYPE = "default"
15
  MAX_WIDTH = MAX_HEIGHT = 800
16
+ THRESHOLD = 0.05
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
 
 
33
  if width > MAX_WIDTH:
34
  height, width = int(MAX_WIDTH / width * height), MAX_WIDTH
35
  image = cv2.resize(image, (width, height))
 
36
  return image
37
 
38
 
39
+ def filter_masks(
40
+ masks: List[Dict[str, np.ndarray]],
41
+ predicted_iou_threshold: float,
42
+ stability_score_threshold: float,
43
+ query: str,
44
+ clip_threshold: float,
45
+ ) -> List[np.ndarray]:
46
+ filtered_masks: List[Dict[str, np.ndarray]] = []
47
+ for mask in masks:
48
+ if (
49
+ mask["predicted_iou"] < predicted_iou_threshold
50
+ or mask["stability_score"] < stability_score_threshold
51
+ ):
52
+ continue
53
+ filtered_masks.append(mask)
54
+
55
+ return [mask["segmentation"] for mask in filtered_masks]
56
+
57
+
58
  def draw_masks(
59
  image: np.ndarray, masks: List[np.ndarray], alpha: float = 0.7
60
  ) -> np.ndarray:
61
  for mask in masks:
62
  color = [randint(127, 255) for _ in range(3)]
 
63
 
64
  # draw mask overlay
65
+ colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
66
+ colored_mask = np.moveaxis(colored_mask, 0, -1)
67
+ masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
68
  image_overlay = masked.filled()
69
  image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
70
 
71
  # draw contour
72
  contours, _ = cv2.findContours(
73
+ np.uint8(mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
74
  )
75
  cv2.drawContours(image, contours, -1, (255, 0, 0), 2)
76
  return image
77
 
78
 
79
+ def segment(
80
+ predicted_iou_threshold: float,
81
+ stability_score_threshold: float,
82
+ clip_threshold: float,
83
+ image_path: str,
84
+ query: str,
85
+ ) -> PIL.ImageFile.ImageFile:
86
  mask_generator = load_mask_generator()
87
  # reduce the size to save gpu memory
88
  image = adjust_image_size(cv2.imread(image_path))
89
  masks = mask_generator.generate(image)
90
+ masks = filter_masks(
91
+ masks, predicted_iou_threshold, stability_score_threshold, query, clip_threshold
92
+ )
93
  image = draw_masks(image, masks)
94
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
95
  image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")
 
98
 
99
  demo = gr.Interface(
100
  fn=segment,
101
+ inputs=[
102
+ gr.Slider(0, 1, value=0.9, label="predicted_iou_threshold"),
103
+ gr.Slider(0, 1, value=0.8, label="stability_score_threshold"),
104
+ gr.Slider(0, 1, value=0.05, label="clip_threshold"),
105
+ gr.Image(type="filepath"),
106
+ "text",
107
+ ],
108
  outputs="image",
109
  allow_flagging="never",
110
  title="Segment Anything with CLIP",
111
  examples=[
112
+ [
113
+ 0.9,
114
+ 0.8,
115
+ 0.05,
116
+ os.path.join(os.path.dirname(__file__), "examples/dog.jpg"),
117
+ "",
118
+ ],
119
+ [
120
+ 0.9,
121
+ 0.8,
122
+ 0.05,
123
+ os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
124
+ "",
125
+ ],
126
+ [
127
+ 0.9,
128
+ 0.8,
129
+ 0.05,
130
+ os.path.join(os.path.dirname(__file__), "examples/food.jpg"),
131
+ "",
132
+ ],
133
+ [
134
+ 0.9,
135
+ 0.8,
136
+ 0.05,
137
+ os.path.join(os.path.dirname(__file__), "examples/horse.jpg"),
138
+ "",
139
+ ],
140
  ],
141
  )
142