Spaces:
Paused
Paused
# Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# -------------------------------------------------------- | |
# CroCo model for downstream tasks | |
# -------------------------------------------------------- | |
import torch | |
from .croco import CroCoNet | |
def croco_args_from_ckpt(ckpt): | |
if 'croco_kwargs' in ckpt: # CroCo v2 released models | |
return ckpt['croco_kwargs'] | |
elif 'args' in ckpt and hasattr(ckpt['args'], 'model'): # pretrained using the official code release | |
s = ckpt['args'].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)" | |
assert s.startswith('CroCoNet(') | |
return eval('dict'+s[len('CroCoNet'):]) # transform it into the string of a dictionary and evaluate it | |
else: # CroCo v1 released models | |
return dict() | |
class CroCoDownstreamMonocularEncoder(CroCoNet): | |
def __init__(self, | |
head, | |
**kwargs): | |
""" Build network for monocular downstream task, only using the encoder. | |
It takes an extra argument head, that is called with the features | |
and a dictionary img_info containing 'width' and 'height' keys | |
The head is setup with the croconet arguments in this init function | |
NOTE: It works by *calling super().__init__() but with redefined setters | |
""" | |
super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs) | |
head.setup(self) | |
self.head = head | |
def _set_mask_generator(self, *args, **kwargs): | |
""" No mask generator """ | |
return | |
def _set_mask_token(self, *args, **kwargs): | |
""" No mask token """ | |
self.mask_token = None | |
return | |
def _set_decoder(self, *args, **kwargs): | |
""" No decoder """ | |
return | |
def _set_prediction_head(self, *args, **kwargs): | |
""" No 'prediction head' for downstream tasks.""" | |
return | |
def forward(self, img): | |
""" | |
img if of size batch_size x 3 x h x w | |
""" | |
B, C, H, W = img.size() | |
img_info = {'height': H, 'width': W} | |
need_all_layers = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks | |
out, _, _ = self._encode_image(img, do_mask=False, return_all_blocks=need_all_layers) | |
return self.head(out, img_info) | |
class CroCoDownstreamBinocular(CroCoNet): | |
def __init__(self, | |
head, | |
**kwargs): | |
""" Build network for binocular downstream task | |
It takes an extra argument head, that is called with the features | |
and a dictionary img_info containing 'width' and 'height' keys | |
The head is setup with the croconet arguments in this init function | |
""" | |
super(CroCoDownstreamBinocular, self).__init__(**kwargs) | |
head.setup(self) | |
self.head = head | |
def _set_mask_generator(self, *args, **kwargs): | |
""" No mask generator """ | |
return | |
def _set_mask_token(self, *args, **kwargs): | |
""" No mask token """ | |
self.mask_token = None | |
return | |
def _set_prediction_head(self, *args, **kwargs): | |
""" No prediction head for downstream tasks, define your own head """ | |
return | |
def encode_image_pairs(self, img1, img2, return_all_blocks=False): | |
""" run encoder for a pair of images | |
it is actually ~5% faster to concatenate the images along the batch dimension | |
than to encode them separately | |
""" | |
## the two commented lines below is the naive version with separate encoding | |
#out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks) | |
#out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False) | |
## and now the faster version | |
out, pos, _ = self._encode_image( torch.cat( (img1,img2), dim=0), do_mask=False, return_all_blocks=return_all_blocks ) | |
if return_all_blocks: | |
out,out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out]))) | |
out2 = out2[-1] | |
else: | |
out,out2 = out.chunk(2, dim=0) | |
pos,pos2 = pos.chunk(2, dim=0) | |
return out, out2, pos, pos2 | |
def forward(self, img1, img2): | |
B, C, H, W = img1.size() | |
img_info = {'height': H, 'width': W} | |
return_all_blocks = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks | |
out, out2, pos, pos2 = self.encode_image_pairs(img1, img2, return_all_blocks=return_all_blocks) | |
if return_all_blocks: | |
decout = self._decoder(out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks) | |
decout = out+decout | |
else: | |
decout = self._decoder(out, pos, None, out2, pos2, return_all_blocks=return_all_blocks) | |
return self.head(decout, img_info) |