ghostsInTheMachine commited on
Commit
73b0806
1 Parent(s): d5d8098

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +6 -4
infer.py CHANGED
@@ -47,14 +47,16 @@ def infer_pipe(pipe, images_batch, task_name, seed, device):
47
  with torch.no_grad():
48
  with autocast_ctx:
49
  # Convert list of images to tensor
50
- images = [np.array(img.convert('RGB')).astype(np.float16) for img in images_batch]
51
  test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images])
52
  test_images = test_images / 127.5 - 1.0
53
- test_images = test_images.to(device)
54
 
55
- task_emb = torch.tensor([1, 0]).float().unsqueeze(0).to(device)
 
 
56
  task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
57
- task_emb = task_emb.repeat(len(test_images), 1)
58
 
59
  # Run inference
60
  preds = pipe(
 
47
  with torch.no_grad():
48
  with autocast_ctx:
49
  # Convert list of images to tensor
50
+ images = [np.array(img.convert('RGB')).astype(np.float32) for img in images_batch]
51
  test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images])
52
  test_images = test_images / 127.5 - 1.0
53
+ test_images = test_images.to(device).type(torch.float16)
54
 
55
+ # Ensure task_emb matches expected dimensions
56
+ batch_size = test_images.shape[0]
57
+ task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
58
  task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
59
+ task_emb = task_emb.repeat(batch_size, 1)
60
 
61
  # Run inference
62
  preds = pipe(