import os from pathlib import Path from optimization.constants import ASSETS_DIR_NAME, RANKED_RESULTS_DIR from utils.metrics_accumulator import MetricsAccumulator from utils.video import save_video from utils.fft_pytorch import HighFrequencyLoss from numpy import random from optimization.augmentations import ImageAugmentations from PIL import Image import torch import torchvision from torchvision import transforms import torchvision.transforms.functional as F from torchvision.transforms import functional as TF from torch.nn.functional import mse_loss from optimization.losses import range_loss, d_clip_loss import lpips import numpy as np from CLIP import clip from guided_diffusion.guided_diffusion.script_util import ( create_model_and_diffusion, model_and_diffusion_defaults, create_classifier, classifier_defaults, ) from utils.visualization import show_tensor_image, show_editied_masked_image from utils.change_place import change_place, find_bbox import pdb import cv2 def create_classifier_ours(): model = torchvision.models.resnet50() ckpt = torch.load('checkpoints/DRA_resnet50.pth')['model_state_dict'] model.load_state_dict({k.replace('module.','').replace('last_linear','fc'):v for k,v in ckpt.items()}) model = torch.nn.Sequential(*[torch.nn.Upsample(size=(256,256)), model]) return model class ImageEditor: def __init__(self, args) -> None: self.args = args os.makedirs(self.args.output_path, exist_ok=True) self.ranked_results_path = Path(os.path.join(self.args.output_path, RANKED_RESULTS_DIR)) os.makedirs(self.ranked_results_path, exist_ok=True) if self.args.export_assets: self.assets_path = Path(os.path.join(self.args.output_path, ASSETS_DIR_NAME)) os.makedirs(self.assets_path, exist_ok=True) if self.args.seed is not None: torch.manual_seed(self.args.seed) np.random.seed(self.args.seed) random.seed(self.args.seed) self.model_config = model_and_diffusion_defaults() self.model_config.update( { "attention_resolutions": "32, 16, 8", "class_cond": self.args.model_output_size == 512, "diffusion_steps": 1000, "rescale_timesteps": True, "timestep_respacing": self.args.timestep_respacing, "image_size": self.args.model_output_size, "learn_sigma": True, "noise_schedule": "linear", "num_channels": 256, "num_head_channels": 64, "num_res_blocks": 2, "resblock_updown": True, "use_fp16": True, "use_scale_shift_norm": True, } ) self.classifier_config = classifier_defaults() self.classifier_config.update( { "image_size": self.args.model_output_size, } ) # Load models self.device = torch.device( f"cuda:{self.args.gpu_id}" if torch.cuda.is_available() else "cpu" ) print("Using device:", self.device) self.model, self.diffusion = create_model_and_diffusion(**self.model_config) self.model.load_state_dict( torch.load( "checkpoints/256x256_diffusion_uncond.pt" if self.args.model_output_size == 256 else "checkpoints/512x512_diffusion.pt", map_location="cpu", ) ) # self.model.requires_grad_(False).eval().to(self.device) self.model.eval().to(self.device) for name, param in self.model.named_parameters(): if "qkv" in name or "norm" in name or "proj" in name: param.requires_grad_() if self.model_config["use_fp16"]: self.model.convert_to_fp16() self.classifier = create_classifier(**self.classifier_config) self.classifier.load_state_dict( torch.load("checkpoints/256x256_classifier.pt", map_location="cpu") ) # self.classifier.requires_grad_(False).eval().to(self.device) # self.classifier = create_classifier_ours() self.classifier.eval().to(self.device) if self.classifier_config["classifier_use_fp16"]: self.classifier.convert_to_fp16() self.clip_model = ( clip.load("ViT-B/16", device=self.device, jit=False)[0].eval().requires_grad_(False) ) self.clip_size = self.clip_model.visual.input_resolution self.clip_normalize = transforms.Normalize( mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] ) self.to_tensor = transforms.ToTensor() self.lpips_model = lpips.LPIPS(net="vgg").to(self.device) self.image_augmentations = ImageAugmentations(self.clip_size, self.args.aug_num) self.metrics_accumulator = MetricsAccumulator() self.hf_loss = HighFrequencyLoss() def unscale_timestep(self, t): unscaled_timestep = (t * (self.diffusion.num_timesteps / 1000)).long() return unscaled_timestep def clip_loss(self, x_in, text_embed): clip_loss = torch.tensor(0) if self.mask is not None: masked_input = x_in * self.mask else: masked_input = x_in augmented_input = self.image_augmentations(masked_input).add(1).div(2) # shape: [N,C,H,W], range: [0,1] clip_in = self.clip_normalize(augmented_input) # pdb.set_trace() image_embeds = self.clip_model.encode_image(clip_in).float() dists = d_clip_loss(image_embeds, text_embed) # We want to sum over the averages for i in range(self.args.batch_size): # We want to average at the "augmentations level" clip_loss = clip_loss + dists[i :: self.args.batch_size].mean() return clip_loss def unaugmented_clip_distance(self, x, text_embed): x = F.resize(x, [self.clip_size, self.clip_size]) image_embeds = self.clip_model.encode_image(x).float() dists = d_clip_loss(image_embeds, text_embed) return dists.item() def model_fn(self, x,t,y=None): return self.model(x, t, y if self.args.class_cond else None) def edit_image_by_prompt(self): if self.args.image_guide: img_guidance = Image.open(self.args.prompt).convert('RGB') img_guidance = img_guidance.resize((224,224), Image.LANCZOS) # type: ignore img_guidance = self.clip_normalize(self.to_tensor(img_guidance).unsqueeze(0)).to(self.device) text_embed = self.clip_model.encode_image(img_guidance).float() else: text_embed = self.clip_model.encode_text( clip.tokenize(self.args.prompt).to(self.device) ).float() self.image_size = (self.model_config["image_size"], self.model_config["image_size"]) self.init_image_pil = Image.open(self.args.init_image).convert("RGB") self.init_image_pil = self.init_image_pil.resize(self.image_size, Image.LANCZOS) # type: ignore self.init_image = ( TF.to_tensor(self.init_image_pil).to(self.device).unsqueeze(0).mul(2).sub(1) ) self.init_image_pil_2 = Image.open(self.args.init_image_2).convert("RGB") if self.args.rotate_obj: # angle = random.randint(-45,45) angle = self.args.angle self.init_image_pil_2 = self.init_image_pil_2.rotate(angle) self.init_image_pil_2 = self.init_image_pil_2.resize(self.image_size, Image.LANCZOS) # type: ignore self.init_image_2 = ( TF.to_tensor(self.init_image_pil_2).to(self.device).unsqueeze(0).mul(2).sub(1) ) ''' # Init with the inpainting image self.init_image_pil_ = Image.open('output/ImageNet-S_val/bad_case_RN50/ILSVRC2012_val_00013212/ranked/08480_output_i_0_b_0.png').convert("RGB") self.init_image_pil_ = self.init_image_pil_.resize(self.image_size, Image.LANCZOS) # type: ignore self.init_image_ = ( TF.to_tensor(self.init_image_pil_).to(self.device).unsqueeze(0).mul(2).sub(1) ) ''' if self.args.export_assets: img_path = self.assets_path / Path(self.args.output_file) self.init_image_pil.save(img_path, quality=100) self.mask = torch.ones_like(self.init_image, device=self.device) self.mask_pil = None if self.args.mask is not None: self.mask_pil = Image.open(self.args.mask).convert("RGB") if self.args.rotate_obj: self.mask_pil = self.mask_pil.rotate(angle) if self.mask_pil.size != self.image_size: self.mask_pil = self.mask_pil.resize(self.image_size, Image.NEAREST) # type: ignore if self.args.random_position: bbox = find_bbox(np.array(self.mask_pil)) print(bbox) image_mask_pil_binarized = ((np.array(self.mask_pil) > 0.5) * 255).astype(np.uint8) # image_mask_pil_binarized = cv2.dilate(image_mask_pil_binarized, np.ones((50,50), np.uint8), iterations=1) if self.args.invert_mask: image_mask_pil_binarized = 255 - image_mask_pil_binarized self.mask_pil = TF.to_pil_image(image_mask_pil_binarized) self.mask = TF.to_tensor(Image.fromarray(image_mask_pil_binarized)) self.mask = self.mask[0, ...].unsqueeze(0).unsqueeze(0).to(self.device) # self.mask[:] = 1 if self.args.random_position: # print(self.init_image_2.shape, self.init_image_2.max(), self.init_image_2.min()) # print(self.mask.shape, self.mask.max(), self.mask.min()) # cv2.imwrite('tmp/init_before.jpg', np.transpose(((self.init_image_2+1)/2*255).cpu().numpy()[0], (1,2,0))[:,:,::-1]) # cv2.imwrite('tmp/mask_before.jpg', (self.mask*255).cpu().numpy()[0][0]) self.init_image_2, self.mask = change_place(self.init_image_2, self.mask, bbox, self.args.invert_mask) # cv2.imwrite('tmp/init_after.jpg', np.transpose(((self.init_image_2+1)/2*255).cpu().numpy()[0], (1,2,0))[:,:,::-1]) # cv2.imwrite('tmp/mask_after.jpg', (self.mask*255).cpu().numpy()[0][0]) if self.args.export_assets: mask_path = self.assets_path / Path( self.args.output_file.replace(".png", "_mask.png") ) self.mask_pil.save(mask_path, quality=100) def class_guided(x, y, t): assert y is not None with torch.enable_grad(): x_in = x.detach().requires_grad_(True) # logits = self.classifier(x_in, t) logits = self.classifier(x_in) log_probs = torch.nn.functional.log_softmax(logits, dim=-1) selected = log_probs[range(len(logits)), y.view(-1)] loss = selected.sum() return -torch.autograd.grad(loss, x_in)[0] * self.args.classifier_scale def cond_fn(x, t, y=None): if self.args.prompt == "": return torch.zeros_like(x) # pdb.set_trace() with torch.enable_grad(): x = x.detach().requires_grad_() t_unscale = self.unscale_timestep(t) ''' out = self.diffusion.p_mean_variance( self.model, x, t, clip_denoised=False, model_kwargs={"y": y} ) ''' out = self.diffusion.p_mean_variance( self.model, x, t_unscale, clip_denoised=False, model_kwargs={"y": None} ) fac = self.diffusion.sqrt_one_minus_alphas_cumprod[t_unscale[0].item()] # x_in = out["pred_xstart"] * fac + x * (1 - fac) x_in = out["pred_xstart"] # Revised by XX, 2022.07.14 loss = torch.tensor(0) if self.args.classifier_scale != 0 and y is not None: # gradient_class_guided = class_guided(x, y, t) gradient_class_guided = class_guided(x_in, y, t) if self.args.background_complex != 0: if self.args.hard: loss = loss - self.args.background_complex*self.hf_loss((x_in+1.)/2.) else: loss = loss + self.args.background_complex*self.hf_loss((x_in+1.)/2.) if self.args.clip_guidance_lambda != 0: clip_loss = self.clip_loss(x_in, text_embed) * self.args.clip_guidance_lambda loss = loss + clip_loss self.metrics_accumulator.update_metric("clip_loss", clip_loss.item()) if self.args.range_lambda != 0: r_loss = range_loss(out["pred_xstart"]).sum() * self.args.range_lambda loss = loss + r_loss self.metrics_accumulator.update_metric("range_loss", r_loss.item()) if self.args.background_preservation_loss: x_in = out["pred_xstart"] * fac + x * (1 - fac) if self.mask is not None: # masked_background = x_in * (1 - self.mask) masked_background = x_in * self.mask # 2022.07.19 else: masked_background = x_in if self.args.lpips_sim_lambda: ''' loss = ( loss + self.lpips_model(masked_background, self.init_image).sum() * self.args.lpips_sim_lambda ) ''' # 2022.07.19 loss = ( loss + self.lpips_model(masked_background, self.init_image*self.mask).sum() * self.args.lpips_sim_lambda ) if self.args.l2_sim_lambda: ''' loss = ( loss + mse_loss(masked_background, self.init_image) * self.args.l2_sim_lambda ) ''' # 2022.07.19 loss = ( loss + mse_loss(masked_background, self.init_image*self.mask) * self.args.l2_sim_lambda ) if self.args.classifier_scale != 0 and y is not None: return -torch.autograd.grad(loss, x)[0] + gradient_class_guided else: return -torch.autograd.grad(loss, x)[0] @torch.no_grad() def postprocess_fn(out, t): if self.args.coarse_to_fine: if t > 50: kernel = 51 elif t > 35: kernel = 31 else: kernel = 0 if kernel > 0: max_pool = torch.nn.MaxPool2d(kernel_size=kernel, stride=1, padding=int((kernel-1)/2)) self.mask_d = 1 - self.mask self.mask_d = max_pool(self.mask_d) self.mask_d = 1 - self.mask_d else: self.mask_d = self.mask else: self.mask_d = self.mask if self.mask is not None: background_stage_t = self.diffusion.q_sample(self.init_image_2, t[0]) background_stage_t = torch.tile( background_stage_t, dims=(self.args.batch_size, 1, 1, 1) ) out["sample"] = out["sample"] * self.mask_d + background_stage_t * (1 - self.mask_d) return out save_image_interval = self.diffusion.num_timesteps // 5 for iteration_number in range(self.args.iterations_num): print(f"Start iterations {iteration_number}") sample_func = ( self.diffusion.ddim_sample_loop_progressive if self.args.ddim else self.diffusion.p_sample_loop_progressive ) samples = sample_func( self.model_fn, ( self.args.batch_size, 3, self.model_config["image_size"], self.model_config["image_size"], ), clip_denoised=False, # model_kwargs={} # if self.args.model_output_size == 256 # else { # "y": torch.zeros([self.args.batch_size], device=self.device, dtype=torch.long) # }, model_kwargs={} if self.args.classifier_scale == 0 else {"y": self.args.y*torch.ones([self.args.batch_size], device=self.device, dtype=torch.long)}, cond_fn=cond_fn, device=self.device, progress=True, skip_timesteps=self.args.skip_timesteps, init_image=self.init_image, # init_image=self.init_image_, postprocess_fn=None if self.args.local_clip_guided_diffusion else postprocess_fn, randomize_class=True if self.args.classifier_scale == 0 else False, ) intermediate_samples = [[] for i in range(self.args.batch_size)] total_steps = self.diffusion.num_timesteps - self.args.skip_timesteps - 1 for j, sample in enumerate(samples): should_save_image = j % save_image_interval == 0 or j == total_steps if should_save_image or self.args.save_video: self.metrics_accumulator.print_average_metric() for b in range(self.args.batch_size): pred_image = sample["pred_xstart"][b] visualization_path = Path( os.path.join(self.args.output_path, self.args.output_file) ) visualization_path = visualization_path.with_stem( f"{visualization_path.stem}_i_{iteration_number}_b_{b}" ) if ( self.mask is not None and self.args.enforce_background and j == total_steps and not self.args.local_clip_guided_diffusion ): pred_image = ( self.init_image_2[0] * (1 - self.mask[0]) + pred_image * self.mask[0] ) ''' if j == total_steps: pdb.set_trace() pred_image = ( self.init_image_2[0] * (1 - self.mask[0]) + pred_image * self.mask[0] ) ''' pred_image = pred_image.add(1).div(2).clamp(0, 1) pred_image_pil = TF.to_pil_image(pred_image) masked_pred_image = self.mask * pred_image.unsqueeze(0) final_distance = self.unaugmented_clip_distance( masked_pred_image, text_embed ) formatted_distance = f"{final_distance:.4f}" if self.args.export_assets: pred_path = self.assets_path / visualization_path.name pred_image_pil.save(pred_path, quality=100) if j == total_steps: path_friendly_distance = formatted_distance.replace(".", "") ranked_pred_path = self.ranked_results_path / ( path_friendly_distance + "_" + visualization_path.name ) pred_image_pil.save(ranked_pred_path, quality=100) intermediate_samples[b].append(pred_image_pil) if should_save_image: show_editied_masked_image( title=self.args.prompt, source_image=self.init_image_pil, edited_image=pred_image_pil, mask=self.mask_pil, path=visualization_path, distance=formatted_distance, ) if self.args.save_video: for b in range(self.args.batch_size): video_name = self.args.output_file.replace( ".png", f"_i_{iteration_number}_b_{b}.avi" ) video_path = os.path.join(self.args.output_path, video_name) save_video(intermediate_samples[b], video_path) visualize_size = (256,256) img_ori = cv2.imread(self.args.init_image_2) img_ori = cv2.resize(img_ori, visualize_size) mask = cv2.imread(self.args.mask) mask = cv2.resize(mask, visualize_size) imgs = [img_ori, mask] for ii, img_name in enumerate(os.listdir(os.path.join(self.args.output_path, 'ranked'))): img_path = os.path.join(self.args.output_path, 'ranked', img_name) img = cv2.imread(img_path) img = cv2.resize(img, visualize_size) imgs.append(img) if ii >= 7: break img_whole = cv2.hconcat(imgs[2:]) ''' img_name = self.args.output_path.split('/')[-2]+'/' if self.args.coarse_to_fine: if self.args.clip_guidance_lambda == 0: prompt = 'coarse_to_fine_no_clip' else: prompt = 'coarse_to_fine' elif self.args.image_guide: prompt = 'image_guide' elif self.args.clip_guidance_lambda == 0: prompt = 'no_clip_guide' else: prompt = 'text_guide' ''' cv2.imwrite(os.path.join(self.args.final_save_root, 'edited.png'), img_whole, [int(cv2.IMWRITE_PNG_COMPRESSION), 0]) def reconstruct_image(self): init = Image.open(self.args.init_image).convert("RGB") init = init.resize( self.image_size, # type: ignore Image.LANCZOS, ) init = TF.to_tensor(init).to(self.device).unsqueeze(0).mul(2).sub(1) samples = self.diffusion.p_sample_loop_progressive( self.model, (1, 3, self.model_config["image_size"], self.model_config["image_size"],), clip_denoised=False, model_kwargs={} if self.args.model_output_size == 256 else {"y": torch.zeros([self.args.batch_size], device=self.device, dtype=torch.long)}, cond_fn=None, progress=True, skip_timesteps=self.args.skip_timesteps, init_image=init, randomize_class=True, ) save_image_interval = self.diffusion.num_timesteps // 5 max_iterations = self.diffusion.num_timesteps - self.args.skip_timesteps - 1 for j, sample in enumerate(samples): if j % save_image_interval == 0 or j == max_iterations: print() filename = os.path.join(self.args.output_path, self.args.output_file) TF.to_pil_image(sample["pred_xstart"][0].add(1).div(2).clamp(0, 1)).save(filename)