ELITE / inference_local.py
csyxwei's picture
elite code init
0792c6b
raw
history blame
8.75 kB
import os
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from train_local import Mapper, th2image, MapperLocal
from train_local import inj_forward_text, inj_forward_crossattention, validation
import torch.nn as nn
from datasets import CustomDatasetWithBG
def _pil_from_latents(vae, latents):
_latents = 1 / 0.18215 * latents.clone()
image = vae.decode(_latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
ret_pil_images = [Image.fromarray(image) for image in images]
return ret_pil_images
def pww_load_tools(
device: str = "cuda:0",
scheduler_type=LMSDiscreteScheduler,
mapper_model_path: Optional[str] = None,
mapper_local_model_path: Optional[str] = None,
diffusion_model_path: Optional[str] = None,
model_token: Optional[str] = None,
) -> Tuple[
UNet2DConditionModel,
CLIPTextModel,
CLIPTokenizer,
AutoencoderKL,
CLIPVisionModel,
Mapper,
MapperLocal,
LMSDiscreteScheduler,
]:
# 'CompVis/stable-diffusion-v1-4'
local_path_only = diffusion_model_path is not None
vae = AutoencoderKL.from_pretrained(
diffusion_model_path,
subfolder="vae",
use_auth_token=model_token,
torch_dtype=torch.float16,
local_files_only=local_path_only,
)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
image_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
# Load models and create wrapper for stable diffusion
for _module in text_encoder.modules():
if _module.__class__.__name__ == "CLIPTextTransformer":
_module.__class__.__call__ = inj_forward_text
unet = UNet2DConditionModel.from_pretrained(
diffusion_model_path,
subfolder="unet",
use_auth_token=model_token,
torch_dtype=torch.float16,
local_files_only=local_path_only,
)
inj_forward_crossattention
mapper = Mapper(input_dim=1024, output_dim=768)
mapper_local = MapperLocal(input_dim=1024, output_dim=768)
for _name, _module in unet.named_modules():
if _module.__class__.__name__ == "CrossAttention":
if 'attn1' in _name: continue
_module.__class__.__call__ = inj_forward_crossattention
shape = _module.to_k.weight.shape
to_k_global = nn.Linear(shape[1], shape[0], bias=False)
mapper.add_module(f'{_name.replace(".", "_")}_to_k', to_k_global)
shape = _module.to_v.weight.shape
to_v_global = nn.Linear(shape[1], shape[0], bias=False)
mapper.add_module(f'{_name.replace(".", "_")}_to_v', to_v_global)
to_v_local = nn.Linear(shape[1], shape[0], bias=False)
mapper_local.add_module(f'{_name.replace(".", "_")}_to_v', to_v_local)
to_k_local = nn.Linear(shape[1], shape[0], bias=False)
mapper_local.add_module(f'{_name.replace(".", "_")}_to_k', to_k_local)
mapper.load_state_dict(torch.load(mapper_model_path, map_location='cpu'))
mapper.half()
mapper_local.load_state_dict(torch.load(mapper_local_model_path, map_location='cpu'))
mapper_local.half()
for _name, _module in unet.named_modules():
if 'attn1' in _name: continue
if _module.__class__.__name__ == "CrossAttention":
_module.add_module('to_k_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_k'))
_module.add_module('to_v_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_v'))
_module.add_module('to_v_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_v'))
_module.add_module('to_k_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_k'))
vae.to(device), unet.to(device), text_encoder.to(device), image_encoder.to(device), mapper.to(device), mapper_local.to(device)
scheduler = scheduler_type(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
vae.eval()
unet.eval()
image_encoder.eval()
text_encoder.eval()
mapper.eval()
mapper_local.eval()
return vae, unet, text_encoder, tokenizer, image_encoder, mapper, mapper_local, scheduler
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--global_mapper_path",
type=str,
required=True,
help="Path to pretrained global mapping network.",
)
parser.add_argument(
"--local_mapper_path",
type=str,
required=True,
help="Path to pretrained local mapping network.",
)
parser.add_argument(
"--output_dir",
type=str,
default='outputs',
help="The output directory where the model predictions will be written.",
)
parser.add_argument(
"--placeholder_token",
type=str,
default="S",
help="A token to use as a placeholder for the concept.",
)
parser.add_argument(
"--template",
type=str,
default="a photo of a {}",
help="Text template for customized genetation.",
)
parser.add_argument(
"--test_data_dir", type=str, default=None, required=True, help="A folder containing the testing data."
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--suffix",
type=str,
default="object",
help="Suffix of save directory.",
)
parser.add_argument(
"--selected_data",
type=int,
default=-1,
help="Data index. -1 for all.",
)
parser.add_argument(
"--llambda",
type=str,
default="0.8",
help="Lambda for fuse the global and local feature.",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
save_dir = os.path.join(args.output_dir, f'{args.suffix}_l{args.llambda.replace(".", "p")}')
os.makedirs(save_dir, exist_ok=True)
vae, unet, text_encoder, tokenizer, image_encoder, mapper, mapper_local, scheduler = pww_load_tools(
"cuda:0",
LMSDiscreteScheduler,
diffusion_model_path=args.pretrained_model_name_or_path,
mapper_model_path=args.global_mapper_path,
mapper_local_model_path=args.local_mapper_path,
)
train_dataset = CustomDatasetWithBG(
data_root=args.test_data_dir,
tokenizer=tokenizer,
size=512,
placeholder_token=args.placeholder_token,
template=args.template,
)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False)
for step, batch in enumerate(train_dataloader):
if args.selected_data > -1 and step != args.selected_data:
continue
batch["pixel_values"] = batch["pixel_values"].to("cuda:0")
batch["pixel_values_clip"] = batch["pixel_values_clip"].to("cuda:0").half()
batch["pixel_values_obj"] = batch["pixel_values_obj"].to("cuda:0").half()
batch["pixel_values_seg"] = batch["pixel_values_seg"].to("cuda:0").half()
batch["input_ids"] = batch["input_ids"].to("cuda:0")
batch["index"] = batch["index"].to("cuda:0").long()
print(step, batch['text'])
seeds = [0, 42, 10086, 777, 555, 222, 111, 999, 327, 283, 190, 218, 2371, 9329, 2938, 2073, 27367, 293,
8269, 87367, 29379, 4658, 39, 598]
seeds = sorted(seeds)
for seed in seeds:
syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, mapper_local, vae,
batch["pixel_values_clip"].device, 5,
seed=seed, llambda=float(args.llambda))
concat = np.concatenate((np.array(syn_images[0]), th2image(batch["pixel_values"][0])), axis=1)
Image.fromarray(concat).save(os.path.join(save_dir, f'{str(step).zfill(5)}_{str(seed).zfill(5)}.jpg'))