Spaces:
Running
on
Zero
Running
on
Zero

Refactor app.py to improve UI layout and rename weight download function; update import path for AutoEncoder in vae.py
f7f1ca1
import subprocess | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
import os.path as osp | |
import time | |
import hashlib | |
import argparse | |
import shutil | |
import re | |
import random | |
from pathlib import Path | |
from typing import List | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image, ImageEnhance | |
import PIL.Image as PImage | |
from torchvision.transforms.functional import to_tensor | |
from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
import spaces | |
from models.infinity import Infinity | |
from models.basic import * | |
from utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates | |
torch._dynamo.config.cache_size_limit = 64 | |
# Define a function to download weights if not present | |
def download_infinity_weights(weights_path): | |
try: | |
model_file = weights_path / 'infinity_2b_reg.pth' | |
if not model_file.exists(): | |
hf_hub_download(repo_id="FoundationVision/Infinity", filename="infinity_2b_reg.pth", local_dir=str(weights_path)) | |
vae_file = weights_path / 'infinity_vae_d32reg.pth' | |
if not vae_file.exists(): | |
hf_hub_download(repo_id="FoundationVision/Infinity", filename="infinity_vae_d32reg.pth", local_dir=str(weights_path)) | |
except Exception as e: | |
print(f"Error downloading weights: {e}") | |
def extract_key_val(text): | |
pattern = r'<(.+?):(.+?)>' | |
matches = re.findall(pattern, text) | |
key_val = {} | |
for match in matches: | |
key_val[match[0]] = match[1].lstrip() | |
return key_val | |
def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=False): | |
if enable_positive_prompt: | |
print(f'before positive_prompt aug: {prompt}') | |
prompt = aug_with_positive_prompt(prompt) | |
print(f'after positive_prompt aug: {prompt}') | |
print(f'prompt={prompt}') | |
captions = [prompt] | |
tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset | |
input_ids = tokens.input_ids.cuda(non_blocking=True) | |
mask = tokens.attention_mask.cuda(non_blocking=True) | |
text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float() | |
lens: List[int] = mask.sum(dim=-1).tolist() | |
cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0)) | |
Ltext = max(lens) | |
kv_compact = [] | |
for len_i, feat_i in zip(lens, text_features.unbind(0)): | |
kv_compact.append(feat_i[:len_i]) | |
kv_compact = torch.cat(kv_compact, dim=0) | |
text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext) | |
return text_cond_tuple | |
def aug_with_positive_prompt(prompt): | |
for key in ['man', 'woman', 'men', 'women', 'boy', 'girl', 'child', 'person', 'human', 'adult', 'teenager', 'employee', | |
'employer', 'worker', 'mother', 'father', 'sister', 'brother', 'grandmother', 'grandfather', 'son', 'daughter']: | |
if key in prompt: | |
prompt = prompt + '. very smooth faces, good looking faces, face to the camera, perfect facial features' | |
break | |
return prompt | |
def enhance_image(image): | |
for t in range(1): | |
contrast_image = image.copy() | |
contrast_enhancer = ImageEnhance.Contrast(contrast_image) | |
contrast_image = contrast_enhancer.enhance(1.05) # 增强对比度 | |
color_image = contrast_image.copy() | |
color_enhancer = ImageEnhance.Color(color_image) | |
color_image = color_enhancer.enhance(1.05) # 增强饱和度 | |
return color_image | |
def gen_one_img( | |
infinity_test, | |
vae, | |
text_tokenizer, | |
text_encoder, | |
prompt, | |
cfg_list=[], | |
tau_list=[], | |
negative_prompt='', | |
scale_schedule=None, | |
top_k=900, | |
top_p=0.97, | |
cfg_sc=3, | |
cfg_exp_k=0.0, | |
cfg_insertion_layer=-5, | |
vae_type=0, | |
gumbel=0, | |
softmax_merge_topk=-1, | |
gt_leak=-1, | |
gt_ls_Bl=None, | |
g_seed=None, | |
sampling_per_bits=1, | |
enable_positive_prompt=0, | |
): | |
sstt = time.time() | |
if not isinstance(cfg_list, list): | |
cfg_list = [cfg_list] * len(scale_schedule) | |
if not isinstance(tau_list, list): | |
tau_list = [tau_list] * len(scale_schedule) | |
text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt) | |
if negative_prompt: | |
negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt) | |
else: | |
negative_label_B_or_BLT = None | |
print(f'cfg: {cfg_list}, tau: {tau_list}') | |
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True): | |
stt = time.time() | |
_, _, img_list = infinity_test.autoregressive_infer_cfg( | |
vae=vae, | |
scale_schedule=scale_schedule, | |
label_B_or_BLT=text_cond_tuple, g_seed=g_seed, | |
B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None, | |
cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p, | |
returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False, | |
cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer, | |
vae_type=vae_type, softmax_merge_topk=softmax_merge_topk, | |
ret_img=True, trunk_scale=1000, | |
gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True, | |
sampling_per_bits=sampling_per_bits, | |
) | |
print(f"cost: {time.time() - sstt}, infinity cost={time.time() - stt}") | |
img = img_list[0] | |
return img | |
def get_prompt_id(prompt): | |
md5 = hashlib.md5() | |
md5.update(prompt.encode('utf-8')) | |
prompt_id = md5.hexdigest() | |
return prompt_id | |
def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_fsdp'): | |
print('[Save slim model]') | |
full_ckpt = torch.load(infinity_model_path, map_location=device) | |
infinity_slim = full_ckpt['trainer'][key] | |
# ema_state_dict = cpu_d['trainer'].get('gpt_ema_fsdp', state_dict) | |
if not save_file: | |
save_file = osp.splitext(infinity_model_path)[0] + '-slim.pth' | |
print(f'Save to {save_file}') | |
torch.save(infinity_slim, save_file) | |
print('[Save slim model] done') | |
return save_file | |
def load_tokenizer(t5_path =''): | |
print(f'[Loading tokenizer and text encoder]') | |
text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True) | |
text_tokenizer.model_max_length = 512 | |
text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16) | |
text_encoder.to('cuda') | |
text_encoder.eval() | |
text_encoder.requires_grad_(False) | |
return text_tokenizer, text_encoder | |
def load_infinity( | |
rope2d_each_sa_layer, | |
rope2d_normalized_by_hw, | |
use_scale_schedule_embedding, | |
pn, | |
use_bit_label, | |
add_lvl_embeding_only_first_block, | |
model_path='', | |
scale_schedule=None, | |
vae=None, | |
device='cuda', | |
model_kwargs=None, | |
text_channels=2048, | |
apply_spatial_patchify=0, | |
use_flex_attn=False, | |
bf16=False, | |
): | |
print(f'[Loading Infinity]') | |
text_maxlen = 512 | |
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad(): | |
infinity_test: Infinity = Infinity( | |
vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen, | |
shared_aln=True, raw_scale_schedule=scale_schedule, | |
checkpointing='full-block', | |
customized_flash_attn=False, | |
fused_norm=True, | |
pad_to_multiplier=128, | |
use_flex_attn=use_flex_attn, | |
add_lvl_embeding_only_first_block=add_lvl_embeding_only_first_block, | |
use_bit_label=use_bit_label, | |
rope2d_each_sa_layer=rope2d_each_sa_layer, | |
rope2d_normalized_by_hw=rope2d_normalized_by_hw, | |
pn=pn, | |
apply_spatial_patchify=apply_spatial_patchify, | |
inference_mode=True, | |
train_h_div_w_list=[1.0], | |
**model_kwargs, | |
).to(device=device) | |
print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}') | |
if bf16: | |
for block in infinity_test.unregistered_blocks: | |
block.bfloat16() | |
infinity_test.eval() | |
infinity_test.requires_grad_(False) | |
infinity_test.cuda() | |
torch.cuda.empty_cache() | |
print(f'[Load Infinity weights]') | |
state_dict = torch.load(model_path, map_location=device) | |
print(infinity_test.load_state_dict(state_dict)) | |
infinity_test.rng = torch.Generator(device=device) | |
return infinity_test | |
def transform(pil_img, tgt_h, tgt_w): | |
width, height = pil_img.size | |
if width / height <= tgt_w / tgt_h: | |
resized_width = tgt_w | |
resized_height = int(tgt_w / (width / height)) | |
else: | |
resized_height = tgt_h | |
resized_width = int((width / height) * tgt_h) | |
pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS) | |
# crop the center out | |
arr = np.array(pil_img) | |
crop_y = (arr.shape[0] - tgt_h) // 2 | |
crop_x = (arr.shape[1] - tgt_w) // 2 | |
im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w]) | |
return im.add(im).add_(-1) | |
def joint_vi_vae_encode_decode(vae, image_path, scale_schedule, device, tgt_h, tgt_w): | |
pil_image = Image.open(image_path).convert('RGB') | |
inp = transform(pil_image, tgt_h, tgt_w) | |
inp = inp.unsqueeze(0).to(device) | |
scale_schedule = [(item[0], item[1], item[2]) for item in scale_schedule] | |
t1 = time.time() | |
h, z, _, all_bit_indices, _, infinity_input = vae.encode(inp, scale_schedule=scale_schedule) | |
t2 = time.time() | |
recons_img = vae.decode(z)[0] | |
if len(recons_img.shape) == 4: | |
recons_img = recons_img.squeeze(1) | |
print(f'recons: z.shape: {z.shape}, recons_img shape: {recons_img.shape}') | |
t3 = time.time() | |
print(f'vae encode takes {t2-t1:.2f}s, decode takes {t3-t2:.2f}s') | |
recons_img = (recons_img + 1) / 2 | |
recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8) | |
gt_img = (inp[0] + 1) / 2 | |
gt_img = gt_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8) | |
print(recons_img.shape, gt_img.shape) | |
return gt_img, recons_img, all_bit_indices | |
def load_visual_tokenizer(args): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# load vae | |
if args.vae_type in [16,18,20,24,32,64]: | |
from models.bsq_vae.vae import vae_model | |
schedule_mode = "dynamic" | |
codebook_dim = args.vae_type | |
codebook_size = 2**codebook_dim | |
if args.apply_spatial_patchify: | |
patch_size = 8 | |
encoder_ch_mult=[1, 2, 4, 4] | |
decoder_ch_mult=[1, 2, 4, 4] | |
else: | |
patch_size = 16 | |
encoder_ch_mult=[1, 2, 4, 4, 4] | |
decoder_ch_mult=[1, 2, 4, 4, 4] | |
vae = vae_model(args.vae_path, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, | |
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(device) | |
else: | |
raise ValueError(f'vae_type={args.vae_type} not supported') | |
return vae | |
def load_transformer(vae, args): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model_path = args.model_path | |
if args.checkpoint_type == 'torch': | |
# copy large model to local; save slim to local; and copy slim to nas; load local slim model | |
if osp.exists(args.cache_dir): | |
local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_')) | |
else: | |
local_model_path = model_path | |
if args.enable_model_cache: | |
slim_model_path = model_path.replace('ar-', 'slim-') | |
local_slim_model_path = local_model_path.replace('ar-', 'slim-') | |
os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True) | |
print(f'model_path: {model_path}, slim_model_path: {slim_model_path}') | |
print(f'local_model_path: {local_model_path}, local_slim_model_path: {local_slim_model_path}') | |
if not osp.exists(local_slim_model_path): | |
if osp.exists(slim_model_path): | |
print(f'copy {slim_model_path} to {local_slim_model_path}') | |
shutil.copyfile(slim_model_path, local_slim_model_path) | |
else: | |
if not osp.exists(local_model_path): | |
print(f'copy {model_path} to {local_model_path}') | |
shutil.copyfile(model_path, local_model_path) | |
save_slim_model(local_model_path, save_file=local_slim_model_path, device=device) | |
print(f'copy {local_slim_model_path} to {slim_model_path}') | |
if not osp.exists(slim_model_path): | |
shutil.copyfile(local_slim_model_path, slim_model_path) | |
os.remove(local_model_path) | |
os.remove(model_path) | |
slim_model_path = local_slim_model_path | |
else: | |
slim_model_path = model_path | |
print(f'load checkpoint from {slim_model_path}') | |
if args.model_type == 'infinity_2b': | |
kwargs_model = dict(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8) # 2b model | |
elif args.model_type == 'infinity_layer12': | |
kwargs_model = dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4) | |
elif args.model_type == 'infinity_layer16': | |
kwargs_model = dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4) | |
elif args.model_type == 'infinity_layer24': | |
kwargs_model = dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4) | |
elif args.model_type == 'infinity_layer32': | |
kwargs_model = dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4) | |
elif args.model_type == 'infinity_layer40': | |
kwargs_model = dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4) | |
elif args.model_type == 'infinity_layer48': | |
kwargs_model = dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4) | |
infinity = load_infinity( | |
rope2d_each_sa_layer=args.rope2d_each_sa_layer, | |
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, | |
use_scale_schedule_embedding=args.use_scale_schedule_embedding, | |
pn=args.pn, | |
use_bit_label=args.use_bit_label, | |
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, | |
model_path=slim_model_path, | |
scale_schedule=None, | |
vae=vae, | |
device=device, | |
model_kwargs=kwargs_model, | |
text_channels=args.text_channels, | |
apply_spatial_patchify=args.apply_spatial_patchify, | |
use_flex_attn=args.use_flex_attn, | |
bf16=args.bf16, | |
) | |
return infinity | |
# Set up paths | |
weights_path = Path(__file__).parent / 'weights' | |
weights_path.mkdir(exist_ok=True) | |
download_infinity_weights(weights_path) | |
# Device setup | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32 | |
# Define args | |
args = argparse.Namespace( | |
pn='1M', | |
model_path=str(weights_path / 'infinity_2b_reg.pth'), | |
cfg_insertion_layer=0, | |
vae_type=32, | |
vae_path=str(weights_path / 'infinity_vae_d32reg.pth'), | |
add_lvl_embeding_only_first_block=1, | |
use_bit_label=1, | |
model_type='infinity_2b', | |
rope2d_each_sa_layer=1, | |
rope2d_normalized_by_hw=2, | |
use_scale_schedule_embedding=0, | |
sampling_per_bits=1, | |
text_encoder_ckpt=str(weights_path / 'flan-t5-xl'), | |
text_channels=2048, | |
apply_spatial_patchify=0, | |
h_div_w_template=1.000, | |
use_flex_attn=0, | |
cache_dir='/dev/shm', | |
checkpoint_type='torch', | |
seed=0, | |
bf16=1 if dtype == torch.bfloat16 else 0, | |
save_file='tmp.jpg', | |
enable_model_cache=False, | |
) | |
# Load models | |
text_tokenizer, text_encoder = load_tokenizer(t5_path="google/flan-t5-xl") | |
vae = load_visual_tokenizer(args) | |
infinity = load_transformer(vae, args) | |
# Define the image generation function | |
def generate_image(prompt, cfg, tau, h_div_w, seed, enable_positive_prompt): | |
try: | |
args.prompt = prompt | |
args.cfg = cfg | |
args.tau = tau | |
args.h_div_w = h_div_w | |
args.seed = seed | |
args.enable_positive_prompt = enable_positive_prompt | |
# Find the closest h_div_w_template | |
h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates - h_div_w))] | |
# Get scale_schedule based on h_div_w_template_ | |
scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales'] | |
scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule] | |
# Generate the image | |
generated_image = gen_one_img( | |
infinity, | |
vae, | |
text_tokenizer, | |
text_encoder, | |
prompt, | |
g_seed=seed, | |
gt_leak=0, | |
gt_ls_Bl=None, | |
cfg_list=cfg, | |
tau_list=tau, | |
scale_schedule=scale_schedule, | |
cfg_insertion_layer=[args.cfg_insertion_layer], | |
vae_type=args.vae_type, | |
sampling_per_bits=args.sampling_per_bits, | |
enable_positive_prompt=enable_positive_prompt, | |
) | |
# Convert the image to RGB and uint8 | |
image = generated_image.cpu().numpy() | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image = np.uint8(image) | |
return image | |
except Exception as e: | |
print(f"Error generating image: {e}") | |
return None | |
# Set up Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1><center>Infinity Image Generator</center></h1>") | |
with gr.Row(): | |
with gr.Column(): | |
# Prompt Settings | |
gr.Markdown("### Prompt Settings") | |
prompt = gr.Textbox(label="Prompt", value="alien spaceship enterprise", placeholder="Enter your prompt here...") | |
enable_positive_prompt = gr.Checkbox(label="Enable Positive Prompt", value=False, info="Enhance prompts with positive attributes for faces.") | |
# Image Settings | |
gr.Markdown("### Image Settings") | |
with gr.Row(): | |
cfg = gr.Slider(label="CFG (Classifier-Free Guidance)", minimum=1, maximum=10, step=0.5, value=3, info="Controls the strength of the prompt.") | |
tau = gr.Slider(label="Tau (Temperature)", minimum=0.1, maximum=1.0, step=0.1, value=0.5, info="Controls the randomness of the output.") | |
with gr.Row(): | |
h_div_w = gr.Slider(label="Aspect Ratio (Height/Width)", minimum=0.5, maximum=2.0, step=0.1, value=1.0, info="Set the aspect ratio of the generated image.") | |
seed = gr.Number(label="Seed", value=random.randint(0, 10000), info="Set a seed for reproducibility.") | |
# Generate Button | |
generate_button = gr.Button("Generate Image", variant="primary") | |
with gr.Column(): | |
# Output Section | |
gr.Markdown("### Generated Image") | |
output_image = gr.Image(label="Generated Image", type="pil") | |
gr.Markdown("**Tip:** Right-click the image to save it.") | |
# Error Handling | |
error_message = gr.Textbox(label="Error Message", visible=False) | |
# Link the generate button to the image generation function | |
generate_button.click( | |
generate_image, | |
inputs=[prompt, cfg, tau, h_div_w, seed, enable_positive_prompt], | |
outputs=output_image | |
) | |
# Launch the Gradio app | |
demo.launch() |