Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,928 Bytes
9c9498f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import torch
import torch.nn as nn
import torchvision
from timm.models.vision_transformer import Block
import math
import gazelle.utils as utils
from gazelle.backbone import DinoV2Backbone
class GazeLLE(nn.Module):
def __init__(self, backbone, inout=False, dim=256, num_layers=3, in_size=(448, 448), out_size=(64, 64)):
super().__init__()
self.backbone = backbone
self.dim = dim
self.num_layers = num_layers
self.featmap_h, self.featmap_w = backbone.get_out_size(in_size)
self.in_size = in_size
self.out_size = out_size
self.inout = inout
self.linear = nn.Conv2d(backbone.get_dimension(), self.dim, 1)
self.register_buffer("pos_embed", positionalencoding2d(self.dim, self.featmap_h, self.featmap_w).squeeze(dim=0).squeeze(dim=0))
self.transformer = nn.Sequential(*[
Block(
dim=self.dim,
num_heads=8,
mlp_ratio=4,
drop_path=0.1)
for i in range(num_layers)
])
self.heatmap_head = nn.Sequential(
nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2),
nn.Conv2d(dim, 1, kernel_size=1, bias=False),
nn.Sigmoid()
)
self.head_token = nn.Embedding(1, self.dim)
if self.inout:
self.inout_head = nn.Sequential(
nn.Linear(self.dim, 128),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(128, 1),
nn.Sigmoid()
)
self.inout_token = nn.Embedding(1, self.dim)
def forward(self, input):
# input["images"]: [B, 3, H, W] tensor of images
# input["bboxes"]: list of lists of bbox tuples [[(xmin, ymin, xmax, ymax)]] per image in normalized image coords
num_ppl_per_img = [len(bbox_list) for bbox_list in input["bboxes"]]
x = self.backbone.forward(input["images"])
x = self.linear(x)
x = x + self.pos_embed
x = utils.repeat_tensors(x, num_ppl_per_img) # repeat image features along people dimension per image
head_maps = torch.cat(self.get_input_head_maps(input["bboxes"]), dim=0).to(x.device) # [sum(N_p), 32, 32]
head_map_embeddings = head_maps.unsqueeze(dim=1) * self.head_token.weight.unsqueeze(-1).unsqueeze(-1)
x = x + head_map_embeddings
x = x.flatten(start_dim=2).permute(0, 2, 1) # "b c h w -> b (h w) c"
if self.inout:
x = torch.cat([self.inout_token.weight.unsqueeze(dim=0).repeat(x.shape[0], 1, 1), x], dim=1)
x = self.transformer(x)
if self.inout:
inout_tokens = x[:, 0, :]
inout_preds = self.inout_head(inout_tokens).squeeze(dim=-1)
inout_preds = utils.split_tensors(inout_preds, num_ppl_per_img)
x = x[:, 1:, :] # slice off inout tokens from scene tokens
x = x.reshape(x.shape[0], self.featmap_h, self.featmap_w, x.shape[2]).permute(0, 3, 1, 2) # b (h w) c -> b c h w
x = self.heatmap_head(x).squeeze(dim=1)
x = torchvision.transforms.functional.resize(x, self.out_size)
heatmap_preds = utils.split_tensors(x, num_ppl_per_img) # resplit per image
return {"heatmap": heatmap_preds, "inout": inout_preds if self.inout else None}
def get_input_head_maps(self, bboxes):
# bboxes: [[(xmin, ymin, xmax, ymax)]] - list of list of head bboxes per image
head_maps = []
for bbox_list in bboxes:
img_head_maps = []
for bbox in bbox_list:
if bbox is None: # no bbox provided, use empty head map
img_head_maps.append(torch.zeros(self.featmap_h, self.featmap_w))
else:
xmin, ymin, xmax, ymax = bbox
width, height = self.featmap_w, self.featmap_h
xmin = round(xmin * width)
ymin = round(ymin * height)
xmax = round(xmax * width)
ymax = round(ymax * height)
head_map = torch.zeros((height, width))
head_map[ymin:ymax, xmin:xmax] = 1
img_head_maps.append(head_map)
head_maps.append(torch.stack(img_head_maps))
return head_maps
def get_gazelle_state_dict(self, include_backbone=False):
if include_backbone:
return self.state_dict()
else:
return {k: v for k, v in self.state_dict().items() if not k.startswith("backbone")}
def load_gazelle_state_dict(self, ckpt_state_dict, include_backbone=False):
current_state_dict = self.state_dict()
keys1 = current_state_dict.keys()
keys2 = ckpt_state_dict.keys()
if not include_backbone:
keys1 = set([k for k in keys1 if not k.startswith("backbone")])
keys2 = set([k for k in keys2 if not k.startswith("backbone")])
else:
keys1 = set(keys1)
keys2 = set(keys2)
if len(keys2 - keys1) > 0:
print("WARNING unused keys in provided state dict: ", keys2 - keys1)
if len(keys1 - keys2) > 0:
print("WARNING provided state dict does not have values for keys: ", keys1 - keys2)
for k in list(keys1 & keys2):
current_state_dict[k] = ckpt_state_dict[k]
self.load_state_dict(current_state_dict, strict=False)
# From https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py
def positionalencoding2d(d_model, height, width):
"""
:param d_model: dimension of the model
:param height: height of the positions
:param width: width of the positions
:return: d_model*height*width position matrix
"""
if d_model % 4 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dimension (got dim={:d})".format(d_model))
pe = torch.zeros(d_model, height, width)
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = torch.exp(torch.arange(0., d_model, 2) *
-(math.log(10000.0) / d_model))
pos_w = torch.arange(0., width).unsqueeze(1)
pos_h = torch.arange(0., height).unsqueeze(1)
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
return pe
# models
def get_gazelle_model(model_name):
factory = {
"gazelle_dinov2_vitb14": gazelle_dinov2_vitb14,
"gazelle_dinov2_vitl14": gazelle_dinov2_vitl14,
"gazelle_dinov2_vitb14_inout": gazelle_dinov2_vitb14_inout,
"gazelle_dinov2_vitl14_inout": gazelle_dinov2_vitl14_inout,
}
assert model_name in factory.keys(), "invalid model name"
return factory[model_name]()
def gazelle_dinov2_vitb14():
backbone = DinoV2Backbone('dinov2_vitb14')
transform = backbone.get_transform((448, 448))
model = GazeLLE(backbone)
return model, transform
def gazelle_dinov2_vitl14():
backbone = DinoV2Backbone('dinov2_vitl14')
transform = backbone.get_transform((448, 448))
model = GazeLLE(backbone)
return model, transform
def gazelle_dinov2_vitb14_inout():
backbone = DinoV2Backbone('dinov2_vitb14')
transform = backbone.get_transform((448, 448))
model = GazeLLE(backbone, inout=True)
return model, transform
def gazelle_dinov2_vitl14_inout():
backbone = DinoV2Backbone('dinov2_vitl14')
transform = backbone.get_transform((448, 448))
model = GazeLLE(backbone, inout=True)
return model, transform
|