Infinity / app.py
MohamedRashad's picture
Refactor app.py to improve UI layout and rename weight download function; update import path for AutoEncoder in vae.py
f7f1ca1
raw
history blame
20.1 kB
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import os.path as osp
import time
import hashlib
import argparse
import shutil
import re
import random
from pathlib import Path
from typing import List
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image, ImageEnhance
import PIL.Image as PImage
from torchvision.transforms.functional import to_tensor
from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast
from huggingface_hub import hf_hub_download
import gradio as gr
import spaces
from models.infinity import Infinity
from models.basic import *
from utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
torch._dynamo.config.cache_size_limit = 64
# Define a function to download weights if not present
def download_infinity_weights(weights_path):
try:
model_file = weights_path / 'infinity_2b_reg.pth'
if not model_file.exists():
hf_hub_download(repo_id="FoundationVision/Infinity", filename="infinity_2b_reg.pth", local_dir=str(weights_path))
vae_file = weights_path / 'infinity_vae_d32reg.pth'
if not vae_file.exists():
hf_hub_download(repo_id="FoundationVision/Infinity", filename="infinity_vae_d32reg.pth", local_dir=str(weights_path))
except Exception as e:
print(f"Error downloading weights: {e}")
def extract_key_val(text):
pattern = r'<(.+?):(.+?)>'
matches = re.findall(pattern, text)
key_val = {}
for match in matches:
key_val[match[0]] = match[1].lstrip()
return key_val
def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=False):
if enable_positive_prompt:
print(f'before positive_prompt aug: {prompt}')
prompt = aug_with_positive_prompt(prompt)
print(f'after positive_prompt aug: {prompt}')
print(f'prompt={prompt}')
captions = [prompt]
tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
input_ids = tokens.input_ids.cuda(non_blocking=True)
mask = tokens.attention_mask.cuda(non_blocking=True)
text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
lens: List[int] = mask.sum(dim=-1).tolist()
cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
Ltext = max(lens)
kv_compact = []
for len_i, feat_i in zip(lens, text_features.unbind(0)):
kv_compact.append(feat_i[:len_i])
kv_compact = torch.cat(kv_compact, dim=0)
text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
return text_cond_tuple
def aug_with_positive_prompt(prompt):
for key in ['man', 'woman', 'men', 'women', 'boy', 'girl', 'child', 'person', 'human', 'adult', 'teenager', 'employee',
'employer', 'worker', 'mother', 'father', 'sister', 'brother', 'grandmother', 'grandfather', 'son', 'daughter']:
if key in prompt:
prompt = prompt + '. very smooth faces, good looking faces, face to the camera, perfect facial features'
break
return prompt
def enhance_image(image):
for t in range(1):
contrast_image = image.copy()
contrast_enhancer = ImageEnhance.Contrast(contrast_image)
contrast_image = contrast_enhancer.enhance(1.05) # 增强对比度
color_image = contrast_image.copy()
color_enhancer = ImageEnhance.Color(color_image)
color_image = color_enhancer.enhance(1.05) # 增强饱和度
return color_image
def gen_one_img(
infinity_test,
vae,
text_tokenizer,
text_encoder,
prompt,
cfg_list=[],
tau_list=[],
negative_prompt='',
scale_schedule=None,
top_k=900,
top_p=0.97,
cfg_sc=3,
cfg_exp_k=0.0,
cfg_insertion_layer=-5,
vae_type=0,
gumbel=0,
softmax_merge_topk=-1,
gt_leak=-1,
gt_ls_Bl=None,
g_seed=None,
sampling_per_bits=1,
enable_positive_prompt=0,
):
sstt = time.time()
if not isinstance(cfg_list, list):
cfg_list = [cfg_list] * len(scale_schedule)
if not isinstance(tau_list, list):
tau_list = [tau_list] * len(scale_schedule)
text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
if negative_prompt:
negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
else:
negative_label_B_or_BLT = None
print(f'cfg: {cfg_list}, tau: {tau_list}')
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
stt = time.time()
_, _, img_list = infinity_test.autoregressive_infer_cfg(
vae=vae,
scale_schedule=scale_schedule,
label_B_or_BLT=text_cond_tuple, g_seed=g_seed,
B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None,
cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p,
returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False,
cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer,
vae_type=vae_type, softmax_merge_topk=softmax_merge_topk,
ret_img=True, trunk_scale=1000,
gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True,
sampling_per_bits=sampling_per_bits,
)
print(f"cost: {time.time() - sstt}, infinity cost={time.time() - stt}")
img = img_list[0]
return img
def get_prompt_id(prompt):
md5 = hashlib.md5()
md5.update(prompt.encode('utf-8'))
prompt_id = md5.hexdigest()
return prompt_id
def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_fsdp'):
print('[Save slim model]')
full_ckpt = torch.load(infinity_model_path, map_location=device)
infinity_slim = full_ckpt['trainer'][key]
# ema_state_dict = cpu_d['trainer'].get('gpt_ema_fsdp', state_dict)
if not save_file:
save_file = osp.splitext(infinity_model_path)[0] + '-slim.pth'
print(f'Save to {save_file}')
torch.save(infinity_slim, save_file)
print('[Save slim model] done')
return save_file
def load_tokenizer(t5_path =''):
print(f'[Loading tokenizer and text encoder]')
text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
text_tokenizer.model_max_length = 512
text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
text_encoder.to('cuda')
text_encoder.eval()
text_encoder.requires_grad_(False)
return text_tokenizer, text_encoder
def load_infinity(
rope2d_each_sa_layer,
rope2d_normalized_by_hw,
use_scale_schedule_embedding,
pn,
use_bit_label,
add_lvl_embeding_only_first_block,
model_path='',
scale_schedule=None,
vae=None,
device='cuda',
model_kwargs=None,
text_channels=2048,
apply_spatial_patchify=0,
use_flex_attn=False,
bf16=False,
):
print(f'[Loading Infinity]')
text_maxlen = 512
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
infinity_test: Infinity = Infinity(
vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
shared_aln=True, raw_scale_schedule=scale_schedule,
checkpointing='full-block',
customized_flash_attn=False,
fused_norm=True,
pad_to_multiplier=128,
use_flex_attn=use_flex_attn,
add_lvl_embeding_only_first_block=add_lvl_embeding_only_first_block,
use_bit_label=use_bit_label,
rope2d_each_sa_layer=rope2d_each_sa_layer,
rope2d_normalized_by_hw=rope2d_normalized_by_hw,
pn=pn,
apply_spatial_patchify=apply_spatial_patchify,
inference_mode=True,
train_h_div_w_list=[1.0],
**model_kwargs,
).to(device=device)
print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
if bf16:
for block in infinity_test.unregistered_blocks:
block.bfloat16()
infinity_test.eval()
infinity_test.requires_grad_(False)
infinity_test.cuda()
torch.cuda.empty_cache()
print(f'[Load Infinity weights]')
state_dict = torch.load(model_path, map_location=device)
print(infinity_test.load_state_dict(state_dict))
infinity_test.rng = torch.Generator(device=device)
return infinity_test
def transform(pil_img, tgt_h, tgt_w):
width, height = pil_img.size
if width / height <= tgt_w / tgt_h:
resized_width = tgt_w
resized_height = int(tgt_w / (width / height))
else:
resized_height = tgt_h
resized_width = int((width / height) * tgt_h)
pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS)
# crop the center out
arr = np.array(pil_img)
crop_y = (arr.shape[0] - tgt_h) // 2
crop_x = (arr.shape[1] - tgt_w) // 2
im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w])
return im.add(im).add_(-1)
def joint_vi_vae_encode_decode(vae, image_path, scale_schedule, device, tgt_h, tgt_w):
pil_image = Image.open(image_path).convert('RGB')
inp = transform(pil_image, tgt_h, tgt_w)
inp = inp.unsqueeze(0).to(device)
scale_schedule = [(item[0], item[1], item[2]) for item in scale_schedule]
t1 = time.time()
h, z, _, all_bit_indices, _, infinity_input = vae.encode(inp, scale_schedule=scale_schedule)
t2 = time.time()
recons_img = vae.decode(z)[0]
if len(recons_img.shape) == 4:
recons_img = recons_img.squeeze(1)
print(f'recons: z.shape: {z.shape}, recons_img shape: {recons_img.shape}')
t3 = time.time()
print(f'vae encode takes {t2-t1:.2f}s, decode takes {t3-t2:.2f}s')
recons_img = (recons_img + 1) / 2
recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
gt_img = (inp[0] + 1) / 2
gt_img = gt_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
print(recons_img.shape, gt_img.shape)
return gt_img, recons_img, all_bit_indices
def load_visual_tokenizer(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load vae
if args.vae_type in [16,18,20,24,32,64]:
from models.bsq_vae.vae import vae_model
schedule_mode = "dynamic"
codebook_dim = args.vae_type
codebook_size = 2**codebook_dim
if args.apply_spatial_patchify:
patch_size = 8
encoder_ch_mult=[1, 2, 4, 4]
decoder_ch_mult=[1, 2, 4, 4]
else:
patch_size = 16
encoder_ch_mult=[1, 2, 4, 4, 4]
decoder_ch_mult=[1, 2, 4, 4, 4]
vae = vae_model(args.vae_path, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(device)
else:
raise ValueError(f'vae_type={args.vae_type} not supported')
return vae
def load_transformer(vae, args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = args.model_path
if args.checkpoint_type == 'torch':
# copy large model to local; save slim to local; and copy slim to nas; load local slim model
if osp.exists(args.cache_dir):
local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
else:
local_model_path = model_path
if args.enable_model_cache:
slim_model_path = model_path.replace('ar-', 'slim-')
local_slim_model_path = local_model_path.replace('ar-', 'slim-')
os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
print(f'model_path: {model_path}, slim_model_path: {slim_model_path}')
print(f'local_model_path: {local_model_path}, local_slim_model_path: {local_slim_model_path}')
if not osp.exists(local_slim_model_path):
if osp.exists(slim_model_path):
print(f'copy {slim_model_path} to {local_slim_model_path}')
shutil.copyfile(slim_model_path, local_slim_model_path)
else:
if not osp.exists(local_model_path):
print(f'copy {model_path} to {local_model_path}')
shutil.copyfile(model_path, local_model_path)
save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
print(f'copy {local_slim_model_path} to {slim_model_path}')
if not osp.exists(slim_model_path):
shutil.copyfile(local_slim_model_path, slim_model_path)
os.remove(local_model_path)
os.remove(model_path)
slim_model_path = local_slim_model_path
else:
slim_model_path = model_path
print(f'load checkpoint from {slim_model_path}')
if args.model_type == 'infinity_2b':
kwargs_model = dict(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8) # 2b model
elif args.model_type == 'infinity_layer12':
kwargs_model = dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
elif args.model_type == 'infinity_layer16':
kwargs_model = dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
elif args.model_type == 'infinity_layer24':
kwargs_model = dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
elif args.model_type == 'infinity_layer32':
kwargs_model = dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
elif args.model_type == 'infinity_layer40':
kwargs_model = dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
elif args.model_type == 'infinity_layer48':
kwargs_model = dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
infinity = load_infinity(
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
use_scale_schedule_embedding=args.use_scale_schedule_embedding,
pn=args.pn,
use_bit_label=args.use_bit_label,
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
model_path=slim_model_path,
scale_schedule=None,
vae=vae,
device=device,
model_kwargs=kwargs_model,
text_channels=args.text_channels,
apply_spatial_patchify=args.apply_spatial_patchify,
use_flex_attn=args.use_flex_attn,
bf16=args.bf16,
)
return infinity
# Set up paths
weights_path = Path(__file__).parent / 'weights'
weights_path.mkdir(exist_ok=True)
download_infinity_weights(weights_path)
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
# Define args
args = argparse.Namespace(
pn='1M',
model_path=str(weights_path / 'infinity_2b_reg.pth'),
cfg_insertion_layer=0,
vae_type=32,
vae_path=str(weights_path / 'infinity_vae_d32reg.pth'),
add_lvl_embeding_only_first_block=1,
use_bit_label=1,
model_type='infinity_2b',
rope2d_each_sa_layer=1,
rope2d_normalized_by_hw=2,
use_scale_schedule_embedding=0,
sampling_per_bits=1,
text_encoder_ckpt=str(weights_path / 'flan-t5-xl'),
text_channels=2048,
apply_spatial_patchify=0,
h_div_w_template=1.000,
use_flex_attn=0,
cache_dir='/dev/shm',
checkpoint_type='torch',
seed=0,
bf16=1 if dtype == torch.bfloat16 else 0,
save_file='tmp.jpg',
enable_model_cache=False,
)
# Load models
text_tokenizer, text_encoder = load_tokenizer(t5_path="google/flan-t5-xl")
vae = load_visual_tokenizer(args)
infinity = load_transformer(vae, args)
# Define the image generation function
@spaces.GPU
def generate_image(prompt, cfg, tau, h_div_w, seed, enable_positive_prompt):
try:
args.prompt = prompt
args.cfg = cfg
args.tau = tau
args.h_div_w = h_div_w
args.seed = seed
args.enable_positive_prompt = enable_positive_prompt
# Find the closest h_div_w_template
h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates - h_div_w))]
# Get scale_schedule based on h_div_w_template_
scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
# Generate the image
generated_image = gen_one_img(
infinity,
vae,
text_tokenizer,
text_encoder,
prompt,
g_seed=seed,
gt_leak=0,
gt_ls_Bl=None,
cfg_list=cfg,
tau_list=tau,
scale_schedule=scale_schedule,
cfg_insertion_layer=[args.cfg_insertion_layer],
vae_type=args.vae_type,
sampling_per_bits=args.sampling_per_bits,
enable_positive_prompt=enable_positive_prompt,
)
# Convert the image to RGB and uint8
image = generated_image.cpu().numpy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = np.uint8(image)
return image
except Exception as e:
print(f"Error generating image: {e}")
return None
# Set up Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h1><center>Infinity Image Generator</center></h1>")
with gr.Row():
with gr.Column():
# Prompt Settings
gr.Markdown("### Prompt Settings")
prompt = gr.Textbox(label="Prompt", value="alien spaceship enterprise", placeholder="Enter your prompt here...")
enable_positive_prompt = gr.Checkbox(label="Enable Positive Prompt", value=False, info="Enhance prompts with positive attributes for faces.")
# Image Settings
gr.Markdown("### Image Settings")
with gr.Row():
cfg = gr.Slider(label="CFG (Classifier-Free Guidance)", minimum=1, maximum=10, step=0.5, value=3, info="Controls the strength of the prompt.")
tau = gr.Slider(label="Tau (Temperature)", minimum=0.1, maximum=1.0, step=0.1, value=0.5, info="Controls the randomness of the output.")
with gr.Row():
h_div_w = gr.Slider(label="Aspect Ratio (Height/Width)", minimum=0.5, maximum=2.0, step=0.1, value=1.0, info="Set the aspect ratio of the generated image.")
seed = gr.Number(label="Seed", value=random.randint(0, 10000), info="Set a seed for reproducibility.")
# Generate Button
generate_button = gr.Button("Generate Image", variant="primary")
with gr.Column():
# Output Section
gr.Markdown("### Generated Image")
output_image = gr.Image(label="Generated Image", type="pil")
gr.Markdown("**Tip:** Right-click the image to save it.")
# Error Handling
error_message = gr.Textbox(label="Error Message", visible=False)
# Link the generate button to the image generation function
generate_button.click(
generate_image,
inputs=[prompt, cfg, tau, h_div_w, seed, enable_positive_prompt],
outputs=output_image
)
# Launch the Gradio app
demo.launch()