geopavlakos commited on
Commit
6cacefa
·
1 Parent(s): d7a991a

Expand bbox

Browse files
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -46,6 +46,12 @@ DEFAULT_CHECKPOINT='_DATA/hamer_ckpts/checkpoints/hamer.ckpt'
46
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
47
  model_cfg = str(Path(DEFAULT_CHECKPOINT).parent.parent / 'model_config.yaml')
48
  model_cfg = get_config(model_cfg)
 
 
 
 
 
 
49
  model = HAMER.load_from_checkpoint(DEFAULT_CHECKPOINT, strict=False, cfg=model_cfg).to(device)
50
  model.eval()
51
 
 
46
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
47
  model_cfg = str(Path(DEFAULT_CHECKPOINT).parent.parent / 'model_config.yaml')
48
  model_cfg = get_config(model_cfg)
49
+ # Override some config values, to crop bbox correctly
50
+ if (model_cfg.MODEL.BACKBONE.TYPE == 'vit') and ('BBOX_SHAPE' not in model_cfg.MODEL):
51
+ model_cfg.defrost()
52
+ assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
53
+ model_cfg.MODEL.BBOX_SHAPE = [192,256]
54
+ model_cfg.freeze()
55
  model = HAMER.load_from_checkpoint(DEFAULT_CHECKPOINT, strict=False, cfg=model_cfg).to(device)
56
  model.eval()
57