Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from PIL import Image | |
import src.depth_pro as depth_pro | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import subprocess | |
import spaces | |
import torch | |
import tempfile | |
import os | |
import trimesh | |
import time # Add this import at the top of the file | |
# Run the script to download pretrained models | |
subprocess.run(["bash", "get_pretrained_models.sh"]) | |
# Set the device to GPU if available, else CPU | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Load the depth prediction model and its preprocessing transforms | |
model, transform = depth_pro.create_model_and_transforms() | |
model = model.to(device) # Move the model to the selected device | |
model.eval() # Set the model to evaluation mode | |
def resize_image(image_path, max_size=1024): | |
""" | |
Resize the input image to ensure its largest dimension does not exceed max_size. | |
Maintains the aspect ratio and saves the resized image as a temporary PNG file. | |
Args: | |
image_path (str): Path to the input image. | |
max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024. | |
Returns: | |
str: Path to the resized temporary image file. | |
""" | |
with Image.open(image_path) as img: | |
# Calculate the resizing ratio while maintaining aspect ratio | |
ratio = max_size / max(img.size) | |
new_size = tuple([int(x * ratio) for x in img.size]) | |
# Resize the image using LANCZOS filter for high-quality downsampling | |
img = img.resize(new_size, Image.LANCZOS) | |
# Save the resized image to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: | |
img.save(temp_file, format="PNG") | |
return temp_file.name | |
def generate_3d_model(depth, image_path, focallength_px): | |
""" | |
Generate a textured 3D mesh from the depth map and the original image. | |
Args: | |
depth (np.ndarray): 2D array representing depth in meters. | |
image_path (str): Path to the resized RGB image. | |
focallength_px (float): Focal length in pixels. | |
Returns: | |
tuple: Paths to the exported 3D model files for viewing and downloading. | |
""" | |
# Load the RGB image and convert to a NumPy array | |
image = np.array(Image.open(image_path)) | |
height, width = depth.shape | |
# Compute camera intrinsic parameters | |
fx = fy = focallength_px # Assuming square pixels and fx = fy | |
cx, cy = width / 2, height / 2 # Principal point at the image center | |
# Create a grid of (u, v) pixel coordinates | |
u = np.arange(0, width) | |
v = np.arange(0, height) | |
uu, vv = np.meshgrid(u, v) | |
# Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model | |
Z = depth.flatten() | |
X = ((uu.flatten() - cx) * Z) / fx | |
Y = ((vv.flatten() - cy) * Z) / fy | |
# Stack the coordinates to form vertices (X, Y, Z) | |
vertices = np.vstack((X, Y, Z)).T | |
# Normalize RGB colors to [0, 1] for vertex coloring | |
colors = image.reshape(-1, 3) / 255.0 | |
# Generate faces by connecting adjacent vertices to form triangles | |
faces = [] | |
for i in range(height - 1): | |
for j in range(width - 1): | |
idx = i * width + j | |
# Triangle 1 | |
faces.append([idx, idx + width, idx + 1]) | |
# Triangle 2 | |
faces.append([idx + 1, idx + width, idx + width + 1]) | |
faces = np.array(faces) | |
# Create the mesh using Trimesh with vertex colors | |
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors) | |
# Export the mesh to OBJ files with unique filenames | |
timestamp = int(time.time()) | |
view_model_path = f'view_model_{timestamp}.obj' | |
download_model_path = f'download_model_{timestamp}.obj' | |
mesh.export(view_model_path) | |
mesh.export(download_model_path) | |
return view_model_path, download_model_path | |
def predict_depth(input_image): | |
""" | |
Predict the depth map from the input image, generate visualizations and a 3D model. | |
Args: | |
input_image (str): Path to the input image file. | |
Returns: | |
tuple: | |
- str: Path to the depth map image. | |
- str: Focal length in pixels or an error message. | |
- str: Path to the raw depth data CSV file. | |
- str: Path to the generated 3D model file for viewing. | |
- str: Path to the downloadable 3D model file. | |
""" | |
temp_file = None | |
try: | |
# Resize the input image to a manageable size | |
temp_file = resize_image(input_image) | |
# Preprocess the image for depth prediction | |
result = depth_pro.load_rgb(temp_file) | |
image = result[0] | |
f_px = result[-1] # Focal length in pixels | |
image = transform(image) # Apply preprocessing transforms | |
image = image.to(device) # Move the image tensor to the selected device | |
# Run the depth prediction model | |
prediction = model.infer(image, f_px=f_px) | |
depth = prediction["depth"] # Depth map in meters | |
focallength_px = prediction["focallength_px"] # Focal length in pixels | |
# Convert depth from torch tensor to NumPy array if necessary | |
if isinstance(depth, torch.Tensor): | |
depth = depth.cpu().numpy() | |
# Ensure the depth map is a 2D array | |
if depth.ndim != 2: | |
depth = depth.squeeze() | |
# **Downsample depth map and image to improve processing speed** | |
downscale_factor = 2 # Factor by which to downscale (e.g., 2 reduces dimensions by half) | |
depth = depth[::downscale_factor, ::downscale_factor] | |
# Convert image tensor to CPU and NumPy for slicing | |
image_np = image.cpu().detach().numpy()[0].transpose(1, 2, 0) | |
image_ds = image_np[::downscale_factor, ::downscale_factor, :] | |
# Update focal length based on downscaling | |
focallength_px = focallength_px / downscale_factor | |
# **Note:** The downscaled image is saved back to the temporary file for consistency | |
downscaled_image = Image.fromarray((image_ds * 255).astype(np.uint8)) | |
downscaled_image.save(temp_file) | |
# No normalization of depth map as it is already in meters | |
depth_min = np.min(depth) | |
depth_max = np.max(depth) | |
depth_normalized = depth # Depth remains in meters | |
# Create a color map for visualization using matplotlib | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(depth_normalized, cmap='gist_rainbow') | |
plt.colorbar(label='Depth [m]') | |
plt.title(f'Predicted Depth Map - Min: {depth_min:.1f}m, Max: {depth_max:.1f}m') | |
plt.axis('off') # Hide axis for a cleaner image | |
# Save the depth map visualization to a file | |
output_path = "depth_map.png" | |
plt.savefig(output_path) | |
plt.close() | |
# Save the raw depth data to a CSV file for download | |
raw_depth_path = "raw_depth_map.csv" | |
np.savetxt(raw_depth_path, depth, delimiter=',') | |
# Generate the 3D model from the depth map and resized image | |
view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px) | |
return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path | |
except Exception as e: | |
# Return error messages in case of failures | |
return None, f"An error occurred: {str(e)}", None, None, None | |
finally: | |
# Clean up by removing the temporary resized image file | |
if temp_file and os.path.exists(temp_file): | |
os.remove(temp_file) | |
# Create the Gradio interface with appropriate input and output components | |
iface = gr.Interface( | |
fn=predict_depth, | |
inputs=gr.Image(type="filepath"), | |
outputs=[ | |
gr.Image(type="filepath", label="Depth Map"), # Displays the depth map image | |
gr.Textbox(label="Focal Length or Error Message"), # Shows focal length or error messages | |
gr.File(label="Download Raw Depth Map (CSV)"), # Allows downloading the raw depth data | |
gr.Model3D(label="View 3D Model"), # For viewing the 3D model | |
gr.File(label="Download 3D Model (OBJ)") # For downloading the 3D model | |
], | |
title="DepthPro Demo with 3D Visualization", | |
description=( | |
"An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n" | |
"**Instructions:**\n" | |
"1. Upload an image.\n" | |
"2. The app will predict the depth map, display it, and provide the focal length.\n" | |
"3. Download the raw depth data as a CSV file.\n" | |
"4. View the generated 3D model textured with the original image.\n" | |
"5. Download the 3D model as an OBJ file if desired." | |
), | |
) | |
# Launch the Gradio interface with sharing enabled | |
iface.launch(share=True) # share=True allows you to share the interface with others. |