ghostsInTheMachine commited on
Commit
e4ce1e7
1 Parent(s): 3508f0c

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +21 -20
infer.py CHANGED
@@ -10,7 +10,13 @@ import spaces # Import the spaces module for ZeroGPU
10
 
11
  check_min_version('0.28.0.dev0')
12
 
 
 
 
 
 
13
  def load_models(task_name, device):
 
14
  if task_name == 'depth':
15
  model_g = 'jingheya/lotus-depth-g-v1-0'
16
  model_d = 'jingheya/lotus-depth-d-v1-1'
@@ -19,10 +25,20 @@ def load_models(task_name, device):
19
  model_d = 'jingheya/lotus-normal-d-v1-0'
20
 
21
  dtype = torch.float16
22
- # Models will be loaded inside the GPU-decorated function
23
- return model_g, model_d, dtype
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- @spaces.GPU
26
  def infer_pipe(pipe, image, task_name, seed, device):
27
  if seed is None:
28
  generator = None
@@ -67,22 +83,7 @@ def infer_pipe(pipe, image, task_name, seed, device):
67
 
68
  return output_color
69
 
70
- @spaces.GPU
71
- def lotus(image, task_name, seed, device, model_g, model_d, dtype):
72
- # Load models inside the GPU-decorated function
73
- pipe_g = LotusGPipeline.from_pretrained(
74
- model_g,
75
- torch_dtype=dtype,
76
- )
77
- pipe_d = LotusDPipeline.from_pretrained(
78
- model_d,
79
- torch_dtype=dtype,
80
- )
81
- pipe_g.to(device)
82
- pipe_d.to(device)
83
- pipe_g.set_progress_bar_config(disable=True)
84
- pipe_d.set_progress_bar_config(disable=True)
85
- logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
86
-
87
  output_d = infer_pipe(pipe_d, image, task_name, seed, device)
88
  return output_d # Only returning depth outputs for this application
 
10
 
11
  check_min_version('0.28.0.dev0')
12
 
13
+ # Global variables to store the models
14
+ pipe_g = None
15
+ pipe_d = None
16
+
17
+ @spaces.GPU
18
  def load_models(task_name, device):
19
+ global pipe_g, pipe_d # Use global variables to store the models
20
  if task_name == 'depth':
21
  model_g = 'jingheya/lotus-depth-g-v1-0'
22
  model_d = 'jingheya/lotus-depth-d-v1-1'
 
25
  model_d = 'jingheya/lotus-normal-d-v1-0'
26
 
27
  dtype = torch.float16
28
+ pipe_g = LotusGPipeline.from_pretrained(
29
+ model_g,
30
+ torch_dtype=dtype,
31
+ )
32
+ pipe_d = LotusDPipeline.from_pretrained(
33
+ model_d,
34
+ torch_dtype=dtype,
35
+ )
36
+ pipe_g.to(device)
37
+ pipe_d.to(device)
38
+ pipe_g.set_progress_bar_config(disable=True)
39
+ pipe_d.set_progress_bar_config(disable=True)
40
+ logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
41
 
 
42
  def infer_pipe(pipe, image, task_name, seed, device):
43
  if seed is None:
44
  generator = None
 
83
 
84
  return output_color
85
 
86
+ def lotus(image, task_name, seed, device):
87
+ global pipe_g, pipe_d # Access the global models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  output_d = infer_pipe(pipe_d, image, task_name, seed, device)
89
  return output_d # Only returning depth outputs for this application