diff --git a/README.md b/README.md index d4bbdadaed6703a1fe608206ef44f262430588d1..ba6a7194ca3c27d21df9bf2abe7a8d2a6c04612d 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,12 @@ --- title: Flash3d -emoji: 📊 +emoji: 🌍 colorFrom: green -colorTo: purple +colorTo: indigo sdk: gradio -sdk_version: 5.0.2 +sdk_version: 4.36.0 app_file: app.py pinned: false -license: apache-2.0 -short_description: ' If you run the demo online, the first example you upload sh' --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..0427192b65657b414d9e63a8719752c8b4b51de5 --- /dev/null +++ b/app.py @@ -0,0 +1,134 @@ +import sys +import spaces +sys.path.append("flash3d") + +from omegaconf import OmegaConf +import gradio as gr +import torch +import torchvision.transforms as TT +import torchvision.transforms.functional as TTF +from huggingface_hub import hf_hub_download + +from networks.gaussian_predictor import GaussianPredictor +from util.vis3d import save_ply + +def main(): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + + model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", + filename="config_re10k_v1.yaml") + model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", + filename="model_re10k_v1.pth") + + cfg = OmegaConf.load(model_cfg_path) + model = GaussianPredictor(cfg) + device = torch.device(device) + model.to(device) + model.load_model(model_path) + + pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug)) + to_tensor = TT.ToTensor() + + def check_input_image(input_image): + if input_image is None: + raise gr.Error("No image uploaded!") + + def preprocess(image): + image = TTF.resize( + image, (cfg.dataset.height, cfg.dataset.width), + interpolation=TT.InterpolationMode.BICUBIC + ) + image = pad_border_fn(image) + return image + + @spaces.GPU(duration=120) + def reconstruct_and_export(image): + """ + Passes image through model, outputs reconstruction in form of a dict of tensors. + """ + image = to_tensor(image).to(device).unsqueeze(0) + inputs = { + ("color_aug", 0, 0): image, + } + + outputs = model(inputs) + + # export reconstruction to ply + save_ply(outputs, ply_out_path, num_gauss=2) + + return ply_out_path + + ply_out_path = f'./mesh.ply' + + css = """ + h1 { + text-align: center; + display:block; + } + """ + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Flash3D + """ + ) + with gr.Row(variant="panel"): + with gr.Column(scale=1): + with gr.Row(): + input_image = gr.Image( + label="Input Image", + image_mode="RGBA", + sources="upload", + type="pil", + elem_id="content_image", + ) + with gr.Row(): + submit = gr.Button("Generate", elem_id="generate", variant="primary") + + with gr.Row(variant="panel"): + gr.Examples( + examples=[ + './demo_examples/bedroom_01.png', + './demo_examples/kitti_02.png', + './demo_examples/kitti_03.png', + './demo_examples/re10k_04.jpg', + './demo_examples/re10k_05.jpg', + './demo_examples/re10k_06.jpg', + ], + inputs=[input_image], + cache_examples=False, + label="Examples", + examples_per_page=20, + ) + + with gr.Row(): + processed_image = gr.Image(label="Processed Image", interactive=False) + + with gr.Column(scale=2): + with gr.Row(): + with gr.Tab("Reconstruction"): + output_model = gr.Model3D( + height=512, + label="Output Model", + interactive=False + ) + + submit.click(fn=check_input_image, inputs=[input_image]).success( + fn=preprocess, + inputs=[input_image], + outputs=[processed_image], + ).success( + fn=reconstruct_and_export, + inputs=[processed_image], + outputs=[output_model], + ) + + demo.queue(max_size=1) + demo.launch(share=True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/demo_examples/demo_examples_bedroom_01.png b/demo_examples/demo_examples_bedroom_01.png new file mode 100644 index 0000000000000000000000000000000000000000..5e1e7f4940a28cde585f8be4e337e50e71e3a0ac Binary files /dev/null and b/demo_examples/demo_examples_bedroom_01.png differ diff --git a/demo_examples/demo_examples_kitti_02.png b/demo_examples/demo_examples_kitti_02.png new file mode 100644 index 0000000000000000000000000000000000000000..e4bcf249f1280a2df1553c4ae6bb8d123a0500c9 Binary files /dev/null and b/demo_examples/demo_examples_kitti_02.png differ diff --git a/demo_examples/demo_examples_kitti_03.png b/demo_examples/demo_examples_kitti_03.png new file mode 100644 index 0000000000000000000000000000000000000000..4e25037fff2e08c5f28e7dff1df50e72b0ede003 Binary files /dev/null and b/demo_examples/demo_examples_kitti_03.png differ diff --git a/demo_examples/demo_examples_re10k_04.jpg b/demo_examples/demo_examples_re10k_04.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3e87f75b23d49f1c1ad15c3fd5bf7dde20a33a26 Binary files /dev/null and b/demo_examples/demo_examples_re10k_04.jpg differ diff --git a/demo_examples/demo_examples_re10k_05 (1).jpg b/demo_examples/demo_examples_re10k_05 (1).jpg new file mode 100644 index 0000000000000000000000000000000000000000..e76a0c1d5febff7836fb7806ef93e052c11be3f9 Binary files /dev/null and b/demo_examples/demo_examples_re10k_05 (1).jpg differ diff --git a/demo_examples/demo_examples_re10k_05.jpg b/demo_examples/demo_examples_re10k_05.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e76a0c1d5febff7836fb7806ef93e052c11be3f9 Binary files /dev/null and b/demo_examples/demo_examples_re10k_05.jpg differ diff --git a/demo_examples/demo_examples_re10k_06.jpg b/demo_examples/demo_examples_re10k_06.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6f8532c7f5b6b7b1e50c6da24e7abdbe093ee7e1 Binary files /dev/null and b/demo_examples/demo_examples_re10k_06.jpg differ diff --git a/flash3d/networks/depth_decoder.py b/flash3d/networks/depth_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..059d8a8714f61bb1773aee60f786788d605f4eec --- /dev/null +++ b/flash3d/networks/depth_decoder.py @@ -0,0 +1,81 @@ +# Copyright Niantic 2019. Patent Pending. All rights reserved. +# +# This software is licensed under the terms of the Monodepth2 licence +# which allows for non-commercial use only, the full terms of which are made +# available in the LICENSE file. + +import numpy as np +import torch +import torch.nn as nn + +from collections import OrderedDict +from networks.layers import upsample, ConvBlock, Conv3x3 + +from einops import rearrange + + +class DepthDecoder(nn.Module): + def __init__(self, cfg, num_ch_enc, num_output_channels=1, use_skips=True): + super(DepthDecoder, self).__init__() + + self.cfg = cfg + depth_num = cfg.model.gaussians_per_pixel - 1 if "unidepth" in cfg.model.name else cfg.model.gaussians_per_pixel + self.num_output_channels = num_output_channels * depth_num + self.use_skips = use_skips + self.upsample_mode = 'nearest' + self.scales = cfg.model.scales + + self.num_ch_enc = num_ch_enc + self.num_ch_dec = np.array([16, 32, 64, 128, 256]) + + # decoder + self.convs = OrderedDict() + for i in range(4, -1, -1): + # upconv_0 + num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] + num_ch_out = self.num_ch_dec[i] + self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) + + # upconv_1 + num_ch_in = self.num_ch_dec[i] + if self.use_skips and i > 0: + num_ch_in += self.num_ch_enc[i - 1] + num_ch_out = self.num_ch_dec[i] + self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) + + for s in self.scales: + out = Conv3x3(self.num_ch_dec[s], self.num_output_channels) + self.convs[("dispconv", s)] = out + nn.init.xavier_uniform_(out.conv.weight, cfg.model.depth_scale) + nn.init.constant_(out.conv.bias, cfg.model.depth_bias) + + self.decoder = nn.ModuleList(list(self.convs.values())) + if cfg.model.depth_type in ["disp", "disp_inc"]: + self.activate = nn.Sigmoid() + elif cfg.model.depth_type == "depth": + self.activate = nn.Softplus() + elif cfg.model.depth_type == "depth_inc": + self.activate = torch.exp + + def forward(self, input_features): + outputs = {} + x = input_features[-1] + for i in range(4, -1, -1): + x = self.convs[("upconv", i, 0)](x) + x = [upsample(x)] + if self.use_skips and i > 0: + x += [input_features[i - 1]] + x = torch.cat(x, 1) + x = self.convs[("upconv", i, 1)](x) + if i in self.scales: + depth_num = self.cfg.model.gaussians_per_pixel - 1 if "unidepth" in self.cfg.model.name else self.cfg.model.gaussians_per_pixel + if self.cfg.model.depth_type == "depth_inc": + outputs[("depth", i)] = rearrange(self.activate(torch.clamp(self.convs[("dispconv", i)](x), min=-10.0, max=6.0)), + 'b (n c) ...-> (b n) c ...', n = depth_num) + elif self.cfg.model.depth_type in ["disp", "disp_inc"]: + outputs[("disp", i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)), + 'b (n c) ...-> (b n) c ...', n = depth_num) + else: + outputs[(self.cfg.model.depth_type, i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)), + 'b (n c) ...-> (b n) c ...', n = depth_num) + return outputs diff --git a/flash3d/networks/gaussian_decoder.py b/flash3d/networks/gaussian_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..38be79f6175a11be06463822cc24a1862212aa5f --- /dev/null +++ b/flash3d/networks/gaussian_decoder.py @@ -0,0 +1,196 @@ +from collections import OrderedDict +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +def upsample(x): + """Upsample input tensor by a factor of 2 + """ + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class Conv3x3(nn.Module): + """Layer to pad and convolve input + """ + def __init__(self, in_channels, out_channels, use_refl=True): + super(Conv3x3, self).__init__() + + if use_refl: + self.pad = nn.ReflectionPad2d(1) + else: + self.pad = nn.ZeroPad2d(1) + self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) + + def forward(self, x): + out = self.pad(x) + out = self.conv(out) + return out + + +class ConvBlock(nn.Module): + """Layer to perform a convolution followed by ELU + """ + def __init__(self, in_channels, out_channels): + super(ConvBlock, self).__init__() + + self.conv = Conv3x3(in_channels, out_channels) + self.nonlin = nn.ELU(inplace=True) + + def forward(self, x): + out = self.conv(x) + out = self.nonlin(out) + return out + + +def get_splits_and_inits(cfg): + split_dimensions = [] + scale_inits = [] + bias_inits = [] + + for g_idx in range(cfg.model.gaussians_per_pixel): + if cfg.model.predict_offset: + split_dimensions += [3] + scale_inits += [cfg.model.xyz_scale] + bias_inits += [cfg.model.xyz_bias] + + split_dimensions += [1, 3, 4, 3] + scale_inits += [cfg.model.opacity_scale, + cfg.model.scale_scale, + 1.0, + 5.0] + bias_inits += [cfg.model.opacity_bias, + np.log(cfg.model.scale_bias), + 0.0, + 0.0] + + if cfg.model.max_sh_degree != 0: + sh_num = (cfg.model.max_sh_degree + 1) ** 2 - 1 + sh_num_rgb = sh_num * 3 + split_dimensions.append(sh_num_rgb) + scale_inits.append(cfg.model.sh_scale) + bias_inits.append(0.0) + if not cfg.model.one_gauss_decoder: + break + + return split_dimensions, scale_inits, bias_inits, + + +class GaussianDecoder(nn.Module): + def __init__(self, cfg, num_ch_enc, use_skips=True): + super(GaussianDecoder, self).__init__() + + self.cfg = cfg + self.use_skips = use_skips + self.upsample_mode = 'nearest' + + self.num_ch_enc = num_ch_enc + self.num_ch_dec = np.array(cfg.model.num_ch_dec) + + split_dimensions, scale, bias = get_splits_and_inits(cfg) + + # [offset], opacity, scaling, rotation, feat_dc + assert not cfg.model.unified_decoder + + self.split_dimensions = split_dimensions + + self.num_output_channels = sum(self.split_dimensions) + + # decoder + self.convs = OrderedDict() + for i in range(4, -1, -1): + # upconv_0 + num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] + num_ch_out = self.num_ch_dec[i] + self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) + + # upconv_1 + num_ch_in = self.num_ch_dec[i] + if self.use_skips and i > 0: + num_ch_in += self.num_ch_enc[i - 1] + num_ch_out = self.num_ch_dec[i] + self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) + + self.out = nn.Conv2d(self.num_ch_dec[0], self.num_output_channels, 1) + + out_channels = self.split_dimensions + start_channels = 0 + for out_channel, b, s in zip(out_channels, bias, scale): + nn.init.xavier_uniform_( + self.out.weight[start_channels:start_channels+out_channel, + :, :, :], s) + nn.init.constant_( + self.out.bias[start_channels:start_channels+out_channel], b) + start_channels += out_channel + + self.decoder = nn.ModuleList(list(self.convs.values())) + + self.scaling_activation = torch.exp + self.opacity_activation = torch.sigmoid + self.rotation_activation = torch.nn.functional.normalize + self.scaling_lambda = cfg.model.scale_lambda + self.sigmoid = nn.Sigmoid() + + def forward(self, input_features): + self.outputs = {} + + # decoder + x = input_features[-1] + for i in range(4, -1, -1): + x = self.convs[("upconv", i, 0)](x) + x = [upsample(x)] + if self.use_skips and i > 0: + x += [input_features[i - 1]] + x = torch.cat(x, 1) + x = self.convs[("upconv", i, 1)](x) + + x = self.out(x) + + split_network_outputs = x.split(self.split_dimensions, dim=1) + + offset_list = [] + opacity_list = [] + scaling_list = [] + rotation_list = [] + feat_dc_list = [] + feat_rest_list = [] + + assert not self.cfg.model.unified_decoder + + for i in range(self.cfg.model.gaussians_per_pixel): + assert self.cfg.model.max_sh_degree > 0 + if self.cfg.model.predict_offset: + offset_s, opacity_s, scaling_s, \ + rotation_s, feat_dc_s, features_rest_s = split_network_outputs[i*6:(i+1)*6] + offset_list.append(offset_s[:, None, ...]) + else: + opacity_s, scaling_s, rotation_s, feat_dc_s, features_rest_s = split_network_outputs[i*5:(i+1)*5] + opacity_list.append(opacity_s[:, None, ...]) + scaling_list.append(scaling_s[:, None, ...]) + rotation_list.append(rotation_s[:, None, ...]) + feat_dc_list.append(feat_dc_s[:, None, ...]) + feat_rest_list.append(features_rest_s[:, None, ...]) + if not self.cfg.model.one_gauss_decoder: + break + + # squeezing will remove dimension if there is only one gaussian per pixel + opacity = torch.cat(opacity_list, dim=1).squeeze(1) + scaling = torch.cat(scaling_list, dim=1).squeeze(1) + rotation = torch.cat(rotation_list, dim=1).squeeze(1) + feat_dc = torch.cat(feat_dc_list, dim=1).squeeze(1) + features_rest = torch.cat(feat_rest_list, dim=1).squeeze(1) + + out = { + ("gauss_opacity", 0): self.opacity_activation(opacity), + ("gauss_scaling", 0): self.scaling_activation(scaling) * self.scaling_lambda, + ("gauss_rotation", 0): self.rotation_activation(rotation), + ("gauss_features_dc", 0): feat_dc, + ("gauss_features_rest", 0): features_rest + } + + if self.cfg.model.predict_offset: + offset = torch.cat(offset_list, dim=1).squeeze(1) + out[("gauss_offset", 0)] = offset + return out + diff --git a/flash3d/networks/gaussian_predictor.py b/flash3d/networks/gaussian_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..c0291462a0b793e22c69d8c73192ae3cff7b7ef3 --- /dev/null +++ b/flash3d/networks/gaussian_predictor.py @@ -0,0 +1,293 @@ +from pathlib import Path +import logging + +import torch +import torch.nn as nn +from einops import rearrange + +from networks.layers import BackprojectDepth, disp_to_depth +from networks.resnet_encoder import ResnetEncoder +from networks.depth_decoder import DepthDecoder +from networks.gaussian_decoder import GaussianDecoder + + +def default_param_group(model): + return [{'params': model.parameters()}] + + +def to_device(inputs, device): + for key, ipt in inputs.items(): + if isinstance(ipt, torch.Tensor): + inputs[key] = ipt.to(device) + return inputs + + +class GaussianPredictor(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + # checking height and width are multiples of 32 + # assert cfg.dataset.width % 32 == 0, "'width' must be a multiple of 32" + + models = {} + self.parameters_to_train = [] + + self.num_scales = len(cfg.model.scales) + + assert cfg.model.frame_ids[0] == 0, "frame_ids must start with 0" + + if cfg.model.use_stereo: + cfg.model.frame_ids.append("s") + + model_name = cfg.model.name + if model_name == "resnet": + models["encoder"] = ResnetEncoder( + cfg.model.num_layers, + cfg.model.weights_init == "pretrained", + cfg.model.resnet_bn_order + ) + self.parameters_to_train += default_param_group(models["encoder"]) + if not cfg.model.unified_decoder: + models["depth"] = DepthDecoder( + cfg, models["encoder"].num_ch_enc) + self.parameters_to_train += default_param_group(models["depth"]) + if cfg.model.gaussian_rendering: + for i in range(cfg.model.gaussians_per_pixel): + gauss_decoder = GaussianDecoder( + cfg, models["encoder"].num_ch_enc, + ) + self.parameters_to_train += default_param_group(gauss_decoder) + models["gauss_decoder_"+str(i)] = gauss_decoder + elif model_name == "unidepth": + from networks.unidepth import UniDepthSplatter + models["unidepth"] = UniDepthSplatter(cfg) + self.parameters_to_train += models["unidepth"].get_parameter_groups() + elif model_name in ["unidepth_unprojector_vit", "unidepth_unprojector_cnvnxtl"]: + from networks.unidepth import UniDepthUnprojector + models["unidepth"] = UniDepthUnprojector(cfg) + self.parameters_to_train += models["unidepth"].get_parameter_groups() + elif model_name in ["unidepth_extension_vit", "unidepth_extension_cnvnxtl"]: + from networks.unidepth_extension import UniDepthExtended + models["unidepth_extended"] = UniDepthExtended(cfg) + self.parameters_to_train += models["unidepth_extended"].get_parameter_groups() + + self.models = nn.ModuleDict(models) + + backproject_depth = {} + H = cfg.dataset.height + W = cfg.dataset.width + for scale in cfg.model.scales: + h = H // (2 ** scale) + w = W // (2 ** scale) + if cfg.model.shift_rays_half_pixel == "zero": + shift_rays_half_pixel = 0 + elif cfg.model.shift_rays_half_pixel == "forward": + shift_rays_half_pixel = 0.5 + elif cfg.model.shift_rays_half_pixel == "backward": + shift_rays_half_pixel = -0.5 + else: + raise NotImplementedError + backproject_depth[str(scale)] = BackprojectDepth( + cfg.optimiser.batch_size * cfg.model.gaussians_per_pixel, + # backprojection can be different if padding was used + h + 2 * self.cfg.dataset.pad_border_aug, + w + 2 * self.cfg.dataset.pad_border_aug, + shift_rays_half_pixel=shift_rays_half_pixel + ) + self.backproject_depth = nn.ModuleDict(backproject_depth) + + def set_train(self): + """Convert all models to training mode + """ + for m in self.models.values(): + m.train() + self._is_train = True + + def set_eval(self): + """Convert all models to testing/evaluation mode + """ + for m in self.models.values(): + m.eval() + self._is_train = False + + def is_train(self): + return self._is_train + + def forward(self, inputs): + cfg = self.cfg + B = cfg.optimiser.batch_size + + if cfg.model.name == "resnet": + do_flip = self.is_train() and \ + cfg.train.lazy_flip_augmentation and \ + (torch.rand(1) > .5).item() + # Otherwise, we only feed the image with frame_id 0 through the depth encoder + input_img = inputs["color_aug", 0, 0] + if do_flip: + input_img = torch.flip(input_img, dims=(-1, )) + features = self.models["encoder"](input_img) + if not cfg.model.unified_decoder: + outputs = self.models["depth"](features) + else: + outputs = dict() + + if self.cfg.model.gaussian_rendering: + # gauss_feats = self.models["gauss_encoder"](inputs["color_aug", 0, 0]) + input_f_id = 0 + gauss_feats = features + gauss_outs = dict() + for i in range(self.cfg.model.gaussians_per_pixel): + outs = self.models["gauss_decoder_"+str(i)](gauss_feats) + for key, v in outs.items(): + gauss_outs[key] = outs[key][:,None,...] if i==0 else torch.cat([gauss_outs[key], outs[key][:,None,...]], dim=1) + for key, v in gauss_outs.items(): + gauss_outs[key] = rearrange(gauss_outs[key], 'b n ... -> (b n) ...') + outputs |= gauss_outs + outputs = {(key[0], input_f_id, key[1]): v for key, v in outputs.items()} + else: + for scale in cfg.model.scales: + outputs[("disp", 0, scale)] = outputs[("disp", scale)] + + # unflip all outputs + if do_flip: + for k, v in outputs.items(): + outputs[k] = torch.flip(v, dims=(-1, )) + elif "unidepth" in cfg.model.name: + if cfg.model.name in ["unidepth", + "unidepth_unprojector_vit", + "unidepth_unprojector_cnvnxtl"]: + outputs = self.models["unidepth"](inputs) + elif cfg.model.name in ["unidepth_extension_vit", + "unidepth_extension_cnvnxtl"]: + outputs = self.models["unidepth_extended"](inputs) + + input_f_id = 0 + outputs = {(key[0], input_f_id, key[1]): v for key, v in outputs.items()} + + input_f_id = 0 + scale = 0 + if not ("depth", input_f_id, scale) in outputs: + disp = outputs[("disp", input_f_id, scale)] + _, depth = disp_to_depth(disp, cfg.model.min_depth, cfg.model.max_depth) + outputs[("depth", input_f_id, scale)] = depth + + self.compute_gauss_means(inputs, outputs) + + return outputs + + def target_tensor_image_dims(self, inputs): + B, _, H, W = inputs["color", 0, 0].shape + return B, H, W + + def compute_gauss_means(self, inputs, outputs): + cfg = self.cfg + input_f_id = 0 + scale = 0 + depth = outputs[("depth", input_f_id, scale)] + B, _, H, W = depth.shape + if ("inv_K_src", scale) in inputs: + inv_K = inputs[("inv_K_src", scale)] + else: + inv_K = outputs[("inv_K_src", input_f_id, scale)] + if self.cfg.model.gaussians_per_pixel > 1: + inv_K = rearrange(inv_K[:,None,...]. + repeat(1, self.cfg.model.gaussians_per_pixel, 1, 1), + 'b n ... -> (b n) ...') + xyz = self.backproject_depth[str(scale)]( + depth, inv_K + ) + inputs[("inv_K_src", scale)] = inv_K + if cfg.model.predict_offset: + offset = outputs[("gauss_offset", input_f_id, scale)] + if cfg.model.scaled_offset: + offset = offset * depth.detach() + offset = offset.view(B, 3, -1) + zeros = torch.zeros(B, 1, H * W, device=depth.device) + offset = torch.cat([offset, zeros], 1) + xyz = xyz + offset # [B, 4, W*H] + outputs[("gauss_means", input_f_id, scale)] = xyz + + def checkpoint_dir(self): + return Path("checkpoints") + + def save_model(self, optimizer, step, ema=None): + """Save model weights to disk + """ + save_folder = self.checkpoint_dir() + save_folder.mkdir(exist_ok=True, parents=True) + + save_path = save_folder / f"model_{step:07}.pth" + logging.info(f"saving checkpoint to {str(save_path)}") + + model = ema.ema_model if ema is not None else self + save_dict = { + "model": model.state_dict(), + "version": "1.0", + "optimiser": optimizer.state_dict(), + "step": step + } + torch.save(save_dict, save_path) + + num_ckpts = self.cfg.optimiser.num_keep_ckpts + ckpts = sorted(list(save_folder.glob("model_*.pth")), reverse=True) + if len(ckpts) > num_ckpts: + for ckpt in ckpts[num_ckpts:]: + ckpt.unlink() + + def load_model(self, weights_path, optimizer=None): + """Load model(s) from disk + """ + weights_path = Path(weights_path) + + # determine if it is an old or new saving format + if weights_path.is_dir() and weights_path.joinpath("encoder.pth").exists(): + self.load_model_old(weights_path, optimizer) + return + + logging.info(f"Loading weights from {weights_path}...") + state_dict = torch.load(weights_path) + if "version" in state_dict and state_dict["version"] == "1.0": + new_dict = {} + for k, v in state_dict["model"].items(): + if "backproject_depth" in k: + new_dict[k] = self.state_dict()[k].clone() + else: + new_dict[k] = v.clone() + # for k, v in state_dict["model"].items(): + # if "backproject_depth" in k and ("pix_coords" in k or "ones" in k): + # # model has these parameters set as a function of batch size + # # when batch size changes in eval this results in a loading error + # state_dict["model"][k] = v[:1, ...] + self.load_state_dict(new_dict, strict=False) + else: + # TODO remove loading according to the old format + for name in self.cfg.train.models_to_load: + if name not in self.models: + continue + self.models[name].load_state_dict(state_dict[name]) + + # loading adam state + if optimizer is not None: + optimizer.load_state_dict(state_dict["optimiser"]) + self.step = state_dict["step"] + + def load_model_old(self, weights_folder, optimizer=None): + for n in self.cfg.train.models_to_load: + print(f"Loading {n} weights...") + path = weights_folder / f"{n}.pth" + if n not in self.models: + continue + model_dict = self.models[n].state_dict() + pretrained_dict = torch.load(path) + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + model_dict.update(pretrained_dict) + self.models[n].load_state_dict(model_dict) + + # loading adam state + optimizer_load_path = weights_folder / "adam.pth" + if optimizer is not None and optimizer_load_path.is_file(): + print("Loading Adam weights") + optimizer_state = torch.load(optimizer_load_path) + optimizer.load_state_dict(optimizer_state["adam"]) + self.step = optimizer_state["step"] diff --git a/flash3d/networks/layers.py b/flash3d/networks/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..ad9b89b6b4e9f58e30fcde6ad9f81bdf7bf07caa --- /dev/null +++ b/flash3d/networks/layers.py @@ -0,0 +1,295 @@ +# Copyright Niantic 2019. Patent Pending. All rights reserved. +# +# This software is licensed under the terms of the Monodepth2 licence +# which allows for non-commercial use only, the full terms of which are made +# available in the LICENSE file. + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def disp_to_depth(disp, min_depth, max_depth): + """Convert network's sigmoid output into depth prediction + The formula for this conversion is given in the 'additional considerations' + section of the paper. + """ + min_disp = 1 / max_depth + max_disp = 1 / min_depth + scaled_disp = min_disp + (max_disp - min_disp) * disp + depth = 1 / scaled_disp + return scaled_disp, depth + + +def transformation_from_parameters(axisangle, translation, invert=False): + """Convert the network's (axisangle, translation) output into a 4x4 matrix + """ + R = rot_from_axisangle(axisangle) + t = translation.clone() + + if invert: + R = R.transpose(1, 2) + t *= -1 + + T = get_translation_matrix(t) + + if invert: + M = torch.matmul(R, T) + else: + M = torch.matmul(T, R) + + return M + + +def get_translation_matrix(translation_vector): + """Convert a translation vector into a 4x4 transformation matrix + """ + T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) + + t = translation_vector.contiguous().view(-1, 3, 1) + + T[:, 0, 0] = 1 + T[:, 1, 1] = 1 + T[:, 2, 2] = 1 + T[:, 3, 3] = 1 + T[:, :3, 3, None] = t + + return T + + +def rot_from_axisangle(vec): + """Convert an axisangle rotation into a 4x4 transformation matrix + (adapted from https://github.com/Wallacoloo/printipi) + Input 'vec' has to be Bx1x3 + """ + angle = torch.norm(vec, 2, 2, True) + axis = vec / (angle + 1e-7) + + ca = torch.cos(angle) + sa = torch.sin(angle) + C = 1 - ca + + x = axis[..., 0].unsqueeze(1) + y = axis[..., 1].unsqueeze(1) + z = axis[..., 2].unsqueeze(1) + + xs = x * sa + ys = y * sa + zs = z * sa + xC = x * C + yC = y * C + zC = z * C + xyC = x * yC + yzC = y * zC + zxC = z * xC + + rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) + + rot[:, 0, 0] = torch.squeeze(x * xC + ca) + rot[:, 0, 1] = torch.squeeze(xyC - zs) + rot[:, 0, 2] = torch.squeeze(zxC + ys) + rot[:, 1, 0] = torch.squeeze(xyC + zs) + rot[:, 1, 1] = torch.squeeze(y * yC + ca) + rot[:, 1, 2] = torch.squeeze(yzC - xs) + rot[:, 2, 0] = torch.squeeze(zxC - ys) + rot[:, 2, 1] = torch.squeeze(yzC + xs) + rot[:, 2, 2] = torch.squeeze(z * zC + ca) + rot[:, 3, 3] = 1 + + return rot + + +class ConvBlock(nn.Module): + """Layer to perform a convolution followed by ELU + """ + def __init__(self, in_channels, out_channels): + super(ConvBlock, self).__init__() + + self.conv = Conv3x3(in_channels, out_channels) + self.nonlin = nn.ELU(inplace=True) + + def forward(self, x): + out = self.conv(x) + out = self.nonlin(out) + return out + + +class Conv3x3(nn.Module): + """Layer to pad and convolve input + """ + def __init__(self, in_channels, out_channels, use_refl=True): + super(Conv3x3, self).__init__() + + if use_refl: + self.pad = nn.ReflectionPad2d(1) + else: + self.pad = nn.ZeroPad2d(1) + self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) + + def forward(self, x): + out = self.pad(x) + out = self.conv(out) + return out + + +class BackprojectDepth(nn.Module): + """Layer to transform a depth image into a point cloud + """ + def __init__(self, batch_size, height, width, shift_rays_half_pixel=0): + super(BackprojectDepth, self).__init__() + + self.batch_size = batch_size + self.height = height + self.width = width + + meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') + id_coords = np.stack(meshgrid, axis=0).astype(np.float32) + id_coords = torch.from_numpy(id_coords) + + ones = torch.ones(self.batch_size, 1, self.height * self.width) + + pix_coords = torch.unsqueeze(torch.stack( + [id_coords[0].view(-1), id_coords[1].view(-1)], 0), 0) + pix_coords = pix_coords.repeat(batch_size, 1, 1) + pix_coords = torch.cat([pix_coords + shift_rays_half_pixel, + ones], 1) + self.register_buffer("pix_coords", pix_coords) + self.register_buffer("id_coords", id_coords) + self.register_buffer("ones", ones) + # self.pix_coords = pix_coords + # self.ones = ones + + def forward(self, depth, inv_K): + cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords.to(depth.device)) + cam_points = depth.view(self.batch_size, 1, -1) * cam_points + cam_points = torch.cat([cam_points, self.ones.to(depth.device)], 1) + + return cam_points + + +class Project3D(nn.Module): + """Layer which projects 3D points into a camera with intrinsics K and at position T + """ + def __init__(self, batch_size, height, width, eps=1e-7): + super(Project3D, self).__init__() + + self.batch_size = batch_size + self.height = height + self.width = width + self.eps = eps + + def forward(self, points, K, T=None): + if T is None: + P = K + else: + P = torch.matmul(K, T) + P = P[:, :3, :] + + cam_points = torch.matmul(P, points) + + pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) + pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) + pix_coords = pix_coords.permute(0, 2, 3, 1) + pix_coords[..., 0] /= self.width - 1 + pix_coords[..., 1] /= self.height - 1 + pix_coords = (pix_coords - 0.5) * 2 + return pix_coords + + +class Project3DSimple(nn.Module): + """Layer which projects 3D points into a camera with intrinsics K and at position T + """ + def __init__(self, batch_size, height, width, eps=1e-7): + super(Project3DSimple, self).__init__() + + self.batch_size = batch_size + self.height = height + self.width = width + self.eps = eps + + def forward(self, points, K): + K = K[:, :3, :] + + cam_points = torch.matmul(K, points) + + pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) + pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) + pix_coords = pix_coords.permute(0, 2, 3, 1) + return pix_coords + +def upsample(x): + """Upsample input tensor by a factor of 2 + """ + return F.interpolate(x, scale_factor=2, mode="nearest") + + +def get_smooth_loss(disp, img): + """Computes the smoothness loss for a disparity image + The color image is used for edge-aware smoothness + """ + grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) + grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) + + grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) + grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) + + grad_disp_x *= torch.exp(-grad_img_x) + grad_disp_y *= torch.exp(-grad_img_y) + + return grad_disp_x.mean() + grad_disp_y.mean() + + +class SSIM(nn.Module): + """Layer to compute the SSIM loss between a pair of images + """ + def __init__(self): + super(SSIM, self).__init__() + self.mu_x_pool = nn.AvgPool2d(3, 1) + self.mu_y_pool = nn.AvgPool2d(3, 1) + self.sig_x_pool = nn.AvgPool2d(3, 1) + self.sig_y_pool = nn.AvgPool2d(3, 1) + self.sig_xy_pool = nn.AvgPool2d(3, 1) + + self.refl = nn.ReflectionPad2d(1) + + self.C1 = 0.01 ** 2 + self.C2 = 0.03 ** 2 + + def forward(self, x, y): + x = self.refl(x) + y = self.refl(y) + + mu_x = self.mu_x_pool(x) + mu_y = self.mu_y_pool(y) + + sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 + sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 + sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y + + SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) + SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) + + return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) + + +def compute_depth_errors(gt, pred): + """Computation of error metrics between predicted and ground truth depths + """ + thresh = torch.max((gt / pred), (pred / gt)) + a1 = (thresh < 1.25 ).float().mean() + a2 = (thresh < 1.25 ** 2).float().mean() + a3 = (thresh < 1.25 ** 3).float().mean() + + rmse = (gt - pred) ** 2 + rmse = torch.sqrt(rmse.mean()) + + rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 + rmse_log = torch.sqrt(rmse_log.mean()) + + abs_rel = torch.mean(torch.abs(gt - pred) / gt) + + sq_rel = torch.mean((gt - pred) ** 2 / gt) + + return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 diff --git a/flash3d/networks/resnet_encoder.py b/flash3d/networks/resnet_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..dad6056ebc32f59f7fbbdbc844bfb017b0160360 --- /dev/null +++ b/flash3d/networks/resnet_encoder.py @@ -0,0 +1,115 @@ +# Copyright Niantic 2019. Patent Pending. All rights reserved. +# +# This software is licensed under the terms of the Monodepth2 licence +# which allows for non-commercial use only, the full terms of which are made +# available in the LICENSE file. + +import numpy as np + +import torch +import torch.nn as nn +import torchvision.models as models + + +RESNETS = {18: (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1), + 50: (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V2)} + + +class ResNetMultiImageInput(models.ResNet): + """Constructs a resnet model with varying number of input images. + Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py + """ + def __init__(self, block, layers, num_classes=1000, num_input_images=1): + super(ResNetMultiImageInput, self).__init__(block, layers) + self.inplanes = 64 + self.conv1 = nn.Conv2d( + num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): + """Constructs a ResNet model. + Args: + num_layers (int): Number of resnet layers. Must be 18 or 50 + pretrained (bool): If True, returns a model pre-trained on ImageNet + num_input_images (int): Number of frames stacked as input + """ + assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" + blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] + block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] + model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) + model, weigths = RESNETS[num_layers] + + if pretrained: + loaded = torch.hub.load_state_dict_from_url(weigths.url) + loaded['conv1.weight'] = torch.cat( + [loaded['conv1.weight']] * num_input_images, 1) / num_input_images + model.load_state_dict(loaded) + return model + + +class ResnetEncoder(nn.Module): + """Pytorch module for a resnet encoder + """ + def __init__(self, num_layers, pretrained, bn_order, num_input_images=1): + super(ResnetEncoder, self).__init__() + + self.num_ch_enc = np.array([64, 64, 128, 256, 512]) + self.bn_order = bn_order + + if num_layers not in RESNETS: + raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) + + if num_input_images > 1: + self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) + else: + model, weights = RESNETS[num_layers] + self.encoder = model(weights=weights) + + if num_layers > 34: + self.num_ch_enc[1:] *= 4 + + def forward(self, input_image): + encoder = self.encoder + features = [] + x = (input_image - 0.45) / 0.225 + x = encoder.conv1(x) + + if self.bn_order == "pre_bn": + # Concatenating pre-norm features allows us to + # keep the scale and shift of RGB colours + # and recover them at output + features.append(x) + x = encoder.bn1(x) + x = encoder.relu(x) + features.append(encoder.layer1(encoder.maxpool(x))) + elif self.bn_order == "monodepth": + # Batchnorm gets rid of constants due to colour shift + # will make the network not able to recover absolute colour shift + # of the input image + # used in old models + x = encoder.bn1(x) + x = encoder.relu(x) + features.append(x) + features.append(encoder.layer1(encoder.maxpool(x))) + else: + assert False + + features.append(encoder.layer2(features[-1])) + features.append(encoder.layer3(features[-1])) + features.append(encoder.layer4(features[-1])) + + return features diff --git a/flash3d/networks/unidepth.py b/flash3d/networks/unidepth.py new file mode 100644 index 0000000000000000000000000000000000000000..346e382643867ef00bcdd410a5b58d60e9bdb574 --- /dev/null +++ b/flash3d/networks/unidepth.py @@ -0,0 +1,577 @@ +import json +from pathlib import Path +from typing import List, Tuple +from math import ceil +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from einops import rearrange + +from unidepth.models.unidepthv1 import UniDepthV1 +from unidepth.utils.constants import IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD +from unidepth.utils.geometric import ( + generate_rays, + spherical_zbuffer_to_euclidean, + flat_interpolate, +) +from unidepth.layers import ( + MLP, + AttentionBlock, + NystromBlock, + PositionEmbeddingSine, + ConvUpsample, +) +from unidepth.utils.sht import rsh_cart_8 + +from networks.gaussian_decoder import get_splits_and_inits + + +# inference helpers +def _paddings(image_shape, network_shape): + cur_h, cur_w = image_shape + h, w = network_shape + pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2 + pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2 + return pad_left, pad_right, pad_top, pad_bottom + + +def _shapes(image_shape, network_shape): + h, w = image_shape + input_ratio = w / h + output_ratio = network_shape[1] / network_shape[0] + if output_ratio > input_ratio: + ratio = network_shape[0] / h + elif output_ratio <= input_ratio: + ratio = network_shape[1] / w + return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio + + +def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes): + (pad_left, pad_right, pad_top, pad_bottom) = pads + rgbs = F.interpolate( + rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True + ) + rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant") + if intrinsics is not None: + intrinsics = intrinsics.clone() + intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio + intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio + intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left + intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top + return rgbs, intrinsics + return rgbs, None + + +def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes): + + (pad_left, pad_right, pad_top, pad_bottom) = pads + # pred mean, trim paddings, and upsample to input dim + predictions = sum( + [ + F.interpolate( + x, + size=shapes, + mode="bilinear", + align_corners=False, + antialias=True, + ) + for x in predictions + ] + ) / len(predictions) + + shapes = predictions.shape[2:] + predictions = predictions[ + ..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right + ] + + predictions = F.interpolate( + predictions, + size=original_shapes, + mode="bilinear", + align_corners=False, + antialias=True, + ) + + if intrinsics is not None: + intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio + intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio + intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio + intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio + + return predictions, intrinsics + + +def scale_intrinsics_xy(intrinsics, x_ratio, y_ratio): + intrinsics = intrinsics.clone() + intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * x_ratio + intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * y_ratio + intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * x_ratio + intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * y_ratio + return intrinsics + + +def scale_intrinsics(intrinsics, ratio): + intrinsics = intrinsics.clone() + intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio + intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio + intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + return intrinsics + + +def unidepthv1_forward(model, rgbs, intrinsics, skip_camera, + return_raw_preds=False): + B, _, H, W = rgbs.shape + + rgbs = TF.normalize( + rgbs, + mean=IMAGENET_DATASET_MEAN, + std=IMAGENET_DATASET_STD, + ) + + (h, w), ratio = _shapes((H, W), model.image_shape) + pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), model.image_shape) + rgbs, gt_intrinsics = _preprocess( + rgbs, + intrinsics, + (h, w), + (pad_left, pad_right, pad_top, pad_bottom), + ratio, + model.image_shape, + ) + + encoder_outputs, cls_tokens = model.pixel_encoder(rgbs) + if "dino" in model.pixel_encoder.__class__.__name__.lower(): + encoder_outputs = [ + (x + y.unsqueeze(1)).contiguous() + for x, y in zip(encoder_outputs, cls_tokens) + ] + + # get data for decoder and adapt to given camera + inputs = {} + inputs["encoder_outputs"] = encoder_outputs + inputs["cls_tokens"] = cls_tokens + inputs["image"] = rgbs + if gt_intrinsics is not None: + rays, angles = generate_rays( + gt_intrinsics, model.image_shape, noisy=False + ) + inputs["rays"] = rays + inputs["angles"] = angles + inputs["K"] = gt_intrinsics + model.pixel_decoder.test_fixed_camera = True + model.pixel_decoder.skip_camera = skip_camera + + # decode all + pred_intrinsics, predictions, features, rays = model.pixel_decoder(inputs, {}) + + pads = (pad_left, pad_right, pad_top, pad_bottom) + + # undo the reshaping and get original image size (slow) + predictions, pred_intrinsics = _postprocess( + predictions, + pred_intrinsics, + model.image_shape, + pads, + ratio, + (H, W), + ) + + if return_raw_preds: + return inputs, predictions + + # final 3D points backprojection + intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics + angles = generate_rays(intrinsics, (H, W), noisy=False)[-1] + angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W) + points_3d = torch.cat((angles, predictions), dim=1) + points_3d = spherical_zbuffer_to_euclidean( + points_3d.permute(0, 2, 3, 1) + ).permute(0, 3, 1, 2) + + # output data + outputs = { + "intrinsics": intrinsics, + "points": points_3d, + "depth": predictions[:, -1:], + "depth_feats": features, + "rays": rays, + "padding": pads + } + model.pixel_decoder.test_fixed_camera = False + model.pixel_decoder.skip_camera = False + return inputs, outputs + +class UniDepthDepth(nn.Module): + def __init__( + self, + cfg, + return_raw_preds=False + ): + super().__init__() + + self.cfg = cfg + self.return_raw_preds = return_raw_preds + + if "cnvnxtl" in cfg.model.name: + self.depth_prediction_model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-cnvnxtl") + elif "vit" in cfg.model.name: + self.depth_prediction_model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14") + + self.skip_camera = True + + def get_depth(self, img, intrinsics): + depth_inputs, outputs = unidepthv1_forward( + self.depth_prediction_model, + img, + intrinsics, + self.skip_camera, + return_raw_preds=self.return_raw_preds) + return outputs + + def forward(self, inputs): + input_img = inputs["color_aug", 0, 0] + # here we need the intrinsics of the source image to condition on + # the depth prediction. needs to account for padding + if ("K_src", 0) in inputs: + intrinsics = inputs[("K_src", 0)] + else: + intrinsics = None + + depth_inputs, outputs = unidepthv1_forward( + self.depth_prediction_model, + input_img, + intrinsics, + self.skip_camera, + return_raw_preds=self.return_raw_preds) + + return depth_inputs, outputs + +class UniDepthUnprojector(nn.Module): + def __init__( + self, + cfg + ): + super().__init__() + + self.cfg = cfg + + if cfg.model.name == "unidepth_unprojector_cnvnxtl": + model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-cnvnxtl") + elif cfg.model.name == "unidepth_unprojector_vit": + model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14") + self.unidepth = model + + self.skip_camera = True + + self.register_buffer("gauss_opacity", torch.ones(1, 1, 1).float()) + self.register_buffer("gauss_scaling", torch.ones(3, 1, 1).float()) + self.register_buffer("gauss_rotation", torch.ones(4, 1, 1).float() * 0.5) + self.register_buffer("gauss_features_rest", torch.zeros(9, 1, 1).float()) + self.register_buffer("gauss_offset", torch.zeros(3, 1, 1).float()) + + self.all_params = nn.ParameterDict({ + "opacity_scaling": nn.Parameter(torch.tensor(cfg.model.opacity_bias).float()), + "scale_scaling": nn.Parameter(torch.tensor(cfg.model.scale_bias).float()), + "colour_scaling": nn.Parameter(torch.tensor(self.cfg.model.colour_scale).float())}) + + + self.scaling_activation = torch.exp + self.opacity_activation = torch.sigmoid + self.relu = nn.ReLU() + + def get_parameter_groups(self): + # tune scalars for size, opacity and colour modulation + return [{'params': self.all_params.parameters()}] + + def forward(self, inputs): + model = self.unidepth + input_img = inputs["color_aug", 0, 0] + # here we need the intrinsics of the source image to condition on + # the depth prediction. needs to account for padding + intrinsics = inputs[("K_src", 0)] + b, c, h, w = inputs["color_aug", 0, 0].shape + + with torch.no_grad(): + _, depth_outs = unidepthv1_forward(model, input_img, intrinsics, self.skip_camera) + + outs = {} + + outs[("gauss_opacity", 0)] = self.gauss_opacity.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \ + * self.opacity_activation(self.all_params["opacity_scaling"]) + if not self.cfg.model.scale_with_depth: + outs[("gauss_scaling", 0)] = self.gauss_scaling.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \ + * self.scaling_activation(self.all_params["scale_scaling"]) + else: + outs[("gauss_scaling", 0)] = self.gauss_scaling.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \ + * self.scaling_activation(self.all_params["scale_scaling"]) * depth_outs["depth"] / 10.0 + outs[("gauss_rotation", 0)] = self.gauss_rotation.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) + outs[("gauss_offset", 0)] = self.gauss_offset.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) + outs[("gauss_features_rest", 0)] = self.gauss_features_rest.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) + # rendering adds 0.5 to go from rendered colours to output + outs[("gauss_features_dc", 0)] = (input_img - 0.5)* self.relu(self.all_params["colour_scaling"]) + + outs[("depth", 0)] = depth_outs["depth"] + + return outs + +class UniDepthSplatter(nn.Module): + def __init__( + self, + cfg + ): + super().__init__() + + self.cfg = cfg + + config_path = Path("/work/eldar/src/UniDepth") + with open(config_path / "configs/config_v1_cnvnxtl.json") as f: + config = json.load(f) + self.unidepth = UniDepthDepth(self.cfg) + + hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"] + expansion = config["model"]["expansion"] + depth = config["model"]["pixel_decoder"]["depths"] + num_heads = config["model"]["num_heads"] + dropout = config["model"]["pixel_decoder"]["dropout"] + layer_scale = 1.0 + self.splat_decoder = GaussSplatHead( + cfg, + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion=expansion, + depths=depth, + camera_dim=81, + dropout=dropout, + layer_scale=layer_scale, + ) + + self.skip_camera = True + + def get_parameter_groups(self): + base_lr = self.cfg.optimiser.learning_rate + return [ + {'params': self.unidepth.parameters(), "lr": base_lr * 0.05}, + {'params': self.splat_decoder.parameters()} + ] + + def forward(self, inputs): + gauss_head = self.splat_decoder + + depth_inputs, depth_outs = self.unidepth(inputs) + depth_feats = depth_outs["depth_feats"] + rays = depth_outs["rays"] + padding = depth_outs["padding"] + + B, _, H, W = depth_inputs["image"].shape + + # TODO remove hardcoded shapes + common_shape = (28, 38) + gauss_head.set_shapes(common_shape) + gauss_head.set_original_shapes((H, W)) + + depth_feats = rearrange(depth_feats, "b c h w -> b (h w) c") + outs = gauss_head( + latents_16=depth_feats, + rays_hr=rays, + ) + for k, v in outs.items(): + pred, _ = _postprocess([v], None, self.unidepth.depth_prediction_model.image_shape, + padding, None, inputs["color_aug", 0, 0].shape[2:4]) + outs[k] = pred + outs[("depth", 0)] = depth_outs["depth"] + + return outs + + +class GaussSplatHead(nn.Module): + def __init__( + self, + cfg, + hidden_dim: int, + num_heads: int = 8, + expansion: int = 4, + depths: int | list[int] = 4, + camera_dim: int = 256, + dropout: float = 0.0, + layer_scale: float = 1.0, + ) -> None: + super().__init__() + + self.cfg = cfg + + if isinstance(depths, int): + depths = [depths] * 3 + assert len(depths) == 3 + + self.project_rays16 = MLP( + camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim + ) + self.project_rays8 = MLP( + camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 2 + ) + self.project_rays4 = MLP( + camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 4 + ) + + self.layers_8 = nn.ModuleList([]) + self.layers_4 = nn.ModuleList([]) + layers_16 = nn.ModuleList([]) + + self.up8 = ConvUpsample( + hidden_dim, expansion=expansion, layer_scale=layer_scale + ) + self.up4 = ConvUpsample( + hidden_dim // 2, expansion=expansion, layer_scale=layer_scale + ) + self.up2 = ConvUpsample( + hidden_dim // 4, expansion=expansion, layer_scale=layer_scale + ) + + split_dimensions, scale, bias = get_splits_and_inits(cfg) + start = 1 + self.split_dimensions = split_dimensions[start:] + scale = scale[start:] + bias = bias[start:] + + self.num_output_channels = sum(self.split_dimensions) + + self.out2 = nn.Conv2d(hidden_dim // 8, self.num_output_channels, 3, padding=1) + # self.out4 = nn.Conv2d(hidden_dim // 4, self.num_output_channels, 3, padding=1) + # self.out8 = nn.Conv2d(hidden_dim // 2, self.num_output_channels, 3, padding=1) + + start_channels = 0 + for out_channel, b, s in zip(self.split_dimensions, bias, scale): + nn.init.xavier_uniform_( + self.out2.weight[start_channels:start_channels+out_channel, + :, :, :], s) + nn.init.constant_( + self.out2.bias[start_channels:start_channels+out_channel], b) + start_channels += out_channel + + for i, (blk_lst, depth) in enumerate( + zip([layers_16, self.layers_8, self.layers_4], depths) + ): + if i == 0: + continue + attn_cls = AttentionBlock if i == 0 else NystromBlock + for _ in range(depth): + blk_lst.append( + attn_cls( + hidden_dim // (2**i), + num_heads=num_heads // (2**i), + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + ) + ) + + self.scaling_activation = torch.exp + self.opacity_activation = torch.sigmoid + self.rotation_activation = torch.nn.functional.normalize + self.scaling_lambda = cfg.model.scale_lambda + self.sigmoid = nn.Sigmoid() + + def set_original_shapes(self, shapes: Tuple[int, int]): + self.original_shapes = shapes + + def set_shapes(self, shapes: Tuple[int, int]): + self.shapes = shapes + + def forward( + self, latents_16: torch.Tensor, rays_hr: torch.Tensor + ) -> torch.Tensor: + shapes = self.shapes + + # camera_embedding + # torch.cuda.synchronize() + # start = time() + rays_embedding_16 = F.normalize( + flat_interpolate(rays_hr, old=self.original_shapes, new=shapes), dim=-1 + ) + rays_embedding_8 = F.normalize( + flat_interpolate( + rays_hr, old=self.original_shapes, new=[x * 2 for x in shapes] + ), + dim=-1, + ) + rays_embedding_4 = F.normalize( + flat_interpolate( + rays_hr, old=self.original_shapes, new=[x * 4 for x in shapes] + ), + dim=-1, + ) + rays_embedding_16 = self.project_rays16(rsh_cart_8(rays_embedding_16)) + rays_embedding_8 = self.project_rays8(rsh_cart_8(rays_embedding_8)) + rays_embedding_4 = self.project_rays4(rsh_cart_8(rays_embedding_4)) + + # Block 16 - Out 8 + latents_8 = self.up8( + rearrange( + latents_16 + rays_embedding_16, + "b (h w) c -> b c h w", + h=shapes[0], + w=shapes[1], + ).contiguous() + ) + # out8 = self.out8( + # rearrange( + # latents_8, "b (h w) c -> b c h w", h=shapes[0] * 2, w=shapes[1] * 2 + # ) + # ) + + # Block 8 - Out 4 + for layer in self.layers_8: + latents_8 = layer(latents_8, pos_embed=rays_embedding_8) + latents_4 = self.up4( + rearrange( + latents_8 + rays_embedding_8, + "b (h w) c -> b c h w", + h=shapes[0] * 2, + w=shapes[1] * 2, + ).contiguous() + ) + # out4 = self.out4( + # rearrange( + # latents_4, "b (h w) c -> b c h w", h=shapes[0] * 4, w=shapes[1] * 4 + # ) + # ) + + # Block 4 - Out 2 + for layer in self.layers_4: + latents_4 = layer(latents_4, pos_embed=rays_embedding_4) + latents_2 = self.up2( + rearrange( + latents_4 + rays_embedding_4, + "b (h w) c -> b c h w", + h=shapes[0] * 4, + w=shapes[1] * 4, + ).contiguous() + ) + out2 = self.out2( + rearrange( + latents_2, "b (h w) c -> b c h w", h=shapes[0] * 8, w=shapes[1] * 8 + ) + ) + + split_network_outputs = out2.split(self.split_dimensions, dim=1) + last = 5 + offset, opacity, scaling, rotation, feat_dc = split_network_outputs[:last] + + out = { + ("gauss_opacity", 0): self.opacity_activation(opacity), + ("gauss_scaling", 0): self.scaling_activation(scaling) * self.scaling_lambda, + ("gauss_rotation", 0): self.rotation_activation(rotation), + ("gauss_features_dc", 0): feat_dc + } + + if self.cfg.model.max_sh_degree > 0: + features_rest = split_network_outputs[last] + out[("gauss_features_rest", 0)] = features_rest + + if self.cfg.model.predict_offset: + out[("gauss_offset", 0)] = offset + + return out + # return out8, out4, out2, proj_latents_16 diff --git a/flash3d/networks/unidepth_extension.py b/flash3d/networks/unidepth_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb1cffb9aaff211f55bd589e36d3713bfe42200 --- /dev/null +++ b/flash3d/networks/unidepth_extension.py @@ -0,0 +1,205 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from .unidepth import UniDepthDepth +from unidepth.models import UniDepthV1 +from .resnet_encoder import ResnetEncoder +from .gaussian_decoder import GaussianDecoder +from .depth_decoder import DepthDecoder + +from networks.layers import disp_to_depth +from networks.gaussian_decoder import get_splits_and_inits + + +class UniDepthExtended(nn.Module): + def __init__(self,cfg): + super().__init__() + + self.cfg = cfg + + self.unidepth = UniDepthDepth(cfg) + # self.unidepth = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14") + + self.parameters_to_train = [] + if self.cfg.model.splat_branch == "resnet": + self.encoder = ResnetEncoder(cfg.model.num_layers, + cfg.model.weights_init == "pretrained", + cfg.model.resnet_bn_order + ) + # change encoder to take depth as conditioning + if self.cfg.model.depth_cond: + self.encoder.encoder.conv1 = nn.Conv2d( + 4, + self.encoder.encoder.conv1.out_channels, + kernel_size = self.encoder.encoder.conv1.kernel_size, + padding = self.encoder.encoder.conv1.padding, + stride = self.encoder.encoder.conv1.stride + ) + self.parameters_to_train += [{"params": self.encoder.parameters()}] + + # use depth branch only for more gaussians + if cfg.model.gaussians_per_pixel > 1: + models ={} + models["depth"] = DepthDecoder(cfg, self.encoder.num_ch_enc) + self.parameters_to_train +=[{"params": models["depth"].parameters()}] + for i in range(cfg.model.gaussians_per_pixel): + models["gauss_decoder_"+str(i)] = GaussianDecoder(cfg, self.encoder.num_ch_enc) + self.parameters_to_train += [{"params": models["gauss_decoder_"+str(i)].parameters()}] + if cfg.model.one_gauss_decoder: + break + self.models = nn.ModuleDict(models) + else: + self.gauss_decoder = GaussianDecoder(cfg, self.encoder.num_ch_enc) + self.parameters_to_train += [{"params": self.gauss_decoder.parameters()}] + + elif self.cfg.model.splat_branch == "unidepth_vit" or self.cfg.model.splat_branch == "unidepth_cnvnxtl": + self.splat_branch = UniDepthDepth(cfg, + return_raw_preds=True) + # modify the head to output the channels for Gaussian parameters + self.init_ouput_head_splat_branch() + self.parameters_to_train +=[{"params": self.splat_branch.parameters()}] + + self.scaling_activation = torch.exp + self.opacity_activation = torch.sigmoid + self.rotation_activation = torch.nn.functional.normalize + + def init_ouput_head_splat_branch(self): + split_dimensions, scale, bias = get_splits_and_inits(self.cfg) + # the first dim in the output is for depth - we don't use that in this branch + self.split_dimensions = split_dimensions[1:] + scale = scale[1:] + bias = bias[1:] + + self.num_output_channels = sum(self.split_dimensions) + + self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2 = \ + nn.Conv2d(self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.in_channels, + self.num_output_channels, + kernel_size = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.kernel_size, + padding = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.padding) + + self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4 = \ + nn.Conv2d(self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.in_channels, + self.num_output_channels, + kernel_size = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.kernel_size, + padding = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.padding) + + self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8 = \ + nn.Conv2d(self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.in_channels, + self.num_output_channels, + kernel_size = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.kernel_size, + padding = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.padding) + + start_channels = 0 + for out_channel, b, s in zip(split_dimensions, bias, scale): + nn.init.xavier_uniform_( + self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.weight[start_channels:start_channels+out_channel, + :, :, :], s) + nn.init.constant_( + self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.bias[start_channels:start_channels+out_channel], b) + start_channels += out_channel + + start_channels = 0 + for out_channel, b, s in zip(split_dimensions, bias, scale): + nn.init.xavier_uniform_( + self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.weight[start_channels:start_channels+out_channel, + :, :, :], s) + nn.init.constant_( + self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.bias[start_channels:start_channels+out_channel], b) + start_channels += out_channel + + start_channels = 0 + for out_channel, b, s in zip(split_dimensions, bias, scale): + nn.init.xavier_uniform_( + self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.weight[start_channels:start_channels+out_channel, + :, :, :], s) + nn.init.constant_( + self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.bias[start_channels:start_channels+out_channel], b) + start_channels += out_channel + + def get_parameter_groups(self): + # only the resnet encoder and gaussian parameter decoder are optimisable + return self.parameters_to_train + + def forward(self, inputs): + if ('unidepth', 0, 0) in inputs.keys() and inputs[('unidepth', 0, 0)] is not None: + depth_outs = dict() + depth_outs["depth"] = inputs[('unidepth', 0, 0)] + else: + with torch.no_grad(): + # if self.training and self.cfg.dataset.pad_border_aug > 0: + # pad = self.cfg.dataset.pad_border_aug + # input = inputs["color_aug", 0, 0][:,:,pad:-pad, pad:-pad] + # intrincs = inputs[("K_tgt", 0)] + # else: + # input = inputs["color_aug", 0, 0] + # intrincs = inputs[("K_src", 0)] + _, depth_outs = self.unidepth(inputs) + # depth_outs = self.unidepth.infer(input, intrincs) + # if self.training and self.cfg.dataset.pad_border_aug > 0: + # depth_outs["depth"] = F.pad(depth_outs["depth"], (pad,pad,pad,pad), mode="replicate") + + outputs_gauss = {} + + K = depth_outs["intrinsics"] + outputs_gauss[("K_src", 0)] = K + outputs_gauss[("inv_K_src", 0)] = torch.linalg.inv(K) + + if self.cfg.model.splat_branch == "resnet": + if self.cfg.model.depth_cond: + # division by 20 is to put depth in a similar range to RGB + resnet_input = torch.cat([inputs["color_aug", 0, 0], + depth_outs["depth"] / 20.0], dim=1) + else: + resnet_input = inputs["color_aug", 0, 0] + resnet_features = self.encoder(resnet_input) + if self.cfg.model.gaussians_per_pixel > 1: + pred_depth = dict() + depth = self.models["depth"](resnet_features) + if self.cfg.model.depth_type == "disp": + for key, v in depth.items(): + _, pred_depth[("depth", key[1])] = disp_to_depth(v, self.cfg.model.min_depth, self.cfg.model.max_depth) + elif self.cfg.model.depth_type in ["depth", "depth_inc"]: + pred_depth = depth + pred_depth[("depth", 0)] = rearrange(pred_depth[("depth", 0)], "(b n) ... -> b n ...", n=self.cfg.model.gaussians_per_pixel - 1) + if self.cfg.model.depth_type in ["depth_inc", "disp_inc"]: + pred_depth[("depth", 0)] = torch.cumsum(torch.cat((depth_outs["depth"][:,None,...], pred_depth[("depth", 0)]), dim=1), dim=1) + else: + pred_depth[("depth", 0)] = torch.cat((depth_outs["depth"][:,None,...], pred_depth[("depth", 0)]), dim=1) + outputs_gauss[("depth", 0)] = rearrange(pred_depth[("depth", 0)], "b n c ... -> (b n) c ...", n = self.cfg.model.gaussians_per_pixel) + gauss_outs = dict() + for i in range(self.cfg.model.gaussians_per_pixel): + outs = self.models["gauss_decoder_"+str(i)](resnet_features) + if not self.cfg.model.one_gauss_decoder: + for key, v in outs.items(): + gauss_outs[key] = outs[key][:,None,...] if i==0 else torch.cat([gauss_outs[key], outs[key][:,None,...]], dim=1) + else: + gauss_outs |= outs + for key, v in gauss_outs.items(): + gauss_outs[key] = rearrange(gauss_outs[key], 'b n ... -> (b n) ...') + outputs_gauss |= gauss_outs + else: + outputs_gauss[("depth", 0)] = depth_outs["depth"] + outputs_gauss |= self.gauss_decoder(resnet_features) + elif self.cfg.model.splat_branch == "unidepth_vit" or self.cfg.model.splat_branch == "unidepth_cnvnxtl": + split_network_outputs = self.splat_branch(inputs)[1].split(self.split_dimensions, dim=1) + offset, opacity, scaling, rotation, feat_dc = split_network_outputs[:5] + + outputs_gauss |= { + ("gauss_opacity", 0): self.opacity_activation(opacity), + ("gauss_scaling", 0): self.scaling_activation(scaling), + ("gauss_rotation", 0): self.rotation_activation(rotation), + ("gauss_features_dc", 0): feat_dc + } + + if self.cfg.model.max_sh_degree > 0: + features_rest = split_network_outputs[5] + outputs_gauss[("gauss_features_rest", 0)] = features_rest + + assert self.cfg.model.predict_offset + outputs_gauss[("gauss_offset", 0)] = offset + + return outputs_gauss + diff --git a/flash3d/unidepth/layers/__init__.py b/flash3d/unidepth/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4acb2508c715186fadfa3b0441b8a0e981bd41e3 --- /dev/null +++ b/flash3d/unidepth/layers/__init__.py @@ -0,0 +1,21 @@ +from .activation import SwiGLU, GEGLU +from .convnext import CvnxtBlock +from .attention import AttentionBlock, AttentionDecoderBlock +from .nystrom_attention import NystromBlock +from .positional_encoding import PositionEmbeddingSine +from .upsample import ConvUpsample, ConvUpsampleShuffle +from .mlp import MLP + + +__all__ = [ + "SwiGLU", + "GEGLU", + "CvnxtBlock", + "AttentionBlock", + "NystromBlock", + "PositionEmbeddingSine", + "ConvUpsample", + "MLP", + "ConvUpsampleShuffle", + "AttentionDecoderBlock", +] diff --git a/flash3d/unidepth/layers/activation.py b/flash3d/unidepth/layers/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..f5787a340013ba59e2956b6b829f724d9cfb7fcc --- /dev/null +++ b/flash3d/unidepth/layers/activation.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SwiGLU(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gates = x.chunk(2, dim=-1) + return x * F.silu(gates) + + +class GEGLU(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gates = x.chunk(2, dim=-1) + return x * F.gelu(gates) diff --git a/flash3d/unidepth/layers/attention.py b/flash3d/unidepth/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c9fc5f79003e28815e65f9f8fe71474b7ed021a1 --- /dev/null +++ b/flash3d/unidepth/layers/attention.py @@ -0,0 +1,308 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from .layer_scale import LayerScale +from .mlp import MLP + + +class SimpleAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 4, + dropout: float = 0.0, + cosine: bool = False, + context_dim: int | None = None, + ): + super().__init__() + self.dropout = dropout + self.num_heads = num_heads + self.hidden_dim = dim + context_dim = context_dim or dim + + self.kv = nn.Linear(context_dim, dim * 2, bias=False) + self.q = nn.Linear(dim, dim, bias=False) + self.norm_attnx = nn.LayerNorm(dim) + self.norm_attnctx = nn.LayerNorm(context_dim) + self.cosine = cosine + self.out = nn.Linear(dim, dim) + + def forward( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + context = x if context is None else context + x = self.norm_attnx(x) + context = self.norm_attnctx(context) + k, v = rearrange( + self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + else: + if pos_embed is not None: + pos_embed = rearrange( + pos_embed, "b n (h d) -> b h n d", h=self.num_heads + ) + q = q + pos_embed + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out(x) + return x + + +class AttentionBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 4, + expansion: int = 4, + dropout: float = 0.0, + cosine: bool = False, + gated: bool = False, + layer_scale: float = 1.0, + context_dim: int | None = None, + ): + super().__init__() + self.dropout = dropout + self.num_heads = num_heads + self.hidden_dim = dim + context_dim = context_dim or dim + self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated) + self.kv = nn.Linear(context_dim, dim * 2) + self.q = nn.Linear(dim, dim) + self.norm_attnx = nn.LayerNorm(dim) + self.norm_attnctx = nn.LayerNorm(context_dim) + self.cosine = cosine + self.out = nn.Linear(dim, dim) + self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + + def attn( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + x = self.norm_attnx(x) + context = self.norm_attnctx(context) + k, v = rearrange( + self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + else: + if pos_embed is not None: + pos_embed = rearrange( + pos_embed, "b n (h d) -> b h n d", h=self.num_heads + ) + q = q + pos_embed + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out(x) + return x + + def forward( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + context = x if context is None else context + x = ( + self.ls1( + self.attn( + x, + rope=rope, + attn_bias=attn_bias, + context=context, + pos_embed=pos_embed, + pos_embed_context=pos_embed_context, + ) + ) + + x + ) + x = self.ls2(self.mlp(x)) + x + return x + + +class AttentionDecoderBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 4, + expansion: int = 4, + dropout: float = 0.0, + cosine: bool = False, + gated: bool = False, + layer_scale: float = 1.0, + context_dim: int | None = None, + single_head_ca: bool = True, + ): + super().__init__() + self.dropout = dropout + self.num_heads = num_heads + self.hidden_dim = dim + self.single_head_ca = single_head_ca + context_dim = context_dim or dim + self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated) + self.kv_ca = nn.Linear(context_dim, dim * 2) + self.q_ca = nn.Linear(dim, dim) + self.kv_sa = nn.Linear(dim, dim * 2) + self.q_sa = nn.Linear(dim, dim) + self.norm_x_sa = nn.LayerNorm(dim) + self.norm_x_ca = nn.LayerNorm(dim) + self.norm_ctx_ca = nn.LayerNorm(context_dim) + self.cosine = cosine + self.out_ca = nn.Linear(dim, dim) + self.out_sa = nn.Linear(dim, dim) + self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + self.ls3 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + + def cross_attn( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + num_heads = 1 if self.single_head_ca else self.num_heads + x = self.norm_x_ca(x) + context = self.norm_ctx_ca(context) + k, v = rearrange( + self.kv_ca(context), "b n (kv h d) -> b h n d kv", h=num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q_ca(x), "b n (h d) -> b h n d", h=num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + else: + if pos_embed is not None: + pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=num_heads) + q = q + pos_embed + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, "b n (h d) -> b h n d", h=num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out_ca(x) + return x + + def self_attn( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + x = self.norm_x_sa(x) + k, v = rearrange( + self.kv_sa(x), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q_sa(x), "b n (h d) -> b h n d", h=self.num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + elif pos_embed is not None: + pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=self.num_heads) + q = q + pos_embed + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out_sa(x) + return x + + def forward( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + context = x if context is None else context + x = ( + self.ls1( + self.cross_attn( + x, + rope=rope, + attn_bias=attn_bias, + context=context, + pos_embed=pos_embed, + pos_embed_context=pos_embed_context, + ) + ) + + x + ) + x = ( + self.ls2( + self.self_attn(x, rope=rope, attn_bias=attn_bias, pos_embed=pos_embed) + ) + + x + ) + x = self.ls3(self.mlp(x)) + x + return x diff --git a/flash3d/unidepth/layers/convnext.py b/flash3d/unidepth/layers/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..12a4e9a15e25433418d6b066f15a39a205f5aa81 --- /dev/null +++ b/flash3d/unidepth/layers/convnext.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn + + +class CvnxtBlock(nn.Module): + def __init__( + self, + dim, + kernel_size=7, + layer_scale=1.0, + expansion=4, + dilation=1, + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding="same", + groups=dim, + dilation=dilation, + ) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, expansion * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(expansion * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale * torch.ones((dim))) if layer_scale > 0.0 else 1.0 + ) + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + x = self.gamma * x + x = input + x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + return x diff --git a/flash3d/unidepth/layers/drop_path.py b/flash3d/unidepth/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..781ff566500c923b1f199542b0c7dfb862a077ca --- /dev/null +++ b/flash3d/unidepth/layers/drop_path.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + + +def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/flash3d/unidepth/layers/layer_scale.py b/flash3d/unidepth/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..01b6662490d7296725f103d1abf8790cac84d0f8 --- /dev/null +++ b/flash3d/unidepth/layers/layer_scale.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float | torch.Tensor = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/flash3d/unidepth/layers/mlp.py b/flash3d/unidepth/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..074b7e3949a12233e88a08877738e0ce2ca53acf --- /dev/null +++ b/flash3d/unidepth/layers/mlp.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn + +from unidepth.utils.misc import default +from .activation import SwiGLU + + +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + expansion: int = 4, + dropout: float = 0.0, + gated: bool = False, + output_dim: int | None = None, + ): + super().__init__() + if gated: + expansion = int(expansion * 2 / 3) + hidden_dim = int(input_dim * expansion) + output_dim = default(output_dim, input_dim) + self.norm = nn.LayerNorm(input_dim) + self.proj1 = nn.Linear(input_dim, hidden_dim) + self.proj2 = nn.Linear(hidden_dim, output_dim) + self.act = nn.GELU() if not gated else SwiGLU() + self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.proj1(x) + x = self.act(x) + x = self.proj2(x) + x = self.dropout(x) + return x diff --git a/flash3d/unidepth/layers/nystrom_attention.py b/flash3d/unidepth/layers/nystrom_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7476f114a68617bf64bc4cb51eec6c98445df5 --- /dev/null +++ b/flash3d/unidepth/layers/nystrom_attention.py @@ -0,0 +1,74 @@ +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from xformers.components.attention import NystromAttention + +from .attention import AttentionBlock + + +class NystromBlock(AttentionBlock): + def __init__( + self, + dim: int, + num_heads: int = 4, + expansion: int = 4, + dropout: float = 0.0, + cosine: bool = False, + gated: bool = False, + layer_scale: float = 1.0, + context_dim: int | None = None, + ): + super().__init__( + dim=dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + cosine=cosine, + gated=gated, + layer_scale=layer_scale, + context_dim=context_dim, + ) + self.attention_fn = NystromAttention( + num_landmarks=128, num_heads=num_heads, dropout=dropout + ) + + def attn( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + x = self.norm_attnx(x) + context = self.norm_attnctx(context) + k, v = rearrange( + self.kv(context), "b n (kv h d) -> b n h d kv", h=self.num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q(x), "b n (h d) -> b n h d", h=self.num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + else: + if pos_embed is not None: + pos_embed = rearrange( + pos_embed, "b n (h d) -> b n h d", h=self.num_heads + ) + q = q + pos_embed + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, "b n (h d) -> b n h d", h=self.num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + x = self.attention_fn(q, k, v, key_padding_mask=attn_bias) + x = rearrange(x, "b n h d -> b n (h d)") + x = self.out(x) + return x diff --git a/flash3d/unidepth/layers/positional_encoding.py b/flash3d/unidepth/layers/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..616dc197cf2e602e85085dc6f05957920f115cb5 --- /dev/null +++ b/flash3d/unidepth/layers/positional_encoding.py @@ -0,0 +1,228 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from math import pi +from typing import Optional + +import torch +import torch.nn as nn + +from einops import rearrange, repeat + + +class PositionEmbeddingSine(nn.Module): + def __init__( + self, num_pos_feats=64, temperature=10000, normalize=False, scale=None + ): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * pi + self.scale = scale + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if mask is None: + mask = torch.zeros( + (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool + ) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** ( + 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats + ) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self, _repr_indent=4): + head = "Positional encoding " + self.__class__.__name__ + body = [ + "num_pos_feats: {}".format(self.num_pos_feats), + "temperature: {}".format(self.temperature), + "normalize: {}".format(self.normalize), + "scale: {}".format(self.scale), + ] + # _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + +class LearnedSinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +def generate_fourier_features(x, max_freq=64, num_bands=16): + x = x.unsqueeze(-1) + device, dtype, orig_x = x.device, x.dtype, x + + scales = torch.linspace( + -max_freq / 2, max_freq / 2, num_bands, device=device, dtype=dtype + ) + scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] + + x = x * scales * pi + x = torch.cat([x.sin(), x.cos()], dim=-1) + x = torch.cat((x, orig_x), dim=-1) + return x.flatten(-2) + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class VisionRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs_h = torch.einsum("..., f -> ... f", t, freqs) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + + freqs_w = torch.einsum("..., f -> ... f", t, freqs) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + + freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) + + self.register_buffer("freqs_cos", freqs.cos()) + self.register_buffer("freqs_sin", freqs.sin()) + + print("======== shape of rope freq", self.freqs_cos.shape, "========") + + def forward(self, t, start_index=0): + rot_dim = self.freqs_cos.shape[-1] + end_index = start_index + rot_dim + assert ( + rot_dim <= t.shape[-1] + ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + t_left, t, t_right = ( + t[..., :start_index], + t[..., start_index:end_index], + t[..., end_index:], + ) + t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) + return torch.cat((t_left, t, t_right), dim=-1) + + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + def forward(self, t): + return t * self.freqs_cos + rotate_half(t) * self.freqs_sin diff --git a/flash3d/unidepth/layers/upsample.py b/flash3d/unidepth/layers/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..b0162e76c3319c8ac802b68795d6f32793e693b6 --- /dev/null +++ b/flash3d/unidepth/layers/upsample.py @@ -0,0 +1,69 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import torch +import torch.nn as nn +from einops import rearrange + +from .convnext import CvnxtBlock + + +class ConvUpsample(nn.Module): + def __init__( + self, + hidden_dim, + num_layers: int = 2, + expansion: int = 4, + layer_scale: float = 1.0, + kernel_size: int = 7, + **kwargs + ): + super().__init__() + self.convs = nn.ModuleList([]) + for _ in range(num_layers): + self.convs.append( + CvnxtBlock( + hidden_dim, + kernel_size=kernel_size, + expansion=expansion, + layer_scale=layer_scale, + ) + ) + self.up = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0), + nn.UpsamplingBilinear2d(scale_factor=2), + nn.Conv2d(hidden_dim // 2, hidden_dim // 2, kernel_size=3, padding=1), + ) + + def forward(self, x: torch.Tensor): + for conv in self.convs: + x = conv(x) + x = self.up(x) + x = rearrange(x, "b c h w -> b (h w) c") + return x + + +class ConvUpsampleShuffle(nn.Module): + def __init__( + self, hidden_dim, expansion: int = 4, layer_scale: float = 1.0, **kwargs + ): + super().__init__() + self.conv1 = CvnxtBlock( + hidden_dim, expansion=expansion, layer_scale=layer_scale + ) + self.conv2 = CvnxtBlock( + hidden_dim, expansion=expansion, layer_scale=layer_scale + ) + self.up = nn.Sequential( + nn.PixelShuffle(2), + nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1), + ) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) + x = self.conv2(x) + x = self.up(x) + x = rearrange(x, "b c h w -> b (h w) c") + return x diff --git a/flash3d/unidepth/models/__init__.py b/flash3d/unidepth/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1781bda94cdc13b0c0c805e7cde0872defc20cd3 --- /dev/null +++ b/flash3d/unidepth/models/__init__.py @@ -0,0 +1,5 @@ +from .unidepthv1 import UniDepthV1 + +__all__ = [ + "UniDepthV1", +] diff --git a/flash3d/unidepth/models/backbones/__init__.py b/flash3d/unidepth/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..55f55cde8b365f0f63f98fcc21ea64166c7ce33c --- /dev/null +++ b/flash3d/unidepth/models/backbones/__init__.py @@ -0,0 +1,9 @@ +from .convnext2 import ConvNeXtV2 +from .convnext import ConvNeXt +from .dinov2 import _make_dinov2_model + +__all__ = [ + "ConvNeXt", + "ConvNeXtV2", + "_make_dinov2_model", +] diff --git a/flash3d/unidepth/models/backbones/convnext.py b/flash3d/unidepth/models/backbones/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..b745415724df69347697efc9987c3b8a6c9cb849 --- /dev/null +++ b/flash3d/unidepth/models/backbones/convnext.py @@ -0,0 +1,590 @@ +from collections import OrderedDict +from functools import partial +from typing import Callable, Optional, Tuple, Union, Sequence + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from timm.layers import ( + trunc_normal_, + AvgPool2dSame, + DropPath, + Mlp, + GlobalResponseNormMlp, + LayerNorm2d, + LayerNorm, + create_conv2d, + get_act_layer, + make_divisible, + to_ntuple, +) + + +def get_num_layer_for_convnext(var_name): + """ + Divide [3, 3, 27, 3] layers into 12 groups; each group is three + consecutive blocks, including possible neighboring downsample layers; + adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py + """ + if var_name.startswith("downsample_layers"): + stage_id = int(var_name.split(".")[1]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1 or stage_id == 2: + layer_id = stage_id + 1 + elif stage_id == 3: + layer_id = 12 + + elif var_name.startswith("stages"): + stage_id = int(var_name.split(".")[1]) + block_id = int(var_name.split(".")[3]) + if stage_id == 0 or stage_id == 1: + layer_id = stage_id + 1 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + elif stage_id == 3: + layer_id = 12 + + elif var_name.startswith("stem"): + return 0 + else: + layer_id = 12 + return layer_id + 1 + + +def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=None): + parameter_group_names = {} + parameter_group_vars = {} + skip = set() + if skip_list is not None: + skip = skip_list + if hasattr(model, "no_weight_decay"): + skip.update(model.no_weight_decay()) + num_layers = 12 + layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2)) + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias") or name in skip: + group_name = "no_decay" + this_wd = 0.0 + else: + group_name = "decay" + this_wd = wd + + layer_id = get_num_layer_for_convnext(name) + group_name = "layer_%d_%s" % (layer_id, group_name) + + if group_name not in parameter_group_names: + scale = layer_scale[layer_id] + cur_lr = lr * scale + + parameter_group_names[group_name] = { + "weight_decay": this_wd, + "weight_decay_init": this_wd, + "weight_decay_base": this_wd, + "params": [], + "lr_init": cur_lr, + "lr_base": lr, + "lr": cur_lr, + } + parameter_group_vars[group_name] = { + "weight_decay": this_wd, + "weight_decay_init": this_wd, + "weight_decay_base": this_wd, + "params": [], + "lr_init": cur_lr, + "lr_base": lr, + "lr": cur_lr, + } + if this_wd == 0.0: + parameter_group_names[group_name]["weight_decay_final"] = 0.0 + parameter_group_vars[group_name]["weight_decay_final"] = 0.0 + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + # from unidepth.utils import is_main_process + # import json + # if is_main_process(): + # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) + return list(parameter_group_vars.values()), [ + v["lr"] for k, v in parameter_group_vars.items() + ] + + +class Downsample(nn.Module): + def __init__(self, in_chs, out_chs, stride=1, dilation=1): + super().__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = ( + AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + ) + self.pool = avg_pool_fn( + 2, avg_stride, ceil_mode=True, count_include_pad=False + ) + else: + self.pool = nn.Identity() + + if in_chs != out_chs: + self.conv = create_conv2d(in_chs, out_chs, 1, stride=1) + else: + self.conv = nn.Identity() + + def forward(self, x): + x = self.pool(x) + x = self.conv(x) + return x + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block + There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + + Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate + choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear + is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. + """ + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 7, + stride: int = 1, + dilation: Union[int, Tuple[int, int]] = (1, 1), + mlp_ratio: float = 4, + conv_mlp: bool = False, + conv_bias: bool = True, + use_grn: bool = False, + ls_init_value: Optional[float] = 1e-6, + act_layer: Union[str, Callable] = "gelu", + norm_layer: Optional[Callable] = None, + drop_path: float = 0.0, + ): + """ + + Args: + in_chs: Block input channels. + out_chs: Block output channels (same as in_chs if None). + kernel_size: Depthwise convolution kernel size. + stride: Stride of depthwise convolution. + dilation: Tuple specifying input and output dilation of block. + mlp_ratio: MLP expansion ratio. + conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True. + conv_bias: Apply bias for all convolution (linear) layers. + use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2) + ls_init_value: Layer-scale init values, layer-scale applied if not None. + act_layer: Activation layer. + norm_layer: Normalization layer (defaults to LN if not specified). + drop_path: Stochastic depth probability. + """ + super().__init__() + out_chs = out_chs or in_chs + dilation = to_ntuple(2)(dilation) + act_layer = get_act_layer(act_layer) + if not norm_layer: + norm_layer = LayerNorm2d if conv_mlp else LayerNorm + mlp_layer = partial( + GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp + ) + self.use_conv_mlp = conv_mlp + self.conv_dw = create_conv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation[0], + depthwise=True, + bias=conv_bias, + ) + self.norm = norm_layer(out_chs) + self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer) + self.gamma = ( + nn.Parameter(ls_init_value * torch.ones(out_chs)) + if ls_init_value is not None + else None + ) + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + self.shortcut = Downsample( + in_chs, out_chs, stride=stride, dilation=dilation[0] + ) + else: + self.shortcut = nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.conv_dw(x.contiguous()) + if self.use_conv_mlp: + x = self.norm(x) + x = self.mlp(x) + else: + x = x.permute(0, 2, 3, 1).contiguous() + x = self.norm(x) + x = self.mlp(x) + x = x.permute(0, 3, 1, 2).contiguous() + if self.gamma is not None: + x = x.mul(self.gamma.reshape(1, -1, 1, 1)) + + x = self.drop_path(x) + self.shortcut(shortcut) + return x.contiguous() + + +class ConvNeXtStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + kernel_size=7, + stride=2, + depth=2, + dilation=(1, 1), + drop_path_rates=None, + ls_init_value=1.0, + conv_mlp=False, + conv_bias=True, + use_grn=False, + act_layer="gelu", + norm_layer=None, + norm_layer_cl=None, + ): + super().__init__() + self.grad_checkpointing = False + + if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]: + ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1 + pad = ( + "same" if dilation[1] > 1 else 0 + ) # same padding needed if dilation used + self.downsample = nn.Sequential( + norm_layer(in_chs), + create_conv2d( + in_chs, + out_chs, + kernel_size=ds_ks, + stride=stride, + dilation=dilation[0], + padding=pad, + bias=conv_bias, + ), + ) + in_chs = out_chs + else: + self.downsample = nn.Identity() + + drop_path_rates = drop_path_rates or [0.0] * depth + stage_blocks = [] + for i in range(depth): + stage_blocks.append( + ConvNeXtBlock( + in_chs=in_chs, + out_chs=out_chs, + kernel_size=kernel_size, + dilation=dilation[1], + drop_path=drop_path_rates[i], + ls_init_value=ls_init_value, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + use_grn=use_grn, + act_layer=act_layer, + norm_layer=norm_layer if conv_mlp else norm_layer_cl, + ) + ) + in_chs = out_chs + self.blocks = nn.ModuleList(stage_blocks) + + def forward(self, x): + xs = [] + x = self.downsample(x) + for block in self.blocks: + if self.grad_checkpointing: + x = checkpoint(block, x) + else: + x = block(x) + xs.append(x) + return xs + + +class ConvNeXt(nn.Module): + def __init__( + self, + in_chans: int = 3, + output_stride: int = 32, + depths: Tuple[int, ...] = (3, 3, 9, 3), + dims: Tuple[int, ...] = (96, 192, 384, 768), + kernel_sizes: Union[int, Tuple[int, ...]] = 7, + ls_init_value: Optional[float] = 1e-6, + stem_type: str = "patch", + patch_size: int = 4, + conv_mlp: bool = False, + conv_bias: bool = True, + use_grn: bool = False, + act_layer: Union[str, Callable] = "gelu", + norm_layer: Optional[Union[str, Callable]] = None, + norm_eps: Optional[float] = None, + drop_path_rate: float = 0.0, + output_idx=[], + use_checkpoint=False, + ): + """ + Args: + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + global_pool: Global pooling type. + output_stride: Output stride of network, one of (8, 16, 32). + depths: Number of blocks at each stage. + dims: Feature dimension at each stage. + kernel_sizes: Depthwise convolution kernel-sizes for each stage. + ls_init_value: Init value for Layer Scale, disabled if None. + stem_type: Type of stem. + patch_size: Stem patch size for patch stem. + head_init_scale: Init scaling value for classifier weights and biases. + head_norm_first: Apply normalization before global pool + head. + head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False. + conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last. + conv_bias: Use bias layers w/ all convolutions. + use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP. + act_layer: Activation layer type. + norm_layer: Normalization layer type. + drop_rate: Head pre-classifier dropout rate. + drop_path_rate: Stochastic depth drop rate. + """ + super().__init__() + self.num_layers = len(depths) + self.depths = output_idx + self.embed_dims = [ + int(dim) for i, dim in enumerate(dims) for _ in range(depths[i]) + ] + self.embed_dim = dims[0] + + assert output_stride in (8, 16, 32) + kernel_sizes = to_ntuple(4)(kernel_sizes) + if norm_layer is None: + norm_layer = LayerNorm2d + norm_layer_cl = norm_layer if conv_mlp else LayerNorm + if norm_eps is not None: + norm_layer = partial(norm_layer, eps=norm_eps) + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) + else: + assert ( + conv_mlp + ), "If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input" + norm_layer_cl = norm_layer + if norm_eps is not None: + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) + + self.feature_info = [] + + assert stem_type in ("patch", "overlap", "overlap_tiered") + if stem_type == "patch": + # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 + self.stem = nn.Sequential( + nn.Conv2d( + in_chans, + dims[0], + kernel_size=patch_size, + stride=patch_size, + bias=conv_bias, + ), + norm_layer(dims[0]), + ) + stem_stride = patch_size + else: + mid_chs = make_divisible(dims[0] // 2) if "tiered" in stem_type else dims[0] + self.stem = nn.Sequential( + nn.Conv2d( + in_chans, + mid_chs, + kernel_size=3, + stride=2, + padding=1, + bias=conv_bias, + ), + nn.Conv2d( + mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias + ), + norm_layer(dims[0]), + ) + stem_stride = 4 + + self.stages = nn.Sequential() + dp_rates = [ + x.tolist() + for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths) + ] + stages = [] + prev_chs = dims[0] + curr_stride = stem_stride + dilation = 1 + # 4 feature resolution stages, each consisting of multiple residual blocks + for i in range(4): + stride = 2 if curr_stride == 2 or i > 0 else 1 + if curr_stride >= output_stride and stride > 1: + dilation *= stride + stride = 1 + curr_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + out_chs = dims[i] + stages.append( + ConvNeXtStage( + prev_chs, + out_chs, + kernel_size=kernel_sizes[i], + stride=stride, + dilation=(first_dilation, dilation), + depth=depths[i], + drop_path_rates=dp_rates[i], + ls_init_value=ls_init_value, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + use_grn=use_grn, + act_layer=act_layer, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + ) + ) + prev_chs = out_chs + # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 + self.feature_info += [ + dict(num_chs=prev_chs, reduction=curr_stride, module=f"stages.{i}") + ] + self.stages = nn.ModuleList(stages) + self.mask_token = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1)) + self.num_features = prev_chs + self.apply(self._init_weights) + self.set_grad_checkpointing(use_checkpoint) + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + nn.init.zeros_(module.bias) + + def forward(self, x, masks=None): + outs = [] + x = self.stem(x) + if masks is not None: + masks = torch.nn.functional.interpolate( + masks.float(), size=x.shape[-2:], mode="nearest" + ) + x = torch.where(masks.bool(), self.mask_token.to(x.dtype), x).contiguous() + for stage in self.stages: + xs = stage(x) + outs.extend([x.permute(0, 2, 3, 1).contiguous() for x in xs]) + x = xs[-1] + return outs, [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs] + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r"^stem", + blocks=( + r"^stages\.(\d+)" + if coarse + else [ + (r"^stages\.(\d+)\.downsample", (0,)), # blocks + (r"^stages\.(\d+)\.blocks\.(\d+)", None), + (r"^norm_pre", (99999,)), + ] + ), + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + def freeze(self) -> None: + for module in self.modules(): + module.eval() + for parameters in self.parameters(): + parameters.requires_grad = False + + def get_params(self, lr, wd, ld, *args, **kwargs): + encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld) + return encoder_p, encoder_lr + + def no_weight_decay(self): + return {"mask_token"} + + @classmethod + def build(cls, config): + obj = globals()[config["model"]["encoder"]["name"]](config) + return obj + + +def checkpoint_filter_fn(state_dict, model): + """Remap FB checkpoints -> timm""" + if "head.norm.weight" in state_dict or "norm_pre.weight" in state_dict: + return state_dict # non-FB checkpoint + if "model" in state_dict: + state_dict = state_dict["model"] + + out_dict = {} + if "visual.trunk.stem.0.weight" in state_dict: + out_dict = { + k.replace("visual.trunk.", ""): v + for k, v in state_dict.items() + if k.startswith("visual.trunk.") + } + if "visual.head.proj.weight" in state_dict: + out_dict["head.fc.weight"] = state_dict["visual.head.proj.weight"] + out_dict["head.fc.bias"] = torch.zeros( + state_dict["visual.head.proj.weight"].shape[0] + ) + elif "visual.head.mlp.fc1.weight" in state_dict: + out_dict["head.pre_logits.fc.weight"] = state_dict[ + "visual.head.mlp.fc1.weight" + ] + out_dict["head.pre_logits.fc.bias"] = state_dict["visual.head.mlp.fc1.bias"] + out_dict["head.fc.weight"] = state_dict["visual.head.mlp.fc2.weight"] + out_dict["head.fc.bias"] = torch.zeros( + state_dict["visual.head.mlp.fc2.weight"].shape[0] + ) + return out_dict + + import re + + for k, v in state_dict.items(): + k = k.replace("downsample_layers.0.", "stem.") + k = re.sub(r"stages.([0-9]+).([0-9]+)", r"stages.\1.blocks.\2", k) + k = re.sub( + r"downsample_layers.([0-9]+).([0-9]+)", r"stages.\1.downsample.\2", k + ) + k = k.replace("dwconv", "conv_dw") + k = k.replace("pwconv", "mlp.fc") + if "grn" in k: + k = k.replace("grn.beta", "mlp.grn.bias") + k = k.replace("grn.gamma", "mlp.grn.weight") + v = v.reshape(v.shape[-1]) + k = k.replace("head.", "head.fc.") + if k.startswith("norm."): + k = k.replace("norm", "head.norm") + if v.ndim == 2 and "head" not in k: + model_shape = model.state_dict()[k].shape + v = v.reshape(model_shape) + out_dict[k] = v + + return out_dict + + +HF_URL = { + "convnext_xxlarge_pt": ( + "laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup", + "open_clip_pytorch_model.bin", + ), + "convnext_large_pt": ( + "laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup", + "open_clip_pytorch_model.bin", + ), + "convnext_large": ( + "timm/convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384", + "pytorch_model.bin", + ), +} diff --git a/flash3d/unidepth/models/backbones/convnext2.py b/flash3d/unidepth/models/backbones/convnext2.py new file mode 100644 index 0000000000000000000000000000000000000000..793538172b043f683d0856ddd68e48355774ca46 --- /dev/null +++ b/flash3d/unidepth/models/backbones/convnext2.py @@ -0,0 +1,288 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import trunc_normal_, DropPath + + +def get_num_layer_for_convnext_single(var_name, depths): + """ + Each layer is assigned distinctive layer ids + """ + if var_name.startswith("downsample_layers"): + stage_id = int(var_name.split(".")[1]) + layer_id = sum(depths[:stage_id]) + 1 + return layer_id + + elif var_name.startswith("stages"): + stage_id = int(var_name.split(".")[1]) + block_id = int(var_name.split(".")[2]) + layer_id = sum(depths[:stage_id]) + block_id + 1 + return layer_id + + else: + return sum(depths) + 1 + + +def get_num_layer_for_convnext(var_name): + """ + Divide [3, 3, 27, 3] layers into 12 groups; each group is three + consecutive blocks, including possible neighboring downsample layers; + adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py + """ + num_max_layer = 12 + if var_name.startswith("downsample_layers"): + stage_id = int(var_name.split(".")[1]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1 or stage_id == 2: + layer_id = stage_id + 1 + elif stage_id == 3: + layer_id = 12 + return layer_id + + elif var_name.startswith("stages"): + stage_id = int(var_name.split(".")[1]) + block_id = int(var_name.split(".")[2]) + if stage_id == 0 or stage_id == 1: + layer_id = stage_id + 1 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + elif stage_id == 3: + layer_id = 12 + return layer_id + else: + return num_max_layer + 1 + + +def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()): + parameter_group_names = {} + parameter_group_vars = {} + skip = {} + if skip_list is not None: + skip = skip_list + elif hasattr(model, "no_weight_decay"): + skip = model.no_weight_decay() + num_layers = 12 # sum(model.depths) + layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2)) + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if ( + len(param.shape) == 1 + or name.endswith(".bias") + or name in skip + or name.endswith(".gamma") + or name.endswith(".beta") + ): + group_name = "no_decay" + this_weight_decay = 0.0 + else: + group_name = "decay" + this_weight_decay = wd + + # layer_id = get_num_layer_for_convnext_single(name, model.depths) + layer_id = get_num_layer_for_convnext(name) + group_name = "layer_%d_%s" % (layer_id, group_name) + + if group_name not in parameter_group_names: + scale = layer_scale[layer_id] + cur_lr = lr * scale + + parameter_group_names[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale, + "lr": cur_lr, + } + parameter_group_vars[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale, + "lr": cur_lr, + } + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + # if is_main_process(): + # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) + return list(parameter_group_vars.values()), [ + v["lr"] for k, v in parameter_group_vars.items() + ] + + +class LayerNorm(nn.Module): + """LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class GRN(nn.Module): + """GRN (Global Response Normalization) layer""" + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class Block(nn.Module): + """ConvNeXtV2 Block. + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + """ + + def __init__(self, dim, drop_path=0.0, mult=4, use_checkpoint=False): + super().__init__() + self.dwconv = nn.Conv2d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, mult * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(mult * dim) + self.pwconv2 = nn.Linear(mult * dim, dim) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.use_checkpoint = use_checkpoint + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ConvNeXtV2(nn.Module): + """ConvNeXt V2 + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__( + self, + in_chans=3, + depths=[3, 3, 9, 3], + dims=96, + drop_path_rate=0.0, + output_idx=[], + use_checkpoint=False, + ): + super().__init__() + self.num_layers = len(depths) + self.depths = output_idx + self.embed_dims = [ + int(dim) for i, dim in enumerate(dims) for _ in range(depths[i]) + ] + self.embed_dim = dims[0] + + self.downsample_layers = ( + nn.ModuleList() + ) # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = ( + nn.ModuleList() + ) # 4 feature resolution stages, each consisting of multiple residual blocks + self.out_norms = nn.ModuleList() + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = nn.ModuleList( + [ + Block( + dim=dims[i], + drop_path=dp_rates[cur + j], + use_checkpoint=use_checkpoint, + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + cur += depths[i] + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + outs = [] + for i in range(4): + x = self.downsample_layers[i](x) + for stage in self.stages[i]: + x = stage(x) + outs.append(x.permute(0, 2, 3, 1)) + cls_tokens = [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs] + return outs, cls_tokens + + def get_params(self, lr, wd, ld, *args, **kwargs): + encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld) + return encoder_p, encoder_lr + + def freeze(self) -> None: + for module in self.modules(): + module.eval() + for parameters in self.parameters(): + parameters.requires_grad = False + + @classmethod + def build(cls, config): + obj = globals()[config["model"]["encoder"]["name"]](config) + return obj diff --git a/flash3d/unidepth/models/backbones/dinov2.py b/flash3d/unidepth/models/backbones/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..a9c0a25e2b5091d6eb435fb56f68cf292bebebf4 --- /dev/null +++ b/flash3d/unidepth/models/backbones/dinov2.py @@ -0,0 +1,552 @@ +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.nn.init import trunc_normal_ + +from .metadinov2 import ( + Mlp, + PatchEmbed, + SwiGLUFFNFused, + MemEffAttention, + NestedTensorBlock as Block, +) + + +logger = logging.getLogger("dinov2") + + +def named_apply( + fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()): + parameter_group_names = {} + parameter_group_vars = {} + skip = {} + if skip_list is not None: + skip = skip_list + elif hasattr(model, "no_weight_decay"): + skip = model.no_weight_decay() + + num_layers = model.n_blocks + layer_scale = list(ld ** (num_layers - i) for i in range(num_layers)) + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if len(param.shape) == 1: # norm + group_name = "no_decay" + this_wd = 0.0 + # layer scale, bias beta? + elif ( + name in skip + or name.endswith(".gamma") + or name.endswith(".beta") + or name.endswith(".bias") + ): + group_name = "no_decay" + this_wd = 0.0 + elif "cls_token" in name or "pos_embed" in name or "mask_token" in name: + group_name = "no_decay" + this_wd = 0.0 + else: + group_name = "decay" + this_wd = wd + + if name.startswith("blocks"): + layer_id = int(name.split(".")[1]) + elif name.startswith("patch_embed"): + layer_id = 0 + else: + layer_id = 0 + + group_name = f"layer_{layer_id}_{group_name}" + + if group_name not in parameter_group_names: + scale = layer_scale[layer_id] + cur_lr = lr * scale + + parameter_group_names[group_name] = { + "weight_decay": this_wd, + "params": [], + "lr_init": cur_lr, + "lr_base": lr, + "lr": cur_lr, + } + parameter_group_vars[group_name] = { + "weight_decay": this_wd, + "params": [], + "lr_init": cur_lr, + "lr_base": lr, + "lr": cur_lr, + } + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + return list(parameter_group_vars.values()), [ + v["lr"] for k, v in parameter_group_vars.items() + ] + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + output_idx=[5, 12, 18, 24], + checkpoint: bool = False, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.embed_dims = [embed_dim] * output_idx[-1] + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.depths = output_idx + self.checkpoint = checkpoint + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim) + ) + assert num_register_tokens >= 0 + self.register_tokens = nn.Parameter( + torch.zeros(1, max(1, num_register_tokens), embed_dim) + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append( + [nn.Identity()] * i + blocks_list[i : i + chunksize] + ) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + # self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.num_register_tokens: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape( + 1, int(math.sqrt(N)), int(math.sqrt(N)), dim + ).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + antialias=self.interpolate_antialias, + ) + + assert ( + int(w0) == patch_pos_embed.shape[-2] + and int(h0) == patch_pos_embed.shape[-1] + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + masks = masks.bool().view(B, -1, 1) + x = torch.where(masks, self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.num_register_tokens: + x = torch.cat( + (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), + dim=1, + ) + + return x + + def forward_features(self, x, masks=None): + # if isinstance(x, list): + # return self.forward_features_list(x, masks) + shapes = [val // self.patch_size for val in x.shape[-2:]] + batch_size = x.shape[0] + x = self.prepare_tokens_with_masks(x, masks) + output, cls_tokens = [], [] + + for i, blk in enumerate(self.blocks): + x = blk(x) + cls_token = x[:, :1] + + out = x[:, self.num_register_tokens + 1 :] + # was like this before, add cls to dense features + # out = out + cls_token + + output.append(out.view(batch_size, *shapes, -1)) + cls_tokens.append(cls_token) + return (output, cls_tokens) + + def get_params(self, lr, wd, ld, *args, **kwargs): + encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld) + return encoder_p, encoder_lr + + def freeze(self) -> None: + for module in self.modules(): + module.eval() + for parameters in self.parameters(): + parameters.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self.mask_token.requires_grad = False + self.register_tokens.requires_grad = False + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + return ret + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + num_register_tokens=num_register_tokens, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + num_register_tokens=num_register_tokens, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +import torch +import torch.nn as nn + + +dependencies = ["torch"] + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + return f"dinov2_{compact_arch_name}{patch_size}" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + pretrained: str = "", + output_idx: Sequence[int] = [], + num_register_tokens: int = 0, + drop_path_rate: float = 0.0, + **kwargs, +): + model_name = _make_dinov2_model_name(arch_name, patch_size) + print("Instantiate:", model_name) + + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + output_idx=output_idx, + drop_path_rate=drop_path_rate, + num_register_tokens=num_register_tokens, + ) + vit_kwargs.update(**kwargs) + model = eval(arch_name)(**vit_kwargs) + if pretrained == "": + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}" + if num_register_tokens > 0: + url += "_reg4" + url += "_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + ) + info = model.load_state_dict(state_dict, strict=False) + print(info) + elif pretrained is not None: + state_dict = torch.load(pretrained, map_location="cpu") + info = model.load_state_dict(state_dict, strict=False) + print(f"loading from {pretrained} with:", info) + return model + + # def forward_features_list(self, x_list, masks_list): + # x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + # for blk in self.blocks: + # x = blk(x) + + # all_x = x + # output = [] + # for x, masks in zip(all_x, masks_list): + # x_norm = self.norm(x) + # output.append( + # { + # "x_norm_clstoken": x_norm[:, 0], + # "x_norm_patchtokens": x_norm[:, 1:], + # "x_prenorm": x, + # "masks": masks, + # } + # ) + # return output + + # def _get_intermediate_layers_not_chunked(self, x, n=1): + # x = self.prepare_tokens_with_masks(x) + # # If n is an int, take the n last blocks. If it's a list, take them + # output, total_block_len = [], len(self.blocks) + # blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + # for i, blk in enumerate(self.blocks): + # x = blk(x) + # if i in blocks_to_take: + # output.append(x) + # assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + # return output + + # def _get_intermediate_layers_chunked(self, x, n=1): + # x = self.prepare_tokens_with_masks(x) + # output, i, total_block_len = [], 0, len(self.blocks[-1]) + # # If n is an int, take the n last blocks. If it's a list, take them + # blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + # for block_chunk in self.blocks: + # for blk in block_chunk[i:]: # Passing the nn.Identity() + # x = blk(x) + # if i in blocks_to_take: + # output.append(x) + # i += 1 + # assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + # return output + + # def get_intermediate_layers( + # self, + # x: torch.Tensor, + # n: Union[int, Sequence] = 1, # Layers or n last layers to take + # reshape: bool = False, + # return_class_token: bool = False, + # norm=True, + # ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + # if self.chunked_blocks: + # outputs = self._get_intermediate_layers_chunked(x, n) + # else: + # outputs = self._get_intermediate_layers_not_chunked(x, n) + # if norm: + # outputs = [self.norm(out) for out in outputs] + # class_tokens = [out[:, 0] for out in outputs] + # outputs = [out[:, 1:] for out in outputs] + # if reshape: + # B, _, w, h = x.shape + # outputs = [ + # out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + # for out in outputs + # ] + # if return_class_token: + # return tuple(zip(outputs, class_tokens)) + # return tuple(outputs) diff --git a/flash3d/unidepth/models/backbones/metadinov2/__init__.py b/flash3d/unidepth/models/backbones/metadinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31f196aacac5be8a7c537a3dfa8f97084671b466 --- /dev/null +++ b/flash3d/unidepth/models/backbones/metadinov2/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/flash3d/unidepth/models/backbones/metadinov2/attention.py b/flash3d/unidepth/models/backbones/metadinov2/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..38efc12df276fff129441805d260f9a8107a06d6 --- /dev/null +++ b/flash3d/unidepth/models/backbones/metadinov2/attention.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +import torch.nn as nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/flash3d/unidepth/models/backbones/metadinov2/block.py b/flash3d/unidepth/models/backbones/metadinov2/block.py new file mode 100644 index 0000000000000000000000000000000000000000..c568363443383aa107c07ec65b4bd2ec901575c0 --- /dev/null +++ b/flash3d/unidepth/models/backbones/metadinov2/block.py @@ -0,0 +1,284 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +import torch.nn as nn + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: torch.Tensor) -> torch.Tensor: + def attn_residual_func(x: torch.Tensor) -> torch.Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: torch.Tensor) -> torch.Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: torch.Tensor, + residual_func: Callable[[torch.Tensor], torch.Tensor], + sample_drop_ratio: float = 0.0, +) -> torch.Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + else: + x_plus_residual = scaled_index_add( + x, + brange, + residual.to(dtype=x.dtype), + scaling=scaling_vector, + alpha=residual_scale_factor, + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = ( + [b.shape[0] for b in branges] + if branges is not None + else [x.shape[0] for x in x_list] + ) + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( + 1, -1, x_list[0].shape[-1] + ) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[torch.Tensor], + residual_func: Callable[[torch.Tensor, Any], torch.Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> torch.Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [ + get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list + ] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip( + x_list, branges, residual_list, residual_scale_factors + ): + outputs.append( + add_residual( + x, brange, residual, residual_scale_factor, scaling_vector + ).view_as(x) + ) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=( + self.ls1.gamma if isinstance(self.ls1, LayerScale) else None + ), + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=( + self.ls2.gamma if isinstance(self.ls1, LayerScale) else None + ), + ) + return x_list + else: + + def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, torch.Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert ( + XFORMERS_AVAILABLE + ), "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/flash3d/unidepth/models/backbones/metadinov2/dino_head.py b/flash3d/unidepth/models/backbones/metadinov2/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1147dd3a3c046aee8d427b42b1055f38a218275b --- /dev/null +++ b/flash3d/unidepth/models/backbones/metadinov2/dino_head.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp( + nlayers, + in_dim, + bottleneck_dim, + hidden_dim=hidden_dim, + use_bn=use_bn, + bias=mlp_bias, + ) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp( + nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True +): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/flash3d/unidepth/models/backbones/metadinov2/drop_path.py b/flash3d/unidepth/models/backbones/metadinov2/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..35b1a620d06ba862ea05297d271d8c2c625b5f93 --- /dev/null +++ b/flash3d/unidepth/models/backbones/metadinov2/drop_path.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +import torch.nn as nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/flash3d/unidepth/models/backbones/metadinov2/layer_scale.py b/flash3d/unidepth/models/backbones/metadinov2/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..08c29476cff85a85ab5f071139175f6ac8ba19b2 --- /dev/null +++ b/flash3d/unidepth/models/backbones/metadinov2/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +import torch.nn as nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/flash3d/unidepth/models/backbones/metadinov2/mlp.py b/flash3d/unidepth/models/backbones/metadinov2/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/flash3d/unidepth/models/backbones/metadinov2/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/flash3d/unidepth/models/backbones/metadinov2/patch_embed.py b/flash3d/unidepth/models/backbones/metadinov2/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..837f952cf9a463444feeb146e0d5b539102ee26c --- /dev/null +++ b/flash3d/unidepth/models/backbones/metadinov2/patch_embed.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert ( + H % patch_H == 0 + ), f"Input image height {H} is not a multiple of patch height {patch_H}" + assert ( + W % patch_W == 0 + ), f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/flash3d/unidepth/models/backbones/metadinov2/swiglu_ffn.py b/flash3d/unidepth/models/backbones/metadinov2/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e --- /dev/null +++ b/flash3d/unidepth/models/backbones/metadinov2/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/flash3d/unidepth/models/encoder.py b/flash3d/unidepth/models/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e302df27bb4bf0f82c1b0baff9c682dc2d2b9e9f --- /dev/null +++ b/flash3d/unidepth/models/encoder.py @@ -0,0 +1,184 @@ +import torch +import torch.nn as nn + +from unidepth.models.backbones import ConvNeXtV2, _make_dinov2_model, ConvNeXt + + +class ModelWrap(nn.Module): + def __init__(self, model) -> None: + super().__init__() + self.backbone = model + + def forward(self, x, *args, **kwargs): + features = [] + for layer in self.backbone.features: + x = layer(x) + features.append(x) + return features + + +def convnextv2_base(config, **kwargs): + model = ConvNeXtV2( + depths=[3, 3, 27, 3], + dims=[128, 256, 512, 1024], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + **kwargs, + ) + url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + )["model"] + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnextv2_large(config, **kwargs): + model = ConvNeXtV2( + depths=[3, 3, 27, 3], + dims=[192, 384, 768, 1536], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + **kwargs, + ) + url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + )["model"] + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnextv2_large_mae(config, **kwargs): + model = ConvNeXtV2( + depths=[3, 3, 27, 3], + dims=[192, 384, 768, 1536], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + **kwargs, + ) + url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + )["model"] + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnextv2_huge(config, **kwargs): + model = ConvNeXtV2( + depths=[3, 3, 27, 3], + dims=[352, 704, 1408, 2816], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + **kwargs, + ) + url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + )["model"] + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnextv2_huge_mae(config, **kwargs): + model = ConvNeXtV2( + depths=[3, 3, 27, 3], + dims=[352, 704, 1408, 2816], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + **kwargs, + ) + url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + )["model"] + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnext_large_pt(config, **kwargs): + model = ConvNeXt( + depths=[3, 3, 27, 3], + dims=[192, 384, 768, 1536], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + **kwargs, + ) + from unidepth.models.backbones.convnext import HF_URL, checkpoint_filter_fn + from huggingface_hub import hf_hub_download + from huggingface_hub.utils import disable_progress_bars + + disable_progress_bars() + repo_id, filename = HF_URL["convnext_large_pt"] + state_dict = torch.load(hf_hub_download(repo_id=repo_id, filename=filename)) + state_dict = checkpoint_filter_fn(state_dict, model) + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnext_large(config, **kwargs): + model = ConvNeXt( + depths=[3, 3, 27, 3], + dims=[192, 384, 768, 1536], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + drop_path_rate=config.get("drop_path", 0.0), + **kwargs, + ) + return model + + +def dinov2_vitb14(config, pretrained: bool = True, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + vit = _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + output_idx=config.get("output_idx", [3, 6, 9, 12]), + checkpoint=config.get("use_checkpoint", False), + drop_path_rate=config.get("drop_path", 0.0), + num_register_tokens=config.get("num_register_tokens", 0), + **kwargs, + ) + return vit + + +def dinov2_vitl14(config, pretrained: str = "", **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + vit = _make_dinov2_model( + arch_name="vit_large", + pretrained=config["pretrained"], + output_idx=config.get("output_idx", [5, 12, 18, 24]), + checkpoint=config.get("use_checkpoint", False), + drop_path_rate=config.get("drop_path", 0.0), + num_register_tokens=config.get("num_register_tokens", 0), + **kwargs, + ) + return vit + + +def dinov2_vitg14(config, pretrained: bool = True, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + vit = _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + pretrained=pretrained, + output_idx=config.get("output_idx", [10, 20, 30, 40]), + checkpoint=config.get("use_checkpoint", False), + drop_path_rate=config.get("drop_path", 0.0), + num_register_tokens=config.get("num_register_tokens", 0), + **kwargs, + ) + return vit diff --git a/flash3d/unidepth/models/unidepthv1/__init__.py b/flash3d/unidepth/models/unidepthv1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1781bda94cdc13b0c0c805e7cde0872defc20cd3 --- /dev/null +++ b/flash3d/unidepth/models/unidepthv1/__init__.py @@ -0,0 +1,5 @@ +from .unidepthv1 import UniDepthV1 + +__all__ = [ + "UniDepthV1", +] diff --git a/flash3d/unidepth/models/unidepthv1/decoder.py b/flash3d/unidepth/models/unidepthv1/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e1b0fc750ac8798575c401a23994318914cf80f0 --- /dev/null +++ b/flash3d/unidepth/models/unidepthv1/decoder.py @@ -0,0 +1,542 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from typing import List, Tuple + +from einops import rearrange +from timm.models.layers import trunc_normal_ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from unidepth.layers import ( + MLP, + AttentionBlock, + NystromBlock, + PositionEmbeddingSine, + ConvUpsample, +) +from unidepth.utils.sht import rsh_cart_8 +from unidepth.utils.geometric import ( + generate_rays, + flat_interpolate, +) +from unidepth.utils.misc import max_stack + + +class ListAdapter(nn.Module): + def __init__(self, input_dims: List[int], hidden_dim: int): + super().__init__() + self.input_adapters = nn.ModuleList([]) + self.num_chunks = len(input_dims) + for input_dim in input_dims: + self.input_adapters.append( + nn.Sequential( + nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim), nn.GELU() + ) + ) + + def forward(self, x: torch.Tensor, splits: torch.Tensor) -> torch.Tensor: + xs = torch.split(x, splits.int().tolist(), dim=-1) + xs = [adapter(x) for x, adapter in zip(xs, self.input_adapters)] + return torch.cat(xs, dim=-1) + + +class CameraHead(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + num_heads: int = 8, + expansion: int = 4, + depth: int = 4, + dropout: float = 0.0, + layer_scale: float = 1.0, + **kwargs, + ): + super().__init__() + + self.aggregate = AttentionBlock( + hidden_dim, + num_heads=1, + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + ) + self.latents_pos = nn.Parameter( + torch.randn(1, 4, hidden_dim), requires_grad=True + ) + self.layers = nn.ModuleList([]) + self.in_features = MLP(hidden_dim, expansion=2, dropout=dropout) + for _ in range(depth): + blk = AttentionBlock( + hidden_dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + ) + self.layers.append(blk) + self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1) + self.cls_project = nn.Sequential( + nn.LayerNorm(input_dim), + nn.Linear(input_dim, hidden_dim // 2), + nn.GELU(), + nn.Linear(hidden_dim // 2, hidden_dim), + ) + + def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor: + features = features.unbind(dim=-1) + cls_tokens = self.cls_project(cls_tokens) + features_stack = torch.cat(features, dim=1) + features_stack = features_stack + pos_embed + latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1) + features_stack = self.in_features(features_stack) + features = torch.cat((features_stack, cls_tokens), dim=1) + cls_tokens = self.aggregate(cls_tokens, context=features, pos_embed=latents_pos) + for i, layer in enumerate(self.layers): + cls_tokens = layer(cls_tokens, pos_embed=latents_pos) + + # project + x = self.out(cls_tokens).squeeze(-1) + camera_intrinsics = torch.zeros( + x.shape[0], 3, 3, device=x.device, requires_grad=False + ) + camera_intrinsics[:, 0, 0] = x[:, 0].exp() + camera_intrinsics[:, 1, 1] = x[:, 1].exp() + camera_intrinsics[:, 0, 2] = x[:, 2].sigmoid() + camera_intrinsics[:, 1, 2] = x[:, 3].sigmoid() + camera_intrinsics[:, 2, 2] = 1.0 + return camera_intrinsics + + def set_shapes(self, shapes: Tuple[int, int]): + self.shapes = shapes + + +class DepthHead(nn.Module): + def __init__( + self, + hidden_dim: int, + num_heads: int = 8, + expansion: int = 4, + depths: int | list[int] = 4, + camera_dim: int = 256, + num_resolutions: int = 4, + dropout: float = 0.0, + layer_scale: float = 1.0, + **kwargs, + ) -> None: + super().__init__() + if isinstance(depths, int): + depths = [depths] * 3 + assert len(depths) == 3 + + self.project_rays16 = MLP( + camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim + ) + self.project_rays8 = MLP( + camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 2 + ) + self.project_rays4 = MLP( + camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 4 + ) + self.to_latents = MLP(hidden_dim, expansion=2, dropout=dropout) + + self.features_channel_cat = nn.Linear(hidden_dim * num_resolutions, hidden_dim) + + self.up8 = ConvUpsample( + hidden_dim, expansion=expansion, layer_scale=layer_scale + ) + self.up4 = ConvUpsample( + hidden_dim // 2, expansion=expansion, layer_scale=layer_scale + ) + self.up2 = ConvUpsample( + hidden_dim // 4, expansion=expansion, layer_scale=layer_scale + ) + + self.layers_16 = nn.ModuleList([]) + self.layers_8 = nn.ModuleList([]) + self.layers_4 = nn.ModuleList([]) + self.aggregate_16 = AttentionBlock( + hidden_dim, + num_heads=1, + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + context_dim=hidden_dim, + ) + self.prompt_camera = AttentionBlock( + hidden_dim, + num_heads=1, + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + context_dim=hidden_dim, + ) + for i, (blk_lst, depth) in enumerate( + zip([self.layers_16, self.layers_8, self.layers_4], depths) + ): + attn_cls = AttentionBlock if i == 0 else NystromBlock + for _ in range(depth): + blk_lst.append( + attn_cls( + hidden_dim // (2**i), + num_heads=num_heads // (2**i), + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + ) + ) + + self.out2 = nn.Conv2d(hidden_dim // 8, 1, 3, padding=1) + self.out4 = nn.Conv2d(hidden_dim // 4, 1, 3, padding=1) + self.out8 = nn.Conv2d(hidden_dim // 2, 1, 3, padding=1) + + def set_original_shapes(self, shapes: Tuple[int, int]): + self.original_shapes = shapes + + def set_shapes(self, shapes: Tuple[int, int]): + self.shapes = shapes + + def forward( + self, features: torch.Tensor, rays_hr: torch.Tensor, pos_embed, level_embed + ) -> torch.Tensor: + features = features.unbind(dim=-1) + shapes = self.shapes + + # camera_embedding + # torch.cuda.synchronize() + # start = time() + rays_embedding_16 = F.normalize( + flat_interpolate(rays_hr, old=self.original_shapes, new=shapes), dim=-1 + ) + rays_embedding_8 = F.normalize( + flat_interpolate( + rays_hr, old=self.original_shapes, new=[x * 2 for x in shapes] + ), + dim=-1, + ) + rays_embedding_4 = F.normalize( + flat_interpolate( + rays_hr, old=self.original_shapes, new=[x * 4 for x in shapes] + ), + dim=-1, + ) + rays_embedding_16 = self.project_rays16(rsh_cart_8(rays_embedding_16)) + rays_embedding_8 = self.project_rays8(rsh_cart_8(rays_embedding_8)) + rays_embedding_4 = self.project_rays4(rsh_cart_8(rays_embedding_4)) + # torch.cuda.synchronize() + # print(f"camera_embedding took {time() - start} seconds") + features_tokens = torch.cat(features, dim=1) + features_tokens_pos = pos_embed + level_embed + + # Generate latents with init as pooled features + features_channels = torch.cat(features, dim=-1) + features_16 = self.features_channel_cat(features_channels) + latents_16 = self.to_latents( + flat_interpolate(features_16, old=self.shapes, new=shapes, antialias=False) + ) + + # Aggregate features: F -> D + latents_16 = self.aggregate_16( + latents_16, context=features_tokens, pos_embed_context=features_tokens_pos + ) + + # Aggregate camera: D- > D|E + latents_16 = self.prompt_camera(latents_16, context=rays_embedding_16) + + # Block 16 - Out 8 + for layer in self.layers_16: + latents_16 = layer(latents_16, pos_embed=rays_embedding_16) + latents_8 = self.up8( + rearrange( + latents_16 + rays_embedding_16, + "b (h w) c -> b c h w", + h=shapes[0], + w=shapes[1], + ).contiguous() + ) + out8 = self.out8( + rearrange( + latents_8, "b (h w) c -> b c h w", h=shapes[0] * 2, w=shapes[1] * 2 + ) + ) + + # Block 8 - Out 4 + for layer in self.layers_8: + latents_8 = layer(latents_8, pos_embed=rays_embedding_8) + latents_4 = self.up4( + rearrange( + latents_8 + rays_embedding_8, + "b (h w) c -> b c h w", + h=shapes[0] * 2, + w=shapes[1] * 2, + ).contiguous() + ) + out4 = self.out4( + rearrange( + latents_4, "b (h w) c -> b c h w", h=shapes[0] * 4, w=shapes[1] * 4 + ) + ) + + # Block 4 - Out 2 + for layer in self.layers_4: + latents_4 = layer(latents_4, pos_embed=rays_embedding_4) + latents_2 = self.up2( + rearrange( + latents_4 + rays_embedding_4, + "b (h w) c -> b c h w", + h=shapes[0] * 4, + w=shapes[1] * 4, + ).contiguous() + ) + out2 = self.out2( + rearrange( + latents_2, "b (h w) c -> b c h w", h=shapes[0] * 8, w=shapes[1] * 8 + ) + ) + + # Depth features + proj_latents_16 = rearrange( + latents_16, "b (h w) c -> b c h w", h=shapes[0], w=shapes[1] + ).contiguous() + + # MS Outputs + out2 = out2.clamp(-10.0, 10.0).exp() + out4 = out4.clamp(-10.0, 10.0).exp() + out8 = out8.clamp(-10.0, 10.0).exp() + + return out8, out4, out2, proj_latents_16 + + +class Decoder(nn.Module): + def __init__( + self, + config, + *args, + **kwargs, + ): + super().__init__() + self.build(config) + self.apply(self._init_weights) + self.test_fixed_camera = False + self.skip_camera = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_adapted_features(self, features_flat, splits): + features_flat_cat = torch.cat(features_flat, dim=-1) + features_projected = self.input_adapter( + features_flat_cat, splits + ) # list [b hw c] shapes + features = torch.chunk(features_projected, len(splits), dim=-1) + return features + + def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays): + # get cls tokens projections + cls_tokens_splits = torch.tensor( + [x.shape[-1] for x in cls_tokens], + device=features.device, + requires_grad=False, + dtype=features.dtype, + ) + cls_tokens = torch.cat(cls_tokens, dim=-1) + cls_tokens = self.token_adapter(cls_tokens, cls_tokens_splits) + cls_tokens = torch.cat( + torch.chunk(cls_tokens, len(cls_tokens_splits), dim=-1), dim=1 + ) + + # camera layer + intrinsics = self.camera_layer( + features=features, cls_tokens=cls_tokens, pos_embed=pos_embed + ) + intrinsics[:, 0, 0] = max(original_shapes) / 2 * intrinsics[:, 0, 0] + intrinsics[:, 1, 1] = max(original_shapes) / 2 * intrinsics[:, 1, 1] + intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * original_shapes[1] + intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * original_shapes[0] + if not self.test_fixed_camera: + rays, _ = generate_rays(intrinsics, original_shapes, noisy=False) + + return intrinsics, rays + + def forward(self, inputs, image_metas) -> torch.Tensor: + B, _, H, W = inputs["image"].shape + device = inputs["image"].device + + # make stride happy? + original_encoder_outputs = [x.contiguous() for x in inputs["encoder_outputs"]] + cls_tokens = [x.contiguous() for x in inputs["cls_tokens"]] + + # collect features and tokens + original_encoder_outputs = [ + max_stack(original_encoder_outputs[i:j]) + for i, j in self.slices_encoder_range + ] + cls_tokens = [cls_tokens[-i - 1] for i in range(len(self.slices_encoder_range))] + + # get features in b n d format + # level shapes, the shape per level, for swin like [[128, 128], [64, 64],...], for vit [[32,32]] -> mult times resolutions + resolutions = [ + tuple(sorted([x.shape[1], x.shape[2]])) for x in original_encoder_outputs + ] + level_shapes = sorted(list(set(resolutions)))[::-1] + + if len(level_shapes) == 1: + level_shapes = level_shapes * self.num_resolutions + input_shapes = [ + level_shapes[i] + for i, (start, end) in enumerate(self.slices_encoder) + for _ in range(end - start) + ] + common_shape = level_shapes[-2] + + # input shapes repeat shapes for each level, times the amount of the layers: + features_flat = [ + flat_interpolate( + rearrange(x, "b h w c -> b (h w) c"), old=input_shape, new=common_shape + ) + for x, input_shape in zip(original_encoder_outputs, input_shapes) + ] + features_splits = torch.tensor( + [x.shape[-1] for x in features_flat], + device=device, + requires_grad=False, + dtype=torch.float32, + ) + + # input adapter, then do mean of features in same blocks + features = self.get_adapted_features(features_flat, features_splits) + features = torch.stack(features, dim=-1) + + # positional embeddings, spatial and level + level_embed = torch.cat( + [ + self.level_embed_layer(self.level_embeds)[i : i + 1] + .unsqueeze(0) + .repeat(B, common_shape[0] * common_shape[1], 1) + for i in range(self.num_resolutions) + ], + dim=1, + ) + pos_embed = self.pos_embed( + torch.zeros( + B, + 1, + common_shape[0], + common_shape[1], + device=device, + requires_grad=False, + ) + ) + pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat( + 1, self.num_resolutions, 1 + ) + + self.camera_layer.set_shapes(common_shape) + intrinsics, rays = ( + self.run_camera( + cls_tokens, + features=features, + pos_embed=pos_embed + level_embed, + original_shapes=(H, W), + rays=inputs.get("rays", None), + ) + if not self.skip_camera + else (inputs["K"], inputs["rays"]) + ) + + # run bulk of the model + self.depth_layer.set_shapes(common_shape) + self.depth_layer.set_original_shapes((H, W)) + out8, out4, out2, depth_features = self.depth_layer( + features=features, + rays_hr=rays, + pos_embed=pos_embed, + level_embed=level_embed, + ) + + return intrinsics, [out8, out4, out2], depth_features, rays + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"latents_pos", "level_embeds"} + + def build(self, config): + depth = config["model"]["pixel_decoder"]["depths"] + input_dims = config["model"]["pixel_encoder"]["embed_dims"] + hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"] + num_heads = config["model"]["num_heads"] + expansion = config["model"]["expansion"] + dropout = config["model"]["pixel_decoder"]["dropout"] + depths_encoder = config["model"]["pixel_encoder"]["depths"] + num_steps = config["model"].get("num_steps", 100000) + layer_scale = 1.0 + + self.depth = depth + self.dim = hidden_dim + self.downsample = 4 + self.num_heads = num_heads + self.num_resolutions = len(depths_encoder) + self.depths_encoder = depths_encoder + + self.slices_encoder_single = list( + zip([d - 1 for d in self.depths_encoder], self.depths_encoder) + ) + self.slices_encoder_range = list( + zip([0, *self.depths_encoder[:-1]], self.depths_encoder) + ) + cls_token_input_dims = [input_dims[-i - 1] for i in range(len(depths_encoder))] + + input_dims = [input_dims[d - 1] for d in depths_encoder] + self.slices_encoder = self.slices_encoder_single + + # adapt from encoder features, just project + self.input_adapter = ListAdapter(input_dims, hidden_dim) + self.token_adapter = ListAdapter(cls_token_input_dims, hidden_dim) + + # camera layer + self.camera_layer = CameraHead( + input_dim=hidden_dim, + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion=expansion, + depth=2, + dropout=dropout, + layer_scale=layer_scale, + ) + + self.depth_layer = DepthHead( + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion=expansion, + depths=depth, + dropout=dropout, + camera_dim=81, + num_resolutions=self.num_resolutions, + layer_scale=layer_scale, + ) + + # transformer part + self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True) + self.level_embeds = nn.Parameter( + torch.randn(len(input_dims), hidden_dim), requires_grad=True + ) + self.level_embed_layer = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + ) \ No newline at end of file diff --git a/flash3d/unidepth/models/unidepthv1/unidepthv1.py b/flash3d/unidepth/models/unidepthv1/unidepthv1.py new file mode 100644 index 0000000000000000000000000000000000000000..bf0207120b85f033b6479d4db7003eee1563c868 --- /dev/null +++ b/flash3d/unidepth/models/unidepthv1/unidepthv1.py @@ -0,0 +1,329 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from copy import deepcopy +import importlib +from typing import Any, Dict, Tuple +from math import ceil + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from einops import rearrange + +from unidepth.utils.geometric import ( + generate_rays, + spherical_zbuffer_to_euclidean, +) +from unidepth.utils.misc import get_params +from unidepth.utils.distributed import is_main_process +from unidepth.utils.constants import IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD +from unidepth.models.unidepthv1.decoder import Decoder + +from huggingface_hub import PyTorchModelHubMixin + + +MAP_BACKBONES = {"ViTL14": "vitl14", "ConvNextL": "cnvnxtl"} + + +# inference helpers +def _paddings(image_shape, network_shape): + cur_h, cur_w = image_shape + h, w = network_shape + pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2 + pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2 + return pad_left, pad_right, pad_top, pad_bottom + + +def _shapes(image_shape, network_shape): + h, w = image_shape + input_ratio = w / h + output_ratio = network_shape[1] / network_shape[0] + if output_ratio > input_ratio: + ratio = network_shape[0] / h + elif output_ratio <= input_ratio: + ratio = network_shape[1] / w + return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio + + +def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes): + (pad_left, pad_right, pad_top, pad_bottom) = pads + rgbs = F.interpolate( + rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True + ) + rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant") + if intrinsics is not None: + intrinsics = intrinsics.clone() + intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio + intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio + intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left + intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top + return rgbs, intrinsics + return rgbs, None + + +def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes): + (pad_left, pad_right, pad_top, pad_bottom) = pads + # pred mean, trim paddings, and upsample to input dim + predictions = sum( + [ + F.interpolate( + x.clone(), + size=shapes, + mode="bilinear", + align_corners=False, + antialias=True, + ) + for x in predictions + ] + ) / len(predictions) + predictions = predictions[ + ..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right + ] + predictions = F.interpolate( + predictions, + size=original_shapes, + mode="bilinear", + align_corners=False, + antialias=True, + ) + intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio + intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio + intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio + intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio + return predictions, intrinsics + + +class UniDepthV1(nn.Module, + PyTorchModelHubMixin, + library_name="UniDepth", + repo_url="https://github.com/lpiccinelli-eth/UniDepth", + tags=["monocular-metric-depth-estimation"]): + def __init__( + self, + config, + eps: float = 1e-6, + **kwargs, + ): + super().__init__() + self.build(config) + self.eps = eps + + def forward(self, inputs, image_metas): + rgbs = inputs["image"] + gt_intrinsics = inputs.get("K") + H, W = rgbs.shape[-2:] + + # Encode + encoder_outputs, cls_tokens = self.pixel_encoder(rgbs) + if "dino" in self.pixel_encoder.__class__.__name__.lower(): + encoder_outputs = [ + (x + y.unsqueeze(1)).contiguous() + for x, y in zip(encoder_outputs, cls_tokens) + ] + inputs["encoder_outputs"] = encoder_outputs + inputs["cls_tokens"] = cls_tokens + + # Get camera infos, if any + if gt_intrinsics is not None: + rays, angles = generate_rays( + gt_intrinsics, self.image_shape, noisy=self.training + ) + inputs["rays"] = rays + inputs["angles"] = angles + inputs["K"] = gt_intrinsics + self.pixel_decoder.test_fixed_camera = True # use GT camera in fwd + + # Decode + pred_intrinsics, predictions, _, _ = self.pixel_decoder(inputs, {}) + predictions = sum( + [ + F.interpolate( + x.clone(), + size=self.image_shape, + mode="bilinear", + align_corners=False, + antialias=True, + ) + for x in predictions + ] + ) / len(predictions) + + # Final 3D points backprojection + pred_angles = generate_rays(pred_intrinsics, (H, W), noisy=False)[-1] + # You may want to use inputs["angles"] if available? + pred_angles = rearrange(pred_angles, "b (h w) c -> b c h w", h=H, w=W) + points_3d = torch.cat((pred_angles, predictions), dim=1) + points_3d = spherical_zbuffer_to_euclidean( + points_3d.permute(0, 2, 3, 1) + ).permute(0, 3, 1, 2) + + # Output data, use for loss computation + outputs = { + "angles": pred_angles, + "intrinsics": pred_intrinsics, + "points": points_3d, + "depth": predictions[:, -1:], + } + self.pixel_decoder.test_fixed_camera = False + return outputs + + @torch.no_grad() + def infer(self, rgbs: torch.Tensor, intrinsics=None, skip_camera=False): + if rgbs.ndim == 3: + rgbs = rgbs.unsqueeze(0) + if intrinsics is not None and intrinsics.ndim == 2: + intrinsics = intrinsics.unsqueeze(0) + B, _, H, W = rgbs.shape + + rgbs = rgbs.to(self.device) + if intrinsics is not None: + intrinsics = intrinsics.to(self.device) + + # process image and intrinsiscs (if any) to match network input (slow?) + if rgbs.max() > 5 or rgbs.dtype == torch.uint8: + rgbs = TF.normalize( + rgbs.to(torch.float32).div(255), + mean=IMAGENET_DATASET_MEAN, + std=IMAGENET_DATASET_STD, + ) + else: + pass + # print("Image not normalized, was it already normalized?") + (h, w), ratio = _shapes((H, W), self.image_shape) + pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), self.image_shape) + rgbs, gt_intrinsics = _preprocess( + rgbs, + intrinsics, + (h, w), + (pad_left, pad_right, pad_top, pad_bottom), + ratio, + self.image_shape, + ) + + # run encoder + encoder_outputs, cls_tokens = self.pixel_encoder(rgbs) + if "dino" in self.pixel_encoder.__class__.__name__.lower(): + encoder_outputs = [ + (x + y.unsqueeze(1)).contiguous() + for x, y in zip(encoder_outputs, cls_tokens) + ] + + # get data for decoder and adapt to given camera + inputs = {} + inputs["encoder_outputs"] = encoder_outputs + inputs["cls_tokens"] = cls_tokens + inputs["image"] = rgbs + if gt_intrinsics is not None: + rays, angles = generate_rays( + gt_intrinsics, self.image_shape, noisy=self.training + ) + inputs["rays"] = rays + inputs["angles"] = angles + inputs["K"] = gt_intrinsics + self.pixel_decoder.test_fixed_camera = True + self.pixel_decoder.skip_camera = skip_camera + + # decode all + pred_intrinsics, predictions, _, _ = self.pixel_decoder(inputs, {}) + + # undo the reshaping and get original image size (slow) + predictions, pred_intrinsics = _postprocess( + predictions, + pred_intrinsics, + self.image_shape, + (pad_left, pad_right, pad_top, pad_bottom), + ratio, + (H, W), + ) + + # final 3D points backprojection + intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics + angles = generate_rays(intrinsics, (H, W), noisy=False)[-1] + angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W) + points_3d = torch.cat((angles, predictions), dim=1) + points_3d = spherical_zbuffer_to_euclidean( + points_3d.permute(0, 2, 3, 1) + ).permute(0, 3, 1, 2) + + # output data + outputs = { + "intrinsics": pred_intrinsics, + "points": points_3d, + "depth": predictions[:, -1:], + } + self.pixel_decoder.test_fixed_camera = False + self.pixel_decoder.skip_camera = False + return outputs + + def load_pretrained(self, model_file): + device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + dict_model = torch.load(model_file, map_location=device) + + if "model" in dict_model: + dict_model = dict_model["model"] + new_state_dict = deepcopy( + {k.replace("module.", ""): v for k, v in dict_model.items()} + ) + + info = self.load_state_dict(new_state_dict, strict=False) + if is_main_process(): + print( + f"Loaded from {model_file} for {self.__class__.__name__} results in:", + info, + ) + + def get_params(self, config): + if hasattr(self.pixel_encoder, "get_params"): + encoder_p, encoder_lr = self.pixel_encoder.get_params( + config["model"]["pixel_encoder"]["lr"], + config["training"]["wd"], + config["training"]["ld"], + ) + else: + encoder_p, encoder_lr = get_params( + self.pixel_encoder, + config["model"]["pixel_encoder"]["lr"], + config["training"]["wd"], + ) + decoder_p, decoder_lr = get_params( + self.pixel_decoder, config["training"]["lr"], config["training"]["wd"] + ) + return [*encoder_p, *decoder_p], [*encoder_lr, *decoder_lr] + + @property + def device(self): + return next(self.parameters()).device + + def build(self, config: Dict[str, Dict[str, Any]]): + mod = importlib.import_module("unidepth.models.encoder") + pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"]) + pixel_encoder_config = { + **config["training"], + **config["data"], + **config["model"]["pixel_encoder"], + } + pixel_encoder = pixel_encoder_factory(pixel_encoder_config) + + config["model"]["pixel_encoder"]["patch_size"] = ( + 14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16 + ) + pixel_encoder_embed_dims = ( + pixel_encoder.embed_dims + if hasattr(pixel_encoder, "embed_dims") + else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)] + ) + config["model"]["pixel_encoder"]["embed_dim"] = getattr( + pixel_encoder, "embed_dim" + ) + config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims + config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths + + self.pixel_encoder = pixel_encoder + self.pixel_decoder = Decoder(config) + self.image_shape = config["data"]["image_shape"] diff --git a/flash3d/unidepth/ops/__init__.py b/flash3d/unidepth/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..412c242f25ec2e4c496cde55bbbf0b47a8580081 --- /dev/null +++ b/flash3d/unidepth/ops/__init__.py @@ -0,0 +1,9 @@ +from .losses import SILog, MSE, SelfCons +from .scheduler import CosineScheduler + +__all__ = [ + "SILog", + "MSE", + "SelfCons", + "CosineScheduler", +] diff --git a/flash3d/unidepth/ops/losses.py b/flash3d/unidepth/ops/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..6dda902708aca5b15e680d1dfd88bd040e68ef6f --- /dev/null +++ b/flash3d/unidepth/ops/losses.py @@ -0,0 +1,429 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from typing import Any, Optional, Dict, Tuple, List + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +FNS = { + "sqrt": torch.sqrt, + "log": torch.log, + "log1": lambda x: torch.log(x + 1), + "linear": lambda x: x, + "square": torch.square, + "disp": lambda x: 1 / x, +} + + +FNS_INV = { + "sqrt": torch.square, + "log": torch.exp, + "log1": lambda x: torch.exp(x) - 1, + "linear": lambda x: x, + "square": torch.sqrt, + "disp": lambda x: 1 / x, +} + + +def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): + if mask is None: + return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + mask_var = torch.sum( + mask * (data - mask_mean) ** 2, dim=dim, keepdim=True + ) / torch.clamp(mask_sum, min=1.0) + return mask_mean.squeeze(dim), mask_var.squeeze(dim) + + +def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]): + if mask is None: + return data.mean(dim=dim, keepdim=True) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + return mask_mean + + +def masked_mae(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]): + if mask is None: + return data.abs().mean(dim=dim, keepdim=True) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum(data.abs() * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + return mask_mean + + +def masked_mse(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]): + if mask is None: + return (data**2).mean(dim=dim, keepdim=True) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum((data**2) * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + return mask_mean + + +def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): + ndim = data.ndim + data = data.flatten(ndim - len(dim)) + mask = mask.flatten(ndim - len(dim)) + mask_median = torch.median(data[mask], dim=-1).values + return mask_median + + +def masked_median_mad(data: torch.Tensor, mask: torch.Tensor): + data = data.flatten() + mask = mask.flatten() + mask_median = torch.median(data[mask]) + n_samples = torch.clamp(torch.sum(mask.float()), min=1.0) + mask_mad = torch.sum((data[mask] - mask_median).abs()) / n_samples + return mask_median, mask_mad + + +def masked_weighted_mean_var( + data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...] +): + if mask is None: + return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) + mask = mask.float() + mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum( + mask * weights, dim=dim, keepdim=True + ).clamp(min=1.0) + # V1**2 - V2, V1: sum w_i, V2: sum w_i**2 + denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum( + (mask * weights).square(), dim=dim, keepdim=True + ) + # correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd) + correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp( + min=1.0 + ) + mask_var = correction_factor * torch.sum( + weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True + ) + return mask_mean, mask_var + + +def masked_mean_var_q(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): + if mask is None: + return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + mask_var = torch.sum( + mask * (data - mask_mean) ** 2, dim=dim, keepdim=True + ) / torch.clamp(mask_sum, min=1.0) + return mask_mean, mask_var + + +class SILog(nn.Module): + def __init__( + self, + weight: float, + scale_pred_weight: float = 0.15, + output_fn: str = "sqrt", + input_fn: str = "log", + legacy: bool = False, + abs_rel: bool = False, + norm: bool = False, + eps: float = 1e-5, + ): + super().__init__() + assert output_fn in FNS + self.name: str = self.__class__.__name__ + self.weight: float = weight + + self.scale_pred_weight: float = scale_pred_weight + self.dims = (-4, -3, -2, -1) if legacy else (-2, -1) + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.abs_rel = abs_rel + self.norm = norm + self.eps: float = eps + + @torch.cuda.amp.autocast(enabled=False) + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: Optional[torch.Tensor] = None, + interpolate: bool = True, + scale_inv: torch.Tensor | None = None, + ss_inv: torch.Tensor | None = None, + **kwargs + ) -> torch.Tensor: + if interpolate: + input = F.interpolate( + input, target.shape[-2:], mode="bilinear", align_corners=False + ) + if mask is not None: + mask = mask.to(torch.bool) + if ss_inv is not None: + ss_inv = ~ss_inv + + if input.shape[1] > 1: + input_ = torch.cat( + [input[:, :-1], self.input_fn(input[:, -1:].clamp(min=self.eps))], dim=1 + ) + target_ = torch.cat( + [target[:, :-1], self.input_fn(target[:, -1:].clamp(min=self.eps))], + dim=1, + ) + error = torch.norm(input_ - target_, dim=1, keepdim=True) + else: + input_ = self.input_fn(input.clamp(min=self.eps)) + target_ = self.input_fn(target.clamp(min=self.eps)) + error = input_ - target_ + + mean_error, var_error = masked_mean_var(data=error, mask=mask, dim=self.dims) + + # prevoiusly was inverted!! + if self.abs_rel: + scale_error = (input - target).abs()[:, -1:] / target[:, -1:].clip( + min=self.eps + ) + scale_error = masked_mean(data=scale_error, mask=mask, dim=self.dims) + else: + scale_error = mean_error**2 + + if var_error.ndim > 1: + var_error = var_error.sum(dim=1) + scale_error = scale_error.sum(dim=1) + + # if scale inv -> mask scale error, if scale/shift, mask the full loss + if scale_inv is not None: + scale_error = (1 - scale_inv.int()) * scale_error + scale_error = self.scale_pred_weight * scale_error + loss = var_error + scale_error + out_loss = self.output_fn(loss.clamp(min=self.eps)) + out_loss = masked_mean(data=out_loss, mask=ss_inv, dim=(0,)) + return out_loss.mean() + + @classmethod + def build(cls, config: Dict[str, Any]): + obj = cls( + weight=config["weight"], + legacy=config["legacy"], + output_fn=config["output_fn"], + input_fn=config["input_fn"], + norm=config.get("norm", False), + scale_pred_weight=config.get("gamma", 0.15), + abs_rel=config.get("abs_rel", False), + ) + return obj + + +class MSE(nn.Module): + def __init__( + self, + weight: float = 1.0, + input_fn: str = "linear", + output_fn: str = "linear", + ): + super().__init__() + self.name: str = self.__class__.__name__ + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.weight: float = weight + self.eps = 1e-6 + + @torch.cuda.amp.autocast(enabled=False) + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, + batch_mask: torch.Tensor | None = None, + **kwargs + ) -> torch.Tensor: + input = input[..., : target.shape[-1]] # B N C or B H W C + error = self.input_fn(input + self.eps) - self.input_fn(target + self.eps) + abs_error = torch.square(error).sum(dim=-1) + mean_error = masked_mean(data=abs_error, mask=mask, dim=(-1,)).mean(dim=-1) + batched_error = masked_mean( + self.output_fn(mean_error.clamp(self.eps)), batch_mask, dim=(0,) + ) + return batched_error.mean(), mean_error.detach() + + @classmethod + def build(cls, config: Dict[str, Any]): + obj = cls( + weight=config["weight"], + output_fn=config["output_fn"], + input_fn=config["input_fn"], + ) + return obj + + +class SelfCons(nn.Module): + def __init__( + self, + weight: float, + scale_pred_weight: float = 0.15, + output_fn: str = "sqrt", + input_fn: str = "log", + abs_rel: bool = False, + norm: bool = False, + eps: float = 1e-5, + ): + super().__init__() + assert output_fn in FNS + self.name: str = self.__class__.__name__ + self.weight: float = weight + + self.scale_pred_weight: float = scale_pred_weight + self.dims = (-2, -1) + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.abs_rel = abs_rel + self.norm = norm + self.eps: float = eps + + @torch.cuda.amp.autocast(enabled=False) + def forward( + self, + input: torch.Tensor, + mask: torch.Tensor, + metas: List[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + chunks = input.shape[0] // 2 + device = input.device + mask = F.interpolate(mask.float(), size=input.shape[-2:], mode="nearest") + + rescales = input.shape[-2] / torch.tensor( + [x["resized_shape"][0] for x in metas], device=device + ) + cams = torch.cat([x["K_target"] for x in metas], dim=0).to(device) + flips = torch.tensor([x["flip"] for x in metas], device=device) + + iters = zip( + input.chunk(chunks), + mask.chunk(chunks), + cams.chunk(chunks), + rescales.chunk(chunks), + flips.chunk(chunks), + ) + inputs0, inputs1, masks = [], [], [] + for i, (pair_input, pair_mask, pair_cam, pair_rescale, pair_flip) in enumerate( + iters + ): + mask0, mask1 = pair_mask + input0, input1 = pair_input + cam0, cam1 = pair_cam + rescale0, rescale1 = pair_rescale + flip0, flip1 = pair_flip + + fx_0 = cam0[0, 0] * rescale0 + fx_1 = cam1[0, 0] * rescale1 + cx_0 = (cam0[0, 2] - 0.5) * rescale0 + 0.5 + cx_1 = (cam1[0, 2] - 0.5) * rescale1 + 0.5 + cy_0 = (cam0[1, 2] - 0.5) * rescale0 + 0.5 + cy_1 = (cam1[1, 2] - 0.5) * rescale1 + 0.5 + + # flip image + if flip0 ^ flip1: + input0 = torch.flip(input0, dims=(2,)) + mask0 = torch.flip(mask0, dims=(2,)) + cx_0 = input0.shape[-1] - cx_0 + + # calc zoom + zoom_x = float(fx_1 / fx_0) + + # apply zoom + input0 = F.interpolate( + input0.unsqueeze(0), + scale_factor=zoom_x, + mode="bilinear", + align_corners=True, + ).squeeze(0) + mask0 = F.interpolate( + mask0.unsqueeze(0), scale_factor=zoom_x, mode="nearest" + ).squeeze(0) + + # calc translation + change_left = int(cx_1 - (cx_0 - 0.5) * zoom_x - 0.5) + change_top = int(cy_1 - (cy_0 - 0.5) * zoom_x - 0.5) + change_right = input1.shape[-1] - change_left - input0.shape[-1] + change_bottom = input1.shape[-2] - change_top - input0.shape[-2] + + # apply translation + pad_left = max(0, change_left) + pad_right = max(0, change_right) + pad_top = max(0, change_top) + pad_bottom = max(0, change_bottom) + + crop_left = max(0, -change_left) + crop_right = max(0, -change_right) + crop_top = max(0, -change_top) + crop_bottom = max(0, -change_bottom) + + input0 = F.pad( + input0, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=0, + ) + mask0 = F.pad( + mask0, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=0, + ) + input0 = input0[ + :, + crop_top : input0.shape[-2] - crop_bottom, + crop_left : input0.shape[-1] - crop_right, + ] + mask0 = mask0[ + :, + crop_top : mask0.shape[-2] - crop_bottom, + crop_left : mask0.shape[-1] - crop_right, + ] + + mask = torch.logical_and(mask0, mask1) + + inputs0.append(input0) + inputs1.append(input1) + masks.append(mask) + + inputs0 = torch.stack(inputs0, dim=0) + inputs1 = torch.stack(inputs1, dim=0) + masks = torch.stack(masks, dim=0) + loss1 = self.loss(inputs0, inputs1.detach(), masks) + loss2 = self.loss(inputs1, inputs0.detach(), masks) + return torch.cat([loss1, loss2], dim=0).mean() + + def loss( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + ) -> torch.Tensor: + loss = masked_mean( + (input - target).square().mean(dim=1), mask=mask, dim=(-2, -1) + ) + return self.output_fn(loss + self.eps) + + @classmethod + def build(cls, config: Dict[str, Any]): + obj = cls( + weight=config["weight"], + output_fn=config["output_fn"], + input_fn=config["input_fn"], + ) + return obj diff --git a/flash3d/unidepth/ops/scheduler.py b/flash3d/unidepth/ops/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..a182ff6e204ab445a67846314a8bea087119685e --- /dev/null +++ b/flash3d/unidepth/ops/scheduler.py @@ -0,0 +1,70 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import numpy as np + + +class CosineScheduler(object): + def __init__( + self, + optimizer, + warmup_iters, + total_iters, + key, + overwrite=False, + init_value=None, + base_value=None, + final_value=None, + step_init=-1, + ): + super().__init__() + self.iter = step_init + self.overwrite = overwrite + self.optimizer = optimizer + self.base_value = base_value + self.init_value = init_value + self.final_value = final_value + self.total_iters = total_iters + self.warmup_iters = warmup_iters + self.key = key + self.schedulers = [ + self.get_schedulers(group) for group in optimizer.param_groups + ] + + def get_schedulers(self, group): + init_value = group.get(self.key + "_init", self.init_value) + base_value = group.get(self.key + "_base", self.base_value) + final_value = group.get(self.key + "_final", self.final_value) + warmup_iters = self.warmup_iters + total_iters = self.total_iters + if self.overwrite: + final_value = self.final_value + + # normalize in 0,1, then apply function (power) and denormalize + normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True) + normalized_schedule = np.power(normalized_schedule, 2) + warmup_schedule = (base_value - init_value) * normalized_schedule + init_value + + # main scheduling + iters = np.arange(total_iters - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) * ( + 1 + np.cos(np.pi * iters / len(iters)) + ) + return np.concatenate((warmup_schedule, schedule)) + + def step(self): + self.iter = self.iter + 1 + vals = self[self.iter] + for group, val in zip(self.optimizer.param_groups, vals): + if isinstance(group[self.key], (tuple, list)): + val = (val, *group[self.key][1:]) + group[self.key] = val + + def __getitem__(self, it): + it = min(it, self.total_iters - 1) + return [scheduler[it] for scheduler in self.schedulers] + + def get(self): + return [group[self.key] for group in self.optimizer.param_groups] diff --git a/flash3d/unidepth/utils/__init__.py b/flash3d/unidepth/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e3153806457cace0aa1ffd2820e60f4d87c8180 --- /dev/null +++ b/flash3d/unidepth/utils/__init__.py @@ -0,0 +1,35 @@ +from .evaluation_depth import eval_depth, DICT_METRICS +from .visualization import colorize, image_grid, log_train_artifacts +from .misc import format_seconds, remove_padding, get_params, identity +from .distributed import ( + is_main_process, + setup_multi_processes, + setup_slurm, + sync_tensor_across_gpus, + barrier, + get_rank, + get_dist_info, +) +from .geometric import unproject_points, spherical_zbuffer_to_euclidean + +__all__ = [ + "eval_depth", + "DICT_METRICS", + "colorize", + "image_grid", + "log_train_artifacts", + "format_seconds", + "remove_padding", + "get_params", + "identity", + "is_main_process", + "setup_multi_processes", + "setup_slurm", + "sync_tensor_across_gpus", + "barrier", + "get_rank", + "unproject_points", + "spherical_zbuffer_to_euclidean", + "validate", + "get_dist_info", +] diff --git a/flash3d/unidepth/utils/constants.py b/flash3d/unidepth/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..7b23481335056a8bfe756bf6d0772bbf10c2ca22 --- /dev/null +++ b/flash3d/unidepth/utils/constants.py @@ -0,0 +1,21 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import math +import torch + +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) +IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DATASET_STD = (0.229, 0.224, 0.225) +DEPTH_BINS = torch.cat( + ( + torch.logspace(math.log10(0.1), math.log10(180.0), steps=512), + torch.tensor([260.0]), + ), + dim=0, +) +LOGERR_BINS = torch.linspace(-2, 2, steps=128 + 1) +LINERR_BINS = torch.linspace(-50, 50, steps=256 + 1) diff --git a/flash3d/unidepth/utils/distributed.py b/flash3d/unidepth/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd8a501582808fe967f54f51fb645b667137d02 --- /dev/null +++ b/flash3d/unidepth/utils/distributed.py @@ -0,0 +1,179 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import os +import platform +import warnings +import subprocess + +import cv2 + +import torch +import torch.utils.data.distributed +from torch import multiprocessing as mp +from torch import distributed as dist + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def barrier(): + if not is_dist_avail_and_initialized(): + return + dist.barrier() + + +def is_main_process(): + return get_rank() == 0 + + +def is_rank_zero(args): + return args.rank == 0 + + +def get_dist_info(): + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def setup_multi_processes(cfg): + """Setup multi-processing environment variables.""" + # set multi-process start method as `fork` to speed up the training + if platform.system() != "Windows": + mp_start_method = cfg.get("mp_start_method", "fork") + current_method = mp.get_start_method(allow_none=True) + if current_method is not None and current_method != mp_start_method: + warnings.warn( + f"Multi-processing start method `{mp_start_method}` is " + f"different from the previous setting `{current_method}`." + f"It will be force set to `{mp_start_method}`. You can change " + f"this behavior by changing `mp_start_method` in your config." + ) + mp.set_start_method(mp_start_method, force=True) + + # disable opencv multithreading to avoid system being overloaded + opencv_num_threads = cfg.get("opencv_num_threads", 0) + cv2.setNumThreads(opencv_num_threads) + + # setup OMP threads + # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa + workers_per_gpu = cfg.get("workers_per_gpu", 4) + + if "OMP_NUM_THREADS" not in os.environ and workers_per_gpu > 1: + omp_num_threads = 1 + warnings.warn( + f"Setting OMP_NUM_THREADS environment variable for each process " + f"to be {omp_num_threads} in default, to avoid your system being " + f"overloaded, please further tune the variable for optimal " + f"performance in your application as needed." + ) + os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) + + # setup MKL threads + if "MKL_NUM_THREADS" not in os.environ and workers_per_gpu > 1: + mkl_num_threads = os.environ.get("OMP_NUM_THREADS", 1) + warnings.warn( + f"Setting MKL_NUM_THREADS environment variable for each process " + f"to be {mkl_num_threads} in default, to avoid your system being " + f"overloaded, please further tune the variable for optimal " + f"performance in your application as needed." + ) + os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads) + + +def setup_slurm(backend: str, port: str) -> None: + """Initialize slurm distributed training environment. + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ["SLURM_PROCID"]) + ntasks = int(os.environ["SLURM_NTASKS"]) + node_list = os.environ["SLURM_NODELIST"] + + num_gpus = torch.cuda.device_count() + + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") + os.environ["MASTER_PORT"] = str(port) + os.environ["MASTER_ADDR"] = addr + os.environ["WORLD_SIZE"] = str(ntasks) + os.environ["LOCAL_RANK"] = str(proc_id % num_gpus) + os.environ["RANK"] = str(proc_id) + print( + proc_id, + ntasks, + num_gpus, + proc_id % num_gpus, + node_list, + addr, + os.environ["MASTER_PORT"], + os.system("nvidia-smi -L"), + ) + dist.init_process_group(backend, rank=proc_id, world_size=ntasks) + + +def sync_tensor_across_gpus(t, dim=0, cat=True): + if t is None or not (dist.is_available() and dist.is_initialized()): + return t + t = torch.atleast_1d(t) + group = dist.group.WORLD + group_size = torch.distributed.get_world_size(group) + + local_size = torch.tensor(t.size(dim), device=t.device) + all_sizes = [torch.zeros_like(local_size) for _ in range(group_size)] + dist.all_gather(all_sizes, local_size) + max_size = max(all_sizes) + size_diff = max_size.item() - local_size.item() + if size_diff: + padding = torch.zeros(size_diff, device=t.device, dtype=t.dtype) + t = torch.cat((t, padding)) + + gather_t_tensor = [torch.zeros_like(t) for _ in range(group_size)] + dist.all_gather(gather_t_tensor, t) + all_ts = [] + for t, size in zip(gather_t_tensor, all_sizes): + all_ts.append(t[:size]) + if cat: + return torch.cat(all_ts, dim=0) + return all_ts + + +import pickle + + +def sync_string_across_gpus(keys: list[str], device, dim=0): + keys_serialized = pickle.dumps(keys, protocol=pickle.HIGHEST_PROTOCOL) + keys_serialized_tensor = torch.frombuffer(keys_serialized, dtype=torch.uint8).to( + device + ) + keys_serialized_tensor = sync_tensor_across_gpus( + keys_serialized_tensor, dim=0, cat=False + ) + keys = [ + key + for keys in keys_serialized_tensor + for key in pickle.loads(bytes(keys.cpu().tolist())) + ] + return keys diff --git a/flash3d/unidepth/utils/ema_torch.py b/flash3d/unidepth/utils/ema_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..9a37461f017fb220c88c82416eb3d2c371b5a9e8 --- /dev/null +++ b/flash3d/unidepth/utils/ema_torch.py @@ -0,0 +1,342 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from __future__ import division +from __future__ import unicode_literals + +from typing import Iterable, Optional +import weakref +import copy +import contextlib +from math import tanh + +import torch + + +class DummyExponentialMovingAverage: + def __init__(self, *args, **kwargs): + pass + + def _get_parameters(self, *args, **kwargs): + pass + + def get_current_decay(self, *args, **kwargs): + pass + + def update(self, *args, **kwargs): + pass + + def copy_to(self, *args, **kwargs): + pass + + def store(self, *args, **kwargs): + return + + def restore(self, *args, **kwargs): + return + + @contextlib.contextmanager + def average_parameters(self, *args, **kwargs): + try: + yield + finally: + pass + + def to(self, *args, **kwargs): + pass + + def state_dict(self, *args, **kwargs): + pass + + def load_state_dict(self, *args, **kwargs): + pass + + +class ExponentialMovingAverage: + """ + Maintains (exponential) moving average of a set of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter` (typically from + `model.parameters()`). + Note that EMA is computed on *all* provided parameters, + regardless of whether or not they have `requires_grad = True`; + this allows a single EMA object to be consistantly used even + if which parameters are trainable changes step to step. + + If you want to some parameters in the EMA, do not pass them + to the object in the first place. For example: + + ExponentialMovingAverage( + parameters=[p for p in model.parameters() if p.requires_grad], + decay=0.9 + ) + + will ignore parameters that do not require grad. + + decay: The exponential decay. + + use_num_updates: Whether to use number of updates when computing + averages. + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float, + use_num_updates: bool = True, + update_after_step: int = 10000, + tau: int = 20000, + switch: bool = False, + ): + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + self.decay = decay + self.switch = switch # fi keeping EMA params in model after epochs + self.num_updates = 0 if use_num_updates else None + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + self.collected_params = None + # By maintaining only a weakref to each parameter, + # we maintain the old GC behaviour of ExponentialMovingAverage: + # if the model goes out of scope but the ExponentialMovingAverage + # is kept, no references to the model or its parameters will be + # maintained, and the model will be cleaned up. + self._params_refs = [weakref.ref(p) for p in parameters] + self.update_after_step = update_after_step + self.tau = tau + + def _get_parameters( + self, parameters: Optional[Iterable[torch.nn.Parameter]] + ) -> Iterable[torch.nn.Parameter]: + if parameters is None: + parameters = [p() for p in self._params_refs] + if any(p is None for p in parameters): + raise ValueError( + "(One of) the parameters with which this ExponentialMovingAverage was initialized no longer exists (was garbage collected);" + " please either provide `parameters` explicitly or keep the model to which they belong from being garbage collected." + ) + return parameters + else: + parameters = list(parameters) + if len(parameters) != len(self.shadow_params): + raise ValueError( + "Number of parameters passed as argument is different " + "from number of shadow parameters maintained by this " + "ExponentialMovingAverage" + ) + return parameters + + def get_current_decay(self): + epoch = max(self.num_updates - self.update_after_step - 1, 0.0) + if epoch <= 0: + return 0.0 + value = tanh(epoch / self.tau) * self.decay + return value + + def update(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: + """ + Update currently maintained parameters. + + Call this every time the parameters are updated, such as the result of + the `optimizer.step()` call. + + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + decay = self.get_current_decay() + if self.num_updates is not None: + self.num_updates += 1 + + one_minus_decay = 1.0 - decay + with torch.no_grad(): + for s_param, param in zip(self.shadow_params, parameters): + tmp = s_param - param + # tmp will be a new tensor so we can do in-place + tmp.mul_(one_minus_decay) + s_param.sub_(tmp) + + def copy_to( + self, parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def store(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: + """ + Save the current parameters for restoring later. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.collected_params = [param.detach().clone() for param in parameters] + + def restore( + self, parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + if self.collected_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights " + "to `restore()`" + ) + parameters = self._get_parameters(parameters) + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) + + @contextlib.contextmanager + def average_parameters( + self, parameters: Optional[Iterable[torch.nn.Parameter]] = None + ): + r""" + Context manager for validation/inference with averaged parameters. + + Equivalent to: + + ema.store() + ema.copy_to() + try: + ... + finally: + ema.restore() + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.store(parameters) + self.copy_to(parameters) + try: + yield + finally: + if not self.switch: + self.restore(parameters) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + ( + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + ) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + ( + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + ) + for p in self.collected_params + ] + return + + def state_dict(self) -> dict: + r"""Returns the state of the ExponentialMovingAverage as a dict.""" + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "num_updates": self.num_updates, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params, + } + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the ExponentialMovingAverage state. + + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + self.num_updates = state_dict["num_updates"] + assert self.num_updates is None or isinstance( + self.num_updates, int + ), "Invalid num_updates" + + self.shadow_params = state_dict["shadow_params"] + assert isinstance(self.shadow_params, list), "shadow_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.shadow_params + ), "shadow_params must all be Tensors" + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + assert isinstance( + self.collected_params, list + ), "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" + assert len(self.collected_params) == len( + self.shadow_params + ), "collected_params and shadow_params had different lengths" + + if len(self.shadow_params) == len(self._params_refs): + # Consistant with torch.optim.Optimizer, cast things to consistant + # device and dtype with the parameters + params = [p() for p in self._params_refs] + # If parameters have been garbage collected, just load the state + # we were given without change. + if not any(p is None for p in params): + # ^ parameter references are still good + for i, p in enumerate(params): + self.shadow_params[i] = self.shadow_params[i].to( + device=p.device, dtype=p.dtype + ) + if self.collected_params is not None: + self.collected_params[i] = self.collected_params[i].to( + device=p.device, dtype=p.dtype + ) + else: + raise ValueError( + "Tried to `load_state_dict()` with the wrong number of " + "parameters in the saved state." + ) diff --git a/flash3d/unidepth/utils/evaluation_depth.py b/flash3d/unidepth/utils/evaluation_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..9f84ca0591efae3f13f78ab4e10b5069ae2f74eb --- /dev/null +++ b/flash3d/unidepth/utils/evaluation_depth.py @@ -0,0 +1,173 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" +# We prefer not to install PyTorch3D in the package +# Code commented is how 3D metrics are computed + +from collections import defaultdict +from functools import partial + +import torch +import torch.nn.functional as F + +# from chamfer_distance import ChamferDistance + +from unidepth.utils.constants import DEPTH_BINS + + +# chamfer_cls = ChamferDistance() + + +# def chamfer_dist(tensor1, tensor2): +# x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) +# y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) +# dist1, dist2, idx1, idx2 = chamfer_cls( +# tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths +# ) +# return (torch.sqrt(dist1) + torch.sqrt(dist2)) / 2 + + +# def auc(tensor1, tensor2, thresholds): +# x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) +# y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) +# dist1, dist2, idx1, idx2 = chamfer_cls( +# tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths +# ) +# # compute precision recall +# precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] +# recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] +# auc_value = torch.trapz( +# torch.tensor(precisions, device=tensor1.device), +# torch.tensor(recalls, device=tensor1.device), +# ) +# return auc_value + + +def delta(tensor1, tensor2, exponent): + inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1)) + return (inlier < 1.25**exponent).to(torch.float32).mean() + + +def ssi(tensor1, tensor2, qtl=0.05): + stability_mat = 1e-9 * torch.eye(2, device=tensor1.device) + error = (tensor1 - tensor2).abs() + mask = error < torch.quantile(error, 1 - qtl) + tensor1_mask = tensor1[mask] + tensor2_mask = tensor2[mask] + tensor2_one = torch.stack( + [tensor2_mask.detach(), torch.ones_like(tensor2_mask).detach()], dim=1 + ) + scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ ( + tensor2_one.T @ tensor1_mask.unsqueeze(1) + ) + scale, shift = scale_shift.squeeze().chunk(2, dim=0) + return tensor2 * scale + shift + # tensor2_one = torch.stack([tensor2.detach(), torch.ones_like(tensor2).detach()], dim=1) + # scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (tensor2_one.T @ tensor1.unsqueeze(1)) + # scale, shift = scale_shift.squeeze().chunk(2, dim=0) + # return tensor2 * scale + shift + + +def d1_ssi(tensor1, tensor2): + delta_ = delta(tensor1, ssi(tensor1, tensor2), 1.0) + return delta_ + + +def d_auc(tensor1, tensor2): + exponents = torch.linspace(0.01, 5.0, steps=100, device=tensor1.device) + deltas = [delta(tensor1, tensor2, exponent) for exponent in exponents] + return torch.trapz(torch.tensor(deltas, device=tensor1.device), exponents) / 5.0 + + +# def f1_score(tensor1, tensor2, thresholds): +# x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) +# y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) +# dist1, dist2, idx1, idx2 = chamfer_cls( +# tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths +# ) +# # compute precision recall +# precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] +# recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] +# precisions = torch.tensor(precisions, device=tensor1.device) +# recalls = torch.tensor(recalls, device=tensor1.device) +# f1_thresholds = 2 * precisions * recalls / (precisions + recalls) +# f1_thresholds = torch.where( +# torch.isnan(f1_thresholds), torch.zeros_like(f1_thresholds), f1_thresholds +# ) +# f1_value = torch.trapz(f1_thresholds) / len(thresholds) +# return f1_value + + +DICT_METRICS = { + "d1": partial(delta, exponent=1.0), + "d2": partial(delta, exponent=2.0), + "d3": partial(delta, exponent=3.0), + "rmse": lambda gt, pred: torch.sqrt(((gt - pred) ** 2).mean()), + "rmselog": lambda gt, pred: torch.sqrt( + ((torch.log(gt) - torch.log(pred)) ** 2).mean() + ), + "arel": lambda gt, pred: (torch.abs(gt - pred) / gt).mean(), + "sqrel": lambda gt, pred: (((gt - pred) ** 2) / gt).mean(), + "log10": lambda gt, pred: torch.abs(torch.log10(pred) - torch.log10(gt)).mean(), + "silog": lambda gt, pred: 100 * torch.std(torch.log(pred) - torch.log(gt)).mean(), + "medianlog": lambda gt, pred: 100 + * (torch.log(pred) - torch.log(gt)).median().abs(), + "d_auc": d_auc, + "d1_ssi": d1_ssi, +} + + +# DICT_METRICS_3D = { +# "chamfer": lambda gt, pred, thresholds: chamfer_dist( +# gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1) +# ), +# "F1": lambda gt, pred, thresholds: f1_score( +# gt.unsqueeze(0).permute(0, 2, 1), +# pred.unsqueeze(0).permute(0, 2, 1), +# thresholds=thresholds, +# ), +# } + + +DICT_METRICS_D = { + "a1": lambda gt, pred: (torch.maximum((gt / pred), (pred / gt)) > 1.25**1.0).to( + torch.float32 + ), + "abs_rel": lambda gt, pred: (torch.abs(gt - pred) / gt), +} + + +def eval_depth( + gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, max_depth=None +): + summary_metrics = defaultdict(list) + preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear") + for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): + if max_depth is not None: + mask = torch.logical_and(mask, gt <= max_depth) + for name, fn in DICT_METRICS.items(): + summary_metrics[name].append(fn(gt[mask], pred[mask]).mean()) + return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} + + +# def eval_3d( +# gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, thresholds=None +# ): +# summary_metrics = defaultdict(list) +# w_max = min(gts.shape[-1] // 4, 400) +# gts = F.interpolate( +# gts, (int(w_max * gts.shape[-2] / gts.shape[-1]), w_max), mode="nearest" +# ) +# preds = F.interpolate(preds, gts.shape[-2:], mode="nearest") +# masks = F.interpolate( +# masks.to(torch.float32), gts.shape[-2:], mode="nearest" +# ).bool() +# for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): +# if not torch.any(mask): +# continue +# for name, fn in DICT_METRICS_3D.items(): +# summary_metrics[name].append( +# fn(gt[:, mask.squeeze()], pred[:, mask.squeeze()], thresholds).mean() +# ) +# return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} diff --git a/flash3d/unidepth/utils/geometric.py b/flash3d/unidepth/utils/geometric.py new file mode 100644 index 0000000000000000000000000000000000000000..b942beb1af3c58e8910a1e113bfd974a1fd169ce --- /dev/null +++ b/flash3d/unidepth/utils/geometric.py @@ -0,0 +1,248 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from typing import Tuple + +import torch +from torch.nn import functional as F + + +def generate_rays( + camera_intrinsics: torch.Tensor, image_shape: Tuple[int, int], noisy: bool = False +): + batch_size, device, dtype = ( + camera_intrinsics.shape[0], + camera_intrinsics.device, + camera_intrinsics.dtype, + ) + height, width = image_shape + # Generate grid of pixel coordinates + pixel_coords_x = torch.linspace(0, width - 1, width, device=device, dtype=dtype) + pixel_coords_y = torch.linspace(0, height - 1, height, device=device, dtype=dtype) + if noisy: + pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5 + pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5 + pixel_coords = torch.stack( + [pixel_coords_x.repeat(height, 1), pixel_coords_y.repeat(width, 1).t()], dim=2 + ) # (H, W, 2) + pixel_coords = pixel_coords + 0.5 + + # Calculate ray directions + intrinsics_inv = torch.inverse(camera_intrinsics.float()).to(dtype) # (B, 3, 3) + homogeneous_coords = torch.cat( + [pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2 + ) # (H, W, 3) + ray_directions = torch.matmul( + intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1) + ) # (3, H*W) + ray_directions = F.normalize(ray_directions, dim=1) # (B, 3, H*W) + ray_directions = ray_directions.permute(0, 2, 1) # (B, H*W, 3) + + theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1]) + phi = torch.acos(ray_directions[..., 1]) + # pitch = torch.asin(ray_directions[..., 1]) + # roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1]) + angles = torch.stack([theta, phi], dim=-1) + return ray_directions, angles + + +@torch.jit.script +def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor: + theta = spherical_tensor[..., 0] # Extract polar angle + phi = spherical_tensor[..., 1] # Extract azimuthal angle + z = spherical_tensor[..., 2] # Extract zbuffer depth + + # y = r * cos(phi) + # x = r * sin(phi) * sin(theta) + # z = r * sin(phi) * cos(theta) + # => + # r = z / sin(phi) / cos(theta) + # y = z / (sin(phi) / cos(phi)) / cos(theta) + # x = z * sin(theta) / cos(theta) + x = z * torch.tan(theta) + y = z / torch.tan(phi) / torch.cos(theta) + + euclidean_tensor = torch.stack((x, y, z), dim=-1) + return euclidean_tensor + + +@torch.jit.script +def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor: + theta = spherical_tensor[..., 0] # Extract polar angle + phi = spherical_tensor[..., 1] # Extract azimuthal angle + r = spherical_tensor[..., 2] # Extract radius + # y = r * cos(phi) + # x = r * sin(phi) * sin(theta) + # z = r * sin(phi) * cos(theta) + x = r * torch.sin(phi) * torch.sin(theta) + y = r * torch.cos(phi) + z = r * torch.cos(theta) * torch.sin(phi) + + euclidean_tensor = torch.stack((x, y, z), dim=-1) + return euclidean_tensor + + +@torch.jit.script +def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor: + x = spherical_tensor[..., 0] # Extract polar angle + y = spherical_tensor[..., 1] # Extract azimuthal angle + z = spherical_tensor[..., 2] # Extract radius + # y = r * cos(phi) + # x = r * sin(phi) * sin(theta) + # z = r * sin(phi) * cos(theta) + r = torch.sqrt(x**2 + y**2 + z**2) + theta = torch.atan2(x / r, z / r) + phi = torch.acos(y / r) + + euclidean_tensor = torch.stack((theta, phi, r), dim=-1) + return euclidean_tensor + + +@torch.jit.script +def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tensor: + pitch = torch.asin(euclidean_tensor[..., 1]) + yaw = torch.atan2(euclidean_tensor[..., 0], euclidean_tensor[..., -1]) + z = euclidean_tensor[..., 2] # Extract zbuffer depth + euclidean_tensor = torch.stack((pitch, yaw, z), dim=-1) + return euclidean_tensor + + +@torch.jit.script +def unproject_points( + depth: torch.Tensor, camera_intrinsics: torch.Tensor +) -> torch.Tensor: + """ + Unprojects a batch of depth maps to 3D point clouds using camera intrinsics. + + Args: + depth (torch.Tensor): Batch of depth maps of shape (B, 1, H, W). + camera_intrinsics (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3). + + Returns: + torch.Tensor: Batch of 3D point clouds of shape (B, 3, H, W). + """ + batch_size, _, height, width = depth.shape + device = depth.device + + # Create pixel grid + y_coords, x_coords = torch.meshgrid( + torch.arange(height, device=device), + torch.arange(width, device=device), + indexing="ij", + ) + pixel_coords = torch.stack((x_coords, y_coords), dim=-1) # (H, W, 2) + + # Get homogeneous coords (u v 1) + pixel_coords_homogeneous = torch.cat( + (pixel_coords, torch.ones((height, width, 1), device=device)), dim=-1 + ) + pixel_coords_homogeneous = pixel_coords_homogeneous.permute(2, 0, 1).flatten( + 1 + ) # (3, H*W) + # Apply K^-1 @ (u v 1): [B, 3, 3] @ [3, H*W] -> [B, 3, H*W] + unprojected_points = torch.matmul( + torch.inverse(camera_intrinsics), pixel_coords_homogeneous + ) # (B, 3, H*W) + unprojected_points = unprojected_points.view( + batch_size, 3, height, width + ) # (B, 3, H, W) + unprojected_points = unprojected_points * depth # (B, 3, H, W) + return unprojected_points + + +@torch.jit.script +def project_points( + points_3d: torch.Tensor, + intrinsic_matrix: torch.Tensor, + image_shape: Tuple[int, int], +) -> torch.Tensor: + # Project 3D points onto the image plane via intrinsics (u v w) = (x y z) @ K^T + points_2d = torch.matmul(points_3d, intrinsic_matrix.transpose(1, 2)) + + # Normalize projected points: (u v w) -> (u / w, v / w, 1) + points_2d = points_2d[..., :2] / points_2d[..., 2:] + + # To pixels (rounding!!!), no int as it breaks gradient + points_2d = points_2d.round() + + # pointa need to be inside the image (can it diverge onto all points out???) + valid_mask = ( + (points_2d[..., 0] >= 0) + & (points_2d[..., 0] < image_shape[1]) + & (points_2d[..., 1] >= 0) + & (points_2d[..., 1] < image_shape[0]) + ) + + # Calculate the flat indices of the valid pixels + flat_points_2d = points_2d[..., 0] + points_2d[..., 1] * image_shape[1] + flat_indices = flat_points_2d.long() + + # Create depth maps and counts using scatter_add, (B, H, W) + depth_maps = torch.zeros( + [points_3d.shape[0], *image_shape], device=points_3d.device + ) + counts = torch.zeros([points_3d.shape[0], *image_shape], device=points_3d.device) + + # Loop over batches to apply masks and accumulate depth/count values + for i in range(points_3d.shape[0]): + valid_indices = flat_indices[i, valid_mask[i]] + depth_maps[i].view(-1).scatter_add_( + 0, valid_indices, points_3d[i, valid_mask[i], 2] + ) + counts[i].view(-1).scatter_add_( + 0, valid_indices, torch.ones_like(points_3d[i, valid_mask[i], 2]) + ) + + # Calculate mean depth for each pixel in each batch + mean_depth_maps = depth_maps / counts.clamp(min=1.0) + return mean_depth_maps.reshape(-1, 1, *image_shape) # (B, 1, H, W) + + +@torch.jit.script +def downsample(data: torch.Tensor, downsample_factor: int = 2): + N, _, H, W = data.shape + data = data.view( + N, + H // downsample_factor, + downsample_factor, + W // downsample_factor, + downsample_factor, + 1, + ) + data = data.permute(0, 1, 3, 5, 2, 4).contiguous() + data = data.view(-1, downsample_factor * downsample_factor) + data_tmp = torch.where(data == 0.0, 1e5 * torch.ones_like(data), data) + data = torch.min(data_tmp, dim=-1).values + data = data.view(N, 1, H // downsample_factor, W // downsample_factor) + data = torch.where(data > 1000, torch.zeros_like(data), data) + return data + + +@torch.jit.script +def flat_interpolate( + flat_tensor: torch.Tensor, + old: Tuple[int, int], + new: Tuple[int, int], + antialias: bool = True, + mode: str = "bilinear", +) -> torch.Tensor: + if old[0] == new[0] and old[1] == new[1]: + return flat_tensor + tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute( + 0, 3, 1, 2 + ) # b c h w + tensor_interp = F.interpolate( + tensor, + size=(new[0], new[1]), + mode=mode, + align_corners=False, + antialias=antialias, + ) + flat_tensor_interp = tensor_interp.view( + flat_tensor.shape[0], -1, new[0] * new[1] + ).permute( + 0, 2, 1 + ) # b (h w) c + return flat_tensor_interp.contiguous() diff --git a/flash3d/unidepth/utils/misc.py b/flash3d/unidepth/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..23ae7adc0048b313dca985175a5713644be5e75d --- /dev/null +++ b/flash3d/unidepth/utils/misc.py @@ -0,0 +1,403 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from functools import wraps + +import numpy as np +from scipy import interpolate + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat, reduce + + +def max_stack(tensors): + return torch.stack(tensors, dim=-1).max(dim=-1)[0] + + +def softmax_stack(tensors, temperature=1.0): + return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1) + + +def mean_stack(tensors): + if len(tensors) == 1: + return tensors[0] + return torch.stack(tensors, dim=-1).mean(dim=-1) + + +def sum_stack(tensors): + return torch.stack(tensors, dim=-1).sum(dim=-1) + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() + + +def format_seconds(seconds): + minutes, seconds = divmod(seconds, 60) + hours, minutes = divmod(minutes, 60) + return f"{hours:d}:{minutes:02d}:{seconds:02d}" + + +def get_params(module, lr, wd): + skip_list = {} + skip_keywords = {} + if hasattr(module, "no_weight_decay"): + skip_list = module.no_weight_decay() + if hasattr(module, "no_weight_decay_keywords"): + skip_keywords = module.no_weight_decay_keywords() + has_decay = [] + no_decay = [] + for name, param in module.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if ( + (name in skip_list) + or any((kw in name for kw in skip_keywords)) + or len(param.shape) == 1 + ): + # if (name in skip_list) or any((kw in name for kw in skip_keywords)): + # print(name, skip_keywords) + no_decay.append(param) + else: + has_decay.append(param) + + group1 = { + "params": has_decay, + "weight_decay": wd, + "lr": lr, + "weight_decay_init": wd, + "weight_decay_base": wd, + "lr_init": lr, + "lr_base": lr, + } + group2 = { + "params": no_decay, + "weight_decay": 0.0, + "lr": lr, + "weight_decay_init": 0.0, + "weight_decay_base": 0.0, + "weight_decay_final": 0.0, + "lr_init": lr, + "lr_base": lr, + } + return [group1, group2], [lr, lr] + + +def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage): + if var_name in ("cls_token", "mask_token", "pos_embed", "absolute_pos_embed"): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("layers"): + if var_name.split(".")[2] == "blocks": + stage_id = int(var_name.split(".")[1]) + layer_id = int(var_name.split(".")[3]) + sum(layers_per_stage[:stage_id]) + return layer_id + 1 + elif var_name.split(".")[2] == "downsample": + stage_id = int(var_name.split(".")[1]) + layer_id = sum(layers_per_stage[: stage_id + 1]) + return layer_id + else: + return num_max_layer - 1 + + +def get_params_layerdecayswin(module, lr, wd, ld): + skip_list = {} + skip_keywords = {} + if hasattr(module, "no_weight_decay"): + skip_list = module.no_weight_decay() + if hasattr(module, "no_weight_decay_keywords"): + skip_keywords = module.no_weight_decay_keywords() + layers_per_stage = module.depths + num_layers = sum(layers_per_stage) + 1 + lrs = [] + params = [] + for name, param in module.named_parameters(): + if not param.requires_grad: + print(f"{name} frozen") + continue # frozen weights + layer_id = get_num_layer_for_swin(name, num_layers, layers_per_stage) + lr_cur = lr * ld ** (num_layers - layer_id - 1) + # if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 or name.endswith(".bias"): + if (name in skip_list) or any((kw in name for kw in skip_keywords)): + wd_cur = 0.0 + else: + wd_cur = wd + params.append({"params": param, "weight_decay": wd_cur, "lr": lr_cur}) + lrs.append(lr_cur) + return params, lrs + + +def log(t, eps: float = 1e-5): + return torch.log(t.clamp(min=eps)) + + +def l2norm(t): + return F.normalize(t, dim=-1) + + +def exists(val): + return val is not None + + +def identity(t, *args, **kwargs): + return t + + +def divisible_by(numer, denom): + return (numer % denom) == 0 + + +def first(arr, d=None): + if len(arr) == 0: + return d + return arr[0] + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def maybe(fn): + @wraps(fn) + def inner(x): + if not exists(x): + return x + return fn(x) + + return inner + + +def once(fn): + called = False + + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + + return inner + + +def _many(fn): + @wraps(fn) + def inner(tensors, pattern, **kwargs): + return (fn(tensor, pattern, **kwargs) for tensor in tensors) + + return inner + + +rearrange_many = _many(rearrange) +repeat_many = _many(repeat) +reduce_many = _many(reduce) + + +def load_pretrained(state_dict, checkpoint): + checkpoint_model = checkpoint["model"] + if any([True if "encoder." in k else False for k in checkpoint_model.keys()]): + checkpoint_model = { + k.replace("encoder.", ""): v + for k, v in checkpoint_model.items() + if k.startswith("encoder.") + } + print("Detect pre-trained model, remove [encoder.] prefix.") + else: + print("Detect non-pre-trained model, pass without doing anything.") + print(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") + checkpoint = load_checkpoint_swin(state_dict, checkpoint_model) + + +def load_checkpoint_swin(model, checkpoint_model): + state_dict = model.state_dict() + # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size + all_keys = list(checkpoint_model.keys()) + for key in all_keys: + if "relative_position_bias_table" in key: + relative_position_bias_table_pretrained = checkpoint_model[key] + relative_position_bias_table_current = state_dict[key] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + if nH1 != nH2: + print(f"Error in loading {key}, passing......") + else: + if L1 != L2: + print(f"{key}: Interpolate relative_position_bias_table using geo.") + src_size = int(L1**0.5) + dst_size = int(L2**0.5) + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + print("Original positions = %s" % str(x)) + print("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(nH1): + z = ( + relative_position_bias_table_pretrained[:, i] + .view(src_size, src_size) + .float() + .numpy() + ) + f_cubic = interpolate.interp2d(x, y, z, kind="cubic") + all_rel_pos_bias.append( + torch.Tensor(f_cubic(dx, dy)) + .contiguous() + .view(-1, 1) + .to(relative_position_bias_table_pretrained.device) + ) + + new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + checkpoint_model[key] = new_rel_pos_bias + + # delete relative_position_index since we always re-init it + relative_position_index_keys = [ + k for k in checkpoint_model.keys() if "relative_position_index" in k + ] + for k in relative_position_index_keys: + del checkpoint_model[k] + + # delete relative_coords_table since we always re-init it + relative_coords_table_keys = [ + k for k in checkpoint_model.keys() if "relative_coords_table" in k + ] + for k in relative_coords_table_keys: + del checkpoint_model[k] + + # # re-map keys due to name change + rpe_mlp_keys = [k for k in checkpoint_model.keys() if "cpb_mlp" in k] + for k in rpe_mlp_keys: + checkpoint_model[k.replace("cpb_mlp", "rpe_mlp")] = checkpoint_model.pop(k) + + # delete attn_mask since we always re-init it + attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] + for k in attn_mask_keys: + del checkpoint_model[k] + + encoder_keys = [k for k in checkpoint_model.keys() if k.startswith("encoder.")] + for k in encoder_keys: + checkpoint_model[k.replace("encoder.", "")] = checkpoint_model.pop(k) + + return checkpoint_model + + +def add_padding_metas(out, image_metas): + device = out.device + # left, right, top, bottom + paddings = [img_meta.get("padding_size", [0] * 4) for img_meta in image_metas] + paddings = torch.stack(paddings).to(device) + outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)] + return torch.stack(outs) + + +def remove_padding(out, paddings): + B, C, H, W = out.shape + device = out.device + # left, right, top, bottom + paddings = torch.stack(paddings).to(device) + outs = [ + o[:, padding[1] : H - padding[3], padding[0] : W - padding[2]] + for padding, o in zip(paddings, out) + ] + return torch.stack(outs) + + +def remove_padding_metas(out, image_metas): + B, C, H, W = out.shape + device = out.device + # left, right, top, bottom + paddings = [ + torch.tensor(img_meta.get("padding_size", [0] * 4)) for img_meta in image_metas + ] + return remove_padding(out, paddings) + + +def ssi_helper(tensor1, tensor2): + stability_mat = 1e-4 * torch.eye(2, device=tensor1.device) + tensor2_one = torch.stack([tensor2, torch.ones_like(tensor2)], dim=1) + scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ ( + tensor2_one.T @ tensor1.unsqueeze(1) + ) + scale, shift = scale_shift.squeeze().chunk(2, dim=0) + return scale, shift + + +def calculate_mean_values(names, values): + # Create a defaultdict to store sum and count for each name + name_values = {name: {} for name in names} + + # Iterate through the lists and accumulate values for each name + for name, value in zip(names, values): + name_values[name]["sum"] = name_values[name].get("sum", 0.0) + value + name_values[name]["count"] = name_values[name].get("count", 0.0) + 1 + + # Calculate mean values and create the output dictionary + output_dict = { + name: name_values[name]["sum"] / name_values[name]["count"] + for name in name_values + } + + return output_dict + + +def remove_leading_dim(infos): + if isinstance(infos, dict): + return {k: remove_leading_dim(v) for k, v in infos.items()} + elif isinstance(infos, torch.Tensor): + return infos.squeeze(0) + else: + return infos diff --git a/flash3d/unidepth/utils/positional_embedding.py b/flash3d/unidepth/utils/positional_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..8883076cb8d10896322d2973d6d0cd5df35e6943 --- /dev/null +++ b/flash3d/unidepth/utils/positional_embedding.py @@ -0,0 +1,274 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from math import pi +from typing import Optional + +import torch +import torch.nn as nn + +from einops import rearrange, repeat + + +class PositionEmbeddingSine(nn.Module): + def __init__( + self, num_pos_feats=64, temperature=10000, normalize=False, scale=None + ): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * pi + self.scale = scale + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if mask is None: + mask = torch.zeros( + (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool + ) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** ( + 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats + ) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self, _repr_indent=4): + head = "Positional encoding " + self.__class__.__name__ + body = [ + "num_pos_feats: {}".format(self.num_pos_feats), + "temperature: {}".format(self.temperature), + "normalize: {}".format(self.normalize), + "scale: {}".format(self.scale), + ] + # _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + +class LearnedSinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class VisionRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs_h = torch.einsum("..., f -> ... f", t, freqs) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + + freqs_w = torch.einsum("..., f -> ... f", t, freqs) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + + freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) + + self.register_buffer("freqs_cos", freqs.cos()) + self.register_buffer("freqs_sin", freqs.sin()) + + print("======== shape of rope freq", self.freqs_cos.shape, "========") + + def forward(self, t, start_index=0): + rot_dim = self.freqs_cos.shape[-1] + end_index = start_index + rot_dim + assert ( + rot_dim <= t.shape[-1] + ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + t_left, t, t_right = ( + t[..., :start_index], + t[..., start_index:end_index], + t[..., end_index:], + ) + t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) + return torch.cat((t_left, t, t_right), dim=-1) + + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + def forward(self, t): + return t * self.freqs_cos + rotate_half(t) * self.freqs_sin + + +from math import log2 + + +def generate_fourier_features( + x: torch.Tensor, + dim: int = 512, + max_freq: int = 64, + use_cos: bool = False, + use_log: bool = False, + cat_orig: bool = False, +): + x_orig = x + device, dtype, input_dim = x.device, x.dtype, x.shape[-1] + num_bands = dim // (2 * input_dim) if use_cos else dim // input_dim + + if use_log: + scales = 2.0 ** torch.linspace( + 0.0, log2(max_freq), steps=num_bands, device=device, dtype=dtype + ) + else: + scales = torch.linspace( + 1.0, max_freq / 2, num_bands, device=device, dtype=dtype + ) + + x = x.unsqueeze(-1) + scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] + + x = x * scales * pi + x = torch.cat( + ( + [x.sin(), x.cos()] + if use_cos + else [ + x.sin(), + ] + ), + dim=-1, + ) + x = x.flatten(-2) + if cat_orig: + return torch.cat((x, x_orig), dim=-1) + return x + + +# from PIL import Image +# from unidepth.utils import image_grid, colorize +# if __name__ == "__main__": +# H, W = 512, 512 +# resolution = 128 +# mesh = torch.meshgrid(torch.linspace(-1, 1, H), torch.linspace(-1, 1, W)) +# mesh = torch.stack(mesh, dim=0).unsqueeze(0) +# mesh = mesh.view(1, 2, -1).permute(0, 2, 1) + +# features = generate_fourier_features(mesh, dim=32, max_freq=resolution, use_log=True) +# channels = features.shape[-1] +# print(features.shape) + +# features = features[0].view(H, W, channels).permute(2, 0, 1).numpy() +# Image.fromarray(image_grid([colorize(1+x, 0.0, 2.0, "viridis") for x in features], rows=8, cols=4)).save(f"tmp_{resolution}.png") diff --git a/flash3d/unidepth/utils/sht.py b/flash3d/unidepth/utils/sht.py new file mode 100644 index 0000000000000000000000000000000000000000..4b89273a8f20b4da5ba296b175c856c974df0984 --- /dev/null +++ b/flash3d/unidepth/utils/sht.py @@ -0,0 +1,1637 @@ +"""Real spherical harmonics in Cartesian form for PyTorch. + +This is an autogenerated file. See +https://github.com/cheind/torch-spherical-harmonics +for more information. +""" + +import torch + + +def rsh_cart_0(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 0. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,1) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + ], + -1, + ) + + +def rsh_cart_1(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 1. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,4) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + ], + -1, + ) + + +def rsh_cart_2(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 2. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,9) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + ], + -1, + ) + + +def rsh_cart_3(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 3. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,16) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + ], + -1, + ) + + +def rsh_cart_4(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 4. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,25) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + z4 = z2**2 + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + ], + -1, + ) + + +def rsh_cart_5(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 5. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,36) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + z4 = z2**2 + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 8.30264925952416 * xy * z * (x2 - y2), + 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), + 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.241571547304372 + * y + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + -1.24747010616985 * z * (1.5 * z2 - 0.5) + + 1.6840846433293 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.498988042467941 * z, + 0.241571547304372 + * x + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), + 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), + -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + ], + -1, + ) + + +def rsh_cart_6(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 6. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,49) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + z4 = z2**2 + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 8.30264925952416 * xy * z * (x2 - y2), + 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), + 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.241571547304372 + * y + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + -1.24747010616985 * z * (1.5 * z2 - 0.5) + + 1.6840846433293 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.498988042467941 * z, + 0.241571547304372 + * x + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), + 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), + -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 4.09910463115149 * x**4 * xy + - 13.6636821038383 * xy**3 + + 4.09910463115149 * xy * y**4, + -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), + 0.00584892228263444 + * y + * (3.0 * x2 - y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0701870673916132 + * xy + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.221950995245231 + * y + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + -1.48328138624466 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.86469659985043 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.953538034014426 * z2 + - 0.317846011338142, + 0.221950995245231 + * x + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + 0.0350935336958066 + * (x2 - y2) + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.00584892228263444 + * x + * (x2 - 3.0 * y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), + -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 0.683184105191914 * x2**3 + + 10.2477615778787 * x2 * y4 + - 10.2477615778787 * x4 * y2 + - 0.683184105191914 * y2**3, + ], + -1, + ) + + +def rsh_cart_7(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 7. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,64) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + z4 = z2**2 + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 8.30264925952416 * xy * z * (x2 - y2), + 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), + 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.241571547304372 + * y + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + -1.24747010616985 * z * (1.5 * z2 - 0.5) + + 1.6840846433293 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.498988042467941 * z, + 0.241571547304372 + * x + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), + 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), + -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 4.09910463115149 * x**4 * xy + - 13.6636821038383 * xy**3 + + 4.09910463115149 * xy * y**4, + -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), + 0.00584892228263444 + * y + * (3.0 * x2 - y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0701870673916132 + * xy + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.221950995245231 + * y + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + -1.48328138624466 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.86469659985043 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.953538034014426 * z2 + - 0.317846011338142, + 0.221950995245231 + * x + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + 0.0350935336958066 + * (x2 - y2) + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.00584892228263444 + * x + * (x2 - 3.0 * y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), + -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 0.683184105191914 * x2**3 + + 10.2477615778787 * x2 * y4 + - 10.2477615778787 * x4 * y2 + - 0.683184105191914 * y2**3, + -0.707162732524596 + * y + * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), + 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), + 9.98394571852353e-5 + * y + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00239614697244565 + * xy + * (x2 - y2) + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z), + 0.00397356022507413 + * y + * (3.0 * x2 - y2) + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.0561946276120613 + * xy + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.206472245902897 + * y + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 1.24862677781952 * z * (1.5 * z2 - 0.5) + - 1.68564615005635 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 2.02901851395672 + * z + * ( + -1.45833333333333 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.83333333333333 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.9375 * z2 + - 0.3125 + ) + - 0.499450711127808 * z, + 0.206472245902897 + * x + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 0.0280973138060306 + * (x2 - y2) + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.00397356022507413 + * x + * (x2 - 3.0 * y2) + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.000599036743111412 + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + * (-6.0 * x2 * y2 + x4 + y4), + 9.98394571852353e-5 + * x + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), + -0.707162732524596 + * x + * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), + ], + -1, + ) + + +# @torch.jit.script +def rsh_cart_8(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 8. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,81) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + # z4 = z2**2 + return torch.stack( + [ + 0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 8.30264925952416 * xy * z * (x2 - y2), + 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), + 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.241571547304372 + * y + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + -1.24747010616985 * z * (1.5 * z2 - 0.5) + + 1.6840846433293 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.498988042467941 * z, + 0.241571547304372 + * x + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), + 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), + -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 4.09910463115149 * x**4 * xy + - 13.6636821038383 * xy**3 + + 4.09910463115149 * xy * y**4, + -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), + 0.00584892228263444 + * y + * (3.0 * x2 - y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0701870673916132 + * xy + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.221950995245231 + * y + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + -1.48328138624466 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.86469659985043 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.953538034014426 * z2 + - 0.317846011338142, + 0.221950995245231 + * x + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + 0.0350935336958066 + * (x2 - y2) + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.00584892228263444 + * x + * (x2 - 3.0 * y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), + -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 0.683184105191914 * x2**3 + + 10.2477615778787 * x2 * y4 + - 10.2477615778787 * x4 * y2 + - 0.683184105191914 * y2**3, + -0.707162732524596 + * y + * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), + 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), + 9.98394571852353e-5 + * y + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00239614697244565 + * xy + * (x2 - y2) + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z), + 0.00397356022507413 + * y + * (3.0 * x2 - y2) + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.0561946276120613 + * xy + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.206472245902897 + * y + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 1.24862677781952 * z * (1.5 * z2 - 0.5) + - 1.68564615005635 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 2.02901851395672 + * z + * ( + -1.45833333333333 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.83333333333333 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.9375 * z2 + - 0.3125 + ) + - 0.499450711127808 * z, + 0.206472245902897 + * x + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 0.0280973138060306 + * (x2 - y2) + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.00397356022507413 + * x + * (x2 - 3.0 * y2) + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.000599036743111412 + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + * (-6.0 * x2 * y2 + x4 + y4), + 9.98394571852353e-5 + * x + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), + -0.707162732524596 + * x + * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), + 5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3), + -2.91570664069932 + * yz + * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), + 7.87853281621404e-6 + * (1013512.5 * z2 - 67567.5) + * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), + 5.10587282657803e-5 + * y + * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) + * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00147275890257803 + * xy + * (x2 - y2) + * ( + 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + - 14293.125 * z2 + + 1299.375 + ), + 0.0028519853513317 + * y + * (3.0 * x2 - y2) + * ( + -7.33333333333333 * z * (52.5 - 472.5 * z2) + + 3.0 + * z + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ) + - 560.0 * z + ), + 0.0463392770473559 + * xy + * ( + -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + + 2.5 + * z + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ) + + 137.8125 * z2 + - 19.6875 + ), + 0.193851103820053 + * y + * ( + 3.2 * z * (1.5 - 7.5 * z2) + - 2.51428571428571 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + + 2.14285714285714 + * z + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ) + + 5.48571428571429 * z + ), + 1.48417251362228 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.86581687426801 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 2.1808249179756 + * z + * ( + 1.14285714285714 * z * (1.5 * z2 - 0.5) + - 1.54285714285714 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 1.85714285714286 + * z + * ( + -1.45833333333333 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.83333333333333 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.9375 * z2 + - 0.3125 + ) + - 0.457142857142857 * z + ) + - 0.954110901614325 * z2 + + 0.318036967204775, + 0.193851103820053 + * x + * ( + 3.2 * z * (1.5 - 7.5 * z2) + - 2.51428571428571 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + + 2.14285714285714 + * z + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ) + + 5.48571428571429 * z + ), + 0.0231696385236779 + * (x2 - y2) + * ( + -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + + 2.5 + * z + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ) + + 137.8125 * z2 + - 19.6875 + ), + 0.0028519853513317 + * x + * (x2 - 3.0 * y2) + * ( + -7.33333333333333 * z * (52.5 - 472.5 * z2) + + 3.0 + * z + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ) + - 560.0 * z + ), + 0.000368189725644507 + * (-6.0 * x2 * y2 + x4 + y4) + * ( + 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + - 14293.125 * z2 + + 1299.375 + ), + 5.10587282657803e-5 + * x + * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) + * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 7.87853281621404e-6 + * (1013512.5 * z2 - 67567.5) + * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), + -2.91570664069932 + * xz + * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), + -20.4099464848952 * x2**3 * y2 + - 20.4099464848952 * x2 * y2**3 + + 0.72892666017483 * x4**2 + + 51.0248662122381 * x4 * y4 + + 0.72892666017483 * y4**2, + ], + -1, + ) + + +__all__ = [ + "rsh_cart_0", + "rsh_cart_1", + "rsh_cart_2", + "rsh_cart_3", + "rsh_cart_4", + "rsh_cart_5", + "rsh_cart_6", + "rsh_cart_7", + "rsh_cart_8", +] + + +from typing import Optional +import torch + + +class SphHarm(torch.nn.Module): + def __init__(self, m, n, dtype=torch.float32) -> None: + super().__init__() + self.dtype = dtype + m = torch.tensor(list(range(-m + 1, m))) + n = torch.tensor(list(range(n))) + self.is_normalized = False + vals = torch.cartesian_prod(m, n).T + vals = vals[:, vals[0] <= vals[1]] + m, n = vals.unbind(0) + + self.register_buffer("m", tensor=m) + self.register_buffer("n", tensor=n) + self.register_buffer("l_max", tensor=torch.max(self.n)) + + f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre() + self.register_buffer("f_a", tensor=f_a) + self.register_buffer("f_b", tensor=f_b) + self.register_buffer("d0_mask_3d", tensor=d0_mask_3d) + self.register_buffer("d1_mask_3d", tensor=d1_mask_3d) + self.register_buffer("initial_value", tensor=initial_value) + + @property + def device(self): + return next(self.buffers()).device + + def forward(self, points: torch.Tensor) -> torch.Tensor: + """Computes the spherical harmonics.""" + # Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi) + B, N, D = points.shape + dtype = points.dtype + theta, phi = points.view(-1, D).to(self.dtype).unbind(-1) + cos_colatitude = torch.cos(phi) + legendre = self._gen_associated_legendre(cos_colatitude) + vals = torch.stack([self.m.abs(), self.n], dim=0) + vals = torch.cat( + [ + vals.repeat(1, theta.shape[0]), + torch.arange(theta.shape[0], device=theta.device) + .unsqueeze(0) + .repeat_interleave(vals.shape[1], dim=1), + ], + dim=0, + ) + legendre_vals = legendre[vals[0], vals[1], vals[2]] + legendre_vals = legendre_vals.reshape(-1, theta.shape[0]) + angle = torch.outer(self.m.abs(), theta) + vandermonde = torch.complex(torch.cos(angle), torch.sin(angle)) + harmonics = torch.complex( + legendre_vals * torch.real(vandermonde), + legendre_vals * torch.imag(vandermonde), + ) + + # Negative order. + m = self.m.unsqueeze(-1) + harmonics = torch.where( + m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics + ) + harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype) + return harmonics + + def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]: + """Generates mask for recurrence relation on the remaining entries. + + The remaining entries are with respect to the diagonal and offdiagonal + entries. + + Args: + l_max: see `gen_normalized_legendre`. + Returns: + torch.Tensors representing the mask used by the recurrence relations. + """ + + # Computes all coefficients. + m_mat, l_mat = torch.meshgrid( + torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype), + torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype), + indexing="ij", + ) + if self.is_normalized: + c0 = l_mat * l_mat + c1 = m_mat * m_mat + c2 = 2.0 * l_mat + c3 = (l_mat - 1.0) * (l_mat - 1.0) + d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1)) + d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1))) + else: + d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat) + d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat) + + d0_mask_indices = torch.triu_indices(self.l_max + 1, 1) + d1_mask_indices = torch.triu_indices(self.l_max + 1, 2) + + d_zeros = torch.zeros( + (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device + ) + d_zeros[d0_mask_indices] = d0[d0_mask_indices] + d0_mask = d_zeros + + d_zeros = torch.zeros( + (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device + ) + d_zeros[d1_mask_indices] = d1[d1_mask_indices] + d1_mask = d_zeros + + # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere. + i = torch.arange(self.l_max + 1, device=self.device)[:, None, None] + j = torch.arange(self.l_max + 1, device=self.device)[None, :, None] + k = torch.arange(self.l_max + 1, device=self.device)[None, None, :] + mask = (i + j - k == 0).to(self.dtype) + d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask) + d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask) + return (d0_mask_3d, d1_mask_3d) + + def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + coeff_0 = self.d0_mask_3d[i] + coeff_1 = self.d1_mask_3d[i] + h = torch.einsum( + "ij,ijk->ijk", + coeff_0, + torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x), + ) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1)) + p_val = p_val + h + return p_val + + def _init_legendre(self): + a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device) + b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device) + if self.is_normalized: + # The initial value p(0,0). + initial_value: torch.Tensor = torch.tensor( + 0.5 / (torch.pi**0.5), device=self.device + ) + f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0) + f_b = torch.sqrt(2.0 * b_idx + 3.0) + else: + # The initial value p(0,0). + initial_value = torch.tensor(1.0, device=self.device) + f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0) + f_b = 2.0 * b_idx + 1.0 + + d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask() + return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d + + def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor: + r"""Computes associated Legendre functions (ALFs) of the first kind. + + The ALFs of the first kind are used in spherical harmonics. The spherical + harmonic of degree `l` and order `m` can be written as + `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the + normalization factor and θ and φ are the colatitude and longitude, + repectively. `N_l^m` is chosen in the way that the spherical harmonics form + a set of orthonormal basis function of L^2(S^2). For the computational + efficiency of spherical harmonics transform, the normalization factor is + used in the computation of the ALFs. In addition, normalizing `P_l^m` + avoids overflow/underflow and achieves better numerical stability. Three + recurrence relations are used in the computation. + + Args: + l_max: The maximum degree of the associated Legendre function. Both the + degrees and orders are `[0, 1, 2, ..., l_max]`. + x: A vector of type `float32`, `float64` containing the sampled points in + spherical coordinates, at which the ALFs are computed; `x` is essentially + `cos(θ)`. For the numerical integration used by the spherical harmonics + transforms, `x` contains the quadrature points in the interval of + `[-1, 1]`. There are several approaches to provide the quadrature points: + Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev + method (`scipy.special.roots_chebyu`), and Driscoll & Healy + method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier + transforms and convolutions on the 2-sphere." Advances in applied + mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature + points are nearly equal-spaced along θ and provide exact discrete + orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose + operation, `W` is a diagonal matrix containing the quadrature weights, + and `I` is the identity matrix. The Gauss-Chebyshev points are equally + spaced, which only provide approximate discrete orthogonality. The + Driscoll & Healy qudarture points are equally spaced and provide the + exact discrete orthogonality. The number of sampling points is required to + be twice as the number of frequency points (modes) in the Driscoll & Healy + approach, which enables FFT and achieves a fast spherical harmonics + transform. + is_normalized: True if the associated Legendre functions are normalized. + With normalization, `N_l^m` is applied such that the spherical harmonics + form a set of orthonormal basis functions of L^2(S^2). + + Returns: + The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values + of the ALFs at `x`; the dimensions in the sequence of order, degree, and + evalution points. + """ + p = torch.zeros( + (self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device + ) + p[0, 0] = self.initial_value + + # Compute the diagonal entries p(l,l) with recurrence. + y = torch.cumprod( + torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0 + ) + p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y) + # torch.diag_indices(l_max + 1) + diag_indices = torch.stack( + [torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0 + ) + p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag + + diag_indices = torch.stack( + [torch.arange(0, self.l_max, device=x.device)] * 2, dim=0 + ) + + # Compute the off-diagonal entries with recurrence. + p_offdiag = torch.einsum( + "ij,ij->ij", + torch.einsum("i,j->ij", self.f_b, x), + p[(diag_indices[0], diag_indices[1])], + ) # p[torch.diag_indices(l_max)]) + p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = ( + p_offdiag + ) + + # Compute the remaining entries with recurrence. + if self.l_max > 1: + for i in range(2, self.l_max + 1): + p = self._recursive(i, p, x) + return p diff --git a/flash3d/unidepth/utils/visualization.py b/flash3d/unidepth/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..8504ec0430924847a1c2123be0dcea6e00c6945d --- /dev/null +++ b/flash3d/unidepth/utils/visualization.py @@ -0,0 +1,201 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import os + +import numpy as np +from PIL import Image +import matplotlib.cm +import wandb +import torch + +from unidepth.utils.misc import ssi_helper + + +def colorize( + value: np.ndarray, vmin: float = None, vmax: float = None, cmap: str = "magma_r" +): + # if already RGB, do nothing + if value.ndim > 2: + if value.shape[-1] > 1: + return value + value = value[..., 0] + invalid_mask = value < 0.0001 + # normalize + vmin = value.min() if vmin is None else vmin + vmax = value.max() if vmax is None else vmax + value = (value - vmin) / (vmax - vmin) # vmin..vmax + + # set color + cmapper = matplotlib.cm.get_cmap(cmap) + value = cmapper(value, bytes=True) # (nxmx4) + value[invalid_mask] = 0 + img = value[..., :3] + return img + + +def image_grid(imgs: list[np.ndarray], rows: int, cols: int) -> np.ndarray: + if not len(imgs): + return None + assert len(imgs) == rows * cols + h, w = imgs[0].shape[:2] + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste( + Image.fromarray(img.astype(np.uint8)).resize( + (w, h), resample=Image.BILINEAR + ), + box=(i % cols * w, i // cols * h), + ) + + return np.array(grid) + + +def get_pointcloud_from_rgbd( + image: np.array, + depth: np.array, + mask: np.ndarray, + intrinsic_matrix: np.array, + extrinsic_matrix: np.array = None, +): + depth = np.array(depth).squeeze() + mask = np.array(mask).squeeze() + # Mask the depth array + masked_depth = np.ma.masked_where(mask == False, depth) + # masked_depth = np.ma.masked_greater(masked_depth, 8000) + # Create idx array + idxs = np.indices(masked_depth.shape) + u_idxs = idxs[1] + v_idxs = idxs[0] + # Get only non-masked depth and idxs + z = masked_depth[~masked_depth.mask] + compressed_u_idxs = u_idxs[~masked_depth.mask] + compressed_v_idxs = v_idxs[~masked_depth.mask] + image = np.stack( + [image[..., i][~masked_depth.mask] for i in range(image.shape[-1])], axis=-1 + ) + + # Calculate local position of each point + # Apply vectorized math to depth using compressed arrays + cx = intrinsic_matrix[0, 2] + fx = intrinsic_matrix[0, 0] + x = (compressed_u_idxs - cx) * z / fx + cy = intrinsic_matrix[1, 2] + fy = intrinsic_matrix[1, 1] + # Flip y as we want +y pointing up not down + y = -((compressed_v_idxs - cy) * z / fy) + + # # Apply camera_matrix to pointcloud as to get the pointcloud in world coords + # if extrinsic_matrix is not None: + # # Calculate camera pose from extrinsic matrix + # camera_matrix = np.linalg.inv(extrinsic_matrix) + # # Create homogenous array of vectors by adding 4th entry of 1 + # # At the same time flip z as for eye space the camera is looking down the -z axis + # w = np.ones(z.shape) + # x_y_z_eye_hom = np.vstack((x, y, -z, w)) + # # Transform the points from eye space to world space + # x_y_z_world = np.dot(camera_matrix, x_y_z_eye_hom)[:3] + # return x_y_z_world.T + # else: + x_y_z_local = np.stack((x, y, z), axis=-1) + return np.concatenate([x_y_z_local, image], axis=-1) + + +def save_file_ply(xyz, rgb, pc_file): + if rgb.max() < 1.001: + rgb = rgb * 255.0 + rgb = rgb.astype(np.uint8) + # print(rgb) + with open(pc_file, "w") as f: + # headers + f.writelines( + [ + "ply\n" "format ascii 1.0\n", + "element vertex {}\n".format(xyz.shape[0]), + "property float x\n", + "property float y\n", + "property float z\n", + "property uchar red\n", + "property uchar green\n", + "property uchar blue\n", + "end_header\n", + ] + ) + + for i in range(xyz.shape[0]): + str_v = "{:10.6f} {:10.6f} {:10.6f} {:d} {:d} {:d}\n".format( + xyz[i, 0], xyz[i, 1], xyz[i, 2], rgb[i, 0], rgb[i, 1], rgb[i, 2] + ) + f.write(str_v) + + +# really awful fct... FIXME +def log_train_artifacts(rgbs, gts, preds, ds_name, step, infos={}): + rgbs = [ + (127.5 * (rgb + 1)) + .clip(0, 255) + .to(torch.uint8) + .cpu() + .detach() + .permute(1, 2, 0) + .numpy() + for rgb in rgbs + ] + + new_gts, new_preds = [], [] + if len(gts) > 0: + for i, gt in enumerate(gts): + scale, shift = ssi_helper( + gts[i][gts[i] > 0].cpu().detach(), preds[i][gts[i] > 0].cpu().detach() + ) + gt = gts[i].cpu().detach().squeeze().numpy() + pred = (preds[i].cpu().detach() * scale + shift).squeeze().numpy() + vmin = gt[gt > 0].min() if (gt > 0).any() else 0.0 + vmax = gt.max() if (gt > 0).any() else 0.1 + new_gts.append(colorize(gt, vmin=vmin, vmax=vmax)) + new_preds.append(colorize(pred, vmin=vmin, vmax=vmax)) + gts, preds = new_gts, new_preds + else: + preds = [ + colorize(pred.cpu().detach().squeeze().numpy(), 0.0, 80.0) + for i, pred in enumerate(preds) + ] + + num_additional, additionals = 0, [] + for name, info in infos.items(): + num_additional += 1 + if info.shape[1] == 3: + additionals.extend( + [ + (127.5 * (x + 1)) + .clip(0, 255) + .to(torch.uint8) + .cpu() + .detach() + .permute(1, 2, 0) + .numpy() + for x in info[:4] + ] + ) + else: + additionals.extend( + [ + colorize(x.cpu().detach().squeeze().numpy()) + for i, x in enumerate(info[:4]) + ] + ) + + num_rows = 2 + int(len(gts) > 0) + num_additional + artifacts_grid = image_grid( + [*rgbs, *gts, *preds, *additionals], num_rows, len(rgbs) + ) + try: + wandb.log({f"{ds_name}_training": [wandb.Image(artifacts_grid)]}, step=step) + except: + Image.fromarray(artifacts_grid).save( + os.path.join(os.environ["HOME"], "Workspace", f"art_grid{step}.png") + ) + print("Logging training images failed") diff --git a/flash3d/util/vis3d.py b/flash3d/util/vis3d.py new file mode 100644 index 0000000000000000000000000000000000000000..deb53d03ffa735e4352e1f6eda38f12164b1fd71 --- /dev/null +++ b/flash3d/util/vis3d.py @@ -0,0 +1,135 @@ +from pathlib import Path +from jaxtyping import Float +import numpy as np +from scipy.spatial.transform import Rotation as R +from plyfile import PlyData, PlyElement +import torch +from torch import Tensor +from einops import rearrange, einsum + + +def construct_list_of_attributes(num_rest: int) -> list[str]: + attributes = ["x", "y", "z", "nx", "ny", "nz"] + for i in range(3): + attributes.append(f"f_dc_{i}") + for i in range(num_rest): + attributes.append(f"f_rest_{i}") + attributes.append("opacity") + for i in range(3): + attributes.append(f"scale_{i}") + for i in range(4): + attributes.append(f"rot_{i}") + return attributes + + +def export_ply( + means: Float[Tensor, "gaussian 3"], + scales: Float[Tensor, "gaussian 3"], + rotations: Float[Tensor, "gaussian 4"], + harmonics: Float[Tensor, "gaussian 3 d_sh"], + opacities: Float[Tensor, "gaussian"], + path: Path, +): + path = Path(path) + # Shift the scene so that the median Gaussian is at the origin. + means = means - means.median(dim=0).values + + # Rescale the scene so that most Gaussians are within range [-1, 1]. + scale_factor = means.abs().quantile(0.95, dim=0).max() + means = means / scale_factor + scales = scales / scale_factor + scales = scales * 4.0 + scales = torch.clamp(scales, 0, 0.0075) + + # Define a rotation that makes +Z be the world up vector. + # rotation = [ + # [0, 0, 1], + # [-1, 0, 0], + # [0, -1, 0], + # ] + rotation = [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ] + rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device) + + # The Polycam viewer seems to start at a 45 degree angle. Since we want to be + # looking directly at the object, we compose a 45 degree rotation onto the above + # rotation. + # adjustment = torch.tensor( + # R.from_rotvec([0, 0, -45], True).as_matrix(), + # dtype=torch.float32, + # device=means.device, + # ) + # rotation = adjustment @ rotation + + # We also want to see the scene in camera space (as the default view). We therefore + # compose the w2c rotation onto the above rotation. + # rotation = rotation @ extrinsics[:3, :3].inverse() + + # Apply the rotation to the means (Gaussian positions). + means = einsum(rotation, means, "i j, ... j -> ... i") + + # Apply the rotation to the Gaussian rotations. + rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() + rotations = rotation.detach().cpu().numpy() @ rotations + rotations = R.from_matrix(rotations).as_quat() + x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") + rotations = np.stack((w, x, y, z), axis=-1) + + # Since our axes are swizzled for the spherical harmonics, we only export the DC + # band. + harmonics_view_invariant = harmonics + + dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)] + elements = np.empty(means.shape[0], dtype=dtype_full) + attributes = ( + means.detach().cpu().numpy(), + torch.zeros_like(means).detach().cpu().numpy(), + harmonics_view_invariant.detach().cpu().contiguous().numpy(), + opacities.detach().cpu().numpy(), + scales.log().detach().cpu().numpy(), + rotations, + ) + attributes = np.concatenate(attributes, axis=1) + elements[:] = list(map(tuple, attributes)) + path.parent.mkdir(exist_ok=True, parents=True) + PlyData([PlyElement.describe(elements, "vertex")]).write(path) + + +def save_ply(outputs, path, num_gauss=3): + pad = 32 + + def crop_r(t): + h, w = 256, 384 + H = h + pad * 2 + W = w + pad * 2 + t = rearrange(t, "b c (h w) -> b c h w", h=H, w=W) + t = t[..., pad:H-pad, pad:W-pad] + t = rearrange(t, "b c h w -> b c (h w)") + return t + + def crop(t): + h, w = 256, 384 + H = h + pad * 2 + W = w + pad * 2 + t = t[..., pad:H-pad, pad:W-pad] + return t + + # import pdb + # pdb.set_trace() + means = rearrange(crop_r(outputs[('gauss_means', 0, 0)]), "(b v) c n -> b (v n) c", v=num_gauss)[0, :, :3] + scales = rearrange(crop(outputs[('gauss_scaling', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0] + rotations = rearrange(crop(outputs[('gauss_rotation', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0] + opacities = rearrange(crop(outputs[('gauss_opacity', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0] + harmonics = rearrange(crop(outputs[('gauss_features_dc', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0] + + export_ply( + means, + scales, + rotations, + harmonics, + opacities, + path + ) \ No newline at end of file diff --git a/pre-requirements.txt b/pre-requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..11e33b0618d21e50d42e80e939c5441017e7844a --- /dev/null +++ b/pre-requirements.txt @@ -0,0 +1,5 @@ +--extra-index-url https://download.pytorch.org/whl/cu118 +torch +torchvision +torchaudio +xformers==0.0.25.post1 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..401f350fb13e757170daa0d678ed1c2b8edd297a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +einops +huggingface-hub>=0.22.0 +opencv-python +imageio +matplotlib +safetensors +scipy +timm +tqdm +wandb +neptune +scikit-image +plyfile +omegaconf +jaxtyping +gradio +spaces +numpy +torch +torchvision +torchaudio +xformers \ No newline at end of file