ghostsInTheMachine commited on
Commit
693892f
1 Parent(s): 6a4fc34

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +18 -25
infer.py CHANGED
@@ -33,7 +33,7 @@ def load_models(task_name, device):
33
  logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
34
  return pipe_g, pipe_d
35
 
36
- def infer_pipe(pipe, images_batch, task_name, seed, device):
37
  if seed is None:
38
  generator = None
39
  else:
@@ -46,44 +46,37 @@ def infer_pipe(pipe, images_batch, task_name, seed, device):
46
 
47
  with torch.no_grad():
48
  with autocast_ctx:
49
- # Convert list of images to tensor
50
- images = [np.array(img.convert('RGB')).astype(np.float32) for img in images_batch]
51
- test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images])
52
- test_images = test_images / 127.5 - 1.0
53
- test_images = test_images.to(device).type(torch.float16)
54
 
55
- # Ensure task_emb matches expected dimensions
56
- batch_size = test_images.shape[0]
57
  task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
58
  task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
59
- task_emb = task_emb.expand(batch_size, -1)
60
 
61
  # Run inference
62
- preds = pipe(
63
- rgb_in=test_images,
64
  prompt='',
65
  num_inference_steps=1,
66
  generator=generator,
67
  output_type='np',
68
  timesteps=[999],
69
  task_emb=task_emb,
70
- ).images
71
 
72
- # Post-process predictions
73
- outputs = []
74
  if task_name == 'depth':
75
- for p in preds:
76
- output_npy = p.mean(axis=-1)
77
- output_color = colorize_depth_map(output_npy)
78
- outputs.append(output_color)
79
  else:
80
- for p in preds:
81
- output_npy = p
82
- output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
83
- outputs.append(output_color)
84
 
85
- return outputs
86
 
87
- def lotus(images_batch, task_name, seed, device, pipe_g, pipe_d):
88
- output_d = infer_pipe(pipe_d, images_batch, task_name, seed, device)
89
  return output_d # Only returning depth outputs for this application
 
33
  logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
34
  return pipe_g, pipe_d
35
 
36
+ def infer_pipe(pipe, image, task_name, seed, device):
37
  if seed is None:
38
  generator = None
39
  else:
 
46
 
47
  with torch.no_grad():
48
  with autocast_ctx:
49
+ # Convert image to tensor
50
+ img = np.array(image.convert('RGB')).astype(np.float32)
51
+ test_image = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
52
+ test_image = test_image / 127.5 - 1.0
53
+ test_image = test_image.to(device).type(torch.float16)
54
 
55
+ # Create task_emb
 
56
  task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
57
  task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
 
58
 
59
  # Run inference
60
+ pred = pipe(
61
+ rgb_in=test_image,
62
  prompt='',
63
  num_inference_steps=1,
64
  generator=generator,
65
  output_type='np',
66
  timesteps=[999],
67
  task_emb=task_emb,
68
+ ).images[0]
69
 
70
+ # Post-process prediction
 
71
  if task_name == 'depth':
72
+ output_npy = pred.mean(axis=-1)
73
+ output_color = colorize_depth_map(output_npy)
 
 
74
  else:
75
+ output_npy = pred
76
+ output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
 
 
77
 
78
+ return output_color
79
 
80
+ def lotus(image, task_name, seed, device, pipe_g, pipe_d):
81
+ output_d = infer_pipe(pipe_d, image, task_name, seed, device)
82
  return output_d # Only returning depth outputs for this application