Martin Tomov commited on
Commit
3657d52
·
verified ·
1 Parent(s): cd7cf5e

HF IDEA-Research/grounding-dino-base

Browse files
Files changed (1) hide show
  1. gsl_utils.py +20 -55
gsl_utils.py CHANGED
@@ -1,57 +1,32 @@
1
- # GSL
2
-
3
  import os
4
  import torch
5
  import numpy as np
6
  from PIL import Image, ImageChops, ImageEnhance
7
  import cv2
8
  from simple_lama_inpainting import SimpleLama
9
- from segment_anything import build_sam, SamPredictor
10
- from GroundingDINO.groundingdino.util import box_ops
11
- from GroundingDINO.groundingdino.util.slconfig import SLConfig
12
- from GroundingDINO.groundingdino.util.utils import clean_state_dict
13
- from GroundingDINO.groundingdino.util.inference import annotate, load_image, predict
14
  from huggingface_hub import hf_hub_download
15
 
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
 
18
- def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
19
- cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
20
- args = SLConfig.fromfile(cache_config_file)
21
- args.device = device
22
- model = build_model(args)
23
- cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
24
- checkpoint = torch.load(cache_file, map_location=device)
25
- model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
26
- model.eval()
27
  return model
28
 
29
- groundingdino_model = load_model_hf(
30
- repo_id="ShilongLiu/GroundingDINO",
31
- filename="groundingdino_swinb_cogcoor.pth",
32
- ckpt_config_filename="GroundingDINO_SwinB.cfg.py",
33
- device=device
34
- )
35
-
36
- sam_predictor = SamPredictor(build_sam(checkpoint='sam_vit_h_4b8939.pth').to(device))
37
  simple_lama = SimpleLama()
38
 
39
  def detect(image, model, text_prompt='insect . flower . cloud', box_threshold=0.15, text_threshold=0.15):
40
- boxes, logits, phrases = predict(
41
- image=image,
42
- model=model,
43
- caption=text_prompt,
44
- box_threshold=box_threshold,
45
- text_threshold=text_threshold
46
- )
47
- annotated_frame = annotate(image_source=image, boxes=boxes, logits=logits, phrases=phrases)
48
- annotated_frame = annotated_frame[..., ::-1] # BGR to RGB
49
- return annotated_frame, boxes, phrases
50
 
51
  def segment(image, sam_model, boxes):
52
  sam_model.set_image(image)
53
  H, W, _ = image.shape
54
- boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
 
55
  transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
56
  masks, _, _ = sam_model.predict_torch(
57
  point_coords=None,
@@ -81,18 +56,12 @@ def dilate_mask(mask, dilate_factor=15):
81
  )
82
  return mask
83
 
84
- def gsl_process_image(local_image_path):
85
- # Load image
86
- image_source, image = load_image(local_image_path)
87
-
88
- # Detect insects
89
- annotated_frame, detected_boxes, phrases = detect(image, model=groundingdino_model)
90
- indices = [i for i, s in enumerate(phrases) if 'insect' in s]
91
-
92
- # Segment insects
93
- segmented_frame_masks = segment(image_source, sam_predictor, detected_boxes[indices])
94
 
95
- # Combine masks
96
  final_mask = None
97
  for i in range(len(segmented_frame_masks) - 1):
98
  if final_mask is None:
@@ -100,23 +69,19 @@ def gsl_process_image(local_image_path):
100
  else:
101
  final_mask = np.bitwise_or(final_mask, segmented_frame_masks[i + 1][0].cpu())
102
 
103
- # Draw mask
104
- annotated_frame_with_mask = draw_mask(final_mask, image_source)
105
 
106
- # Dilate mask
107
  mask = final_mask.numpy()
108
  mask = mask.astype(np.uint8) * 255
109
  mask = dilate_mask(mask)
110
  dilated_image_mask_pil = Image.fromarray(mask)
111
 
112
- # Inpainting
113
- result = simple_lama(image_source, dilated_image_mask_pil)
114
 
115
- # Difference and composite
116
- diff = ImageChops.difference(result, Image.fromarray(image_source))
117
  threshold = 7
118
  diff2 = diff.convert('L').point(lambda p: 255 if p > threshold else 0).convert('1')
119
- img3 = Image.new('RGB', Image.fromarray(image_source).size, (255, 236, 10))
120
- diff3 = Image.composite(Image.fromarray(image_source), img3, diff2)
121
 
122
  return diff3
 
 
 
1
  import os
2
  import torch
3
  import numpy as np
4
  from PIL import Image, ImageChops, ImageEnhance
5
  import cv2
6
  from simple_lama_inpainting import SimpleLama
7
+ from transformers import pipeline
 
 
 
 
8
  from huggingface_hub import hf_hub_download
9
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
+ def load_groundingdino_model(device='cpu'):
13
+ model = pipeline(model="IDEA-Research/grounding-dino-base", task="zero-shot-object-detection", device=device)
 
 
 
 
 
 
 
14
  return model
15
 
16
+ groundingdino_model = load_groundingdino_model(device=device)
17
+ sam_predictor = None
 
 
 
 
 
 
18
  simple_lama = SimpleLama()
19
 
20
  def detect(image, model, text_prompt='insect . flower . cloud', box_threshold=0.15, text_threshold=0.15):
21
+ labels = [label if label.endswith('.') else label + '.' for label in text_prompt.split('.')]
22
+ results = model(image, candidate_labels=labels, threshold=box_threshold)
23
+ return results
 
 
 
 
 
 
 
24
 
25
  def segment(image, sam_model, boxes):
26
  sam_model.set_image(image)
27
  H, W, _ = image.shape
28
+ boxes_xyxy = torch.Tensor(boxes) * torch.Tensor([W, H, W, H])
29
+
30
  transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
31
  masks, _, _ = sam_model.predict_torch(
32
  point_coords=None,
 
56
  )
57
  return mask
58
 
59
+ def gsl_process_image(image):
60
+ image_source = Image.fromarray(image)
61
+ detected_boxes = detect(image_source, groundingdino_model)
62
+ boxes = [[d['box']['xmin'], d['box']['ymin'], d['box']['xmax'], d['box']['ymax']] for d in detected_boxes]
63
+ segmented_frame_masks = segment(image, sam_predictor, boxes)
 
 
 
 
 
64
 
 
65
  final_mask = None
66
  for i in range(len(segmented_frame_masks) - 1):
67
  if final_mask is None:
 
69
  else:
70
  final_mask = np.bitwise_or(final_mask, segmented_frame_masks[i + 1][0].cpu())
71
 
72
+ annotated_frame_with_mask = draw_mask(final_mask, image)
 
73
 
 
74
  mask = final_mask.numpy()
75
  mask = mask.astype(np.uint8) * 255
76
  mask = dilate_mask(mask)
77
  dilated_image_mask_pil = Image.fromarray(mask)
78
 
79
+ result = simple_lama(image, dilated_image_mask_pil)
 
80
 
81
+ diff = ImageChops.difference(result, Image.fromarray(image))
 
82
  threshold = 7
83
  diff2 = diff.convert('L').point(lambda p: 255 if p > threshold else 0).convert('1')
84
+ img3 = Image.new('RGB', Image.fromarray(image).size, (255, 236, 10))
85
+ diff3 = Image.composite(Image.fromarray(image), img3, diff2)
86
 
87
  return diff3