|
import torch |
|
from omegaconf import OmegaConf |
|
from sgm.util import instantiate_from_config |
|
from sgm.modules.diffusionmodules.sampling import * |
|
|
|
|
|
def init_model(cfgs): |
|
|
|
model_cfg = OmegaConf.load(cfgs.model_cfg_path) |
|
ckpt = cfgs.load_ckpt_path |
|
|
|
model = instantiate_from_config(model_cfg.model) |
|
model.init_from_ckpt(ckpt) |
|
|
|
if cfgs.type == "train": |
|
model.train() |
|
else: |
|
model.to(torch.device("cuda", index=cfgs.gpu)) |
|
model.eval() |
|
model.freeze() |
|
|
|
return model |
|
|
|
def init_sampling(cfgs): |
|
|
|
discretization_config = { |
|
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", |
|
} |
|
|
|
guider_config = { |
|
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", |
|
"params": {"scale": cfgs.scale[0]}, |
|
} |
|
|
|
sampler = EulerEDMSampler( |
|
num_steps=cfgs.steps, |
|
discretization_config=discretization_config, |
|
guider_config=guider_config, |
|
s_churn=0.0, |
|
s_tmin=0.0, |
|
s_tmax=999.0, |
|
s_noise=1.0, |
|
verbose=True, |
|
device=torch.device("cuda", index=cfgs.gpu) |
|
) |
|
|
|
return sampler |
|
|
|
def deep_copy(batch): |
|
|
|
c_batch = {} |
|
for key in batch: |
|
if isinstance(batch[key], torch.Tensor): |
|
c_batch[key] = torch.clone(batch[key]) |
|
elif isinstance(batch[key], (tuple, list)): |
|
c_batch[key] = batch[key].copy() |
|
else: |
|
c_batch[key] = batch[key] |
|
|
|
return c_batch |
|
|
|
def prepare_batch(cfgs, batch): |
|
|
|
for key in batch: |
|
if isinstance(batch[key], torch.Tensor): |
|
batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu)) |
|
|
|
batch_uc = batch |
|
|
|
return batch, batch_uc |