import dataclasses import os from typing import Any, Callable, Optional, Tuple, List import numpy as np import torch import torch.nn.functional as F from absl import flags import gin from internal import utils gin.add_config_file_search_path('configs/') configurables = { 'torch': [torch.reciprocal, torch.log, torch.log1p, torch.exp, torch.sqrt, torch.square], } for module, configurables in configurables.items(): for configurable in configurables: gin.config.external_configurable(configurable, module=module) @gin.configurable() @dataclasses.dataclass class Config: """Configuration flags for everything.""" seed = 0 dataset_loader: str = 'llff' # The type of dataset loader to use. batching: str = 'all_images' # Batch composition, [single_image, all_images]. batch_size: int = 2 ** 16 # The number of rays/pixels in each batch. patch_size: int = 1 # Resolution of patches sampled for training batches. factor: int = 4 # The downsample factor of images, 0 for no downsampling. multiscale: bool = False # use multiscale data for training. multiscale_levels: int = 4 # number of multiscale levels. # ordering (affects heldout test set). forward_facing: bool = False # Set to True for forward-facing LLFF captures. render_path: bool = False # If True, render a path. Used only by LLFF. llffhold: int = 8 # Use every Nth image for the test set. Used only by LLFF. # If true, use all input images for training. llff_use_all_images_for_training: bool = False llff_use_all_images_for_testing: bool = False use_tiffs: bool = False # If True, use 32-bit TIFFs. Used only by Blender. compute_disp_metrics: bool = False # If True, load and compute disparity MSE. compute_normal_metrics: bool = False # If True, load and compute normal MAE. disable_multiscale_loss: bool = False # If True, disable multiscale loss. randomized: bool = True # Use randomized stratified sampling. near: float = 2. # Near plane distance. far: float = 6. # Far plane distance. exp_name: str = "test" # experiment name data_dir: Optional[str] = "/SSD_DISK/datasets/360_v2/bicycle" # Input data directory. vocab_tree_path: Optional[str] = None # Path to vocab tree for COLMAP. render_chunk_size: int = 65536 # Chunk size for whole-image renderings. num_showcase_images: int = 5 # The number of test-set images to showcase. deterministic_showcase: bool = True # If True, showcase the same images. vis_num_rays: int = 16 # The number of rays to visualize. # Decimate images for tensorboard (ie, x[::d, ::d]) to conserve memory usage. vis_decimate: int = 0 # Only used by train.py: max_steps: int = 25000 # The number of optimization steps. early_exit_steps: Optional[int] = None # Early stopping, for debugging. checkpoint_every: int = 5000 # The number of steps to save a checkpoint. resume_from_checkpoint: bool = True # whether to resume from checkpoint. checkpoints_total_limit: int = 1 gradient_scaling: bool = False # If True, scale gradients as in https://gradient-scaling.github.io/. print_every: int = 100 # The number of steps between reports to tensorboard. train_render_every: int = 500 # Steps between test set renders when training data_loss_type: str = 'charb' # What kind of loss to use ('mse' or 'charb'). charb_padding: float = 0.001 # The padding used for Charbonnier loss. data_loss_mult: float = 1.0 # Mult for the finest data term in the loss. data_coarse_loss_mult: float = 0. # Multiplier for the coarser data terms. interlevel_loss_mult: float = 0.0 # Mult. for the loss on the proposal MLP. anti_interlevel_loss_mult: float = 0.01 # Mult. for the loss on the proposal MLP. pulse_width = [0.03, 0.003] # Mult. for the loss on the proposal MLP. orientation_loss_mult: float = 0.0 # Multiplier on the orientation loss. orientation_coarse_loss_mult: float = 0.0 # Coarser orientation loss weights. # What that loss is imposed on, options are 'normals' or 'normals_pred'. orientation_loss_target: str = 'normals_pred' predicted_normal_loss_mult: float = 0.0 # Mult. on the predicted normal loss. # Mult. on the coarser predicted normal loss. predicted_normal_coarse_loss_mult: float = 0.0 hash_decay_mults: float = 0.1 lr_init: float = 0.01 # The initial learning rate. lr_final: float = 0.001 # The final learning rate. lr_delay_steps: int = 5000 # The number of "warmup" learning steps. lr_delay_mult: float = 1e-8 # How much sever the "warmup" should be. adam_beta1: float = 0.9 # Adam's beta2 hyperparameter. adam_beta2: float = 0.99 # Adam's beta2 hyperparameter. adam_eps: float = 1e-15 # Adam's epsilon hyperparameter. grad_max_norm: float = 0. # Gradient clipping magnitude, disabled if == 0. grad_max_val: float = 0. # Gradient clipping value, disabled if == 0. distortion_loss_mult: float = 0.005 # Multiplier on the distortion loss. opacity_loss_mult: float = 0. # Multiplier on the distortion loss. # Only used by eval.py: eval_only_once: bool = True # If True evaluate the model only once, ow loop. eval_save_output: bool = True # If True save predicted images to disk. eval_save_ray_data: bool = False # If True save individual ray traces. eval_render_interval: int = 1 # The interval between images saved to disk. eval_dataset_limit: int = np.iinfo(np.int32).max # Num test images to eval. eval_quantize_metrics: bool = True # If True, run metrics on 8-bit images. eval_crop_borders: int = 0 # Ignore c border pixels in eval (x[c:-c, c:-c]). # Only used by render.py render_video_fps: int = 60 # Framerate in frames-per-second. render_video_crf: int = 18 # Constant rate factor for ffmpeg video quality. render_path_frames: int = 120 # Number of frames in render path. z_variation: float = 0. # How much height variation in render path. z_phase: float = 0. # Phase offset for height variation in render path. render_dist_percentile: float = 0.5 # How much to trim from near/far planes. render_dist_curve_fn: Callable[..., Any] = np.log # How depth is curved. render_path_file: Optional[str] = None # Numpy render pose file to load. render_resolution: Optional[Tuple[int, int]] = None # Render resolution, as # (width, height). render_focal: Optional[float] = None # Render focal length. render_camtype: Optional[str] = None # 'perspective', 'fisheye', or 'pano'. render_spherical: bool = False # Render spherical 360 panoramas. render_save_async: bool = True # Save to CNS using a separate thread. render_spline_keyframes: Optional[str] = None # Text file containing names of # images to be used as spline # keyframes, OR directory # containing those images. render_spline_n_interp: int = 30 # Num. frames to interpolate per keyframe. render_spline_degree: int = 5 # Polynomial degree of B-spline interpolation. render_spline_smoothness: float = .03 # B-spline smoothing factor, 0 for # exact interpolation of keyframes. # Interpolate per-frame exposure value from spline keyframes. render_spline_interpolate_exposure: bool = False # Flags for raw datasets. rawnerf_mode: bool = False # Load raw images and train in raw color space. exposure_percentile: float = 97. # Image percentile to expose as white. num_border_pixels_to_mask: int = 0 # During training, discard N-pixel border # around each input image. apply_bayer_mask: bool = False # During training, apply Bayer mosaic mask. autoexpose_renders: bool = False # During rendering, autoexpose each image. # For raw test scenes, use affine raw-space color correction. eval_raw_affine_cc: bool = False zero_glo: bool = False # marching cubes valid_weight_thresh: float = 0.05 isosurface_threshold: float = 20 mesh_voxels: int = 512 ** 3 visibility_resolution: int = 512 mesh_radius: float = 1.0 # mesh radius * 2 = in contract space mesh_max_radius: float = 10.0 # in world space std_value: float = 0.0 # std of the sampled points compute_visibility: bool = False extract_visibility: bool = True decimate_target: int = -1 vertex_color: bool = True vertex_projection: bool = True # tsdf tsdf_radius: float = 2.0 tsdf_resolution: int = 512 truncation_margin: float = 5.0 tsdf_max_radius: float = 10.0 # in world space def define_common_flags(): # Define the flags used by both train.py and eval.py flags.DEFINE_string('mode', None, 'Required by GINXM, not used.') flags.DEFINE_string('base_folder', None, 'Required by GINXM, not used.') flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.') flags.DEFINE_multi_string('gin_configs', None, 'Gin config files.') def load_config(): """Load the config, and optionally checkpoint it.""" gin.parse_config_files_and_bindings( flags.FLAGS.gin_configs, flags.FLAGS.gin_bindings, skip_unknown=True) config = Config() return config