ghostsInTheMachine commited on
Commit
7376db6
1 Parent(s): a9377e4

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +22 -16
infer.py CHANGED
@@ -6,6 +6,7 @@ from diffusers.utils import check_min_version
6
  from pipeline import LotusGPipeline, LotusDPipeline
7
  from utils.image_utils import colorize_depth_map
8
  from contextlib import nullcontext
 
9
 
10
  check_min_version('0.28.0.dev0')
11
 
@@ -18,21 +19,10 @@ def load_models(task_name, device):
18
  model_d = 'jingheya/lotus-normal-d-v1-0'
19
 
20
  dtype = torch.float16
21
- pipe_g = LotusGPipeline.from_pretrained(
22
- model_g,
23
- torch_dtype=dtype,
24
- )
25
- pipe_d = LotusDPipeline.from_pretrained(
26
- model_d,
27
- torch_dtype=dtype,
28
- )
29
- pipe_g.to(device)
30
- pipe_d.to(device)
31
- pipe_g.set_progress_bar_config(disable=True)
32
- pipe_d.set_progress_bar_config(disable=True)
33
- logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
34
- return pipe_g, pipe_d
35
 
 
36
  def infer_pipe(pipe, image, task_name, seed, device):
37
  if seed is None:
38
  generator = None
@@ -42,7 +32,7 @@ def infer_pipe(pipe, image, task_name, seed, device):
42
  if torch.backends.mps.is_available():
43
  autocast_ctx = nullcontext()
44
  else:
45
- autocast_ctx = torch.autocast(pipe.device.type)
46
 
47
  with torch.no_grad():
48
  with autocast_ctx:
@@ -77,6 +67,22 @@ def infer_pipe(pipe, image, task_name, seed, device):
77
 
78
  return output_color
79
 
80
- def lotus(image, task_name, seed, device, pipe_g, pipe_d):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  output_d = infer_pipe(pipe_d, image, task_name, seed, device)
82
  return output_d # Only returning depth outputs for this application
 
6
  from pipeline import LotusGPipeline, LotusDPipeline
7
  from utils.image_utils import colorize_depth_map
8
  from contextlib import nullcontext
9
+ import spaces # Import the spaces module for ZeroGPU
10
 
11
  check_min_version('0.28.0.dev0')
12
 
 
19
  model_d = 'jingheya/lotus-normal-d-v1-0'
20
 
21
  dtype = torch.float16
22
+ # Models will be loaded inside the GPU-decorated function
23
+ return model_g, model_d, dtype
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ @spaces.GPU
26
  def infer_pipe(pipe, image, task_name, seed, device):
27
  if seed is None:
28
  generator = None
 
32
  if torch.backends.mps.is_available():
33
  autocast_ctx = nullcontext()
34
  else:
35
+ autocast_ctx = torch.autocast(device_type='cuda', dtype=torch.float16)
36
 
37
  with torch.no_grad():
38
  with autocast_ctx:
 
67
 
68
  return output_color
69
 
70
+ @spaces.GPU
71
+ def lotus(image, task_name, seed, device, model_g, model_d, dtype):
72
+ # Load models inside the GPU-decorated function
73
+ pipe_g = LotusGPipeline.from_pretrained(
74
+ model_g,
75
+ torch_dtype=dtype,
76
+ )
77
+ pipe_d = LotusDPipeline.from_pretrained(
78
+ model_d,
79
+ torch_dtype=dtype,
80
+ )
81
+ pipe_g.to(device)
82
+ pipe_d.to(device)
83
+ pipe_g.set_progress_bar_config(disable=True)
84
+ pipe_d.set_progress_bar_config(disable=True)
85
+ logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
86
+
87
  output_d = infer_pipe(pipe_d, image, task_name, seed, device)
88
  return output_d # Only returning depth outputs for this application