oxkitsune commited on
Commit
4d29a77
·
1 Parent(s): 846566d

move image to correct device

Browse files
Files changed (1) hide show
  1. app.py +26 -5
app.py CHANGED
@@ -5,6 +5,7 @@ import subprocess
5
 
6
  import torch
7
  import cv2
 
8
  import os
9
  from pathlib import Path
10
  import gradio as gr
@@ -24,9 +25,30 @@ model, transform = depth_pro.create_model_and_transforms()
24
  model = model.to(device)
25
  model.eval()
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- @rr.thread_local_stream("rerun_example_ml_depth_pro")
29
  @spaces.GPU(duration=20)
 
 
 
 
 
 
 
 
 
30
  def run_ml_depth_pro(frame):
31
  stream = rr.binary_stream()
32
 
@@ -51,16 +73,15 @@ def run_ml_depth_pro(frame):
51
  rr.set_time_sequence("frame", 0)
52
  rr.log("world/camera/image", rr.Image(frame))
53
 
54
- image = transform(frame)
55
- prediction = model.infer(image)
56
- depth = prediction["depth"].squeeze().detach().cpu().numpy()
57
 
58
  rr.log(
59
  "world/camera",
60
  rr.Pinhole(
61
  width=frame.shape[1],
62
  height=frame.shape[0],
63
- focal_length=prediction["focallength_px"].item(),
64
  principal_point=(frame.shape[1] / 2, frame.shape[0] / 2),
65
  image_plane_distance=depth.max(),
66
  ),
 
5
 
6
  import torch
7
  import cv2
8
+ import numpy as np
9
  import os
10
  from pathlib import Path
11
  import gradio as gr
 
25
  model = model.to(device)
26
  model.eval()
27
 
28
+ def resize_image(image_path, max_size=1536):
29
+ with Image.open(image_path) as img:
30
+ # Calculate the new size while maintaining aspect ratio
31
+ ratio = max_size / max(img.size)
32
+ new_size = tuple([int(x * ratio) for x in img.size])
33
+
34
+ # Resize the image
35
+ img = img.resize(new_size, Image.LANCZOS)
36
+
37
+ # Create a temporary file
38
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
39
+ img.save(temp_file, format="PNG")
40
+ return temp_file.name
41
 
 
42
  @spaces.GPU(duration=20)
43
+ def predict(frame):
44
+ image = transform(frame)
45
+ image = image.to(device)
46
+ prediction = model.infer(image)
47
+ depth = prediction["depth"].squeeze().detach().cpu().numpy()
48
+ return depth.cpu().numpy(), prediction["focallength_px"].item()
49
+
50
+
51
+ @rr.thread_local_stream("rerun_example_ml_depth_pro")
52
  def run_ml_depth_pro(frame):
53
  stream = rr.binary_stream()
54
 
 
73
  rr.set_time_sequence("frame", 0)
74
  rr.log("world/camera/image", rr.Image(frame))
75
 
76
+ depth, focal_length = predict(frame)
77
+
 
78
 
79
  rr.log(
80
  "world/camera",
81
  rr.Pinhole(
82
  width=frame.shape[1],
83
  height=frame.shape[0],
84
+ focal_length=focal_length,
85
  principal_point=(frame.shape[1] / 2, frame.shape[0] / 2),
86
  image_plane_distance=depth.max(),
87
  ),