kimjy0411 commited on
Commit
7a68b01
·
verified ·
1 Parent(s): f6236c5

Upload src/pipelines/pipeline_pose2img.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/pipelines/pipeline_pose2img.py +374 -0
src/pipelines/pipeline_pose2img.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchvision.transforms as transforms
9
+ from diffusers import DiffusionPipeline
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.schedulers import (
12
+ DDIMScheduler,
13
+ DPMSolverMultistepScheduler,
14
+ EulerAncestralDiscreteScheduler,
15
+ EulerDiscreteScheduler,
16
+ LMSDiscreteScheduler,
17
+ PNDMScheduler,
18
+ )
19
+ from diffusers.utils import BaseOutput, is_accelerate_available
20
+ from diffusers.utils.torch_utils import randn_tensor
21
+ from einops import rearrange
22
+ from tqdm import tqdm
23
+ from transformers import CLIPImageProcessor
24
+
25
+ from src.models.mutual_self_attention import ReferenceAttentionControl
26
+
27
+
28
+ @dataclass
29
+ class Pose2ImagePipelineOutput(BaseOutput):
30
+ images: Union[torch.Tensor, np.ndarray]
31
+
32
+
33
+ class Pose2ImagePipeline(DiffusionPipeline):
34
+ _optional_components = []
35
+
36
+ def __init__(
37
+ self,
38
+ vae,
39
+ image_encoder,
40
+ reference_unet,
41
+ denoising_unet,
42
+ pose_guider,
43
+ scheduler: Union[
44
+ DDIMScheduler,
45
+ PNDMScheduler,
46
+ LMSDiscreteScheduler,
47
+ EulerDiscreteScheduler,
48
+ EulerAncestralDiscreteScheduler,
49
+ DPMSolverMultistepScheduler,
50
+ ],
51
+ ):
52
+ super().__init__()
53
+
54
+ self.register_modules(
55
+ vae=vae,
56
+ image_encoder=image_encoder,
57
+ reference_unet=reference_unet,
58
+ denoising_unet=denoising_unet,
59
+ pose_guider=pose_guider,
60
+ scheduler=scheduler,
61
+ )
62
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
63
+ self.clip_image_processor = CLIPImageProcessor()
64
+ self.ref_image_processor = VaeImageProcessor(
65
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
66
+ )
67
+ self.cond_image_processor = VaeImageProcessor(
68
+ vae_scale_factor=self.vae_scale_factor,
69
+ do_convert_rgb=True,
70
+ do_normalize=True,
71
+ )
72
+
73
+ def enable_vae_slicing(self):
74
+ self.vae.enable_slicing()
75
+
76
+ def disable_vae_slicing(self):
77
+ self.vae.disable_slicing()
78
+
79
+ def enable_sequential_cpu_offload(self, gpu_id=0):
80
+ if is_accelerate_available():
81
+ from accelerate import cpu_offload
82
+ else:
83
+ raise ImportError("Please install accelerate via `pip install accelerate`")
84
+
85
+ device = torch.device(f"cuda:{gpu_id}")
86
+
87
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
88
+ if cpu_offloaded_model is not None:
89
+ cpu_offload(cpu_offloaded_model, device)
90
+
91
+ @property
92
+ def _execution_device(self):
93
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
94
+ return self.device
95
+ for module in self.unet.modules():
96
+ if (
97
+ hasattr(module, "_hf_hook")
98
+ and hasattr(module._hf_hook, "execution_device")
99
+ and module._hf_hook.execution_device is not None
100
+ ):
101
+ return torch.device(module._hf_hook.execution_device)
102
+ return self.device
103
+
104
+ def decode_latents(self, latents):
105
+ video_length = latents.shape[2]
106
+ latents = 1 / 0.18215 * latents
107
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
108
+ # video = self.vae.decode(latents).sample
109
+ video = []
110
+ for frame_idx in tqdm(range(latents.shape[0])):
111
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
112
+ video = torch.cat(video)
113
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
114
+ video = (video / 2 + 0.5).clamp(0, 1)
115
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
116
+ video = video.cpu().float().numpy()
117
+ return video
118
+
119
+ def prepare_extra_step_kwargs(self, generator, eta):
120
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
121
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
122
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
123
+ # and should be between [0, 1]
124
+
125
+ accepts_eta = "eta" in set(
126
+ inspect.signature(self.scheduler.step).parameters.keys()
127
+ )
128
+ extra_step_kwargs = {}
129
+ if accepts_eta:
130
+ extra_step_kwargs["eta"] = eta
131
+
132
+ # check if the scheduler accepts generator
133
+ accepts_generator = "generator" in set(
134
+ inspect.signature(self.scheduler.step).parameters.keys()
135
+ )
136
+ if accepts_generator:
137
+ extra_step_kwargs["generator"] = generator
138
+ return extra_step_kwargs
139
+
140
+ def prepare_latents(
141
+ self,
142
+ batch_size,
143
+ num_channels_latents,
144
+ width,
145
+ height,
146
+ dtype,
147
+ device,
148
+ generator,
149
+ latents=None,
150
+ ):
151
+ shape = (
152
+ batch_size,
153
+ num_channels_latents,
154
+ height // self.vae_scale_factor,
155
+ width // self.vae_scale_factor,
156
+ )
157
+ if isinstance(generator, list) and len(generator) != batch_size:
158
+ raise ValueError(
159
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
160
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
161
+ )
162
+
163
+ if latents is None:
164
+ latents = randn_tensor(
165
+ shape, generator=generator, device=device, dtype=dtype
166
+ )
167
+ else:
168
+ latents = latents.to(device)
169
+
170
+ # scale the initial noise by the standard deviation required by the scheduler
171
+ latents = latents * self.scheduler.init_noise_sigma
172
+ return latents
173
+
174
+ def prepare_condition(
175
+ self,
176
+ cond_image,
177
+ width,
178
+ height,
179
+ device,
180
+ dtype,
181
+ do_classififer_free_guidance=False,
182
+ ):
183
+ image = self.cond_image_processor.preprocess(
184
+ cond_image, height=height, width=width
185
+ ).to(dtype=torch.float32)
186
+
187
+ image = image.to(device=device, dtype=dtype)
188
+
189
+
190
+ if do_classififer_free_guidance:
191
+ image = torch.cat([image] * 2)
192
+
193
+ return image
194
+
195
+ @torch.no_grad()
196
+ def __call__(
197
+ self,
198
+ ref_image,
199
+ pose_image,
200
+ ref_pose_image,
201
+ width,
202
+ height,
203
+ num_inference_steps,
204
+ guidance_scale,
205
+ num_images_per_prompt=1,
206
+ eta: float = 0.0,
207
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
208
+ output_type: Optional[str] = "tensor",
209
+ return_dict: bool = True,
210
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
211
+ callback_steps: Optional[int] = 1,
212
+ **kwargs,
213
+ ):
214
+ # Default height and width to unet
215
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
216
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
217
+
218
+ device = self._execution_device
219
+
220
+ do_classifier_free_guidance = guidance_scale > 1.0
221
+
222
+ # Prepare timesteps
223
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
224
+ timesteps = self.scheduler.timesteps
225
+
226
+ batch_size = 1
227
+
228
+ # Prepare clip image embeds
229
+ clip_image = self.clip_image_processor.preprocess(
230
+ ref_image.resize((224, 224)), return_tensors="pt"
231
+ ).pixel_values
232
+ clip_image_embeds = self.image_encoder(
233
+ clip_image.to(device, dtype=self.image_encoder.dtype)
234
+ ).image_embeds
235
+ image_prompt_embeds = clip_image_embeds.unsqueeze(1)
236
+
237
+ uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
238
+
239
+ if do_classifier_free_guidance:
240
+ image_prompt_embeds = torch.cat(
241
+ [uncond_image_prompt_embeds, image_prompt_embeds], dim=0
242
+ )
243
+
244
+ reference_control_writer = ReferenceAttentionControl(
245
+ self.reference_unet,
246
+ do_classifier_free_guidance=do_classifier_free_guidance,
247
+ mode="write",
248
+ batch_size=batch_size,
249
+ fusion_blocks="full",
250
+ )
251
+ reference_control_reader = ReferenceAttentionControl(
252
+ self.denoising_unet,
253
+ do_classifier_free_guidance=do_classifier_free_guidance,
254
+ mode="read",
255
+ batch_size=batch_size,
256
+ fusion_blocks="full",
257
+ )
258
+
259
+ num_channels_latents = self.denoising_unet.in_channels
260
+ latents = self.prepare_latents(
261
+ batch_size * num_images_per_prompt,
262
+ num_channels_latents,
263
+ width,
264
+ height,
265
+ clip_image_embeds.dtype,
266
+ device,
267
+ generator,
268
+ )
269
+ latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
270
+ latents_dtype = latents.dtype
271
+
272
+ # Prepare extra step kwargs.
273
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
274
+
275
+ # Prepare ref image latents
276
+ ref_image_tensor = self.ref_image_processor.preprocess(
277
+ ref_image, height=height, width=width
278
+ ) # (bs, c, width, height)
279
+ ref_image_tensor = ref_image_tensor.to(
280
+ dtype=self.vae.dtype, device=self.vae.device
281
+ )
282
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
283
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
284
+
285
+ # Prepare pose condition image
286
+ pose_cond_tensor = self.cond_image_processor.preprocess(
287
+ pose_image, height=height, width=width
288
+ )
289
+
290
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
291
+ pose_cond_tensor = pose_cond_tensor.to(
292
+ device=device, dtype=self.pose_guider.dtype
293
+ )
294
+
295
+ ref_pose_tensor = self.cond_image_processor.preprocess(
296
+ ref_pose_image, height=height, width=width
297
+ )
298
+ ref_pose_tensor = ref_pose_tensor.to(
299
+ device=device, dtype=self.pose_guider.dtype
300
+ )
301
+
302
+ pose_fea = self.pose_guider(pose_cond_tensor, ref_pose_tensor)
303
+ if do_classifier_free_guidance:
304
+ for idxx in range(len(pose_fea)):
305
+ pose_fea[idxx] = torch.cat([pose_fea[idxx]] * 2)
306
+
307
+ # denoising loop
308
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
309
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
310
+ for i, t in enumerate(timesteps):
311
+ # 1. Forward reference image
312
+ if i == 0:
313
+ self.reference_unet(
314
+ ref_image_latents.repeat(
315
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
316
+ ),
317
+ torch.zeros_like(t),
318
+ encoder_hidden_states=image_prompt_embeds,
319
+ return_dict=False,
320
+ )
321
+
322
+ # 2. Update reference unet feature into denosing net
323
+ reference_control_reader.update(reference_control_writer)
324
+
325
+ # 3.1 expand the latents if we are doing classifier free guidance
326
+ latent_model_input = (
327
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
328
+ )
329
+ latent_model_input = self.scheduler.scale_model_input(
330
+ latent_model_input, t
331
+ )
332
+
333
+ noise_pred = self.denoising_unet(
334
+ latent_model_input,
335
+ t,
336
+ encoder_hidden_states=image_prompt_embeds,
337
+ pose_cond_fea=pose_fea,
338
+ return_dict=False,
339
+ )[0]
340
+
341
+ # perform guidance
342
+ if do_classifier_free_guidance:
343
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
344
+ noise_pred = noise_pred_uncond + guidance_scale * (
345
+ noise_pred_text - noise_pred_uncond
346
+ )
347
+
348
+ # compute the previous noisy sample x_t -> x_t-1
349
+ latents = self.scheduler.step(
350
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
351
+ )[0]
352
+
353
+ # call the callback, if provided
354
+ if i == len(timesteps) - 1 or (
355
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
356
+ ):
357
+ progress_bar.update()
358
+ if callback is not None and i % callback_steps == 0:
359
+ step_idx = i // getattr(self.scheduler, "order", 1)
360
+ callback(step_idx, t, latents)
361
+ reference_control_reader.clear()
362
+ reference_control_writer.clear()
363
+
364
+ # Post-processing
365
+ image = self.decode_latents(latents) # (b, c, 1, h, w)
366
+
367
+ # Convert to tensor
368
+ if output_type == "tensor":
369
+ image = torch.from_numpy(image)
370
+
371
+ if not return_dict:
372
+ return image
373
+
374
+ return Pose2ImagePipelineOutput(images=image)