ghostsInTheMachine commited on
Commit
8c25de0
1 Parent(s): c71b96e

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +32 -31
infer.py CHANGED
@@ -44,40 +44,41 @@ def infer_pipe(pipe, images_batch, task_name, seed, device):
44
  else:
45
  autocast_ctx = torch.autocast(pipe.device.type)
46
 
47
- with autocast_ctx:
48
- # Convert list of images to tensor
49
- images = [np.array(img.convert('RGB')).astype(np.float16) for img in images_batch]
50
- test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images])
51
- test_images = test_images / 127.5 - 1.0
52
- test_images = test_images.to(device)
 
53
 
54
- task_emb = torch.tensor([1, 0]).float().unsqueeze(0).to(device)
55
- task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
56
- task_emb = task_emb.repeat(len(test_images), 1)
57
 
58
- # Run inference
59
- preds = pipe(
60
- rgb_in=test_images,
61
- prompt='',
62
- num_inference_steps=1,
63
- generator=generator,
64
- output_type='np',
65
- timesteps=[999],
66
- task_emb=task_emb,
67
- ).images
68
 
69
- # Post-process predictions
70
- outputs = []
71
- if task_name == 'depth':
72
- for p in preds:
73
- output_npy = p.mean(axis=-1)
74
- output_color = colorize_depth_map(output_npy)
75
- outputs.append(output_color)
76
- else:
77
- for p in preds:
78
- output_npy = p
79
- output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
80
- outputs.append(output_color)
81
 
82
  return outputs
83
 
 
44
  else:
45
  autocast_ctx = torch.autocast(pipe.device.type)
46
 
47
+ with torch.no_grad():
48
+ with autocast_ctx:
49
+ # Convert list of images to tensor
50
+ images = [np.array(img.convert('RGB')).astype(np.float16) for img in images_batch]
51
+ test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images])
52
+ test_images = test_images / 127.5 - 1.0
53
+ test_images = test_images.to(device)
54
 
55
+ task_emb = torch.tensor([1, 0]).float().unsqueeze(0).to(device)
56
+ task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
57
+ task_emb = task_emb.repeat(len(test_images), 1)
58
 
59
+ # Run inference
60
+ preds = pipe(
61
+ rgb_in=test_images,
62
+ prompt='',
63
+ num_inference_steps=1,
64
+ generator=generator,
65
+ output_type='np',
66
+ timesteps=[999],
67
+ task_emb=task_emb,
68
+ ).images
69
 
70
+ # Post-process predictions
71
+ outputs = []
72
+ if task_name == 'depth':
73
+ for p in preds:
74
+ output_npy = p.mean(axis=-1)
75
+ output_color = colorize_depth_map(output_npy)
76
+ outputs.append(output_color)
77
+ else:
78
+ for p in preds:
79
+ output_npy = p
80
+ output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
81
+ outputs.append(output_color)
82
 
83
  return outputs
84