ShilongLiu commited on
Commit
af720a1
·
1 Parent(s): 27486e3
Files changed (2) hide show
  1. app.py +3 -3
  2. groundingdino/util/inference.py +7 -6
app.py CHANGED
@@ -34,10 +34,10 @@ ckpt_repo_id = "ShilongLiu/GroundingDINO"
34
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
35
 
36
 
37
- def load_model_hf(model_config_path, repo_id, filename):
38
  args = SLConfig.fromfile(model_config_path)
39
- args.device = 'cuda'
40
  model = build_model(args)
 
41
 
42
  cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
43
  checkpoint = torch.load(cache_file, map_location='cpu')
@@ -72,7 +72,7 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
72
  image_pil: Image = image_transform_grounding_for_vis(init_image)
73
 
74
  # run grounidng
75
- boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold)
76
  annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
77
  image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
78
 
 
34
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
35
 
36
 
37
+ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
38
  args = SLConfig.fromfile(model_config_path)
 
39
  model = build_model(args)
40
+ args.device = device
41
 
42
  cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
43
  checkpoint = torch.load(cache_file, map_location='cpu')
 
72
  image_pil: Image = image_transform_grounding_for_vis(init_image)
73
 
74
  # run grounidng
75
+ boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
76
  annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
77
  image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
78
 
groundingdino/util/inference.py CHANGED
@@ -21,9 +21,9 @@ def preprocess_caption(caption: str) -> str:
21
  return result + "."
22
 
23
 
24
- def load_model(model_config_path: str, model_checkpoint_path: str):
25
  args = SLConfig.fromfile(model_config_path)
26
- args.device = "cuda"
27
  model = build_model(args)
28
  checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
29
  model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
@@ -50,12 +50,13 @@ def predict(
50
  image: torch.Tensor,
51
  caption: str,
52
  box_threshold: float,
53
- text_threshold: float
 
54
  ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
55
  caption = preprocess_caption(caption=caption)
56
-
57
- model = model.cuda()
58
- image = image.cuda()
59
 
60
  with torch.no_grad():
61
  outputs = model(image[None], captions=[caption])
 
21
  return result + "."
22
 
23
 
24
+ def load_model(model_config_path: str, model_checkpoint_path: str, device='cuda'):
25
  args = SLConfig.fromfile(model_config_path)
26
+ args.device = device
27
  model = build_model(args)
28
  checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
29
  model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
 
50
  image: torch.Tensor,
51
  caption: str,
52
  box_threshold: float,
53
+ text_threshold: float,
54
+ device='cuda',
55
  ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
56
  caption = preprocess_caption(caption=caption)
57
+
58
+ model = model.to(device)
59
+ image = image.to(device)
60
 
61
  with torch.no_grad():
62
  outputs = model(image[None], captions=[caption])