import math import os from typing import List, Union import numpy as np import streamlit as st import torch from einops import rearrange, repeat from imwatermark import WatermarkEncoder from omegaconf import ListConfig, OmegaConf from PIL import Image from safetensors.torch import load_file as load_safetensors from torch import autocast from torchvision import transforms from torchvision.utils import make_grid from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.modules.diffusionmodules.sampling import ( DPMPP2MSampler, DPMPP2SAncestralSampler, EulerAncestralSampler, EulerEDMSampler, HeunEDMSampler, LinearMultistepSampler, ) from sgm.util import append_dims, instantiate_from_config class WatermarkEmbedder: def __init__(self, watermark): self.watermark = watermark self.num_bits = len(WATERMARK_BITS) self.encoder = WatermarkEncoder() self.encoder.set_watermark("bits", self.watermark) def __call__(self, image: torch.Tensor): """ Adds a predefined watermark to the input image Args: image: ([N,] B, C, H, W) in range [0, 1] Returns: same as input but watermarked """ # watermarking libary expects input as cv2 BGR format squeeze = len(image.shape) == 4 if squeeze: image = image[None, ...] n = image.shape[0] image_np = rearrange( (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" ).numpy()[:, :, :, ::-1] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") image = torch.from_numpy( rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) ).to(image.device) image = torch.clamp(image / 255, min=0.0, max=1.0) if squeeze: image = image[0] return image # A fixed 48-bit message that was choosen at random # WATERMARK_MESSAGE = 0xB3EC907BB19E WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] embed_watemark = WatermarkEmbedder(WATERMARK_BITS) @st.cache_resource() def init_st(version_dict, load_ckpt=True, load_filter=True): state = dict() if not "model" in state: config = version_dict["config"] ckpt = version_dict["ckpt"] config = OmegaConf.load(config) model, msg = load_model_from_config(config, ckpt if load_ckpt else None) state["msg"] = msg state["model"] = model state["ckpt"] = ckpt if load_ckpt else None state["config"] = config if load_filter: state["filter"] = DeepFloydDataFiltering(verbose=False) return state def load_model(model): model.cuda() lowvram_mode = False def set_lowvram_mode(mode): global lowvram_mode lowvram_mode = mode def initial_model_load(model): global lowvram_mode if lowvram_mode: model.model.half() else: model.cuda() return model def unload_model(model): global lowvram_mode if lowvram_mode: model.cpu() torch.cuda.empty_cache() def load_model_from_config(config, ckpt=None, verbose=True): model = instantiate_from_config(config.model) if ckpt is not None: print(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: global_step = pl_sd["global_step"] st.info(f"loaded ckpt from global step {global_step}") print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] elif ckpt.endswith("safetensors"): sd = load_safetensors(ckpt) else: raise NotImplementedError msg = None m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) else: msg = None model = initial_model_load(model) model.eval() return model, msg def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): # Hardcoded demo settings; might undergo some changes in the future value_dict = {} for key in keys: if key == "txt": if prompt is None: prompt = st.text_input( "Prompt", "A professional photograph of an astronaut riding a pig" ) if negative_prompt is None: negative_prompt = st.text_input("Negative prompt", "") value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt if key == "original_size_as_tuple": orig_width = st.number_input( "orig_width", value=init_dict["orig_width"], min_value=16, ) orig_height = st.number_input( "orig_height", value=init_dict["orig_height"], min_value=16, ) value_dict["orig_width"] = orig_width value_dict["orig_height"] = orig_height if key == "crop_coords_top_left": crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0) crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0) value_dict["crop_coords_top"] = crop_coord_top value_dict["crop_coords_left"] = crop_coord_left if key == "aesthetic_score": value_dict["aesthetic_score"] = 6.0 value_dict["negative_aesthetic_score"] = 2.5 if key == "target_size_as_tuple": value_dict["target_width"] = init_dict["target_width"] value_dict["target_height"] = init_dict["target_height"] return value_dict def perform_save_locally(save_path, samples): os.makedirs(os.path.join(save_path), exist_ok=True) base_count = len(os.listdir(os.path.join(save_path))) samples = embed_watemark(samples) for sample in samples: sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(sample.astype(np.uint8)).save( os.path.join(save_path, f"{base_count:09}.png") ) base_count += 1 def init_save_locally(_dir, init_value: bool = False): save_locally = st.sidebar.checkbox("Save images locally", value=init_value) if save_locally: save_path = st.text_input("Save path", value=os.path.join(_dir, "samples")) else: save_path = None return save_locally, save_path class Img2ImgDiscretizationWrapper: """ wraps a discretizer, and prunes the sigmas params: strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) """ def __init__(self, discretization, strength: float = 1.0): self.discretization = discretization self.strength = strength assert 0.0 <= self.strength <= 1.0 def __call__(self, *args, **kwargs): # sigmas start large first, and decrease then sigmas = self.discretization(*args, **kwargs) print(f"sigmas after discretization, before pruning img2img: ", sigmas) sigmas = torch.flip(sigmas, (0,)) sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] print("prune index:", max(int(self.strength * len(sigmas)), 1)) sigmas = torch.flip(sigmas, (0,)) print(f"sigmas after pruning: ", sigmas) return sigmas class Txt2NoisyDiscretizationWrapper: """ wraps a discretizer, and prunes the sigmas params: strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) """ def __init__(self, discretization, strength: float = 0.0, original_steps=None): self.discretization = discretization self.strength = strength self.original_steps = original_steps assert 0.0 <= self.strength <= 1.0 def __call__(self, *args, **kwargs): # sigmas start large first, and decrease then sigmas = self.discretization(*args, **kwargs) print(f"sigmas after discretization, before pruning img2img: ", sigmas) sigmas = torch.flip(sigmas, (0,)) if self.original_steps is None: steps = len(sigmas) else: steps = self.original_steps + 1 prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) sigmas = sigmas[prune_index:] print("prune index:", prune_index) sigmas = torch.flip(sigmas, (0,)) print(f"sigmas after pruning: ", sigmas) return sigmas def get_guider(key): guider = st.sidebar.selectbox( f"Discretization #{key}", [ "VanillaCFG", "IdentityGuider", ], ) if guider == "IdentityGuider": guider_config = { "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" } elif guider == "VanillaCFG": scale = st.number_input( f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0 ) thresholder = st.sidebar.selectbox( f"Thresholder #{key}", [ "None", ], ) if thresholder == "None": dyn_thresh_config = { "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" } else: raise NotImplementedError guider_config = { "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, } else: raise NotImplementedError return guider_config def init_sampling( key=1, img2img_strength=1.0, specify_num_samples=True, stage2strength=None, ): num_rows, num_cols = 1, 1 if specify_num_samples: num_cols = st.number_input( f"num cols #{key}", value=2, min_value=1, max_value=10 ) steps = st.sidebar.number_input( f"steps #{key}", value=40, min_value=1, max_value=1000 ) sampler = st.sidebar.selectbox( f"Sampler #{key}", [ "EulerEDMSampler", "HeunEDMSampler", "EulerAncestralSampler", "DPMPP2SAncestralSampler", "DPMPP2MSampler", "LinearMultistepSampler", ], 0, ) discretization = st.sidebar.selectbox( f"Discretization #{key}", [ "LegacyDDPMDiscretization", "EDMDiscretization", ], ) discretization_config = get_discretization(discretization, key=key) guider_config = get_guider(key=key) sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) if img2img_strength < 1.0: st.warning( f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper" ) sampler.discretization = Img2ImgDiscretizationWrapper( sampler.discretization, strength=img2img_strength ) if stage2strength is not None: sampler.discretization = Txt2NoisyDiscretizationWrapper( sampler.discretization, strength=stage2strength, original_steps=steps ) return sampler, num_rows, num_cols def get_discretization(discretization, key=1): if discretization == "LegacyDDPMDiscretization": discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", } elif discretization == "EDMDiscretization": sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292 sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146 rho = st.number_input(f"rho #{key}", value=3.0) discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", "params": { "sigma_min": sigma_min, "sigma_max": sigma_max, "rho": rho, }, } return discretization_config def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1): if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0) s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0) s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0) s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0) if sampler_name == "EulerEDMSampler": sampler = EulerEDMSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, verbose=True, ) elif sampler_name == "HeunEDMSampler": sampler = HeunEDMSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, verbose=True, ) elif ( sampler_name == "EulerAncestralSampler" or sampler_name == "DPMPP2SAncestralSampler" ): s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0) eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0) if sampler_name == "EulerAncestralSampler": sampler = EulerAncestralSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, eta=eta, s_noise=s_noise, verbose=True, ) elif sampler_name == "DPMPP2SAncestralSampler": sampler = DPMPP2SAncestralSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, eta=eta, s_noise=s_noise, verbose=True, ) elif sampler_name == "DPMPP2MSampler": sampler = DPMPP2MSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, verbose=True, ) elif sampler_name == "LinearMultistepSampler": order = st.sidebar.number_input("order", value=4, min_value=1) sampler = LinearMultistepSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, order=order, verbose=True, ) else: raise ValueError(f"unknown sampler {sampler_name}!") return sampler def get_interactive_image(key=None) -> Image.Image: image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) if image is not None: image = Image.open(image) if not image.mode == "RGB": image = image.convert("RGB") return image def load_img(display=True, key=None): image = get_interactive_image(key=key) if image is None: return None if display: st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") transform = transforms.Compose( [ transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0), ] ) img = transform(image)[None, ...] st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}") return img def get_init_img(batch_size=1, key=None): init_image = load_img(key=key).cuda() init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) return init_image def do_sample( model, sampler, value_dict, num_samples, H, W, C, F, force_uc_zero_embeddings: List = None, batch2model_input: List = None, return_latents=False, filter=None, ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] if batch2model_input is None: batch2model_input = [] st.text("Sampling") outputs = st.empty() precision_scope = autocast with torch.no_grad(): with precision_scope("cuda"): with model.ema_scope(): num_samples = [num_samples] load_model(model.conditioner) batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples, ) for key in batch: if isinstance(batch[key], torch.Tensor): print(key, batch[key].shape) elif isinstance(batch[key], list): print(key, [len(l) for l in batch[key]]) else: print(key, batch[key]) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, ) unload_model(model.conditioner) for k in c: if not k == "crossattn": c[k], uc[k] = map( lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) ) additional_model_inputs = {} for k in batch2model_input: additional_model_inputs[k] = batch[k] shape = (math.prod(num_samples), C, H // F, W // F) randn = torch.randn(shape).to("cuda") def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) load_model(model.denoiser) load_model(model.model) samples_z = sampler(denoiser, randn, cond=c, uc=uc) unload_model(model.model) unload_model(model.denoiser) load_model(model.first_stage_model) samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) unload_model(model.first_stage_model) if filter is not None: samples = filter(samples) grid = torch.stack([samples]) grid = rearrange(grid, "n b c h w -> (n h) (b w) c") outputs.image(grid.cpu().numpy()) if return_latents: return samples, samples_z return samples def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): # Hardcoded demo setups; might undergo some changes in the future batch = {} batch_uc = {} for key in keys: if key == "txt": batch["txt"] = ( np.repeat([value_dict["prompt"]], repeats=math.prod(N)) .reshape(N) .tolist() ) batch_uc["txt"] = ( np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) .reshape(N) .tolist() ) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) .to(device) .repeat(*N, 1) ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( torch.tensor( [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] ) .to(device) .repeat(*N, 1) ) elif key == "aesthetic_score": batch["aesthetic_score"] = ( torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) ) batch_uc["aesthetic_score"] = ( torch.tensor([value_dict["negative_aesthetic_score"]]) .to(device) .repeat(*N, 1) ) elif key == "target_size_as_tuple": batch["target_size_as_tuple"] = ( torch.tensor([value_dict["target_height"], value_dict["target_width"]]) .to(device) .repeat(*N, 1) ) else: batch[key] = value_dict[key] for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc @torch.no_grad() def do_img2img( img, model, sampler, value_dict, num_samples, force_uc_zero_embeddings=[], additional_kwargs={}, offset_noise_level: int = 0.0, return_latents=False, skip_encode=False, filter=None, add_noise=True, ): st.text("Sampling") outputs = st.empty() precision_scope = autocast with torch.no_grad(): with precision_scope("cuda"): with model.ema_scope(): load_model(model.conditioner) batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [num_samples], ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, ) unload_model(model.conditioner) for k in c: c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) for k in additional_kwargs: c[k] = uc[k] = additional_kwargs[k] if skip_encode: z = img else: load_model(model.first_stage_model) z = model.encode_first_stage(img) unload_model(model.first_stage_model) noise = torch.randn_like(z) sigmas = sampler.discretization(sampler.num_steps).cuda() sigma = sigmas[0] st.info(f"all sigmas: {sigmas}") st.info(f"noising sigma: {sigma}") if offset_noise_level > 0.0: noise = noise + offset_noise_level * append_dims( torch.randn(z.shape[0], device=z.device), z.ndim ) if add_noise: noised_z = z + noise * append_dims(sigma, z.ndim).cuda() noised_z = noised_z / torch.sqrt( 1.0 + sigmas[0] ** 2.0 ) # Note: hardcoded to DDPM-like scaling. need to generalize later. else: noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) load_model(model.denoiser) load_model(model.model) samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) unload_model(model.model) unload_model(model.denoiser) load_model(model.first_stage_model) samples_x = model.decode_first_stage(samples_z) unload_model(model.first_stage_model) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) grid = embed_watemark(torch.stack([samples])) grid = rearrange(grid, "n b c h w -> (n h) (b w) c") outputs.image(grid.cpu().numpy()) if return_latents: return samples, samples_z return samples