|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from .model import gaussian_diffusion as gd |
|
from .model.dpm_solver import DPM_Solver, NoiseScheduleFlow, NoiseScheduleVP, model_wrapper |
|
|
|
|
|
def DPMS( |
|
model, |
|
condition, |
|
uncondition, |
|
cfg_scale, |
|
pag_scale=1.0, |
|
pag_applied_layers=None, |
|
model_type="noise", |
|
noise_schedule="linear", |
|
guidance_type="classifier-free", |
|
model_kwargs=None, |
|
diffusion_steps=1000, |
|
schedule="VP", |
|
interval_guidance=None, |
|
): |
|
if pag_applied_layers is None: |
|
pag_applied_layers = [] |
|
if model_kwargs is None: |
|
model_kwargs = {} |
|
if interval_guidance is None: |
|
interval_guidance = [0, 1.0] |
|
betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps)) |
|
|
|
|
|
if schedule == "VP": |
|
noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas) |
|
elif schedule == "FLOW": |
|
noise_schedule = NoiseScheduleFlow(schedule="discrete_flow") |
|
|
|
|
|
|
|
|
|
model_fn = model_wrapper( |
|
model, |
|
noise_schedule, |
|
model_type=model_type, |
|
model_kwargs=model_kwargs, |
|
guidance_type=guidance_type, |
|
pag_scale=pag_scale, |
|
pag_applied_layers=pag_applied_layers, |
|
condition=condition, |
|
unconditional_condition=uncondition, |
|
guidance_scale=cfg_scale, |
|
interval_guidance=interval_guidance, |
|
) |
|
|
|
return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") |
|
|