ghostsInTheMachine
commited on
Commit
•
693892f
1
Parent(s):
6a4fc34
Update infer.py
Browse files
infer.py
CHANGED
@@ -33,7 +33,7 @@ def load_models(task_name, device):
|
|
33 |
logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
|
34 |
return pipe_g, pipe_d
|
35 |
|
36 |
-
def infer_pipe(pipe,
|
37 |
if seed is None:
|
38 |
generator = None
|
39 |
else:
|
@@ -46,44 +46,37 @@ def infer_pipe(pipe, images_batch, task_name, seed, device):
|
|
46 |
|
47 |
with torch.no_grad():
|
48 |
with autocast_ctx:
|
49 |
-
# Convert
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
|
55 |
-
#
|
56 |
-
batch_size = test_images.shape[0]
|
57 |
task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
|
58 |
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
|
59 |
-
task_emb = task_emb.expand(batch_size, -1)
|
60 |
|
61 |
# Run inference
|
62 |
-
|
63 |
-
rgb_in=
|
64 |
prompt='',
|
65 |
num_inference_steps=1,
|
66 |
generator=generator,
|
67 |
output_type='np',
|
68 |
timesteps=[999],
|
69 |
task_emb=task_emb,
|
70 |
-
).images
|
71 |
|
72 |
-
# Post-process
|
73 |
-
outputs = []
|
74 |
if task_name == 'depth':
|
75 |
-
|
76 |
-
|
77 |
-
output_color = colorize_depth_map(output_npy)
|
78 |
-
outputs.append(output_color)
|
79 |
else:
|
80 |
-
|
81 |
-
|
82 |
-
output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
|
83 |
-
outputs.append(output_color)
|
84 |
|
85 |
-
return
|
86 |
|
87 |
-
def lotus(
|
88 |
-
output_d = infer_pipe(pipe_d,
|
89 |
return output_d # Only returning depth outputs for this application
|
|
|
33 |
logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
|
34 |
return pipe_g, pipe_d
|
35 |
|
36 |
+
def infer_pipe(pipe, image, task_name, seed, device):
|
37 |
if seed is None:
|
38 |
generator = None
|
39 |
else:
|
|
|
46 |
|
47 |
with torch.no_grad():
|
48 |
with autocast_ctx:
|
49 |
+
# Convert image to tensor
|
50 |
+
img = np.array(image.convert('RGB')).astype(np.float32)
|
51 |
+
test_image = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
|
52 |
+
test_image = test_image / 127.5 - 1.0
|
53 |
+
test_image = test_image.to(device).type(torch.float16)
|
54 |
|
55 |
+
# Create task_emb
|
|
|
56 |
task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
|
57 |
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
|
|
|
58 |
|
59 |
# Run inference
|
60 |
+
pred = pipe(
|
61 |
+
rgb_in=test_image,
|
62 |
prompt='',
|
63 |
num_inference_steps=1,
|
64 |
generator=generator,
|
65 |
output_type='np',
|
66 |
timesteps=[999],
|
67 |
task_emb=task_emb,
|
68 |
+
).images[0]
|
69 |
|
70 |
+
# Post-process prediction
|
|
|
71 |
if task_name == 'depth':
|
72 |
+
output_npy = pred.mean(axis=-1)
|
73 |
+
output_color = colorize_depth_map(output_npy)
|
|
|
|
|
74 |
else:
|
75 |
+
output_npy = pred
|
76 |
+
output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
|
|
|
|
|
77 |
|
78 |
+
return output_color
|
79 |
|
80 |
+
def lotus(image, task_name, seed, device, pipe_g, pipe_d):
|
81 |
+
output_d = infer_pipe(pipe_d, image, task_name, seed, device)
|
82 |
return output_d # Only returning depth outputs for this application
|