JUGGHM commited on
Commit
0c39f3c
1 Parent(s): 69543a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -10
app.py CHANGED
@@ -33,19 +33,34 @@ from mono.utils.transform import gray_to_colormap
33
  from mono.utils.visualization import vis_surface_normal
34
  import gradio as gr
35
 
36
- cfg = Config.fromfile('./mono/configs/HourglassDecoder/vit.raft5.large.py')
37
-
38
  torch.hub.download_url_to_file('https://images.unsplash.com/photo-1437622368342-7a3d73a34c8f', 'turtle.jpg')
39
  torch.hub.download_url_to_file('https://images.unsplash.com/photo-1519066629447-267fffa62d4b', 'lions.jpg')
40
 
41
- model = get_configured_monodepth_model(cfg, )
42
- model, _, _, _ = load_ckpt('./weight/metric_depth_vit_large_800k.pth', model, strict_match=False)
43
- model.eval()
 
44
 
45
- device = "cpu"
46
- model.to(device)
 
 
47
 
48
- def depth_normal(img):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  cv_image = np.array(img)
50
  img = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
51
  intrinsic = [1000.0, 1000.0, img.shape[1]/2, img.shape[0]/2]
@@ -86,7 +101,7 @@ def depth_normal(img):
86
  #normal = gr.outputs.Image(type="pil",label="Output Normal")
87
 
88
  title = "Metric3D"
89
- description = "Gradio demo for Metric3D (vit-large) running on CPU which takes in a single image for computing metric depth and surface normal. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
90
  article = "<p style='text-align: center'><a href='https://arxiv.org/pdf/2307.10984.pdf'>Metric3D: Towards Zero-shot Metric 3D Prediction from A Single Image</a> | <a href='https://github.com/YvanYin/Metric3D'>Github Repo</a></p>"
91
 
92
  examples = [
@@ -96,6 +111,6 @@ examples = [
96
 
97
  gr.Interface(
98
  depth_normal,
99
- inputs=[gr.Image(type='pil', label="Original Image")],
100
  outputs=[gr.Image(type="pil",label="Output Depth"), gr.Image(type="pil",label="Output Normal")],
101
  title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch()
 
33
  from mono.utils.visualization import vis_surface_normal
34
  import gradio as gr
35
 
 
 
36
  torch.hub.download_url_to_file('https://images.unsplash.com/photo-1437622368342-7a3d73a34c8f', 'turtle.jpg')
37
  torch.hub.download_url_to_file('https://images.unsplash.com/photo-1519066629447-267fffa62d4b', 'lions.jpg')
38
 
39
+ cfg_large = Config.fromfile('./mono/configs/HourglassDecoder/vit.raft5.large.py')
40
+ model_large = get_configured_monodepth_model(cfg_large, )
41
+ model_large, _, _, _ = load_ckpt('./weight/metric_depth_vit_large_800k.pth', model_large, strict_match=False)
42
+ model_large.eval()
43
 
44
+ cfg_small = Config.fromfile('./mono/configs/HourglassDecoder/vit.raft5.small.py')
45
+ model_small = get_configured_monodepth_model(cfg_small, )
46
+ model_small, _, _, _ = load_ckpt('./weight/metric_depth_vit_small_800k.pth', model_small, strict_match=False)
47
+ model_small.eval()
48
 
49
+ device = "cpu"
50
+ model_large.to(device)
51
+ model_small.to(device)
52
+
53
+ def depth_normal(img, model_selection="vit-small"):
54
+ if model_selection == "vit_small":
55
+ model = model_small
56
+ cfg = cfg_small
57
+ elif model == "vit_large":
58
+ model = model_large
59
+ cfg = cfg_large
60
+
61
+ else:
62
+ raise NotImplementedError
63
+
64
  cv_image = np.array(img)
65
  img = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
66
  intrinsic = [1000.0, 1000.0, img.shape[1]/2, img.shape[0]/2]
 
101
  #normal = gr.outputs.Image(type="pil",label="Output Normal")
102
 
103
  title = "Metric3D"
104
+ description = "Gradio demo for Metric3D (v2, more diverse models) running on CPU which takes in a single image for computing metric depth and surface normal. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
105
  article = "<p style='text-align: center'><a href='https://arxiv.org/pdf/2307.10984.pdf'>Metric3D: Towards Zero-shot Metric 3D Prediction from A Single Image</a> | <a href='https://github.com/YvanYin/Metric3D'>Github Repo</a></p>"
106
 
107
  examples = [
 
111
 
112
  gr.Interface(
113
  depth_normal,
114
+ inputs=[gr.Image(type='pil', label="Original Image"), gr.Dropdown(["vit-small", "vit-large"], label="Model", info="Will support more models later!"],
115
  outputs=[gr.Image(type="pil",label="Output Depth"), gr.Image(type="pil",label="Output Normal")],
116
  title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch()