liuyizhang commited on
Commit
9403943
·
1 Parent(s): c419c35
GroundingDINO/groundingdino/config/GroundingDINO_SwinB.cfg.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size = 1
2
+ modelname = "groundingdino"
3
+ backbone = "swin_B_384_22k"
4
+ position_embedding = "sine"
5
+ pe_temperatureH = 20
6
+ pe_temperatureW = 20
7
+ return_interm_indices = [1, 2, 3]
8
+ backbone_freeze_keywords = None
9
+ enc_layers = 6
10
+ dec_layers = 6
11
+ pre_norm = False
12
+ dim_feedforward = 2048
13
+ hidden_dim = 256
14
+ dropout = 0.0
15
+ nheads = 8
16
+ num_queries = 900
17
+ query_dim = 4
18
+ num_patterns = 0
19
+ num_feature_levels = 4
20
+ enc_n_points = 4
21
+ dec_n_points = 4
22
+ two_stage_type = "standard"
23
+ two_stage_bbox_embed_share = False
24
+ two_stage_class_embed_share = False
25
+ transformer_activation = "relu"
26
+ dec_pred_bbox_embed_share = True
27
+ dn_box_noise_scale = 1.0
28
+ dn_label_noise_ratio = 0.5
29
+ dn_label_coef = 1.0
30
+ dn_bbox_coef = 1.0
31
+ embed_init_tgt = True
32
+ dn_labelbook_size = 2000
33
+ max_text_len = 256
34
+ text_encoder_type = "bert-base-uncased"
35
+ use_text_enhancer = True
36
+ use_fusion_layer = True
37
+ use_checkpoint = True
38
+ use_transformer_ckpt = True
39
+ use_text_cross_attention = True
40
+ text_dropout = 0.0
41
+ fusion_dropout = 0.0
42
+ fusion_droppath = 0.1
43
+ sub_sentence_present = True
GroundingDINO/groundingdino/datasets/__init__.py ADDED
File without changes
GroundingDINO/groundingdino/util/inference.py CHANGED
@@ -13,6 +13,10 @@ from groundingdino.util.misc import clean_state_dict
13
  from groundingdino.util.slconfig import SLConfig
14
  from groundingdino.util.utils import get_phrases_from_posmap
15
 
 
 
 
 
16
 
17
  def preprocess_caption(caption: str) -> str:
18
  result = caption.lower().strip()
@@ -96,3 +100,143 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor
96
  annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
97
  annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
98
  return annotated_frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from groundingdino.util.slconfig import SLConfig
14
  from groundingdino.util.utils import get_phrases_from_posmap
15
 
16
+ # ----------------------------------------------------------------------------------------------------------------------
17
+ # OLD API
18
+ # ----------------------------------------------------------------------------------------------------------------------
19
+
20
 
21
  def preprocess_caption(caption: str) -> str:
22
  result = caption.lower().strip()
 
100
  annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
101
  annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
102
  return annotated_frame
