|
import os |
|
from pathlib import Path |
|
from optimization.constants import ASSETS_DIR_NAME, RANKED_RESULTS_DIR |
|
|
|
from utils.metrics_accumulator import MetricsAccumulator |
|
from utils.video import save_video |
|
from utils.fft_pytorch import HighFrequencyLoss |
|
|
|
from numpy import random |
|
from optimization.augmentations import ImageAugmentations |
|
|
|
from PIL import Image |
|
import torch |
|
import torchvision |
|
from torchvision import transforms |
|
import torchvision.transforms.functional as F |
|
from torchvision.transforms import functional as TF |
|
from torch.nn.functional import mse_loss |
|
from optimization.losses import range_loss, d_clip_loss |
|
import lpips |
|
import numpy as np |
|
|
|
from CLIP import clip |
|
from guided_diffusion.guided_diffusion.script_util import ( |
|
create_model_and_diffusion, |
|
model_and_diffusion_defaults, |
|
create_classifier, |
|
classifier_defaults, |
|
) |
|
from utils.visualization import show_tensor_image, show_editied_masked_image |
|
from utils.change_place import change_place, find_bbox |
|
|
|
import pdb |
|
import cv2 |
|
|
|
def create_classifier_ours(): |
|
|
|
model = torchvision.models.resnet50() |
|
ckpt = torch.load('checkpoints/DRA_resnet50.pth')['model_state_dict'] |
|
model.load_state_dict({k.replace('module.','').replace('last_linear','fc'):v for k,v in ckpt.items()}) |
|
model = torch.nn.Sequential(*[torch.nn.Upsample(size=(256,256)), model]) |
|
return model |
|
|
|
class ImageEditor: |
|
def __init__(self, args) -> None: |
|
self.args = args |
|
os.makedirs(self.args.output_path, exist_ok=True) |
|
|
|
self.ranked_results_path = Path(os.path.join(self.args.output_path, RANKED_RESULTS_DIR)) |
|
os.makedirs(self.ranked_results_path, exist_ok=True) |
|
|
|
if self.args.export_assets: |
|
self.assets_path = Path(os.path.join(self.args.output_path, ASSETS_DIR_NAME)) |
|
os.makedirs(self.assets_path, exist_ok=True) |
|
if self.args.seed is not None: |
|
torch.manual_seed(self.args.seed) |
|
np.random.seed(self.args.seed) |
|
random.seed(self.args.seed) |
|
|
|
self.model_config = model_and_diffusion_defaults() |
|
self.model_config.update( |
|
{ |
|
"attention_resolutions": "32, 16, 8", |
|
"class_cond": self.args.model_output_size == 512, |
|
"diffusion_steps": 1000, |
|
"rescale_timesteps": True, |
|
"timestep_respacing": self.args.timestep_respacing, |
|
"image_size": self.args.model_output_size, |
|
"learn_sigma": True, |
|
"noise_schedule": "linear", |
|
"num_channels": 256, |
|
"num_head_channels": 64, |
|
"num_res_blocks": 2, |
|
"resblock_updown": True, |
|
"use_fp16": True, |
|
"use_scale_shift_norm": True, |
|
} |
|
) |
|
|
|
self.classifier_config = classifier_defaults() |
|
self.classifier_config.update( |
|
{ |
|
"image_size": self.args.model_output_size, |
|
} |
|
) |
|
|
|
|
|
self.device = torch.device( |
|
f"cuda:{self.args.gpu_id}" if torch.cuda.is_available() else "cpu" |
|
) |
|
print("Using device:", self.device) |
|
|
|
self.model, self.diffusion = create_model_and_diffusion(**self.model_config) |
|
self.model.load_state_dict( |
|
torch.load( |
|
"checkpoints/256x256_diffusion_uncond.pt" |
|
if self.args.model_output_size == 256 |
|
else "checkpoints/512x512_diffusion.pt", |
|
map_location="cpu", |
|
) |
|
) |
|
|
|
self.model.eval().to(self.device) |
|
for name, param in self.model.named_parameters(): |
|
if "qkv" in name or "norm" in name or "proj" in name: |
|
param.requires_grad_() |
|
if self.model_config["use_fp16"]: |
|
self.model.convert_to_fp16() |
|
|
|
self.classifier = create_classifier(**self.classifier_config) |
|
self.classifier.load_state_dict( |
|
torch.load("checkpoints/256x256_classifier.pt", map_location="cpu") |
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.classifier.eval().to(self.device) |
|
if self.classifier_config["classifier_use_fp16"]: |
|
self.classifier.convert_to_fp16() |
|
|
|
self.clip_model = ( |
|
clip.load("ViT-B/16", device=self.device, jit=False)[0].eval().requires_grad_(False) |
|
) |
|
self.clip_size = self.clip_model.visual.input_resolution |
|
self.clip_normalize = transforms.Normalize( |
|
mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] |
|
) |
|
self.to_tensor = transforms.ToTensor() |
|
self.lpips_model = lpips.LPIPS(net="vgg").to(self.device) |
|
|
|
self.image_augmentations = ImageAugmentations(self.clip_size, self.args.aug_num) |
|
self.metrics_accumulator = MetricsAccumulator() |
|
|
|
self.hf_loss = HighFrequencyLoss() |
|
|
|
|
|
def unscale_timestep(self, t): |
|
unscaled_timestep = (t * (self.diffusion.num_timesteps / 1000)).long() |
|
|
|
return unscaled_timestep |
|
|
|
|
|
def clip_loss(self, x_in, text_embed): |
|
clip_loss = torch.tensor(0) |
|
|
|
if self.mask is not None: |
|
masked_input = x_in * self.mask |
|
else: |
|
masked_input = x_in |
|
augmented_input = self.image_augmentations(masked_input).add(1).div(2) |
|
clip_in = self.clip_normalize(augmented_input) |
|
|
|
image_embeds = self.clip_model.encode_image(clip_in).float() |
|
dists = d_clip_loss(image_embeds, text_embed) |
|
|
|
|
|
for i in range(self.args.batch_size): |
|
|
|
clip_loss = clip_loss + dists[i :: self.args.batch_size].mean() |
|
|
|
return clip_loss |
|
|
|
def unaugmented_clip_distance(self, x, text_embed): |
|
x = F.resize(x, [self.clip_size, self.clip_size]) |
|
image_embeds = self.clip_model.encode_image(x).float() |
|
dists = d_clip_loss(image_embeds, text_embed) |
|
|
|
return dists.item() |
|
|
|
def model_fn(self, x,t,y=None): |
|
return self.model(x, t, y if self.args.class_cond else None) |
|
|
|
def edit_image_by_prompt(self): |
|
if self.args.image_guide: |
|
img_guidance = Image.open(self.args.prompt).convert('RGB') |
|
img_guidance = img_guidance.resize((224,224), Image.LANCZOS) |
|
img_guidance = self.clip_normalize(self.to_tensor(img_guidance).unsqueeze(0)).to(self.device) |
|
text_embed = self.clip_model.encode_image(img_guidance).float() |
|
|
|
else: |
|
text_embed = self.clip_model.encode_text( |
|
clip.tokenize(self.args.prompt).to(self.device) |
|
).float() |
|
|
|
self.image_size = (self.model_config["image_size"], self.model_config["image_size"]) |
|
self.init_image_pil = Image.open(self.args.init_image).convert("RGB") |
|
self.init_image_pil = self.init_image_pil.resize(self.image_size, Image.LANCZOS) |
|
self.init_image = ( |
|
TF.to_tensor(self.init_image_pil).to(self.device).unsqueeze(0).mul(2).sub(1) |
|
) |
|
self.init_image_pil_2 = Image.open(self.args.init_image_2).convert("RGB") |
|
if self.args.rotate_obj: |
|
|
|
angle = self.args.angle |
|
self.init_image_pil_2 = self.init_image_pil_2.rotate(angle) |
|
self.init_image_pil_2 = self.init_image_pil_2.resize(self.image_size, Image.LANCZOS) |
|
self.init_image_2 = ( |
|
TF.to_tensor(self.init_image_pil_2).to(self.device).unsqueeze(0).mul(2).sub(1) |
|
) |
|
|
|
''' |
|
# Init with the inpainting image |
|
self.init_image_pil_ = Image.open('output/ImageNet-S_val/bad_case_RN50/ILSVRC2012_val_00013212/ranked/08480_output_i_0_b_0.png').convert("RGB") |
|
self.init_image_pil_ = self.init_image_pil_.resize(self.image_size, Image.LANCZOS) # type: ignore |
|
self.init_image_ = ( |
|
TF.to_tensor(self.init_image_pil_).to(self.device).unsqueeze(0).mul(2).sub(1) |
|
) |
|
''' |
|
|
|
if self.args.export_assets: |
|
img_path = self.assets_path / Path(self.args.output_file) |
|
self.init_image_pil.save(img_path, quality=100) |
|
|
|
self.mask = torch.ones_like(self.init_image, device=self.device) |
|
self.mask_pil = None |
|
if self.args.mask is not None: |
|
self.mask_pil = Image.open(self.args.mask).convert("RGB") |
|
if self.args.rotate_obj: |
|
self.mask_pil = self.mask_pil.rotate(angle) |
|
if self.mask_pil.size != self.image_size: |
|
self.mask_pil = self.mask_pil.resize(self.image_size, Image.NEAREST) |
|
if self.args.random_position: |
|
bbox = find_bbox(np.array(self.mask_pil)) |
|
print(bbox) |
|
|
|
image_mask_pil_binarized = ((np.array(self.mask_pil) > 0.5) * 255).astype(np.uint8) |
|
|
|
if self.args.invert_mask: |
|
image_mask_pil_binarized = 255 - image_mask_pil_binarized |
|
self.mask_pil = TF.to_pil_image(image_mask_pil_binarized) |
|
self.mask = TF.to_tensor(Image.fromarray(image_mask_pil_binarized)) |
|
self.mask = self.mask[0, ...].unsqueeze(0).unsqueeze(0).to(self.device) |
|
|
|
|
|
if self.args.random_position: |
|
|
|
|
|
|
|
|
|
self.init_image_2, self.mask = change_place(self.init_image_2, self.mask, bbox, self.args.invert_mask) |
|
|
|
|
|
|
|
if self.args.export_assets: |
|
mask_path = self.assets_path / Path( |
|
self.args.output_file.replace(".png", "_mask.png") |
|
) |
|
self.mask_pil.save(mask_path, quality=100) |
|
|
|
def class_guided(x, y, t): |
|
assert y is not None |
|
with torch.enable_grad(): |
|
x_in = x.detach().requires_grad_(True) |
|
|
|
logits = self.classifier(x_in) |
|
log_probs = torch.nn.functional.log_softmax(logits, dim=-1) |
|
selected = log_probs[range(len(logits)), y.view(-1)] |
|
loss = selected.sum() |
|
|
|
return -torch.autograd.grad(loss, x_in)[0] * self.args.classifier_scale |
|
|
|
def cond_fn(x, t, y=None): |
|
if self.args.prompt == "": |
|
return torch.zeros_like(x) |
|
|
|
with torch.enable_grad(): |
|
x = x.detach().requires_grad_() |
|
|
|
t_unscale = self.unscale_timestep(t) |
|
|
|
''' |
|
out = self.diffusion.p_mean_variance( |
|
self.model, x, t, clip_denoised=False, model_kwargs={"y": y} |
|
) |
|
''' |
|
out = self.diffusion.p_mean_variance( |
|
self.model, x, t_unscale, clip_denoised=False, model_kwargs={"y": None} |
|
) |
|
|
|
fac = self.diffusion.sqrt_one_minus_alphas_cumprod[t_unscale[0].item()] |
|
|
|
x_in = out["pred_xstart"] |
|
|
|
loss = torch.tensor(0) |
|
if self.args.classifier_scale != 0 and y is not None: |
|
|
|
gradient_class_guided = class_guided(x_in, y, t) |
|
|
|
if self.args.background_complex != 0: |
|
if self.args.hard: |
|
loss = loss - self.args.background_complex*self.hf_loss((x_in+1.)/2.) |
|
else: |
|
loss = loss + self.args.background_complex*self.hf_loss((x_in+1.)/2.) |
|
|
|
if self.args.clip_guidance_lambda != 0: |
|
clip_loss = self.clip_loss(x_in, text_embed) * self.args.clip_guidance_lambda |
|
loss = loss + clip_loss |
|
self.metrics_accumulator.update_metric("clip_loss", clip_loss.item()) |
|
|
|
if self.args.range_lambda != 0: |
|
r_loss = range_loss(out["pred_xstart"]).sum() * self.args.range_lambda |
|
loss = loss + r_loss |
|
self.metrics_accumulator.update_metric("range_loss", r_loss.item()) |
|
|
|
if self.args.background_preservation_loss: |
|
x_in = out["pred_xstart"] * fac + x * (1 - fac) |
|
if self.mask is not None: |
|
|
|
masked_background = x_in * self.mask |
|
else: |
|
masked_background = x_in |
|
|
|
if self.args.lpips_sim_lambda: |
|
''' |
|
loss = ( |
|
loss |
|
+ self.lpips_model(masked_background, self.init_image).sum() |
|
* self.args.lpips_sim_lambda |
|
) |
|
''' |
|
|
|
loss = ( |
|
loss |
|
+ self.lpips_model(masked_background, self.init_image*self.mask).sum() |
|
* self.args.lpips_sim_lambda |
|
) |
|
if self.args.l2_sim_lambda: |
|
''' |
|
loss = ( |
|
loss |
|
+ mse_loss(masked_background, self.init_image) * self.args.l2_sim_lambda |
|
) |
|
''' |
|
|
|
loss = ( |
|
loss |
|
+ mse_loss(masked_background, self.init_image*self.mask) * self.args.l2_sim_lambda |
|
) |
|
|
|
|
|
if self.args.classifier_scale != 0 and y is not None: |
|
return -torch.autograd.grad(loss, x)[0] + gradient_class_guided |
|
else: |
|
return -torch.autograd.grad(loss, x)[0] |
|
|
|
@torch.no_grad() |
|
def postprocess_fn(out, t): |
|
if self.args.coarse_to_fine: |
|
if t > 50: |
|
kernel = 51 |
|
elif t > 35: |
|
kernel = 31 |
|
else: |
|
kernel = 0 |
|
if kernel > 0: |
|
max_pool = torch.nn.MaxPool2d(kernel_size=kernel, stride=1, padding=int((kernel-1)/2)) |
|
self.mask_d = 1 - self.mask |
|
self.mask_d = max_pool(self.mask_d) |
|
self.mask_d = 1 - self.mask_d |
|
else: |
|
self.mask_d = self.mask |
|
else: |
|
self.mask_d = self.mask |
|
|
|
if self.mask is not None: |
|
background_stage_t = self.diffusion.q_sample(self.init_image_2, t[0]) |
|
background_stage_t = torch.tile( |
|
background_stage_t, dims=(self.args.batch_size, 1, 1, 1) |
|
) |
|
out["sample"] = out["sample"] * self.mask_d + background_stage_t * (1 - self.mask_d) |
|
|
|
return out |
|
|
|
save_image_interval = self.diffusion.num_timesteps // 5 |
|
for iteration_number in range(self.args.iterations_num): |
|
print(f"Start iterations {iteration_number}") |
|
|
|
sample_func = ( |
|
self.diffusion.ddim_sample_loop_progressive |
|
if self.args.ddim |
|
else self.diffusion.p_sample_loop_progressive |
|
) |
|
samples = sample_func( |
|
self.model_fn, |
|
( |
|
self.args.batch_size, |
|
3, |
|
self.model_config["image_size"], |
|
self.model_config["image_size"], |
|
), |
|
clip_denoised=False, |
|
|
|
|
|
|
|
|
|
|
|
model_kwargs={} |
|
if self.args.classifier_scale == 0 |
|
else {"y": self.args.y*torch.ones([self.args.batch_size], device=self.device, dtype=torch.long)}, |
|
cond_fn=cond_fn, |
|
device=self.device, |
|
progress=True, |
|
skip_timesteps=self.args.skip_timesteps, |
|
init_image=self.init_image, |
|
|
|
postprocess_fn=None if self.args.local_clip_guided_diffusion else postprocess_fn, |
|
randomize_class=True if self.args.classifier_scale == 0 else False, |
|
) |
|
|
|
intermediate_samples = [[] for i in range(self.args.batch_size)] |
|
total_steps = self.diffusion.num_timesteps - self.args.skip_timesteps - 1 |
|
for j, sample in enumerate(samples): |
|
should_save_image = j % save_image_interval == 0 or j == total_steps |
|
if should_save_image or self.args.save_video: |
|
self.metrics_accumulator.print_average_metric() |
|
|
|
for b in range(self.args.batch_size): |
|
pred_image = sample["pred_xstart"][b] |
|
visualization_path = Path( |
|
os.path.join(self.args.output_path, self.args.output_file) |
|
) |
|
visualization_path = visualization_path.with_stem( |
|
f"{visualization_path.stem}_i_{iteration_number}_b_{b}" |
|
) |
|
if ( |
|
self.mask is not None |
|
and self.args.enforce_background |
|
and j == total_steps |
|
and not self.args.local_clip_guided_diffusion |
|
): |
|
pred_image = ( |
|
self.init_image_2[0] * (1 - self.mask[0]) + pred_image * self.mask[0] |
|
) |
|
''' |
|
if j == total_steps: |
|
pdb.set_trace() |
|
pred_image = ( |
|
self.init_image_2[0] * (1 - self.mask[0]) + pred_image * self.mask[0] |
|
) |
|
''' |
|
pred_image = pred_image.add(1).div(2).clamp(0, 1) |
|
pred_image_pil = TF.to_pil_image(pred_image) |
|
masked_pred_image = self.mask * pred_image.unsqueeze(0) |
|
final_distance = self.unaugmented_clip_distance( |
|
masked_pred_image, text_embed |
|
) |
|
formatted_distance = f"{final_distance:.4f}" |
|
|
|
if self.args.export_assets: |
|
pred_path = self.assets_path / visualization_path.name |
|
pred_image_pil.save(pred_path, quality=100) |
|
|
|
if j == total_steps: |
|
path_friendly_distance = formatted_distance.replace(".", "") |
|
ranked_pred_path = self.ranked_results_path / ( |
|
path_friendly_distance + "_" + visualization_path.name |
|
) |
|
pred_image_pil.save(ranked_pred_path, quality=100) |
|
|
|
intermediate_samples[b].append(pred_image_pil) |
|
if should_save_image: |
|
show_editied_masked_image( |
|
title=self.args.prompt, |
|
source_image=self.init_image_pil, |
|
edited_image=pred_image_pil, |
|
mask=self.mask_pil, |
|
path=visualization_path, |
|
distance=formatted_distance, |
|
) |
|
|
|
if self.args.save_video: |
|
for b in range(self.args.batch_size): |
|
video_name = self.args.output_file.replace( |
|
".png", f"_i_{iteration_number}_b_{b}.avi" |
|
) |
|
video_path = os.path.join(self.args.output_path, video_name) |
|
save_video(intermediate_samples[b], video_path) |
|
|
|
visualize_size = (256,256) |
|
img_ori = cv2.imread(self.args.init_image_2) |
|
img_ori = cv2.resize(img_ori, visualize_size) |
|
mask = cv2.imread(self.args.mask) |
|
mask = cv2.resize(mask, visualize_size) |
|
imgs = [img_ori, mask] |
|
for ii, img_name in enumerate(os.listdir(os.path.join(self.args.output_path, 'ranked'))): |
|
img_path = os.path.join(self.args.output_path, 'ranked', img_name) |
|
img = cv2.imread(img_path) |
|
img = cv2.resize(img, visualize_size) |
|
imgs.append(img) |
|
if ii >= 7: |
|
break |
|
|
|
img_whole = cv2.hconcat(imgs[2:]) |
|
''' |
|
img_name = self.args.output_path.split('/')[-2]+'/' |
|
if self.args.coarse_to_fine: |
|
if self.args.clip_guidance_lambda == 0: |
|
prompt = 'coarse_to_fine_no_clip' |
|
else: |
|
prompt = 'coarse_to_fine' |
|
elif self.args.image_guide: |
|
prompt = 'image_guide' |
|
elif self.args.clip_guidance_lambda == 0: |
|
prompt = 'no_clip_guide' |
|
else: |
|
prompt = 'text_guide' |
|
''' |
|
|
|
cv2.imwrite(os.path.join(self.args.final_save_root, 'edited.png'), img_whole, [int(cv2.IMWRITE_PNG_COMPRESSION), 0]) |
|
|
|
|
|
def reconstruct_image(self): |
|
init = Image.open(self.args.init_image).convert("RGB") |
|
init = init.resize( |
|
self.image_size, |
|
Image.LANCZOS, |
|
) |
|
init = TF.to_tensor(init).to(self.device).unsqueeze(0).mul(2).sub(1) |
|
|
|
samples = self.diffusion.p_sample_loop_progressive( |
|
self.model, |
|
(1, 3, self.model_config["image_size"], self.model_config["image_size"],), |
|
clip_denoised=False, |
|
model_kwargs={} |
|
if self.args.model_output_size == 256 |
|
else {"y": torch.zeros([self.args.batch_size], device=self.device, dtype=torch.long)}, |
|
cond_fn=None, |
|
progress=True, |
|
skip_timesteps=self.args.skip_timesteps, |
|
init_image=init, |
|
randomize_class=True, |
|
) |
|
save_image_interval = self.diffusion.num_timesteps // 5 |
|
max_iterations = self.diffusion.num_timesteps - self.args.skip_timesteps - 1 |
|
|
|
for j, sample in enumerate(samples): |
|
if j % save_image_interval == 0 or j == max_iterations: |
|
print() |
|
filename = os.path.join(self.args.output_path, self.args.output_file) |
|
TF.to_pil_image(sample["pred_xstart"][0].add(1).div(2).clamp(0, 1)).save(filename) |
|
|