ghostsInTheMachine
commited on
Commit
•
5c6dcd2
1
Parent(s):
173c7e2
Update infer.py
Browse files
infer.py
CHANGED
@@ -56,7 +56,7 @@ def infer_pipe(pipe, images_batch, task_name, seed, device):
|
|
56 |
batch_size = test_images.shape[0]
|
57 |
task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
|
58 |
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
|
59 |
-
task_emb = task_emb.
|
60 |
|
61 |
# Run inference
|
62 |
preds = pipe(
|
|
|
56 |
batch_size = test_images.shape[0]
|
57 |
task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
|
58 |
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
|
59 |
+
task_emb = task_emb.expand(batch_size, -1)
|
60 |
|
61 |
# Run inference
|
62 |
preds = pipe(
|