qubvel-hf HF staff commited on
Commit
1837eda
1 Parent(s): e186d7e
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -94,7 +94,7 @@ def predict_depth_v1(image, model_name):
94
  def predict_depth_v2(image, model_name):
95
  if model_name not in depth_anything_v2_models:
96
  depth_anything_v2_models[model_name] = get_v2_model(model_name)
97
- model = depth_anything_v2_models[model_name]
98
  return model.infer_image(image)
99
 
100
 
@@ -115,6 +115,7 @@ def compute_depth_map_v1(image, model_select):
115
 
116
 
117
  @spaces.GPU
 
118
  def on_submit(image, model_v1_select, model_v2_select):
119
  logger.info(f"Computing depth for V1 model: {model_v1_select} and V2 model: {model_v2_select}")
120
  colored_depth_v1 = compute_depth_map_v1(image, model_v1_select)
@@ -142,10 +143,7 @@ with gr.Blocks(css=css) as demo:
142
  example_files = os.listdir('assets/examples')
143
  example_files.sort()
144
  example_files = [os.path.join('assets/examples', filename) for filename in example_files]
145
- examples = gr.Examples(
146
- examples=example_files, inputs=[input_image, model_select_v1, model_select_v2],
147
- outputs=[depth_image_slider], fn=on_submit, cache_examples="lazy",
148
- )
149
 
150
 
151
  if __name__ == '__main__':
 
94
  def predict_depth_v2(image, model_name):
95
  if model_name not in depth_anything_v2_models:
96
  depth_anything_v2_models[model_name] = get_v2_model(model_name)
97
+ model = depth_anything_v2_models[model_name].cuda()
98
  return model.infer_image(image)
99
 
100
 
 
115
 
116
 
117
  @spaces.GPU
118
+ @torch.no_grad()
119
  def on_submit(image, model_v1_select, model_v2_select):
120
  logger.info(f"Computing depth for V1 model: {model_v1_select} and V2 model: {model_v2_select}")
121
  colored_depth_v1 = compute_depth_map_v1(image, model_v1_select)
 
143
  example_files = os.listdir('assets/examples')
144
  example_files.sort()
145
  example_files = [os.path.join('assets/examples', filename) for filename in example_files]
146
+ examples = gr.Examples(examples=example_files, inputs=[input_image])
 
 
 
147
 
148
 
149
  if __name__ == '__main__':