Spaces:
Sleeping
Sleeping
File size: 6,875 Bytes
40e68f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# Adapted from https://github.com/MichalGeyer/pnp-diffusers/blob/main/pnp_utils.py
import torch
import os
import random
import numpy as np
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def register_time(model, t):
conv_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
for res in conv_res_dict:
for block in conv_res_dict[res]:
conv_module = model.unet.up_blocks[res].resnets[block]
setattr(conv_module, 't', t)
down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]}
up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
for res in up_res_dict:
for block in up_res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
setattr(module, 't', t)
for res in down_res_dict:
for block in down_res_dict[res]:
module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
setattr(module, 't', t)
module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1
setattr(module, 't', t)
def load_source_latents_t(t, latents_path):
latents_t_path = os.path.join(latents_path, f'noisy_latents_{t}.pt')
assert os.path.exists(latents_t_path), f'Missing latents at t {t} path {latents_t_path}'
latents = torch.load(latents_t_path)
return latents
def register_attention_control_efficient(model, injection_schedule):
def sa_forward(self):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward(x, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if is_cross else x
if not is_cross and self.injection_schedule is not None and (
self.t in self.injection_schedule or self.t == 1000):
q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
source_batch_size = int(q.shape[0] // 3)
# inject unconditional
q[source_batch_size:2 * source_batch_size] = q[:source_batch_size]
k[source_batch_size:2 * source_batch_size] = k[:source_batch_size]
# inject conditional
q[2 * source_batch_size:] = q[:source_batch_size]
k[2 * source_batch_size:] = k[:source_batch_size]
q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
else:
q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
v = self.to_v(encoder_hidden_states)
v = self.head_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if attention_mask is not None:
attention_mask = attention_mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~attention_mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.batch_to_head_dim(out)
return to_out(out)
return forward
res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
for res in res_dict:
for block in res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
module.forward = sa_forward(module)
setattr(module, 'injection_schedule', injection_schedule)
def register_conv_control_efficient(model, injection_schedule):
def conv_forward(self):
def forward(input_tensor, temb):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
source_batch_size = int(hidden_states.shape[0] // 3)
# inject unconditional
hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
# inject conditional
hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size]
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
return forward
conv_res_dict = {1: [1, 2]}
for res in conv_res_dict:
for block in conv_res_dict[res]:
conv_module = model.unet.up_blocks[res].resnets[block]
conv_module.forward = conv_forward(conv_module)
setattr(conv_module, 'injection_schedule', injection_schedule)
|