|
import sys |
|
import spaces |
|
sys.path.append("flash3d") |
|
|
|
from omegaconf import OmegaConf |
|
import gradio as gr |
|
import torch |
|
import torchvision.transforms as TT |
|
import torchvision.transforms.functional as TTF |
|
from huggingface_hub import hf_hub_download |
|
import numpy as np |
|
|
|
from networks.gaussian_predictor import GaussianPredictor |
|
from util.vis3d import save_ply |
|
|
|
def main(): |
|
print("[INFO] Starting main function...") |
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
print("[INFO] CUDA is available. Using GPU device.") |
|
else: |
|
device = "cpu" |
|
print("[INFO] CUDA is not available. Using CPU device.") |
|
|
|
|
|
print("[INFO] Downloading model configuration...") |
|
model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", |
|
filename="config_re10k_v1.yaml") |
|
print("[INFO] Downloading model weights...") |
|
model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", |
|
filename="model_re10k_v1.pth") |
|
|
|
|
|
print("[INFO] Loading model configuration...") |
|
cfg = OmegaConf.load(model_cfg_path) |
|
|
|
|
|
print("[INFO] Initializing GaussianPredictor model...") |
|
model = GaussianPredictor(cfg) |
|
try: |
|
device = torch.device(device) |
|
model.to(device) |
|
except Exception as e: |
|
print(f"[ERROR] Failed to set device: {e}") |
|
raise |
|
|
|
|
|
print("[INFO] Loading model weights...") |
|
model.load_model(model_path) |
|
|
|
|
|
pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug)) |
|
to_tensor = TT.ToTensor() |
|
|
|
|
|
def check_input_image(input_image): |
|
print("[DEBUG] Checking input image...") |
|
if input_image is None: |
|
print("[ERROR] No image uploaded!") |
|
raise gr.Error("No image uploaded!") |
|
print("[INFO] Input image is valid.") |
|
|
|
|
|
def preprocess(image, padding_value, resize_height, resize_width): |
|
print("[DEBUG] Preprocessing image...") |
|
|
|
image = TTF.resize( |
|
image, (resize_height, resize_width), |
|
interpolation=TT.InterpolationMode.BICUBIC |
|
) |
|
|
|
pad_border_fn = TT.Pad((padding_value, padding_value)) |
|
image = pad_border_fn(image) |
|
print("[INFO] Image preprocessing complete.") |
|
return image |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def reconstruct_and_export(image, num_gauss, scale_factor): |
|
""" |
|
Passes image through model, outputs reconstruction in form of a dict of tensors. |
|
""" |
|
print("[DEBUG] Starting reconstruction and export...") |
|
|
|
image = to_tensor(image).to(device).unsqueeze(0) |
|
inputs = { |
|
("color_aug", 0, 0): image, |
|
} |
|
|
|
|
|
print("[INFO] Passing image through the model...") |
|
outputs = model(inputs) |
|
|
|
|
|
print(f"[INFO] Saving output to {ply_out_path} with scale factor {scale_factor}...") |
|
save_ply(outputs, ply_out_path, num_gauss=num_gauss, scale_factor=scale_factor) |
|
print("[INFO] Reconstruction and export complete.") |
|
|
|
return ply_out_path |
|
|
|
|
|
ply_out_path = f'./mesh.ply' |
|
|
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
display:block; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown( |
|
""" |
|
# Flash3D |
|
""" |
|
) |
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
|
|
input_image = gr.Image( |
|
label="Input Image", |
|
image_mode="RGBA", |
|
sources="upload", |
|
type="pil", |
|
elem_id="content_image", |
|
) |
|
with gr.Row(): |
|
|
|
num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=10) |
|
scale_factor = gr.Slider(minimum=0.5, maximum=5.0, step=0.1, label="Scale Factor for Model Size", value=1.5, info="Test this range for stability, as extreme values may cause visual distortions or unexpected outputs.") |
|
padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32) |
|
resize_height = gr.Slider(minimum=256, maximum=1024, step=64, label="Resize Height for Image", value=cfg.dataset.height) |
|
resize_width = gr.Slider(minimum=256, maximum=1024, step=64, label="Resize Width for Image", value=cfg.dataset.width) |
|
with gr.Row(): |
|
|
|
submit = gr.Button("Generate", elem_id="generate", variant="primary") |
|
|
|
with gr.Row(variant="panel"): |
|
|
|
gr.Examples( |
|
examples=[ |
|
'./demo_examples/bedroom_01.png', |
|
'./demo_examples/kitti_02.png', |
|
'./demo_examples/kitti_03.png', |
|
'./demo_examples/re10k_04.jpg', |
|
'./demo_examples/re10k_05.jpg', |
|
'./demo_examples/re10k_06.jpg', |
|
], |
|
inputs=[input_image], |
|
cache_examples=False, |
|
label="Examples", |
|
examples_per_page=20, |
|
) |
|
|
|
with gr.Row(): |
|
|
|
processed_image = gr.Image(label="Processed Image", interactive=False) |
|
|
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
with gr.Tab("Reconstruction"): |
|
|
|
output_model = gr.Model3D( |
|
height=512, |
|
label="Output Model", |
|
interactive=False |
|
) |
|
|
|
|
|
submit.click(fn=check_input_image, inputs=[input_image]).success( |
|
fn=preprocess, |
|
inputs=[input_image, padding_value, resize_height, resize_width], |
|
outputs=[processed_image], |
|
).success( |
|
fn=reconstruct_and_export, |
|
inputs=[processed_image, num_gauss, scale_factor], |
|
outputs=[output_model], |
|
) |
|
|
|
|
|
demo.queue(max_size=1) |
|
print("[INFO] Launching Gradio demo...") |
|
demo.launch(share=True) |
|
|
|
if __name__ == "__main__": |
|
print("[INFO] Running application...") |
|
main() |