from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import ( FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, ) from diffusers.models.autoencoders import AutoencoderKL from diffusers.pipelines.flux.pipeline_flux_fill import ( calculate_shift, retrieve_latents, retrieve_timesteps, ) from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor from model.flux.transformer_flux import FluxTransformer2DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Modified from `diffusers.pipelines.flux.pipeline_flux_fill.FluxFillPipeline` class FluxTryOnPipeline( DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, ): model_cpu_offload_seq = "transformer->vae" _optional_components = [] _callback_tensor_inputs = ["latents"] def __init__( self, vae: AutoencoderKL, scheduler: FlowMatchEulerDiscreteScheduler, transformer: FluxTransformer2DModel, ): super().__init__() self.register_modules( vae=vae, scheduler=scheduler, transformer=transformer, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.vae.config.latent_channels, do_normalize=False, do_binarize=True, do_convert_grayscale=True, ) self.default_sample_size = 128 self.transformer.remove_text_layers() # TryOnEdit: remove text layers @classmethod def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs): transformer = FluxTransformer2DModel.from_pretrained(pretrained_model_name_or_path, subfolder="transformer") transformer.remove_text_layers() vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") return FluxTryOnPipeline(vae, scheduler, transformer) def prepare_mask_latents( self, mask, masked_image, batch_size, num_channels_latents, num_images_per_prompt, height, width, dtype, device, generator, ): # 1. calculate the height and width of the latents # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) # 2. encode the masked image if masked_image.shape[1] == num_channels_latents: masked_image_latents = masked_image else: masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) # 3. duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method batch_size = batch_size * num_images_per_prompt if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) # 4. pack the masked_image_latents # batch_size, num_channels_latents, height, width -> batch_size, height//2 * width//2 , num_channels_latents*4 masked_image_latents = self._pack_latents( masked_image_latents, batch_size, num_channels_latents, height, width, ) # 5.resize mask to latents shape we we concatenate the mask to the latents mask = mask[:, 0, :, :] # batch_size, 8 * height, 8 * width (mask has not been 8x compressed) mask = mask.view( batch_size, height, self.vae_scale_factor, width, self.vae_scale_factor ) # batch_size, height, 8, width, 8 mask = mask.permute(0, 2, 4, 1, 3) # batch_size, 8, 8, height, width mask = mask.reshape( batch_size, self.vae_scale_factor * self.vae_scale_factor, height, width ) # batch_size, 8*8, height, width # 6. pack the mask: # batch_size, 64, height, width -> batch_size, height//2 * width//2 , 64*2*2 mask = self._pack_latents( mask, batch_size, self.vae_scale_factor * self.vae_scale_factor, height, width, ) mask = mask.to(device=device, dtype=dtype) return mask, masked_image_latents def check_inputs( self, height, width, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, image=None, mask_image=None, condition_image=None, masked_image_latents=None, ): if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: logger.warning( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") if image is not None and masked_image_latents is not None: raise ValueError( "Please provide either `image` or `masked_image_latents`, `masked_image_latents` should not be passed." ) if image is not None and mask_image is None: raise ValueError("Please provide `mask_image` when passing `image`.") if condition_image is None: raise ValueError("Please provide `condition_image`.") @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height, width, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_ids = latent_image_ids.reshape( latent_image_id_height * latent_image_id_width, latent_image_id_channels ) return latent_image_ids.to(device=device, dtype=dtype) @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) return latents @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (vae_scale_factor * 2)) width = 2 * (int(width) // (vae_scale_factor * 2)) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) if latents is not None: latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents.to(device=device, dtype=dtype), latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents, latent_image_ids @property def guidance_scale(self): return self._guidance_scale @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() def __call__( self, image: Optional[torch.FloatTensor] = None, condition_image: Optional[torch.FloatTensor] = None, # TryOnEdit: condition image (garment) mask_image: Optional[torch.FloatTensor] = None, masked_image_latents: Optional[torch.FloatTensor] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, sigmas: Optional[List[float]] = None, guidance_scale: float = 30.0, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, ): height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( height, width, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, image=image, mask_image=mask_image, condition_image=condition_image, masked_image_latents=masked_image_latents, ) self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters batch_size = 1 device = self._execution_device dtype = self.transformer.dtype # 3. Prepare prompt embeddings lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) # 4. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width * 2, # TryOnEdit: width * 2 dtype, device, generator, latents, ) # 5. Prepare mask and masked image latents if masked_image_latents is not None: masked_image_latents = masked_image_latents.to(latents.device) else: image = self.image_processor.preprocess(image, height=height, width=width) condition_image = self.image_processor.preprocess(condition_image, height=height, width=width) mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) masked_image = image * (1 - mask_image) masked_image = masked_image.to(device=device, dtype=dtype) # TryOnEdit: Concat condition image to masked image condition_image = condition_image.to(device=device, dtype=dtype) masked_image = torch.cat((masked_image, condition_image), dim=-1) mask_image = torch.cat((mask_image, torch.zeros_like(mask_image)), dim=-1) height, width = image.shape[-2:] mask, masked_image_latents = self.prepare_mask_latents( mask_image, masked_image, batch_size, num_channels_latents, num_images_per_prompt, height, width * 2, # TryOnEdit: width * 2 dtype, device, generator, ) masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) # 6. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, self.scheduler.config.max_image_seq_len, self.scheduler.config.base_shift, self.scheduler.config.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # handle guidance if self.transformer.config.guidance_embeds: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) guidance = guidance.expand(latents.shape[0]) else: guidance = None # 7. Denoising loop pooled_prompt_embeds = torch.zeros([latents.shape[0], 768], device=device, dtype=dtype) # TryOnEdit: for now, we don't use pooled prompt embeddings with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = self.transformer( hidden_states=torch.cat((latents, masked_image_latents), dim=2), timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=None, txt_ids=None, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # 8. Post-process the image if output_type == "latent": image = latents else: latents = self._unpack_latents(latents, height, width * 2, self.vae_scale_factor) # TryOnEdit: width * 2 latents = latents.split(latents.shape[-1] // 2, dim=-1)[0] # TryOnEdit: split along the last dimension latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return FluxPipelineOutput(images=image)