ghostsInTheMachine
commited on
Commit
•
8c25de0
1
Parent(s):
c71b96e
Update infer.py
Browse files
infer.py
CHANGED
@@ -44,40 +44,41 @@ def infer_pipe(pipe, images_batch, task_name, seed, device):
|
|
44 |
else:
|
45 |
autocast_ctx = torch.autocast(pipe.device.type)
|
46 |
|
47 |
-
with
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
|
82 |
return outputs
|
83 |
|
|
|
44 |
else:
|
45 |
autocast_ctx = torch.autocast(pipe.device.type)
|
46 |
|
47 |
+
with torch.no_grad():
|
48 |
+
with autocast_ctx:
|
49 |
+
# Convert list of images to tensor
|
50 |
+
images = [np.array(img.convert('RGB')).astype(np.float16) for img in images_batch]
|
51 |
+
test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images])
|
52 |
+
test_images = test_images / 127.5 - 1.0
|
53 |
+
test_images = test_images.to(device)
|
54 |
|
55 |
+
task_emb = torch.tensor([1, 0]).float().unsqueeze(0).to(device)
|
56 |
+
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
|
57 |
+
task_emb = task_emb.repeat(len(test_images), 1)
|
58 |
|
59 |
+
# Run inference
|
60 |
+
preds = pipe(
|
61 |
+
rgb_in=test_images,
|
62 |
+
prompt='',
|
63 |
+
num_inference_steps=1,
|
64 |
+
generator=generator,
|
65 |
+
output_type='np',
|
66 |
+
timesteps=[999],
|
67 |
+
task_emb=task_emb,
|
68 |
+
).images
|
69 |
|
70 |
+
# Post-process predictions
|
71 |
+
outputs = []
|
72 |
+
if task_name == 'depth':
|
73 |
+
for p in preds:
|
74 |
+
output_npy = p.mean(axis=-1)
|
75 |
+
output_color = colorize_depth_map(output_npy)
|
76 |
+
outputs.append(output_color)
|
77 |
+
else:
|
78 |
+
for p in preds:
|
79 |
+
output_npy = p
|
80 |
+
output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
|
81 |
+
outputs.append(output_color)
|
82 |
|
83 |
return outputs
|
84 |
|