ghostsInTheMachine commited on
Commit
5c6dcd2
1 Parent(s): 173c7e2

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +1 -1
infer.py CHANGED
@@ -56,7 +56,7 @@ def infer_pipe(pipe, images_batch, task_name, seed, device):
56
  batch_size = test_images.shape[0]
57
  task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
58
  task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
59
- task_emb = task_emb.repeat(batch_size, 1)
60
 
61
  # Run inference
62
  preds = pipe(
 
56
  batch_size = test_images.shape[0]
57
  task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
58
  task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
59
+ task_emb = task_emb.expand(batch_size, -1)
60
 
61
  # Run inference
62
  preds = pipe(