ghostsInTheMachine commited on
Commit
c71b96e
1 Parent(s): afe7cc3

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +45 -309
infer.py CHANGED
@@ -1,71 +1,18 @@
1
- # from utils.args import parse_args
2
  import logging
3
- import os
4
- import argparse
5
- from pathlib import Path
6
- from PIL import Image
7
-
8
- import numpy as np
9
  import torch
10
- from tqdm.auto import tqdm
 
11
  from diffusers.utils import check_min_version
12
-
13
  from pipeline import LotusGPipeline, LotusDPipeline
14
  from utils.image_utils import colorize_depth_map
15
- from utils.seed_all import seed_all
16
-
17
  from contextlib import nullcontext
18
- import cv2
19
 
20
  check_min_version('0.28.0.dev0')
21
 
22
- def infer_pipe(pipe, image_input, task_name, seed, device):
23
- if seed is None:
24
- generator = None
25
- else:
26
- generator = torch.Generator(device=device).manual_seed(seed)
27
-
28
- if torch.backends.mps.is_available():
29
- autocast_ctx = nullcontext()
30
- else:
31
- autocast_ctx = torch.autocast(pipe.device.type)
32
- with autocast_ctx:
33
-
34
- test_image = Image.open(image_input).convert('RGB')
35
- test_image = np.array(test_image).astype(np.float16)
36
- test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
37
- test_image = test_image / 127.5 - 1.0
38
- test_image = test_image.to(device)
39
-
40
- task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
41
- task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
42
-
43
- # Run
44
- pred = pipe(
45
- rgb_in=test_image,
46
- prompt='',
47
- num_inference_steps=1,
48
- generator=generator,
49
- # guidance_scale=0,
50
- output_type='np',
51
- timesteps=[999],
52
- task_emb=task_emb,
53
- ).images[0]
54
-
55
- # Post-process the prediction
56
- if task_name == 'depth':
57
- output_npy = pred.mean(axis=-1)
58
- output_color = colorize_depth_map(output_npy)
59
- else:
60
- output_npy = pred
61
- output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
62
-
63
- return output_color
64
-
65
- def lotus_video(input_video, task_name, seed, device):
66
  if task_name == 'depth':
67
  model_g = 'jingheya/lotus-depth-g-v1-0'
68
- model_d = 'jingheya/lotus-depth-d-v1-0'
69
  else:
70
  model_g = 'jingheya/lotus-normal-g-v1-0'
71
  model_d = 'jingheya/lotus-normal-d-v1-0'
@@ -83,268 +30,57 @@ def lotus_video(input_video, task_name, seed, device):
83
  pipe_d.to(device)
84
  pipe_g.set_progress_bar_config(disable=True)
85
  pipe_d.set_progress_bar_config(disable=True)
86
- logging.info(f"Successfully loading pipeline from {model_g} and {model_d}.")
87
-
88
- # load the video and split it into frames
89
- cap = cv2.VideoCapture(input_video)
90
- frames = []
91
- while True:
92
- ret, frame = cap.read()
93
- if not ret:
94
- break
95
- frames.append(frame)
96
- cap.release()
97
- logging.info(f"There are {len(frames)} frames in the video.")
98
 
 
99
  if seed is None:
100
  generator = None
101
  else:
102
  generator = torch.Generator(device=device).manual_seed(seed)
103
 
