import os import argparse from tqdm.auto import tqdm from packaging import version import torch import torch.nn.functional as F import torch.utils.checkpoint from torchvision import transforms from diffusers import ( AutoencoderKL, ControlNetModel, DDPMScheduler, StableDiffusionControlNetPipeline, UNet2DConditionModel, UniPCMultistepScheduler, PNDMScheduler, AmusedInpaintPipeline, AmusedScheduler, VQModel, UVit2DModel ) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils import load_image from transformers import AutoTokenizer, CLIPFeatureExtractor, PretrainedConfig from PIL import Image from utils.mclip import * def parse_args(): parser = argparse.ArgumentParser(description="Edit images with M3Face.") parser.add_argument( "--prompt", type=str, default="This attractive woman has narrow eyes, rosy cheeks, and wears heavy makeup.", help="The input text prompt for image generation." ) parser.add_argument( "--condition", type=str, default="mask", choices=["mask", "landmark"], help="Use segmentation mask or facial landmarks for image generation." ) parser.add_argument( "--image_path", type=str, default=None, help="Path to the input image." ) parser.add_argument( "--condition_path", type=str, default=None, help="Path to the original mask/landmark image." ) parser.add_argument( "--edit_condition_path", type=str, default=None, help="Path to the target mask/landmark image." ) parser.add_argument( "--output_dir", type=str, default='output/', help="The output directory where the results will be written.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible generation.") parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument("--edit_condition", action="store_true") parser.add_argument("--load_unet_from_local", action="store_true") parser.add_argument("--save_unet", action="store_true") parser.add_argument("--unet_local_path", type=str, default=None) parser.add_argument("--load_finetune_from_local", action="store_true") parser.add_argument("--finetune_path", type=str, default=None) parser.add_argument("--use_english", action="store_true", help="Use the English models.") parser.add_argument("--embedding_optimize_it", type=int, default=500) parser.add_argument("--model_finetune_it", type=int, default=1000) parser.add_argument("--alpha", nargs="+", type=float, default=[0.8, 0.9, 1, 1.1]) parser.add_argument("--num_inference_steps", nargs="+", type=int, default=[20, 40, 50]) parser.add_argument("--unet_layer", type=str, default="2and3", help="Which UNet layers in the SD to finetune.") args = parser.parse_args() return args def get_muse(args): muse_model_name = 'm3face/FaceConditioning' if args.condition == 'mask': muse_revision = 'segmentation' elif args.condition == 'landmark': muse_revision = 'landmark' scheduler = AmusedScheduler.from_pretrained(muse_model_name, revision=muse_revision, subfolder='scheduler') vqvae = VQModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='vqvae') uvit2 = UVit2DModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='transformer') text_encoder = MultilingualCLIP.from_pretrained(muse_model_name, revision=muse_revision, subfolder='text_encoder') tokenizer = AutoTokenizer.from_pretrained(muse_model_name, revision=muse_revision, subfolder='tokenizer') pipeline = AmusedInpaintPipeline( vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=uvit2, scheduler=scheduler ).to("cuda") return pipeline def import_model_class_from_model_name(sd_model_name): text_encoder_config = PretrainedConfig.from_pretrained( sd_model_name, subfolder="text_encoder", ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "RobertaSeriesModelWithTransformation": from diffusers.pipelines.deprecated.alt_diffusion import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation else: raise ValueError(f"{model_class} is not supported.") def preprocess(image, condition, prompt, tokenizer): image_transforms = transforms.Compose( [ transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(512), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) condition_transforms = transforms.Compose( [ transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(512), transforms.ToTensor(), ] ) image = image_transforms(image) condition = condition_transforms(condition) inputs = tokenizer( [prompt], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) return image, condition, inputs.input_ids, inputs.attention_mask def main(args): if args.use_english: sd_model_name = 'runwayml/stable-diffusion-v1-5' controlnet_model_name = 'm3face/FaceControlNet' if args.condition == 'mask': controlnet_revision = 'segmentation-english' elif args.condition == 'landmark': controlnet_revision = 'landmark-english' else: sd_model_name = 'BAAI/AltDiffusion-m18' controlnet_model_name = 'm3face/FaceControlNet' if args.condition == 'mask': controlnet_revision = 'segmentation-mlin' elif args.condition == 'landmark': controlnet_revision = 'landmark-mlin' # ========== set up models ========== vae = AutoencoderKL.from_pretrained(sd_model_name, subfolder="vae") tokenizer = AutoTokenizer.from_pretrained(sd_model_name, subfolder="tokenizer", use_fast=False) text_encoder_cls = import_model_class_from_model_name(sd_model_name) text_encoder = text_encoder_cls.from_pretrained(sd_model_name, subfolder="text_encoder") noise_scheduler = DDPMScheduler.from_pretrained(sd_model_name, subfolder="scheduler") if args.load_unet_from_local: unet = UNet2DConditionModel.from_pretrained(args.unet_local_path) else: unet = UNet2DConditionModel.from_pretrained(sd_model_name, subfolder="unet") controlnet = ControlNetModel.from_pretrained(controlnet_model_name, revision=controlnet_revision) if args.edit_condition: muse = get_muse(args) vae.requires_grad_(False) text_encoder.requires_grad_(False) controlnet.requires_grad_(False) unet.requires_grad_(False) vae.eval() text_encoder.eval() controlnet.eval() unet.eval() if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): print( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() controlnet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") # ========== select params to optimize ========== params = [] for name, param in unet.named_parameters(): if(name.startswith('up_blocks')): params.append(param) if args.unet_layer == 'only1': # 116 layers params_to_optimize = [ {'params': params[38:154]}, ] elif args.unet_layer == 'only2': # 116 layers params_to_optimize = [ {'params': params[154:270]}, ] elif args.unet_layer == 'only3': # 114 layers params_to_optimize = [ {'params': params[270:]}, ] elif args.unet_layer == '1and2': # 232 layers params_to_optimize = [ {'params': params[38:270]}, ] elif args.unet_layer == '2and3': # 230 layers params_to_optimize = [ {'params': params[154:]}, ] elif args.unet_layer == 'all': # all layers params_to_optimize = [ {'params': params}, ] image = Image.open(args.image_path).convert('RGB') condition = Image.open(args.condition_path).convert('RGB') image, condition, input_ids, attention_mask = preprocess(image, condition, args.prompt, tokenizer) # Move to device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') vae.to(device, dtype=torch.float32) unet.to(device, dtype=torch.float32) text_encoder.to(device, dtype=torch.float32) controlnet.to(device) image = image.to(device).unsqueeze(0) condition = condition.to(device).unsqueeze(0) input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) # ========== imagic ========== if args.load_finetune_from_local: print('Loading embeddings from local ...') orig_emb = torch.load(os.path.join(args.finetune_path, 'orig_emb.pt')) emb = torch.load(os.path.join(args.finetune_path, 'emb.pt')) else: init_latent = vae.encode(image.to(dtype=torch.float32)).latent_dist.sample() init_latent = init_latent * vae.config.scaling_factor if not args.use_english: orig_emb = text_encoder(input_ids, attention_mask=attention_mask)[0] else: orig_emb = text_encoder(input_ids)[0] emb = orig_emb.clone() torch.save(orig_emb, os.path.join(args.output_dir, 'orig_emb.pt')) torch.save(emb, os.path.join(args.output_dir, 'emb.pt')) # 1. Optimize the embedding print('1. Optimize the embedding') unet.eval() emb.requires_grad = True lr = 0.001 it = args.embedding_optimize_it # 500 opt = torch.optim.Adam([emb], lr=lr) history = [] pbar = tqdm( range(it), initial=0, desc="Optimize Steps", ) global_step = 0 for i in pbar: opt.zero_grad() noise = torch.randn_like(init_latent) bsz = init_latent.shape[0] t_enc = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latent.device) t_enc = t_enc.long() z = noise_scheduler.add_noise(init_latent, noise, t_enc) controlnet_image = condition.to(dtype=torch.float32) down_block_res_samples, mid_block_res_sample = controlnet( z, t_enc, encoder_hidden_states=emb, controlnet_cond=controlnet_image, return_dict=False, ) # Predict the noise residual pred_noise = unet( z, t_enc, encoder_hidden_states=emb, down_block_additional_residuals=[ sample.to(dtype=torch.float32) for sample in down_block_res_samples ], mid_block_additional_residual=mid_block_res_sample.to(dtype=torch.float32), ).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(init_latent, noise, t_enc) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean") loss.backward() global_step += 1 pbar.set_postfix({"loss": loss.item()}) history.append(loss.item()) opt.step() opt.zero_grad() # 2. Finetune the model print('2. Finetune the model') emb.requires_grad = False unet.requires_grad_(True) unet.train() lr = 5e-5 it = args.model_finetune_it # 1000 opt = torch.optim.Adam(params_to_optimize, lr=lr) history = [] pbar = tqdm( range(it), initial=0, desc="Finetune Steps", ) global_step = 0 for i in pbar: opt.zero_grad() noise = torch.randn_like(init_latent) bsz = init_latent.shape[0] t_enc = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latent.device) t_enc = t_enc.long() z = noise_scheduler.add_noise(init_latent, noise, t_enc) controlnet_image = condition.to(dtype=torch.float32) down_block_res_samples, mid_block_res_sample = controlnet( z, t_enc, encoder_hidden_states=emb, controlnet_cond=controlnet_image, return_dict=False, ) # Predict the noise residual pred_noise = unet( z, t_enc, encoder_hidden_states=emb, down_block_additional_residuals=[ sample.to(dtype=torch.float32) for sample in down_block_res_samples ], mid_block_additional_residual=mid_block_res_sample.to(dtype=torch.float32), ).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(init_latent, noise, t_enc) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean") loss.backward() global_step += 1 pbar.set_postfix({"loss": loss.item()}) history.append(loss.item()) opt.step() opt.zero_grad() # 3. Generate Images print("3. Generating images... ") unet.eval() controlnet.eval() if args.edit_condition_path is None: edit_condition = load_image(args.condition_path) else: edit_condition = load_image(args.edit_condition_path) if args.edit_condition: edit_mask = Image.new("L", (256, 256), 0) for i in range(256): for j in range(256): if 40 < i < 220 and 20 < j < 256: edit_mask.putpixel((i, j), 256) if args.condition == 'mask': condition = 'segmentation' elif args.condition == 'landmark': condition = 'landmark' edit_prompt = f"Generate face {condition} | " + args.prompt input_image = edit_condition.resize((256, 256)).convert("RGB") edit_condition = muse(edit_prompt, input_image, edit_mask, num_inference_steps=30).images[0].resize((512, 512)) edit_condition.save(f'{args.output_dir}/edited_condition.png') # remove muse and empty cache del muse torch.cuda.empty_cache() if sd_model_name.startswith('BAAI'): scheduler = PNDMScheduler.from_pretrained( sd_model_name, subfolder='scheduler', ) scheduler = UniPCMultistepScheduler.from_config(scheduler.config) feature_extractor = CLIPFeatureExtractor.from_pretrained( sd_model_name, subfolder='feature_extractor', ) pipeline = StableDiffusionControlNetPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, feature_extractor=feature_extractor ) else: pipeline = StableDiffusionControlNetPipeline.from_pretrained( sd_model_name, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, safety_checker=None, ) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(device) pipeline.set_progress_bar_config(disable=True) if args.enable_xformers_memory_efficient_attention: pipeline.enable_xformers_memory_efficient_attention() if args.seed is None: generator = None else: generator = torch.Generator(device=device).manual_seed(args.seed) with torch.autocast("cuda"): image = pipeline( image=edit_condition, prompt_embeds=emb, num_inference_steps=20, generator=generator ).images[0] image.save(f'{args.output_dir}/reconstruct.png') # Interpolate the embedding for num_inference_steps in args.num_inference_steps: for alpha in args.alpha: new_emb = alpha * orig_emb + (1 - alpha) * emb with torch.autocast("cuda"): image = pipeline( image=edit_condition, prompt_embeds=new_emb, num_inference_steps=num_inference_steps, generator=generator ).images[0] image.save(f'{args.output_dir}/image_{num_inference_steps}_{alpha}.png') if args.save_unet: print('Saving the unet model...') unet.save_pretrained(f'{args.output_dir}/unet') if __name__ == '__main__': args = parse_args() main(args)