|
import abc |
|
import copy |
|
import json |
|
import os |
|
import cv2 |
|
from internal import camera_utils |
|
from internal import configs |
|
from internal import image as lib_image |
|
from internal import raw_utils |
|
from internal import utils |
|
from collections import defaultdict |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
import torch |
|
from tqdm import tqdm |
|
|
|
import sys |
|
|
|
sys.path.insert(0, 'internal/pycolmap') |
|
sys.path.insert(0, 'internal/pycolmap/pycolmap') |
|
import pycolmap |
|
|
|
|
|
def load_dataset(split, train_dir, config: configs.Config): |
|
"""Loads a split of a dataset using the data_loader specified by `config`.""" |
|
if config.multiscale: |
|
dataset_dict = { |
|
'llff': MultiLLFF, |
|
} |
|
else: |
|
dataset_dict = { |
|
'blender': Blender, |
|
'llff': LLFF, |
|
'tat_nerfpp': TanksAndTemplesNerfPP, |
|
'tat_fvs': TanksAndTemplesFVS, |
|
'dtu': DTU, |
|
} |
|
return dataset_dict[config.dataset_loader](split, train_dir, config) |
|
|
|
|
|
class NeRFSceneManager(pycolmap.SceneManager): |
|
"""COLMAP pose loader. |
|
|
|
Minor NeRF-specific extension to the third_party Python COLMAP loader: |
|
google3/third_party/py/pycolmap/scene_manager.py |
|
""" |
|
|
|
def process(self): |
|
"""Applies NeRF-specific postprocessing to the loaded pose data. |
|
|
|
Returns: |
|
a tuple [image_names, poses, pixtocam, distortion_params]. |
|
image_names: contains the only the basename of the images. |
|
poses: [N, 4, 4] array containing the camera to world matrices. |
|
pixtocam: [N, 3, 3] array containing the camera to pixel space matrices. |
|
distortion_params: mapping of distortion param name to distortion |
|
parameters. Cameras share intrinsics. Valid keys are k1, k2, p1 and p2. |
|
""" |
|
|
|
self.load_cameras() |
|
self.load_images() |
|
|
|
|
|
|
|
cam = self.cameras[1] |
|
|
|
|
|
fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy |
|
pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy)) |
|
|
|
|
|
imdata = self.images |
|
w2c_mats = [] |
|
bottom = np.array([0, 0, 0, 1]).reshape(1, 4) |
|
for k in imdata: |
|
im = imdata[k] |
|
rot = im.R() |
|
trans = im.tvec.reshape(3, 1) |
|
w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0) |
|
w2c_mats.append(w2c) |
|
w2c_mats = np.stack(w2c_mats, axis=0) |
|
|
|
|
|
c2w_mats = np.linalg.inv(w2c_mats) |
|
poses = c2w_mats[:, :3, :4] |
|
|
|
|
|
|
|
names = [imdata[k].name for k in imdata] |
|
|
|
|
|
poses = poses @ np.diag([1, -1, -1, 1]) |
|
|
|
|
|
type_ = cam.camera_type |
|
|
|
if type_ == 0 or type_ == 'SIMPLE_PINHOLE': |
|
params = None |
|
camtype = camera_utils.ProjectionType.PERSPECTIVE |
|
|
|
elif type_ == 1 or type_ == 'PINHOLE': |
|
params = None |
|
camtype = camera_utils.ProjectionType.PERSPECTIVE |
|
|
|
if type_ == 2 or type_ == 'SIMPLE_RADIAL': |
|
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']} |
|
params['k1'] = cam.k1 |
|
camtype = camera_utils.ProjectionType.PERSPECTIVE |
|
|
|
elif type_ == 3 or type_ == 'RADIAL': |
|
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']} |
|
params['k1'] = cam.k1 |
|
params['k2'] = cam.k2 |
|
camtype = camera_utils.ProjectionType.PERSPECTIVE |
|
|
|
elif type_ == 4 or type_ == 'OPENCV': |
|
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']} |
|
params['k1'] = cam.k1 |
|
params['k2'] = cam.k2 |
|
params['p1'] = cam.p1 |
|
params['p2'] = cam.p2 |
|
camtype = camera_utils.ProjectionType.PERSPECTIVE |
|
|
|
elif type_ == 5 or type_ == 'OPENCV_FISHEYE': |
|
params = {k: 0. for k in ['k1', 'k2', 'k3', 'k4']} |
|
params['k1'] = cam.k1 |
|
params['k2'] = cam.k2 |
|
params['k3'] = cam.k3 |
|
params['k4'] = cam.k4 |
|
camtype = camera_utils.ProjectionType.FISHEYE |
|
|
|
return names, poses, pixtocam, params, camtype |
|
|
|
|
|
def load_blender_posedata(data_dir, split=None): |
|
"""Load poses from `transforms.json` file, as used in Blender/NGP datasets.""" |
|
suffix = '' if split is None else f'_{split}' |
|
pose_file = os.path.join(data_dir, f'transforms{suffix}.json') |
|
with utils.open_file(pose_file, 'r') as fp: |
|
meta = json.load(fp) |
|
names = [] |
|
poses = [] |
|
for _, frame in enumerate(meta['frames']): |
|
filepath = os.path.join(data_dir, frame['file_path']) |
|
if utils.file_exists(filepath): |
|
names.append(frame['file_path'].split('/')[-1]) |
|
poses.append(np.array(frame['transform_matrix'], dtype=np.float32)) |
|
poses = np.stack(poses, axis=0) |
|
|
|
w = meta['w'] |
|
h = meta['h'] |
|
cx = meta['cx'] if 'cx' in meta else w / 2. |
|
cy = meta['cy'] if 'cy' in meta else h / 2. |
|
if 'fl_x' in meta: |
|
fx = meta['fl_x'] |
|
else: |
|
fx = 0.5 * w / np.tan(0.5 * float(meta['camera_angle_x'])) |
|
if 'fl_y' in meta: |
|
fy = meta['fl_y'] |
|
else: |
|
fy = 0.5 * h / np.tan(0.5 * float(meta['camera_angle_y'])) |
|
pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy)) |
|
coeffs = ['k1', 'k2', 'p1', 'p2'] |
|
if not any([c in meta for c in coeffs]): |
|
params = None |
|
else: |
|
params = {c: (meta[c] if c in meta else 0.) for c in coeffs} |
|
camtype = camera_utils.ProjectionType.PERSPECTIVE |
|
return names, poses, pixtocam, params, camtype |
|
|
|
|
|
class Dataset(torch.utils.data.Dataset): |
|
"""Dataset Base Class. |
|
|
|
Base class for a NeRF dataset. Creates batches of ray and color data used for |
|
training or rendering a NeRF model. |
|
|
|
Each subclass is responsible for loading images and camera poses from disk by |
|
implementing the _load_renderings() method. This data is used to generate |
|
train and test batches of ray + color data for feeding through the NeRF model. |
|
The ray parameters are calculated in _generate_rays(). |
|
|
|
The public interface mimics the behavior of a standard machine learning |
|
pipeline dataset provider that can provide infinite batches of data to the |
|
training/testing pipelines without exposing any details of how the batches are |
|
loaded/created or how this is parallelized. Therefore, the initializer runs |
|
all setup, including data loading from disk using _load_renderings(), and |
|
begins the thread using its parent start() method. After the initializer |
|
returns, the caller can request batches of data straight away. |
|
|
|
The internal self._queue is initialized as queue.Queue(3), so the infinite |
|
loop in run() will block on the call self._queue.put(self._next_fn()) once |
|
there are 3 elements. The main thread training job runs in a loop that pops 1 |
|
element at a time off the front of the queue. The Dataset thread's run() loop |
|
will populate the queue with 3 elements, then wait until a batch has been |
|
removed and push one more onto the end. |
|
|
|
This repeats indefinitely until the main thread's training loop completes |
|
(typically hundreds of thousands of iterations), then the main thread will |
|
exit and the Dataset thread will automatically be killed since it is a daemon. |
|
|
|
Attributes: |
|
alphas: np.ndarray, optional array of alpha channel data. |
|
cameras: tuple summarizing all camera extrinsic/intrinsic/distortion params. |
|
camtoworlds: np.ndarray, a list of extrinsic camera pose matrices. |
|
camtype: camera_utils.ProjectionType, fisheye or perspective camera. |
|
data_dir: str, location of the dataset on disk. |
|
disp_images: np.ndarray, optional array of disparity (inverse depth) data. |
|
distortion_params: dict, the camera distortion model parameters. |
|
exposures: optional per-image exposure value (shutter * ISO / 1000). |
|
far: float, far plane value for rays. |
|
focal: float, focal length from camera intrinsics. |
|
height: int, height of images. |
|
images: np.ndarray, array of RGB image data. |
|
metadata: dict, optional metadata for raw datasets. |
|
near: float, near plane value for rays. |
|
normal_images: np.ndarray, optional array of surface normal vector data. |
|
pixtocams: np.ndarray, one or a list of inverse intrinsic camera matrices. |
|
pixtocam_ndc: np.ndarray, the inverse intrinsic matrix used for NDC space. |
|
poses: np.ndarray, optional array of auxiliary camera pose data. |
|
rays: utils.Rays, ray data for every pixel in the dataset. |
|
render_exposures: optional list of exposure values for the render path. |
|
render_path: bool, indicates if a smooth camera path should be generated. |
|
size: int, number of images in the dataset. |
|
split: str, indicates if this is a "train" or "test" dataset. |
|
width: int, width of images. |
|
""" |
|
|
|
def __init__(self, |
|
split: str, |
|
data_dir: str, |
|
config: configs.Config): |
|
super().__init__() |
|
|
|
|
|
self._patch_size = max(config.patch_size, 1) |
|
self._batch_size = config.batch_size // config.world_size |
|
if self._patch_size ** 2 > self._batch_size: |
|
raise ValueError(f'Patch size {self._patch_size}^2 too large for ' + |
|
f'per-process batch size {self._batch_size}') |
|
self._batching = utils.BatchingMethod(config.batching) |
|
self._use_tiffs = config.use_tiffs |
|
self._load_disps = config.compute_disp_metrics |
|
self._load_normals = config.compute_normal_metrics |
|
self._num_border_pixels_to_mask = config.num_border_pixels_to_mask |
|
self._apply_bayer_mask = config.apply_bayer_mask |
|
self._render_spherical = False |
|
|
|
self.config = config |
|
self.global_rank = config.global_rank |
|
self.world_size = config.world_size |
|
self.split = utils.DataSplit(split) |
|
self.data_dir = data_dir |
|
self.near = config.near |
|
self.far = config.far |
|
self.render_path = config.render_path |
|
self.distortion_params = None |
|
self.disp_images = None |
|
self.normal_images = None |
|
self.alphas = None |
|
self.poses = None |
|
self.pixtocam_ndc = None |
|
self.metadata = None |
|
self.camtype = camera_utils.ProjectionType.PERSPECTIVE |
|
self.exposures = None |
|
self.render_exposures = None |
|
|
|
|
|
|
|
self.images: np.ndarray = None |
|
self.camtoworlds: np.ndarray = None |
|
self.pixtocams: np.ndarray = None |
|
self.height: int = None |
|
self.width: int = None |
|
|
|
|
|
self._load_renderings(config) |
|
|
|
if self.render_path: |
|
if config.render_path_file is not None: |
|
with utils.open_file(config.render_path_file, 'rb') as fp: |
|
render_poses = np.load(fp) |
|
self.camtoworlds = render_poses |
|
if config.render_resolution is not None: |
|
self.width, self.height = config.render_resolution |
|
if config.render_focal is not None: |
|
self.focal = config.render_focal |
|
if config.render_camtype is not None: |
|
if config.render_camtype == 'pano': |
|
self._render_spherical = True |
|
else: |
|
self.camtype = camera_utils.ProjectionType(config.render_camtype) |
|
|
|
self.distortion_params = None |
|
self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width, |
|
self.height) |
|
|
|
self._n_examples = self.camtoworlds.shape[0] |
|
|
|
self.cameras = (self.pixtocams, |
|
self.camtoworlds, |
|
self.distortion_params, |
|
self.pixtocam_ndc) |
|
|
|
|
|
if self.split == utils.DataSplit.TRAIN and not config.compute_visibility: |
|
self._next_fn = self._next_train |
|
else: |
|
self._next_fn = self._next_test |
|
|
|
@property |
|
def size(self): |
|
return self._n_examples |
|
|
|
def __len__(self): |
|
if self.split == utils.DataSplit.TRAIN and not self.config.compute_visibility: |
|
return 1000 |
|
else: |
|
return self._n_examples |
|
|
|
@abc.abstractmethod |
|
def _load_renderings(self, config): |
|
"""Load images and poses from disk. |
|
|
|
Args: |
|
config: utils.Config, user-specified config parameters. |
|
In inherited classes, this method must set the following public attributes: |
|
images: [N, height, width, 3] array for RGB images. |
|
disp_images: [N, height, width] array for depth data (optional). |
|
normal_images: [N, height, width, 3] array for normals (optional). |
|
camtoworlds: [N, 3, 4] array of extrinsic pose matrices. |
|
poses: [..., 3, 4] array of auxiliary pose data (optional). |
|
pixtocams: [N, 3, 4] array of inverse intrinsic matrices. |
|
distortion_params: dict, camera lens distortion model parameters. |
|
height: int, height of images. |
|
width: int, width of images. |
|
focal: float, focal length to use for ideal pinhole rendering. |
|
""" |
|
|
|
def _make_ray_batch(self, |
|
pix_x_int, |
|
pix_y_int, |
|
cam_idx, |
|
lossmult=None |
|
): |
|
"""Creates ray data batch from pixel coordinates and camera indices. |
|
|
|
All arguments must have broadcastable shapes. If the arguments together |
|
broadcast to a shape [a, b, c, ..., z] then the returned utils.Rays object |
|
will have array attributes with shape [a, b, c, ..., z, N], where N=3 for |
|
3D vectors and N=1 for per-ray scalar attributes. |
|
|
|
Args: |
|
pix_x_int: int array, x coordinates of image pixels. |
|
pix_y_int: int array, y coordinates of image pixels. |
|
cam_idx: int or int array, camera indices. |
|
lossmult: float array, weight to apply to each ray when computing loss fn. |
|
|
|
Returns: |
|
A dict mapping from strings utils.Rays or arrays of image data. |
|
This is the batch provided for one NeRF train or test iteration. |
|
""" |
|
|
|
broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None] |
|
ray_kwargs = { |
|
'lossmult': broadcast_scalar(1.) if lossmult is None else lossmult, |
|
'near': broadcast_scalar(self.near), |
|
'far': broadcast_scalar(self.far), |
|
'cam_idx': broadcast_scalar(cam_idx), |
|
} |
|
|
|
if self.metadata is not None: |
|
|
|
for key in ['exposure_idx', 'exposure_values']: |
|
idx = 0 if self.render_path else cam_idx |
|
ray_kwargs[key] = broadcast_scalar(self.metadata[key][idx]) |
|
if self.exposures is not None: |
|
idx = 0 if self.render_path else cam_idx |
|
ray_kwargs['exposure_values'] = broadcast_scalar(self.exposures[idx]) |
|
if self.render_path and self.render_exposures is not None: |
|
ray_kwargs['exposure_values'] = broadcast_scalar( |
|
self.render_exposures[cam_idx]) |
|
|
|
pixels = dict(pix_x_int=pix_x_int, pix_y_int=pix_y_int, **ray_kwargs) |
|
|
|
|
|
batch = camera_utils.cast_ray_batch(self.cameras, pixels, self.camtype) |
|
batch['cam_dirs'] = -self.camtoworlds[ray_kwargs['cam_idx'][..., 0]][..., :3, 2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.render_path: |
|
batch['rgb'] = self.images[cam_idx, pix_y_int, pix_x_int] |
|
if self._load_disps: |
|
batch['disps'] = self.disp_images[cam_idx, pix_y_int, pix_x_int] |
|
if self._load_normals: |
|
batch['normals'] = self.normal_images[cam_idx, pix_y_int, pix_x_int] |
|
batch['alphas'] = self.alphas[cam_idx, pix_y_int, pix_x_int] |
|
return {k: torch.from_numpy(v.copy()).float() if v is not None else None for k, v in batch.items()} |
|
|
|
def _next_train(self, item): |
|
"""Sample next training batch (random rays).""" |
|
|
|
|
|
|
|
num_patches = self._batch_size // self._patch_size ** 2 |
|
lower_border = self._num_border_pixels_to_mask |
|
upper_border = self._num_border_pixels_to_mask + self._patch_size - 1 |
|
|
|
pix_x_int = np.random.randint(lower_border, self.width - upper_border, |
|
(num_patches, 1, 1)) |
|
|
|
pix_y_int = np.random.randint(lower_border, self.height - upper_border, |
|
(num_patches, 1, 1)) |
|
|
|
|
|
patch_dx_int, patch_dy_int = camera_utils.pixel_coordinates( |
|
self._patch_size, self._patch_size) |
|
pix_x_int = pix_x_int + patch_dx_int |
|
pix_y_int = pix_y_int + patch_dy_int |
|
|
|
if self._batching == utils.BatchingMethod.ALL_IMAGES: |
|
cam_idx = np.random.randint(0, self._n_examples, (num_patches, 1, 1)) |
|
else: |
|
cam_idx = np.random.randint(0, self._n_examples, (1,)) |
|
|
|
if self._apply_bayer_mask: |
|
|
|
lossmult = raw_utils.pixels_to_bayer_mask(pix_x_int, pix_y_int) |
|
else: |
|
lossmult = None |
|
|
|
return self._make_ray_batch(pix_x_int, pix_y_int, cam_idx, |
|
lossmult=lossmult) |
|
|
|
def generate_ray_batch(self, cam_idx: int): |
|
"""Generate ray batch for a specified camera in the dataset.""" |
|
if self._render_spherical: |
|
camtoworld = self.camtoworlds[cam_idx] |
|
rays = camera_utils.cast_spherical_rays( |
|
camtoworld, self.height, self.width, self.near, self.far) |
|
return rays |
|
else: |
|
|
|
pix_x_int, pix_y_int = camera_utils.pixel_coordinates( |
|
self.width, self.height) |
|
return self._make_ray_batch(pix_x_int, pix_y_int, cam_idx) |
|
|
|
def _next_test(self, item): |
|
"""Sample next test batch (one full image).""" |
|
return self.generate_ray_batch(item) |
|
|
|
def collate_fn(self, item): |
|
return self._next_fn(item[0]) |
|
|
|
def __getitem__(self, item): |
|
return self._next_fn(item) |
|
|
|
|
|
class Blender(Dataset): |
|
"""Blender Dataset.""" |
|
|
|
def _load_renderings(self, config): |
|
"""Load images from disk.""" |
|
if config.render_path: |
|
raise ValueError('render_path cannot be used for the blender dataset.') |
|
pose_file = os.path.join(self.data_dir, f'transforms_{self.split.value}.json') |
|
with utils.open_file(pose_file, 'r') as fp: |
|
meta = json.load(fp) |
|
images = [] |
|
disp_images = [] |
|
normal_images = [] |
|
cams = [] |
|
for idx, frame in enumerate(tqdm(meta['frames'], desc='Loading Blender dataset', disable=self.global_rank != 0, leave=False)): |
|
fprefix = os.path.join(self.data_dir, frame['file_path']) |
|
|
|
def get_img(f, fprefix=fprefix): |
|
image = utils.load_img(fprefix + f) |
|
if config.factor > 1: |
|
image = lib_image.downsample(image, config.factor) |
|
return image |
|
|
|
if self._use_tiffs: |
|
channels = [get_img(f'_{ch}.tiff') for ch in ['R', 'G', 'B', 'A']] |
|
|
|
image = lib_image.linear_to_srgb_np(np.stack(channels, axis=-1)) |
|
else: |
|
image = get_img('.png') / 255. |
|
images.append(image) |
|
|
|
if self._load_disps: |
|
disp_image = get_img('_disp.tiff') |
|
disp_images.append(disp_image) |
|
if self._load_normals: |
|
normal_image = get_img('_normal.png')[..., :3] * 2. / 255. - 1. |
|
normal_images.append(normal_image) |
|
|
|
cams.append(np.array(frame['transform_matrix'], dtype=np.float32)) |
|
|
|
self.images = np.stack(images, axis=0) |
|
if self._load_disps: |
|
self.disp_images = np.stack(disp_images, axis=0) |
|
if self._load_normals: |
|
self.normal_images = np.stack(normal_images, axis=0) |
|
self.alphas = self.images[..., -1] |
|
|
|
rgb, alpha = self.images[..., :3], self.images[..., -1:] |
|
self.images = rgb * alpha + (1. - alpha) |
|
self.height, self.width = self.images.shape[1:3] |
|
self.camtoworlds = np.stack(cams, axis=0) |
|
self.focal = .5 * self.width / np.tan(.5 * float(meta['camera_angle_x'])) |
|
self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width, |
|
self.height) |
|
|
|
|
|
class LLFF(Dataset): |
|
"""LLFF Dataset.""" |
|
|
|
def _load_renderings(self, config): |
|
"""Load images from disk.""" |
|
|
|
image_dir_suffix = '' |
|
|
|
|
|
if config.factor > 0 and not (config.rawnerf_mode and |
|
self.split == utils.DataSplit.TRAIN): |
|
image_dir_suffix = f'_{config.factor}' |
|
factor = config.factor |
|
else: |
|
factor = 1 |
|
|
|
|
|
colmap_dir = os.path.join(self.data_dir, 'sparse/0/') |
|
|
|
|
|
if utils.file_exists(colmap_dir): |
|
pose_data = NeRFSceneManager(colmap_dir).process() |
|
else: |
|
|
|
|
|
raise ValueError('COLMAP data not found.') |
|
image_names, poses, pixtocam, distortion_params, camtype = pose_data |
|
|
|
|
|
|
|
inds = np.argsort(image_names) |
|
image_names = [image_names[i] for i in inds] |
|
poses = poses[inds] |
|
|
|
|
|
posefile = os.path.join(self.data_dir, 'poses_bounds.npy') |
|
if utils.file_exists(posefile): |
|
with utils.open_file(posefile, 'rb') as fp: |
|
poses_arr = np.load(fp) |
|
bounds = poses_arr[:, -2:] |
|
else: |
|
bounds = np.array([0.01, 1.]) |
|
self.colmap_to_world_transform = np.eye(4) |
|
|
|
|
|
pixtocam = pixtocam @ np.diag([factor, factor, 1.]) |
|
self.pixtocams = pixtocam.astype(np.float32) |
|
self.focal = 1. / self.pixtocams[0, 0] |
|
self.distortion_params = distortion_params |
|
self.camtype = camtype |
|
|
|
|
|
if config.forward_facing: |
|
|
|
self.pixtocam_ndc = self.pixtocams.reshape(-1, 3, 3)[0] |
|
|
|
scale = 1. / (bounds.min() * .75) |
|
poses[:, :3, 3] *= scale |
|
self.colmap_to_world_transform = np.diag([scale] * 3 + [1]) |
|
bounds *= scale |
|
|
|
poses, transform = camera_utils.recenter_poses(poses) |
|
self.colmap_to_world_transform = ( |
|
transform @ self.colmap_to_world_transform) |
|
|
|
self.render_poses = camera_utils.generate_spiral_path( |
|
poses, bounds, n_frames=config.render_path_frames) |
|
else: |
|
|
|
poses, transform = camera_utils.transform_poses_pca(poses) |
|
self.colmap_to_world_transform = transform |
|
if config.render_spline_keyframes is not None: |
|
rets = camera_utils.create_render_spline_path(config, image_names, |
|
poses, self.exposures) |
|
self.spline_indices, self.render_poses, self.render_exposures = rets |
|
else: |
|
|
|
self.render_poses = camera_utils.generate_ellipse_path( |
|
poses, |
|
n_frames=config.render_path_frames, |
|
z_variation=config.z_variation, |
|
z_phase=config.z_phase) |
|
|
|
|
|
all_indices = np.arange(len(image_names)) |
|
if config.llff_use_all_images_for_training: |
|
train_indices = all_indices |
|
else: |
|
train_indices = all_indices % config.llffhold != 0 |
|
if config.llff_use_all_images_for_testing: |
|
test_indices = all_indices |
|
else: |
|
test_indices = all_indices % config.llffhold == 0 |
|
split_indices = { |
|
utils.DataSplit.TEST: all_indices[test_indices], |
|
utils.DataSplit.TRAIN: all_indices[train_indices], |
|
} |
|
indices = split_indices[self.split] |
|
image_names = [image_names[i] for i in indices] |
|
poses = poses[indices] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raw_testscene = False |
|
if config.rawnerf_mode: |
|
|
|
images, metadata, raw_testscene = raw_utils.load_raw_dataset( |
|
self.split, |
|
self.data_dir, |
|
image_names, |
|
config.exposure_percentile, |
|
factor) |
|
self.metadata = metadata |
|
|
|
else: |
|
|
|
colmap_image_dir = os.path.join(self.data_dir, 'images') |
|
image_dir = os.path.join(self.data_dir, 'images' + image_dir_suffix) |
|
for d in [image_dir, colmap_image_dir]: |
|
if not utils.file_exists(d): |
|
raise ValueError(f'Image folder {d} does not exist.') |
|
|
|
|
|
colmap_files = sorted(utils.listdir(colmap_image_dir)) |
|
image_files = sorted(utils.listdir(image_dir)) |
|
colmap_to_image = dict(zip(colmap_files, image_files)) |
|
image_paths = [os.path.join(image_dir, colmap_to_image[f]) |
|
for f in image_names] |
|
images = [utils.load_img(x) for x in tqdm(image_paths, desc='Loading LLFF dataset', disable=self.global_rank != 0, leave=False)] |
|
images = np.stack(images, axis=0) / 255. |
|
|
|
|
|
jpeg_paths = [os.path.join(colmap_image_dir, f) for f in image_names] |
|
exifs = [utils.load_exif(x) for x in jpeg_paths] |
|
self.exifs = exifs |
|
if 'ExposureTime' in exifs[0] and 'ISOSpeedRatings' in exifs[0]: |
|
gather_exif_value = lambda k: np.array([float(x[k]) for x in exifs]) |
|
shutters = gather_exif_value('ExposureTime') |
|
isos = gather_exif_value('ISOSpeedRatings') |
|
self.exposures = shutters * isos / 1000. |
|
|
|
if raw_testscene: |
|
|
|
|
|
raw_testscene_poses = { |
|
utils.DataSplit.TEST: poses[:1], |
|
utils.DataSplit.TRAIN: poses[1:], |
|
} |
|
poses = raw_testscene_poses[self.split] |
|
|
|
self.poses = poses |
|
self.images = images |
|
self.camtoworlds = self.render_poses if config.render_path else poses |
|
self.height, self.width = images.shape[1:3] |
|
|
|
|
|
class TanksAndTemplesNerfPP(Dataset): |
|
"""Subset of Tanks and Temples Dataset as processed by NeRF++.""" |
|
|
|
def _load_renderings(self, config): |
|
"""Load images from disk.""" |
|
if config.render_path: |
|
split_str = 'camera_path' |
|
else: |
|
split_str = self.split.value |
|
|
|
basedir = os.path.join(self.data_dir, split_str) |
|
|
|
|
|
def load_files(dirname, load_fn, shape=None): |
|
files = [ |
|
os.path.join(basedir, dirname, f) |
|
for f in sorted(utils.listdir(os.path.join(basedir, dirname))) |
|
] |
|
mats = np.array([load_fn(utils.open_file(f, 'rb')) for f in files]) |
|
if shape is not None: |
|
mats = mats.reshape(mats.shape[:1] + shape) |
|
return mats |
|
|
|
poses = load_files('pose', np.loadtxt, (4, 4)) |
|
|
|
poses = np.matmul(poses, np.diag(np.array([1, -1, -1, 1]))) |
|
|
|
|
|
intrinsics = load_files('intrinsics', np.loadtxt, (4, 4)) |
|
|
|
if not config.render_path: |
|
images = load_files('rgb', lambda f: np.array(Image.open(f))) / 255. |
|
self.images = images |
|
self.height, self.width = self.images.shape[1:3] |
|
|
|
else: |
|
|
|
d = os.path.join(self.data_dir, 'test', 'rgb') |
|
f = os.path.join(d, sorted(utils.listdir(d))[0]) |
|
shape = utils.load_img(f).shape |
|
self.height, self.width = shape[:2] |
|
self.images = None |
|
|
|
self.camtoworlds = poses |
|
self.focal = intrinsics[0, 0, 0] |
|
self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width, |
|
self.height) |
|
|
|
|
|
class TanksAndTemplesFVS(Dataset): |
|
"""Subset of Tanks and Temples Dataset as processed by Free View Synthesis.""" |
|
|
|
def _load_renderings(self, config): |
|
"""Load images from disk.""" |
|
render_only = config.render_path and self.split == utils.DataSplit.TEST |
|
|
|
basedir = os.path.join(self.data_dir, 'dense') |
|
sizes = [f for f in sorted(utils.listdir(basedir)) if f.startswith('ibr3d')] |
|
sizes = sizes[::-1] |
|
|
|
if config.factor >= len(sizes): |
|
raise ValueError(f'Factor {config.factor} larger than {len(sizes)}') |
|
|
|
basedir = os.path.join(basedir, sizes[config.factor]) |
|
open_fn = lambda f: utils.open_file(os.path.join(basedir, f), 'rb') |
|
|
|
files = [f for f in sorted(utils.listdir(basedir)) if f.startswith('im_')] |
|
if render_only: |
|
files = files[:1] |
|
images = np.array([np.array(Image.open(open_fn(f))) for f in files]) / 255. |
|
|
|
names = ['Ks', 'Rs', 'ts'] |
|
intrinsics, rot, trans = (np.load(open_fn(f'{n}.npy')) for n in names) |
|
|
|
|
|
w2c = np.concatenate([rot, trans[..., None]], axis=-1) |
|
c2w_colmap = np.linalg.inv(camera_utils.pad_poses(w2c))[:, :3, :4] |
|
c2w = c2w_colmap @ np.diag(np.array([1, -1, -1, 1])) |
|
|
|
|
|
poses, _ = camera_utils.transform_poses_pca(c2w) |
|
self.poses = poses |
|
|
|
self.images = images |
|
self.height, self.width = self.images.shape[1:3] |
|
self.camtoworlds = poses |
|
|
|
self.focal = intrinsics[0, 0, 0] |
|
self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width, |
|
self.height) |
|
|
|
if render_only: |
|
render_path = camera_utils.generate_ellipse_path( |
|
poses, |
|
config.render_path_frames, |
|
z_variation=config.z_variation, |
|
z_phase=config.z_phase) |
|
self.images = None |
|
self.camtoworlds = render_path |
|
self.render_poses = render_path |
|
else: |
|
|
|
all_indices = np.arange(images.shape[0]) |
|
indices = { |
|
utils.DataSplit.TEST: |
|
all_indices[all_indices % config.llffhold == 0], |
|
utils.DataSplit.TRAIN: |
|
all_indices[all_indices % config.llffhold != 0], |
|
}[self.split] |
|
|
|
self.images = self.images[indices] |
|
self.camtoworlds = self.camtoworlds[indices] |
|
|
|
|
|
class DTU(Dataset): |
|
"""DTU Dataset.""" |
|
|
|
def _load_renderings(self, config): |
|
"""Load images from disk.""" |
|
if config.render_path: |
|
raise ValueError('render_path cannot be used for the DTU dataset.') |
|
|
|
images = [] |
|
pixtocams = [] |
|
camtoworlds = [] |
|
|
|
|
|
n_images = len(utils.listdir(self.data_dir)) // 8 |
|
|
|
|
|
for i in range(1, n_images + 1): |
|
|
|
if config.dtu_light_cond < 7: |
|
light_str = f'{config.dtu_light_cond}_r' + ('5000' |
|
if i < 50 else '7000') |
|
else: |
|
light_str = 'max' |
|
|
|
|
|
fname = os.path.join(self.data_dir, f'rect_{i:03d}_{light_str}.png') |
|
image = utils.load_img(fname) / 255. |
|
if config.factor > 1: |
|
image = lib_image.downsample(image, config.factor) |
|
images.append(image) |
|
|
|
|
|
fname = os.path.join(self.data_dir, f'../../cal18/pos_{i:03d}.txt') |
|
with utils.open_file(fname, 'rb') as f: |
|
projection = np.loadtxt(f, dtype=np.float32) |
|
|
|
|
|
camera_mat, rot_mat, t = cv2.decomposeProjectionMatrix(projection)[:3] |
|
camera_mat = camera_mat / camera_mat[2, 2] |
|
pose = np.eye(4, dtype=np.float32) |
|
pose[:3, :3] = rot_mat.transpose() |
|
pose[:3, 3] = (t[:3] / t[3])[:, 0] |
|
pose = pose[:3] |
|
camtoworlds.append(pose) |
|
|
|
if config.factor > 0: |
|
|
|
camera_mat = np.diag([1. / config.factor, 1. / config.factor, 1. |
|
]).astype(np.float32) @ camera_mat |
|
pixtocams.append(np.linalg.inv(camera_mat)) |
|
|
|
pixtocams = np.stack(pixtocams) |
|
camtoworlds = np.stack(camtoworlds) |
|
images = np.stack(images) |
|
|
|
def rescale_poses(poses): |
|
"""Rescales camera poses according to maximum x/y/z value.""" |
|
s = np.max(np.abs(poses[:, :3, -1])) |
|
out = np.copy(poses) |
|
out[:, :3, -1] /= s |
|
return out |
|
|
|
|
|
camtoworlds, _ = camera_utils.recenter_poses(camtoworlds) |
|
camtoworlds = rescale_poses(camtoworlds) |
|
|
|
camtoworlds = camtoworlds @ np.diag([1., -1., -1., 1.]).astype(np.float32) |
|
|
|
all_indices = np.arange(images.shape[0]) |
|
split_indices = { |
|
utils.DataSplit.TEST: all_indices[all_indices % config.dtuhold == 0], |
|
utils.DataSplit.TRAIN: all_indices[all_indices % config.dtuhold != 0], |
|
} |
|
indices = split_indices[self.split] |
|
|
|
self.images = images[indices] |
|
self.height, self.width = images.shape[1:3] |
|
self.camtoworlds = camtoworlds[indices] |
|
self.pixtocams = pixtocams[indices] |
|
|
|
|
|
class Multicam(Dataset): |
|
def __init__(self, |
|
split: str, |
|
data_dir: str, |
|
config: configs.Config): |
|
super().__init__(split, data_dir, config) |
|
|
|
self.multiscale_levels = config.multiscale_levels |
|
|
|
images, camtoworlds, pixtocams, pixtocam_ndc = \ |
|
self.images, self.camtoworlds, self.pixtocams, self.pixtocam_ndc |
|
self.heights, self.widths, self.focals, self.images, self.camtoworlds, self.pixtocams, self.lossmults = [], [], [], [], [], [], [] |
|
if pixtocam_ndc is not None: |
|
self.pixtocam_ndc = [] |
|
else: |
|
self.pixtocam_ndc = None |
|
|
|
for i in range(self._n_examples): |
|
for j in range(self.multiscale_levels): |
|
self.heights.append(self.height // 2 ** j) |
|
self.widths.append(self.width // 2 ** j) |
|
|
|
self.pixtocams.append(pixtocams @ np.diag([self.height / self.heights[-1], |
|
self.width / self.widths[-1], |
|
1.])) |
|
self.focals.append(1. / self.pixtocams[-1][0, 0]) |
|
if config.forward_facing: |
|
|
|
self.pixtocam_ndc.append(pixtocams.reshape(3, 3)) |
|
|
|
self.camtoworlds.append(camtoworlds[i]) |
|
self.lossmults.append(2. ** j) |
|
self.images.append(self.down2(images[i], (self.heights[-1], self.widths[-1]))) |
|
self.pixtocams = np.stack(self.pixtocams) |
|
self.camtoworlds = np.stack(self.camtoworlds) |
|
self.cameras = (self.pixtocams, |
|
self.camtoworlds, |
|
self.distortion_params, |
|
np.stack(self.pixtocam_ndc) if self.pixtocam_ndc is not None else None) |
|
self._generate_rays() |
|
|
|
if self.split == utils.DataSplit.TRAIN: |
|
|
|
def flatten(x): |
|
if x[0] is not None: |
|
x = [y.reshape([-1, y.shape[-1]]) for y in x] |
|
if self._batching == utils.BatchingMethod.ALL_IMAGES: |
|
|
|
x = np.concatenate(x, axis=0) |
|
return x |
|
else: |
|
return None |
|
|
|
self.batches = {k: flatten(v) for k, v in self.batches.items()} |
|
self._n_examples = len(self.camtoworlds) |
|
|
|
|
|
if self.split == utils.DataSplit.TRAIN: |
|
self._next_fn = self._next_train |
|
else: |
|
self._next_fn = self._next_test |
|
|
|
def _generate_rays(self): |
|
if self.global_rank == 0: |
|
tbar = tqdm(range(len(self.camtoworlds)), desc='Generating rays', leave=False) |
|
else: |
|
tbar = range(len(self.camtoworlds)) |
|
|
|
self.batches = defaultdict(list) |
|
for cam_idx in tbar: |
|
pix_x_int, pix_y_int = camera_utils.pixel_coordinates( |
|
self.widths[cam_idx], self.heights[cam_idx]) |
|
broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None] |
|
ray_kwargs = { |
|
'lossmult': broadcast_scalar(self.lossmults[cam_idx]), |
|
'near': broadcast_scalar(self.near), |
|
'far': broadcast_scalar(self.far), |
|
'cam_idx': broadcast_scalar(cam_idx), |
|
} |
|
|
|
pixels = dict(pix_x_int=pix_x_int, pix_y_int=pix_y_int, **ray_kwargs) |
|
|
|
batch = camera_utils.cast_ray_batch(self.cameras, pixels, self.camtype) |
|
if not self.render_path: |
|
batch['rgb'] = self.images[cam_idx] |
|
if self._load_disps: |
|
batch['disps'] = self.disp_images[cam_idx, pix_y_int, pix_x_int] |
|
if self._load_normals: |
|
batch['normals'] = self.normal_images[cam_idx, pix_y_int, pix_x_int] |
|
batch['alphas'] = self.alphas[cam_idx, pix_y_int, pix_x_int] |
|
for k, v in batch.items(): |
|
self.batches[k].append(v) |
|
|
|
def _next_train(self, item): |
|
"""Sample next training batch (random rays).""" |
|
|
|
|
|
|
|
num_patches = self._batch_size // self._patch_size ** 2 |
|
|
|
if self._batching == utils.BatchingMethod.ALL_IMAGES: |
|
ray_indices = np.random.randint(0, self.batches['origins'].shape[0], (num_patches, 1, 1)) |
|
batch = {k: v[ray_indices] if v is not None else None for k, v in self.batches.items()} |
|
else: |
|
image_index = np.random.randint(0, self._n_examples, ()) |
|
ray_indices = np.random.randint(0, self.batches['origins'][image_index].shape[0], (num_patches,)) |
|
batch = {k: v[image_index][ray_indices] if v is not None else None for k, v in self.batches.items()} |
|
batch['cam_dirs'] = -self.camtoworlds[batch['cam_idx'][..., 0]][..., 2] |
|
return {k: torch.from_numpy(v.copy()).float() if v is not None else None for k, v in batch.items()} |
|
|
|
def _next_test(self, item): |
|
"""Sample next test batch (one full image).""" |
|
batch = {k: v[item] for k, v in self.batches.items()} |
|
batch['cam_dirs'] = -self.camtoworlds[batch['cam_idx'][..., 0]][..., 2] |
|
return {k: torch.from_numpy(v.copy()).float() if v is not None else None for k, v in batch.items()} |
|
|
|
@staticmethod |
|
def down2(img, sh): |
|
return cv2.resize(img, sh[::-1], interpolation=cv2.INTER_CUBIC) |
|
|
|
|
|
class MultiLLFF(Multicam, LLFF): |
|
pass |
|
|
|
|
|
if __name__ == '__main__': |
|
from internal import configs |
|
import accelerate |
|
|
|
config = configs.Config() |
|
accelerator = accelerate.Accelerator() |
|
config.world_size = accelerator.num_processes |
|
config.global_rank = accelerator.process_index |
|
config.factor = 8 |
|
dataset = LLFF('test', '/SSD_DISK/datasets/360_v2/bicycle', config) |
|
print(len(dataset)) |
|
for _ in tqdm(dataset): |
|
pass |
|
print('done') |
|
|
|
|