import gradio as gr from PIL import Image import src.depth_pro as depth_pro import numpy as np import matplotlib.pyplot as plt import os os.system("source get_pretrained_models.sh") # Load model and preprocessing transform model, transform = depth_pro.create_model_and_transforms() model.eval() def predict_depth(input_image): # Preprocess the image result = depth_pro.load_rgb(input_image.name) image = result[0] f_px = result[-1] # Assuming f_px is the last item in the returned tuple image = transform(image) # Run inference prediction = model.infer(image, f_px=f_px) depth = prediction["depth"] # Depth in [m] focallength_px = prediction["focallength_px"] # Focal length in pixels # Normalize depth for visualization depth_normalized = (depth - np.min(depth)) / (np.max(depth) - np.min(depth)) # Create a color map plt.figure(figsize=(10, 10)) plt.imshow(depth_normalized, cmap='viridis') plt.colorbar(label='Depth') plt.title('Predicted Depth Map') plt.axis('off') # Save the plot to a file output_path = "depth_map.png" plt.savefig(output_path) plt.close() return output_path, f"Focal length: {focallength_px:.2f} pixels" # Create Gradio interface iface = gr.Interface( fn=predict_depth, inputs=gr.Image(type="filepath"), outputs=[gr.Image(type="filepath", label="Depth Map"), gr.Textbox(label="Focal Length")], title="Depth Prediction Demo", description="Upload an image to predict its depth map and focal length." ) # Launch the interface iface.launch()