pablovela5620 commited on
Commit
5da60f1
·
verified ·
1 Parent(s): 062a096

Upload gradio_app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gradio_app.py +9 -5
gradio_app.py CHANGED
@@ -11,9 +11,7 @@ except ImportError:
11
  import torch
12
 
13
  from monopriors.relative_depth_models import (
14
- DepthAnythingV2Predictor,
15
  RelativeDepthPrediction,
16
- UniDepthRelativePredictor,
17
  get_relative_predictor,
18
  RELATIVE_PREDICTORS,
19
  )
@@ -37,9 +35,10 @@ model_load_status: str = "Models loaded and ready to use!"
37
  DEVICE: Literal["cuda"] | Literal["cpu"] = (
38
  "cuda" if torch.cuda.is_available() else "cpu"
39
  )
 
40
  if gr.NO_RELOAD:
41
- MODEL_1 = DepthAnythingV2Predictor(device=DEVICE)
42
- MODEL_2 = UniDepthRelativePredictor(device=DEVICE)
43
 
44
 
45
  def predict_depth(
@@ -53,6 +52,7 @@ def predict_depth(
53
  if IN_SPACES:
54
  predict_depth = spaces.GPU(predict_depth)
55
  # remove any model that fails on zerogpu spaces
 
56
 
57
 
58
  def load_models(
@@ -72,11 +72,15 @@ def load_models(
72
  progress(0, desc="Loading Models please wait...")
73
 
74
  models: list[int] = [model_1, model_2]
 
 
 
 
 
75
  loaded_models = []
76
 
77
  for model in models:
78
  loaded_models.append(get_relative_predictor(model)(device=DEVICE))
79
-
80
  progress(0.5, desc=f"Loaded {model}")
81
 
82
  progress(1, desc="Models Loaded")
 
11
  import torch
12
 
13
  from monopriors.relative_depth_models import (
 
14
  RelativeDepthPrediction,
 
15
  get_relative_predictor,
16
  RELATIVE_PREDICTORS,
17
  )
 
35
  DEVICE: Literal["cuda"] | Literal["cpu"] = (
36
  "cuda" if torch.cuda.is_available() else "cpu"
37
  )
38
+ MODELS_TO_SKIP: list[str] = []
39
  if gr.NO_RELOAD:
40
+ MODEL_1 = get_relative_predictor("DepthAnythingV2Predictor")(device=DEVICE)
41
+ MODEL_2 = get_relative_predictor("Metric3DRelativePredictor")(device=DEVICE)
42
 
43
 
44
  def predict_depth(
 
52
  if IN_SPACES:
53
  predict_depth = spaces.GPU(predict_depth)
54
  # remove any model that fails on zerogpu spaces
55
+ MODELS_TO_SKIP.extend(["Metric3DRelativePredictor"])
56
 
57
 
58
  def load_models(
 
72
  progress(0, desc="Loading Models please wait...")
73
 
74
  models: list[int] = [model_1, model_2]
75
+ # check if the models are in the list of models to skip
76
+ if any(model in MODELS_TO_SKIP for model in models):
77
+ raise gr.Error(
78
+ f"Model not supported on ZeroGPU, please try another model: {MODELS_TO_SKIP}"
79
+ )
80
  loaded_models = []
81
 
82
  for model in models:
83
  loaded_models.append(get_relative_predictor(model)(device=DEVICE))
 
84
  progress(0.5, desc=f"Loaded {model}")
85
 
86
  progress(1, desc="Models Loaded")