103
+
104
+
105
+ # ----------------------------------------------------------------------------------------------------------------------
106
+ # NEW API
107
+ # ----------------------------------------------------------------------------------------------------------------------
108
+
109
+
110
+ class Model:
111
+
112
+ def __init__(
113
+ self,
114
+ model_config_path: str,
115
+ model_checkpoint_path: str,
116
+ device: str = "cuda"
117
+ ):
118
+ self.model = load_model(
119
+ model_config_path=model_config_path,
120
+ model_checkpoint_path=model_checkpoint_path,
121
+ device=device
122
+ ).to(device)
123
+ self.device = device
124
+
125
+ def predict_with_caption(
126
+ self,
127
+ image: np.ndarray,
128
+ caption: str,
129
+ box_threshold: float = 0.35,
130
+ text_threshold: float = 0.25
131
+ ) -> Tuple[sv.Detections, List[str]]:
132
+ """
133
+ import cv2
134
+
135
+ image = cv2.imread(IMAGE_PATH)
136
+
137
+ model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
138
+ detections, labels = model.predict_with_caption(
139
+ image=image,
140
+ caption=caption,
141
+ box_threshold=BOX_THRESHOLD,
142
+ text_threshold=TEXT_THRESHOLD
143
+ )
144
+
145
+ import supervision as sv
146
+
147
+ box_annotator = sv.BoxAnnotator()
148
+ annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels)
149
+ """
150
+ processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
151
+ boxes, logits, phrases = predict(
152
+ model=self.model,
153
+ image=processed_image,
154
+ caption=caption,
155
+ box_threshold=box_threshold,
156
+ text_threshold=text_threshold)
157
+ source_h, source_w, _ = image.shape
158
+ detections = Model.post_process_result(
159
+ source_h=source_h,
160
+ source_w=source_w,
161
+ boxes=boxes,
162
+ logits=logits)
163
+ return detections, phrases
164
+
165
+ def predict_with_classes(
166
+ self,
167
+ image: np.ndarray,
168
+ classes: List[str],
169
+ box_threshold: float,
170
+ text_threshold: float
171
+ ) -> sv.Detections:
172
+ """
173
+ import cv2
174
+
175
+ image = cv2.imread(IMAGE_PATH)
176
+
177
+ model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
178
+ detections = model.predict_with_classes(
179
+ image=image,
180
+ classes=CLASSES,
181
+ box_threshold=BOX_THRESHOLD,
182
+ text_threshold=TEXT_THRESHOLD
183
+ )
184
+
185
+
186
+ import supervision as sv
187
+
188
+ box_annotator = sv.BoxAnnotator()
189
+ annotated_image = box_annotator.annotate(scene=image, detections=detections)
190
+ """
191
+ caption = ", ".join(classes)
192
+ processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
193
+ boxes, logits, phrases = predict(
194
+ model=self.model,
195
+ image=processed_image,
196
+ caption=caption,
197
+ box_threshold=box_threshold,
198
+ text_threshold=text_threshold)
199
+ source_h, source_w, _ = image.shape
200
+ detections = Model.post_process_result(
201
+ source_h=source_h,
202
+ source_w=source_w,
203
+ boxes=boxes,
204
+ logits=logits)
205
+ class_id = Model.phrases2classes(phrases=phrases, classes=classes)
206
+ detections.class_id = class_id
207
+ return detections
208
+
209
+ @staticmethod
210
+ def preprocess_image(image_bgr: np.ndarray) -> torch.Tensor:
211
+ transform = T.Compose(
212
+ [
213
+ T.RandomResize([800], max_size=1333),
214
+ T.ToTensor(),
215
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
216
+ ]
217
+ )
218
+ image_pillow = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
219
+ image_transformed, _ = transform(image_pillow, None)
220
+ return image_transformed
221
+
222
+ @staticmethod
223
+ def post_process_result(
224
+ source_h: int,
225
+ source_w: int,
226
+ boxes: torch.Tensor,
227
+ logits: torch.Tensor
228
+ ) -> sv.Detections:
229
+ boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
230
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
231
+ confidence = logits.numpy()
232
+ return sv.Detections(xyxy=xyxy, confidence=confidence)
233
+
234
+ @staticmethod
235
+ def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray:
236
+ class_ids = []
237
+ for phrase in phrases:
238
+ try:
239
+ class_ids.append(classes.index(phrase))
240
+ except ValueError:
241
+ class_ids.append(None)
242
+ return np.array(class_ids)
GroundingDINO/groundingdino/util/slconfig.py CHANGED
@@ -2,13 +2,13 @@
2
  # Modified from mmcv
3
  # ==========================================================
4
  import ast
 
5
  import os.path as osp
6
  import shutil
7
  import sys
8
  import tempfile
9
  from argparse import Action
10
  from importlib import import_module
11
- import platform
12
 
13
  from addict import Dict
14
  from yapf.yapflib.yapf_api import FormatCode
@@ -81,7 +81,7 @@ class SLConfig(object):
81
  with tempfile.TemporaryDirectory() as temp_config_dir:
82
  temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=".py")
83
  temp_config_name = osp.basename(temp_config_file.name)
84
- if platform.system() == 'Windows':
85
  temp_config_file.close()
86
  shutil.copyfile(filename, osp.join(temp_config_dir, temp_config_name))
87
  temp_module_name = osp.splitext(temp_config_name)[0]
 
2
  # Modified from mmcv
3
  # ==========================================================
4
  import ast
5
+ import os
6
  import os.path as osp
7
  import shutil
8
  import sys
9
  import tempfile
10
  from argparse import Action
11
  from importlib import import_module
 
12
 
13
  from addict import Dict
14
  from yapf.yapflib.yapf_api import FormatCode
 
81
  with tempfile.TemporaryDirectory() as temp_config_dir:
82
  temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=".py")
83
  temp_config_name = osp.basename(temp_config_file.name)
84
+ if os.name == 'nt':
85
  temp_config_file.close()
86
  shutil.copyfile(filename, osp.join(temp_config_dir, temp_config_name))
87
  temp_module_name = osp.splitext(temp_config_name)[0]
GroundingDINO/groundingdino/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.1.0'
 
1
+ __version__ = "0.1.0"