|
from transformers import PretrainedConfig |
|
from PIL import Image |
|
import torch |
|
import numpy as np |
|
import PIL |
|
import os |
|
from tqdm.auto import tqdm |
|
from diffusers.models.attention_processor import ( |
|
AttnProcessor2_0, |
|
LoRAAttnProcessor2_0, |
|
LoRAXFormersAttnProcessor, |
|
XFormersAttnProcessor, |
|
) |
|
|
|
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
def myroll2d(a, delta_x, delta_y): |
|
h, w = a.shape[0], a.shape[1] |
|
delta_x = -delta_x |
|
delta_y = -delta_y |
|
if isinstance(a, np.ndarray): |
|
b = np.zeros ([h,w]).astype(np.uint8) |
|
elif isinstance(a, torch.Tensor): |
|
b = torch.zeros([h,w]).to(torch.uint8) |
|
if delta_x > 0: |
|
left_a = delta_x |
|
right_a = w |
|
left_b = 0 |
|
right_b = w - delta_x |
|
else: |
|
left_a = 0 |
|
right_a = w + delta_x |
|
left_b = -delta_x |
|
right_b = w |
|
if delta_y > 0: |
|
top_a = delta_y |
|
bot_a = h |
|
top_b = 0 |
|
bot_b = h-delta_y |
|
else: |
|
top_a = 0 |
|
bot_a = h + delta_y |
|
top_b = -delta_y |
|
bot_b = h |
|
b[left_b: right_b, top_b: bot_b] = a[left_a: right_a, top_a: bot_a] |
|
return b |
|
|
|
def import_model_class_from_model_name_or_path( |
|
pretrained_model_name_or_path: str, revision = None, subfolder: str = "text_encoder" |
|
): |
|
text_encoder_config = PretrainedConfig.from_pretrained( |
|
pretrained_model_name_or_path, subfolder=subfolder, revision=revision |
|
) |
|
model_class = text_encoder_config.architectures[0] |
|
|
|
if model_class == "CLIPTextModel": |
|
from transformers import CLIPTextModel |
|
return CLIPTextModel |
|
elif model_class == "CLIPTextModelWithProjection": |
|
from transformers import CLIPTextModelWithProjection |
|
return CLIPTextModelWithProjection |
|
else: |
|
raise ValueError(f"{model_class} is not supported.") |
|
|
|
@torch.no_grad() |
|
def image2latent(image, vae = None, dtype=None): |
|
with torch.no_grad(): |
|
if type(image) is Image or type(image) is PIL.PngImagePlugin.PngImageFile or type(image) is PIL.JpegImagePlugin.JpegImageFile: |
|
image = np.array(image) |
|
if type(image) is torch.Tensor and image.dim() == 4: |
|
latents = image |
|
else: |
|
image = torch.from_numpy(image).float() / 127.5 - 1 |
|
image = image.permute(2, 0, 1).unsqueeze(0).to(device, dtype= dtype) |
|
latents = vae.encode(image).latent_dist.sample() |
|
latents = latents * vae.config.scaling_factor |
|
return latents |
|
|
|
@torch.no_grad() |
|
def latent2image(latents, return_type = 'np', vae = None): |
|
|
|
needs_upcasting = True |
|
if needs_upcasting: |
|
upcast_vae(vae) |
|
latents = latents.to(next(iter(vae.post_quant_conv.parameters())).dtype) |
|
image = vae.decode(latents /vae.config.scaling_factor, return_dict=False)[0] |
|
|
|
if return_type == 'np': |
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = image.cpu().permute(0, 2, 3, 1).numpy() |
|
image = (image * 255).astype(np.uint8) |
|
if needs_upcasting: |
|
vae.to(dtype=torch.float16) |
|
return image |
|
|
|
def upcast_vae(vae): |
|
dtype = vae.dtype |
|
vae.to(dtype=torch.float32) |
|
use_torch_2_0_or_xformers = isinstance( |
|
vae.decoder.mid_block.attentions[0].processor, |
|
( |
|
AttnProcessor2_0, |
|
XFormersAttnProcessor, |
|
LoRAXFormersAttnProcessor, |
|
LoRAAttnProcessor2_0, |
|
), |
|
) |
|
|
|
|
|
if use_torch_2_0_or_xformers: |
|
vae.post_quant_conv.to(dtype) |
|
vae.decoder.conv_in.to(dtype) |
|
vae.decoder.mid_block.to(dtype) |
|
|
|
def prompt_to_emb_length_sdxl(prompt, tokenizer, text_encoder, length = None): |
|
text_input = tokenizer( |
|
[prompt], |
|
padding="max_length", |
|
max_length=length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
prompt_embeds = text_encoder(text_input.input_ids.to(device),output_hidden_states=True) |
|
pooled_prompt_embeds = prompt_embeds[0] |
|
|
|
prompt_embeds = prompt_embeds.hidden_states[-2] |
|
bs_embed, seq_len, _ = prompt_embeds.shape |
|
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) |
|
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) |
|
|
|
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds} |
|
|
|
|
|
|
|
|
|
def prompt_to_emb_length_sd(prompt, tokenizer, text_encoder, length = None): |
|
text_input = tokenizer( |
|
[prompt], |
|
padding="max_length", |
|
max_length=length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
emb = text_encoder(text_input.input_ids.to(device))[0] |
|
return emb |
|
|
|
def sdxl_prepare_input_decom( |
|
set_string_list, |
|
tokenizer, |
|
tokenizer_2, |
|
text_encoder_1, |
|
text_encoder_2, |
|
length = 20, |
|
bsz = 1, |
|
weight_dtype = torch.float32, |
|
resolution = 1024, |
|
normal_token_id_list = [] |
|
): |
|
encoder_hidden_states_list = [] |
|
pooled_prompt_embeds = 0 |
|
|
|
for m_idx in range(len(set_string_list)): |
|
prompt_embeds_list = [] |
|
if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : |
|
out = prompt_to_emb_length_sdxl( |
|
set_string_list[m_idx], tokenizer, text_encoder_1, length = length |
|
) |
|
else: |
|
out = prompt_to_emb_length_sdxl( |
|
set_string_list[m_idx], tokenizer, text_encoder_1, length = 77 |
|
) |
|
print(m_idx, set_string_list[m_idx]) |
|
prompt_embeds, _ = out["prompt_embeds"].to(dtype=weight_dtype), out["pooled_prompt_embeds"].to(dtype=weight_dtype) |
|
prompt_embeds = prompt_embeds.repeat(bsz, 1, 1) |
|
prompt_embeds_list.append(prompt_embeds) |
|
if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list: |
|
out = prompt_to_emb_length_sdxl( |
|
set_string_list[m_idx], tokenizer_2, text_encoder_2, length = length |
|
) |
|
else: |
|
out = prompt_to_emb_length_sdxl( |
|
set_string_list[m_idx], tokenizer_2, text_encoder_2, length = 77 |
|
) |
|
print(m_idx, set_string_list[m_idx]) |
|
|
|
prompt_embeds = out["prompt_embeds"].to(dtype=weight_dtype) |
|
pooled_prompt_embeds += out["pooled_prompt_embeds"].to(dtype=weight_dtype) |
|
prompt_embeds = prompt_embeds.repeat(bsz, 1, 1) |
|
prompt_embeds_list.append(prompt_embeds) |
|
|
|
encoder_hidden_states_list.append(torch.concat(prompt_embeds_list, dim=-1)) |
|
|
|
add_text_embeds = pooled_prompt_embeds /len(set_string_list) |
|
target_size, original_size,crops_coords_top_left = (resolution,resolution),(resolution,resolution),(0,0) |
|
add_time_ids = list(original_size + crops_coords_top_left + target_size) |
|
|
|
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype,device = pooled_prompt_embeds.device) |
|
return encoder_hidden_states_list, add_text_embeds, add_time_ids |
|
|
|
def sd_prepare_input_decom( |
|
set_string_list, |
|
tokenizer, |
|
text_encoder_1, |
|
length = 20, |
|
bsz = 1, |
|
weight_dtype = torch.float32, |
|
normal_token_id_list = [] |
|
): |
|
encoder_hidden_states_list = [] |
|
for m_idx in range(len(set_string_list)): |
|
if ("#" in set_string_list[m_idx] or "$" in set_string_list[m_idx]) and m_idx not in normal_token_id_list : |
|
encoder_hidden_states = prompt_to_emb_length_sd( |
|
set_string_list[m_idx], tokenizer, text_encoder_1, length = length |
|
) |
|
else: |
|
encoder_hidden_states = prompt_to_emb_length_sd( |
|
set_string_list[m_idx], tokenizer, text_encoder_1, length = 77 |
|
) |
|
print(m_idx, set_string_list[m_idx]) |
|
encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) |
|
encoder_hidden_states_list.append(encoder_hidden_states.to(dtype=weight_dtype)) |
|
return encoder_hidden_states_list |
|
|
|
|
|
def load_mask (input_folder): |
|
np_mask_dtype = 'uint8' |
|
mask_np_list = [] |
|
mask_label_list = [] |
|
files = [ |
|
file_name for file_name in os.listdir(input_folder) \ |
|
if "mask" in file_name and ".npy" in file_name \ |
|
and "_" in file_name and "Edited" not in file_name |
|
] |
|
files = sorted(files, key = lambda x: int(x.split("_")[0][4:])) |
|
|
|
for idx, file_name in enumerate(files): |
|
if "mask" in file_name and ".npy" in file_name and "_" in file_name \ |
|
and "Edited" not in file_name: |
|
mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype) |
|
mask_np_list.append(mask_np) |
|
mask_label = file_name.split("_")[1][:-4] |
|
mask_label_list.append(mask_label) |
|
mask_list = [] |
|
for mask_np in mask_np_list: |
|
mask = torch.from_numpy(mask_np) |
|
mask_list.append(mask) |
|
try: |
|
assert torch.all(sum(mask_list)==1) |
|
except: |
|
print("please check mask") |
|
|
|
import pdb; pdb.set_trace() |
|
return mask_list, mask_label_list |
|
|
|
def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512): |
|
if type(image_path) is str: |
|
image = np.array(Image.open(image_path))[:, :, :3] |
|
else: |
|
image = image_path |
|
h, w, c = image.shape |
|
left = min(left, w-1) |
|
right = min(right, w - left - 1) |
|
top = min(top, h - left - 1) |
|
bottom = min(bottom, h - top - 1) |
|
image = image[top:h-bottom, left:w-right] |
|
h, w, c = image.shape |
|
if h < w: |
|
offset = (w - h) // 2 |
|
image = image[:, offset:offset + h] |
|
elif w < h: |
|
offset = (h - w) // 2 |
|
image = image[offset:offset + w] |
|
image = np.array(Image.fromarray(image).resize((size, size))) |
|
return image |
|
|
|
def mask_union_torch(*masks): |
|
masks = [m.to(torch.float) for m in masks] |
|
res = sum(masks)>0 |
|
return res |
|
|
|
def load_mask_edit(input_folder): |
|
np_mask_dtype = 'uint8' |
|
mask_np_list = [] |
|
mask_label_list = [] |
|
|
|
files = [file_name for file_name in os.listdir(input_folder) if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name] |
|
files = sorted(files, key = lambda x: int(x.split("_")[0][10:])) |
|
|
|
for idx, file_name in enumerate(files): |
|
if "mask" in file_name and ".npy" in file_name and "_" in file_name and "Edited" in file_name and "-1" not in file_name: |
|
mask_np = np.load(os.path.join(input_folder, file_name)).astype(np_mask_dtype) |
|
mask_np_list.append(mask_np) |
|
mask_label = file_name.split("_")[1][:-4] |
|
|
|
mask_label_list.append(mask_label) |
|
mask_list = [] |
|
for mask_np in mask_np_list: |
|
mask = torch.from_numpy(mask_np) |
|
mask_list.append(mask) |
|
try: |
|
assert torch.all(sum(mask_list)==1) |
|
except: |
|
print("Make sure maskEdited is in the folder, if not, generate using the UI") |
|
import pdb; pdb.set_trace() |
|
return mask_list, mask_label_list |
|
|
|
def save_images(images,filename, num_rows=1, offset_ratio=0.02): |
|
if type(images) is list: |
|
num_empty = len(images) % num_rows |
|
elif images.ndim == 4: |
|
num_empty = images.shape[0] % num_rows |
|
else: |
|
images = [images] |
|
num_empty = 0 |
|
|
|
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 |
|
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty |
|
num_items = len(images) |
|
|
|
folder = os.path.dirname(filename) |
|
for i, image in enumerate(images): |
|
pil_img = Image.fromarray(image) |
|
name = filename.split("/")[-1] |
|
name = name.split(".")[-2]+"_{}".format(i) +"."+filename.split(".")[-1] |
|
pil_img.save(os.path.join(folder, name)) |
|
print("saved to ", os.path.join(folder, name)) |