Spaces:
Build error
Build error
import torch | |
import numpy as np | |
import torchvision.transforms as tvtf | |
from tools.StableDiffusion.Hack_SD_stepwise import Hack_SDPipe_Stepwise | |
''' | |
Input: Multiview images with added noise | |
denoise to x0 | |
denoise from step t1 to step t2 | |
''' | |
class HackSD_MCS(): | |
''' | |
transform images to self.latents | |
add noise to self.latents | |
predict step noise --> x0 | |
mv RGB-D warp as target image | |
target image encode to latent and get target noise | |
noise rectification | |
step denoise | |
''' | |
def __init__(self,device='cpu',use_lcm=True,denoise_steps=20, | |
sd_ckpt=f'tools/StableDiffusion/ckpt', | |
lcm_ckpt=f'latent-consistency/lcm-lora-sdv1-5') -> None: | |
''' | |
ref_rgb should be -1~1 tensor B*3*H*W | |
''' | |
self.device = device | |
self.target_type = np.float32 | |
self.use_lcm = use_lcm | |
self.sd_ckpt = sd_ckpt | |
self.lcm_ckpt = lcm_ckpt | |
self._load_model() | |
# define step to add noise and steps to denoise | |
self.denoise_steps = denoise_steps | |
self.timesteps = self.model.timesteps | |
def _load_model(self): | |
self.model = Hack_SDPipe_Stepwise.from_pretrained(self.sd_ckpt) | |
self.model._use_lcm(self.use_lcm,self.lcm_ckpt) | |
self.model.re_init(num_inference_steps=50) | |
try: | |
self.model.enable_xformers_memory_efficient_attention() | |
except: | |
pass # run without xformers | |
self.model = self.model.to(self.device) | |
def to(self, device): | |
self.device = device | |
self.model.to(device) | |
def _add_noise_to_latent(self,latents): | |
bsz = latents.shape[0] | |
# in the Stable Diffusion, the iterations numbers is 1000 for adding the noise and denosing. | |
timestep = self.timesteps[-self.denoise_steps] | |
timestep = timestep.repeat(bsz).to(self.device) | |
# target noise | |
noise = torch.randn_like(latents) | |
# add noise | |
noisy_latent = self.model.scheduler.add_noise(latents, noise, timestep) | |
# -------------------- noise for supervision ----------------- | |
if self.model.scheduler.config.prediction_type == "epsilon": | |
target = noise | |
elif self.model.scheduler.config.prediction_type == "v_prediction": | |
target = self.model.scheduler.get_velocity(latents, noise, timestep) | |
return noisy_latent, timestep, target | |
def _encode_mv_init_images(self, images): | |
''' | |
images should be B3HW | |
''' | |
images = images * 2 - 1 | |
self.latents = self.model._encode(images) | |
self.latents,_,_ = self._add_noise_to_latent(self.latents) | |
def _sd_forward(self, denoise_step, prompt_latent:torch.Tensor): | |
# temp noise prediction | |
t = self.timesteps[[-self.denoise_steps+denoise_step]].to(self.device) | |
noise_pred = self.model._step_noise(self.latents, t, prompt_latent.repeat(len(self.latents),1,1)) | |
# solve image | |
_,x0 = self.model._solve_x0(self.latents,noise_pred,t) | |
x0 = (x0 + 1) / 2 # in 0-1 | |
return t, noise_pred, x0 | |
def _denoise_to_x0(self, timestep_in_1000, prompt_latent:torch.Tensor): | |
# temp noise prediction | |
noise_pred = self.model._step_noise(self.latents, timestep_in_1000, prompt_latent.repeat(len(self.latents),1,1)) | |
# solve image | |
_,x0 = self.model._solve_x0(self.latents,noise_pred,timestep_in_1000) | |
x0 = (x0 + 1) / 2 # in 0-1 | |
return noise_pred, x0 | |
def _step_denoise(self, t, pred_noise, rect_x0, rect_w = 0.7): | |
''' | |
pred_noise B4H//8W//8 | |
x0, rect_x0 B3HW | |
''' | |
# encoder rect_x0 to latent | |
rect_x0 = rect_x0 * 2 - 1 | |
rect_latent = self.model._encode(rect_x0) | |
# rectified noise | |
rect_noise = self.model._solve_noise_given_x0_latent(self.latents,rect_latent,t) | |
# noise rectification | |
rect_noise = rect_noise / rect_noise.std(dim=list(range(1, rect_noise.ndim)),keepdim=True) \ | |
* pred_noise.std(dim=list(range(1, pred_noise.ndim)),keepdim=True) | |
pred_noise = pred_noise*(1.-rect_w) + rect_noise*rect_w | |
# step forward | |
self.latents = self.model._step_denoise(self.latents,pred_noise,t) | |
def _decode_mv_imgs(self): | |
imgs = self.model._decode(self.latents) | |
imgs = (imgs + 1) / 2 | |
return imgs | |