VistaDream / ops /mcs.py
hpwang's picture
[Init]
fd5e0f7
import torch
import numpy as np
import torchvision.transforms as tvtf
from tools.StableDiffusion.Hack_SD_stepwise import Hack_SDPipe_Stepwise
'''
Input: Multiview images with added noise
denoise to x0
denoise from step t1 to step t2
'''
class HackSD_MCS():
'''
transform images to self.latents
add noise to self.latents
predict step noise --> x0
mv RGB-D warp as target image
target image encode to latent and get target noise
noise rectification
step denoise
'''
def __init__(self,device='cpu',use_lcm=True,denoise_steps=20,
sd_ckpt=f'tools/StableDiffusion/ckpt',
lcm_ckpt=f'latent-consistency/lcm-lora-sdv1-5') -> None:
'''
ref_rgb should be -1~1 tensor B*3*H*W
'''
self.device = device
self.target_type = np.float32
self.use_lcm = use_lcm
self.sd_ckpt = sd_ckpt
self.lcm_ckpt = lcm_ckpt
self._load_model()
# define step to add noise and steps to denoise
self.denoise_steps = denoise_steps
self.timesteps = self.model.timesteps
def _load_model(self):
self.model = Hack_SDPipe_Stepwise.from_pretrained(self.sd_ckpt)
self.model._use_lcm(self.use_lcm,self.lcm_ckpt)
self.model.re_init(num_inference_steps=50)
try:
self.model.enable_xformers_memory_efficient_attention()
except:
pass # run without xformers
self.model = self.model.to(self.device)
def to(self, device):
self.device = device
self.model.to(device)
@ torch.no_grad()
def _add_noise_to_latent(self,latents):
bsz = latents.shape[0]
# in the Stable Diffusion, the iterations numbers is 1000 for adding the noise and denosing.
timestep = self.timesteps[-self.denoise_steps]
timestep = timestep.repeat(bsz).to(self.device)
# target noise
noise = torch.randn_like(latents)
# add noise
noisy_latent = self.model.scheduler.add_noise(latents, noise, timestep)
# -------------------- noise for supervision -----------------
if self.model.scheduler.config.prediction_type == "epsilon":
target = noise
elif self.model.scheduler.config.prediction_type == "v_prediction":
target = self.model.scheduler.get_velocity(latents, noise, timestep)
return noisy_latent, timestep, target
@ torch.no_grad()
def _encode_mv_init_images(self, images):
'''
images should be B3HW
'''
images = images * 2 - 1
self.latents = self.model._encode(images)
self.latents,_,_ = self._add_noise_to_latent(self.latents)
@ torch.no_grad()
def _sd_forward(self, denoise_step, prompt_latent:torch.Tensor):
# temp noise prediction
t = self.timesteps[[-self.denoise_steps+denoise_step]].to(self.device)
noise_pred = self.model._step_noise(self.latents, t, prompt_latent.repeat(len(self.latents),1,1))
# solve image
_,x0 = self.model._solve_x0(self.latents,noise_pred,t)
x0 = (x0 + 1) / 2 # in 0-1
return t, noise_pred, x0
@ torch.no_grad()
def _denoise_to_x0(self, timestep_in_1000, prompt_latent:torch.Tensor):
# temp noise prediction
noise_pred = self.model._step_noise(self.latents, timestep_in_1000, prompt_latent.repeat(len(self.latents),1,1))
# solve image
_,x0 = self.model._solve_x0(self.latents,noise_pred,timestep_in_1000)
x0 = (x0 + 1) / 2 # in 0-1
return noise_pred, x0
@ torch.no_grad()
def _step_denoise(self, t, pred_noise, rect_x0, rect_w = 0.7):
'''
pred_noise B4H//8W//8
x0, rect_x0 B3HW
'''
# encoder rect_x0 to latent
rect_x0 = rect_x0 * 2 - 1
rect_latent = self.model._encode(rect_x0)
# rectified noise
rect_noise = self.model._solve_noise_given_x0_latent(self.latents,rect_latent,t)
# noise rectification
rect_noise = rect_noise / rect_noise.std(dim=list(range(1, rect_noise.ndim)),keepdim=True) \
* pred_noise.std(dim=list(range(1, pred_noise.ndim)),keepdim=True)
pred_noise = pred_noise*(1.-rect_w) + rect_noise*rect_w
# step forward
self.latents = self.model._step_denoise(self.latents,pred_noise,t)
@ torch.no_grad()
def _decode_mv_imgs(self):
imgs = self.model._decode(self.latents)
imgs = (imgs + 1) / 2
return imgs