|
import math |
|
import os |
|
from typing import List, Union |
|
|
|
import numpy as np |
|
import streamlit as st |
|
import torch |
|
from einops import rearrange, repeat |
|
from imwatermark import WatermarkEncoder |
|
from omegaconf import ListConfig, OmegaConf |
|
from PIL import Image |
|
from safetensors.torch import load_file as load_safetensors |
|
from torch import autocast |
|
from torchvision import transforms |
|
from torchvision.utils import make_grid |
|
|
|
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering |
|
from sgm.modules.diffusionmodules.sampling import ( |
|
DPMPP2MSampler, |
|
DPMPP2SAncestralSampler, |
|
EulerAncestralSampler, |
|
EulerEDMSampler, |
|
HeunEDMSampler, |
|
LinearMultistepSampler, |
|
) |
|
from sgm.util import append_dims, instantiate_from_config |
|
|
|
|
|
class WatermarkEmbedder: |
|
def __init__(self, watermark): |
|
self.watermark = watermark |
|
self.num_bits = len(WATERMARK_BITS) |
|
self.encoder = WatermarkEncoder() |
|
self.encoder.set_watermark("bits", self.watermark) |
|
|
|
def __call__(self, image: torch.Tensor): |
|
""" |
|
Adds a predefined watermark to the input image |
|
|
|
Args: |
|
image: ([N,] B, C, H, W) in range [0, 1] |
|
|
|
Returns: |
|
same as input but watermarked |
|
""" |
|
|
|
squeeze = len(image.shape) == 4 |
|
if squeeze: |
|
image = image[None, ...] |
|
n = image.shape[0] |
|
image_np = rearrange( |
|
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c" |
|
).numpy()[:, :, :, ::-1] |
|
|
|
for k in range(image_np.shape[0]): |
|
image_np[k] = self.encoder.encode(image_np[k], "dwtDct") |
|
image = torch.from_numpy( |
|
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) |
|
).to(image.device) |
|
image = torch.clamp(image / 255, min=0.0, max=1.0) |
|
if squeeze: |
|
image = image[0] |
|
return image |
|
|
|
|
|
|
|
|
|
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 |
|
|
|
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] |
|
embed_watemark = WatermarkEmbedder(WATERMARK_BITS) |
|
|
|
|
|
@st.cache_resource() |
|
def init_st(version_dict, load_ckpt=True, load_filter=True): |
|
state = dict() |
|
if not "model" in state: |
|
config = version_dict["config"] |
|
ckpt = version_dict["ckpt"] |
|
|
|
config = OmegaConf.load(config) |
|
model, msg = load_model_from_config(config, ckpt if load_ckpt else None) |
|
|
|
state["msg"] = msg |
|
state["model"] = model |
|
state["ckpt"] = ckpt if load_ckpt else None |
|
state["config"] = config |
|
if load_filter: |
|
state["filter"] = DeepFloydDataFiltering(verbose=False) |
|
return state |
|
|
|
|
|
def load_model(model): |
|
model.cuda() |
|
|
|
|
|
lowvram_mode = False |
|
|
|
|
|
def set_lowvram_mode(mode): |
|
global lowvram_mode |
|
lowvram_mode = mode |
|
|
|
|
|
def initial_model_load(model): |
|
global lowvram_mode |
|
if lowvram_mode: |
|
model.model.half() |
|
else: |
|
model.cuda() |
|
return model |
|
|
|
|
|
def unload_model(model): |
|
global lowvram_mode |
|
if lowvram_mode: |
|
model.cpu() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def load_model_from_config(config, ckpt=None, verbose=True): |
|
model = instantiate_from_config(config.model) |
|
|
|
if ckpt is not None: |
|
print(f"Loading model from {ckpt}") |
|
if ckpt.endswith("ckpt"): |
|
pl_sd = torch.load(ckpt, map_location="cpu") |
|
if "global_step" in pl_sd: |
|
global_step = pl_sd["global_step"] |
|
st.info(f"loaded ckpt from global step {global_step}") |
|
print(f"Global Step: {pl_sd['global_step']}") |
|
sd = pl_sd["state_dict"] |
|
elif ckpt.endswith("safetensors"): |
|
sd = load_safetensors(ckpt) |
|
else: |
|
raise NotImplementedError |
|
|
|
msg = None |
|
|
|
m, u = model.load_state_dict(sd, strict=False) |
|
|
|
if len(m) > 0 and verbose: |
|
print("missing keys:") |
|
print(m) |
|
if len(u) > 0 and verbose: |
|
print("unexpected keys:") |
|
print(u) |
|
else: |
|
msg = None |
|
|
|
model = initial_model_load(model) |
|
model.eval() |
|
return model, msg |
|
|
|
|
|
def get_unique_embedder_keys_from_conditioner(conditioner): |
|
return list(set([x.input_key for x in conditioner.embedders])) |
|
|
|
|
|
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): |
|
|
|
|
|
value_dict = {} |
|
for key in keys: |
|
if key == "txt": |
|
if prompt is None: |
|
prompt = st.text_input( |
|
"Prompt", "A professional photograph of an astronaut riding a pig" |
|
) |
|
if negative_prompt is None: |
|
negative_prompt = st.text_input("Negative prompt", "") |
|
|
|
value_dict["prompt"] = prompt |
|
value_dict["negative_prompt"] = negative_prompt |
|
|
|
if key == "original_size_as_tuple": |
|
orig_width = st.number_input( |
|
"orig_width", |
|
value=init_dict["orig_width"], |
|
min_value=16, |
|
) |
|
orig_height = st.number_input( |
|
"orig_height", |
|
value=init_dict["orig_height"], |
|
min_value=16, |
|
) |
|
|
|
value_dict["orig_width"] = orig_width |
|
value_dict["orig_height"] = orig_height |
|
|
|
if key == "crop_coords_top_left": |
|
crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0) |
|
crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0) |
|
|
|
value_dict["crop_coords_top"] = crop_coord_top |
|
value_dict["crop_coords_left"] = crop_coord_left |
|
|
|
if key == "aesthetic_score": |
|
value_dict["aesthetic_score"] = 6.0 |
|
value_dict["negative_aesthetic_score"] = 2.5 |
|
|
|
if key == "target_size_as_tuple": |
|
value_dict["target_width"] = init_dict["target_width"] |
|
value_dict["target_height"] = init_dict["target_height"] |
|
|
|
return value_dict |
|
|
|
|
|
def perform_save_locally(save_path, samples): |
|
os.makedirs(os.path.join(save_path), exist_ok=True) |
|
base_count = len(os.listdir(os.path.join(save_path))) |
|
samples = embed_watemark(samples) |
|
for sample in samples: |
|
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") |
|
Image.fromarray(sample.astype(np.uint8)).save( |
|
os.path.join(save_path, f"{base_count:09}.png") |
|
) |
|
base_count += 1 |
|
|
|
|
|
def init_save_locally(_dir, init_value: bool = False): |
|
save_locally = st.sidebar.checkbox("Save images locally", value=init_value) |
|
if save_locally: |
|
save_path = st.text_input("Save path", value=os.path.join(_dir, "samples")) |
|
else: |
|
save_path = None |
|
|
|
return save_locally, save_path |
|
|
|
|
|
class Img2ImgDiscretizationWrapper: |
|
""" |
|
wraps a discretizer, and prunes the sigmas |
|
params: |
|
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) |
|
""" |
|
|
|
def __init__(self, discretization, strength: float = 1.0): |
|
self.discretization = discretization |
|
self.strength = strength |
|
assert 0.0 <= self.strength <= 1.0 |
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
sigmas = self.discretization(*args, **kwargs) |
|
print(f"sigmas after discretization, before pruning img2img: ", sigmas) |
|
sigmas = torch.flip(sigmas, (0,)) |
|
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] |
|
print("prune index:", max(int(self.strength * len(sigmas)), 1)) |
|
sigmas = torch.flip(sigmas, (0,)) |
|
print(f"sigmas after pruning: ", sigmas) |
|
return sigmas |
|
|
|
|
|
class Txt2NoisyDiscretizationWrapper: |
|
""" |
|
wraps a discretizer, and prunes the sigmas |
|
params: |
|
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) |
|
""" |
|
|
|
def __init__(self, discretization, strength: float = 0.0, original_steps=None): |
|
self.discretization = discretization |
|
self.strength = strength |
|
self.original_steps = original_steps |
|
assert 0.0 <= self.strength <= 1.0 |
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
sigmas = self.discretization(*args, **kwargs) |
|
print(f"sigmas after discretization, before pruning img2img: ", sigmas) |
|
sigmas = torch.flip(sigmas, (0,)) |
|
if self.original_steps is None: |
|
steps = len(sigmas) |
|
else: |
|
steps = self.original_steps + 1 |
|
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) |
|
sigmas = sigmas[prune_index:] |
|
print("prune index:", prune_index) |
|
sigmas = torch.flip(sigmas, (0,)) |
|
print(f"sigmas after pruning: ", sigmas) |
|
return sigmas |
|
|
|
|
|
def get_guider(key): |
|
guider = st.sidebar.selectbox( |
|
f"Discretization #{key}", |
|
[ |
|
"VanillaCFG", |
|
"IdentityGuider", |
|
], |
|
) |
|
|
|
if guider == "IdentityGuider": |
|
guider_config = { |
|
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" |
|
} |
|
elif guider == "VanillaCFG": |
|
scale = st.number_input( |
|
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0 |
|
) |
|
|
|
thresholder = st.sidebar.selectbox( |
|
f"Thresholder #{key}", |
|
[ |
|
"None", |
|
], |
|
) |
|
|
|
if thresholder == "None": |
|
dyn_thresh_config = { |
|
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" |
|
} |
|
else: |
|
raise NotImplementedError |
|
|
|
guider_config = { |
|
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", |
|
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, |
|
} |
|
else: |
|
raise NotImplementedError |
|
return guider_config |
|
|
|
|
|
def init_sampling( |
|
key=1, |
|
img2img_strength=1.0, |
|
specify_num_samples=True, |
|
stage2strength=None, |
|
): |
|
num_rows, num_cols = 1, 1 |
|
if specify_num_samples: |
|
num_cols = st.number_input( |
|
f"num cols #{key}", value=2, min_value=1, max_value=10 |
|
) |
|
|
|
steps = st.sidebar.number_input( |
|
f"steps #{key}", value=40, min_value=1, max_value=1000 |
|
) |
|
sampler = st.sidebar.selectbox( |
|
f"Sampler #{key}", |
|
[ |
|
"EulerEDMSampler", |
|
"HeunEDMSampler", |
|
"EulerAncestralSampler", |
|
"DPMPP2SAncestralSampler", |
|
"DPMPP2MSampler", |
|
"LinearMultistepSampler", |
|
], |
|
0, |
|
) |
|
discretization = st.sidebar.selectbox( |
|
f"Discretization #{key}", |
|
[ |
|
"LegacyDDPMDiscretization", |
|
"EDMDiscretization", |
|
], |
|
) |
|
|
|
discretization_config = get_discretization(discretization, key=key) |
|
|
|
guider_config = get_guider(key=key) |
|
|
|
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) |
|
if img2img_strength < 1.0: |
|
st.warning( |
|
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper" |
|
) |
|
sampler.discretization = Img2ImgDiscretizationWrapper( |
|
sampler.discretization, strength=img2img_strength |
|
) |
|
if stage2strength is not None: |
|
sampler.discretization = Txt2NoisyDiscretizationWrapper( |
|
sampler.discretization, strength=stage2strength, original_steps=steps |
|
) |
|
return sampler, num_rows, num_cols |
|
|
|
|
|
def get_discretization(discretization, key=1): |
|
if discretization == "LegacyDDPMDiscretization": |
|
discretization_config = { |
|
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", |
|
} |
|
elif discretization == "EDMDiscretization": |
|
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) |
|
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) |
|
rho = st.number_input(f"rho #{key}", value=3.0) |
|
discretization_config = { |
|
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", |
|
"params": { |
|
"sigma_min": sigma_min, |
|
"sigma_max": sigma_max, |
|
"rho": rho, |
|
}, |
|
} |
|
|
|
return discretization_config |
|
|
|
|
|
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1): |
|
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": |
|
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0) |
|
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0) |
|
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0) |
|
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0) |
|
|
|
if sampler_name == "EulerEDMSampler": |
|
sampler = EulerEDMSampler( |
|
num_steps=steps, |
|
discretization_config=discretization_config, |
|
guider_config=guider_config, |
|
s_churn=s_churn, |
|
s_tmin=s_tmin, |
|
s_tmax=s_tmax, |
|
s_noise=s_noise, |
|
verbose=True, |
|
) |
|
elif sampler_name == "HeunEDMSampler": |
|
sampler = HeunEDMSampler( |
|
num_steps=steps, |
|
discretization_config=discretization_config, |
|
guider_config=guider_config, |
|
s_churn=s_churn, |
|
s_tmin=s_tmin, |
|
s_tmax=s_tmax, |
|
s_noise=s_noise, |
|
verbose=True, |
|
) |
|
elif ( |
|
sampler_name == "EulerAncestralSampler" |
|
or sampler_name == "DPMPP2SAncestralSampler" |
|
): |
|
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0) |
|
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0) |
|
|
|
if sampler_name == "EulerAncestralSampler": |
|
sampler = EulerAncestralSampler( |
|
num_steps=steps, |
|
discretization_config=discretization_config, |
|
guider_config=guider_config, |
|
eta=eta, |
|
s_noise=s_noise, |
|
verbose=True, |
|
) |
|
elif sampler_name == "DPMPP2SAncestralSampler": |
|
sampler = DPMPP2SAncestralSampler( |
|
num_steps=steps, |
|
discretization_config=discretization_config, |
|
guider_config=guider_config, |
|
eta=eta, |
|
s_noise=s_noise, |
|
verbose=True, |
|
) |
|
elif sampler_name == "DPMPP2MSampler": |
|
sampler = DPMPP2MSampler( |
|
num_steps=steps, |
|
discretization_config=discretization_config, |
|
guider_config=guider_config, |
|
verbose=True, |
|
) |
|
elif sampler_name == "LinearMultistepSampler": |
|
order = st.sidebar.number_input("order", value=4, min_value=1) |
|
sampler = LinearMultistepSampler( |
|
num_steps=steps, |
|
discretization_config=discretization_config, |
|
guider_config=guider_config, |
|
order=order, |
|
verbose=True, |
|
) |
|
else: |
|
raise ValueError(f"unknown sampler {sampler_name}!") |
|
|
|
return sampler |
|
|
|
|
|
def get_interactive_image(key=None) -> Image.Image: |
|
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) |
|
if image is not None: |
|
image = Image.open(image) |
|
if not image.mode == "RGB": |
|
image = image.convert("RGB") |
|
return image |
|
|
|
|
|
def load_img(display=True, key=None): |
|
image = get_interactive_image(key=key) |
|
if image is None: |
|
return None |
|
if display: |
|
st.image(image) |
|
w, h = image.size |
|
print(f"loaded input image of size ({w}, {h})") |
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Lambda(lambda x: x * 2.0 - 1.0), |
|
] |
|
) |
|
img = transform(image)[None, ...] |
|
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}") |
|
return img |
|
|
|
|
|
def get_init_img(batch_size=1, key=None): |
|
init_image = load_img(key=key).cuda() |
|
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) |
|
return init_image |
|
|
|
|
|
def do_sample( |
|
model, |
|
sampler, |
|
value_dict, |
|
num_samples, |
|
H, |
|
W, |
|
C, |
|
F, |
|
force_uc_zero_embeddings: List = None, |
|
batch2model_input: List = None, |
|
return_latents=False, |
|
filter=None, |
|
): |
|
if force_uc_zero_embeddings is None: |
|
force_uc_zero_embeddings = [] |
|
if batch2model_input is None: |
|
batch2model_input = [] |
|
|
|
st.text("Sampling") |
|
|
|
outputs = st.empty() |
|
precision_scope = autocast |
|
with torch.no_grad(): |
|
with precision_scope("cuda"): |
|
with model.ema_scope(): |
|
num_samples = [num_samples] |
|
load_model(model.conditioner) |
|
batch, batch_uc = get_batch( |
|
get_unique_embedder_keys_from_conditioner(model.conditioner), |
|
value_dict, |
|
num_samples, |
|
) |
|
for key in batch: |
|
if isinstance(batch[key], torch.Tensor): |
|
print(key, batch[key].shape) |
|
elif isinstance(batch[key], list): |
|
print(key, [len(l) for l in batch[key]]) |
|
else: |
|
print(key, batch[key]) |
|
c, uc = model.conditioner.get_unconditional_conditioning( |
|
batch, |
|
batch_uc=batch_uc, |
|
force_uc_zero_embeddings=force_uc_zero_embeddings, |
|
) |
|
unload_model(model.conditioner) |
|
|
|
for k in c: |
|
if not k == "crossattn": |
|
c[k], uc[k] = map( |
|
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) |
|
) |
|
|
|
additional_model_inputs = {} |
|
for k in batch2model_input: |
|
additional_model_inputs[k] = batch[k] |
|
|
|
shape = (math.prod(num_samples), C, H // F, W // F) |
|
randn = torch.randn(shape).to("cuda") |
|
|
|
def denoiser(input, sigma, c): |
|
return model.denoiser( |
|
model.model, input, sigma, c, **additional_model_inputs |
|
) |
|
|
|
load_model(model.denoiser) |
|
load_model(model.model) |
|
samples_z = sampler(denoiser, randn, cond=c, uc=uc) |
|
unload_model(model.model) |
|
unload_model(model.denoiser) |
|
|
|
load_model(model.first_stage_model) |
|
samples_x = model.decode_first_stage(samples_z) |
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) |
|
unload_model(model.first_stage_model) |
|
|
|
if filter is not None: |
|
samples = filter(samples) |
|
|
|
grid = torch.stack([samples]) |
|
grid = rearrange(grid, "n b c h w -> (n h) (b w) c") |
|
outputs.image(grid.cpu().numpy()) |
|
|
|
if return_latents: |
|
return samples, samples_z |
|
return samples |
|
|
|
|
|
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): |
|
|
|
|
|
batch = {} |
|
batch_uc = {} |
|
|
|
for key in keys: |
|
if key == "txt": |
|
batch["txt"] = ( |
|
np.repeat([value_dict["prompt"]], repeats=math.prod(N)) |
|
.reshape(N) |
|
.tolist() |
|
) |
|
batch_uc["txt"] = ( |
|
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) |
|
.reshape(N) |
|
.tolist() |
|
) |
|
elif key == "original_size_as_tuple": |
|
batch["original_size_as_tuple"] = ( |
|
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) |
|
.to(device) |
|
.repeat(*N, 1) |
|
) |
|
elif key == "crop_coords_top_left": |
|
batch["crop_coords_top_left"] = ( |
|
torch.tensor( |
|
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]] |
|
) |
|
.to(device) |
|
.repeat(*N, 1) |
|
) |
|
elif key == "aesthetic_score": |
|
batch["aesthetic_score"] = ( |
|
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) |
|
) |
|
batch_uc["aesthetic_score"] = ( |
|
torch.tensor([value_dict["negative_aesthetic_score"]]) |
|
.to(device) |
|
.repeat(*N, 1) |
|
) |
|
|
|
elif key == "target_size_as_tuple": |
|
batch["target_size_as_tuple"] = ( |
|
torch.tensor([value_dict["target_height"], value_dict["target_width"]]) |
|
.to(device) |
|
.repeat(*N, 1) |
|
) |
|
else: |
|
batch[key] = value_dict[key] |
|
|
|
for key in batch.keys(): |
|
if key not in batch_uc and isinstance(batch[key], torch.Tensor): |
|
batch_uc[key] = torch.clone(batch[key]) |
|
return batch, batch_uc |
|
|
|
|
|
@torch.no_grad() |
|
def do_img2img( |
|
img, |
|
model, |
|
sampler, |
|
value_dict, |
|
num_samples, |
|
force_uc_zero_embeddings=[], |
|
additional_kwargs={}, |
|
offset_noise_level: int = 0.0, |
|
return_latents=False, |
|
skip_encode=False, |
|
filter=None, |
|
add_noise=True, |
|
): |
|
st.text("Sampling") |
|
|
|
outputs = st.empty() |
|
precision_scope = autocast |
|
with torch.no_grad(): |
|
with precision_scope("cuda"): |
|
with model.ema_scope(): |
|
load_model(model.conditioner) |
|
batch, batch_uc = get_batch( |
|
get_unique_embedder_keys_from_conditioner(model.conditioner), |
|
value_dict, |
|
[num_samples], |
|
) |
|
c, uc = model.conditioner.get_unconditional_conditioning( |
|
batch, |
|
batch_uc=batch_uc, |
|
force_uc_zero_embeddings=force_uc_zero_embeddings, |
|
) |
|
unload_model(model.conditioner) |
|
for k in c: |
|
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) |
|
|
|
for k in additional_kwargs: |
|
c[k] = uc[k] = additional_kwargs[k] |
|
if skip_encode: |
|
z = img |
|
else: |
|
load_model(model.first_stage_model) |
|
z = model.encode_first_stage(img) |
|
unload_model(model.first_stage_model) |
|
|
|
noise = torch.randn_like(z) |
|
|
|
sigmas = sampler.discretization(sampler.num_steps).cuda() |
|
sigma = sigmas[0] |
|
|
|
st.info(f"all sigmas: {sigmas}") |
|
st.info(f"noising sigma: {sigma}") |
|
if offset_noise_level > 0.0: |
|
noise = noise + offset_noise_level * append_dims( |
|
torch.randn(z.shape[0], device=z.device), z.ndim |
|
) |
|
if add_noise: |
|
noised_z = z + noise * append_dims(sigma, z.ndim).cuda() |
|
noised_z = noised_z / torch.sqrt( |
|
1.0 + sigmas[0] ** 2.0 |
|
) |
|
else: |
|
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) |
|
|
|
def denoiser(x, sigma, c): |
|
return model.denoiser(model.model, x, sigma, c) |
|
|
|
load_model(model.denoiser) |
|
load_model(model.model) |
|
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) |
|
unload_model(model.model) |
|
unload_model(model.denoiser) |
|
|
|
load_model(model.first_stage_model) |
|
samples_x = model.decode_first_stage(samples_z) |
|
unload_model(model.first_stage_model) |
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) |
|
|
|
if filter is not None: |
|
samples = filter(samples) |
|
|
|
grid = embed_watemark(torch.stack([samples])) |
|
grid = rearrange(grid, "n b c h w -> (n h) (b w) c") |
|
outputs.image(grid.cpu().numpy()) |
|
if return_latents: |
|
return samples, samples_z |
|
return samples |
|
|