|
import math |
|
import os |
|
from dataclasses import dataclass, field |
|
from typing import List, Union |
|
|
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
import torch.nn.functional as F |
|
import trimesh |
|
from einops import rearrange |
|
from huggingface_hub import hf_hub_download |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
|
|
from .utils import ( |
|
BaseModule, |
|
ImagePreprocessor, |
|
find_class, |
|
get_spherical_cameras, |
|
scale_tensor, |
|
) |
|
|
|
|
|
class TSR(BaseModule): |
|
@dataclass |
|
class Config(BaseModule.Config): |
|
cond_image_size: int |
|
|
|
image_tokenizer_cls: str |
|
image_tokenizer: dict |
|
|
|
tokenizer_cls: str |
|
tokenizer: dict |
|
|
|
backbone_cls: str |
|
backbone: dict |
|
|
|
post_processor_cls: str |
|
post_processor: dict |
|
|
|
decoder_cls: str |
|
decoder: dict |
|
|
|
renderer_cls: str |
|
renderer: dict |
|
|
|
cfg: Config |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str |
|
): |
|
if os.path.isdir(pretrained_model_name_or_path): |
|
config_path = os.path.join(pretrained_model_name_or_path, config_name) |
|
weight_path = os.path.join(pretrained_model_name_or_path, weight_name) |
|
else: |
|
config_path = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, filename=config_name |
|
) |
|
weight_path = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, filename=weight_name |
|
) |
|
|
|
cfg = OmegaConf.load(config_path) |
|
OmegaConf.resolve(cfg) |
|
model = cls(cfg) |
|
ckpt = torch.load(weight_path, map_location="cpu") |
|
model.load_state_dict(ckpt) |
|
return model |
|
|
|
def configure(self): |
|
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)( |
|
self.cfg.image_tokenizer |
|
) |
|
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer) |
|
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone) |
|
self.post_processor = find_class(self.cfg.post_processor_cls)( |
|
self.cfg.post_processor |
|
) |
|
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder) |
|
self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer) |
|
self.image_processor = ImagePreprocessor() |
|
self.isosurface_helper = None |
|
|
|
def forward( |
|
self, |
|
image: Union[ |
|
PIL.Image.Image, |
|
np.ndarray, |
|
torch.FloatTensor, |
|
List[PIL.Image.Image], |
|
List[np.ndarray], |
|
List[torch.FloatTensor], |
|
], |
|
device: str, |
|
) -> torch.FloatTensor: |
|
rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to( |
|
device |
|
) |
|
batch_size = rgb_cond.shape[0] |
|
|
|
input_image_tokens: torch.Tensor = self.image_tokenizer( |
|
rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1), |
|
) |
|
|
|
input_image_tokens = rearrange( |
|
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1 |
|
) |
|
|
|
tokens: torch.Tensor = self.tokenizer(batch_size) |
|
|
|
tokens = self.backbone( |
|
tokens, |
|
encoder_hidden_states=input_image_tokens, |
|
) |
|
|
|
scene_codes = self.post_processor(self.tokenizer.detokenize(tokens)) |
|
return scene_codes |
|
|
|
def render( |
|
self, |
|
scene_codes, |
|
n_views: int, |
|
elevation_deg: float = 0.0, |
|
camera_distance: float = 1.9, |
|
fovy_deg: float = 40.0, |
|
height: int = 256, |
|
width: int = 256, |
|
return_type: str = "pil", |
|
): |
|
rays_o, rays_d = get_spherical_cameras( |
|
n_views, elevation_deg, camera_distance, fovy_deg, height, width |
|
) |
|
rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device) |
|
|
|
def process_output(image: torch.FloatTensor): |
|
if return_type == "pt": |
|
return image |
|
elif return_type == "np": |
|
return image.detach().cpu().numpy() |
|
elif return_type == "pil": |
|
return Image.fromarray( |
|
(image.detach().cpu().numpy() * 255.0).astype(np.uint8) |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
images = [] |
|
for scene_code in scene_codes: |
|
images_ = [] |
|
for i in range(n_views): |
|
with torch.no_grad(): |
|
image = self.renderer( |
|
self.decoder, scene_code, rays_o[i], rays_d[i] |
|
) |
|
images_.append(process_output(image)) |
|
images.append(images_) |
|
|
|
return images |
|
|
|
def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0): |
|
meshes = [] |
|
for scene_code in scene_codes: |
|
with torch.no_grad(): |
|
v_pos, t_pos_idx = self.renderer.block_based_marchingcube(self.decoder.to(scene_codes.device), |
|
scene_code, |
|
resolution, |
|
threshold |
|
) |
|
color = self.renderer.query_triplane(self.decoder.to(scene_codes.device), v_pos.to(scene_codes.device), scene_code, False)["color"] |
|
v_pos = scale_tensor( |
|
v_pos, |
|
(-1.0, 1.0), |
|
(-self.renderer.cfg.radius, self.renderer.cfg.radius) |
|
) |
|
mesh = trimesh.Trimesh( |
|
vertices=v_pos.cpu().numpy(), |
|
faces=t_pos_idx.cpu().numpy(), |
|
vertex_colors=color.cpu().numpy(), |
|
) |
|
meshes.append(mesh) |
|
return meshes |
|
|