yansong1616 commited on
Commit
539d1de
·
verified ·
1 Parent(s): 83bfcdf

Update SAM2/sam2/build_sam.py

Browse files
Files changed (1) hide show
  1. SAM2/sam2/build_sam.py +94 -90
SAM2/sam2/build_sam.py CHANGED
@@ -1,90 +1,94 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import logging
8
-
9
- import torch
10
- from hydra import compose
11
- from hydra.utils import instantiate
12
- from omegaconf import OmegaConf
13
-
14
-
15
- def build_sam2(
16
- config_file,
17
- ckpt_path=None,
18
- device="cuda",
19
- mode="eval",
20
- hydra_overrides_extra=[],
21
- apply_postprocessing=True,
22
- ):
23
-
24
- if apply_postprocessing:
25
- hydra_overrides_extra = hydra_overrides_extra.copy()
26
- hydra_overrides_extra += [
27
- # dynamically fall back to multi-mask if the single mask is not stable
28
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
29
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
30
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
31
- ]
32
- # Read config and init model
33
- cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
34
- OmegaConf.resolve(cfg)
35
- model = instantiate(cfg.model, _recursive_=True)
36
- _load_checkpoint(model, ckpt_path)
37
- model = model.to(device)
38
- if mode == "eval":
39
- model.eval()
40
- return model
41
-
42
-
43
- def build_sam2_video_predictor(
44
- config_file,
45
- ckpt_path=None,
46
- device="cuda",
47
- mode="eval",
48
- hydra_overrides_extra=[],
49
- apply_postprocessing=True,
50
- ):
51
- print('... loading SAM2_Video from', ckpt_path)
52
- hydra_overrides = [
53
- "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
54
- ]
55
- if apply_postprocessing:
56
- hydra_overrides_extra = hydra_overrides_extra.copy()
57
- hydra_overrides_extra += [
58
- # dynamically fall back to multi-mask if the single mask is not stable
59
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
60
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
61
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
62
- # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
63
- "++model.binarize_mask_from_pts_for_mem_enc=true",
64
- # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
65
- "++model.fill_hole_area=8",
66
- ]
67
- hydra_overrides.extend(hydra_overrides_extra)
68
-
69
- # Read config and init model
70
- cfg = compose(config_name=config_file, overrides=hydra_overrides)
71
- OmegaConf.resolve(cfg)
72
- model = instantiate(cfg.model, _recursive_=True)
73
- _load_checkpoint(model, ckpt_path)
74
- model = model.to(device)
75
- if mode == "eval":
76
- model.eval()
77
- return model
78
-
79
-
80
- def _load_checkpoint(model, ckpt_path):
81
- if ckpt_path is not None:
82
- sd = torch.load(ckpt_path, map_location="cpu")["model"]
83
- missing_keys, unexpected_keys = model.load_state_dict(sd)
84
- if missing_keys:
85
- logging.error(missing_keys)
86
- raise RuntimeError()
87
- if unexpected_keys:
88
- logging.error(unexpected_keys)
89
- raise RuntimeError()
90
- logging.info("Loaded checkpoint sucessfully")
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ import torch
10
+ from hydra import compose
11
+ from hydra.utils import instantiate
12
+ from omegaconf import OmegaConf
13
+
14
+
15
+ def build_sam2(
16
+ config_file,
17
+ ckpt_path=None,
18
+ device="cuda",
19
+ mode="eval",
20
+ hydra_overrides_extra=[],
21
+ apply_postprocessing=True,
22
+ ):
23
+
24
+ if apply_postprocessing:
25
+ hydra_overrides_extra = hydra_overrides_extra.copy()
26
+ hydra_overrides_extra += [
27
+ # dynamically fall back to multi-mask if the single mask is not stable
28
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
29
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
30
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
31
+ ]
32
+ # Read config and init model
33
+ cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
34
+ OmegaConf.resolve(cfg)
35
+ model = instantiate(cfg.model, _recursive_=True)
36
+ _load_checkpoint(model, ckpt_path)
37
+ model = model.to(device)
38
+ if mode == "eval":
39
+ model.eval()
40
+ return model
41
+
42
+
43
+ def build_sam2_video_predictor(
44
+ config_file,
45
+ ckpt_path=None,
46
+ device="cuda",
47
+ mode="eval",
48
+ hydra_overrides_extra=[],
49
+ apply_postprocessing=True,
50
+ ):
51
+ print('... loading SAM2_Video from', ckpt_path)
52
+ hydra_overrides = [
53
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
54
+ ]
55
+ if apply_postprocessing:
56
+ hydra_overrides_extra = hydra_overrides_extra.copy()
57
+ hydra_overrides_extra += [
58
+ # dynamically fall back to multi-mask if the single mask is not stable
59
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
60
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
61
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
62
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
63
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
64
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
65
+ "++model.fill_hole_area=8",
66
+ ]
67
+ hydra_overrides.extend(hydra_overrides_extra)
68
+
69
+
70
+ import os
71
+ config_file = os.path.join(os.path.dirname(__file__), 'SAM2', 'sam2_configs')
72
+
73
+ # Read config and init model
74
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
75
+ OmegaConf.resolve(cfg)
76
+ model = instantiate(cfg.model, _recursive_=True)
77
+ _load_checkpoint(model, ckpt_path)
78
+ model = model.to(device)
79
+ if mode == "eval":
80
+ model.eval()
81
+ return model
82
+
83
+
84
+ def _load_checkpoint(model, ckpt_path):
85
+ if ckpt_path is not None:
86
+ sd = torch.load(ckpt_path, map_location="cpu")["model"]
87
+ missing_keys, unexpected_keys = model.load_state_dict(sd)
88
+ if missing_keys:
89
+ logging.error(missing_keys)
90
+ raise RuntimeError()
91
+ if unexpected_keys:
92
+ logging.error(unexpected_keys)
93
+ raise RuntimeError()
94
+ logging.info("Loaded checkpoint sucessfully")