Ashoka74 commited on
Commit
7b61338
1 Parent(s): c15f1a2

Upload inference_i2mv_sdxl.py

Browse files
Files changed (1) hide show
  1. inference_i2mv_sdxl.py +428 -0
inference_i2mv_sdxl.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from tqdm import tqdm
9
+ from transformers import AutoModelForImageSegmentation
10
+
11
+ import logging
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
15
+
16
+
17
+ from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline
18
+ from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
19
+ from mvadapter.utils import (
20
+ get_orthogonal_camera,
21
+ get_plucker_embeds_from_cameras_ortho,
22
+ make_image_grid,
23
+ )
24
+
25
+
26
+ def prepare_pipeline(
27
+ base_model,
28
+ vae_model,
29
+ unet_model,
30
+ lora_model,
31
+ adapter_path,
32
+ scheduler,
33
+ num_views,
34
+ device,
35
+ dtype,
36
+ ):
37
+ # Load vae and unet if provided
38
+ pipe_kwargs = {}
39
+ if vae_model is not None:
40
+ pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
41
+ if unet_model is not None:
42
+ pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
43
+
44
+ # Prepare pipeline
45
+ pipe: MVAdapterI2MVSDXLPipeline
46
+ pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
47
+
48
+ # Load scheduler if provided
49
+ scheduler_class = None
50
+ if scheduler == "ddpm":
51
+ scheduler_class = DDPMScheduler
52
+ elif scheduler == "lcm":
53
+ scheduler_class = LCMScheduler
54
+
55
+ pipe.scheduler = ShiftSNRScheduler.from_scheduler(
56
+ pipe.scheduler,
57
+ shift_mode="interpolated",
58
+ shift_scale=8.0,
59
+ scheduler_class=scheduler_class,
60
+ )
61
+ pipe.init_custom_adapter(num_views=num_views)
62
+ pipe.load_custom_adapter(
63
+ adapter_path, weight_name="mvadapter_i2mv_sdxl.safetensors"
64
+ )
65
+
66
+ pipe.to(device=device, dtype=dtype)
67
+ pipe.cond_encoder.to(device=device, dtype=dtype)
68
+
69
+ # load lora if provided
70
+ if lora_model is not None:
71
+ model_, name_ = lora_model.rsplit("/", 1)
72
+ pipe.load_lora_weights(model_, weight_name=name_)
73
+
74
+ # vae slicing for lower memory usage
75
+ pipe.enable_vae_slicing()
76
+
77
+ return pipe
78
+
79
+ def remove_bg(image: Image.Image, net, transform, device, mask: Image.Image = None):
80
+ """
81
+ Applies a pre-existing mask to an image to make the background transparent.
82
+
83
+ Args:
84
+ image (PIL.Image.Image): The input image.
85
+ net: Pre-trained neural network (not used but kept for compatibility).
86
+ transform: Image transformation object (not used but kept for compatibility).
87
+ device: Device used for inference (not used but kept for compatibility).
88
+ mask (PIL.Image.Image, optional): The mask to use. Should be the same size
89
+ as the input image, with values between 0 and 255 (or 0-1).
90
+ If None, will return image with no changes.
91
+
92
+ Returns:
93
+ PIL.Image.Image: The modified image with transparent background.
94
+ """
95
+ if mask is None:
96
+ return image
97
+
98
+ image_size = image.size
99
+ if mask.size != image_size:
100
+ mask = mask.resize(image_size) # Resizing the mask if it is not the same size as image
101
+
102
+ image.putalpha(mask)
103
+ return image
104
+
105
+
106
+ # def remove_bg(image, net, transform, device):
107
+ # image_size = image.size
108
+ # input_images = transform(image).unsqueeze(0).to(device)
109
+ # with torch.no_grad():
110
+ # preds = net(input_images)[0].sigmoid().cpu()
111
+ # #preds = net(input_images)[-1] if isinstance(net(input_images), list) else net(input_images)
112
+ # pred = preds[0].squeeze()
113
+ # pred_pil = transforms.ToPILImage()(pred)
114
+ # mask = pred_pil.resize(image_size)
115
+ # image.putalpha(mask)
116
+ # return image
117
+
118
+
119
+ # def remove_bg(image: Image.Image, net, transform, device):
120
+ # """
121
+ # Applies a pre-existing mask to an image to make the background transparent.
122
+ # Args:
123
+ # image (PIL.Image.Image): The input image.
124
+ # net: Pre-trained neural network (not used but kept for compatibility).
125
+ # transform: Image transformation object (not used but kept for compatibility).
126
+ # device: Device used for inference (not used but kept for compatibility).
127
+ # Returns:
128
+ # PIL.Image.Image: The modified image with transparent background.
129
+ # """
130
+ # image_size = image.size
131
+ # input_images = transform(image).unsqueeze(0).to(device)
132
+
133
+ # with torch.no_grad():
134
+ # preds = net(input_images)[-1].sigmoid().cpu()
135
+
136
+ # pred = preds[0].squeeze()
137
+ # pred_pil = transforms.ToPILImage()(pred)
138
+
139
+ # # Resize the mask to match the original image size
140
+ # mask = pred_pil.resize(image_size, Image.LANCZOS)
141
+
142
+ # # Create a new image with the same size and mode as the original
143
+ # output_image = Image.new("RGBA", image_size)
144
+
145
+ # # Apply the mask to the original image
146
+ # image.putalpha(mask)
147
+
148
+ # # Composite the original image with the mask
149
+ # output_image.paste(image, (0, 0), image)
150
+
151
+ # return output_image
152
+
153
+
154
+ def remove_bg(image: Image.Image, net, transform, device, mask: np.ndarray = None):
155
+ """
156
+ Applies a pre-existing mask to an image to make the background transparent.
157
+
158
+ Args:
159
+ image (PIL.Image.Image): The input image.
160
+ net: Pre-trained neural network (not used but kept for compatibility).
161
+ transform: Image transformation object (not used but kept for compatibility).
162
+ device: Device used for inference (not used but kept for compatibility).
163
+ mask (np.ndarray, optional): The mask to use. Should be the same size
164
+ as the input image, with values between 0 and 255.
165
+ If None, will return image with no changes.
166
+
167
+ Returns:
168
+ PIL.Image.Image: The modified image with transparent background.
169
+ """
170
+ if mask is None:
171
+ return image
172
+
173
+ # Ensure the mask is in the correct format
174
+ if mask.ndim == 2: # If mask is 2D (H, W)
175
+ mask = mask.astype(np.uint8) # Ensure mask is uint8
176
+ mask = np.expand_dims(mask, axis=-1) # Add channel dimension
177
+
178
+ # Convert the mask to PIL Image
179
+ mask_pil = Image.fromarray(mask.squeeze(2) * 255) # Convert to binary mask
180
+
181
+ # Resize the mask to match the original image size
182
+ mask_pil = mask_pil.resize(image.size, Image.LANCZOS)
183
+
184
+ # Create a new image with the same size and mode as the original
185
+ output_image = Image.new("RGBA", image.size)
186
+
187
+ # Apply the mask to the original image
188
+ image.putalpha(mask_pil)
189
+
190
+ # Composite the original image with the mask
191
+ output_image.paste(image, (0, 0), image)
192
+
193
+ return output_image
194
+
195
+
196
+ # def preprocess_image(image: Image.Image, height, width):
197
+
198
+ # alpha = image[..., 3] > 0
199
+ # # alpha = image
200
+
201
+ # #if image.mode in ("RGBA", "LA"):
202
+ # # image = np.array(image)
203
+ # # alpha = image[..., 3] # Extract the alpha channel
204
+ # #elif image.mode in ("RGB"):
205
+ # # image = np.array(image)
206
+ # # Create default alpha for non-alpha images
207
+ # # alpha = np.ones(image[..., 0].shape, dtype=np.uint8) * 255 # Create
208
+ # H, W = alpha.shape
209
+ # # get the bounding box of alpha
210
+ # y, x = np.where(alpha)
211
+ # y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
212
+ # x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
213
+ # image_center = image[y0:y1, x0:x1]
214
+ # # resize the longer side to H * 0.9
215
+ # H, W, _ = image_center.shape
216
+ # if H > W:
217
+ # W = int(W * (height * 0.9) / H)
218
+ # H = int(height * 0.9)
219
+ # else:
220
+ # H = int(H * (width * 0.9) / W)
221
+ # W = int(width * 0.9)
222
+ # image_center = np.array(Image.fromarray(image_center).resize((W, H)))
223
+ # # pad to H, W
224
+ # start_h = (height - H) // 2
225
+ # start_w = (width - W) // 2
226
+ # image = np.zeros((height, width, 4), dtype=np.uint8)
227
+ # image[start_h : start_h + H, start_w : start_w + W] = image_center
228
+ # image = image.astype(np.float32) / 255.0
229
+ # image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
230
+ # image = (image * 255).clip(0, 255).astype(np.uint8)
231
+ # image = Image.fromarray(image)
232
+
233
+ # return image
234
+
235
+ def preprocess_image(image: Image.Image, height, width):
236
+ # Convert image to numpy array
237
+ image_np = np.array(image)
238
+
239
+ # Extract the alpha channel if present
240
+ if image_np.shape[-1] == 4:
241
+ alpha = image_np[..., 3] > 0 # Create a binary mask from the alpha channel
242
+ else:
243
+ alpha = np.ones(image_np[..., 0].shape, dtype=bool) # Default to all true for RGB images
244
+
245
+ H, W = alpha.shape
246
+ # Get the bounding box of the alpha
247
+ y, x = np.where(alpha)
248
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
249
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
250
+ image_center = image_np[y0:y1, x0:x1]
251
+
252
+ # Resize the longer side to H * 0.9
253
+ H, W, _ = image_center.shape
254
+ if H > W:
255
+ W = int(W * (height * 0.9) / H)
256
+ H = int(height * 0.9)
257
+ else:
258
+ H = int(H * (width * 0.9) / W)
259
+ W = int(width * 0.9)
260
+
261
+ image_center = np.array(Image.fromarray(image_center).resize((W, H)))
262
+
263
+ # Pad to H, W
264
+ start_h = (height - H) // 2
265
+ start_w = (width - W) // 2
266
+ padded_image = np.zeros((height, width, 4), dtype=np.uint8)
267
+ padded_image[start_h:start_h + H, start_w:start_w + W] = image_center
268
+
269
+ # Convert back to PIL Image
270
+ return Image.fromarray(padded_image)
271
+
272
+
273
+ def run_pipeline(
274
+ pipe,
275
+ num_views,
276
+ text,
277
+ image,
278
+ height,
279
+ width,
280
+ num_inference_steps,
281
+ guidance_scale,
282
+ seed,
283
+ remove_bg_fn=None,
284
+ reference_conditioning_scale=1.0,
285
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
286
+ lora_scale=1.0,
287
+ device="cuda",
288
+ ):
289
+ # Prepare cameras
290
+ cameras = get_orthogonal_camera(
291
+ elevation_deg=[0, 0, 0, 0, 0, 0],
292
+ distance=[1.8] * num_views,
293
+ left=-0.55,
294
+ right=0.55,
295
+ bottom=-0.55,
296
+ top=0.55,
297
+ azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]],
298
+ device=device,
299
+ )
300
+
301
+ plucker_embeds = get_plucker_embeds_from_cameras_ortho(
302
+ cameras.c2w, [1.1] * num_views, width
303
+ )
304
+ control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1)
305
+
306
+ # Prepare image
307
+ # reference_image = Image.open(image) if isinstance(image, str) else image
308
+ # if remove_bg_fn is not None:
309
+ # reference_image = remove_bg_fn(reference_image)
310
+ # reference_image = preprocess_image(reference_image, height, width)
311
+ # elif reference_image.mode == "RGBA":
312
+ # reference_image = preprocess_image(reference_image, height, width)
313
+ reference_image = Image.open(image) if isinstance(image, str) else image
314
+ logging.info(f"Initial reference_image mode: {reference_image.mode}")
315
+
316
+ if remove_bg_fn is not None:
317
+ logging.info("Using remove_bg_fn")
318
+ reference_image = remove_bg_fn(reference_image)
319
+ reference_image = preprocess_image(reference_image, height, width)
320
+ elif reference_image.mode == "RGBA":
321
+ logging.info("Image is RGBA, preprocessing directly")
322
+ reference_image = preprocess_image(reference_image, height, width)
323
+
324
+ logging.info(f"Final reference_image mode: {reference_image.mode}")
325
+
326
+ pipe_kwargs = {}
327
+ if seed != -1 and isinstance(seed, int):
328
+ pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
329
+
330
+ images = pipe(
331
+ text,
332
+ height=height,
333
+ width=width,
334
+ num_inference_steps=num_inference_steps,
335
+ guidance_scale=guidance_scale,
336
+ num_images_per_prompt=num_views,
337
+ control_image=control_images,
338
+ control_conditioning_scale=1.0,
339
+ reference_image=reference_image,
340
+ reference_conditioning_scale=reference_conditioning_scale,
341
+ negative_prompt=negative_prompt,
342
+ cross_attention_kwargs={"scale": lora_scale},
343
+ **pipe_kwargs,
344
+ ).images
345
+
346
+ return images, reference_image
347
+
348
+
349
+ if __name__ == "__main__":
350
+ parser = argparse.ArgumentParser()
351
+ # Models
352
+ parser.add_argument(
353
+ "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
354
+ )
355
+ parser.add_argument(
356
+ "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix"
357
+ )
358
+ parser.add_argument("--unet_model", type=str, default=None)
359
+ parser.add_argument("--scheduler", type=str, default=None)
360
+ parser.add_argument("--lora_model", type=str, default=None)
361
+ parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
362
+ parser.add_argument("--num_views", type=int, default=6)
363
+ # Device
364
+ parser.add_argument("--device", type=str, default="cuda")
365
+ # Inference
366
+ parser.add_argument("--image", type=str, required=True)
367
+ parser.add_argument("--text", type=str, default="high quality")
368
+ parser.add_argument("--num_inference_steps", type=int, default=50)
369
+ parser.add_argument("--guidance_scale", type=float, default=3.0)
370
+ parser.add_argument("--seed", type=int, default=-1)
371
+ parser.add_argument("--lora_scale", type=float, default=1.0)
372
+ parser.add_argument("--reference_conditioning_scale", type=float, default=1.0)
373
+ parser.add_argument(
374
+ "--negative_prompt",
375
+ type=str,
376
+ default="watermark, ugly, deformed, noisy, blurry, low contrast",
377
+ )
378
+ parser.add_argument("--output", type=str, default="output.png")
379
+ # Extra
380
+ parser.add_argument("--remove_bg", action="store_true", help="Remove background")
381
+ args = parser.parse_args()
382
+
383
+ pipe = prepare_pipeline(
384
+ base_model=args.base_model,
385
+ vae_model=args.vae_model,
386
+ unet_model=args.unet_model,
387
+ lora_model=args.lora_model,
388
+ adapter_path=args.adapter_path,
389
+ scheduler=args.scheduler,
390
+ num_views=args.num_views,
391
+ device=args.device,
392
+ dtype=torch.float16,
393
+ )
394
+
395
+ if args.remove_bg:
396
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
397
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
398
+ )
399
+ birefnet.to(args.device)
400
+ transform_image = transforms.Compose(
401
+ [
402
+ transforms.Resize((1024, 1024)),
403
+ transforms.ToTensor(),
404
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
405
+ ]
406
+ )
407
+ remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device)
408
+ else:
409
+ remove_bg_fn = None
410
+
411
+ images, reference_image = run_pipeline(
412
+ pipe,
413
+ num_views=args.num_views,
414
+ text=args.text,
415
+ image=args.image,
416
+ height=768,
417
+ width=768,
418
+ num_inference_steps=args.num_inference_steps,
419
+ guidance_scale=args.guidance_scale,
420
+ seed=args.seed,
421
+ lora_scale=args.lora_scale,
422
+ reference_conditioning_scale=args.reference_conditioning_scale,
423
+ negative_prompt=args.negative_prompt,
424
+ device=args.device,
425
+ remove_bg_fn=remove_bg_fn,
426
+ )
427
+ make_image_grid(images, rows=1).save(args.output)
428
+ reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png")