|
import inspect |
|
import re |
|
from pathlib import Path |
|
from typing import Callable, List, Optional, Tuple, Union |
|
|
|
import diffusers |
|
import numpy as np |
|
import PIL |
|
import torch |
|
from accelerate import init_empty_weights |
|
from diffusers import ( |
|
AutoencoderKL, |
|
DDIMScheduler, |
|
EulerDiscreteScheduler, |
|
LCMScheduler, |
|
LMSDiscreteScheduler, |
|
PNDMScheduler, |
|
StableDiffusionXLPipeline, |
|
) |
|
from diffusers.configuration_utils import FrozenDict |
|
from diffusers.utils.deprecation_utils import deprecate |
|
from einops import rearrange |
|
from PIL import Image |
|
from PIL.PngImagePlugin import PngInfo |
|
from safetensors.torch import load_file |
|
from tqdm import tqdm |
|
from transformers import ( |
|
CLIPImageProcessor, |
|
CLIPTextModel, |
|
CLIPTokenizer, |
|
CLIPVisionModelWithProjection, |
|
) |
|
|
|
import external.llite.library.model_util as model_util |
|
import external.llite.library.sdxl_model_util as sdxl_model_util |
|
import external.llite.library.sdxl_original_unet as sdxl_original_unet |
|
import external.llite.library.sdxl_train_util as sdxl_train_util |
|
import external.llite.library.train_util as train_util |
|
from external.llite.library.original_unet import FlashAttentionFunction |
|
from external.llite.library.sdxl_original_unet import InferSdxlUNet2DConditionModel |
|
from external.llite.networks.control_net_lllite import ControlNetLLLite |
|
from external.llite.networks.lora import LoRANetwork |
|
from internals.pipelines.commons import AbstractPipeline |
|
from internals.util.cache import clear_cuda_and_gc |
|
from internals.util.commons import download_file |
|
|
|
|
|
class PipelineLike: |
|
def __init__( |
|
self, |
|
device, |
|
vae: AutoencoderKL, |
|
text_encoders: List[CLIPTextModel], |
|
tokenizers: List[CLIPTokenizer], |
|
unet: InferSdxlUNet2DConditionModel, |
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], |
|
clip_skip: int, |
|
): |
|
super().__init__() |
|
self.device = device |
|
self.clip_skip = clip_skip |
|
|
|
if ( |
|
hasattr(scheduler.config, "steps_offset") |
|
and scheduler.config.steps_offset != 1 |
|
): |
|
deprecation_message = ( |
|
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" |
|
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " |
|
"to update the config accordingly as leaving `steps_offset` might led to incorrect results" |
|
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," |
|
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" |
|
" file" |
|
) |
|
deprecate( |
|
"steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False |
|
) |
|
new_config = dict(scheduler.config) |
|
new_config["steps_offset"] = 1 |
|
scheduler._internal_dict = FrozenDict(new_config) |
|
|
|
if ( |
|
hasattr(scheduler.config, "clip_sample") |
|
and scheduler.config.clip_sample is True |
|
): |
|
deprecation_message = ( |
|
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." |
|
" `clip_sample` should be set to False in the configuration file. Please make sure to update the" |
|
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" |
|
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" |
|
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" |
|
) |
|
deprecate( |
|
"clip_sample not set", "1.0.0", deprecation_message, standard_warn=False |
|
) |
|
new_config = dict(scheduler.config) |
|
new_config["clip_sample"] = False |
|
scheduler._internal_dict = FrozenDict(new_config) |
|
|
|
self.vae = vae |
|
self.text_encoders = text_encoders |
|
self.tokenizers = tokenizers |
|
self.unet: InferSdxlUNet2DConditionModel = unet |
|
self.scheduler = scheduler |
|
self.safety_checker = None |
|
|
|
self.clip_vision_model: CLIPVisionModelWithProjection = None |
|
self.clip_vision_processor: CLIPImageProcessor = None |
|
self.clip_vision_strength = 0.0 |
|
|
|
|
|
self.token_replacements_list = [] |
|
for _ in range(len(self.text_encoders)): |
|
self.token_replacements_list.append({}) |
|
|
|
|
|
self.control_nets: List[ControlNetLLLite] = [] |
|
self.control_net_enabled = True |
|
|
|
|
|
def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): |
|
self.token_replacements_list[text_encoder_index][ |
|
target_token_id |
|
] = rep_token_ids |
|
|
|
def set_enable_control_net(self, en: bool): |
|
self.control_net_enabled = en |
|
|
|
def preprocess_image(self, image): |
|
w, h = image.size |
|
|
|
w, h = map(lambda x: x - x % 32, (w, h)) |
|
image = image.resize((w, h), resample=PIL.Image.LANCZOS) |
|
image = np.array(image).astype(np.float32) / 255.0 |
|
image = image[None].transpose(0, 3, 1, 2) |
|
image = torch.from_numpy(image) |
|
return 2.0 * image - 1.0 |
|
|
|
def get_unweighted_text_embeddings( |
|
self, |
|
text_encoder: CLIPTextModel, |
|
text_input: torch.Tensor, |
|
chunk_length: int, |
|
clip_skip: int, |
|
eos: int, |
|
pad: int, |
|
no_boseos_middle: Optional[bool] = True, |
|
): |
|
""" |
|
When the length of tokens is a multiple of the capacity of the text encoder, |
|
it should be split into chunks and sent to the text encoder individually. |
|
""" |
|
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) |
|
if max_embeddings_multiples > 1: |
|
text_embeddings = [] |
|
pool = None |
|
for i in range(max_embeddings_multiples): |
|
|
|
text_input_chunk = text_input[ |
|
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 |
|
].clone() |
|
|
|
|
|
text_input_chunk[:, 0] = text_input[0, 0] |
|
if pad == eos: |
|
text_input_chunk[:, -1] = text_input[0, -1] |
|
else: |
|
for j in range(len(text_input_chunk)): |
|
|
|
if ( |
|
text_input_chunk[j, -1] != eos |
|
and text_input_chunk[j, -1] != pad |
|
): |
|
text_input_chunk[j, -1] = eos |
|
if text_input_chunk[j, 1] == pad: |
|
text_input_chunk[j, 1] = eos |
|
|
|
|
|
enc_out = text_encoder( |
|
text_input_chunk, output_hidden_states=True, return_dict=True |
|
) |
|
text_embedding = enc_out["hidden_states"][-2] |
|
if pool is None: |
|
|
|
pool = enc_out.get("text_embeds", None) |
|
if pool is not None: |
|
pool = train_util.pool_workaround( |
|
text_encoder, |
|
enc_out["last_hidden_state"], |
|
text_input_chunk, |
|
eos, |
|
) |
|
|
|
if no_boseos_middle: |
|
if i == 0: |
|
|
|
text_embedding = text_embedding[:, :-1] |
|
elif i == max_embeddings_multiples - 1: |
|
|
|
text_embedding = text_embedding[:, 1:] |
|
else: |
|
|
|
text_embedding = text_embedding[:, 1:-1] |
|
|
|
text_embeddings.append(text_embedding) |
|
text_embeddings = torch.concat(text_embeddings, axis=1) |
|
else: |
|
enc_out = text_encoder( |
|
text_input, output_hidden_states=True, return_dict=True |
|
) |
|
text_embeddings = enc_out["hidden_states"][-2] |
|
|
|
pool = enc_out.get("text_embeds", None) |
|
if pool is not None: |
|
pool = train_util.pool_workaround( |
|
text_encoder, enc_out["last_hidden_state"], text_input, eos |
|
) |
|
return text_embeddings, pool |
|
|
|
def preprocess_mask(self, mask): |
|
mask = mask.convert("L") |
|
w, h = mask.size |
|
|
|
w, h = map(lambda x: x - x % 32, (w, h)) |
|
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) |
|
mask = np.array(mask).astype(np.float32) / 255.0 |
|
mask = np.tile(mask, (4, 1, 1)) |
|
mask = mask[None].transpose(0, 1, 2, 3) |
|
mask = 1 - mask |
|
mask = torch.from_numpy(mask) |
|
return mask |
|
|
|
def get_prompts_with_weights( |
|
self, |
|
tokenizer: CLIPTokenizer, |
|
token_replacer, |
|
prompt: List[str], |
|
max_length: int, |
|
): |
|
r""" |
|
Tokenize a list of prompts and return its tokens with weights of each token. |
|
No padding, starting or ending token is included. |
|
""" |
|
tokens = [] |
|
weights = [] |
|
truncated = False |
|
|
|
def parse_prompt_attention(text): |
|
""" |
|
Parses a string with attention tokens and returns a list of pairs: text and its associated weight. |
|
Accepted tokens are: |
|
(abc) - increases attention to abc by a multiplier of 1.1 |
|
(abc:3.12) - increases attention to abc by a multiplier of 3.12 |
|
[abc] - decreases attention to abc by a multiplier of 1.1 |
|
\( - literal character '(' |
|
\[ - literal character '[' |
|
\) - literal character ')' |
|
\] - literal character ']' |
|
\\ - literal character '\' |
|
anything else - just text |
|
>>> parse_prompt_attention('normal text') |
|
[['normal text', 1.0]] |
|
>>> parse_prompt_attention('an (important) word') |
|
[['an ', 1.0], ['important', 1.1], [' word', 1.0]] |
|
>>> parse_prompt_attention('(unbalanced') |
|
[['unbalanced', 1.1]] |
|
>>> parse_prompt_attention('\(literal\]') |
|
[['(literal]', 1.0]] |
|
>>> parse_prompt_attention('(unnecessary)(parens)') |
|
[['unnecessaryparens', 1.1]] |
|
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') |
|
[['a ', 1.0], |
|
['house', 1.5730000000000004], |
|
[' ', 1.1], |
|
['on', 1.0], |
|
[' a ', 1.1], |
|
['hill', 0.55], |
|
[', sun, ', 1.1], |
|
['sky', 1.4641000000000006], |
|
['.', 1.1]] |
|
""" |
|
|
|
res = [] |
|
round_brackets = [] |
|
square_brackets = [] |
|
|
|
round_bracket_multiplier = 1.1 |
|
square_bracket_multiplier = 1 / 1.1 |
|
|
|
def multiply_range(start_position, multiplier): |
|
for p in range(start_position, len(res)): |
|
res[p][1] *= multiplier |
|
|
|
|
|
text = text.replace("BREAK", "\\BREAK\\") |
|
re_attention = re.compile( |
|
r""" |
|
\\\(| |
|
\\\)| |
|
\\\[| |
|
\\]| |
|
\\\\| |
|
\\| |
|
\(| |
|
\[| |
|
:([+-]?[.\d]+)\)| |
|
\)| |
|
]| |
|
[^\\()\[\]:]+| |
|
: |
|
""", |
|
re.X, |
|
) |
|
for m in re_attention.finditer(text): |
|
text = m.group(0) |
|
weight = m.group(1) |
|
|
|
if text.startswith("\\"): |
|
res.append([text[1:], 1.0]) |
|
elif text == "(": |
|
round_brackets.append(len(res)) |
|
elif text == "[": |
|
square_brackets.append(len(res)) |
|
elif weight is not None and len(round_brackets) > 0: |
|
multiply_range(round_brackets.pop(), float(weight)) |
|
elif text == ")" and len(round_brackets) > 0: |
|
multiply_range(round_brackets.pop(), round_bracket_multiplier) |
|
elif text == "]" and len(square_brackets) > 0: |
|
multiply_range(square_brackets.pop(), square_bracket_multiplier) |
|
else: |
|
res.append([text, 1.0]) |
|
|
|
for pos in round_brackets: |
|
multiply_range(pos, round_bracket_multiplier) |
|
|
|
for pos in square_brackets: |
|
multiply_range(pos, square_bracket_multiplier) |
|
|
|
if len(res) == 0: |
|
res = [["", 1.0]] |
|
|
|
|
|
i = 0 |
|
while i + 1 < len(res): |
|
if ( |
|
res[i][1] == res[i + 1][1] |
|
and res[i][0].strip() != "BREAK" |
|
and res[i + 1][0].strip() != "BREAK" |
|
): |
|
res[i][0] += res[i + 1][0] |
|
res.pop(i + 1) |
|
else: |
|
i += 1 |
|
|
|
return res |
|
|
|
for text in prompt: |
|
texts_and_weights = parse_prompt_attention(text) |
|
text_token = [] |
|
text_weight = [] |
|
for word, weight in texts_and_weights: |
|
if word.strip() == "BREAK": |
|
|
|
pad_len = tokenizer.model_max_length - ( |
|
len(text_token) % tokenizer.model_max_length |
|
) |
|
print(f"BREAK pad_len: {pad_len}") |
|
for i in range(pad_len): |
|
|
|
|
|
|
|
|
|
text_token.append(tokenizer.pad_token_id) |
|
text_weight.append(1.0) |
|
continue |
|
|
|
|
|
token = tokenizer(word).input_ids[1:-1] |
|
|
|
token = token_replacer(token) |
|
|
|
text_token += token |
|
|
|
text_weight += [weight] * len(token) |
|
|
|
if len(text_token) > max_length: |
|
truncated = True |
|
break |
|
|
|
if len(text_token) > max_length: |
|
truncated = True |
|
text_token = text_token[:max_length] |
|
text_weight = text_weight[:max_length] |
|
tokens.append(text_token) |
|
weights.append(text_weight) |
|
if truncated: |
|
print( |
|
"warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples" |
|
) |
|
return tokens, weights |
|
|
|
def pad_tokens_and_weights( |
|
self, |
|
tokens, |
|
weights, |
|
max_length, |
|
bos, |
|
eos, |
|
pad, |
|
no_boseos_middle=True, |
|
chunk_length=77, |
|
): |
|
r""" |
|
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. |
|
""" |
|
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) |
|
weights_length = ( |
|
max_length if no_boseos_middle else max_embeddings_multiples * chunk_length |
|
) |
|
for i in range(len(tokens)): |
|
tokens[i] = ( |
|
[bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) |
|
) |
|
if no_boseos_middle: |
|
weights[i] = ( |
|
[1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) |
|
) |
|
else: |
|
w = [] |
|
if len(weights[i]) == 0: |
|
w = [1.0] * weights_length |
|
else: |
|
for j in range(max_embeddings_multiples): |
|
|
|
w.append(1.0) |
|
w += weights[i][ |
|
j |
|
* (chunk_length - 2) : min( |
|
len(weights[i]), (j + 1) * (chunk_length - 2) |
|
) |
|
] |
|
w.append(1.0) |
|
w += [1.0] * (weights_length - len(w)) |
|
weights[i] = w[:] |
|
|
|
return tokens, weights |
|
|
|
def get_unweighted_text_embeddings( |
|
self, |
|
text_encoder: CLIPTextModel, |
|
text_input: torch.Tensor, |
|
chunk_length: int, |
|
clip_skip: int, |
|
eos: int, |
|
pad: int, |
|
no_boseos_middle: Optional[bool] = True, |
|
): |
|
""" |
|
When the length of tokens is a multiple of the capacity of the text encoder, |
|
it should be split into chunks and sent to the text encoder individually. |
|
""" |
|
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) |
|
if max_embeddings_multiples > 1: |
|
text_embeddings = [] |
|
pool = None |
|
for i in range(max_embeddings_multiples): |
|
|
|
text_input_chunk = text_input[ |
|
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 |
|
].clone() |
|
|
|
|
|
text_input_chunk[:, 0] = text_input[0, 0] |
|
if pad == eos: |
|
text_input_chunk[:, -1] = text_input[0, -1] |
|
else: |
|
for j in range(len(text_input_chunk)): |
|
|
|
if ( |
|
text_input_chunk[j, -1] != eos |
|
and text_input_chunk[j, -1] != pad |
|
): |
|
text_input_chunk[j, -1] = eos |
|
if text_input_chunk[j, 1] == pad: |
|
text_input_chunk[j, 1] = eos |
|
|
|
|
|
enc_out = text_encoder( |
|
text_input_chunk, output_hidden_states=True, return_dict=True |
|
) |
|
text_embedding = enc_out["hidden_states"][-2] |
|
if pool is None: |
|
|
|
pool = enc_out.get("text_embeds", None) |
|
if pool is not None: |
|
pool = train_util.pool_workaround( |
|
text_encoder, |
|
enc_out["last_hidden_state"], |
|
text_input_chunk, |
|
eos, |
|
) |
|
|
|
if no_boseos_middle: |
|
if i == 0: |
|
|
|
text_embedding = text_embedding[:, :-1] |
|
elif i == max_embeddings_multiples - 1: |
|
|
|
text_embedding = text_embedding[:, 1:] |
|
else: |
|
|
|
text_embedding = text_embedding[:, 1:-1] |
|
|
|
text_embeddings.append(text_embedding) |
|
text_embeddings = torch.concat(text_embeddings, axis=1) |
|
else: |
|
enc_out = text_encoder( |
|
text_input, output_hidden_states=True, return_dict=True |
|
) |
|
text_embeddings = enc_out["hidden_states"][-2] |
|
|
|
pool = enc_out.get("text_embeds", None) |
|
if pool is not None: |
|
pool = train_util.pool_workaround( |
|
text_encoder, enc_out["last_hidden_state"], text_input, eos |
|
) |
|
return text_embeddings, pool |
|
|
|
def get_weighted_text_embeddings( |
|
self, |
|
tokenizer: CLIPTokenizer, |
|
text_encoder: CLIPTextModel, |
|
prompt: Union[str, List[str]], |
|
uncond_prompt: Optional[Union[str, List[str]]] = None, |
|
max_embeddings_multiples: Optional[int] = 1, |
|
no_boseos_middle: Optional[bool] = False, |
|
skip_parsing: Optional[bool] = False, |
|
skip_weighting: Optional[bool] = False, |
|
clip_skip=None, |
|
token_replacer=None, |
|
device=None, |
|
**kwargs, |
|
): |
|
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 |
|
if isinstance(prompt, str): |
|
prompt = [prompt] |
|
|
|
|
|
new_prompts = [] |
|
for p in prompt: |
|
new_prompts.extend(p.split(" AND ")) |
|
prompt = new_prompts |
|
|
|
if not skip_parsing: |
|
prompt_tokens, prompt_weights = self.get_prompts_with_weights( |
|
tokenizer, token_replacer, prompt, max_length - 2 |
|
) |
|
if uncond_prompt is not None: |
|
if isinstance(uncond_prompt, str): |
|
uncond_prompt = [uncond_prompt] |
|
uncond_tokens, uncond_weights = self.get_prompts_with_weights( |
|
tokenizer, token_replacer, uncond_prompt, max_length - 2 |
|
) |
|
else: |
|
prompt_tokens = [ |
|
token[1:-1] |
|
for token in tokenizer( |
|
prompt, max_length=max_length, truncation=True |
|
).input_ids |
|
] |
|
prompt_weights = [[1.0] * len(token) for token in prompt_tokens] |
|
if uncond_prompt is not None: |
|
if isinstance(uncond_prompt, str): |
|
uncond_prompt = [uncond_prompt] |
|
uncond_tokens = [ |
|
token[1:-1] |
|
for token in tokenizer( |
|
uncond_prompt, max_length=max_length, truncation=True |
|
).input_ids |
|
] |
|
uncond_weights = [[1.0] * len(token) for token in uncond_tokens] |
|
|
|
|
|
max_length = max([len(token) for token in prompt_tokens]) |
|
if uncond_prompt is not None: |
|
max_length = max(max_length, max([len(token) for token in uncond_tokens])) |
|
|
|
max_embeddings_multiples = min( |
|
max_embeddings_multiples, |
|
(max_length - 1) // (tokenizer.model_max_length - 2) + 1, |
|
) |
|
max_embeddings_multiples = max(1, max_embeddings_multiples) |
|
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 |
|
|
|
|
|
bos = tokenizer.bos_token_id |
|
eos = tokenizer.eos_token_id |
|
pad = tokenizer.pad_token_id |
|
prompt_tokens, prompt_weights = self.pad_tokens_and_weights( |
|
prompt_tokens, |
|
prompt_weights, |
|
max_length, |
|
bos, |
|
eos, |
|
pad, |
|
no_boseos_middle=no_boseos_middle, |
|
chunk_length=tokenizer.model_max_length, |
|
) |
|
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) |
|
if uncond_prompt is not None: |
|
uncond_tokens, uncond_weights = self.pad_tokens_and_weights( |
|
uncond_tokens, |
|
uncond_weights, |
|
max_length, |
|
bos, |
|
eos, |
|
pad, |
|
no_boseos_middle=no_boseos_middle, |
|
chunk_length=tokenizer.model_max_length, |
|
) |
|
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) |
|
|
|
|
|
text_embeddings, text_pool = self.get_unweighted_text_embeddings( |
|
text_encoder, |
|
prompt_tokens, |
|
tokenizer.model_max_length, |
|
clip_skip, |
|
eos, |
|
pad, |
|
no_boseos_middle=no_boseos_middle, |
|
) |
|
prompt_weights = torch.tensor( |
|
prompt_weights, dtype=text_embeddings.dtype, device=device |
|
) |
|
if uncond_prompt is not None: |
|
uncond_embeddings, uncond_pool = self.get_unweighted_text_embeddings( |
|
text_encoder, |
|
uncond_tokens, |
|
tokenizer.model_max_length, |
|
clip_skip, |
|
eos, |
|
pad, |
|
no_boseos_middle=no_boseos_middle, |
|
) |
|
uncond_weights = torch.tensor( |
|
uncond_weights, dtype=uncond_embeddings.dtype, device=device |
|
) |
|
|
|
|
|
|
|
|
|
if (not skip_parsing) and (not skip_weighting): |
|
previous_mean = ( |
|
text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) |
|
) |
|
text_embeddings *= prompt_weights.unsqueeze(-1) |
|
current_mean = ( |
|
text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) |
|
) |
|
text_embeddings *= ( |
|
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) |
|
) |
|
if uncond_prompt is not None: |
|
previous_mean = ( |
|
uncond_embeddings.float() |
|
.mean(axis=[-2, -1]) |
|
.to(uncond_embeddings.dtype) |
|
) |
|
uncond_embeddings *= uncond_weights.unsqueeze(-1) |
|
current_mean = ( |
|
uncond_embeddings.float() |
|
.mean(axis=[-2, -1]) |
|
.to(uncond_embeddings.dtype) |
|
) |
|
uncond_embeddings *= ( |
|
(previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) |
|
) |
|
|
|
if uncond_prompt is not None: |
|
return ( |
|
text_embeddings, |
|
text_pool, |
|
uncond_embeddings, |
|
uncond_pool, |
|
prompt_tokens, |
|
) |
|
return text_embeddings, text_pool, None, None, prompt_tokens |
|
|
|
def get_token_replacer(self, tokenizer): |
|
tokenizer_index = self.tokenizers.index(tokenizer) |
|
token_replacements = self.token_replacements_list[tokenizer_index] |
|
|
|
def replace_tokens(tokens): |
|
|
|
if isinstance(tokens, torch.Tensor): |
|
tokens = tokens.tolist() |
|
|
|
new_tokens = [] |
|
for token in tokens: |
|
if token in token_replacements: |
|
replacement = token_replacements[token] |
|
new_tokens.extend(replacement) |
|
else: |
|
new_tokens.append(token) |
|
return new_tokens |
|
|
|
return replace_tokens |
|
|
|
def set_control_nets(self, ctrl_nets): |
|
self.control_nets = ctrl_nets |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]], |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
init_image: Union[ |
|
torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image] |
|
] = None, |
|
mask_image: Union[ |
|
torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image] |
|
] = None, |
|
height: int = 1024, |
|
width: int = 1024, |
|
original_height: int = None, |
|
original_width: int = None, |
|
original_height_negative: int = None, |
|
original_width_negative: int = None, |
|
crop_top: int = 0, |
|
crop_left: int = 0, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
negative_scale: float = None, |
|
strength: float = 0.8, |
|
|
|
eta: float = 0.0, |
|
generator: Optional[torch.Generator] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
max_embeddings_multiples: Optional[int] = 3, |
|
output_type: Optional[str] = "pil", |
|
vae_batch_size: float = None, |
|
return_latents: bool = False, |
|
|
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
|
is_cancelled_callback: Optional[Callable[[], bool]] = None, |
|
callback_steps: Optional[int] = 1, |
|
img2img_noise=None, |
|
clip_guide_images=None, |
|
**kwargs, |
|
): |
|
|
|
num_images_per_prompt = 1 |
|
|
|
if isinstance(prompt, str): |
|
batch_size = 1 |
|
prompt = [prompt] |
|
elif isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
raise ValueError( |
|
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" |
|
) |
|
reginonal_network = " AND " in prompt[0] |
|
|
|
vae_batch_size = ( |
|
batch_size |
|
if vae_batch_size is None |
|
else ( |
|
int(vae_batch_size) |
|
if vae_batch_size >= 1 |
|
else max(1, int(batch_size * vae_batch_size)) |
|
) |
|
) |
|
|
|
if strength < 0 or strength > 1: |
|
raise ValueError( |
|
f"The value of strength should in [0.0, 1.0] but is {strength}" |
|
) |
|
|
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError( |
|
f"`height` and `width` have to be divisible by 8 but are {height} and {width}." |
|
) |
|
|
|
if (callback_steps is None) or ( |
|
callback_steps is not None |
|
and (not isinstance(callback_steps, int) or callback_steps <= 0) |
|
): |
|
raise ValueError( |
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
|
f" {type(callback_steps)}." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
if not do_classifier_free_guidance and negative_scale is not None: |
|
print(f"negative_scale is ignored if guidance scalle <= 1.0") |
|
negative_scale = None |
|
|
|
|
|
if negative_prompt is None: |
|
negative_prompt = [""] * batch_size |
|
elif isinstance(negative_prompt, str): |
|
negative_prompt = [negative_prompt] * batch_size |
|
if batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
|
|
tes_text_embs = [] |
|
tes_uncond_embs = [] |
|
tes_real_uncond_embs = [] |
|
|
|
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): |
|
token_replacer = self.get_token_replacer(tokenizer) |
|
|
|
|
|
( |
|
text_embeddings, |
|
text_pool, |
|
uncond_embeddings, |
|
uncond_pool, |
|
_, |
|
) = self.get_weighted_text_embeddings( |
|
tokenizer, |
|
text_encoder, |
|
prompt=prompt, |
|
uncond_prompt=negative_prompt if do_classifier_free_guidance else None, |
|
max_embeddings_multiples=max_embeddings_multiples, |
|
clip_skip=self.clip_skip, |
|
token_replacer=token_replacer, |
|
device=self.device, |
|
**kwargs, |
|
) |
|
tes_text_embs.append(text_embeddings) |
|
tes_uncond_embs.append(uncond_embeddings) |
|
|
|
if negative_scale is not None: |
|
_, real_uncond_embeddings, _ = self.get_weighted_text_embeddings( |
|
token_replacer, |
|
prompt=prompt, |
|
uncond_prompt=[""] * batch_size, |
|
max_embeddings_multiples=max_embeddings_multiples, |
|
clip_skip=self.clip_skip, |
|
token_replacer=token_replacer, |
|
device=self.device, |
|
**kwargs, |
|
) |
|
tes_real_uncond_embs.append(real_uncond_embeddings) |
|
|
|
|
|
text_embeddings = tes_text_embs[0] |
|
uncond_embeddings = tes_uncond_embs[0] |
|
for i in range(1, len(tes_text_embs)): |
|
text_embeddings = torch.cat( |
|
[text_embeddings, tes_text_embs[i]], dim=2 |
|
) |
|
if do_classifier_free_guidance: |
|
uncond_embeddings = torch.cat( |
|
[uncond_embeddings, tes_uncond_embs[i]], dim=2 |
|
) |
|
|
|
if do_classifier_free_guidance: |
|
if negative_scale is None: |
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
else: |
|
text_embeddings = torch.cat( |
|
[uncond_embeddings, text_embeddings, real_uncond_embeddings] |
|
) |
|
|
|
if self.control_nets: |
|
|
|
if isinstance(clip_guide_images, PIL.Image.Image): |
|
clip_guide_images = [clip_guide_images] |
|
if isinstance(clip_guide_images[0], PIL.Image.Image): |
|
clip_guide_images = [ |
|
self.preprocess_image(im) for im in clip_guide_images |
|
] |
|
clip_guide_images = torch.cat(clip_guide_images) |
|
if isinstance(clip_guide_images, list): |
|
clip_guide_images = torch.stack(clip_guide_images) |
|
|
|
clip_guide_images = clip_guide_images.to( |
|
self.device, dtype=text_embeddings.dtype |
|
) |
|
|
|
|
|
if original_height is None: |
|
original_height = height |
|
if original_width is None: |
|
original_width = width |
|
if original_height_negative is None: |
|
original_height_negative = original_height |
|
if original_width_negative is None: |
|
original_width_negative = original_width |
|
if crop_top is None: |
|
crop_top = 0 |
|
if crop_left is None: |
|
crop_left = 0 |
|
emb1 = sdxl_train_util.get_timestep_embedding( |
|
torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256 |
|
) |
|
uc_emb1 = sdxl_train_util.get_timestep_embedding( |
|
torch.FloatTensor( |
|
[original_height_negative, original_width_negative] |
|
).unsqueeze(0), |
|
256, |
|
) |
|
emb2 = sdxl_train_util.get_timestep_embedding( |
|
torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256 |
|
) |
|
emb3 = sdxl_train_util.get_timestep_embedding( |
|
torch.FloatTensor([height, width]).unsqueeze(0), 256 |
|
) |
|
c_vector = ( |
|
torch.cat([emb1, emb2, emb3], dim=1) |
|
.to(self.device, dtype=text_embeddings.dtype) |
|
.repeat(batch_size, 1) |
|
) |
|
uc_vector = ( |
|
torch.cat([uc_emb1, emb2, emb3], dim=1) |
|
.to(self.device, dtype=text_embeddings.dtype) |
|
.repeat(batch_size, 1) |
|
) |
|
|
|
if reginonal_network: |
|
|
|
num_sub_prompts = len(text_pool) // batch_size |
|
text_pool = text_pool[ |
|
num_sub_prompts - 1 :: num_sub_prompts |
|
] |
|
|
|
if init_image is not None and self.clip_vision_model is not None: |
|
print( |
|
f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}" |
|
) |
|
vision_input = self.clip_vision_processor( |
|
init_image, return_tensors="pt", device=self.device |
|
) |
|
pixel_values = vision_input["pixel_values"].to( |
|
self.device, dtype=text_embeddings.dtype |
|
) |
|
|
|
clip_vision_embeddings = self.clip_vision_model( |
|
pixel_values=pixel_values, output_hidden_states=True, return_dict=True |
|
) |
|
clip_vision_embeddings = clip_vision_embeddings.image_embeds |
|
|
|
if len(clip_vision_embeddings) == 1 and batch_size > 1: |
|
clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) |
|
|
|
clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength |
|
assert ( |
|
clip_vision_embeddings.shape == text_pool.shape |
|
), f"{clip_vision_embeddings.shape} != {text_pool.shape}" |
|
text_pool = clip_vision_embeddings |
|
|
|
c_vector = torch.cat([text_pool, c_vector], dim=1) |
|
if do_classifier_free_guidance: |
|
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) |
|
vector_embeddings = torch.cat([uc_vector, c_vector]) |
|
else: |
|
vector_embeddings = c_vector |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, self.device) |
|
|
|
latents_dtype = text_embeddings.dtype |
|
init_latents_orig = None |
|
mask = None |
|
|
|
if init_image is None: |
|
|
|
|
|
|
|
|
|
|
|
latents_shape = ( |
|
batch_size * num_images_per_prompt, |
|
self.unet.in_channels, |
|
height // 8, |
|
width // 8, |
|
) |
|
|
|
if latents is None: |
|
if self.device.type == "mps": |
|
|
|
latents = torch.randn( |
|
latents_shape, |
|
generator=generator, |
|
device="cpu", |
|
dtype=latents_dtype, |
|
).to(self.device) |
|
else: |
|
latents = torch.randn( |
|
latents_shape, |
|
generator=generator, |
|
device=self.device, |
|
dtype=latents_dtype, |
|
) |
|
else: |
|
if latents.shape != latents_shape: |
|
raise ValueError( |
|
f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}" |
|
) |
|
latents = latents.to(self.device) |
|
|
|
timesteps = self.scheduler.timesteps.to(self.device) |
|
|
|
|
|
latents = latents * self.scheduler.init_noise_sigma |
|
else: |
|
|
|
if isinstance(init_image, PIL.Image.Image): |
|
init_image = [init_image] |
|
if isinstance(init_image[0], PIL.Image.Image): |
|
init_image = [self.preprocess_image(im) for im in init_image] |
|
init_image = torch.cat(init_image) |
|
if isinstance(init_image, list): |
|
init_image = torch.stack(init_image) |
|
|
|
|
|
if mask_image is not None: |
|
if isinstance(mask_image, PIL.Image.Image): |
|
mask_image = [mask_image] |
|
if isinstance(mask_image[0], PIL.Image.Image): |
|
mask_image = torch.cat( |
|
[self.preprocess_mask(im) for im in mask_image] |
|
) |
|
|
|
|
|
init_image = init_image.to(device=self.device, dtype=latents_dtype) |
|
if init_image.size()[-2:] == (height // 8, width // 8): |
|
init_latents = init_image |
|
else: |
|
if vae_batch_size >= batch_size: |
|
init_latent_dist = self.vae.encode( |
|
init_image.to(self.vae.dtype) |
|
).latent_dist |
|
init_latents = init_latent_dist.sample(generator=generator) |
|
else: |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
init_latents = [] |
|
for i in tqdm( |
|
range(0, min(batch_size, len(init_image)), vae_batch_size) |
|
): |
|
init_latent_dist = self.vae.encode( |
|
( |
|
init_image[i : i + vae_batch_size] |
|
if vae_batch_size > 1 |
|
else init_image[i].unsqueeze(0) |
|
).to(self.vae.dtype) |
|
).latent_dist |
|
init_latents.append( |
|
init_latent_dist.sample(generator=generator) |
|
) |
|
init_latents = torch.cat(init_latents) |
|
|
|
init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents |
|
|
|
if len(init_latents) == 1: |
|
init_latents = init_latents.repeat((batch_size, 1, 1, 1)) |
|
init_latents_orig = init_latents |
|
|
|
|
|
if mask_image is not None: |
|
mask = mask_image.to(device=self.device, dtype=latents_dtype) |
|
if len(mask) == 1: |
|
mask = mask.repeat((batch_size, 1, 1, 1)) |
|
|
|
|
|
if not mask.shape == init_latents.shape: |
|
raise ValueError("The mask and init_image should be the same size!") |
|
|
|
|
|
offset = self.scheduler.config.get("steps_offset", 0) |
|
init_timestep = int(num_inference_steps * strength) + offset |
|
init_timestep = min(init_timestep, num_inference_steps) |
|
|
|
timesteps = self.scheduler.timesteps[-init_timestep] |
|
timesteps = torch.tensor( |
|
[timesteps] * batch_size * num_images_per_prompt, device=self.device |
|
) |
|
|
|
|
|
latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) |
|
|
|
t_start = max(num_inference_steps - init_timestep + offset, 0) |
|
timesteps = self.scheduler.timesteps[t_start:].to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set( |
|
inspect.signature(self.scheduler.step).parameters.keys() |
|
) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs["eta"] = eta |
|
|
|
num_latent_input = ( |
|
(3 if negative_scale is not None else 2) |
|
if do_classifier_free_guidance |
|
else 1 |
|
) |
|
|
|
if self.control_nets: |
|
|
|
if self.control_net_enabled: |
|
for control_net, _ in self.control_nets: |
|
with torch.no_grad(): |
|
control_net.set_cond_image(clip_guide_images) |
|
else: |
|
for control_net, _ in self.control_nets: |
|
control_net.set_cond_image(None) |
|
|
|
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) |
|
for i, t in enumerate(tqdm(timesteps)): |
|
|
|
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
if self.control_nets and self.control_net_enabled: |
|
for j, ((control_net, ratio), enabled) in enumerate( |
|
zip(self.control_nets, each_control_net_enabled) |
|
): |
|
if not enabled or ratio >= 1.0: |
|
continue |
|
if ratio < i / len(timesteps): |
|
print( |
|
f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})" |
|
) |
|
control_net.set_cond_image(None) |
|
each_control_net_enabled[j] = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
noise_pred = self.unet( |
|
latent_model_input, t, text_embeddings, vector_embeddings |
|
) |
|
|
|
|
|
if do_classifier_free_guidance: |
|
if negative_scale is None: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk( |
|
num_latent_input |
|
) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_text - noise_pred_uncond |
|
) |
|
else: |
|
( |
|
noise_pred_negative, |
|
noise_pred_text, |
|
noise_pred_uncond, |
|
) = noise_pred.chunk( |
|
num_latent_input |
|
) |
|
noise_pred = ( |
|
noise_pred_uncond |
|
+ guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
- negative_scale * (noise_pred_negative - noise_pred_uncond) |
|
) |
|
|
|
|
|
latents = self.scheduler.step( |
|
noise_pred, t, latents, **extra_step_kwargs |
|
).prev_sample |
|
|
|
if mask is not None: |
|
|
|
init_latents_proper = self.scheduler.add_noise( |
|
init_latents_orig, img2img_noise, torch.tensor([t]) |
|
) |
|
latents = (init_latents_proper * mask) + (latents * (1 - mask)) |
|
|
|
|
|
if i % callback_steps == 0: |
|
if callback is not None: |
|
callback(i, t, latents) |
|
if is_cancelled_callback is not None and is_cancelled_callback(): |
|
return None |
|
|
|
if return_latents: |
|
return latents |
|
|
|
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents |
|
if vae_batch_size >= batch_size: |
|
image = self.vae.decode(latents.to(self.vae.dtype)).sample |
|
else: |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
images = [] |
|
for i in tqdm(range(0, batch_size, vae_batch_size)): |
|
images.append( |
|
self.vae.decode( |
|
( |
|
latents[i : i + vae_batch_size] |
|
if vae_batch_size > 1 |
|
else latents[i].unsqueeze(0) |
|
).to(self.vae.dtype) |
|
).sample |
|
) |
|
image = torch.cat(images) |
|
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
|
|
|
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
if output_type == "pil": |
|
|
|
image = (image * 255).round().astype("uint8") |
|
image = [Image.fromarray(im) for im in image] |
|
|
|
return image |
|
|
|
|
|
class SDXLLLiteImg2ImgPipeline: |
|
def __init__(self): |
|
self.SCHEDULER_LINEAR_START = 0.00085 |
|
self.SCHEDULER_LINEAR_END = 0.0120 |
|
self.SCHEDULER_TIMESTEPS = 1000 |
|
self.SCHEDLER_SCHEDULE = "scaled_linear" |
|
self.LATENT_CHANNELS = 4 |
|
self.DOWNSAMPLING_FACTOR = 8 |
|
|
|
def replace_unet_modules( |
|
self, |
|
unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, |
|
mem_eff_attn, |
|
xformers, |
|
sdpa, |
|
): |
|
if mem_eff_attn: |
|
print("Enable memory efficient attention for U-Net") |
|
|
|
|
|
unet.set_use_memory_efficient_attention(False, True) |
|
elif xformers: |
|
print("Enable xformers for U-Net") |
|
try: |
|
import xformers.ops |
|
except ImportError: |
|
raise ImportError("No xformers / xformersがインストールされていないようです") |
|
|
|
unet.set_use_memory_efficient_attention(True, False) |
|
elif sdpa: |
|
print("Enable SDPA for U-Net") |
|
unet.set_use_memory_efficient_attention(False, False) |
|
unet.set_use_sdpa(True) |
|
|
|
|
|
def replace_vae_modules( |
|
self, vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa |
|
): |
|
if mem_eff_attn: |
|
self.replace_vae_attn_to_memory_efficient() |
|
elif xformers: |
|
|
|
vae.set_use_memory_efficient_attention_xformers(True) |
|
elif sdpa: |
|
self.replace_vae_attn_to_sdpa() |
|
|
|
def replace_vae_attn_to_memory_efficient(self): |
|
print( |
|
"VAE Attention.forward has been replaced to FlashAttention (not xformers)" |
|
) |
|
flash_func = FlashAttentionFunction |
|
|
|
def forward_flash_attn(self, hidden_states, **kwargs): |
|
q_bucket_size = 512 |
|
k_bucket_size = 1024 |
|
|
|
residual = hidden_states |
|
batch, channel, height, width = hidden_states.shape |
|
|
|
|
|
hidden_states = self.group_norm(hidden_states) |
|
|
|
hidden_states = hidden_states.view( |
|
batch, channel, height * width |
|
).transpose(1, 2) |
|
|
|
|
|
query_proj = self.to_q(hidden_states) |
|
key_proj = self.to_k(hidden_states) |
|
value_proj = self.to_v(hidden_states) |
|
|
|
query_proj, key_proj, value_proj = map( |
|
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), |
|
(query_proj, key_proj, value_proj), |
|
) |
|
|
|
out = flash_func.apply( |
|
query_proj, |
|
key_proj, |
|
value_proj, |
|
None, |
|
False, |
|
q_bucket_size, |
|
k_bucket_size, |
|
) |
|
|
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
|
|
|
|
|
|
hidden_states = self.to_out[0](hidden_states) |
|
|
|
hidden_states = self.to_out[1](hidden_states) |
|
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape( |
|
batch, channel, height, width |
|
) |
|
|
|
|
|
hidden_states = (hidden_states + residual) / self.rescale_output_factor |
|
return hidden_states |
|
|
|
def forward_flash_attn_0_14(self, hidden_states, **kwargs): |
|
if not hasattr(self, "to_q"): |
|
self.to_q = self.query |
|
self.to_k = self.key |
|
self.to_v = self.value |
|
self.to_out = [self.proj_attn, torch.nn.Identity()] |
|
self.heads = self.num_heads |
|
return forward_flash_attn(self, hidden_states, **kwargs) |
|
|
|
if diffusers.__version__ < "0.15.0": |
|
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 |
|
else: |
|
diffusers.models.attention_processor.Attention.forward = forward_flash_attn |
|
|
|
def replace_vae_attn_to_xformers(self): |
|
print("VAE: Attention.forward has been replaced to xformers") |
|
import xformers.ops |
|
|
|
def forward_xformers(self, hidden_states, **kwargs): |
|
residual = hidden_states |
|
batch, channel, height, width = hidden_states.shape |
|
|
|
|
|
hidden_states = self.group_norm(hidden_states) |
|
|
|
hidden_states = hidden_states.view( |
|
batch, channel, height * width |
|
).transpose(1, 2) |
|
|
|
|
|
query_proj = self.to_q(hidden_states) |
|
key_proj = self.to_k(hidden_states) |
|
value_proj = self.to_v(hidden_states) |
|
|
|
query_proj, key_proj, value_proj = map( |
|
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), |
|
(query_proj, key_proj, value_proj), |
|
) |
|
|
|
query_proj = query_proj.contiguous() |
|
key_proj = key_proj.contiguous() |
|
value_proj = value_proj.contiguous() |
|
out = xformers.ops.memory_efficient_attention( |
|
query_proj, key_proj, value_proj, attn_bias=None |
|
) |
|
|
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
|
|
|
|
|
|
hidden_states = self.to_out[0](hidden_states) |
|
|
|
hidden_states = self.to_out[1](hidden_states) |
|
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape( |
|
batch, channel, height, width |
|
) |
|
|
|
|
|
hidden_states = (hidden_states + residual) / self.rescale_output_factor |
|
return hidden_states |
|
|
|
def forward_xformers_0_14(self, hidden_states, **kwargs): |
|
if not hasattr(self, "to_q"): |
|
self.to_q = self.query |
|
self.to_k = self.key |
|
self.to_v = self.value |
|
self.to_out = [self.proj_attn, torch.nn.Identity()] |
|
self.heads = self.num_heads |
|
return forward_xformers(self, hidden_states, **kwargs) |
|
|
|
if diffusers.__version__ < "0.15.0": |
|
diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 |
|
else: |
|
diffusers.models.attention_processor.Attention.forward = forward_xformers |
|
|
|
def replace_vae_attn_to_sdpa(): |
|
print("VAE: Attention.forward has been replaced to sdpa") |
|
|
|
def forward_sdpa(self, hidden_states, **kwargs): |
|
residual = hidden_states |
|
batch, channel, height, width = hidden_states.shape |
|
|
|
|
|
hidden_states = self.group_norm(hidden_states) |
|
|
|
hidden_states = hidden_states.view( |
|
batch, channel, height * width |
|
).transpose(1, 2) |
|
|
|
|
|
query_proj = self.to_q(hidden_states) |
|
key_proj = self.to_k(hidden_states) |
|
value_proj = self.to_v(hidden_states) |
|
|
|
query_proj, key_proj, value_proj = map( |
|
lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), |
|
(query_proj, key_proj, value_proj), |
|
) |
|
|
|
out = torch.nn.functional.scaled_dot_product_attention( |
|
query_proj, |
|
key_proj, |
|
value_proj, |
|
attn_mask=None, |
|
dropout_p=0.0, |
|
is_causal=False, |
|
) |
|
|
|
out = rearrange(out, "b n h d -> b n (h d)") |
|
|
|
|
|
|
|
hidden_states = self.to_out[0](hidden_states) |
|
|
|
hidden_states = self.to_out[1](hidden_states) |
|
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape( |
|
batch, channel, height, width |
|
) |
|
|
|
|
|
hidden_states = (hidden_states + residual) / self.rescale_output_factor |
|
return hidden_states |
|
|
|
def forward_sdpa_0_14(self, hidden_states, **kwargs): |
|
if not hasattr(self, "to_q"): |
|
self.to_q = self.query |
|
self.to_k = self.key |
|
self.to_v = self.value |
|
self.to_out = [self.proj_attn, torch.nn.Identity()] |
|
self.heads = self.num_heads |
|
return forward_sdpa(self, hidden_states, **kwargs) |
|
|
|
if diffusers.__version__ < "0.15.0": |
|
diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 |
|
else: |
|
diffusers.models.attention_processor.Attention.forward = forward_sdpa |
|
|
|
def load(self, pipeline: AbstractPipeline, controlnet_urls: Optional[List[str]]): |
|
pipeline.pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl") |
|
pipeline.pipe.fuse_lora() |
|
|
|
self.dtype = pipeline.pipe.dtype |
|
self.device = pipeline.pipe.device |
|
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl( |
|
pipeline.pipe.unet.state_dict() |
|
) |
|
with init_empty_weights(): |
|
original_unet = ( |
|
sdxl_original_unet.SdxlUNet2DConditionModel() |
|
) |
|
sdxl_model_util._load_state_dict_on_device( |
|
original_unet, |
|
state_dict, |
|
device=pipeline.pipe.device, |
|
dtype=pipeline.pipe.dtype, |
|
) |
|
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel( |
|
original_unet |
|
) |
|
sched_init_args = {} |
|
has_steps_offset = True |
|
has_clip_sample = True |
|
scheduler_num_noises_per_step = 1 |
|
|
|
mem_eff = not (True or False) |
|
self.replace_unet_modules(unet, mem_eff, True, False) |
|
self.replace_vae_modules(pipeline.pipe.vae, mem_eff, True, False) |
|
|
|
scheduler_cls = LCMScheduler |
|
scheduler_module = diffusers.schedulers.scheduling_ddim |
|
|
|
if has_steps_offset: |
|
sched_init_args["steps_offset"] = 1 |
|
if has_clip_sample: |
|
sched_init_args["clip_sample"] = False |
|
|
|
class NoiseManager: |
|
def __init__(self): |
|
self.sampler_noises = None |
|
self.sampler_noise_index = 0 |
|
|
|
def reset_sampler_noises(self, noises): |
|
self.sampler_noise_index = 0 |
|
self.sampler_noises = noises |
|
|
|
def randn( |
|
self, shape, device=None, dtype=None, layout=None, generator=None |
|
): |
|
|
|
if self.sampler_noises is not None and self.sampler_noise_index < len( |
|
self.sampler_noises |
|
): |
|
noise = self.sampler_noises[self.sampler_noise_index] |
|
if shape != noise.shape: |
|
noise = None |
|
else: |
|
noise = None |
|
|
|
if noise == None: |
|
print( |
|
f"unexpected noise request: {self.sampler_noise_index}, {shape}" |
|
) |
|
noise = torch.randn( |
|
shape, dtype=dtype, device=device, generator=generator |
|
) |
|
|
|
self.sampler_noise_index += 1 |
|
return noise |
|
|
|
class TorchRandReplacer: |
|
def __init__(self, noise_manager): |
|
self.noise_manager = noise_manager |
|
|
|
def __getattr__(self, item): |
|
if item == "randn": |
|
return self.noise_manager.randn |
|
if hasattr(torch, item): |
|
return getattr(torch, item) |
|
raise AttributeError( |
|
"'{}' object has no attribute '{}'".format( |
|
type(self).__name__, item |
|
) |
|
) |
|
|
|
noise_manager = NoiseManager() |
|
if scheduler_module is not None: |
|
scheduler_module.torch = TorchRandReplacer(noise_manager) |
|
|
|
scheduler = scheduler_cls( |
|
num_train_timesteps=self.SCHEDULER_TIMESTEPS, |
|
beta_start=self.SCHEDULER_LINEAR_START, |
|
beta_end=self.SCHEDULER_LINEAR_END, |
|
beta_schedule=self.SCHEDLER_SCHEDULE, |
|
**sched_init_args, |
|
) |
|
device = torch.device( |
|
pipeline.pipe.device if torch.cuda.is_available() else "cpu" |
|
) |
|
|
|
|
|
|
|
|
|
print(pipeline.pipe.dtype) |
|
unet.to(pipeline.pipe.dtype).to(pipeline.pipe.device) |
|
|
|
|
|
unet.eval() |
|
control_nets: List[Tuple[ControlNetLLLite, float]] = [] |
|
for link in controlnet_urls: |
|
net_path = Path.home() / ".cache" / link.split("/")[-1] |
|
download_file(link, net_path) |
|
print(f"loading controlnet {net_path}") |
|
state_dict = load_file(net_path) |
|
mlp_dim = None |
|
cond_emb_dim = None |
|
for key, value in state_dict.items(): |
|
if mlp_dim is None and "down.0.weight" in key: |
|
mlp_dim = value.shape[0] |
|
elif cond_emb_dim is None and "conditioning1.0" in key: |
|
cond_emb_dim = value.shape[0] * 2 |
|
if mlp_dim is not None and cond_emb_dim is not None: |
|
break |
|
assert ( |
|
mlp_dim is not None and cond_emb_dim is not None |
|
), f"invalid control net: {link}" |
|
|
|
multiplier = 0.2 |
|
|
|
ratio = 1.0 |
|
|
|
control_net = ControlNetLLLite( |
|
unet, cond_emb_dim, mlp_dim, multiplier=multiplier |
|
) |
|
control_net.apply_to() |
|
control_net.load_state_dict(state_dict) |
|
control_net.to(pipeline.pipe.dtype).to(device) |
|
control_net.set_batch_cond_only(False, False) |
|
control_nets.append((control_net, ratio)) |
|
|
|
networks = [] |
|
self.pipe = PipelineLike( |
|
device, |
|
pipeline.pipe.vae, |
|
[pipeline.pipe.text_encoder, pipeline.pipe.text_encoder_2], |
|
[pipeline.pipe.tokenizer, pipeline.pipe.tokenizer_2], |
|
unet, |
|
scheduler, |
|
2, |
|
) |
|
self.pipe.set_control_nets(control_nets) |
|
|
|
clear_cuda_and_gc() |
|
|
|
pipeline.pipe.unload_lora_weights() |
|
pipeline.pipe.unfuse_lora() |
|
|
|
clear_cuda_and_gc() |
|
|
|
def __call__( |
|
self, |
|
prompt: str, |
|
negative_prompt: str, |
|
seed: int, |
|
image: Image.Image, |
|
condition_image: Union[Image.Image, List[Image.Image]], |
|
height: int = 1024, |
|
width: int = 1024, |
|
num_inference_steps: int = 24, |
|
guidance_scale=1.0, |
|
): |
|
noise_shape = ( |
|
self.LATENT_CHANNELS, |
|
height // self.DOWNSAMPLING_FACTOR, |
|
width // self.DOWNSAMPLING_FACTOR, |
|
) |
|
i2i_noises = torch.zeros( |
|
(1, *noise_shape), device=self.device, dtype=self.dtype |
|
) |
|
i2i_noises[0] = torch.randn(noise_shape, device=self.device, dtype=self.dtype) |
|
images = self.pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
seed=seed, |
|
init_image=image, |
|
height=height, |
|
width=width, |
|
strength=1.0, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
clip_guide_images=condition_image, |
|
img2img_noise=i2i_noises, |
|
) |
|
return images |
|
|