alvin888's picture
Upload 1461 files
e2ebf5f verified
raw
history blame contribute delete
18.8 kB
import os
import json
import numpy as np
import torch
from PIL import Image
from pytorch3d.renderer import FoVPerspectiveCameras as P3DCameras
from pytorch3d.renderer.cameras import _get_sfm_calibration_matrix
from sugar.sugar_utils.graphics_utils import focal2fov, fov2focal, getWorld2View2, getProjectionMatrix
from sugar.sugar_utils.general_utils import PILtoTorch
def load_gs_cameras(source_path, gs_output_path, image_resolution=1,
load_gt_images=True, max_img_size=1920):
"""Loads Gaussian Splatting camera parameters from a COLMAP reconstruction.
Args:
source_path (str): Path to the source data.
gs_output_path (str): Path to the Gaussian Splatting output.
image_resolution (int, optional): Factor by which to downscale the images. Defaults to 1.
load_gt_images (bool, optional): If True, loads the ground truth images. Defaults to True.
max_img_size (int, optional): Maximum size of the images. Defaults to 1920.
Returns:
List of GSCameras: List of Gaussian Splatting cameras.
"""
image_dir = os.path.join(source_path, 'images')
with open(gs_output_path + 'cameras.json') as f:
unsorted_camera_transforms = json.load(f)
camera_transforms = sorted(unsorted_camera_transforms.copy(), key = lambda x : x['img_name'])
cam_list = []
extension = '.' + os.listdir(image_dir)[0].split('.')[-1]
if extension not in ['.jpg', '.png', '.JPG', '.PNG']:
print(f"Warning: image extension {extension} not supported.")
else:
print(f"Found image extension {extension}")
for cam_idx in range(len(camera_transforms)):
camera_transform = camera_transforms[cam_idx]
# Extrinsics
rot = np.array(camera_transform['rotation'])
pos = np.array(camera_transform['position'])
W2C = np.zeros((4,4))
W2C[:3, :3] = rot
W2C[:3, 3] = pos
W2C[3,3] = 1
Rt = np.linalg.inv(W2C)
T = Rt[:3, 3]
R = Rt[:3, :3].transpose()
# Intrinsics
width = camera_transform['width']
height = camera_transform['height']
fy = camera_transform['fy']
fx = camera_transform['fx']
fov_y = focal2fov(fy, height)
fov_x = focal2fov(fx, width)
# GT data
id = camera_transform['id']
name = camera_transform['img_name']
image_path = os.path.join(image_dir, name + extension)
if load_gt_images:
image = Image.open(image_path)
orig_w, orig_h = image.size
downscale_factor = 1
if image_resolution in [1, 2, 4, 8]:
downscale_factor = image_resolution
# resolution = round(orig_w/(image_resolution)), round(orig_h/(image_resolution))
if max(orig_h, orig_w) > max_img_size:
additional_downscale_factor = max(orig_h, orig_w) / max_img_size
downscale_factor = additional_downscale_factor * downscale_factor
resolution = round(orig_w/(downscale_factor)), round(orig_h/(downscale_factor))
resized_image_rgb = PILtoTorch(image, resolution)
gt_image = resized_image_rgb[:3, ...]
image_height, image_width = None, None
else:
gt_image = None
if image_resolution in [1, 2, 4, 8]:
downscale_factor = image_resolution
# resolution = round(orig_w/(image_resolution)), round(orig_h/(image_resolution))
if max(height, width) > max_img_size:
additional_downscale_factor = max(height, width) / max_img_size
downscale_factor = additional_downscale_factor * downscale_factor
image_height, image_width = round(height/downscale_factor), round(width/downscale_factor)
gs_camera = GSCamera(
colmap_id=id, image=gt_image, gt_alpha_mask=None,
R=R, T=T, FoVx=fov_x, FoVy=fov_y,
image_name=name, uid=id,
image_height=image_height, image_width=image_width,)
cam_list.append(gs_camera)
return cam_list
class GSCamera(torch.nn.Module):
"""Class to store Gaussian Splatting camera parameters.
"""
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda",
image_height=None, image_width=None,
):
"""
Args:
colmap_id (int): ID of the camera in the COLMAP reconstruction.
R (np.array): Rotation matrix.
T (np.array): Translation vector.
FoVx (float): Field of view in the x direction.
FoVy (float): Field of view in the y direction.
image (np.array): GT image.
gt_alpha_mask (_type_): _description_
image_name (_type_): _description_
uid (_type_): _description_
trans (_type_, optional): _description_. Defaults to np.array([0.0, 0.0, 0.0]).
scale (float, optional): _description_. Defaults to 1.0.
data_device (str, optional): _description_. Defaults to "cuda".
image_height (_type_, optional): _description_. Defaults to None.
image_width (_type_, optional): _description_. Defaults to None.
Raises:
ValueError: _description_
"""
super(GSCamera, self).__init__()
self.uid = uid
self.colmap_id = colmap_id
self.R = R
self.T = T
self.FoVx = FoVx
self.FoVy = FoVy
self.image_name = image_name
try:
self.data_device = torch.device(data_device)
except Exception as e:
print(e)
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")
if image is None:
if image_height is None or image_width is None:
raise ValueError("Either image or image_height and image_width must be specified")
else:
self.image_height = image_height
self.image_width = image_width
else:
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]
if gt_alpha_mask is not None:
self.original_image *= gt_alpha_mask.to(self.data_device)
else:
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
self.zfar = 100.0
self.znear = 0.01
self.trans = trans
self.scale = scale
self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3]
@property
def device(self):
return self.world_view_transform.device
def to(self, device):
self.world_view_transform = self.world_view_transform.to(device)
self.projection_matrix = self.projection_matrix.to(device)
self.full_proj_transform = self.full_proj_transform.to(device)
self.camera_center = self.camera_center.to(device)
return self
def create_p3d_cameras(R=None, T=None, K=None, znear=0.0001):
"""Creates pytorch3d-compatible camera object from R, T, K matrices.
Args:
R (torch.Tensor, optional): Rotation matrix. Defaults to Identity.
T (torch.Tensor, optional): Translation vector. Defaults to Zero.
K (torch.Tensor, optional): Camera intrinsics. Defaults to None.
znear (float, optional): Near clipping plane. Defaults to 0.0001.
Returns:
pytorch3d.renderer.cameras.FoVPerspectiveCameras: pytorch3d-compatible camera object.
"""
if R is None:
R = torch.eye(3)[None]
if T is None:
T = torch.zeros(3)[None]
if K is not None:
p3d_cameras = P3DCameras(R=R, T=T, K=K, znear=0.0001)
else:
p3d_cameras = P3DCameras(R=R, T=T, znear=0.0001)
p3d_cameras.K = p3d_cameras.get_projection_transform().get_matrix().transpose(-1, -2)
return p3d_cameras
def convert_camera_from_gs_to_pytorch3d(gs_cameras, device='cuda'):
"""
From Gaussian Splatting camera parameters,
computes R, T, K matrices and outputs pytorch3d-compatible camera object.
Args:
gs_cameras (List of GSCamera): List of Gaussian Splatting cameras.
device (_type_, optional): _description_. Defaults to 'cuda'.
Returns:
p3d_cameras: pytorch3d-compatible camera object.
"""
N = len(gs_cameras)
R = torch.Tensor(np.array([gs_camera.R for gs_camera in gs_cameras])).to(device)
T = torch.Tensor(np.array([gs_camera.T for gs_camera in gs_cameras])).to(device)
fx = torch.Tensor(np.array([fov2focal(gs_camera.FoVx, gs_camera.image_width) for gs_camera in gs_cameras])).to(device)
fy = torch.Tensor(np.array([fov2focal(gs_camera.FoVy, gs_camera.image_height) for gs_camera in gs_cameras])).to(device)
image_height = torch.tensor(np.array([gs_camera.image_height for gs_camera in gs_cameras]), dtype=torch.int).to(device)
image_width = torch.tensor(np.array([gs_camera.image_width for gs_camera in gs_cameras]), dtype=torch.int).to(device)
cx = image_width / 2. # torch.zeros_like(fx).to(device)
cy = image_height / 2. # torch.zeros_like(fy).to(device)
w2c = torch.zeros(N, 4, 4).to(device)
w2c[:, :3, :3] = R.transpose(-1, -2)
w2c[:, :3, 3] = T
w2c[:, 3, 3] = 1
c2w = w2c.inverse()
c2w[:, :3, 1:3] *= -1
c2w = c2w[:, :3, :]
distortion_params = torch.zeros(N, 6).to(device)
camera_type = torch.ones(N, 1, dtype=torch.int32).to(device)
# Pytorch3d-compatible camera matrices
# Intrinsics
image_size = torch.Tensor(
[image_width[0], image_height[0]],
)[
None
].to(device)
scale = image_size.min(dim=1, keepdim=True)[0] / 2.0
c0 = image_size / 2.0
p0_pytorch3d = (
-(
torch.Tensor(
(cx[0], cy[0]),
)[
None
].to(device)
- c0
)
/ scale
)
focal_pytorch3d = (
torch.Tensor([fx[0], fy[0]])[None].to(device) / scale
)
K = _get_sfm_calibration_matrix(
1, "cpu", focal_pytorch3d, p0_pytorch3d, orthographic=False
)
K = K.expand(N, -1, -1)
# Extrinsics
line = torch.Tensor([[0.0, 0.0, 0.0, 1.0]]).to(device).expand(N, -1, -1)
cam2world = torch.cat([c2w, line], dim=1)
world2cam = cam2world.inverse()
R, T = world2cam.split([3, 1], dim=-1)
R = R[:, :3].transpose(1, 2) * torch.Tensor([-1.0, 1.0, -1]).to(device)
T = T.squeeze(2)[:, :3] * torch.Tensor([-1.0, 1.0, -1]).to(device)
p3d_cameras = P3DCameras(device=device, R=R, T=T, K=K, znear=0.0001)
return p3d_cameras
def convert_camera_from_pytorch3d_to_gs(
p3d_cameras: P3DCameras,
height: float,
width: float,
device='cuda',
):
"""From a pytorch3d-compatible camera object and its camera matrices R, T, K, and width, height,
outputs Gaussian Splatting camera parameters.
Args:
p3d_cameras (P3DCameras): R matrices should have shape (N, 3, 3),
T matrices should have shape (N, 3, 1),
K matrices should have shape (N, 3, 3).
height (float): _description_
width (float): _description_
device (_type_, optional): _description_. Defaults to 'cuda'.
"""
N = p3d_cameras.R.shape[0]
if device is None:
device = p3d_cameras.device
if type(height) == torch.Tensor:
height = int(torch.Tensor([[height.item()]]).to(device))
width = int(torch.Tensor([[width.item()]]).to(device))
else:
height = int(height)
width = int(width)
# Inverse extrinsics
R_inv = (p3d_cameras.R * torch.Tensor([-1.0, 1.0, -1]).to(device)).transpose(-1, -2)
T_inv = (p3d_cameras.T * torch.Tensor([-1.0, 1.0, -1]).to(device)).unsqueeze(-1)
world2cam_inv = torch.cat([R_inv, T_inv], dim=-1)
line = torch.Tensor([[0.0, 0.0, 0.0, 1.0]]).to(device).expand(N, -1, -1)
world2cam_inv = torch.cat([world2cam_inv, line], dim=-2)
cam2world_inv = world2cam_inv.inverse()
camera_to_worlds_inv = cam2world_inv[:, :3]
# Inverse intrinsics
image_size = torch.Tensor(
[width, height],
)[
None
].to(device)
scale = image_size.min(dim=1, keepdim=True)[0] / 2.0
c0 = image_size / 2.0
K_inv = p3d_cameras.K[0] * scale
fx_inv, fy_inv = K_inv[0, 0], K_inv[1, 1]
cx_inv, cy_inv = c0[0, 0] - K_inv[0, 2], c0[0, 1] - K_inv[1, 2]
gs_cameras = []
for cam_idx in range(N):
# NeRF 'transform_matrix' is a camera-to-world transform
c2w = camera_to_worlds_inv[cam_idx]
c2w = torch.cat([c2w, torch.Tensor([[0, 0, 0, 1]]).to(device)], dim=0).cpu().numpy() #.transpose(-1, -2)
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
c2w[:3, 1:3] *= -1
# get the world-to-camera transform and set R, T
w2c = np.linalg.inv(c2w)
R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
T = w2c[:3, 3]
image_height=height
image_width=width
fx = fx_inv.item()
fy = fy_inv.item()
fovx = focal2fov(fx, image_width)
fovy = focal2fov(fy, image_height)
FovY = fovy
FovX = fovx
name = 'image_' + str(cam_idx)
camera = GSCamera(
colmap_id=cam_idx, image=None, gt_alpha_mask=None,
R=R, T=T, FoVx=FovX, FoVy=FovY,
image_name=name, uid=cam_idx,
image_height=image_height,
image_width=image_width,
)
gs_cameras.append(camera)
return gs_cameras
class CamerasWrapper:
"""Class to wrap Gaussian Splatting camera parameters
and facilitates both usage and integration with PyTorch3D.
"""
def __init__(
self,
gs_cameras,
p3d_cameras=None,
p3d_cameras_computed=False,
) -> None:
"""
Args:
camera_to_worlds (_type_): _description_
fx (_type_): _description_
fy (_type_): _description_
cx (_type_): _description_
cy (_type_): _description_
width (_type_): _description_
height (_type_): _description_
distortion_params (_type_): _description_
camera_type (_type_): _description_
"""
self.gs_cameras = gs_cameras
self._p3d_cameras = p3d_cameras
self._p3d_cameras_computed = p3d_cameras_computed
device = gs_cameras[0].device
N = len(gs_cameras)
R = torch.Tensor(np.array([gs_camera.R for gs_camera in gs_cameras])).to(device)
T = torch.Tensor(np.array([gs_camera.T for gs_camera in gs_cameras])).to(device)
self.fx = torch.Tensor(np.array([fov2focal(gs_camera.FoVx, gs_camera.image_width) for gs_camera in gs_cameras])).to(device)
self.fy = torch.Tensor(np.array([fov2focal(gs_camera.FoVy, gs_camera.image_height) for gs_camera in gs_cameras])).to(device)
self.height = torch.tensor(np.array([gs_camera.image_height for gs_camera in gs_cameras]), dtype=torch.int).to(device)
self.width = torch.tensor(np.array([gs_camera.image_width for gs_camera in gs_cameras]), dtype=torch.int).to(device)
self.cx = self.width / 2. # torch.zeros_like(fx).to(device)
self.cy = self.height / 2. # torch.zeros_like(fy).to(device)
w2c = torch.zeros(N, 4, 4).to(device)
w2c[:, :3, :3] = R.transpose(-1, -2)
w2c[:, :3, 3] = T
w2c[:, 3, 3] = 1
c2w = w2c.inverse()
c2w[:, :3, 1:3] *= -1
c2w = c2w[:, :3, :]
self.camera_to_worlds = c2w
@classmethod
def from_p3d_cameras(
cls,
p3d_cameras,
width: float,
height: float,
) -> None:
"""Initializes CamerasWrapper from pytorch3d-compatible camera object.
Args:
p3d_cameras (_type_): _description_
width (float): _description_
height (float): _description_
Returns:
_type_: _description_
"""
cls._p3d_cameras = p3d_cameras
cls._p3d_cameras_computed = True
gs_cameras = convert_camera_from_pytorch3d_to_gs(
p3d_cameras,
height=height,
width=width,
)
return cls(
gs_cameras=gs_cameras,
p3d_cameras=p3d_cameras,
p3d_cameras_computed=True,
)
@property
def device(self):
return self.camera_to_worlds.device
@property
def p3d_cameras(self):
if not self._p3d_cameras_computed:
self._p3d_cameras = convert_camera_from_gs_to_pytorch3d(
self.gs_cameras,
)
self._p3d_cameras_computed = True
return self._p3d_cameras
def __len__(self):
return len(self.gs_cameras)
def to(self, device):
self.camera_to_worlds = self.camera_to_worlds.to(device)
self.fx = self.fx.to(device)
self.fy = self.fy.to(device)
self.cx = self.cx.to(device)
self.cy = self.cy.to(device)
self.width = self.width.to(device)
self.height = self.height.to(device)
for gs_camera in self.gs_cameras:
gs_camera.to(device)
if self._p3d_cameras_computed:
self._p3d_cameras = self._p3d_cameras.to(device)
return self
def get_spatial_extent(self):
"""Returns the spatial extent of the cameras, computed as
the extent of the bounding box containing all camera centers.
Returns:
(float): Spatial extent of the cameras.
"""
camera_centers = self.p3d_cameras.get_camera_center()
avg_camera_center = camera_centers.mean(dim=0, keepdim=True)
half_diagonal = torch.norm(camera_centers - avg_camera_center, dim=-1).max().item()
radius = 1.1 * half_diagonal
return radius