Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torchvision | |
from timm.models.vision_transformer import Block | |
import math | |
import gazelle.utils as utils | |
from gazelle.backbone import DinoV2Backbone | |
class GazeLLE(nn.Module): | |
def __init__(self, backbone, inout=False, dim=256, num_layers=3, in_size=(448, 448), out_size=(64, 64)): | |
super().__init__() | |
self.backbone = backbone | |
self.dim = dim | |
self.num_layers = num_layers | |
self.featmap_h, self.featmap_w = backbone.get_out_size(in_size) | |
self.in_size = in_size | |
self.out_size = out_size | |
self.inout = inout | |
self.linear = nn.Conv2d(backbone.get_dimension(), self.dim, 1) | |
self.register_buffer("pos_embed", positionalencoding2d(self.dim, self.featmap_h, self.featmap_w).squeeze(dim=0).squeeze(dim=0)) | |
self.transformer = nn.Sequential(*[ | |
Block( | |
dim=self.dim, | |
num_heads=8, | |
mlp_ratio=4, | |
drop_path=0.1) | |
for i in range(num_layers) | |
]) | |
self.heatmap_head = nn.Sequential( | |
nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), | |
nn.Conv2d(dim, 1, kernel_size=1, bias=False), | |
nn.Sigmoid() | |
) | |
self.head_token = nn.Embedding(1, self.dim) | |
if self.inout: | |
self.inout_head = nn.Sequential( | |
nn.Linear(self.dim, 128), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(128, 1), | |
nn.Sigmoid() | |
) | |
self.inout_token = nn.Embedding(1, self.dim) | |
def forward(self, input): | |
# input["images"]: [B, 3, H, W] tensor of images | |
# input["bboxes"]: list of lists of bbox tuples [[(xmin, ymin, xmax, ymax)]] per image in normalized image coords | |
num_ppl_per_img = [len(bbox_list) for bbox_list in input["bboxes"]] | |
x = self.backbone.forward(input["images"]) | |
x = self.linear(x) | |
x = x + self.pos_embed | |
x = utils.repeat_tensors(x, num_ppl_per_img) # repeat image features along people dimension per image | |
head_maps = torch.cat(self.get_input_head_maps(input["bboxes"]), dim=0).to(x.device) # [sum(N_p), 32, 32] | |
head_map_embeddings = head_maps.unsqueeze(dim=1) * self.head_token.weight.unsqueeze(-1).unsqueeze(-1) | |
x = x + head_map_embeddings | |
x = x.flatten(start_dim=2).permute(0, 2, 1) # "b c h w -> b (h w) c" | |
if self.inout: | |
x = torch.cat([self.inout_token.weight.unsqueeze(dim=0).repeat(x.shape[0], 1, 1), x], dim=1) | |
x = self.transformer(x) | |
if self.inout: | |
inout_tokens = x[:, 0, :] | |
inout_preds = self.inout_head(inout_tokens).squeeze(dim=-1) | |
inout_preds = utils.split_tensors(inout_preds, num_ppl_per_img) | |
x = x[:, 1:, :] # slice off inout tokens from scene tokens | |
x = x.reshape(x.shape[0], self.featmap_h, self.featmap_w, x.shape[2]).permute(0, 3, 1, 2) # b (h w) c -> b c h w | |
x = self.heatmap_head(x).squeeze(dim=1) | |
x = torchvision.transforms.functional.resize(x, self.out_size) | |
heatmap_preds = utils.split_tensors(x, num_ppl_per_img) # resplit per image | |
return {"heatmap": heatmap_preds, "inout": inout_preds if self.inout else None} | |
def get_input_head_maps(self, bboxes): | |
# bboxes: [[(xmin, ymin, xmax, ymax)]] - list of list of head bboxes per image | |
head_maps = [] | |
for bbox_list in bboxes: | |
img_head_maps = [] | |
for bbox in bbox_list: | |
if bbox is None: # no bbox provided, use empty head map | |
img_head_maps.append(torch.zeros(self.featmap_h, self.featmap_w)) | |
else: | |
xmin, ymin, xmax, ymax = bbox | |
width, height = self.featmap_w, self.featmap_h | |
xmin = round(xmin * width) | |
ymin = round(ymin * height) | |
xmax = round(xmax * width) | |
ymax = round(ymax * height) | |
head_map = torch.zeros((height, width)) | |
head_map[ymin:ymax, xmin:xmax] = 1 | |
img_head_maps.append(head_map) | |
head_maps.append(torch.stack(img_head_maps)) | |
return head_maps | |
def get_gazelle_state_dict(self, include_backbone=False): | |
if include_backbone: | |
return self.state_dict() | |
else: | |
return {k: v for k, v in self.state_dict().items() if not k.startswith("backbone")} | |
def load_gazelle_state_dict(self, ckpt_state_dict, include_backbone=False): | |
current_state_dict = self.state_dict() | |
keys1 = current_state_dict.keys() | |
keys2 = ckpt_state_dict.keys() | |
if not include_backbone: | |
keys1 = set([k for k in keys1 if not k.startswith("backbone")]) | |
keys2 = set([k for k in keys2 if not k.startswith("backbone")]) | |
else: | |
keys1 = set(keys1) | |
keys2 = set(keys2) | |
if len(keys2 - keys1) > 0: | |
print("WARNING unused keys in provided state dict: ", keys2 - keys1) | |
if len(keys1 - keys2) > 0: | |
print("WARNING provided state dict does not have values for keys: ", keys1 - keys2) | |
for k in list(keys1 & keys2): | |
current_state_dict[k] = ckpt_state_dict[k] | |
self.load_state_dict(current_state_dict, strict=False) | |
# From https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py | |
def positionalencoding2d(d_model, height, width): | |
""" | |
:param d_model: dimension of the model | |
:param height: height of the positions | |
:param width: width of the positions | |
:return: d_model*height*width position matrix | |
""" | |
if d_model % 4 != 0: | |
raise ValueError("Cannot use sin/cos positional encoding with " | |
"odd dimension (got dim={:d})".format(d_model)) | |
pe = torch.zeros(d_model, height, width) | |
# Each dimension use half of d_model | |
d_model = int(d_model / 2) | |
div_term = torch.exp(torch.arange(0., d_model, 2) * | |
-(math.log(10000.0) / d_model)) | |
pos_w = torch.arange(0., width).unsqueeze(1) | |
pos_h = torch.arange(0., height).unsqueeze(1) | |
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) | |
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) | |
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) | |
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) | |
return pe | |
# models | |
def get_gazelle_model(model_name): | |
factory = { | |
"gazelle_dinov2_vitb14": gazelle_dinov2_vitb14, | |
"gazelle_dinov2_vitl14": gazelle_dinov2_vitl14, | |
"gazelle_dinov2_vitb14_inout": gazelle_dinov2_vitb14_inout, | |
"gazelle_dinov2_vitl14_inout": gazelle_dinov2_vitl14_inout, | |
} | |
assert model_name in factory.keys(), "invalid model name" | |
return factory[model_name]() | |
def gazelle_dinov2_vitb14(): | |
backbone = DinoV2Backbone('dinov2_vitb14') | |
transform = backbone.get_transform((448, 448)) | |
model = GazeLLE(backbone) | |
return model, transform | |
def gazelle_dinov2_vitl14(): | |
backbone = DinoV2Backbone('dinov2_vitl14') | |
transform = backbone.get_transform((448, 448)) | |
model = GazeLLE(backbone) | |
return model, transform | |
def gazelle_dinov2_vitb14_inout(): | |
backbone = DinoV2Backbone('dinov2_vitb14') | |
transform = backbone.get_transform((448, 448)) | |
model = GazeLLE(backbone, inout=True) | |
return model, transform | |
def gazelle_dinov2_vitl14_inout(): | |
backbone = DinoV2Backbone('dinov2_vitl14') | |
transform = backbone.get_transform((448, 448)) | |
model = GazeLLE(backbone, inout=True) | |
return model, transform | |