zxl
first commit
07c6a04
raw
history blame
10.5 kB
# Adapted from OpenSora
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# OpenSora: https://github.com/hpcaitech/Open-Sora
# --------------------------------------------------------
import torch
import torch.distributed as dist
from einops import rearrange
from torch.distributions import LogisticNormal
from tqdm import tqdm
from videosys.core.pab_mgr import get_diffusion_skip, get_diffusion_skip_timestep, skip_diffusion_timestep
from videosys.diffusion.gaussian_diffusion import _extract_into_tensor
def mean_flat(tensor: torch.Tensor, mask=None):
"""
Take the mean over all non-batch dimensions.
"""
if mask is None:
return tensor.mean(dim=list(range(1, len(tensor.shape))))
else:
assert tensor.dim() == 5
assert tensor.shape[2] == mask.shape[1]
tensor = rearrange(tensor, "b c t h w -> b t (c h w)")
denom = mask.sum(dim=1) * tensor.shape[-1]
loss = (tensor * mask.unsqueeze(2)).sum(dim=1).sum(dim=1) / denom
return loss
def timestep_transform(
t,
model_kwargs,
base_resolution=512 * 512,
base_num_frames=1,
scale=1.0,
num_timesteps=1,
):
t = t / num_timesteps
resolution = model_kwargs["height"] * model_kwargs["width"]
ratio_space = (resolution / base_resolution).sqrt()
# NOTE: currently, we do not take fps into account
# NOTE: temporal_reduction is hardcoded, this should be equal to the temporal reduction factor of the vae
if model_kwargs["num_frames"][0] == 1:
num_frames = torch.ones_like(model_kwargs["num_frames"])
else:
num_frames = model_kwargs["num_frames"] // 17 * 5
ratio_time = (num_frames / base_num_frames).sqrt()
ratio = ratio_space * ratio_time * scale
new_t = ratio * t / (1 + (ratio - 1) * t)
new_t = new_t * num_timesteps
return new_t
class RFlowScheduler:
def __init__(
self,
num_timesteps=1000,
num_sampling_steps=10,
use_discrete_timesteps=False,
sample_method="uniform",
loc=0.0,
scale=1.0,
use_timestep_transform=False,
transform_scale=1.0,
):
self.num_timesteps = num_timesteps
self.num_sampling_steps = num_sampling_steps
self.use_discrete_timesteps = use_discrete_timesteps
# sample method
assert sample_method in ["uniform", "logit-normal"]
assert (
sample_method == "uniform" or not use_discrete_timesteps
), "Only uniform sampling is supported for discrete timesteps"
self.sample_method = sample_method
if sample_method == "logit-normal":
self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale]))
self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device)
# timestep transform
self.use_timestep_transform = use_timestep_transform
self.transform_scale = transform_scale
def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None):
"""
Compute training losses for a single timestep.
Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses
Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0]
"""
if t is None:
if self.use_discrete_timesteps:
t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device)
elif self.sample_method == "uniform":
t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_timesteps
elif self.sample_method == "logit-normal":
t = self.sample_t(x_start) * self.num_timesteps
if self.use_timestep_transform:
t = timestep_transform(t, model_kwargs, scale=self.transform_scale, num_timesteps=self.num_timesteps)
if model_kwargs is None:
model_kwargs = {}
if noise is None:
noise = torch.randn_like(x_start)
assert noise.shape == x_start.shape
x_t = self.add_noise(x_start, noise, t)
if mask is not None:
t0 = torch.zeros_like(t)
x_t0 = self.add_noise(x_start, noise, t0)
x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0)
terms = {}
model_output = model(x_t, t, **model_kwargs)
velocity_pred = model_output.chunk(2, dim=1)[0]
if weights is None:
loss = mean_flat((velocity_pred - (x_start - noise)).pow(2), mask=mask)
else:
weight = _extract_into_tensor(weights, t, x_start.shape)
loss = mean_flat(weight * (velocity_pred - (x_start - noise)).pow(2), mask=mask)
terms["loss"] = loss
return terms
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
"""
compatible with diffusers add_noise()
"""
timepoints = timesteps.float() / self.num_timesteps
timepoints = 1 - timepoints # [1,1/1000]
# timepoint (bsz) noise: (bsz, 4, frame, w ,h)
# expand timepoint to noise shape
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4])
return timepoints * original_samples + (1 - timepoints) * noise
class RFLOW:
def __init__(
self,
num_sampling_steps=10,
num_timesteps=1000,
cfg_scale=4.0,
use_discrete_timesteps=False,
use_timestep_transform=False,
**kwargs,
):
self.num_sampling_steps = num_sampling_steps
self.num_timesteps = num_timesteps
self.cfg_scale = cfg_scale
self.use_discrete_timesteps = use_discrete_timesteps
self.use_timestep_transform = use_timestep_transform
self.scheduler = RFlowScheduler(
num_timesteps=num_timesteps,
num_sampling_steps=num_sampling_steps,
use_discrete_timesteps=use_discrete_timesteps,
use_timestep_transform=use_timestep_transform,
**kwargs,
)
def sample(
self,
model,
text_encoder,
z,
prompts,
device,
additional_args=None,
mask=None,
guidance_scale=None,
progress=True,
verbose=False,
):
# if no specific guidance scale is provided, use the default scale when initializing the scheduler
if guidance_scale is None:
guidance_scale = self.cfg_scale
n = len(prompts)
# text encoding
model_args = text_encoder.encode(prompts)
y_null = text_encoder.null(n)
model_args["y"] = torch.cat([model_args["y"], y_null], 0)
if additional_args is not None:
model_args.update(additional_args)
# prepare timesteps
timesteps = [(1.0 - i / self.num_sampling_steps) * self.num_timesteps for i in range(self.num_sampling_steps)]
if self.use_discrete_timesteps:
timesteps = [int(round(t)) for t in timesteps]
timesteps = [torch.tensor([t] * z.shape[0], device=device) for t in timesteps]
if self.use_timestep_transform:
timesteps = [timestep_transform(t, additional_args, num_timesteps=self.num_timesteps) for t in timesteps]
if get_diffusion_skip() and get_diffusion_skip_timestep() is not None:
orignal_timesteps = timesteps
diffusion_skip_timestep = get_diffusion_skip_timestep()
timesteps = skip_diffusion_timestep(timesteps, diffusion_skip_timestep)
if verbose and dist.get_rank() == 0:
print("============================")
print("skip diffusion steps!!!")
print("============================")
print(f"orignal sample timesteps: {orignal_timesteps}")
print(f"orignal diffusion steps: {len(orignal_timesteps)}")
print("============================")
print(f"skip diffusion steps: {get_diffusion_skip_timestep()}")
print(f"sample timesteps: {timesteps}")
print(f"num_inference_steps: {len(timesteps)}")
print("============================")
if mask is not None:
noise_added = torch.zeros_like(mask, dtype=torch.bool)
noise_added = noise_added | (mask == 1)
progress_wrap = tqdm if progress and dist.get_rank() == 0 else (lambda x: x)
dtype = model.x_embedder.proj.weight.dtype
all_timesteps = [int(t.to(dtype).item()) for t in timesteps]
for i, t in progress_wrap(list(enumerate(timesteps))):
# mask for adding noise
if mask is not None:
mask_t = mask * self.num_timesteps
x0 = z.clone()
x_noise = self.scheduler.add_noise(x0, torch.randn_like(x0), t)
mask_t_upper = mask_t >= t.unsqueeze(1)
model_args["x_mask"] = mask_t_upper.repeat(2, 1)
mask_add_noise = mask_t_upper & ~noise_added
z = torch.where(mask_add_noise[:, None, :, None, None], x_noise, x0)
noise_added = mask_t_upper
# classifier-free guidance
z_in = torch.cat([z, z], 0)
t = torch.cat([t, t], 0)
# pred = model(z_in, t, **model_args).chunk(2, dim=1)[0]
output = model(z_in, t, all_timesteps, **model_args)
pred = output.chunk(2, dim=1)[0]
pred_cond, pred_uncond = pred.chunk(2, dim=0)
v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
# update z
dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i]
dt = dt / self.num_timesteps
z = z + v_pred * dt[:, None, None, None, None]
if mask is not None:
z = torch.where(mask_t_upper[:, None, :, None, None], z, x0)
return z
def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None):
return self.scheduler.training_losses(model, x_start, model_kwargs, noise, mask, weights, t)