Martin Tomov commited on
Commit
9fcd716
β€’
1 Parent(s): 81b2e04

Update gsl_utils.py

Browse files
Files changed (1) hide show
  1. gsl_utils.py +15 -14
gsl_utils.py CHANGED
@@ -1,4 +1,5 @@
1
  # GSL
 
2
  import os
3
  import torch
4
  import numpy as np
@@ -15,12 +16,12 @@ def load_groundingdino_model(device='cpu'):
15
  model = pipeline(model="IDEA-Research/grounding-dino-base", task="zero-shot-object-detection", device=device)
16
  return model
17
 
18
- def load_sam_model(device='cpu'):
19
- sam_model = build_sam(checkpoint='sam_vit_h_4b8939.pth').to(device)
20
  return SamPredictor(sam_model)
21
 
22
  groundingdino_model = load_groundingdino_model(device=device)
23
- sam_predictor = load_sam_model(device=device)
24
  simple_lama = SimpleLama()
25
 
26
  def detect(image, model, text_prompt='insect . flower . cloud', box_threshold=0.15, text_threshold=0.15):
@@ -63,11 +64,11 @@ def dilate_mask(mask, dilate_factor=15):
63
  return mask
64
 
65
  def gsl_process_image(image):
66
- # img numpy array
67
  if not isinstance(image, np.ndarray):
68
  image = np.array(image)
69
 
70
- # load img as a PIL
71
  image_pil = Image.fromarray(image)
72
 
73
  detected_boxes = detect(image_pil, groundingdino_model)
@@ -84,16 +85,16 @@ def gsl_process_image(image):
84
  annotated_frame_with_mask = draw_mask(final_mask, image)
85
 
86
  mask = final_mask.numpy()
87
- mask = mask.astype(np.uint8) * 255
88
- mask = dilate_mask(mask)
89
- dilated_image_mask_pil = Image.fromarray(mask)
90
 
91
- result = simple_lama(image, dilated_image_mask_pil)
92
 
93
- diff = ImageChops.difference(result, Image.fromarray(image))
94
- threshold = 7
95
- diff2 = diff.convert('L').point(lambda p: 255 if p > threshold else 0).convert('1')
96
- img3 = Image.new('RGB', Image.fromarray(image).size, (255, 236, 10))
97
- diff3 = Image.composite(Image.fromarray(image), img3, diff2)
98
 
99
  return diff3
 
1
  # GSL
2
+
3
  import os
4
  import torch
5
  import numpy as np
 
16
  model = pipeline(model="IDEA-Research/grounding-dino-base", task="zero-shot-object-detection", device=device)
17
  return model
18
 
19
+ def load_sam_model(checkpoint_path, device='cpu'):
20
+ sam_model = build_sam(checkpoint=checkpoint_path).to(device)
21
  return SamPredictor(sam_model)
22
 
23
  groundingdino_model = load_groundingdino_model(device=device)
24
+ sam_predictor = load_sam_model(checkpoint_path="models/sam_vit_h_4b8939.pth", device=device)
25
  simple_lama = SimpleLama()
26
 
27
  def detect(image, model, text_prompt='insect . flower . cloud', box_threshold=0.15, text_threshold=0.15):
 
64
  return mask
65
 
66
  def gsl_process_image(image):
67
+ # numpy array
68
  if not isinstance(image, np.ndarray):
69
  image = np.array(image)
70
 
71
+ # load image as a PIL
72
  image_pil = Image.fromarray(image)
73
 
74
  detected_boxes = detect(image_pil, groundingdino_model)
 
85
  annotated_frame_with_mask = draw_mask(final_mask, image)
86
 
87
  mask = final_mask.numpy()
88
+ mask is mask.astype(np.uint8) * 255
89
+ mask is dilate_mask(mask)
90
+ dilated_image_mask_pil is Image.fromarray(mask)
91
 
92
+ result is simple_lama(image, dilated_image_mask_pil)
93
 
94
+ diff is ImageChops.difference(result, Image.fromarray(image))
95
+ threshold is 7
96
+ diff2 is diff.convert('L').point(lambda p: 255 if p > threshold else 0).convert('1')
97
+ img3 is Image.new('RGB', Image.fromarray(image).size, (255, 236, 10))
98
+ diff3 is Image.composite(Image.fromarray(image), img3, diff2)
99
 
100
  return diff3