104
- task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
105
- task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
106
-
107
- output_g = []
108
- output_d = []
109
- for frame in frames:
110
- if torch.backends.mps.is_available():
111
- autocast_ctx = nullcontext()
112
- else:
113
- autocast_ctx = torch.autocast(pipe_g.device.type)
114
- with autocast_ctx:
115
- test_image = frame
116
- test_image = np.array(test_image).astype(np.float16)
117
- test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
118
- test_image = test_image / 127.5 - 1.0
119
- test_image = test_image.to(device)
120
-
121
- # Run
122
- pred_g = pipe_g(
123
- rgb_in=test_image,
124
- prompt='',
125
- num_inference_steps=1,
126
- generator=generator,
127
- # guidance_scale=0,
128
- output_type='np',
129
- timesteps=[999],
130
- task_emb=task_emb,
131
- ).images[0]
132
- pred_d = pipe_d(
133
- rgb_in=test_image,
134
- prompt='',
135
- num_inference_steps=1,
136
- generator=generator,
137
- # guidance_scale=0,
138
- output_type='np',
139
- timesteps=[999],
140
- task_emb=task_emb,
141
- ).images[0]
142
-
143
- # Post-process the prediction
144
- if task_name == 'depth':
145
- output_npy_g = pred_g.mean(axis=-1)
146
- output_color_g = colorize_depth_map(output_npy_g)
147
- output_npy_d = pred_d.mean(axis=-1)
148
- output_color_d = colorize_depth_map(output_npy_d)
149
- else:
150
- output_npy_g = pred_g
151
- output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8))
152
- output_npy_d = pred_d
153
- output_color_d = Image.fromarray((output_npy_d * 255).astype(np.uint8))
154
-
155
- output_g.append(output_color_g)
156
- output_d.append(output_color_d)
157
-
158
- return output_g, output_d
159
-
160
- def lotus(image_input, task_name, seed, device):
161
- if task_name == 'depth':
162
- model_g = 'jingheya/lotus-depth-g-v1-0'
163
- model_d = 'jingheya/lotus-depth-d-v1-1'
164
- else:
165
- model_g = 'jingheya/lotus-normal-g-v1-0'
166
- model_d = 'jingheya/lotus-normal-d-v1-0'
167
-
168
- dtype = torch.float16
169
- pipe_g = LotusGPipeline.from_pretrained(
170
- model_g,
171
- torch_dtype=dtype,
172
- )
173
- pipe_d = LotusDPipeline.from_pretrained(
174
- model_d,
175
- torch_dtype=dtype,
176
- )
177
- pipe_g.to(device)
178
- pipe_d.to(device)
179
- pipe_g.set_progress_bar_config(disable=True)
180
- pipe_d.set_progress_bar_config(disable=True)
181
- logging.info(f"Successfully loading pipeline from {model_g} and {model_d}.")
182
- output_g = infer_pipe(pipe_g, image_input, task_name, seed, device)
183
- output_d = infer_pipe(pipe_d, image_input, task_name, seed, device)
184
- return output_g, output_d
185
-
186
- def parse_args():
187
- '''Set the Args'''
188
- parser = argparse.ArgumentParser(
189
- description="Run Lotus..."
190
- )
191
- # model settings
192
- parser.add_argument(
193
- "--pretrained_model_name_or_path",
194
- type=str,
195
- default=None,
196
- help="pretrained model path from hugging face or local dir",
197
- )
198
- parser.add_argument(
199
- "--prediction_type",
200
- type=str,
201
- default="sample",
202
- help="The used prediction_type. ",
203
- )
204
- parser.add_argument(
205
- "--timestep",
206
- type=int,
207
- default=999,
208
- )
209
- parser.add_argument(
210
- "--mode",
211
- type=str,
212
- default="regression", # "generation"
213
- help="Whether to use the generation or regression pipeline."
214
- )
215
- parser.add_argument(
216
- "--task_name",
217
- type=str,
218
- default="depth", # "normal"
219
- )
220
- parser.add_argument(
221
- "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
222
- )
223
-
224
- # inference settings
225
- parser.add_argument("--seed", type=int, default=None, help="Random seed.")
226
- parser.add_argument(
227
- "--output_dir", type=str, required=True, help="Output directory."
228
- )
229
- parser.add_argument(
230
- "--input_dir", type=str, required=True, help="Input directory."
231
- )
232
- parser.add_argument(
233
- "--half_precision",
234
- action="store_true",
235
- help="Run with half-precision (16-bit float), might lead to suboptimal result.",
236
- )
237
-
238
- args = parser.parse_args()
239
-
240
- return args
241
-
242
- def main():
243
- logging.basicConfig(level=logging.INFO)
244
- logging.info(f"Run inference...")
245
-
246
- args = parse_args()
247
-
248
- # -------------------- Preparation --------------------
249
- # Random seed
250
- if args.seed is not None:
251
- seed_all(args.seed)
252
-
253
- # Output directories
254
- os.makedirs(args.output_dir, exist_ok=True)
255
- logging.info(f"Output dir = {args.output_dir}")
256
-
257
- output_dir_color = os.path.join(args.output_dir, f'{args.task_name}_vis')
258
- output_dir_npy = os.path.join(args.output_dir, f'{args.task_name}')
259
- if not os.path.exists(output_dir_color): os.makedirs(output_dir_color)
260
- if not os.path.exists(output_dir_npy): os.makedirs(output_dir_npy)
261
-
262
- # half_precision
263
- if args.half_precision:
264
- dtype = torch.float16
265
- logging.info(f"Running with half precision ({dtype}).")
266
- else:
267
- dtype = torch.float16
268
-
269
- # -------------------- Device --------------------
270
- if torch.cuda.is_available():
271
- device = torch.device("cuda")
272
- else:
273
- device = torch.device("cpu")
274
- logging.warning("CUDA is not available. Running on CPU will be slow.")
275
- logging.info(f"Device = {device}")
276
-
277
- # -------------------- Data --------------------
278
- root_dir = Path(args.input_dir)
279
- test_images = list(root_dir.rglob('*.png')) + list(root_dir.rglob('*.jpg'))
280
- test_images = sorted(test_images)
281
- print('==> There are', len(test_images), 'images for validation.')
282
- # -------------------- Model --------------------
283
-
284
- if args.mode == 'generation':
285
- pipeline = LotusGPipeline.from_pretrained(
286
- args.pretrained_model_name_or_path,
287
- torch_dtype=dtype,
288
- )
289
- elif args.mode == 'regression':
290
- pipeline = LotusDPipeline.from_pretrained(
291
- args.pretrained_model_name_or_path,
292
- torch_dtype=dtype,
293
- )
294
- else:
295
- raise ValueError(f'Invalid mode: {args.mode}')
296
- logging.info(f"Successfully loading pipeline from {args.pretrained_model_name_or_path}.")
297
-
298
- pipeline = pipeline.to(device)
299
- pipeline.set_progress_bar_config(disable=True)
300
-
301
- if args.enable_xformers_memory_efficient_attention:
302
- pipeline.enable_xformers_memory_efficient_attention()
303
-
304
-
305
- if args.seed is None:
306
- generator = None
307
  else:
