Martin Tomov commited on
Commit
dcce1a5
·
verified ·
1 Parent(s): 746e19a

@spaces.GPU sam_utils.py

Browse files
Files changed (1) hide show
  1. sam_utils.py +5 -17
sam_utils.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
-
3
  import random
4
  from dataclasses import dataclass
5
  from typing import Any, List, Dict, Optional, Union, Tuple
@@ -12,7 +11,7 @@ import matplotlib.pyplot as plt
12
  from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
13
  import gradio as gr
14
  import json
15
-
16
 
17
  @dataclass
18
  class BoundingBox:
@@ -24,6 +23,7 @@ class BoundingBox:
24
  @property
25
  def xyxy(self) -> List[float]:
26
  return [self.xmin, self.ymin, self.xmax, self.ymax]
 
27
  @dataclass
28
  class DetectionResult:
29
  score: float
@@ -63,12 +63,10 @@ def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[Dete
63
 
64
  return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
65
 
66
-
67
  def plot_detections(image: Union[Image.Image, np.ndarray], detections: List[DetectionResult], include_bboxes: bool = True) -> np.ndarray:
68
  annotated_image = annotate(image, detections, include_bboxes)
69
  return annotated_image
70
 
71
-
72
  def load_image(image: Union[str, Image.Image]) -> Image.Image:
73
  if isinstance(image, str) and image.startswith("http"):
74
  image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
@@ -78,7 +76,6 @@ def load_image(image: Union[str, Image.Image]) -> Image.Image:
78
  image = image.convert("RGB")
79
  return image
80
 
81
-
82
  def get_boxes(detection_results: List[DetectionResult]) -> List[List[List[float]]]:
83
  boxes = []
84
  for result in detection_results:
@@ -86,7 +83,6 @@ def get_boxes(detection_results: List[DetectionResult]) -> List[List[List[float]
86
  boxes.append(xyxy)
87
  return [boxes]
88
 
89
-
90
  def mask_to_polygon(mask: np.ndarray) -> np.ndarray:
91
  contours, _ = cv2.findContours(
92
  mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
@@ -95,7 +91,6 @@ def mask_to_polygon(mask: np.ndarray) -> np.ndarray:
95
  largest_contour = max(contours, key=cv2.contourArea)
96
  return largest_contour
97
 
98
-
99
  def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
100
  masks = masks.cpu().float().permute(0, 2, 3, 1).mean(
101
  axis=-1).numpy().astype(np.uint8)
@@ -108,7 +103,7 @@ def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> L
108
  np.zeros(shape, dtype=np.uint8), [polygon], 1)
109
  return list(masks)
110
 
111
-
112
  def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[Dict[str, Any]]:
113
  detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
114
  object_detector = pipeline(
@@ -118,7 +113,7 @@ def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detect
118
  image, candidate_labels=labels, threshold=threshold)
119
  return [DetectionResult.from_dict(result) for result in results]
120
 
121
-
122
  def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
123
  segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
124
  segmentator = AutoModelForMaskGeneration.from_pretrained(
@@ -135,19 +130,16 @@ def segment(image: Image.Image, detection_results: List[DetectionResult], polygo
135
  detection_result.mask = mask
136
  return detection_results
137
 
138
-
139
  def grounded_segmentation(image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False, detector_id: Optional[str] = None, segmenter_id: Optional[str] = None) -> Tuple[np.ndarray, List[DetectionResult]]:
140
  image = load_image(image)
141
  detections = detect(image, labels, threshold, detector_id)
142
  detections = segment(image, detections, polygon_refinement, segmenter_id)
143
  return np.array(image), detections
144
 
145
-
146
  def mask_to_min_max(mask: np.ndarray) -> Tuple[int, int, int, int]:
147
  y, x = np.where(mask)
148
  return x.min(), y.min(), x.max(), y.max()
149
 
150
-
151
  def extract_and_paste_insect(original_image: np.ndarray, detection: DetectionResult, background: np.ndarray) -> None:
152
  mask = detection.mask
153
  xmin, ymin, xmax, ymax = mask_to_min_max(mask)
@@ -162,7 +154,6 @@ def extract_and_paste_insect(original_image: np.ndarray, detection: DetectionRes
162
  insect_area = background[y_offset:y_end, x_offset:x_end]
163
  insect_area[mask_crop == 1] = insect[mask_crop == 1]
164
 
165
-
166
  def create_yellow_background_with_insects(image: np.ndarray) -> np.ndarray:
167
  labels = ["insect"]
168
 
@@ -179,14 +170,13 @@ def create_yellow_background_with_insects(image: np.ndarray) -> np.ndarray:
179
  yellow_background = cv2.cvtColor(yellow_background, cv2.COLOR_BGR2RGB)
180
  return yellow_background
181
 
182
-
183
  def run_length_encoding(mask):
184
  pixels = mask.flatten()
185
  rle = []
186
  last_val = 0
187
  count = 0
188
  for pixel in pixels:
189
- if pixel == last_val:
190
  count += 1
191
  else:
192
  if count > 0:
@@ -197,7 +187,6 @@ def run_length_encoding(mask):
197
  rle.append(count)
198
  return rle
199
 
200
-
201
  def detections_to_json(detections):
202
  detections_list = []
203
  for detection in detections:
@@ -214,7 +203,6 @@ def detections_to_json(detections):
214
  detections_list.append(detection_dict)
215
  return detections_list
216
 
217
-
218
  def crop_bounding_boxes_with_yellow_background(image: np.ndarray, yellow_background: np.ndarray, detections: List[DetectionResult]) -> List[np.ndarray]:
219
  crops = []
220
  for detection in detections:
 
1
  import os
 
2
  import random
3
  from dataclasses import dataclass
4
  from typing import Any, List, Dict, Optional, Union, Tuple
 
11
  from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
12
  import gradio as gr
13
  import json
14
+ import spaces
15
 
16
  @dataclass
17
  class BoundingBox:
 
23
  @property
24
  def xyxy(self) -> List[float]:
25
  return [self.xmin, self.ymin, self.xmax, self.ymax]
26
+
27
  @dataclass
28
  class DetectionResult:
29
  score: float
 
63
 
64
  return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
65
 
 
66
  def plot_detections(image: Union[Image.Image, np.ndarray], detections: List[DetectionResult], include_bboxes: bool = True) -> np.ndarray:
67
  annotated_image = annotate(image, detections, include_bboxes)
68
  return annotated_image
69
 
 
70
  def load_image(image: Union[str, Image.Image]) -> Image.Image:
71
  if isinstance(image, str) and image.startswith("http"):
72
  image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
 
76
  image = image.convert("RGB")
77
  return image
78
 
 
79
  def get_boxes(detection_results: List[DetectionResult]) -> List[List[List[float]]]:
80
  boxes = []
81
  for result in detection_results:
 
83
  boxes.append(xyxy)
84
  return [boxes]
85
 
 
86
  def mask_to_polygon(mask: np.ndarray) -> np.ndarray:
87
  contours, _ = cv2.findContours(
88
  mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
91
  largest_contour = max(contours, key=cv2.contourArea)
92
  return largest_contour
93
 
 
94
  def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
95
  masks = masks.cpu().float().permute(0, 2, 3, 1).mean(
96
  axis=-1).numpy().astype(np.uint8)
 
103
  np.zeros(shape, dtype=np.uint8), [polygon], 1)
104
  return list(masks)
105
 
106
+ @spaces.GPU
107
  def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[Dict[str, Any]]:
108
  detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
109
  object_detector = pipeline(
 
113
  image, candidate_labels=labels, threshold=threshold)
114
  return [DetectionResult.from_dict(result) for result in results]
115
 
116
+ @spaces.GPU
117
  def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
118
  segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
119
  segmentator = AutoModelForMaskGeneration.from_pretrained(
 
130
  detection_result.mask = mask
131
  return detection_results
132
 
 
133
  def grounded_segmentation(image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False, detector_id: Optional[str] = None, segmenter_id: Optional[str] = None) -> Tuple[np.ndarray, List[DetectionResult]]:
134
  image = load_image(image)
135
  detections = detect(image, labels, threshold, detector_id)
136
  detections = segment(image, detections, polygon_refinement, segmenter_id)
137
  return np.array(image), detections
138
 
 
139
  def mask_to_min_max(mask: np.ndarray) -> Tuple[int, int, int, int]:
140
  y, x = np.where(mask)
141
  return x.min(), y.min(), x.max(), y.max()
142
 
 
143
  def extract_and_paste_insect(original_image: np.ndarray, detection: DetectionResult, background: np.ndarray) -> None:
144
  mask = detection.mask
145
  xmin, ymin, xmax, ymax = mask_to_min_max(mask)
 
154
  insect_area = background[y_offset:y_end, x_offset:x_end]
155
  insect_area[mask_crop == 1] = insect[mask_crop == 1]
156
 
 
157
  def create_yellow_background_with_insects(image: np.ndarray) -> np.ndarray:
158
  labels = ["insect"]
159
 
 
170
  yellow_background = cv2.cvtColor(yellow_background, cv2.COLOR_BGR2RGB)
171
  return yellow_background
172
 
 
173
  def run_length_encoding(mask):
174
  pixels = mask.flatten()
175
  rle = []
176
  last_val = 0
177
  count = 0
178
  for pixel in pixels:
179
+ if pixel was the last val:
180
  count += 1
181
  else:
182
  if count > 0:
 
187
  rle.append(count)
188
  return rle
189
 
 
190
  def detections_to_json(detections):
191
  detections_list = []
192
  for detection in detections:
 
203
  detections_list.append(detection_dict)
204
  return detections_list
205
 
 
206
  def crop_bounding_boxes_with_yellow_background(image: np.ndarray, yellow_background: np.ndarray, detections: List[DetectionResult]) -> List[np.ndarray]:
207
  crops = []
208
  for detection in detections: