nev commited on
Commit
3fce28b
·
1 Parent(s): a336793

Ininitial commit

Browse files
Files changed (4) hide show
  1. app.py +24 -0
  2. depth.py +60 -0
  3. horse.jpg +0 -0
  4. serve_modal.py +9 -0
app.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from depth import MidasDepth
2
+ import gradio as gr
3
+ import numpy as np
4
+ import cv2
5
+
6
+
7
+ depth_estimator = MidasDepth()
8
+
9
+
10
+ def get_depth(rgb):
11
+ depth = depth_estimator.get_depth(rgb)
12
+
13
+ return rgb, (depth.clip(0, 64) * 1024).astype("uint16")
14
+
15
+
16
+ starter = gr.Interface(fn=get_depth, inputs=[
17
+ gr.components.Image(label="rgb", type="pil"),
18
+ ], outputs=[
19
+ gr.components.Image(type="pil", label="image"),
20
+ gr.components.Image(type="numpy", label="depth"),
21
+
22
+ ])
23
+
24
+ gr.Interface(get_depth).launch(share=True)
depth.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+ from tqdm.auto import trange
3
+ from PIL import ImageOps
4
+ from PIL import Image
5
+ from torch import nn
6
+ import numpy as np
7
+ import torch
8
+ import cv2
9
+
10
+
11
+ class MidasDepth(nn.Module):
12
+ def __init__(self, model_type="DPT_Large",
13
+ device=torch.device(
14
+ "cuda" if torch.cuda.is_available() else "cpu"),
15
+ is_inpainting=False):
16
+ super().__init__()
17
+ self.device = device
18
+ if self.device.type == "mps":
19
+ self.device = torch.device("cpu")
20
+ self.model = torch.hub.load(
21
+ "intel-isl/MiDaS", model_type).to(self.device).eval().requires_grad_(False)
22
+ self.transform = torch.hub.load(
23
+ "intel-isl/MiDaS", "transforms").dpt_transform
24
+
25
+ @torch.no_grad()
26
+ def forward(self, image):
27
+ if torch.is_tensor(image):
28
+ image = image.cpu().detach()
29
+ if not isinstance(image, np.ndarray):
30
+ image = np.asarray(image)
31
+ image = image.squeeze()
32
+ batch = self.transform(image).to(self.device)
33
+ prediction = self.model(batch)
34
+ prediction = torch.nn.functional.interpolate(
35
+ prediction.unsqueeze(1),
36
+ size=image.shape[-3:-1],
37
+ mode="bicubic",
38
+ align_corners=False,
39
+ )[:, 0]
40
+ # prediction = prediction - prediction.min() + 1.5
41
+ # prediction = 20 / prediction
42
+ return prediction # .squeeze()
43
+
44
+ @torch.no_grad()
45
+ def get_depth(self, img):
46
+ im = torch.from_numpy(np.asarray(img)).float().to(self.device) / 255.
47
+ og_depth = self(im.unsqueeze(0) * 255.)[0]
48
+ d = og_depth
49
+ d = (d - d.min()) / (d.max() - d.min()) * (10 - 3) + 3
50
+ d = 30 / d
51
+ # d = d.max() - d
52
+ # d = d / d.max() * 15
53
+ # d = d + 1.5
54
+ return d.detach().cpu().numpy()
55
+
56
+
57
+ if __name__ == "__main__":
58
+ from matplotlib import pyplot as plt
59
+ plt.imshow(MidasDepth().get_depth(Image.open("horse.jpg")))
60
+ plt.show()
horse.jpg ADDED
serve_modal.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import modal
2
+
3
+
4
+ stub = modal.Stub()
5
+
6
+
7
+ @stub.function()
8
+ def estimate_depth(image):
9
+ pass