Spaces:
Sleeping
Sleeping
import math | |
import numpy as np | |
from omegaconf import OmegaConf | |
from pathlib import Path | |
import cv2 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.cuda.amp import custom_bwd, custom_fwd | |
from torchvision.utils import save_image | |
from torchvision.ops import masks_to_boxes | |
from torchvision.transforms import Resize | |
from diffusers import DDIMScheduler, DDPMScheduler | |
from einops import rearrange, repeat | |
from tqdm import tqdm | |
import sys | |
from os import path | |
sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) | |
sys.path.append("./models/") | |
from loguru import logger | |
from ldm.util import instantiate_from_config | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.modules.diffusionmodules.util import extract_into_tensor | |
# load model | |
def load_model_from_config(config, ckpt, device, vram_O=False, verbose=True): | |
pl_sd = torch.load(ckpt, map_location='cpu') | |
if 'global_step' in pl_sd and verbose: | |
logger.info(f'Global Step: {pl_sd["global_step"]}') | |
sd = pl_sd['state_dict'] | |
model = instantiate_from_config(config.model) | |
m, u = model.load_state_dict(sd, strict=False) | |
if len(m) > 0: | |
logger.warning('missing keys: \n', m) | |
if len(u) > 0: | |
logger.warning('unexpected keys: \n', u) | |
# manually load ema and delete it to save GPU memory | |
if model.use_ema: | |
logger.debug('loading EMA...') | |
model.model_ema.copy_to(model.model) | |
del model.model_ema | |
if vram_O: | |
# we don't need decoder | |
del model.first_stage_model.decoder | |
torch.cuda.empty_cache() | |
model.eval().to(device) | |
# model.first_stage_model.train = True | |
# model.first_stage_model.train() | |
for param in model.first_stage_model.parameters(): | |
param.requires_grad = True | |
return model | |
class MateralDiffusion(nn.Module): | |
def __init__(self, device, fp16, | |
config=None, | |
ckpt=None, vram_O=False, t_range=[0.02, 0.98], opt=None, use_ddim=True): | |
super().__init__() | |
self.device = device | |
self.fp16 = fp16 | |
self.vram_O = vram_O | |
self.t_range = t_range | |
self.opt = opt | |
self.config = OmegaConf.load(config) | |
# TODO: seems it cannot load into fp16... | |
self.model = load_model_from_config(self.config, ckpt, device=self.device, vram_O=vram_O, verbose=True) | |
# timesteps: use diffuser for convenience... hope it's alright. | |
self.num_train_timesteps = self.config.model.params.timesteps | |
self.use_ddim = use_ddim | |
if self.use_ddim: | |
self.scheduler = DDIMScheduler( | |
self.num_train_timesteps, | |
self.config.model.params.linear_start, | |
self.config.model.params.linear_end, | |
beta_schedule='scaled_linear', | |
clip_sample=False, | |
set_alpha_to_one=False, | |
steps_offset=1, | |
) | |
print("Using DDIM...") | |
else: | |
self.scheduler = DDPMScheduler( | |
self.num_train_timesteps, | |
self.config.model.params.linear_start, | |
self.config.model.params.linear_end, | |
beta_schedule='scaled_linear', | |
clip_sample=False, | |
) | |
print("Using DDPM...") | |
self.min_step = int(self.num_train_timesteps * t_range[0]) | |
self.max_step = int(self.num_train_timesteps * t_range[1]) | |
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience | |
def get_input(self, x): | |
if len(x.shape) == 3: | |
x = x[..., None] | |
x = rearrange(x, 'b h w c -> b c h w') | |
x = x.to(memory_format=torch.contiguous_format).float() | |
return x | |
def center_crop(self, img, mask, return_uv=False, mask_ratio=.8, image_size=256): | |
margin = np.round((1 - mask_ratio) * image_size).astype(int) | |
resizer = Resize([np.round(image_size-margin*2).astype(int), | |
np.round(image_size-margin*2).astype(int)]) | |
# img ~ batch, h, w, 3 | |
# mask ~ batch, h, w, 3 | |
# ensure border is 0, as grid sampler only support border or zeros padding | |
# But we need the one padding | |
batch_size = img.shape[0] | |
min_max_uv = masks_to_boxes(mask[..., -1] > 0.5) | |
min_uv, max_uv = min_max_uv[..., [1,0]].long(), (min_max_uv[..., [3,2]] + 1).long() | |
# fill back ground to ones | |
img = (img + (mask[..., -1:] <= 0.5)).clamp(0, 1) | |
img = rearrange(img, 'b h w c -> b c h w') | |
ori_size = torch.tensor(img.shape[-2:]).to(min_max_uv.device).reshape(1, 2).expand(img.shape[0], -1) | |
crooped_imgs = [] | |
for batch_idx in range(batch_size): | |
# print(min_uv, max_uv, margin) | |
img_crop = img[batch_idx][:, min_uv[batch_idx, 0]:max_uv[batch_idx, 0], | |
min_uv[batch_idx,1]:max_uv[batch_idx, 1]] | |
img_crop = resizer(img_crop) | |
img_out = torch.ones(3, image_size, image_size).to(img.device) | |
img_out[:, margin:image_size-margin, margin:image_size-margin] = img_crop | |
crooped_imgs.append(img_out) | |
img_new = torch.stack(crooped_imgs, dim=0) | |
img_new = rearrange(img_new, 'b c h w -> b h w c') | |
crop_uv = torch.stack([ori_size[:, 0], ori_size[:, 1], min_uv[:, 0], min_uv[:, 1], max_uv[:, 0], max_uv[:, 1], max_uv[:, 1]*0+margin], dim=-1).float() | |
if return_uv: | |
return img_new, crop_uv | |
return img_new | |
def center_crop_aspect_ratio(self, img, mask, return_uv=False, mask_ratio=.8, image_size=256): | |
# img ~ batch, h, w, 3 | |
# mask ~ batch, h, w, 3 | |
# ensure border is 0, as grid sampler only support border or zeros padding | |
# But we need the one padding | |
boarder_mask = torch.zeros_like(mask) | |
boarder_mask[:, 1:-1, 1:-1] = 1 | |
mask = mask * boarder_mask | |
# print(f"mask: {mask.shape}, {(mask[..., -1] > 0.5).sum}") | |
min_max_uv = masks_to_boxes(mask[..., -1] > 0.5) | |
min_uv, max_uv = min_max_uv[..., [1,0]], min_max_uv[..., [3,2]] | |
# fill back ground to ones | |
img = (img + (mask[..., -1:] <= 0.5)).clamp(0, 1) | |
img = rearrange(img, 'b h w c -> b c h w') | |
ori_size = torch.tensor(img.shape[-2:]).to(min_max_uv.device).reshape(1, 2).expand(img.shape[0], -1) | |
crop_length = torch.div((max_uv - min_uv), 2, rounding_mode='floor') | |
half_size = torch.max(crop_length, dim=-1, keepdim=True)[0] | |
center_uv = min_uv + crop_length | |
# generate grid | |
target_size = image_size | |
grid_x, grid_y = torch.meshgrid(torch.arange(0, target_size, 1, device=min_max_uv.device), \ | |
torch.arange(0, target_size, 1, device=min_max_uv.device), \ | |
indexing='ij') | |
normalized_xy = torch.stack([(grid_x) / (target_size - 1), grid_y / (target_size - 1)], dim=-1) # [0,1] | |
normalized_xy = (normalized_xy - 0.5) / mask_ratio + 0.5 | |
normalized_xy = normalized_xy[None].expand(img.shape[0], -1, -1, -1) | |
ori_crop_size = 2 * half_size + 1 | |
xy_scale = (ori_crop_size-1) / (ori_size - 1) | |
normalized_xy = normalized_xy * xy_scale.reshape(-1, 1, 1, 2)[..., [0,1]] | |
xy_shift = (center_uv - half_size) / (ori_size - 1) | |
normalized_xy = normalized_xy + xy_shift.reshape(-1, 1, 1, 2)[..., [0,1]] | |
normalized_xy = normalized_xy * 2 - 1 # [-1,1] | |
# normalized_xy = normalized_xy / mask_ratio | |
img_new = F.grid_sample(img, normalized_xy[..., [1,0]], padding_mode='border', align_corners=True) | |
crop_uv = torch.stack([ori_size[:, 0], ori_size[:, 1], half_size[..., 0]*0.0 + mask_ratio, half_size[..., 0], center_uv[:, 0], center_uv[:, 1]], dim=-1).float() | |
img_new = rearrange(img_new, 'b c h w -> b h w c') | |
if return_uv: | |
return img_new, crop_uv | |
return img_new | |
def restore_crop(self, img, img_ori, crop_idx): | |
ori_size, min_uv, max_uv, margin = crop_idx[:, :2].long(), crop_idx[:, 2:4].long(), crop_idx[:, 4:6].long(), crop_idx[0, 6].long().item() | |
batch_size = img.shape[0] | |
all_images = [] | |
for batch_idx in range(batch_size): | |
img_out = torch.ones(3, ori_size[batch_idx][0], ori_size[batch_idx][1]).to(img.device) | |
cropped_size = max_uv[batch_idx] - min_uv[batch_idx] | |
resizer = Resize([cropped_size[0], cropped_size[1]]) | |
net_size = img[batch_idx].shape[-1] | |
img_crop = resizer(img[batch_idx][:, margin:net_size-margin, margin:net_size-margin]) | |
img_out[:, min_uv[batch_idx, 0]:max_uv[batch_idx, 0], | |
min_uv[batch_idx,1]:max_uv[batch_idx, 1]] = img_crop | |
all_images.append(img_out) | |
all_images = torch.stack(all_images, dim=0) | |
all_images = rearrange(all_images, 'b c h w -> b h w c') | |
return all_images | |
def restore_crop_aspect_ratio(self, img, img_ori, crop_idx): | |
ori_size, mask_ratio, half_size, center_uv = crop_idx[:, :2].long(), crop_idx[:, 2:3], crop_idx[:, 3:4].long(), crop_idx[:, 4:].long() | |
img[:, :, 0, :] = 1 | |
img[:, :, -1, :] = 1 | |
img[:, :, :, 0] = 1 | |
img[:, :, :, -1] = 1 | |
ori_crop_size = 2*half_size + 1 | |
grid_x, grid_y = torch.meshgrid(torch.arange(0, ori_size[0, 0].item(), 1, device=img.device), \ | |
torch.arange(0, ori_size[0, 1].item(), 1, device=img.device), \ | |
indexing='ij') | |
normalized_xy = torch.stack([grid_x, grid_y], dim=-1)[None].expand(img.shape[0], -1, -1, -1) - \ | |
(center_uv - half_size).reshape(-1, 1, 1, 2)[..., [0,1]] | |
normalized_xy = normalized_xy / (ori_crop_size-1).reshape(-1, 1, 1, 1) | |
normalized_xy = (2*normalized_xy - 1) * mask_ratio.reshape(-1, 1, 1, 1) | |
sample_start = (center_uv - half_size) | |
# print(normalized_xy[0][sample_start[0][0], sample_start[0][1]], mask_ratio) | |
img_out = F.grid_sample(img, normalized_xy[..., [1,0]], padding_mode='border', align_corners=True) | |
img_out = rearrange(img_out, 'b c h w -> b h w c') | |
return img_out | |
def _image2diffusion(self, embeddings, pred_rgb, mask, image_size=256): | |
# pred_rgb: tensor [1, 3, H, W] in [0, 1] | |
# assert pred_rgb.w | |
assert len(pred_rgb.shape) == 4, f"except 4 dim tensor, got: {pred_rgb.shape}" | |
cond_img = embeddings["cond_img"] | |
cond_img = self.center_crop(cond_img, mask, mask_ratio=1.0, image_size=image_size) | |
pred_rgb_256, crop_idx_all = self.center_crop(pred_rgb, mask, return_uv=True, mask_ratio=1.0, image_size=image_size) | |
# print(f"pred_rgb_256: {pred_rgb_256.min()} {pred_rgb_256.max()} {pred_rgb_256.shape} {cond_img.shape}") | |
mask_img = self.center_crop(1 - mask.expand(-1, -1, -1, 3), mask, mask_ratio=1.0, image_size=image_size) | |
xc = self.get_input(cond_img) | |
pred_rgb_256 = self.get_input(pred_rgb_256) | |
return pred_rgb_256, crop_idx_all, xc | |
def _get_condition(self, xc, with_uncondition=False): | |
# To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%. | |
# z.shape: [8, 4, 64, 64]; c.shape: [8, 1, 768] | |
# print('=========== xc shape ===========', xc.shape) | |
# print(xc.shape, xc.min(), xc.max(), self.model.use_clip_embdding) | |
xc = xc * 2 - 1 | |
cond = {} | |
clip_emb = self.model.get_learned_conditioning(xc if self.model.use_clip_embdding else [""]).detach() | |
c_concat = self.model.encode_first_stage((xc.to(self.device))).mode().detach() | |
# print(clip_emb.shape, clip_emb.min(), clip_emb.max(), self.model.use_clip_embdding) | |
if with_uncondition: | |
cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)] | |
cond['c_concat'] = [torch.cat([torch.zeros_like(c_concat).to(self.device), c_concat], dim=0)] | |
else: | |
cond['c_crossattn'] = [clip_emb] | |
cond['c_concat'] = [c_concat] | |
return cond | |
def __call__(self, embeddings, pred_rgb, mask, guidance_scale=3, dps_scale=0.2, as_latent=False, grad_scale=1, save_guidance_path:Path=None, | |
ddim_steps=200, ddim_eta=1, operator=None): | |
# todo: The upsacle is currectly hard-coded | |
upscale = 1 | |
# with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
pred_rgb_256, crop_idx_all, xc = self._image2diffusion(embeddings, pred_rgb, mask, image_size=256*upscale) | |
cond = self._get_condition(xc, with_uncondition=True) | |
assert pred_rgb_256.shape[-1] == pred_rgb_256.shape[-2], f"Expect image of square size, get {pred_rgb.shape}" | |
latents = torch.randn_like(self.encode_imgs(pred_rgb_256)) | |
if self.use_ddim: | |
self.scheduler.set_timesteps(ddim_steps) | |
else: | |
self.scheduler.set_timesteps(self.num_train_timesteps) | |
intermidates = [] | |
for i, t in tqdm(enumerate(self.scheduler.timesteps)): | |
x_in = torch.cat([latents] * 2) | |
t_in = torch.cat([t.view(1).expand(latents.shape[0])] * 2).to(self.device) | |
noise_pred = self.model.apply_model(x_in, t_in, cond) | |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
# dps | |
if dps_scale > 0: | |
with torch.enable_grad(): | |
t_batch = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) * 0 + t | |
x_hat_latents = self.model.predict_start_from_noise(latents.requires_grad_(True), t_batch, noise_pred) | |
x_hat = self.decode_latents(x_hat_latents) | |
x_hat = operator.forward(x_hat) | |
norm = torch.linalg.norm((pred_rgb_256-x_hat).reshape(pred_rgb_256.shape[0], -1), dim=-1) | |
guidance_score = torch.autograd.grad(norm.sum(), latents, retain_graph=True)[0] | |
if (not save_guidance_path is None) and i % (len(self.scheduler.timesteps)//20) == 0: | |
x_t = self.decode_latents(latents) | |
intermidates.append(torch.cat([x_hat, x_t, pred_rgb_256, pred_rgb_256-x_hat], dim=-2).detach().cpu()) | |
# print("before", noise_pred[0, 2, 10, 16:22], noise_pred.shape, dps_scale) | |
logger.debug(f"Guidance loss: {norm}") | |
noise_pred = noise_pred + dps_scale * guidance_score | |
if self.use_ddim: | |
latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)['prev_sample'] | |
else: | |
latents = self.scheduler.step(noise_pred.clone().detach(), t, latents)['prev_sample'] | |
if dps_scale > 0: | |
del x_hat | |
del guidance_score | |
del noise_pred | |
del x_hat_latents | |
del norm | |
imgs = self.decode_latents(latents) | |
viz_images = torch.cat([pred_rgb_256, imgs],dim=-1)[:1] | |
if not save_guidance_path is None and len(intermidates) > 0: | |
save_image(viz_images, save_guidance_path) | |
viz_images = torch.cat(intermidates,dim=-1)[:1] | |
save_image(viz_images, save_guidance_path+"all.jpg") | |
# transform back to original images | |
img_ori_size = self.restore_crop(imgs, pred_rgb, crop_idx_all) | |
if not save_guidance_path is None: | |
img_ori_size_save = rearrange(img_ori_size, 'b h w c -> b c h w')[:1] | |
save_image(img_ori_size_save, save_guidance_path+"_out.jpg") | |
return img_ori_size | |
def decode_latents(self, latents): | |
# zs: [B, 4, 32, 32] Latent space image | |
# with self.model.ema_scope(): | |
imgs = self.model.decode_first_stage(latents) | |
imgs = (imgs / 2 + 0.5).clamp(0, 1) | |
return imgs # [B, 3, 256, 256] RGB space image | |
def encode_imgs(self, imgs): | |
# imgs: [B, 3, 256, 256] RGB space image | |
# with self.model.ema_scope(): | |
imgs = imgs * 2 - 1 | |
# latents = torch.cat([self.model.get_first_stage_encoding(self.model.encode_first_stage(img.unsqueeze(0))) for img in imgs], dim=0) | |
latents = self.model.get_first_stage_encoding(self.model.encode_first_stage(imgs)) | |
return latents # [B, 4, 32, 32] Latent space image |