Martin Tomov commited on
Commit
6dc32ee
·
verified ·
1 Parent(s): 96523cc

Update gsl_utils.py

Browse files
Files changed (1) hide show
  1. gsl_utils.py +12 -3
gsl_utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import torch
3
  import numpy as np
@@ -14,7 +15,7 @@ def load_groundingdino_model(device='cpu'):
14
  return model
15
 
16
  groundingdino_model = load_groundingdino_model(device=device)
17
- sam_predictor = None
18
  simple_lama = SimpleLama()
19
 
20
  def detect(image, model, text_prompt='insect . flower . cloud', box_threshold=0.15, text_threshold=0.15):
@@ -23,6 +24,7 @@ def detect(image, model, text_prompt='insect . flower . cloud', box_threshold=0.
23
  return results
24
 
25
  def segment(image, sam_model, boxes):
 
26
  sam_model.set_image(image)
27
  H, W, _ = image.shape
28
  boxes_xyxy = torch.Tensor(boxes) * torch.Tensor([W, H, W, H])
@@ -57,8 +59,15 @@ def dilate_mask(mask, dilate_factor=15):
57
  return mask
58
 
59
  def gsl_process_image(image):
60
- image_source = Image.fromarray(image)
61
- detected_boxes = detect(image_source, groundingdino_model)
 
 
 
 
 
 
 
62
  boxes = [[d['box']['xmin'], d['box']['ymin'], d['box']['xmax'], d['box']['ymax']] for d in detected_boxes]
63
  segmented_frame_masks = segment(image, sam_predictor, boxes)
64
 
 
1
+ # GSL
2
  import os
3
  import torch
4
  import numpy as np
 
15
  return model
16
 
17
  groundingdino_model = load_groundingdino_model(device=device)
18
+ sam_predictor = None # Initialize this properly using build_sam
19
  simple_lama = SimpleLama()
20
 
21
  def detect(image, model, text_prompt='insect . flower . cloud', box_threshold=0.15, text_threshold=0.15):
 
24
  return results
25
 
26
  def segment(image, sam_model, boxes):
27
+ # sam_moded initialized with build_sam
28
  sam_model.set_image(image)
29
  H, W, _ = image.shape
30
  boxes_xyxy = torch.Tensor(boxes) * torch.Tensor([W, H, W, H])
 
59
  return mask
60
 
61
  def gsl_process_image(image):
62
+ # image numpy array
63
+ if not isinstance(image, np.ndarray):
64
+ image = np.array(image)
65
+
66
+ # load as a PIL
67
+ image_pil = Image.fromarray(image)
68
+
69
+ # detect insects
70
+ detected_boxes = detect(image_pil, groundingdino_model)
71
  boxes = [[d['box']['xmin'], d['box']['ymin'], d['box']['xmax'], d['box']['ymax']] for d in detected_boxes]
72
  segmented_frame_masks = segment(image, sam_predictor, boxes)
73