|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
from typing import List, Tuple |
|
|
|
import logging |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from timm.layers import to_2tuple |
|
from timm.models.vision_transformer import Block |
|
|
|
|
|
def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): |
|
""" |
|
Create 3D sin/cos positional embeddings. |
|
|
|
Args: |
|
embed_dim (int): |
|
Embedding dimension. |
|
grid_size (tuple[int, int, int] | list[int]): |
|
The grid depth, height and width. |
|
add_cls_token (bool, *optional*, defaults to False): |
|
Whether or not to add a classification (CLS) token. |
|
|
|
Returns: |
|
(`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or |
|
(1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token) |
|
""" |
|
|
|
assert embed_dim % 16 == 0 |
|
|
|
t_size, h_size, w_size = grid_size |
|
|
|
w_embed_dim = embed_dim // 16 * 6 |
|
h_embed_dim = embed_dim // 16 * 6 |
|
t_embed_dim = embed_dim // 16 * 4 |
|
|
|
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) |
|
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) |
|
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) |
|
|
|
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) |
|
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) |
|
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) |
|
|
|
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) |
|
|
|
if add_cls_token: |
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
|
return pos_embed |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) |
|
""" |
|
if embed_dim % 2 != 0: |
|
raise ValueError("embed_dim must be even") |
|
|
|
omega = np.arange(embed_dim // 2, dtype=float) |
|
omega /= embed_dim / 2.0 |
|
omega = 1.0 / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = np.einsum("m,d->md", pos, omega) |
|
|
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
return emb |
|
|
|
|
|
def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor): |
|
""" This is the torch version of *get_1d_sincos_pos_embed_from_grid()*. However, |
|
it was modified to cast omega values to pos.dtype which must be float (and not int as in |
|
regular positional embeddings). This was required in order to allow for native FSDP mixed |
|
precision support: modify omega to appropriate dtype (pos carries the correct float dtype), |
|
instead of manually forcing float32. |
|
|
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) - must be float dtype! |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16] |
|
|
|
omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device) |
|
omega /= embed_dim / 2.0 |
|
omega = 1.0 / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = torch.einsum("m,d->md", pos, omega) |
|
|
|
emb_sin = torch.sin(out) |
|
emb_cos = torch.cos(out) |
|
|
|
emb = torch.cat([emb_sin, emb_cos], dim=1) |
|
|
|
return emb |
|
|
|
|
|
def _init_weights(module): |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
"""3D version of timm.models.vision_transformer.PatchEmbed""" |
|
def __init__( |
|
self, |
|
input_size: Tuple[int, int, int] = (1, 224, 224), |
|
patch_size: Tuple[int, int, int] = (1, 16, 16), |
|
in_chans: int = 3, |
|
embed_dim: int = 768, |
|
norm_layer: nn.Module | None = None, |
|
flatten: bool = True, |
|
bias: bool = True, |
|
): |
|
super().__init__() |
|
self.input_size = input_size |
|
self.patch_size = patch_size |
|
self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)] |
|
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] |
|
self.flatten = flatten |
|
|
|
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) |
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
|
def forward(self, x): |
|
B, C, T, H, W = x.shape |
|
|
|
if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1: |
|
logging.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." |
|
f"The border will be ignored, add backbone_padding for pixel-wise tasks.") |
|
|
|
x = self.proj(x) |
|
if self.flatten: |
|
x = x.flatten(2).transpose(1, 2) |
|
x = self.norm(x) |
|
return x |
|
|
|
|
|
class TemporalEncoder(nn.Module): |
|
def __init__(self, embed_dim: int, trainable_scale: bool = False): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.year_embed_dim = embed_dim // 2 |
|
self.julian_day_embed_dim = embed_dim - self.year_embed_dim |
|
|
|
|
|
if trainable_scale: |
|
self.scale = nn.Parameter(torch.full((1,), 0.1)) |
|
else: |
|
self.register_buffer('scale', torch.ones(1)) |
|
|
|
def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None): |
|
""" |
|
temporal_coords: year and day-of-year info with shape (B, T, 2). |
|
tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be |
|
repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim). |
|
""" |
|
shape = temporal_coords.shape[:2] + (-1,) |
|
|
|
year = _get_1d_sincos_embed_from_grid_torch( |
|
self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape) |
|
julian_day = _get_1d_sincos_embed_from_grid_torch( |
|
self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape) |
|
|
|
embedding = self.scale * torch.cat([year, julian_day], dim=-1) |
|
|
|
if tokens_per_frame is not None: |
|
embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1) |
|
|
|
return embedding |
|
|
|
|
|
class LocationEncoder(nn.Module): |
|
def __init__(self, embed_dim: int, trainable_scale: bool = False): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.lat_embed_dim = embed_dim // 2 |
|
self.lon_embed_dim = embed_dim - self.lat_embed_dim |
|
|
|
|
|
if trainable_scale: |
|
self.scale = nn.Parameter(torch.full((1,), 0.1)) |
|
else: |
|
self.register_buffer('scale', torch.ones(1)) |
|
|
|
def forward(self, location_coords: torch.Tensor): |
|
""" |
|
location_coords: lat and lon info with shape (B, 2). |
|
""" |
|
shape = location_coords.shape[:1] + (1, -1) |
|
|
|
lat = _get_1d_sincos_embed_from_grid_torch( |
|
self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape) |
|
lon = _get_1d_sincos_embed_from_grid_torch( |
|
self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape) |
|
|
|
embedding = self.scale * torch.cat([lat, lon], dim=-1) |
|
|
|
return embedding |
|
|
|
|
|
class PrithviViT(nn.Module): |
|
""" Prithvi ViT Encoder""" |
|
def __init__(self, |
|
img_size: int | Tuple[int, int] = 224, |
|
patch_size: int | Tuple[int, int, int] = (1, 16, 16), |
|
num_frames: int = 1, |
|
in_chans: int = 3, |
|
embed_dim: int = 1024, |
|
depth: int = 24, |
|
num_heads: int = 16, |
|
mlp_ratio: float = 4., |
|
norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6), |
|
coords_encoding: List[str] | None = None, |
|
coords_scale_learn: bool = False, |
|
encoder_only: bool = True, |
|
** kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.feature_info = [] |
|
self.encoder_only = encoder_only |
|
self.in_chans = in_chans |
|
self.num_frames = num_frames |
|
self.embed_dim = embed_dim |
|
self.img_size = to_2tuple(img_size) |
|
if isinstance(patch_size, int): |
|
patch_size = (1, patch_size, patch_size) |
|
|
|
|
|
self.patch_embed = PatchEmbed( |
|
input_size=(num_frames,) + self.img_size, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
embed_dim=embed_dim, |
|
) |
|
|
|
|
|
coords_encoding = coords_encoding or [] |
|
self.temporal_encoding = 'time' in coords_encoding |
|
self.location_encoding = 'location' in coords_encoding |
|
if self.temporal_encoding: |
|
assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}" |
|
self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn) |
|
if self.location_encoding: |
|
self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn) |
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)) |
|
|
|
|
|
self.blocks = [] |
|
for i in range(depth): |
|
self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)) |
|
self.feature_info.append( |
|
{"num_chs": embed_dim * self.patch_embed.patch_size[0], "reduction": 1, "module": f"blocks.{i}"} |
|
) |
|
self.blocks = nn.ModuleList(self.blocks) |
|
|
|
self.norm = norm_layer(embed_dim) |
|
|
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
|
|
pos_embed = get_3d_sincos_pos_embed( |
|
self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True |
|
) |
|
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
w = self.patch_embed.proj.weight.data |
|
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
torch.nn.init.normal_(self.cls_token, std=0.02) |
|
self.apply(_init_weights) |
|
|
|
def random_masking(self, sequence, mask_ratio, noise=None): |
|
""" |
|
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random |
|
noise. |
|
|
|
Args: |
|
sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`) |
|
mask_ratio (float): mask ratio to use. |
|
noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is |
|
mainly used for testing purposes to control randomness and maintain the reproducibility |
|
""" |
|
batch_size, seq_length, dim = sequence.shape |
|
len_keep = int(seq_length * (1 - mask_ratio)) |
|
|
|
if noise is None: |
|
noise = torch.rand(batch_size, seq_length, device=sequence.device) |
|
|
|
|
|
ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) |
|
ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device) |
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)) |
|
|
|
|
|
mask = torch.ones([batch_size, seq_length], device=sequence.device) |
|
mask[:, :len_keep] = 0 |
|
|
|
mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
|
return sequence_unmasked, mask, ids_restore |
|
|
|
def _get_pos_embed(self, x): |
|
t, h, w = x.shape[-3:] |
|
|
|
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed( |
|
self.embed_dim, |
|
( |
|
t // self.patch_embed.patch_size[0], |
|
h // self.patch_embed.patch_size[1], |
|
w // self.patch_embed.patch_size[2], |
|
), |
|
add_cls_token=True, |
|
)).float().unsqueeze(0).to(x) |
|
|
|
return pos_embed |
|
|
|
|
|
def forward( |
|
self, x: torch.Tensor, |
|
temporal_coords: None | torch.Tensor = None, |
|
location_coords: None | torch.Tensor = None, |
|
mask_ratio=0.75 |
|
): |
|
if x.shape[-3:] != self.patch_embed.input_size: |
|
|
|
pos_embed = self._get_pos_embed(x) |
|
else: |
|
pos_embed = self.pos_embed |
|
|
|
|
|
x = self.patch_embed(x) |
|
|
|
|
|
x = x + pos_embed[:, 1:, :] |
|
|
|
if self.temporal_encoding: |
|
num_tokens_per_frame = x.shape[1] // self.num_frames |
|
temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) |
|
x = x + temporal_encoding |
|
if self.location_encoding: |
|
location_encoding = self.location_embed_enc(location_coords) |
|
x = x + location_encoding |
|
|
|
|
|
x, mask, ids_restore = self.random_masking(x, mask_ratio) |
|
|
|
|
|
cls_token = self.cls_token + pos_embed[:, :1, :] |
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
|
|
for block in self.blocks: |
|
x = block(x) |
|
x = self.norm(x) |
|
|
|
return x, mask, ids_restore |
|
|
|
def forward_features( |
|
self, |
|
x: torch.Tensor, |
|
temporal_coords: None | torch.Tensor = None, |
|
location_coords: None | torch.Tensor = None, |
|
) -> list[torch.Tensor]: |
|
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: |
|
|
|
x = x.unsqueeze(2) |
|
|
|
if x.shape[-3:] != self.patch_embed.input_size: |
|
pos_embed = self._get_pos_embed(x) |
|
else: |
|
pos_embed = self.pos_embed |
|
|
|
|
|
x = self.patch_embed(x) |
|
|
|
|
|
x = x + pos_embed[:, 1:, :] |
|
|
|
if self.temporal_encoding: |
|
num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames |
|
temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) |
|
x = x + temporal_encoding |
|
if self.location_encoding: |
|
location_encoding = self.location_embed_enc(location_coords) |
|
x = x + location_encoding |
|
|
|
|
|
cls_token = self.cls_token + pos_embed[:, :1, :] |
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
|
|
out = [] |
|
for block in self.blocks: |
|
x = block(x) |
|
out.append(x.clone()) |
|
|
|
x = self.norm(x) |
|
out[-1] = x |
|
return out |
|
|
|
def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]: |
|
out = [] |
|
effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0] |
|
for x in features: |
|
x_no_token = x[:, 1:, :] |
|
number_of_tokens = x_no_token.shape[1] |
|
tokens_per_timestep = number_of_tokens // effective_time_dim |
|
h = int(np.sqrt(tokens_per_timestep)) |
|
encoded = rearrange( |
|
x_no_token, |
|
"batch (t h w) e -> batch (t e) h w", |
|
e=self.embed_dim, |
|
t=effective_time_dim, |
|
h=h, |
|
) |
|
out.append(encoded) |
|
return out |
|
|
|
|
|
class MAEDecoder(nn.Module): |
|
""" Transformer Decoder used in the Prithvi MAE""" |
|
def __init__(self, |
|
patch_size: int | Tuple[int, int, int] = (1, 16, 16), |
|
grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14), |
|
in_chans: int = 3, |
|
encoder_embed_dim: int = 1024, |
|
decoder_embed_dim: int = 512, |
|
depth: int = 8, |
|
num_heads: int = 16, |
|
mlp_ratio: float = 4., |
|
norm_layer: nn.Module = nn.LayerNorm, |
|
coords_encoding: List[str] | None = None, |
|
coords_scale_learn: bool = False, |
|
): |
|
super().__init__() |
|
|
|
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) |
|
self.decoder_embed_dim = decoder_embed_dim |
|
self.grid_size = grid_size |
|
if isinstance(patch_size, int): |
|
patch_size = (1, patch_size, patch_size) |
|
self.patch_size = patch_size |
|
self.num_frames = self.grid_size[0] * patch_size[0] |
|
num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] |
|
|
|
|
|
coords_encoding = coords_encoding or [] |
|
self.temporal_encoding = 'time' in coords_encoding |
|
self.location_encoding = 'location' in coords_encoding |
|
if self.temporal_encoding: |
|
self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn) |
|
if self.location_encoding: |
|
self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn) |
|
|
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
|
|
|
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim)) |
|
|
|
self.decoder_blocks = nn.ModuleList( |
|
[Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)] |
|
) |
|
|
|
self.decoder_norm = norm_layer(decoder_embed_dim) |
|
self.decoder_pred = nn.Linear(decoder_embed_dim, |
|
patch_size[0] * patch_size[1] * patch_size[2] * in_chans, |
|
bias=True) |
|
|
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
|
|
decoder_pos_embed = get_3d_sincos_pos_embed( |
|
self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True |
|
) |
|
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
torch.nn.init.normal_(self.mask_token, std=0.02) |
|
self.apply(_init_weights) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
ids_restore: torch.Tensor, |
|
temporal_coords: None | torch.Tensor = None, |
|
location_coords: None | torch.Tensor = None, |
|
input_size: list[int] = None, |
|
): |
|
|
|
x = self.decoder_embed(hidden_states) |
|
|
|
t, h, w = input_size[-3:] |
|
decoder_pos_embed = torch.from_numpy( |
|
get_3d_sincos_pos_embed( |
|
self.decoder_embed_dim, |
|
( |
|
t // self.patch_size[0], |
|
h // self.patch_size[1], |
|
w // self.patch_size[2], |
|
), |
|
add_cls_token=True, |
|
) |
|
).to(x) |
|
|
|
|
|
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) |
|
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
|
|
|
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device)) |
|
x = torch.cat([x[:, :1, :], x_], dim=1) |
|
|
|
x = x + decoder_pos_embed |
|
|
|
|
|
x_ = x[:, 1:, :] |
|
|
|
if self.temporal_encoding: |
|
num_tokens_per_frame = x_.shape[1] // self.num_frames |
|
temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame) |
|
|
|
x_ = x_ + temporal_encoding |
|
if self.location_encoding: |
|
location_encoding = self.location_embed_dec(location_coords) |
|
|
|
x_ = x_ + location_encoding |
|
|
|
|
|
x = torch.cat([x[:, :1, :], x_], dim=1) |
|
|
|
|
|
for block in self.decoder_blocks: |
|
x = block(x) |
|
x = self.decoder_norm(x) |
|
|
|
|
|
pred = self.decoder_pred(x) |
|
|
|
|
|
pred = pred[:, 1:, :] |
|
|
|
return pred |
|
|
|
|
|
class PrithviMAE(nn.Module): |
|
""" Prithvi Masked Autoencoder""" |
|
|
|
def __init__(self, |
|
img_size: int | Tuple[int, int] = 224, |
|
patch_size: int | Tuple[int, int, int] = (1, 16, 16), |
|
num_frames: int = 3, |
|
in_chans: int = 3, |
|
embed_dim: int = 1024, |
|
depth: int = 24, |
|
num_heads: int = 16, |
|
decoder_embed_dim: int = 512, |
|
decoder_depth: int = 8, |
|
decoder_num_heads: int = 16, |
|
mlp_ratio: float = 4., |
|
norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6), |
|
norm_pix_loss: bool = False, |
|
coords_encoding: List[str] | None = None, |
|
coords_scale_learn: bool = False, |
|
encoder_only: bool = False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.encoder = PrithviViT( |
|
img_size=img_size, |
|
num_frames=num_frames, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
embed_dim=embed_dim, |
|
depth=depth, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
norm_layer=norm_layer, |
|
coords_encoding=coords_encoding, |
|
coords_scale_learn=coords_scale_learn, |
|
) |
|
|
|
self.encoder_only = encoder_only |
|
|
|
if not encoder_only: |
|
self.decoder = MAEDecoder( |
|
patch_size=patch_size, |
|
grid_size=self.encoder.patch_embed.grid_size, |
|
in_chans=in_chans, |
|
encoder_embed_dim=embed_dim, |
|
decoder_embed_dim=decoder_embed_dim, |
|
depth=decoder_depth, |
|
num_heads=decoder_num_heads, |
|
mlp_ratio=mlp_ratio, |
|
norm_layer=norm_layer, |
|
coords_encoding=coords_encoding, |
|
coords_scale_learn=coords_scale_learn, |
|
) |
|
else: |
|
self.decoder = nn.Identity() |
|
|
|
self.norm_pix_loss = norm_pix_loss |
|
|
|
def patchify(self, pixel_values): |
|
""" |
|
Args: |
|
pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`): |
|
Pixel values. |
|
|
|
Returns: |
|
torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: |
|
Patchified pixel values. |
|
""" |
|
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size |
|
num_channels = self.encoder.in_chans |
|
|
|
|
|
patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)', |
|
c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w) |
|
|
|
|
|
return patchified_pixel_values |
|
|
|
def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None): |
|
""" |
|
Args: |
|
patchified_pixel_values (`torch.FloatTensor` of shape |
|
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: |
|
Patchified pixel values. |
|
image_size (`Tuple[int, int]`, *optional*): |
|
Original image size. |
|
|
|
Returns: |
|
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: |
|
Pixel values. |
|
""" |
|
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size |
|
image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size |
|
original_height, original_width = image_size |
|
num_patches_h = original_height // patch_size_h |
|
num_patches_w = original_width // patch_size_w |
|
num_channels = self.encoder.in_chans |
|
|
|
pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)', |
|
c=num_channels, h=num_patches_h, w=num_patches_w, |
|
s=patch_size_t, p=patch_size_h, q=patch_size_w) |
|
return pixel_values |
|
|
|
def forward_loss(self, pixel_values, pred, mask): |
|
""" |
|
Args: |
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`): |
|
Pixel values. |
|
pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: |
|
Predicted pixel values. |
|
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): |
|
Tensor indicating which patches are masked (1) and which are not (0). |
|
|
|
Returns: |
|
`torch.FloatTensor`: Pixel reconstruction loss. |
|
""" |
|
target = self.patchify(pixel_values) |
|
if self.norm_pix_loss: |
|
mean = target.mean(dim=-1, keepdim=True) |
|
var = target.var(dim=-1, keepdim=True) |
|
target = (target - mean) / (var + 1.0e-6) ** 0.5 |
|
|
|
loss = (pred - target) ** 2 |
|
loss = loss.mean(dim=-1) |
|
loss = (loss * mask).sum() / mask.sum() |
|
return loss |
|
|
|
def forward( |
|
self, |
|
pixel_values: torch.Tensor, |
|
temporal_coords: None | torch.Tensor = None, |
|
location_coords: None | torch.Tensor = None, |
|
mask_ratio: float = 0.75 |
|
): |
|
if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1: |
|
|
|
pixel_values = pixel_values.unsqueeze(2) |
|
|
|
latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio) |
|
pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape) |
|
loss = self.forward_loss(pixel_values, pred, mask) |
|
return loss, pred, mask |
|
|
|
def forward_features( |
|
self, |
|
x: torch.Tensor, |
|
temporal_coords: None | torch.Tensor = None, |
|
location_coords: None | torch.Tensor = None, |
|
) -> List[torch.Tensor]: |
|
return self.encoder.forward_features(x, temporal_coords, location_coords) |
|
|