from typing import Any, Callable, Dict, List, Optional, Union import importlib import inspect import math from pathlib import Path import re from collections import defaultdict import cv2 import time import numpy as np import PIL import torch import torch.nn as nn import torch.nn.functional as F from torch import einsum from torch.autograd.function import Function from diffusers import DiffusionPipeline #Support for find the region of object def encode_region_map_sp(state,tokenizer,unet,width,height, scale_ratio=8, text_ids=None,do_classifier_free_guidance = True): if text_ids is None: return torch.Tensor(0) uncond, cond = text_ids[0], text_ids[1] '''img_state = [] for k, v in state.items(): if v["map"] is None: continue v_input = tokenizer( k, max_length=tokenizer.model_max_length, truncation=True, add_special_tokens=False, ).input_ids dotmap = v["map"] < 255 out = dotmap.astype(float) out = out * float(v["weight"]) * g_strength #if v["mask_outsides"]: out[out==0] = -1 * float(v["mask_outsides"]) arr = torch.from_numpy( out ) img_state.append((v_input, arr)) if len(img_state) == 0: return torch.Tensor(0)''' w_tensors = dict() cond = cond.reshape(-1,).tolist() if isinstance(cond,np.ndarray) or isinstance(cond, torch.Tensor) else None uncond = uncond.reshape(-1,).tolist() if isinstance(uncond,np.ndarray) or isinstance(uncond, torch.Tensor) else None for layer in unet.down_blocks: c = int(len(cond)) #w, h = img_state[0][1].shape w_r, h_r = int(math.ceil(width / scale_ratio)), int(math.ceil(height / scale_ratio)) ret_cond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32) ret_uncond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32) #for v_as_tokens, img_where_color in img_state: if state is not None: for k, v in state.items(): if v["map"] is None: continue is_in = 0 k_as_tokens = tokenizer( k, max_length=tokenizer.model_max_length, truncation=True, add_special_tokens=False, ).input_ids region_map_resize = np.array(v["map"] < 255 ,dtype = np.uint8) region_map_resize = cv2.resize(region_map_resize,(w_r,h_r),interpolation = cv2.INTER_CUBIC) region_map_resize = (region_map_resize == np.max(region_map_resize)).astype(float) region_map_resize = region_map_resize * float(v["weight"]) region_map_resize[region_map_resize==0] = -1 * float(v["mask_outsides"]) ret = torch.from_numpy( region_map_resize ) ret = ret.reshape(-1, 1).repeat(1, len(k_as_tokens)) '''ret = ( F.interpolate( img_where_color.unsqueeze(0).unsqueeze(1), scale_factor=1 / scale_ratio, mode="bilinear", align_corners=True, ) .squeeze() .reshape(-1, 1) .repeat(1, len(v_as_tokens)) )''' if cond is not None: for idx, tok in enumerate(cond): if cond[idx : idx + len(k_as_tokens)] == k_as_tokens: is_in = 1 ret_cond_tensor[0, :, idx : idx + len(k_as_tokens)] += ret if uncond is not None: for idx, tok in enumerate(uncond): if uncond[idx : idx + len(k_as_tokens)] == k_as_tokens: is_in = 1 ret_uncond_tensor[0, :, idx : idx + len(k_as_tokens)] += ret if not is_in == 1: print(f"tokens {k_as_tokens} not found in text") w_tensors[w_r * h_r] = torch.cat([ret_uncond_tensor, ret_cond_tensor]) if do_classifier_free_guidance else ret_cond_tensor scale_ratio *= 2 return w_tensors def encode_region_map( pipe : DiffusionPipeline, state, width, height, num_images_per_prompt, text_ids = None, ): negative_prompt_tokens_id, prompt_tokens_id = text_ids[0] , text_ids[1] if prompt_tokens_id is None: return torch.Tensor(0) prompt_tokens_id = np.array(prompt_tokens_id) negative_prompt_tokens_id = np.array(prompt_tokens_id) if negative_prompt_tokens_id is not None else None #Spilit to each prompt number_prompt = prompt_tokens_id.shape[0] prompt_tokens_id = np.split(prompt_tokens_id,number_prompt) negative_prompt_tokens_id = np.split(negative_prompt_tokens_id,number_prompt) if negative_prompt_tokens_id is not None else None lst_prompt_map = [] if not isinstance(state,list): state = [state] if len(state) < number_prompt: state = [state] + [None] * int(number_prompt - len(state)) for i in range(0,number_prompt): text_ids = [negative_prompt_tokens_id[i],prompt_tokens_id[i]] if negative_prompt_tokens_id is not None else [None,prompt_tokens_id[i]] region_map = encode_region_map_sp(state[i],pipe.tokenizer,pipe.unet,width,height,scale_ratio = pipe.vae_scale_factor,text_ids = text_ids,do_classifier_free_guidance = pipe.do_classifier_free_guidance) lst_prompt_map.append(region_map) region_state_sp = {} for d in lst_prompt_map: for key, tensor in d.items(): if key in region_state_sp: #If key exist, concat region_state_sp[key] = torch.cat((region_state_sp[key], tensor)) else: # if key doesnt exist, add region_state_sp[key] = tensor #add_when_apply num_images_per_prompt region_state = {} for key, tensor in region_state_sp.items(): # Repeant accoding to axis = 0 region_state[key] = tensor.repeat(num_images_per_prompt,1,1) return region_state