Spaces:
Paused
Paused
import torch | |
from tqdm.notebook import tqdm | |
from . import scheduler | |
from . import share | |
from src.utils.iimage import IImage | |
class DDIM: | |
def __init__(self, config, vae, encoder, unet): | |
self.vae = vae | |
self.encoder = encoder | |
self.unet = unet | |
self.config = config | |
self.schedule = scheduler.linear(1000, config.linear_start, config.linear_end) | |
def __call__( | |
self, prompt = '', dt = 50, shape = (1,4,64,64), seed = None, negative_prompt = '', unet_condition = None, | |
context = None, verbose = True): | |
if seed is not None: torch.manual_seed(seed) | |
if unet_condition is not None: | |
zT = torch.randn((1,4) + unet_condition.shape[2:]).cuda() | |
else: | |
zT = torch.randn(shape).cuda() | |
with torch.autocast('cuda'), torch.no_grad(): | |
if context is None: context = self.encoder.encode([negative_prompt, prompt]) | |
zt = zT | |
pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt) | |
for timestep in share.DDIMIterator(pbar): | |
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1) | |
eps_uncond, eps = self.unet( | |
torch.cat([_zt, _zt]), | |
timesteps = torch.tensor([timestep, timestep]).cuda(), | |
context = context | |
).chunk(2) | |
eps = (eps_uncond + 7.5 * (eps - eps_uncond)) | |
z0 = (zt - self.schedule.sqrt_one_minus_alphas[timestep] * eps) / self.schedule.sqrt_alphas[timestep] | |
zt = self.schedule.sqrt_alphas[timestep - dt] * z0 + self.schedule.sqrt_one_minus_alphas[timestep - dt] * eps | |
return IImage(self.vae.decode(z0 / self.config.scale_factor)) | |
def get_inpainting_condition(self, image, mask): | |
latent_size = [x//8 for x in image.size] | |
dtype = self.vae.encoder.conv_in.weight.dtype | |
with torch.no_grad(): | |
masked_image = image.torch().cuda() * ~mask.torch(0).bool().cuda() | |
masked_image = masked_image.to(dtype) | |
condition_x0 = self.vae.encode(masked_image).mean * self.config.scale_factor | |
condition_mask = mask.resize(latent_size[::-1]).cuda().torch(0).bool().to(dtype) | |
return torch.cat([condition_mask, condition_x0], 1) | |
inpainting_condition = get_inpainting_condition | |