ghostsInTheMachine
commited on
Commit
•
7376db6
1
Parent(s):
a9377e4
Update infer.py
Browse files
infer.py
CHANGED
@@ -6,6 +6,7 @@ from diffusers.utils import check_min_version
|
|
6 |
from pipeline import LotusGPipeline, LotusDPipeline
|
7 |
from utils.image_utils import colorize_depth_map
|
8 |
from contextlib import nullcontext
|
|
|
9 |
|
10 |
check_min_version('0.28.0.dev0')
|
11 |
|
@@ -18,21 +19,10 @@ def load_models(task_name, device):
|
|
18 |
model_d = 'jingheya/lotus-normal-d-v1-0'
|
19 |
|
20 |
dtype = torch.float16
|
21 |
-
|
22 |
-
|
23 |
-
torch_dtype=dtype,
|
24 |
-
)
|
25 |
-
pipe_d = LotusDPipeline.from_pretrained(
|
26 |
-
model_d,
|
27 |
-
torch_dtype=dtype,
|
28 |
-
)
|
29 |
-
pipe_g.to(device)
|
30 |
-
pipe_d.to(device)
|
31 |
-
pipe_g.set_progress_bar_config(disable=True)
|
32 |
-
pipe_d.set_progress_bar_config(disable=True)
|
33 |
-
logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
|
34 |
-
return pipe_g, pipe_d
|
35 |
|
|
|
36 |
def infer_pipe(pipe, image, task_name, seed, device):
|
37 |
if seed is None:
|
38 |
generator = None
|
@@ -42,7 +32,7 @@ def infer_pipe(pipe, image, task_name, seed, device):
|
|
42 |
if torch.backends.mps.is_available():
|
43 |
autocast_ctx = nullcontext()
|
44 |
else:
|
45 |
-
autocast_ctx = torch.autocast(
|
46 |
|
47 |
with torch.no_grad():
|
48 |
with autocast_ctx:
|
@@ -77,6 +67,22 @@ def infer_pipe(pipe, image, task_name, seed, device):
|
|
77 |
|
78 |
return output_color
|
79 |
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
output_d = infer_pipe(pipe_d, image, task_name, seed, device)
|
82 |
return output_d # Only returning depth outputs for this application
|
|
|
6 |
from pipeline import LotusGPipeline, LotusDPipeline
|
7 |
from utils.image_utils import colorize_depth_map
|
8 |
from contextlib import nullcontext
|
9 |
+
import spaces # Import the spaces module for ZeroGPU
|
10 |
|
11 |
check_min_version('0.28.0.dev0')
|
12 |
|
|
|
19 |
model_d = 'jingheya/lotus-normal-d-v1-0'
|
20 |
|
21 |
dtype = torch.float16
|
22 |
+
# Models will be loaded inside the GPU-decorated function
|
23 |
+
return model_g, model_d, dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
@spaces.GPU
|
26 |
def infer_pipe(pipe, image, task_name, seed, device):
|
27 |
if seed is None:
|
28 |
generator = None
|
|
|
32 |
if torch.backends.mps.is_available():
|
33 |
autocast_ctx = nullcontext()
|
34 |
else:
|
35 |
+
autocast_ctx = torch.autocast(device_type='cuda', dtype=torch.float16)
|
36 |
|
37 |
with torch.no_grad():
|
38 |
with autocast_ctx:
|
|
|
67 |
|
68 |
return output_color
|
69 |
|
70 |
+
@spaces.GPU
|
71 |
+
def lotus(image, task_name, seed, device, model_g, model_d, dtype):
|
72 |
+
# Load models inside the GPU-decorated function
|
73 |
+
pipe_g = LotusGPipeline.from_pretrained(
|
74 |
+
model_g,
|
75 |
+
torch_dtype=dtype,
|
76 |
+
)
|
77 |
+
pipe_d = LotusDPipeline.from_pretrained(
|
78 |
+
model_d,
|
79 |
+
torch_dtype=dtype,
|
80 |
+
)
|
81 |
+
pipe_g.to(device)
|
82 |
+
pipe_d.to(device)
|
83 |
+
pipe_g.set_progress_bar_config(disable=True)
|
84 |
+
pipe_d.set_progress_bar_config(disable=True)
|
85 |
+
logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
|
86 |
+
|
87 |
output_d = infer_pipe(pipe_d, image, task_name, seed, device)
|
88 |
return output_d # Only returning depth outputs for this application
|