Anonymous-123's picture
Add application file
ec0fdfd
raw
history blame
23.9 kB
import os
from pathlib import Path
from optimization.constants import ASSETS_DIR_NAME, RANKED_RESULTS_DIR
from utils.metrics_accumulator import MetricsAccumulator
from utils.video import save_video
from utils.fft_pytorch import HighFrequencyLoss
from numpy import random
from optimization.augmentations import ImageAugmentations
from PIL import Image
import torch
import torchvision
from torchvision import transforms
import torchvision.transforms.functional as F
from torchvision.transforms import functional as TF
from torch.nn.functional import mse_loss
from optimization.losses import range_loss, d_clip_loss
import lpips
import numpy as np
from CLIP import clip
from guided_diffusion.guided_diffusion.script_util import (
create_model_and_diffusion,
model_and_diffusion_defaults,
create_classifier,
classifier_defaults,
)
from utils.visualization import show_tensor_image, show_editied_masked_image
from utils.change_place import change_place, find_bbox
import pdb
import cv2
def create_classifier_ours():
model = torchvision.models.resnet50()
ckpt = torch.load('checkpoints/DRA_resnet50.pth')['model_state_dict']
model.load_state_dict({k.replace('module.','').replace('last_linear','fc'):v for k,v in ckpt.items()})
model = torch.nn.Sequential(*[torch.nn.Upsample(size=(256,256)), model])
return model
class ImageEditor:
def __init__(self, args) -> None:
self.args = args
os.makedirs(self.args.output_path, exist_ok=True)
self.ranked_results_path = Path(os.path.join(self.args.output_path, RANKED_RESULTS_DIR))
os.makedirs(self.ranked_results_path, exist_ok=True)
if self.args.export_assets:
self.assets_path = Path(os.path.join(self.args.output_path, ASSETS_DIR_NAME))
os.makedirs(self.assets_path, exist_ok=True)
if self.args.seed is not None:
torch.manual_seed(self.args.seed)
np.random.seed(self.args.seed)
random.seed(self.args.seed)
self.model_config = model_and_diffusion_defaults()
self.model_config.update(
{
"attention_resolutions": "32, 16, 8",
"class_cond": self.args.model_output_size == 512,
"diffusion_steps": 1000,
"rescale_timesteps": True,
"timestep_respacing": self.args.timestep_respacing,
"image_size": self.args.model_output_size,
"learn_sigma": True,
"noise_schedule": "linear",
"num_channels": 256,
"num_head_channels": 64,
"num_res_blocks": 2,
"resblock_updown": True,
"use_fp16": True,
"use_scale_shift_norm": True,
}
)
self.classifier_config = classifier_defaults()
self.classifier_config.update(
{
"image_size": self.args.model_output_size,
}
)
# Load models
self.device = torch.device(
f"cuda:{self.args.gpu_id}" if torch.cuda.is_available() else "cpu"
)
print("Using device:", self.device)
self.model, self.diffusion = create_model_and_diffusion(**self.model_config)
self.model.load_state_dict(
torch.load(
"checkpoints/256x256_diffusion_uncond.pt"
if self.args.model_output_size == 256
else "checkpoints/512x512_diffusion.pt",
map_location="cpu",
)
)
# self.model.requires_grad_(False).eval().to(self.device)
self.model.eval().to(self.device)
for name, param in self.model.named_parameters():
if "qkv" in name or "norm" in name or "proj" in name:
param.requires_grad_()
if self.model_config["use_fp16"]:
self.model.convert_to_fp16()
self.classifier = create_classifier(**self.classifier_config)
self.classifier.load_state_dict(
torch.load("checkpoints/256x256_classifier.pt", map_location="cpu")
)
# self.classifier.requires_grad_(False).eval().to(self.device)
# self.classifier = create_classifier_ours()
self.classifier.eval().to(self.device)
if self.classifier_config["classifier_use_fp16"]:
self.classifier.convert_to_fp16()
self.clip_model = (
clip.load("ViT-B/16", device=self.device, jit=False)[0].eval().requires_grad_(False)
)
self.clip_size = self.clip_model.visual.input_resolution
self.clip_normalize = transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
)
self.to_tensor = transforms.ToTensor()
self.lpips_model = lpips.LPIPS(net="vgg").to(self.device)
self.image_augmentations = ImageAugmentations(self.clip_size, self.args.aug_num)
self.metrics_accumulator = MetricsAccumulator()
self.hf_loss = HighFrequencyLoss()
def unscale_timestep(self, t):
unscaled_timestep = (t * (self.diffusion.num_timesteps / 1000)).long()
return unscaled_timestep
def clip_loss(self, x_in, text_embed):
clip_loss = torch.tensor(0)
if self.mask is not None:
masked_input = x_in * self.mask
else:
masked_input = x_in
augmented_input = self.image_augmentations(masked_input).add(1).div(2) # shape: [N,C,H,W], range: [0,1]
clip_in = self.clip_normalize(augmented_input)
# pdb.set_trace()
image_embeds = self.clip_model.encode_image(clip_in).float()
dists = d_clip_loss(image_embeds, text_embed)
# We want to sum over the averages
for i in range(self.args.batch_size):
# We want to average at the "augmentations level"
clip_loss = clip_loss + dists[i :: self.args.batch_size].mean()
return clip_loss
def unaugmented_clip_distance(self, x, text_embed):
x = F.resize(x, [self.clip_size, self.clip_size])
image_embeds = self.clip_model.encode_image(x).float()
dists = d_clip_loss(image_embeds, text_embed)
return dists.item()
def model_fn(self, x,t,y=None):
return self.model(x, t, y if self.args.class_cond else None)
def edit_image_by_prompt(self):
if self.args.image_guide:
img_guidance = Image.open(self.args.prompt).convert('RGB')
img_guidance = img_guidance.resize((224,224), Image.LANCZOS) # type: ignore
img_guidance = self.clip_normalize(self.to_tensor(img_guidance).unsqueeze(0)).to(self.device)
text_embed = self.clip_model.encode_image(img_guidance).float()
else:
text_embed = self.clip_model.encode_text(
clip.tokenize(self.args.prompt).to(self.device)
).float()
self.image_size = (self.model_config["image_size"], self.model_config["image_size"])
self.init_image_pil = Image.open(self.args.init_image).convert("RGB")
self.init_image_pil = self.init_image_pil.resize(self.image_size, Image.LANCZOS) # type: ignore
self.init_image = (
TF.to_tensor(self.init_image_pil).to(self.device).unsqueeze(0).mul(2).sub(1)
)
self.init_image_pil_2 = Image.open(self.args.init_image_2).convert("RGB")
if self.args.rotate_obj:
# angle = random.randint(-45,45)
angle = self.args.angle
self.init_image_pil_2 = self.init_image_pil_2.rotate(angle)
self.init_image_pil_2 = self.init_image_pil_2.resize(self.image_size, Image.LANCZOS) # type: ignore
self.init_image_2 = (
TF.to_tensor(self.init_image_pil_2).to(self.device).unsqueeze(0).mul(2).sub(1)
)
'''
# Init with the inpainting image
self.init_image_pil_ = Image.open('output/ImageNet-S_val/bad_case_RN50/ILSVRC2012_val_00013212/ranked/08480_output_i_0_b_0.png').convert("RGB")
self.init_image_pil_ = self.init_image_pil_.resize(self.image_size, Image.LANCZOS) # type: ignore
self.init_image_ = (
TF.to_tensor(self.init_image_pil_).to(self.device).unsqueeze(0).mul(2).sub(1)
)
'''
if self.args.export_assets:
img_path = self.assets_path / Path(self.args.output_file)
self.init_image_pil.save(img_path, quality=100)
self.mask = torch.ones_like(self.init_image, device=self.device)
self.mask_pil = None
if self.args.mask is not None:
self.mask_pil = Image.open(self.args.mask).convert("RGB")
if self.args.rotate_obj:
self.mask_pil = self.mask_pil.rotate(angle)
if self.mask_pil.size != self.image_size:
self.mask_pil = self.mask_pil.resize(self.image_size, Image.NEAREST) # type: ignore
if self.args.random_position:
bbox = find_bbox(np.array(self.mask_pil))
print(bbox)
image_mask_pil_binarized = ((np.array(self.mask_pil) > 0.5) * 255).astype(np.uint8)
# image_mask_pil_binarized = cv2.dilate(image_mask_pil_binarized, np.ones((50,50), np.uint8), iterations=1)
if self.args.invert_mask:
image_mask_pil_binarized = 255 - image_mask_pil_binarized
self.mask_pil = TF.to_pil_image(image_mask_pil_binarized)
self.mask = TF.to_tensor(Image.fromarray(image_mask_pil_binarized))
self.mask = self.mask[0, ...].unsqueeze(0).unsqueeze(0).to(self.device)
# self.mask[:] = 1
if self.args.random_position:
# print(self.init_image_2.shape, self.init_image_2.max(), self.init_image_2.min())
# print(self.mask.shape, self.mask.max(), self.mask.min())
# cv2.imwrite('tmp/init_before.jpg', np.transpose(((self.init_image_2+1)/2*255).cpu().numpy()[0], (1,2,0))[:,:,::-1])
# cv2.imwrite('tmp/mask_before.jpg', (self.mask*255).cpu().numpy()[0][0])
self.init_image_2, self.mask = change_place(self.init_image_2, self.mask, bbox, self.args.invert_mask)
# cv2.imwrite('tmp/init_after.jpg', np.transpose(((self.init_image_2+1)/2*255).cpu().numpy()[0], (1,2,0))[:,:,::-1])
# cv2.imwrite('tmp/mask_after.jpg', (self.mask*255).cpu().numpy()[0][0])
if self.args.export_assets:
mask_path = self.assets_path / Path(
self.args.output_file.replace(".png", "_mask.png")
)
self.mask_pil.save(mask_path, quality=100)
def class_guided(x, y, t):
assert y is not None
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
# logits = self.classifier(x_in, t)
logits = self.classifier(x_in)
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), y.view(-1)]
loss = selected.sum()
return -torch.autograd.grad(loss, x_in)[0] * self.args.classifier_scale
def cond_fn(x, t, y=None):
if self.args.prompt == "":
return torch.zeros_like(x)
# pdb.set_trace()
with torch.enable_grad():
x = x.detach().requires_grad_()
t_unscale = self.unscale_timestep(t)
'''
out = self.diffusion.p_mean_variance(
self.model, x, t, clip_denoised=False, model_kwargs={"y": y}
)
'''
out = self.diffusion.p_mean_variance(
self.model, x, t_unscale, clip_denoised=False, model_kwargs={"y": None}
)
fac = self.diffusion.sqrt_one_minus_alphas_cumprod[t_unscale[0].item()]
# x_in = out["pred_xstart"] * fac + x * (1 - fac)
x_in = out["pred_xstart"] # Revised by XX, 2022.07.14
loss = torch.tensor(0)
if self.args.classifier_scale != 0 and y is not None:
# gradient_class_guided = class_guided(x, y, t)
gradient_class_guided = class_guided(x_in, y, t)
if self.args.background_complex != 0:
if self.args.hard:
loss = loss - self.args.background_complex*self.hf_loss((x_in+1.)/2.)
else:
loss = loss + self.args.background_complex*self.hf_loss((x_in+1.)/2.)
if self.args.clip_guidance_lambda != 0:
clip_loss = self.clip_loss(x_in, text_embed) * self.args.clip_guidance_lambda
loss = loss + clip_loss
self.metrics_accumulator.update_metric("clip_loss", clip_loss.item())
if self.args.range_lambda != 0:
r_loss = range_loss(out["pred_xstart"]).sum() * self.args.range_lambda
loss = loss + r_loss
self.metrics_accumulator.update_metric("range_loss", r_loss.item())
if self.args.background_preservation_loss:
x_in = out["pred_xstart"] * fac + x * (1 - fac)
if self.mask is not None:
# masked_background = x_in * (1 - self.mask)
masked_background = x_in * self.mask # 2022.07.19
else:
masked_background = x_in
if self.args.lpips_sim_lambda:
'''
loss = (
loss
+ self.lpips_model(masked_background, self.init_image).sum()
* self.args.lpips_sim_lambda
)
'''
# 2022.07.19
loss = (
loss
+ self.lpips_model(masked_background, self.init_image*self.mask).sum()
* self.args.lpips_sim_lambda
)
if self.args.l2_sim_lambda:
'''
loss = (
loss
+ mse_loss(masked_background, self.init_image) * self.args.l2_sim_lambda
)
'''
# 2022.07.19
loss = (
loss
+ mse_loss(masked_background, self.init_image*self.mask) * self.args.l2_sim_lambda
)
if self.args.classifier_scale != 0 and y is not None:
return -torch.autograd.grad(loss, x)[0] + gradient_class_guided
else:
return -torch.autograd.grad(loss, x)[0]
@torch.no_grad()
def postprocess_fn(out, t):
if self.args.coarse_to_fine:
if t > 50:
kernel = 51
elif t > 35:
kernel = 31
else:
kernel = 0
if kernel > 0:
max_pool = torch.nn.MaxPool2d(kernel_size=kernel, stride=1, padding=int((kernel-1)/2))
self.mask_d = 1 - self.mask
self.mask_d = max_pool(self.mask_d)
self.mask_d = 1 - self.mask_d
else:
self.mask_d = self.mask
else:
self.mask_d = self.mask
if self.mask is not None:
background_stage_t = self.diffusion.q_sample(self.init_image_2, t[0])
background_stage_t = torch.tile(
background_stage_t, dims=(self.args.batch_size, 1, 1, 1)
)
out["sample"] = out["sample"] * self.mask_d + background_stage_t * (1 - self.mask_d)
return out
save_image_interval = self.diffusion.num_timesteps // 5
for iteration_number in range(self.args.iterations_num):
print(f"Start iterations {iteration_number}")
sample_func = (
self.diffusion.ddim_sample_loop_progressive
if self.args.ddim
else self.diffusion.p_sample_loop_progressive
)
samples = sample_func(
self.model_fn,
(
self.args.batch_size,
3,
self.model_config["image_size"],
self.model_config["image_size"],
),
clip_denoised=False,
# model_kwargs={}
# if self.args.model_output_size == 256
# else {
# "y": torch.zeros([self.args.batch_size], device=self.device, dtype=torch.long)
# },
model_kwargs={}
if self.args.classifier_scale == 0
else {"y": self.args.y*torch.ones([self.args.batch_size], device=self.device, dtype=torch.long)},
cond_fn=cond_fn,
device=self.device,
progress=True,
skip_timesteps=self.args.skip_timesteps,
init_image=self.init_image,
# init_image=self.init_image_,
postprocess_fn=None if self.args.local_clip_guided_diffusion else postprocess_fn,
randomize_class=True if self.args.classifier_scale == 0 else False,
)
intermediate_samples = [[] for i in range(self.args.batch_size)]
total_steps = self.diffusion.num_timesteps - self.args.skip_timesteps - 1
for j, sample in enumerate(samples):
should_save_image = j % save_image_interval == 0 or j == total_steps
if should_save_image or self.args.save_video:
self.metrics_accumulator.print_average_metric()
for b in range(self.args.batch_size):
pred_image = sample["pred_xstart"][b]
visualization_path = Path(
os.path.join(self.args.output_path, self.args.output_file)
)
visualization_path = visualization_path.with_stem(
f"{visualization_path.stem}_i_{iteration_number}_b_{b}"
)
if (
self.mask is not None
and self.args.enforce_background
and j == total_steps
and not self.args.local_clip_guided_diffusion
):
pred_image = (
self.init_image_2[0] * (1 - self.mask[0]) + pred_image * self.mask[0]
)
'''
if j == total_steps:
pdb.set_trace()
pred_image = (
self.init_image_2[0] * (1 - self.mask[0]) + pred_image * self.mask[0]
)
'''
pred_image = pred_image.add(1).div(2).clamp(0, 1)
pred_image_pil = TF.to_pil_image(pred_image)
masked_pred_image = self.mask * pred_image.unsqueeze(0)
final_distance = self.unaugmented_clip_distance(
masked_pred_image, text_embed
)
formatted_distance = f"{final_distance:.4f}"
if self.args.export_assets:
pred_path = self.assets_path / visualization_path.name
pred_image_pil.save(pred_path, quality=100)
if j == total_steps:
path_friendly_distance = formatted_distance.replace(".", "")
ranked_pred_path = self.ranked_results_path / (
path_friendly_distance + "_" + visualization_path.name
)
pred_image_pil.save(ranked_pred_path, quality=100)
intermediate_samples[b].append(pred_image_pil)
if should_save_image:
show_editied_masked_image(
title=self.args.prompt,
source_image=self.init_image_pil,
edited_image=pred_image_pil,
mask=self.mask_pil,
path=visualization_path,
distance=formatted_distance,
)
if self.args.save_video:
for b in range(self.args.batch_size):
video_name = self.args.output_file.replace(
".png", f"_i_{iteration_number}_b_{b}.avi"
)
video_path = os.path.join(self.args.output_path, video_name)
save_video(intermediate_samples[b], video_path)
visualize_size = (256,256)
img_ori = cv2.imread(self.args.init_image_2)
img_ori = cv2.resize(img_ori, visualize_size)
mask = cv2.imread(self.args.mask)
mask = cv2.resize(mask, visualize_size)
imgs = [img_ori, mask]
for ii, img_name in enumerate(os.listdir(os.path.join(self.args.output_path, 'ranked'))):
img_path = os.path.join(self.args.output_path, 'ranked', img_name)
img = cv2.imread(img_path)
img = cv2.resize(img, visualize_size)
imgs.append(img)
if ii >= 7:
break
img_whole = cv2.hconcat(imgs[2:])
'''
img_name = self.args.output_path.split('/')[-2]+'/'
if self.args.coarse_to_fine:
if self.args.clip_guidance_lambda == 0:
prompt = 'coarse_to_fine_no_clip'
else:
prompt = 'coarse_to_fine'
elif self.args.image_guide:
prompt = 'image_guide'
elif self.args.clip_guidance_lambda == 0:
prompt = 'no_clip_guide'
else:
prompt = 'text_guide'
'''
cv2.imwrite(os.path.join(self.args.final_save_root, 'edited.png'), img_whole, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
def reconstruct_image(self):
init = Image.open(self.args.init_image).convert("RGB")
init = init.resize(
self.image_size, # type: ignore
Image.LANCZOS,
)
init = TF.to_tensor(init).to(self.device).unsqueeze(0).mul(2).sub(1)
samples = self.diffusion.p_sample_loop_progressive(
self.model,
(1, 3, self.model_config["image_size"], self.model_config["image_size"],),
clip_denoised=False,
model_kwargs={}
if self.args.model_output_size == 256
else {"y": torch.zeros([self.args.batch_size], device=self.device, dtype=torch.long)},
cond_fn=None,
progress=True,
skip_timesteps=self.args.skip_timesteps,
init_image=init,
randomize_class=True,
)
save_image_interval = self.diffusion.num_timesteps // 5
max_iterations = self.diffusion.num_timesteps - self.args.skip_timesteps - 1
for j, sample in enumerate(samples):
if j % save_image_interval == 0 or j == max_iterations:
print()
filename = os.path.join(self.args.output_path, self.args.output_file)
TF.to_pil_image(sample["pred_xstart"][0].add(1).div(2).clamp(0, 1)).save(filename)