import torch from omegaconf import OmegaConf from sgm.util import instantiate_from_config from sgm.modules.diffusionmodules.sampling import * SD_XL_BASE_RATIOS = { "0.5": (704, 1408), "0.52": (704, 1344), "0.57": (768, 1344), "0.6": (768, 1280), "0.68": (832, 1216), "0.72": (832, 1152), "0.78": (896, 1152), "0.82": (896, 1088), "0.88": (960, 1088), "0.94": (960, 1024), "1.0": (1024, 1024), "1.07": (1024, 960), "1.13": (1088, 960), "1.21": (1088, 896), "1.29": (1152, 896), "1.38": (1152, 832), "1.46": (1216, 832), "1.67": (1280, 768), "1.75": (1344, 768), "1.91": (1344, 704), "2.0": (1408, 704), "2.09": (1472, 704), "2.4": (1536, 640), "2.5": (1600, 640), "2.89": (1664, 576), "3.0": (1728, 576), } def init_model(cfgs): model_cfg = OmegaConf.load(cfgs.model_cfg_path) ckpt = cfgs.load_ckpt_path model = instantiate_from_config(model_cfg.model) model.init_from_ckpt(ckpt) if cfgs.type == "train": model.train() else: if cfgs.use_gpu: model.to(torch.device("cuda", index=cfgs.gpu)) model.eval() model.freeze() return model def init_sampling(cfgs): discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", } if cfgs.dual_conditioner: guider_config = { "target": "sgm.modules.diffusionmodules.guiders.DualCFG", "params": {"scale": cfgs.scale}, } sampler = EulerEDMDualSampler( num_steps=cfgs.steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=0.0, s_tmin=0.0, s_tmax=999.0, s_noise=1.0, verbose=True, device=torch.device("cuda", index=cfgs.gpu) ) else: guider_config = { "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", "params": {"scale": cfgs.scale[0]}, } sampler = EulerEDMSampler( num_steps=cfgs.steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=0.0, s_tmin=0.0, s_tmax=999.0, s_noise=1.0, verbose=True, device=torch.device("cuda", index=cfgs.gpu) ) return sampler def deep_copy(batch): c_batch = {} for key in batch: if isinstance(batch[key], torch.Tensor): c_batch[key] = torch.clone(batch[key]) elif isinstance(batch[key], (tuple, list)): c_batch[key] = batch[key].copy() else: c_batch[key] = batch[key] return c_batch def prepare_batch(cfgs, batch): for key in batch: if isinstance(batch[key], torch.Tensor) and cfgs.use_gpu: batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu)) if not cfgs.dual_conditioner: batch_uc = deep_copy(batch) if "ntxt" in batch: batch_uc["txt"] = batch["ntxt"] else: batch_uc["txt"] = ["" for _ in range(len(batch["txt"]))] if "label" in batch: batch_uc["label"] = ["" for _ in range(len(batch["label"]))] return batch, batch_uc, None else: batch_uc_1 = deep_copy(batch) batch_uc_2 = deep_copy(batch) batch_uc_1["ref"] = torch.zeros_like(batch["ref"]) batch_uc_2["ref"] = torch.zeros_like(batch["ref"]) batch_uc_1["label"] = ["" for _ in range(len(batch["label"]))] return batch, batch_uc_1, batch_uc_2