ghostsInTheMachine commited on
Commit
12b0993
1 Parent(s): c2e328f

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +2 -2
infer.py CHANGED
@@ -47,10 +47,10 @@ def infer_pipe(pipe, image, task_name, seed, device):
47
  with torch.no_grad():
48
  with autocast_ctx:
49
  # Convert image to tensor
50
- img = np.array(image.convert('RGB')).astype(np.float32)
51
  test_image = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
52
  test_image = test_image / 127.5 - 1.0
53
- test_image = test_image.to(device).type(torch.float16)
54
 
55
  # Create task_emb
56
  task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
 
47
  with torch.no_grad():
48
  with autocast_ctx:
49
  # Convert image to tensor
50
+ img = np.array(image.convert('RGB')).astype(np.float16)
51
  test_image = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
52
  test_image = test_image / 127.5 - 1.0
53
+ test_image = test_image.to(device)
54
 
55
  # Create task_emb
56
  task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)