|
import warnings
|
|
|
|
import torch.nn as nn
|
|
import torch
|
|
import numpy as np
|
|
import math
|
|
import torch.nn.functional as F
|
|
from torch.nn import init
|
|
import random
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
def __init__(self, channel):
|
|
super(ResidualBlock, self).__init__()
|
|
self.channel = channel
|
|
self.conv1 = nn.Sequential(
|
|
nn.Conv2d(in_channels=channel,
|
|
out_channels=channel,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1),
|
|
nn.BatchNorm2d(channel),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
self.conv2 = nn.Sequential(
|
|
nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),
|
|
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = self.conv1(x)
|
|
out = self.conv2(out)
|
|
out += x
|
|
out = F.relu(out)
|
|
return out
|
|
|
|
class ResNet(nn.Module):
|
|
def __init__(self):
|
|
super(ResNet, self).__init__()
|
|
self.conv1 = nn.Sequential(
|
|
nn.Conv2d(in_channels=384, out_channels=448, kernel_size=5),
|
|
nn.BatchNorm2d(448),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(2)
|
|
)
|
|
self.conv2 = nn.Sequential(
|
|
nn.Conv2d(in_channels=448, out_channels=640, kernel_size=5),
|
|
nn.BatchNorm2d(640),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(2)
|
|
)
|
|
self.conv3 = nn.Sequential(
|
|
nn.Conv2d(in_channels=640, out_channels=1024, kernel_size=5),
|
|
nn.BatchNorm2d(1024),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(2)
|
|
)
|
|
self.reslayer1 = ResidualBlock(448)
|
|
self.reslayer2 = ResidualBlock(640)
|
|
self.reslayer3 = ResidualBlock(1024)
|
|
|
|
|
|
def forward(self, x):
|
|
out = self.conv1(x)
|
|
out = self.reslayer1(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.reslayer2(out)
|
|
|
|
out = self.conv3(out)
|
|
out = self.reslayer3(out)
|
|
|
|
return out
|
|
|
|
|
|
class Loss_Function(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.mse = nn.MSELoss(reduction="none")
|
|
|
|
self.token_num = args.model.SLOTS.token_num
|
|
self.num_slots = args.model.SLOTS.num_slots
|
|
|
|
self.epsilon = 1e-8
|
|
|
|
def forward(self, reconstruction, masks, target):
|
|
|
|
|
|
|
|
|
|
target = target.detach()
|
|
loss = self.mse(reconstruction, target.detach()).mean()
|
|
|
|
return loss
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, input_dim, hidden_dim, output_dim, residual=False, layer_order="none"):
|
|
super().__init__()
|
|
self.residual = residual
|
|
self.layer_order = layer_order
|
|
if residual:
|
|
assert input_dim == output_dim
|
|
|
|
self.layer1 = nn.Linear(input_dim, hidden_dim)
|
|
self.layer2 = nn.Linear(hidden_dim, output_dim)
|
|
self.activation = nn.ReLU(inplace=True)
|
|
self.dropout = nn.Dropout(p=0.1)
|
|
|
|
if layer_order in ["pre", "post"]:
|
|
self.norm = nn.LayerNorm(input_dim)
|
|
else:
|
|
assert layer_order == "none"
|
|
|
|
def forward(self, x):
|
|
input = x
|
|
|
|
if self.layer_order == "pre":
|
|
x = self.norm(x)
|
|
|
|
x = self.layer1(x)
|
|
x = self.activation(x)
|
|
x = self.layer2(x)
|
|
x = self.dropout(x)
|
|
|
|
if self.residual:
|
|
x = x + input
|
|
if self.layer_order == "post":
|
|
x = self.norm(x)
|
|
|
|
return x
|
|
|
|
|
|
class Visual_Encoder(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
|
|
self.resize_to = args.resize_to
|
|
self.token_num = args.token_num
|
|
|
|
self.encoder = args.encoder
|
|
|
|
self.model = self.load_model(args)
|
|
|
|
def load_model(self, args):
|
|
assert args.resize_to[0] % args.patch_size == 0
|
|
assert args.resize_to[1] % args.patch_size == 0
|
|
|
|
if args.encoder == "dino-vitb-8":
|
|
model = torch.hub.load("facebookresearch/dino:main", "dino_vitb8")
|
|
elif args.encoder == "dino-vitb-16":
|
|
model = torch.hub.load("facebookresearch/dino:main", "dino_vitb16")
|
|
elif args.encoder == "dinov2-vitb-14":
|
|
model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
|
|
elif args.encoder == "sup-vitb-16":
|
|
model = timm.create_model("vit_base_patch16_224", pretrained=True,
|
|
img_size=(args.resize_to[0], args.resize_to[1]))
|
|
else:
|
|
assert False
|
|
|
|
for p in model.parameters():
|
|
p.requires_grad = False
|
|
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
@torch.no_grad()
|
|
def forward(self, frames):
|
|
|
|
|
|
|
|
|
|
B = frames.shape[0]
|
|
|
|
self.model.eval()
|
|
|
|
if self.encoder.startswith("dinov2-"):
|
|
x = self.model.prepare_tokens_with_masks(frames)
|
|
elif self.encoder.startswith("sup-"):
|
|
x = self.model.patch_embed(frames)
|
|
x = self.model._pos_embed(x)
|
|
else:
|
|
x = self.model.prepare_tokens(frames)
|
|
|
|
for blk in self.model.blocks:
|
|
x = blk(x)
|
|
x = x[:, 1:]
|
|
|
|
assert x.shape[0] == B
|
|
assert x.shape[1] == self.token_num
|
|
assert x.shape[2] == 1024
|
|
|
|
return x
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
|
|
|
|
slot_dim = args['slot_dim']
|
|
hidden_dim = 2048
|
|
|
|
|
|
self.layer1 = nn.Linear(slot_dim, hidden_dim)
|
|
self.layer2 = nn.Linear(hidden_dim, hidden_dim)
|
|
self.layer3 = nn.Linear(hidden_dim, hidden_dim)
|
|
self.layer4 = nn.Linear(hidden_dim, 1024 + 1)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
def forward(self, slot_maps):
|
|
|
|
|
|
slot_maps = self.relu(self.layer1(slot_maps))
|
|
slot_maps = self.relu(self.layer2(slot_maps))
|
|
slot_maps = self.relu(self.layer3(slot_maps))
|
|
|
|
slot_maps = self.layer4(slot_maps)
|
|
|
|
return slot_maps
|
|
|
|
|
|
class ISA(nn.Module):
|
|
def __init__(self, args, input_dim):
|
|
super().__init__()
|
|
|
|
self.num_slots = args.num_slots
|
|
self.scale = args.slot_dim ** -0.5
|
|
self.iters = args.slot_att_iter
|
|
self.slot_dim = args.slot_dim
|
|
self.query_opt = args.query_opt
|
|
|
|
self.res_h = args.resize_to[0] // args.patch_size
|
|
self.res_w = args.resize_to[1] // args.patch_size
|
|
self.token = int(self.res_h * self.res_w)
|
|
|
|
|
|
self.sigma = 5
|
|
xs = torch.linspace(-1, 1, steps=self.res_w)
|
|
ys = torch.linspace(-1, 1, steps=self.res_h)
|
|
|
|
xs, ys = torch.meshgrid(xs, ys, indexing='xy')
|
|
xs = xs.reshape(1, 1, -1, 1)
|
|
ys = ys.reshape(1, 1, -1, 1)
|
|
self.abs_grid = nn.Parameter(torch.cat([xs, ys], dim=-1), requires_grad=False)
|
|
assert self.abs_grid.shape[2] == self.token
|
|
|
|
self.h = nn.Linear(2, self.slot_dim)
|
|
|
|
|
|
|
|
if self.query_opt:
|
|
self.slots = nn.Parameter(torch.Tensor(1, self.num_slots, self.slot_dim))
|
|
init.xavier_uniform_(self.slots)
|
|
else:
|
|
self.slots_mu = nn.Parameter(torch.randn(1, 1, self.slot_dim))
|
|
self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, self.slot_dim))
|
|
init.xavier_uniform_(self.slots_mu)
|
|
init.xavier_uniform_(self.slots_logsigma)
|
|
|
|
self.S_s = nn.Parameter(torch.Tensor(1, self.num_slots, 1, 2))
|
|
self.S_p = nn.Parameter(torch.Tensor(1, self.num_slots, 1, 2))
|
|
|
|
init.normal_(self.S_s, mean=0., std=.02)
|
|
init.normal_(self.S_p, mean=0., std=.02)
|
|
|
|
|
|
|
|
self.Q = nn.Linear(self.slot_dim, self.slot_dim, bias=False)
|
|
self.norm = nn.LayerNorm(self.slot_dim)
|
|
self.gru = nn.GRUCell(self.slot_dim, self.slot_dim)
|
|
self.mlp = MLP(self.slot_dim, 4 * self.slot_dim, self.slot_dim,
|
|
residual=True, layer_order="pre")
|
|
|
|
|
|
|
|
self.K = nn.Linear(self.slot_dim, self.slot_dim, bias=False)
|
|
self.V = nn.Linear(self.slot_dim, self.slot_dim, bias=False)
|
|
|
|
self.g = nn.Linear(2, self.slot_dim)
|
|
self.f = nn.Sequential(nn.Linear(self.slot_dim, self.slot_dim),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(self.slot_dim, self.slot_dim))
|
|
|
|
|
|
|
|
self.initial_mlp = nn.Sequential(nn.LayerNorm(input_dim),
|
|
nn.Linear(input_dim, input_dim),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(input_dim, self.slot_dim),
|
|
nn.LayerNorm(self.slot_dim))
|
|
|
|
self.final_layer = nn.Linear(self.slot_dim, self.slot_dim)
|
|
|
|
def get_rel_grid(self, attn):
|
|
|
|
|
|
|
|
|
|
B, S = attn.shape[:2]
|
|
attn = attn.unsqueeze(dim=2)
|
|
|
|
abs_grid = self.abs_grid.expand(B, S, self.token, 2)
|
|
|
|
S_p = torch.einsum('bsjd,bsij->bsd', abs_grid, attn)
|
|
S_p = S_p.unsqueeze(dim=2)
|
|
|
|
values_ss = torch.pow(abs_grid - S_p, 2)
|
|
S_s = torch.einsum('bsjd,bsij->bsd', values_ss, attn)
|
|
S_s = torch.sqrt(S_s)
|
|
S_s = S_s.unsqueeze(dim=2)
|
|
|
|
rel_grid = (abs_grid - S_p) / (S_s * self.sigma)
|
|
rel_grid = self.h(rel_grid)
|
|
|
|
return rel_grid
|
|
|
|
def forward(self, inputs):
|
|
|
|
|
|
|
|
|
|
|
|
B, N, D = inputs.shape
|
|
S = self.num_slots
|
|
D_slot = self.slot_dim
|
|
epsilon = 1e-8
|
|
|
|
if self.query_opt:
|
|
slots = self.slots.expand(B, S, D_slot)
|
|
else:
|
|
mu = self.slots_mu.expand(B, S, D_slot)
|
|
sigma = self.slots_logsigma.exp().expand(B, S, D_slot)
|
|
slots = mu + sigma * torch.randn(mu.shape, device=sigma.device, dtype=sigma.dtype)
|
|
|
|
slots_init = slots
|
|
inputs = self.initial_mlp(inputs).unsqueeze(dim=1)
|
|
inputs = inputs.expand(B, S, N, D_slot)
|
|
|
|
abs_grid = self.abs_grid.expand(B, S, self.token, 2)
|
|
|
|
assert torch.sum(torch.isnan(abs_grid)) == 0
|
|
|
|
S_s = self.S_s.expand(B, S, 1, 2)
|
|
S_p = self.S_p.expand(B, S, 1, 2)
|
|
|
|
for t in range(self.iters + 1):
|
|
|
|
|
|
|
|
assert torch.sum(torch.isnan(slots)) == 0, f"Iteration {t}"
|
|
assert torch.sum(torch.isnan(S_s)) == 0, f"Iteration {t}"
|
|
assert torch.sum(torch.isnan(S_p)) == 0, f"Iteration {t}"
|
|
|
|
if self.query_opt and (t == self.iters - 1):
|
|
slots = slots.detach() + slots_init - slots_init.detach()
|
|
|
|
slots_prev = slots
|
|
slots = self.norm(slots)
|
|
|
|
|
|
rel_grid = (abs_grid - S_p) / (S_s * self.sigma)
|
|
k = self.f(self.K(inputs) + self.g(rel_grid))
|
|
v = self.f(self.V(inputs) + self.g(rel_grid))
|
|
|
|
|
|
q = self.Q(slots).unsqueeze(dim=-1)
|
|
|
|
dots = torch.einsum('bsdi,bsjd->bsj', q, k)
|
|
dots *= self.scale
|
|
attn = dots.softmax(dim=1) + epsilon
|
|
|
|
|
|
attn = attn / attn.sum(dim=-1, keepdim=True)
|
|
attn = attn.unsqueeze(dim=2)
|
|
updates = torch.einsum('bsjd,bsij->bsd', v,
|
|
attn)
|
|
|
|
|
|
S_p = torch.einsum('bsjd,bsij->bsd', abs_grid, attn)
|
|
S_p = S_p.unsqueeze(dim=2)
|
|
|
|
values_ss = torch.pow(abs_grid - S_p, 2)
|
|
S_s = torch.einsum('bsjd,bsij->bsd', values_ss, attn)
|
|
S_s = torch.sqrt(S_s)
|
|
S_s = S_s.unsqueeze(dim=2)
|
|
|
|
|
|
if t != self.iters:
|
|
slots = self.gru(
|
|
updates.reshape(-1, self.slot_dim),
|
|
slots_prev.reshape(-1, self.slot_dim))
|
|
|
|
slots = slots.reshape(B, -1, self.slot_dim)
|
|
slots = self.mlp(slots)
|
|
|
|
slots = self.final_layer(slots_prev)
|
|
attn = attn.squeeze(dim=2)
|
|
|
|
return slots, attn
|
|
|
|
|
|
class SA(nn.Module):
|
|
def __init__(self, args, input_dim):
|
|
|
|
super().__init__()
|
|
self.num_slots = args['num_slots']
|
|
self.scale = args['num_slots'] ** -0.5
|
|
self.iters = args['slot_att_iter']
|
|
self.slot_dim = args['slot_dim']
|
|
self.query_opt = args['query_opt']
|
|
|
|
|
|
if self.query_opt:
|
|
self.slots = nn.Parameter(torch.Tensor(1, self.num_slots, self.slot_dim))
|
|
init.xavier_uniform_(self.slots)
|
|
else:
|
|
self.slots_mu = nn.Parameter(torch.randn(1, 1, self.slot_dim))
|
|
self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, self.slot_dim))
|
|
init.xavier_uniform_(self.slots_mu)
|
|
init.xavier_uniform_(self.slots_logsigma)
|
|
|
|
|
|
self.Q = nn.Linear(self.slot_dim, self.slot_dim, bias=False)
|
|
self.norm = nn.LayerNorm(self.slot_dim)
|
|
|
|
self.gru = nn.GRUCell(self.slot_dim, self.slot_dim)
|
|
self.mlp = MLP(self.slot_dim, 4 * self.slot_dim, self.slot_dim,
|
|
residual=True, layer_order="pre")
|
|
|
|
|
|
|
|
self.K = nn.Linear(self.slot_dim, self.slot_dim, bias=False)
|
|
self.V = nn.Linear(self.slot_dim, self.slot_dim, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.initial_mlp = nn.Sequential(nn.LayerNorm(input_dim),
|
|
nn.Linear(input_dim, input_dim),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(input_dim, self.slot_dim),
|
|
nn.LayerNorm(self.slot_dim))
|
|
|
|
self.final_layer = nn.Linear(self.slot_dim, self.slot_dim)
|
|
|
|
def forward(self, inputs):
|
|
|
|
|
|
|
|
|
|
B = inputs.shape[0]
|
|
S = self.num_slots
|
|
D_slot = self.slot_dim
|
|
epsilon = 1e-8
|
|
|
|
if self.query_opt:
|
|
slots = self.slots.expand(B, S, D_slot)
|
|
else:
|
|
mu = self.slots_mu.expand(B, S, D_slot)
|
|
sigma = self.slots_logsigma.exp().expand(B, S, D_slot)
|
|
slots = mu + sigma * torch.randn(mu.shape, device=sigma.device, dtype=sigma.dtype)
|
|
|
|
slots_init = slots
|
|
inputs = self.initial_mlp(inputs)
|
|
|
|
keys = self.K(inputs)
|
|
values = self.V(inputs)
|
|
|
|
for t in range(self.iters):
|
|
assert torch.sum(torch.isnan(slots)) == 0, f"Iteration {t}"
|
|
|
|
if t == self.iters - 1 and self.query_opt:
|
|
slots = slots.detach() + slots_init - slots_init.detach()
|
|
|
|
slots_prev = slots
|
|
slots = self.norm(slots)
|
|
queries = self.Q(slots)
|
|
|
|
dots = torch.einsum('bsd,btd->bst', queries, keys)
|
|
dots *= self.scale
|
|
attn = dots.softmax(dim=1) + epsilon
|
|
attn = attn / attn.sum(dim=-1, keepdim=True)
|
|
|
|
updates = torch.einsum('bst,btd->bsd', attn, values)
|
|
|
|
slots = self.gru(
|
|
updates.reshape(-1, self.slot_dim),
|
|
slots_prev.reshape(-1, self.slot_dim))
|
|
|
|
slots = slots.reshape(B, -1, self.slot_dim)
|
|
slots = self.mlp(slots)
|
|
|
|
self.final_layer(slots)
|
|
|
|
return slots
|
|
|
|
|
|
class DINOSAURpp(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
|
|
self.slot_dim = args['slot_dim']
|
|
self.slot_num = args['num_slots']
|
|
self.token_num = args['token_num']
|
|
|
|
self.ISA = args['ISA']
|
|
if self.ISA:
|
|
self.slot_encoder = ISA(args, input_dim=1024)
|
|
else:
|
|
self.slot_encoder = SA(args, input_dim=1024)
|
|
|
|
self.slot_decoder = Decoder(args)
|
|
|
|
self.pos_dec = nn.Parameter(torch.Tensor(1, self.token_num, self.slot_dim))
|
|
init.normal_(self.pos_dec, mean=0., std=.02)
|
|
|
|
def sbd_slots(self, slots):
|
|
|
|
|
|
|
|
|
|
B, S, D_slot = slots.shape
|
|
|
|
slots = slots.view(-1, 1, D_slot)
|
|
slots = slots.tile(1, self.token_num, 1)
|
|
|
|
pos_embed = self.pos_dec.expand(slots.shape)
|
|
slots = slots + pos_embed
|
|
slots = slots.view(B, S, self.token_num, D_slot)
|
|
pos_embed = pos_embed.view(B, S, self.token_num, D_slot)
|
|
|
|
return slots, pos_embed
|
|
|
|
def reconstruct_feature_map(self, slot_maps):
|
|
|
|
|
|
|
|
|
|
|
|
B = slot_maps.shape[0]
|
|
|
|
channels, masks = torch.split(slot_maps, [1024, 1], dim=-1)
|
|
masks = masks.softmax(dim=1)
|
|
|
|
reconstruction = torch.sum(channels * masks, dim=1)
|
|
masks = masks.squeeze(dim=-1)
|
|
|
|
return reconstruction, masks
|
|
|
|
def forward(self, features):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B, token, _ = features.shape
|
|
|
|
if self.ISA:
|
|
slots, attn = self.slot_encoder(features)
|
|
assert torch.sum(torch.isnan(slots)) == 0
|
|
assert torch.sum(torch.isnan(attn)) == 0
|
|
|
|
rel_grid = self.slot_encoder.get_rel_grid(attn)
|
|
|
|
slot_maps = self.sbd_slots(slots) + rel_grid
|
|
slot_maps = self.slot_decoder(slot_maps)
|
|
|
|
else:
|
|
slots = self.slot_encoder(features)
|
|
assert torch.sum(torch.isnan(slots)) == 0
|
|
|
|
slot_maps, pos_maps = self.sbd_slots(slots)
|
|
slot_maps = self.slot_decoder(slot_maps)
|
|
|
|
reconstruction, masks = self.reconstruct_feature_map(slot_maps)
|
|
|
|
return reconstruction, slots, masks
|
|
|
|
|
|
|