File size: 2,208 Bytes
b19928f
 
a394b1d
b19928f
 
d6c2352
 
2e549d0
5a3dc03
e9d914e
d6c2352
b19928f
 
 
 
 
d6c2352
b19928f
2e549d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b19928f
 
 
 
 
2e549d0
b19928f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch

# Run the script to get pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])

# Load model and preprocessing transform
model, transform = depth_pro.create_model_and_transforms()
model.eval()

@spaces.GPU(duration=120)
def predict_depth(input_image):
    try:
        # Preprocess the image
        result = depth_pro.load_rgb(input_image)
        image = result[0]
        f_px = result[-1]  # Assuming f_px is the last item in the returned tuple
        image = transform(image)

        # Run inference
        prediction = model.infer(image, f_px=f_px)
        depth = prediction["depth"]  # Depth in [m]
        focallength_px = prediction["focallength_px"]  # Focal length in pixels

        # Convert depth to numpy array if it's a torch tensor
        if isinstance(depth, torch.Tensor):
            depth = depth.cpu().numpy()

        # Ensure depth is a 2D numpy array
        if depth.ndim != 2:
            depth = depth.squeeze()

        # Normalize depth for visualization
        depth_min = np.min(depth)
        depth_max = np.max(depth)
        depth_normalized = (depth - depth_min) / (depth_max - depth_min)
        
        # Create a color map
        plt.figure(figsize=(10, 10))
        plt.imshow(depth_normalized, cmap='viridis')
        plt.colorbar(label='Depth')
        plt.title('Predicted Depth Map')
        plt.axis('off')
        
        # Save the plot to a file
        output_path = "depth_map.png"
        plt.savefig(output_path)
        plt.close()

        return output_path, f"Focal length: {focallength_px:.2f} pixels"
    except Exception as e:
        return None, f"An error occurred: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=predict_depth,
    inputs=gr.Image(type="filepath"),
    outputs=[gr.Image(type="filepath", label="Depth Map"), gr.Textbox(label="Focal Length or Error Message")],
    title="Depth Prediction Demo",
    description="Upload an image to predict its depth map and focal length."
)

# Launch the interface
iface.launch()