|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import gc |
|
import os |
|
import random |
|
import warnings |
|
from dataclasses import dataclass, field |
|
from datetime import datetime |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import pyrallis |
|
import torch |
|
from gradio.components import Image, Textbox |
|
from torchvision.utils import _log_api_usage_once, make_grid, save_image |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
from asset.examples import examples |
|
from diffusion import DPMS, FlowEuler, SASolverSampler |
|
from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST |
|
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode |
|
from diffusion.model.utils import prepare_prompt_ar, resize_and_crop_tensor |
|
from diffusion.utils.config import SanaConfig |
|
from diffusion.utils.dist_utils import flush |
|
from tools.download import find_model |
|
|
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", type=str, help="config path") |
|
return parser.parse_known_args()[0] |
|
|
|
|
|
@dataclass |
|
class SanaInference(SanaConfig): |
|
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" |
|
model_path: str = field( |
|
default="output/Sana_1600M/SANA.pth", metadata={"help": "Path to the model file (positional)"} |
|
) |
|
output: str = "./output" |
|
bs: int = 1 |
|
image_size: int = 1024 |
|
cfg_scale: float = 5.0 |
|
pag_scale: float = 2.0 |
|
seed: int = 42 |
|
step: int = -1 |
|
port: int = 7788 |
|
custom_image_size: Optional[int] = None |
|
shield_model_path: str = field( |
|
default="google/shieldgemma-2b", |
|
metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."}, |
|
) |
|
|
|
|
|
@torch.no_grad() |
|
def ndarr_image( |
|
tensor: Union[torch.Tensor, List[torch.Tensor]], |
|
**kwargs, |
|
) -> None: |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(save_image) |
|
grid = make_grid(tensor, **kwargs) |
|
|
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
return ndarr |
|
|
|
|
|
def set_env(seed=0): |
|
torch.manual_seed(seed) |
|
torch.set_grad_enabled(False) |
|
for _ in range(30): |
|
torch.randn(1, 4, args.image_size, args.image_size) |
|
|
|
|
|
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: |
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
return seed |
|
|
|
|
|
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]: |
|
"""Returns binned height and width.""" |
|
ar = float(height / width) |
|
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) |
|
default_hw = ratios[closest_ratio] |
|
return int(default_hw[0]), int(default_hw[1]) |
|
|
|
|
|
@torch.inference_mode() |
|
def generate_img( |
|
prompt, |
|
sampler, |
|
sample_steps, |
|
scale, |
|
pag_scale=1.0, |
|
guidance_type="classifier-free", |
|
seed=0, |
|
randomize_seed=False, |
|
base_size=1024, |
|
height=1024, |
|
width=1024, |
|
): |
|
flush() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
seed = int(randomize_seed_fn(seed, randomize_seed)) |
|
set_env(seed) |
|
base_ratios = eval(f"ASPECT_RATIO_{base_size}_TEST") |
|
|
|
os.makedirs(f"output/demo/online_demo_prompts/", exist_ok=True) |
|
save_promt_path = f"output/demo/online_demo_prompts/tested_prompts{datetime.now().date()}.txt" |
|
with open(save_promt_path, "a") as f: |
|
f.write(f"{seed}: {prompt}" + "\n") |
|
print(f"{seed}: {prompt}") |
|
prompt_clean, prompt_show, _, _, _ = prepare_prompt_ar(prompt, base_ratios, device=device) |
|
orig_height, orig_width = height, width |
|
height, width = classify_height_width_bin(height, width, ratios=base_ratios) |
|
|
|
prompt_show += ( |
|
f"\n Sample steps: {sample_steps}, CFG Scale: {scale}, PAG Scale: {pag_scale}, flow_shift: {flow_shift}" |
|
) |
|
prompt_clean = prompt_clean.strip() |
|
if isinstance(prompt_clean, str): |
|
prompts = [prompt_clean] |
|
|
|
|
|
if not config.text_encoder.chi_prompt: |
|
max_length_all = max_sequence_length |
|
prompts_all = prompts |
|
else: |
|
chi_prompt = "\n".join(config.text_encoder.chi_prompt) |
|
prompts_all = [chi_prompt + prompt for prompt in prompts] |
|
num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt)) |
|
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 |
|
|
|
caption_token = tokenizer( |
|
prompts_all, max_length=max_length_all, padding="max_length", truncation=True, return_tensors="pt" |
|
).to(device) |
|
select_index = [0] + list(range(-max_sequence_length + 1, 0)) |
|
caption_embs = text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][:, :, select_index] |
|
emb_masks = caption_token.attention_mask[:, select_index] |
|
null_y = null_caption_embs.repeat(len(prompts), 1, 1)[:, None] |
|
|
|
n = len(prompts) |
|
latent_size_h, latent_size_w = height // config.vae.vae_downsample_rate, width // config.vae.vae_downsample_rate |
|
z = torch.randn(n, config.vae.vae_latent_dim, latent_size_h, latent_size_w, device=device) |
|
model_kwargs = dict(data_info={"img_hw": (latent_size_h, latent_size_w), "aspect_ratio": 1.0}, mask=emb_masks) |
|
print(f"Latent Size: {z.shape}") |
|
|
|
if sampler == "dpm-solver": |
|
|
|
dpm_solver = DPMS( |
|
model.forward_with_dpmsolver, |
|
condition=caption_embs, |
|
uncondition=null_y, |
|
cfg_scale=scale, |
|
model_kwargs=model_kwargs, |
|
) |
|
samples = dpm_solver.sample( |
|
z, |
|
steps=sample_steps, |
|
order=2, |
|
skip_type="time_uniform", |
|
method="multistep", |
|
) |
|
elif sampler == "sa-solver": |
|
|
|
sa_solver = SASolverSampler(model.forward_with_dpmsolver, device=device) |
|
samples = sa_solver.sample( |
|
S=sample_steps, |
|
batch_size=n, |
|
shape=(4, latent_size_h, latent_size_w), |
|
eta=1, |
|
conditioning=caption_embs, |
|
unconditional_conditioning=null_y, |
|
unconditional_guidance_scale=scale, |
|
model_kwargs=model_kwargs, |
|
)[0] |
|
elif sampler == "flow_euler": |
|
flow_solver = FlowEuler( |
|
model, condition=caption_embs, uncondition=null_y, cfg_scale=scale, model_kwargs=model_kwargs |
|
) |
|
samples = flow_solver.sample( |
|
z, |
|
steps=sample_steps, |
|
) |
|
elif sampler == "flow_dpm-solver": |
|
if not (pag_scale > 1.0 and config.model.attn_type == "linear"): |
|
guidance_type = "classifier-free" |
|
dpm_solver = DPMS( |
|
model, |
|
condition=caption_embs, |
|
uncondition=null_y, |
|
guidance_type=guidance_type, |
|
cfg_scale=scale, |
|
pag_scale=pag_scale, |
|
pag_applied_layers=pag_applied_layers, |
|
model_type="flow", |
|
model_kwargs=model_kwargs, |
|
schedule="FLOW", |
|
) |
|
samples = dpm_solver.sample( |
|
z, |
|
steps=sample_steps, |
|
order=2, |
|
skip_type="time_uniform_flow", |
|
method="multistep", |
|
flow_shift=flow_shift, |
|
) |
|
else: |
|
raise ValueError(f"{args.sampling_algo} is not defined") |
|
|
|
samples = samples.to(weight_dtype) |
|
samples = vae_decode(config.vae.vae_type, vae, samples) |
|
samples = resize_and_crop_tensor(samples, orig_width, orig_height) |
|
display_model_info = ( |
|
f"Model path: {args.model_path},\nBase image size: {args.image_size}, \nSampling Algo: {sampler}" |
|
) |
|
return ndarr_image(samples, normalize=True, value_range=(-1, 1)), prompt_show, display_model_info, seed |
|
|
|
|
|
if __name__ == "__main__": |
|
from diffusion.utils.logger import get_root_logger |
|
|
|
args = get_args() |
|
config = args = pyrallis.parse(config_class=SanaInference, config_path=args.config) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger = get_root_logger() |
|
|
|
args.image_size = config.model.image_size |
|
assert args.image_size in [ |
|
256, |
|
512, |
|
1024, |
|
2048, |
|
4096, |
|
], "We only provide pre-trained models for 256x256, 512x512, 1024x1024, 2048x2048 and 4096x4096 resolutions." |
|
|
|
|
|
latent_size = config.model.image_size // config.vae.vae_downsample_rate |
|
max_sequence_length = config.text_encoder.model_max_length |
|
pe_interpolation = config.model.pe_interpolation |
|
micro_condition = config.model.micro_condition |
|
pag_applied_layers = config.model.pag_applied_layers |
|
flow_shift = config.scheduler.flow_shift |
|
|
|
if config.model.mixed_precision == "fp16": |
|
weight_dtype = torch.float16 |
|
elif config.model.mixed_precision == "bf16": |
|
weight_dtype = torch.bfloat16 |
|
elif config.model.mixed_precision == "fp32": |
|
weight_dtype = torch.float32 |
|
else: |
|
raise ValueError(f"weigh precision {config.model.mixed_precision} is not defined") |
|
logger.info(f"Inference with {weight_dtype}") |
|
|
|
vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(weight_dtype) |
|
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder.text_encoder_name, device=device) |
|
|
|
|
|
pred_sigma = getattr(config.scheduler, "pred_sigma", True) |
|
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma |
|
model_kwargs = { |
|
"input_size": latent_size, |
|
"pe_interpolation": config.model.pe_interpolation, |
|
"config": config, |
|
"model_max_length": config.text_encoder.model_max_length, |
|
"qk_norm": config.model.qk_norm, |
|
"micro_condition": config.model.micro_condition, |
|
"caption_channels": text_encoder.config.hidden_size, |
|
"y_norm": config.text_encoder.y_norm, |
|
"attn_type": config.model.attn_type, |
|
"ffn_type": config.model.ffn_type, |
|
"mlp_ratio": config.model.mlp_ratio, |
|
"mlp_acts": list(config.model.mlp_acts), |
|
"in_channels": config.vae.vae_latent_dim, |
|
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor, |
|
"use_pe": config.model.use_pe, |
|
"linear_head_dim": config.model.linear_head_dim, |
|
"pred_sigma": pred_sigma, |
|
"learn_sigma": learn_sigma, |
|
} |
|
model = build_model( |
|
config.model.model, use_fp32_attention=config.model.get("fp32_attention", False), **model_kwargs |
|
).to(device) |
|
|
|
logger.info( |
|
f"{model.__class__.__name__}:{config.model.model}, Model Parameters: {sum(p.numel() for p in model.parameters()):,}" |
|
) |
|
logger.info("Generating sample from ckpt: %s" % args.model_path) |
|
state_dict = find_model(args.model_path) |
|
if "pos_embed" in state_dict["state_dict"]: |
|
del state_dict["state_dict"]["pos_embed"] |
|
|
|
missing, unexpected = model.load_state_dict(state_dict["state_dict"], strict=False) |
|
logger.warning(f"Missing keys: {missing}") |
|
logger.warning(f"Unexpected keys: {unexpected}") |
|
model.eval().to(weight_dtype) |
|
base_ratios = eval(f"ASPECT_RATIO_{args.image_size}_TEST") |
|
|
|
null_caption_token = tokenizer( |
|
"", max_length=max_sequence_length, padding="max_length", truncation=True, return_tensors="pt" |
|
).to(device) |
|
null_caption_embs = text_encoder(null_caption_token.input_ids, attention_mask=null_caption_token.attention_mask)[0] |
|
|
|
model_size = "1.6" if "D20" in args.model_path else "0.6" |
|
title = f""" |
|
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'> |
|
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/> |
|
</div> |
|
""" |
|
DESCRIPTION = f""" |
|
<p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p> |
|
<p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p> |
|
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github(coming soon)]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p> |
|
<p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space</p> |
|
""" |
|
if model_size == "0.6": |
|
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>" |
|
if not torch.cuda.is_available(): |
|
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" |
|
|
|
demo = gr.Interface( |
|
fn=generate_img, |
|
inputs=[ |
|
Textbox( |
|
label="Note: If you want to specify a aspect ratio or determine a customized height and width, " |
|
"use --ar h:w (or --aspect_ratio h:w) or --hw h:w. If no aspect ratio or hw is given, all setting will be default.", |
|
placeholder="Please enter your prompt. \n", |
|
), |
|
gr.Radio( |
|
choices=["dpm-solver", "sa-solver", "flow_dpm-solver", "flow_euler"], |
|
label=f"Sampler", |
|
interactive=True, |
|
value="flow_dpm-solver", |
|
), |
|
gr.Slider(label="Sample Steps", minimum=1, maximum=100, value=20, step=1), |
|
gr.Slider(label="Guidance Scale", minimum=1.0, maximum=30.0, value=5.0, step=0.1), |
|
gr.Slider(label="PAG Scale", minimum=1.0, maximum=10.0, value=2.5, step=0.5), |
|
gr.Radio( |
|
choices=["classifier-free", "classifier-free_PAG", "classifier-free_PAG_seq"], |
|
label=f"Guidance Type", |
|
interactive=True, |
|
value="classifier-free_PAG_seq", |
|
), |
|
gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=MAX_SEED, |
|
step=1, |
|
value=0, |
|
), |
|
gr.Checkbox(label="Randomize seed", value=True), |
|
gr.Radio( |
|
choices=[256, 512, 1024, 2048, 4096], |
|
label=f"Base Size", |
|
interactive=True, |
|
value=args.image_size, |
|
), |
|
gr.Slider( |
|
label="Height", |
|
minimum=256, |
|
maximum=6000, |
|
step=32, |
|
value=args.image_size, |
|
), |
|
gr.Slider( |
|
label="Width", |
|
minimum=256, |
|
maximum=6000, |
|
step=32, |
|
value=args.image_size, |
|
), |
|
], |
|
outputs=[ |
|
Image(type="numpy", label="Img"), |
|
Textbox(label="clean prompt"), |
|
Textbox(label="model info"), |
|
gr.Slider(label="seed"), |
|
], |
|
title=title, |
|
description=DESCRIPTION, |
|
examples=examples, |
|
) |
|
demo.launch(server_name="0.0.0.0", server_port=args.port, debug=True, share=True) |
|
|