Gaze-LLE / gazelle /model.py
fffiloni's picture
Migrated from GitHub
9c9498f verified
import torch
import torch.nn as nn
import torchvision
from timm.models.vision_transformer import Block
import math
import gazelle.utils as utils
from gazelle.backbone import DinoV2Backbone
class GazeLLE(nn.Module):
def __init__(self, backbone, inout=False, dim=256, num_layers=3, in_size=(448, 448), out_size=(64, 64)):
super().__init__()
self.backbone = backbone
self.dim = dim
self.num_layers = num_layers
self.featmap_h, self.featmap_w = backbone.get_out_size(in_size)
self.in_size = in_size
self.out_size = out_size
self.inout = inout
self.linear = nn.Conv2d(backbone.get_dimension(), self.dim, 1)
self.register_buffer("pos_embed", positionalencoding2d(self.dim, self.featmap_h, self.featmap_w).squeeze(dim=0).squeeze(dim=0))
self.transformer = nn.Sequential(*[
Block(
dim=self.dim,
num_heads=8,
mlp_ratio=4,
drop_path=0.1)
for i in range(num_layers)
])
self.heatmap_head = nn.Sequential(
nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2),
nn.Conv2d(dim, 1, kernel_size=1, bias=False),
nn.Sigmoid()
)
self.head_token = nn.Embedding(1, self.dim)
if self.inout:
self.inout_head = nn.Sequential(
nn.Linear(self.dim, 128),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(128, 1),
nn.Sigmoid()
)
self.inout_token = nn.Embedding(1, self.dim)
def forward(self, input):
# input["images"]: [B, 3, H, W] tensor of images
# input["bboxes"]: list of lists of bbox tuples [[(xmin, ymin, xmax, ymax)]] per image in normalized image coords
num_ppl_per_img = [len(bbox_list) for bbox_list in input["bboxes"]]
x = self.backbone.forward(input["images"])
x = self.linear(x)
x = x + self.pos_embed
x = utils.repeat_tensors(x, num_ppl_per_img) # repeat image features along people dimension per image
head_maps = torch.cat(self.get_input_head_maps(input["bboxes"]), dim=0).to(x.device) # [sum(N_p), 32, 32]
head_map_embeddings = head_maps.unsqueeze(dim=1) * self.head_token.weight.unsqueeze(-1).unsqueeze(-1)
x = x + head_map_embeddings
x = x.flatten(start_dim=2).permute(0, 2, 1) # "b c h w -> b (h w) c"
if self.inout:
x = torch.cat([self.inout_token.weight.unsqueeze(dim=0).repeat(x.shape[0], 1, 1), x], dim=1)
x = self.transformer(x)
if self.inout:
inout_tokens = x[:, 0, :]
inout_preds = self.inout_head(inout_tokens).squeeze(dim=-1)
inout_preds = utils.split_tensors(inout_preds, num_ppl_per_img)
x = x[:, 1:, :] # slice off inout tokens from scene tokens
x = x.reshape(x.shape[0], self.featmap_h, self.featmap_w, x.shape[2]).permute(0, 3, 1, 2) # b (h w) c -> b c h w
x = self.heatmap_head(x).squeeze(dim=1)
x = torchvision.transforms.functional.resize(x, self.out_size)
heatmap_preds = utils.split_tensors(x, num_ppl_per_img) # resplit per image
return {"heatmap": heatmap_preds, "inout": inout_preds if self.inout else None}
def get_input_head_maps(self, bboxes):
# bboxes: [[(xmin, ymin, xmax, ymax)]] - list of list of head bboxes per image
head_maps = []
for bbox_list in bboxes:
img_head_maps = []
for bbox in bbox_list:
if bbox is None: # no bbox provided, use empty head map
img_head_maps.append(torch.zeros(self.featmap_h, self.featmap_w))
else:
xmin, ymin, xmax, ymax = bbox
width, height = self.featmap_w, self.featmap_h
xmin = round(xmin * width)
ymin = round(ymin * height)
xmax = round(xmax * width)
ymax = round(ymax * height)
head_map = torch.zeros((height, width))
head_map[ymin:ymax, xmin:xmax] = 1
img_head_maps.append(head_map)
head_maps.append(torch.stack(img_head_maps))
return head_maps
def get_gazelle_state_dict(self, include_backbone=False):
if include_backbone:
return self.state_dict()
else:
return {k: v for k, v in self.state_dict().items() if not k.startswith("backbone")}
def load_gazelle_state_dict(self, ckpt_state_dict, include_backbone=False):
current_state_dict = self.state_dict()
keys1 = current_state_dict.keys()
keys2 = ckpt_state_dict.keys()
if not include_backbone:
keys1 = set([k for k in keys1 if not k.startswith("backbone")])
keys2 = set([k for k in keys2 if not k.startswith("backbone")])
else:
keys1 = set(keys1)
keys2 = set(keys2)
if len(keys2 - keys1) > 0:
print("WARNING unused keys in provided state dict: ", keys2 - keys1)
if len(keys1 - keys2) > 0:
print("WARNING provided state dict does not have values for keys: ", keys1 - keys2)
for k in list(keys1 & keys2):
current_state_dict[k] = ckpt_state_dict[k]
self.load_state_dict(current_state_dict, strict=False)
# From https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py
def positionalencoding2d(d_model, height, width):
"""
:param d_model: dimension of the model
:param height: height of the positions
:param width: width of the positions
:return: d_model*height*width position matrix
"""
if d_model % 4 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dimension (got dim={:d})".format(d_model))
pe = torch.zeros(d_model, height, width)
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = torch.exp(torch.arange(0., d_model, 2) *
-(math.log(10000.0) / d_model))
pos_w = torch.arange(0., width).unsqueeze(1)
pos_h = torch.arange(0., height).unsqueeze(1)
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
return pe
# models
def get_gazelle_model(model_name):
factory = {
"gazelle_dinov2_vitb14": gazelle_dinov2_vitb14,
"gazelle_dinov2_vitl14": gazelle_dinov2_vitl14,
"gazelle_dinov2_vitb14_inout": gazelle_dinov2_vitb14_inout,
"gazelle_dinov2_vitl14_inout": gazelle_dinov2_vitl14_inout,
}
assert model_name in factory.keys(), "invalid model name"
return factory[model_name]()
def gazelle_dinov2_vitb14():
backbone = DinoV2Backbone('dinov2_vitb14')
transform = backbone.get_transform((448, 448))
model = GazeLLE(backbone)
return model, transform
def gazelle_dinov2_vitl14():
backbone = DinoV2Backbone('dinov2_vitl14')
transform = backbone.get_transform((448, 448))
model = GazeLLE(backbone)
return model, transform
def gazelle_dinov2_vitb14_inout():
backbone = DinoV2Backbone('dinov2_vitb14')
transform = backbone.get_transform((448, 448))
model = GazeLLE(backbone, inout=True)
return model, transform
def gazelle_dinov2_vitl14_inout():
backbone = DinoV2Backbone('dinov2_vitl14')
transform = backbone.get_transform((448, 448))
model = GazeLLE(backbone, inout=True)
return model, transform