Spaces:
Runtime error
Runtime error
import os | |
import json | |
import numpy as np | |
import torch | |
from PIL import Image | |
from pytorch3d.renderer import FoVPerspectiveCameras as P3DCameras | |
from pytorch3d.renderer.cameras import _get_sfm_calibration_matrix | |
from sugar.sugar_utils.graphics_utils import focal2fov, fov2focal, getWorld2View2, getProjectionMatrix | |
from sugar.sugar_utils.general_utils import PILtoTorch | |
def load_gs_cameras(source_path, gs_output_path, image_resolution=1, | |
load_gt_images=True, max_img_size=1920): | |
"""Loads Gaussian Splatting camera parameters from a COLMAP reconstruction. | |
Args: | |
source_path (str): Path to the source data. | |
gs_output_path (str): Path to the Gaussian Splatting output. | |
image_resolution (int, optional): Factor by which to downscale the images. Defaults to 1. | |
load_gt_images (bool, optional): If True, loads the ground truth images. Defaults to True. | |
max_img_size (int, optional): Maximum size of the images. Defaults to 1920. | |
Returns: | |
List of GSCameras: List of Gaussian Splatting cameras. | |
""" | |
image_dir = os.path.join(source_path, 'images') | |
with open(gs_output_path + 'cameras.json') as f: | |
unsorted_camera_transforms = json.load(f) | |
camera_transforms = sorted(unsorted_camera_transforms.copy(), key = lambda x : x['img_name']) | |
cam_list = [] | |
extension = '.' + os.listdir(image_dir)[0].split('.')[-1] | |
if extension not in ['.jpg', '.png', '.JPG', '.PNG']: | |
print(f"Warning: image extension {extension} not supported.") | |
else: | |
print(f"Found image extension {extension}") | |
for cam_idx in range(len(camera_transforms)): | |
camera_transform = camera_transforms[cam_idx] | |
# Extrinsics | |
rot = np.array(camera_transform['rotation']) | |
pos = np.array(camera_transform['position']) | |
W2C = np.zeros((4,4)) | |
W2C[:3, :3] = rot | |
W2C[:3, 3] = pos | |
W2C[3,3] = 1 | |
Rt = np.linalg.inv(W2C) | |
T = Rt[:3, 3] | |
R = Rt[:3, :3].transpose() | |
# Intrinsics | |
width = camera_transform['width'] | |
height = camera_transform['height'] | |
fy = camera_transform['fy'] | |
fx = camera_transform['fx'] | |
fov_y = focal2fov(fy, height) | |
fov_x = focal2fov(fx, width) | |
# GT data | |
id = camera_transform['id'] | |
name = camera_transform['img_name'] | |
image_path = os.path.join(image_dir, name + extension) | |
if load_gt_images: | |
image = Image.open(image_path) | |
orig_w, orig_h = image.size | |
downscale_factor = 1 | |
if image_resolution in [1, 2, 4, 8]: | |
downscale_factor = image_resolution | |
# resolution = round(orig_w/(image_resolution)), round(orig_h/(image_resolution)) | |
if max(orig_h, orig_w) > max_img_size: | |
additional_downscale_factor = max(orig_h, orig_w) / max_img_size | |
downscale_factor = additional_downscale_factor * downscale_factor | |
resolution = round(orig_w/(downscale_factor)), round(orig_h/(downscale_factor)) | |
resized_image_rgb = PILtoTorch(image, resolution) | |
gt_image = resized_image_rgb[:3, ...] | |
image_height, image_width = None, None | |
else: | |
gt_image = None | |
if image_resolution in [1, 2, 4, 8]: | |
downscale_factor = image_resolution | |
# resolution = round(orig_w/(image_resolution)), round(orig_h/(image_resolution)) | |
if max(height, width) > max_img_size: | |
additional_downscale_factor = max(height, width) / max_img_size | |
downscale_factor = additional_downscale_factor * downscale_factor | |
image_height, image_width = round(height/downscale_factor), round(width/downscale_factor) | |
gs_camera = GSCamera( | |
colmap_id=id, image=gt_image, gt_alpha_mask=None, | |
R=R, T=T, FoVx=fov_x, FoVy=fov_y, | |
image_name=name, uid=id, | |
image_height=image_height, image_width=image_width,) | |
cam_list.append(gs_camera) | |
return cam_list | |
class GSCamera(torch.nn.Module): | |
"""Class to store Gaussian Splatting camera parameters. | |
""" | |
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, | |
image_name, uid, | |
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", | |
image_height=None, image_width=None, | |
): | |
""" | |
Args: | |
colmap_id (int): ID of the camera in the COLMAP reconstruction. | |
R (np.array): Rotation matrix. | |
T (np.array): Translation vector. | |
FoVx (float): Field of view in the x direction. | |
FoVy (float): Field of view in the y direction. | |
image (np.array): GT image. | |
gt_alpha_mask (_type_): _description_ | |
image_name (_type_): _description_ | |
uid (_type_): _description_ | |
trans (_type_, optional): _description_. Defaults to np.array([0.0, 0.0, 0.0]). | |
scale (float, optional): _description_. Defaults to 1.0. | |
data_device (str, optional): _description_. Defaults to "cuda". | |
image_height (_type_, optional): _description_. Defaults to None. | |
image_width (_type_, optional): _description_. Defaults to None. | |
Raises: | |
ValueError: _description_ | |
""" | |
super(GSCamera, self).__init__() | |
self.uid = uid | |
self.colmap_id = colmap_id | |
self.R = R | |
self.T = T | |
self.FoVx = FoVx | |
self.FoVy = FoVy | |
self.image_name = image_name | |
try: | |
self.data_device = torch.device(data_device) | |
except Exception as e: | |
print(e) | |
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) | |
self.data_device = torch.device("cuda") | |
if image is None: | |
if image_height is None or image_width is None: | |
raise ValueError("Either image or image_height and image_width must be specified") | |
else: | |
self.image_height = image_height | |
self.image_width = image_width | |
else: | |
self.original_image = image.clamp(0.0, 1.0).to(self.data_device) | |
self.image_width = self.original_image.shape[2] | |
self.image_height = self.original_image.shape[1] | |
if gt_alpha_mask is not None: | |
self.original_image *= gt_alpha_mask.to(self.data_device) | |
else: | |
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) | |
self.zfar = 100.0 | |
self.znear = 0.01 | |
self.trans = trans | |
self.scale = scale | |
self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() | |
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() | |
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) | |
self.camera_center = self.world_view_transform.inverse()[3, :3] | |
def device(self): | |
return self.world_view_transform.device | |
def to(self, device): | |
self.world_view_transform = self.world_view_transform.to(device) | |
self.projection_matrix = self.projection_matrix.to(device) | |
self.full_proj_transform = self.full_proj_transform.to(device) | |
self.camera_center = self.camera_center.to(device) | |
return self | |
def create_p3d_cameras(R=None, T=None, K=None, znear=0.0001): | |
"""Creates pytorch3d-compatible camera object from R, T, K matrices. | |
Args: | |
R (torch.Tensor, optional): Rotation matrix. Defaults to Identity. | |
T (torch.Tensor, optional): Translation vector. Defaults to Zero. | |
K (torch.Tensor, optional): Camera intrinsics. Defaults to None. | |
znear (float, optional): Near clipping plane. Defaults to 0.0001. | |
Returns: | |
pytorch3d.renderer.cameras.FoVPerspectiveCameras: pytorch3d-compatible camera object. | |
""" | |
if R is None: | |
R = torch.eye(3)[None] | |
if T is None: | |
T = torch.zeros(3)[None] | |
if K is not None: | |
p3d_cameras = P3DCameras(R=R, T=T, K=K, znear=0.0001) | |
else: | |
p3d_cameras = P3DCameras(R=R, T=T, znear=0.0001) | |
p3d_cameras.K = p3d_cameras.get_projection_transform().get_matrix().transpose(-1, -2) | |
return p3d_cameras | |
def convert_camera_from_gs_to_pytorch3d(gs_cameras, device='cuda'): | |
""" | |
From Gaussian Splatting camera parameters, | |
computes R, T, K matrices and outputs pytorch3d-compatible camera object. | |
Args: | |
gs_cameras (List of GSCamera): List of Gaussian Splatting cameras. | |
device (_type_, optional): _description_. Defaults to 'cuda'. | |
Returns: | |
p3d_cameras: pytorch3d-compatible camera object. | |
""" | |
N = len(gs_cameras) | |
R = torch.Tensor(np.array([gs_camera.R for gs_camera in gs_cameras])).to(device) | |
T = torch.Tensor(np.array([gs_camera.T for gs_camera in gs_cameras])).to(device) | |
fx = torch.Tensor(np.array([fov2focal(gs_camera.FoVx, gs_camera.image_width) for gs_camera in gs_cameras])).to(device) | |
fy = torch.Tensor(np.array([fov2focal(gs_camera.FoVy, gs_camera.image_height) for gs_camera in gs_cameras])).to(device) | |
image_height = torch.tensor(np.array([gs_camera.image_height for gs_camera in gs_cameras]), dtype=torch.int).to(device) | |
image_width = torch.tensor(np.array([gs_camera.image_width for gs_camera in gs_cameras]), dtype=torch.int).to(device) | |
cx = image_width / 2. # torch.zeros_like(fx).to(device) | |
cy = image_height / 2. # torch.zeros_like(fy).to(device) | |
w2c = torch.zeros(N, 4, 4).to(device) | |
w2c[:, :3, :3] = R.transpose(-1, -2) | |
w2c[:, :3, 3] = T | |
w2c[:, 3, 3] = 1 | |
c2w = w2c.inverse() | |
c2w[:, :3, 1:3] *= -1 | |
c2w = c2w[:, :3, :] | |
distortion_params = torch.zeros(N, 6).to(device) | |
camera_type = torch.ones(N, 1, dtype=torch.int32).to(device) | |
# Pytorch3d-compatible camera matrices | |
# Intrinsics | |
image_size = torch.Tensor( | |
[image_width[0], image_height[0]], | |
)[ | |
None | |
].to(device) | |
scale = image_size.min(dim=1, keepdim=True)[0] / 2.0 | |
c0 = image_size / 2.0 | |
p0_pytorch3d = ( | |
-( | |
torch.Tensor( | |
(cx[0], cy[0]), | |
)[ | |
None | |
].to(device) | |
- c0 | |
) | |
/ scale | |
) | |
focal_pytorch3d = ( | |
torch.Tensor([fx[0], fy[0]])[None].to(device) / scale | |
) | |
K = _get_sfm_calibration_matrix( | |
1, "cpu", focal_pytorch3d, p0_pytorch3d, orthographic=False | |
) | |
K = K.expand(N, -1, -1) | |
# Extrinsics | |
line = torch.Tensor([[0.0, 0.0, 0.0, 1.0]]).to(device).expand(N, -1, -1) | |
cam2world = torch.cat([c2w, line], dim=1) | |
world2cam = cam2world.inverse() | |
R, T = world2cam.split([3, 1], dim=-1) | |
R = R[:, :3].transpose(1, 2) * torch.Tensor([-1.0, 1.0, -1]).to(device) | |
T = T.squeeze(2)[:, :3] * torch.Tensor([-1.0, 1.0, -1]).to(device) | |
p3d_cameras = P3DCameras(device=device, R=R, T=T, K=K, znear=0.0001) | |
return p3d_cameras | |
def convert_camera_from_pytorch3d_to_gs( | |
p3d_cameras: P3DCameras, | |
height: float, | |
width: float, | |
device='cuda', | |
): | |
"""From a pytorch3d-compatible camera object and its camera matrices R, T, K, and width, height, | |
outputs Gaussian Splatting camera parameters. | |
Args: | |
p3d_cameras (P3DCameras): R matrices should have shape (N, 3, 3), | |
T matrices should have shape (N, 3, 1), | |
K matrices should have shape (N, 3, 3). | |
height (float): _description_ | |
width (float): _description_ | |
device (_type_, optional): _description_. Defaults to 'cuda'. | |
""" | |
N = p3d_cameras.R.shape[0] | |
if device is None: | |
device = p3d_cameras.device | |
if type(height) == torch.Tensor: | |
height = int(torch.Tensor([[height.item()]]).to(device)) | |
width = int(torch.Tensor([[width.item()]]).to(device)) | |
else: | |
height = int(height) | |
width = int(width) | |
# Inverse extrinsics | |
R_inv = (p3d_cameras.R * torch.Tensor([-1.0, 1.0, -1]).to(device)).transpose(-1, -2) | |
T_inv = (p3d_cameras.T * torch.Tensor([-1.0, 1.0, -1]).to(device)).unsqueeze(-1) | |
world2cam_inv = torch.cat([R_inv, T_inv], dim=-1) | |
line = torch.Tensor([[0.0, 0.0, 0.0, 1.0]]).to(device).expand(N, -1, -1) | |
world2cam_inv = torch.cat([world2cam_inv, line], dim=-2) | |
cam2world_inv = world2cam_inv.inverse() | |
camera_to_worlds_inv = cam2world_inv[:, :3] | |
# Inverse intrinsics | |
image_size = torch.Tensor( | |
[width, height], | |
)[ | |
None | |
].to(device) | |
scale = image_size.min(dim=1, keepdim=True)[0] / 2.0 | |
c0 = image_size / 2.0 | |
K_inv = p3d_cameras.K[0] * scale | |
fx_inv, fy_inv = K_inv[0, 0], K_inv[1, 1] | |
cx_inv, cy_inv = c0[0, 0] - K_inv[0, 2], c0[0, 1] - K_inv[1, 2] | |
gs_cameras = [] | |
for cam_idx in range(N): | |
# NeRF 'transform_matrix' is a camera-to-world transform | |
c2w = camera_to_worlds_inv[cam_idx] | |
c2w = torch.cat([c2w, torch.Tensor([[0, 0, 0, 1]]).to(device)], dim=0).cpu().numpy() #.transpose(-1, -2) | |
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) | |
c2w[:3, 1:3] *= -1 | |
# get the world-to-camera transform and set R, T | |
w2c = np.linalg.inv(c2w) | |
R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code | |
T = w2c[:3, 3] | |
image_height=height | |
image_width=width | |
fx = fx_inv.item() | |
fy = fy_inv.item() | |
fovx = focal2fov(fx, image_width) | |
fovy = focal2fov(fy, image_height) | |
FovY = fovy | |
FovX = fovx | |
name = 'image_' + str(cam_idx) | |
camera = GSCamera( | |
colmap_id=cam_idx, image=None, gt_alpha_mask=None, | |
R=R, T=T, FoVx=FovX, FoVy=FovY, | |
image_name=name, uid=cam_idx, | |
image_height=image_height, | |
image_width=image_width, | |
) | |
gs_cameras.append(camera) | |
return gs_cameras | |
class CamerasWrapper: | |
"""Class to wrap Gaussian Splatting camera parameters | |
and facilitates both usage and integration with PyTorch3D. | |
""" | |
def __init__( | |
self, | |
gs_cameras, | |
p3d_cameras=None, | |
p3d_cameras_computed=False, | |
) -> None: | |
""" | |
Args: | |
camera_to_worlds (_type_): _description_ | |
fx (_type_): _description_ | |
fy (_type_): _description_ | |
cx (_type_): _description_ | |
cy (_type_): _description_ | |
width (_type_): _description_ | |
height (_type_): _description_ | |
distortion_params (_type_): _description_ | |
camera_type (_type_): _description_ | |
""" | |
self.gs_cameras = gs_cameras | |
self._p3d_cameras = p3d_cameras | |
self._p3d_cameras_computed = p3d_cameras_computed | |
device = gs_cameras[0].device | |
N = len(gs_cameras) | |
R = torch.Tensor(np.array([gs_camera.R for gs_camera in gs_cameras])).to(device) | |
T = torch.Tensor(np.array([gs_camera.T for gs_camera in gs_cameras])).to(device) | |
self.fx = torch.Tensor(np.array([fov2focal(gs_camera.FoVx, gs_camera.image_width) for gs_camera in gs_cameras])).to(device) | |
self.fy = torch.Tensor(np.array([fov2focal(gs_camera.FoVy, gs_camera.image_height) for gs_camera in gs_cameras])).to(device) | |
self.height = torch.tensor(np.array([gs_camera.image_height for gs_camera in gs_cameras]), dtype=torch.int).to(device) | |
self.width = torch.tensor(np.array([gs_camera.image_width for gs_camera in gs_cameras]), dtype=torch.int).to(device) | |
self.cx = self.width / 2. # torch.zeros_like(fx).to(device) | |
self.cy = self.height / 2. # torch.zeros_like(fy).to(device) | |
w2c = torch.zeros(N, 4, 4).to(device) | |
w2c[:, :3, :3] = R.transpose(-1, -2) | |
w2c[:, :3, 3] = T | |
w2c[:, 3, 3] = 1 | |
c2w = w2c.inverse() | |
c2w[:, :3, 1:3] *= -1 | |
c2w = c2w[:, :3, :] | |
self.camera_to_worlds = c2w | |
def from_p3d_cameras( | |
cls, | |
p3d_cameras, | |
width: float, | |
height: float, | |
) -> None: | |
"""Initializes CamerasWrapper from pytorch3d-compatible camera object. | |
Args: | |
p3d_cameras (_type_): _description_ | |
width (float): _description_ | |
height (float): _description_ | |
Returns: | |
_type_: _description_ | |
""" | |
cls._p3d_cameras = p3d_cameras | |
cls._p3d_cameras_computed = True | |
gs_cameras = convert_camera_from_pytorch3d_to_gs( | |
p3d_cameras, | |
height=height, | |
width=width, | |
) | |
return cls( | |
gs_cameras=gs_cameras, | |
p3d_cameras=p3d_cameras, | |
p3d_cameras_computed=True, | |
) | |
def device(self): | |
return self.camera_to_worlds.device | |
def p3d_cameras(self): | |
if not self._p3d_cameras_computed: | |
self._p3d_cameras = convert_camera_from_gs_to_pytorch3d( | |
self.gs_cameras, | |
) | |
self._p3d_cameras_computed = True | |
return self._p3d_cameras | |
def __len__(self): | |
return len(self.gs_cameras) | |
def to(self, device): | |
self.camera_to_worlds = self.camera_to_worlds.to(device) | |
self.fx = self.fx.to(device) | |
self.fy = self.fy.to(device) | |
self.cx = self.cx.to(device) | |
self.cy = self.cy.to(device) | |
self.width = self.width.to(device) | |
self.height = self.height.to(device) | |
for gs_camera in self.gs_cameras: | |
gs_camera.to(device) | |
if self._p3d_cameras_computed: | |
self._p3d_cameras = self._p3d_cameras.to(device) | |
return self | |
def get_spatial_extent(self): | |
"""Returns the spatial extent of the cameras, computed as | |
the extent of the bounding box containing all camera centers. | |
Returns: | |
(float): Spatial extent of the cameras. | |
""" | |
camera_centers = self.p3d_cameras.get_camera_center() | |
avg_camera_center = camera_centers.mean(dim=0, keepdim=True) | |
half_diagonal = torch.norm(camera_centers - avg_camera_center, dim=-1).max().item() | |
radius = 1.1 * half_diagonal | |
return radius | |