hysts HF staff commited on
Commit
a8d124d
β€’
1 Parent(s): 2d03dfb

Use Gradio 4.x so it can work with ZeroGPU

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +19 -12
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸŒ–
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.14.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import cv2
3
  import numpy as np
4
  from PIL import Image
 
5
  import torch
6
  import torch.nn.functional as F
7
  from torchvision.transforms import Compose
@@ -20,14 +21,14 @@ css = """
20
  #img-display-output {
21
  max-height: 160vh;
22
  }
23
-
24
  """
25
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
26
  model = DPT_DINOv2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024]).to(DEVICE).eval()
27
  model.load_state_dict(torch.load('checkpoints/depth_anything_vitl14.pth'))
28
 
29
  title = "# Depth Anything"
30
- description = """Official demo for **Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data**.
31
 
32
  Please refer to our [paper](), [project page](https://depth-anything.github.io), or [github](https://github.com/LiheYoung/Depth-Anything) for more details."""
33
 
@@ -45,38 +46,44 @@ transform = Compose([
45
  PrepareForNet(),
46
  ])
47
 
 
 
 
 
 
 
 
48
  with gr.Blocks(css=css) as demo:
49
  gr.Markdown(title)
50
  gr.Markdown(description)
51
  gr.Markdown("### Depth Prediction demo")
52
-
53
  with gr.Row():
54
- input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input').style(height="auto")
55
  depth_image = gr.Image(label="Depth Map", elem_id='img-display-output')
56
  raw_file = gr.File(label="16-bit raw depth (can be considered as disparity)")
57
  submit = gr.Button("Submit")
58
 
59
  def on_submit(image):
60
  h, w = image.shape[:2]
61
-
62
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
63
  image = transform({'image': image})['image']
64
  image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
65
-
66
- with torch.no_grad():
67
- depth = model(image)
68
  depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
69
-
70
  raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
71
  tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
72
  raw_depth.save(tmp.name)
73
-
74
  depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
75
  depth = depth.cpu().numpy().astype(np.uint8)
76
  colored_depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]
77
-
78
  return [colored_depth, tmp.name]
79
-
80
  submit.click(on_submit, inputs=[input_image], outputs=[depth_image, raw_file])
81
  examples = gr.Examples(examples=["examples/flower.png", "examples/roller_coaster.png", "examples/hall.png", "examples/car.png", "examples/person.png"],
82
  inputs=[input_image])
 
2
  import cv2
3
  import numpy as np
4
  from PIL import Image
5
+ import spaces
6
  import torch
7
  import torch.nn.functional as F
8
  from torchvision.transforms import Compose
 
21
  #img-display-output {
22
  max-height: 160vh;
23
  }
24
+
25
  """
26
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
27
  model = DPT_DINOv2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024]).to(DEVICE).eval()
28
  model.load_state_dict(torch.load('checkpoints/depth_anything_vitl14.pth'))
29
 
30
  title = "# Depth Anything"
31
+ description = """Official demo for **Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data**.
32
 
33
  Please refer to our [paper](), [project page](https://depth-anything.github.io), or [github](https://github.com/LiheYoung/Depth-Anything) for more details."""
34
 
 
46
  PrepareForNet(),
47
  ])
48
 
49
+
50
+ @spaces.GPU
51
+ @torch.no_grad()
52
+ def predict_depth(model, image):
53
+ return model(image)
54
+
55
+
56
  with gr.Blocks(css=css) as demo:
57
  gr.Markdown(title)
58
  gr.Markdown(description)
59
  gr.Markdown("### Depth Prediction demo")
60
+
61
  with gr.Row():
62
+ input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
63
  depth_image = gr.Image(label="Depth Map", elem_id='img-display-output')
64
  raw_file = gr.File(label="16-bit raw depth (can be considered as disparity)")
65
  submit = gr.Button("Submit")
66
 
67
  def on_submit(image):
68
  h, w = image.shape[:2]
69
+
70
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
71
  image = transform({'image': image})['image']
72
  image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
73
+
74
+ depth = predict_depth(model, image)
 
75
  depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
76
+
77
  raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
78
  tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
79
  raw_depth.save(tmp.name)
80
+
81
  depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
82
  depth = depth.cpu().numpy().astype(np.uint8)
83
  colored_depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]
84
+
85
  return [colored_depth, tmp.name]
86
+
87
  submit.click(on_submit, inputs=[input_image], outputs=[depth_image, raw_file])
88
  examples = gr.Examples(examples=["examples/flower.png", "examples/roller_coaster.png", "examples/hall.png", "examples/car.png", "examples/person.png"],
89
  inputs=[input_image])