308
- generator = torch.Generator(device=device).manual_seed(args.seed)
309
-
310
- # -------------------- Inference and saving --------------------
311
- with torch.no_grad():
312
- for i in tqdm(range(len(test_images))):
313
- # Preprocess validation image
314
- test_image = Image.open(test_images[i]).convert('RGB')
315
- test_image = np.array(test_image).astype(np.float16)
316
- test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
317
- test_image = test_image / 127.5 - 1.0
318
- test_image = test_image.to(device)
319
-
320
- task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
321
- task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
322
 
323
- # Run
324
- pred = pipeline(
325
- rgb_in=test_image,
326
- prompt='',
327
- num_inference_steps=1,
328
- generator=generator,
329
- # guidance_scale=0,
330
- output_type='np',
331
- timesteps=[args.timestep],
332
- task_emb=task_emb,
333
- ).images[0]
 
 
 
 
 
 
 
 
 
 
334
 
335
- # Post-process the prediction
336
- save_file_name = os.path.basename(test_images[i])[:-4]
337
- if args.task_name == 'depth':
338
- output_npy = pred.mean(axis=-1)
 
339
  output_color = colorize_depth_map(output_npy)
340
- else:
341
- output_npy = pred
 
 
342
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
 
343
 
344
- output_color.save(os.path.join(output_dir_color, f'{save_file_name}.png'))
345
- np.save(os.path.join(output_dir_npy, f'{save_file_name}.npy'), output_npy)
346
-
347
- print('==> Inference is done. \n==> Results saved to:', args.output_dir)
348
 
349
- if __name__ == '__main__':
350
- main()
 
 
 
1
  import logging
 
 
 
 
 
 
2
  import torch
3
+ import numpy as np
4
+ from PIL import Image
5
  from diffusers.utils import check_min_version
 
6
  from pipeline import LotusGPipeline, LotusDPipeline
7
  from utils.image_utils import colorize_depth_map
 
 
8
  from contextlib import nullcontext
 
9
 
10
  check_min_version('0.28.0.dev0')
11
 
12
+ def load_models(task_name, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  if task_name == 'depth':
14
  model_g = 'jingheya/lotus-depth-g-v1-0'
15
+ model_d = 'jingheya/lotus-depth-d-v1-1'
16
  else:
17
  model_g = 'jingheya/lotus-normal-g-v1-0'
18
  model_d = 'jingheya/lotus-normal-d-v1-0'
 
30
  pipe_d.to(device)
31
  pipe_g.set_progress_bar_config(disable=True)
32
  pipe_d.set_progress_bar_config(disable=True)
33
+ logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
34
+ return pipe_g, pipe_d
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def infer_pipe(pipe, images_batch, task_name, seed, device):
37
  if seed is None:
38
  generator = None
39
  else:
40
  generator = torch.Generator(device=device).manual_seed(seed)
41
 
42
+ if torch.backends.mps.is_available():
43
+ autocast_ctx = nullcontext()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  else:
45
+ autocast_ctx = torch.autocast(pipe.device.type)
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ with autocast_ctx:
48
+ # Convert list of images to tensor
49
+ images = [np.array(img.convert('RGB')).astype(np.float16) for img in images_batch]
50
+ test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images])
51
+ test_images = test_images / 127.5 - 1.0
52
+ test_images = test_images.to(device)
53
+
54
+ task_emb = torch.tensor([1, 0]).float().unsqueeze(0).to(device)
55
+ task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
56
+ task_emb = task_emb.repeat(len(test_images), 1)
57
+
58
+ # Run inference
59
+ preds = pipe(
60
+ rgb_in=test_images,
61
+ prompt='',
62
+ num_inference_steps=1,
63
+ generator=generator,
64
+ output_type='np',
65
+ timesteps=[999],
66
+ task_emb=task_emb,
67
+ ).images
68
 
69
+ # Post-process predictions
70
+ outputs = []
71
+ if task_name == 'depth':
72
+ for p in preds:
73
+ output_npy = p.mean(axis=-1)
74
  output_color = colorize_depth_map(output_npy)
75
+ outputs.append(output_color)
76
+ else:
77
+ for p in preds:
78
+ output_npy = p
79
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
80
+ outputs.append(output_color)
81
 
82
+ return outputs
 
 
 
83
 
84
+ def lotus(images_batch, task_name, seed, device, pipe_g, pipe_d):
85
+ output_d = infer_pipe(pipe_d, images_batch, task_name, seed, device)
86
+ return output_d # Only returning depth outputs for this application