|
import dataclasses |
|
import os |
|
from typing import Any, Callable, Optional, Tuple, List |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from absl import flags |
|
import gin |
|
from internal import utils |
|
|
|
gin.add_config_file_search_path('configs/') |
|
|
|
configurables = { |
|
'torch': [torch.reciprocal, torch.log, torch.log1p, torch.exp, torch.sqrt, torch.square], |
|
} |
|
|
|
for module, configurables in configurables.items(): |
|
for configurable in configurables: |
|
gin.config.external_configurable(configurable, module=module) |
|
|
|
|
|
@gin.configurable() |
|
@dataclasses.dataclass |
|
class Config: |
|
"""Configuration flags for everything.""" |
|
seed = 0 |
|
dataset_loader: str = 'llff' |
|
batching: str = 'all_images' |
|
batch_size: int = 2 ** 16 |
|
patch_size: int = 1 |
|
factor: int = 4 |
|
multiscale: bool = False |
|
multiscale_levels: int = 4 |
|
|
|
forward_facing: bool = False |
|
render_path: bool = False |
|
llffhold: int = 8 |
|
|
|
llff_use_all_images_for_training: bool = False |
|
llff_use_all_images_for_testing: bool = False |
|
use_tiffs: bool = False |
|
compute_disp_metrics: bool = False |
|
compute_normal_metrics: bool = False |
|
disable_multiscale_loss: bool = False |
|
randomized: bool = True |
|
near: float = 2. |
|
far: float = 6. |
|
exp_name: str = "test" |
|
data_dir: Optional[str] = "/SSD_DISK/datasets/360_v2/bicycle" |
|
vocab_tree_path: Optional[str] = None |
|
render_chunk_size: int = 65536 |
|
num_showcase_images: int = 5 |
|
deterministic_showcase: bool = True |
|
vis_num_rays: int = 16 |
|
|
|
vis_decimate: int = 0 |
|
|
|
|
|
max_steps: int = 25000 |
|
early_exit_steps: Optional[int] = None |
|
checkpoint_every: int = 5000 |
|
resume_from_checkpoint: bool = True |
|
checkpoints_total_limit: int = 1 |
|
gradient_scaling: bool = False |
|
print_every: int = 100 |
|
train_render_every: int = 500 |
|
data_loss_type: str = 'charb' |
|
charb_padding: float = 0.001 |
|
data_loss_mult: float = 1.0 |
|
data_coarse_loss_mult: float = 0. |
|
interlevel_loss_mult: float = 0.0 |
|
anti_interlevel_loss_mult: float = 0.01 |
|
pulse_width = [0.03, 0.003] |
|
orientation_loss_mult: float = 0.0 |
|
orientation_coarse_loss_mult: float = 0.0 |
|
|
|
orientation_loss_target: str = 'normals_pred' |
|
predicted_normal_loss_mult: float = 0.0 |
|
|
|
predicted_normal_coarse_loss_mult: float = 0.0 |
|
hash_decay_mults: float = 0.1 |
|
|
|
lr_init: float = 0.01 |
|
lr_final: float = 0.001 |
|
lr_delay_steps: int = 5000 |
|
lr_delay_mult: float = 1e-8 |
|
adam_beta1: float = 0.9 |
|
adam_beta2: float = 0.99 |
|
adam_eps: float = 1e-15 |
|
grad_max_norm: float = 0. |
|
grad_max_val: float = 0. |
|
distortion_loss_mult: float = 0.005 |
|
opacity_loss_mult: float = 0. |
|
|
|
|
|
eval_only_once: bool = True |
|
eval_save_output: bool = True |
|
eval_save_ray_data: bool = False |
|
eval_render_interval: int = 1 |
|
eval_dataset_limit: int = np.iinfo(np.int32).max |
|
eval_quantize_metrics: bool = True |
|
eval_crop_borders: int = 0 |
|
|
|
|
|
render_video_fps: int = 60 |
|
render_video_crf: int = 18 |
|
render_path_frames: int = 120 |
|
z_variation: float = 0. |
|
z_phase: float = 0. |
|
render_dist_percentile: float = 0.5 |
|
render_dist_curve_fn: Callable[..., Any] = np.log |
|
render_path_file: Optional[str] = None |
|
render_resolution: Optional[Tuple[int, int]] = None |
|
|
|
render_focal: Optional[float] = None |
|
render_camtype: Optional[str] = None |
|
render_spherical: bool = False |
|
render_save_async: bool = True |
|
|
|
render_spline_keyframes: Optional[str] = None |
|
|
|
|
|
|
|
render_spline_n_interp: int = 30 |
|
render_spline_degree: int = 5 |
|
render_spline_smoothness: float = .03 |
|
|
|
|
|
render_spline_interpolate_exposure: bool = False |
|
|
|
|
|
rawnerf_mode: bool = False |
|
exposure_percentile: float = 97. |
|
num_border_pixels_to_mask: int = 0 |
|
|
|
apply_bayer_mask: bool = False |
|
autoexpose_renders: bool = False |
|
|
|
eval_raw_affine_cc: bool = False |
|
|
|
zero_glo: bool = False |
|
|
|
|
|
valid_weight_thresh: float = 0.05 |
|
isosurface_threshold: float = 20 |
|
mesh_voxels: int = 512 ** 3 |
|
visibility_resolution: int = 512 |
|
mesh_radius: float = 1.0 |
|
mesh_max_radius: float = 10.0 |
|
std_value: float = 0.0 |
|
compute_visibility: bool = False |
|
extract_visibility: bool = True |
|
decimate_target: int = -1 |
|
vertex_color: bool = True |
|
vertex_projection: bool = True |
|
|
|
|
|
tsdf_radius: float = 2.0 |
|
tsdf_resolution: int = 512 |
|
truncation_margin: float = 5.0 |
|
tsdf_max_radius: float = 10.0 |
|
|
|
|
|
def define_common_flags(): |
|
|
|
flags.DEFINE_string('mode', None, 'Required by GINXM, not used.') |
|
flags.DEFINE_string('base_folder', None, 'Required by GINXM, not used.') |
|
flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.') |
|
flags.DEFINE_multi_string('gin_configs', None, 'Gin config files.') |
|
|
|
|
|
def load_config(): |
|
"""Load the config, and optionally checkpoint it.""" |
|
gin.parse_config_files_and_bindings( |
|
flags.FLAGS.gin_configs, flags.FLAGS.gin_bindings, skip_unknown=True) |
|
config = Config() |
|
return config |
|
|