afeng commited on
Commit
f4f90db
·
1 Parent(s): 56db6e4
Files changed (2) hide show
  1. app.py +31 -13
  2. pipeline_dedit_sdxl.py +875 -0
app.py CHANGED
@@ -296,19 +296,37 @@ with gr.Blocks() as demo:
296
  gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True )
297
 
298
  add_button = gr.Button("Run optimization")
299
-
300
- run_optimization = partial(
301
- run_main,
302
- num_tokens=int(num_tokens.value),
303
- embedding_learning_rate = float(embedding_learning_rate.value),
304
- max_emb_train_steps = int(max_emb_train_steps.value),
305
- diffusion_model_learning_rate= float(diffusion_model_learning_rate.value),
306
- max_diffusion_train_steps = int(max_diffusion_train_steps.value),
307
- train_batch_size=int(train_batch_size.value),
308
- gradient_accumulation_steps=int(gradient_accumulation_steps.value)
309
- )
310
- add_button.click(run_optimization,
311
- inputs = [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  outputs = []
313
  )
314
 
 
296
  gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True )
297
 
298
  add_button = gr.Button("Run optimization")
299
+ def run_optimization_wrapper (
300
+ num_tokens,
301
+ embedding_learning_rate ,
302
+ max_emb_train_steps ,
303
+ diffusion_model_learning_rate ,
304
+ max_diffusion_train_steps,
305
+ train_batch_size,
306
+ gradient_accumulation_steps
307
+ ):
308
+ run_optimization = partial(
309
+ run_main,
310
+ num_tokens=int(num_tokens),
311
+ embedding_learning_rate = float(embedding_learning_rate),
312
+ max_emb_train_steps = int(max_emb_train_steps),
313
+ diffusion_model_learning_rate= float(diffusion_model_learning_rate),
314
+ max_diffusion_train_steps = int(max_diffusion_train_steps),
315
+ train_batch_size=int(train_batch_size),
316
+ gradient_accumulation_steps=int(gradient_accumulation_steps)
317
+ )
318
+ run_optimization()
319
+
320
+ add_button.click(run_optimization_wrapper,
321
+ inputs = [
322
+ num_tokens,
323
+ embedding_learning_rate ,
324
+ max_emb_train_steps ,
325
+ diffusion_model_learning_rate ,
326
+ max_diffusion_train_steps,
327
+ train_batch_size,
328
+ gradient_accumulation_steps
329
+ ],
330
  outputs = []
331
  )
332
 
pipeline_dedit_sdxl.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils import import_model_class_from_model_name_or_path
3
+ from transformers import AutoTokenizer
4
+ from diffusers import (
5
+ AutoencoderKL,
6
+ DDPMScheduler,
7
+ StableDiffusionXLPipeline,
8
+ UNet2DConditionModel,
9
+ )
10
+ from accelerate import Accelerator
11
+ from tqdm.auto import tqdm
12
+ from utils import sdxl_prepare_input_decom, save_images
13
+ import torch.nn.functional as F
14
+ import itertools
15
+ from peft import LoraConfig
16
+ from controller import GroupedCAController, register_attention_disentangled_control, DummyController
17
+ from utils import image2latent, latent2image
18
+ import matplotlib.pyplot as plt
19
+ from utils_mask import check_mask_overlap_torch
20
+
21
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
22
+ max_length = 40
23
+ class DEditSDXLPipeline:
24
+ def __init__(
25
+ self,
26
+ mask_list,
27
+ mask_label_list,
28
+ mask_list_2 = None,
29
+ mask_label_list_2 = None,
30
+ resolution = 1024,
31
+ num_tokens = 1
32
+ ):
33
+ super().__init__()
34
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
35
+ self.model_id = model_id
36
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", use_fast=False)
37
+ self.tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", use_fast=False)
38
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(model_id, subfolder = "text_encoder")
39
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(model_id, subfolder="text_encoder_2")
40
+ self.text_encoder = text_encoder_cls_one.from_pretrained(model_id, subfolder="text_encoder" ).to(device)
41
+ self.text_encoder_2 = text_encoder_cls_two.from_pretrained(model_id, subfolder="text_encoder_2").to(device)
42
+ self.unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet" )
43
+ self.unet.ca_dim = 2048
44
+ self.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
45
+ self.scheduler = DDPMScheduler.from_pretrained(model_id , subfolder="scheduler")
46
+
47
+ self.mixed_precision = "fp16"
48
+ self.resolution = resolution
49
+ self.num_tokens = num_tokens
50
+
51
+ self.mask_list = mask_list
52
+ self.mask_label_list = mask_label_list
53
+ notation_token_list = [phrase.split(" ")[-1] for phrase in mask_label_list]
54
+ placeholder_token_list = ["#"+word+"{}".format(widx) for widx, word in enumerate(notation_token_list)]
55
+ self.set_string_list, placeholder_token_ids = self.add_tokens(placeholder_token_list)
56
+ self.min_added_id = min(placeholder_token_ids)
57
+ self.max_added_id = max(placeholder_token_ids)
58
+
59
+ if mask_list_2 is not None:
60
+ self.mask_list_2 = mask_list_2
61
+ self.mask_label_list_2 = mask_label_list_2
62
+ notation_token_list_2 = [phrase.split(" ")[-1] for phrase in mask_label_list_2]
63
+
64
+ placeholder_token_list_2 = ["$"+word+"{}".format(widx) for widx, word in enumerate(notation_token_list_2)]
65
+ self.set_string_list_2, placeholder_token_ids_2 = self.add_tokens(placeholder_token_list_2)
66
+ self.max_added_id = max(placeholder_token_ids_2)
67
+
68
+ def add_tokens_text_encoder_random_init(self, placeholder_token, num_tokens=1):
69
+ # Add the placeholder token in tokenizer
70
+ placeholder_tokens = [placeholder_token]
71
+ # add dummy tokens for multi-vector
72
+ additional_tokens = []
73
+ for i in range(1, num_tokens):
74
+ additional_tokens.append(f"{placeholder_token}_{i}")
75
+ placeholder_tokens += additional_tokens
76
+ num_added_tokens = self.tokenizer.add_tokens(placeholder_tokens) # 49408
77
+ num_added_tokens = self.tokenizer_2.add_tokens(placeholder_tokens) # 49408
78
+
79
+ if num_added_tokens != num_tokens:
80
+ raise ValueError(
81
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
82
+ " `placeholder_token` that is not already in the tokenizer."
83
+ )
84
+ placeholder_token_ids = self.tokenizer.convert_tokens_to_ids(placeholder_tokens)
85
+ placeholder_token_ids_2 = self.tokenizer_2.convert_tokens_to_ids(placeholder_tokens)
86
+ assert placeholder_token_ids == placeholder_token_ids_2, "Two text encoders are expected to have same vocabs"
87
+
88
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
89
+ token_embeds = self.text_encoder.get_input_embeddings().weight.data
90
+ std, mean = torch.std_mean(token_embeds)
91
+ with torch.no_grad():
92
+ for token_id in placeholder_token_ids:
93
+ token_embeds[token_id] = torch.randn_like(token_embeds[token_id])*std + mean
94
+
95
+ self.text_encoder_2.resize_token_embeddings(len(self.tokenizer))
96
+ token_embeds = self.text_encoder_2.get_input_embeddings().weight.data
97
+ std, mean = torch.std_mean(token_embeds)
98
+ with torch.no_grad():
99
+ for token_id in placeholder_token_ids:
100
+ token_embeds[token_id] = torch.randn_like(token_embeds[token_id])*std + mean
101
+
102
+ set_string = " ".join(self.tokenizer.convert_ids_to_tokens(placeholder_token_ids))
103
+
104
+ return set_string, placeholder_token_ids
105
+
106
+ def add_tokens(self, placeholder_token_list):
107
+ set_string_list = []
108
+ placeholder_token_ids_list = []
109
+ for str_idx in range(len(placeholder_token_list)):
110
+ placeholder_token = placeholder_token_list[str_idx]
111
+ set_string, placeholder_token_ids = self.add_tokens_text_encoder_random_init(placeholder_token, num_tokens=self.num_tokens)
112
+ set_string_list.append(set_string)
113
+ placeholder_token_ids_list.append(placeholder_token_ids)
114
+ placeholder_token_ids = list(itertools.chain(*placeholder_token_ids_list))
115
+ return set_string_list, placeholder_token_ids
116
+
117
+ def train_emb(
118
+ self,
119
+ image_gt,
120
+ set_string_list,
121
+ gradient_accumulation_steps = 5,
122
+ embedding_learning_rate = 1e-4,
123
+ max_emb_train_steps = 100,
124
+ train_batch_size = 1,
125
+ train_full_lora = False
126
+ ):
127
+ decom_controller = GroupedCAController(mask_list = self.mask_list)
128
+ register_attention_disentangled_control(self.unet, decom_controller)
129
+
130
+ accelerator = Accelerator(mixed_precision=self.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps)
131
+ self.vae.requires_grad_(False)
132
+ self.unet.requires_grad_(False)
133
+
134
+ self.text_encoder.requires_grad_(True)
135
+ self.text_encoder_2.requires_grad_(True)
136
+
137
+ self.text_encoder.text_model.encoder.requires_grad_(False)
138
+ self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
139
+ self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
140
+
141
+ self.text_encoder_2.text_model.encoder.requires_grad_(False)
142
+ self.text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
143
+ self.text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
144
+
145
+ weight_dtype = torch.float32
146
+ if accelerator.mixed_precision == "fp16":
147
+ weight_dtype = torch.float16
148
+ elif accelerator.mixed_precision == "bf16":
149
+ weight_dtype = torch.bfloat16
150
+
151
+ self.unet.to(device, dtype=weight_dtype)
152
+ self.vae.to(device, dtype=weight_dtype)
153
+
154
+ trainable_embmat_list_1 = [param for param in self.text_encoder.get_input_embeddings().parameters()]
155
+ trainable_embmat_list_2 = [param for param in self.text_encoder_2.get_input_embeddings().parameters()]
156
+
157
+ optimizer = torch.optim.AdamW(trainable_embmat_list_1 + trainable_embmat_list_2, lr=embedding_learning_rate)
158
+
159
+ self.text_encoder, self.text_encoder_2, optimizer = accelerator.prepare(self.text_encoder, self.text_encoder_2, optimizer)
160
+
161
+ orig_embeds_params_1 = accelerator.unwrap_model(self.text_encoder) .get_input_embeddings().weight.data.clone()
162
+ orig_embeds_params_2 = accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight.data.clone()
163
+
164
+ self.text_encoder.train()
165
+ self.text_encoder_2.train()
166
+
167
+ effective_emb_train_steps = max_emb_train_steps//gradient_accumulation_steps
168
+
169
+ if accelerator.is_main_process:
170
+ accelerator.init_trackers("DEdit EmbSteps", config={
171
+ "embedding_learning_rate": embedding_learning_rate,
172
+ "text_embedding_optimization_steps": effective_emb_train_steps,
173
+ })
174
+ global_step = 0
175
+ noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0" , subfolder="scheduler")
176
+ progress_bar = tqdm(range(0, effective_emb_train_steps), initial = global_step, desc="EmbSteps")
177
+ latents0 = image2latent(image_gt, vae = self.vae, dtype=weight_dtype)
178
+ latents0 = latents0.repeat(train_batch_size, 1, 1, 1)
179
+
180
+ for _ in range(max_emb_train_steps):
181
+ with accelerator.accumulate(self.text_encoder, self.text_encoder_2):
182
+ latents = latents0.clone().detach()
183
+ noise = torch.randn_like(latents)
184
+ bsz = latents.shape[0]
185
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
186
+ timesteps = timesteps.long()
187
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
188
+ encoder_hidden_states_list, add_text_embeds, add_time_ids = sdxl_prepare_input_decom(
189
+ set_string_list,
190
+ self.tokenizer,
191
+ self.tokenizer_2,
192
+ self.text_encoder,
193
+ self.text_encoder_2,
194
+ length = max_length,
195
+ bsz = train_batch_size,
196
+ weight_dtype = weight_dtype
197
+ )
198
+
199
+ model_pred = self.unet(
200
+ noisy_latents,
201
+ timesteps,
202
+ encoder_hidden_states = encoder_hidden_states_list,
203
+ cross_attention_kwargs = None,
204
+ added_cond_kwargs={"text_embeds": add_text_embeds, "time_ids": add_time_ids},
205
+ return_dict=False
206
+ )[0]
207
+ loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
208
+ accelerator.backward(loss)
209
+ optimizer.step()
210
+ optimizer.zero_grad()
211
+
212
+ index_no_updates = torch.ones((len(self.tokenizer),), dtype=torch.bool)
213
+ index_no_updates[self.min_added_id : self.max_added_id + 1] = False
214
+ with torch.no_grad():
215
+ accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight[
216
+ index_no_updates] = orig_embeds_params_1[index_no_updates]
217
+
218
+ index_no_updates = torch.ones((len(self.tokenizer_2),), dtype=torch.bool)
219
+ index_no_updates[self.min_added_id : self.max_added_id + 1] = False
220
+ with torch.no_grad():
221
+ accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight[
222
+ index_no_updates] = orig_embeds_params_2[index_no_updates]
223
+
224
+ logs = {"loss": loss.detach().item(), "lr": embedding_learning_rate}
225
+ progress_bar.set_postfix(**logs)
226
+ accelerator.log(logs, step=global_step)
227
+ if accelerator.sync_gradients:
228
+ progress_bar.update(1)
229
+ global_step += 1
230
+
231
+ if global_step >= max_emb_train_steps:
232
+ break
233
+ accelerator.wait_for_everyone()
234
+ accelerator.end_training()
235
+ self.text_encoder = accelerator.unwrap_model(self.text_encoder).to(dtype = weight_dtype)
236
+ self.text_encoder_2 = accelerator.unwrap_model(self.text_encoder_2).to(dtype = weight_dtype)
237
+
238
+ def train_model(
239
+ self,
240
+ image_gt,
241
+ set_string_list,
242
+ gradient_accumulation_steps = 5,
243
+ max_diffusion_train_steps = 100,
244
+ diffusion_model_learning_rate = 1e-5,
245
+ train_batch_size = 1,
246
+ train_full_lora = False,
247
+ lora_rank = 4,
248
+ lora_alpha = 4
249
+ ):
250
+ self.unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet").to(device)
251
+ self.unet.ca_dim = 2048
252
+ decom_controller = GroupedCAController(mask_list = self.mask_list)
253
+ register_attention_disentangled_control(self.unet, decom_controller)
254
+
255
+ mixed_precision = "fp16"
256
+ accelerator = Accelerator(gradient_accumulation_steps = gradient_accumulation_steps, mixed_precision = mixed_precision)
257
+
258
+ weight_dtype = torch.float32
259
+ if accelerator.mixed_precision == "fp16":
260
+ weight_dtype = torch.float16
261
+ elif accelerator.mixed_precision == "bf16":
262
+ weight_dtype = torch.bfloat16
263
+
264
+ self.vae.requires_grad_(False)
265
+ self.vae.to(device, dtype=weight_dtype)
266
+
267
+ self.unet.requires_grad_(False)
268
+ self.unet.train()
269
+
270
+ self.text_encoder.requires_grad_(False)
271
+ self.text_encoder_2.requires_grad_(False)
272
+
273
+ if not train_full_lora:
274
+ trainable_params_list = []
275
+ for _, module in self.unet.named_modules():
276
+ module_name = type(module).__name__
277
+ if module_name == "Attention":
278
+ if module.to_k.in_features == 2048: # this is cross attention:
279
+ module.to_k.weight.requires_grad = True
280
+ trainable_params_list.append(module.to_k.weight)
281
+ if module.to_k.bias is not None:
282
+ module.to_k.bias.requires_grad = True
283
+ trainable_params_list.append(module.to_k.bias)
284
+ module.to_v.weight.requires_grad = True
285
+ trainable_params_list.append(module.to_v.weight)
286
+ if module.to_v.bias is not None:
287
+ module.to_v.bias.requires_grad = True
288
+ trainable_params_list.append(module.to_v.bias)
289
+ module.to_q.weight.requires_grad = True
290
+ trainable_params_list.append(module.to_q.weight)
291
+ if module.to_q.bias is not None:
292
+ module.to_q.bias.requires_grad = True
293
+ trainable_params_list.append(module.to_q.bias)
294
+ else:
295
+ unet_lora_config = LoraConfig(
296
+ r=lora_rank,
297
+ lora_alpha=lora_alpha,
298
+ init_lora_weights="gaussian",
299
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
300
+ )
301
+ self.unet.add_adapter(unet_lora_config)
302
+ print("training full parameters using lora!")
303
+ trainable_params_list = list(filter(lambda p: p.requires_grad, self.unet.parameters()))
304
+
305
+ self.text_encoder.to(device, dtype=weight_dtype)
306
+ self.text_encoder_2.to(device, dtype=weight_dtype)
307
+ optimizer = torch.optim.AdamW(trainable_params_list, lr=diffusion_model_learning_rate)
308
+ self.unet, optimizer = accelerator.prepare(self.unet, optimizer)
309
+ psum2 = sum(p.numel() for p in trainable_params_list)
310
+
311
+ effective_diffusion_train_steps = max_diffusion_train_steps // gradient_accumulation_steps
312
+ if accelerator.is_main_process:
313
+ accelerator.init_trackers("textual_inversion", config={
314
+ "diffusion_model_learning_rate": diffusion_model_learning_rate,
315
+ "diffusion_model_optimization_steps": effective_diffusion_train_steps,
316
+ })
317
+
318
+ global_step = 0
319
+ progress_bar = tqdm( range(0, effective_diffusion_train_steps),initial=global_step, desc="ModelSteps")
320
+
321
+ noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0" , subfolder="scheduler")
322
+
323
+ latents0 = image2latent(image_gt, vae = self.vae, dtype=weight_dtype)
324
+ latents0 = latents0.repeat(train_batch_size, 1, 1, 1)
325
+
326
+ with torch.no_grad():
327
+ encoder_hidden_states_list, add_text_embeds, add_time_ids = sdxl_prepare_input_decom(
328
+ set_string_list,
329
+ self.tokenizer,
330
+ self.tokenizer_2,
331
+ self.text_encoder,
332
+ self.text_encoder_2,
333
+ length = max_length,
334
+ bsz = train_batch_size,
335
+ weight_dtype = weight_dtype
336
+ )
337
+
338
+ for _ in range(max_diffusion_train_steps):
339
+ with accelerator.accumulate(self.unet):
340
+ latents = latents0.clone().detach()
341
+ noise = torch.randn_like(latents)
342
+ bsz = latents.shape[0]
343
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
344
+ timesteps = timesteps.long()
345
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
346
+ model_pred = self.unet(
347
+ noisy_latents,
348
+ timesteps,
349
+ encoder_hidden_states=encoder_hidden_states_list,
350
+ cross_attention_kwargs=None, return_dict=False,
351
+ added_cond_kwargs={"text_embeds": add_text_embeds, "time_ids": add_time_ids}
352
+ )[0]
353
+ loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
354
+ accelerator.backward(loss)
355
+ optimizer.step()
356
+ optimizer.zero_grad()
357
+
358
+ logs = {"loss": loss.detach().item(), "lr": diffusion_model_learning_rate}
359
+ progress_bar.set_postfix(**logs)
360
+ accelerator.log(logs, step=global_step)
361
+ if accelerator.sync_gradients:
362
+ progress_bar.update(1)
363
+ global_step += 1
364
+ if global_step >=max_diffusion_train_steps:
365
+ break
366
+ accelerator.wait_for_everyone()
367
+ accelerator.end_training()
368
+ self.unet = accelerator.unwrap_model(self.unet).to(dtype = weight_dtype)
369
+
370
+ def train_emb_2imgs(
371
+ self,
372
+ image_gt_1,
373
+ image_gt_2,
374
+ set_string_list_1,
375
+ set_string_list_2,
376
+ gradient_accumulation_steps = 5,
377
+ embedding_learning_rate = 1e-4,
378
+ max_emb_train_steps = 100,
379
+ train_batch_size = 1,
380
+ train_full_lora = False
381
+ ):
382
+ decom_controller_1 = GroupedCAController(mask_list = self.mask_list)
383
+ decom_controller_2 = GroupedCAController(mask_list = self.mask_list_2)
384
+ accelerator = Accelerator(mixed_precision=self.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps)
385
+ self.vae.requires_grad_(False)
386
+ self.unet.requires_grad_(False)
387
+
388
+ self.text_encoder.requires_grad_(True)
389
+ self.text_encoder_2.requires_grad_(True)
390
+
391
+ self.text_encoder.text_model.encoder.requires_grad_(False)
392
+ self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
393
+ self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
394
+
395
+ self.text_encoder_2.text_model.encoder.requires_grad_(False)
396
+ self.text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
397
+ self.text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
398
+
399
+ weight_dtype = torch.float32
400
+ if accelerator.mixed_precision == "fp16":
401
+ weight_dtype = torch.float16
402
+ elif accelerator.mixed_precision == "bf16":
403
+ weight_dtype = torch.bfloat16
404
+
405
+ self.unet.to(device, dtype=weight_dtype)
406
+ self.vae.to(device, dtype=weight_dtype)
407
+
408
+
409
+ trainable_embmat_list_1 = [param for param in self.text_encoder.get_input_embeddings().parameters()]
410
+ trainable_embmat_list_2 = [param for param in self.text_encoder_2.get_input_embeddings().parameters()]
411
+
412
+ optimizer = torch.optim.AdamW(trainable_embmat_list_1 + trainable_embmat_list_2, lr=embedding_learning_rate)
413
+ self.text_encoder, self.text_encoder_2, optimizer= accelerator.prepare(self.text_encoder, self.text_encoder_2, optimizer) ###
414
+ orig_embeds_params_1 = accelerator.unwrap_model(self.text_encoder) .get_input_embeddings().weight.data.clone()
415
+ orig_embeds_params_2 = accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight.data.clone()
416
+
417
+ self.text_encoder.train()
418
+ self.text_encoder_2.train()
419
+
420
+ effective_emb_train_steps = max_emb_train_steps//gradient_accumulation_steps
421
+
422
+ if accelerator.is_main_process:
423
+ accelerator.init_trackers("EmbFt", config={
424
+ "embedding_learning_rate": embedding_learning_rate,
425
+ "text_embedding_optimization_steps": effective_emb_train_steps,
426
+ })
427
+
428
+ global_step = 0
429
+
430
+ noise_scheduler = DDPMScheduler.from_pretrained(self.model_id , subfolder="scheduler")
431
+ progress_bar = tqdm(range(0, effective_emb_train_steps),initial=global_step,desc="EmbSteps")
432
+ latents0_1 = image2latent(image_gt_1, vae = self.vae, dtype=weight_dtype)
433
+ latents0_1 = latents0_1.repeat(train_batch_size,1,1,1)
434
+
435
+ latents0_2 = image2latent(image_gt_2, vae = self.vae, dtype=weight_dtype)
436
+ latents0_2 = latents0_2.repeat(train_batch_size,1,1,1)
437
+
438
+ for step in range(max_emb_train_steps):
439
+ with accelerator.accumulate(self.text_encoder, self.text_encoder_2):
440
+ latents_1 = latents0_1.clone().detach()
441
+ noise_1 = torch.randn_like(latents_1)
442
+
443
+ latents_2 = latents0_2.clone().detach()
444
+ noise_2 = torch.randn_like(latents_2)
445
+
446
+ bsz = latents_1.shape[0]
447
+
448
+ timesteps_1 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_1.device)
449
+ timesteps_1 = timesteps_1.long()
450
+ noisy_latents_1 = noise_scheduler.add_noise(latents_1, noise_1, timesteps_1)
451
+
452
+ timesteps_2 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_2.device)
453
+ timesteps_2 = timesteps_2.long()
454
+ noisy_latents_2 = noise_scheduler.add_noise(latents_2, noise_2, timesteps_2)
455
+
456
+ register_attention_disentangled_control(self.unet, decom_controller_1)
457
+ encoder_hidden_states_list_1, add_text_embeds_1, add_time_ids_1 = sdxl_prepare_input_decom(
458
+ set_string_list_1,
459
+ self.tokenizer,
460
+ self.tokenizer_2,
461
+ self.text_encoder,
462
+ self.text_encoder_2,
463
+ length = max_length,
464
+ bsz = train_batch_size,
465
+ weight_dtype = weight_dtype
466
+ )
467
+
468
+ model_pred_1 = self.unet(
469
+ noisy_latents_1,
470
+ timesteps_1,
471
+ encoder_hidden_states=encoder_hidden_states_list_1,
472
+ cross_attention_kwargs=None,
473
+ added_cond_kwargs={"text_embeds": add_text_embeds_1, "time_ids": add_time_ids_1},
474
+ return_dict=False
475
+ )[0]
476
+
477
+ register_attention_disentangled_control(self.unet, decom_controller_2)
478
+ # import pdb; pdb.set_trace()
479
+ encoder_hidden_states_list_2, add_text_embeds_2, add_time_ids_2 = sdxl_prepare_input_decom(
480
+ set_string_list_2,
481
+ self.tokenizer,
482
+ self.tokenizer_2,
483
+ self.text_encoder,
484
+ self.text_encoder_2,
485
+ length = max_length,
486
+ bsz = train_batch_size,
487
+ weight_dtype = weight_dtype
488
+ )
489
+
490
+ model_pred_2 = self.unet(
491
+ noisy_latents_2,
492
+ timesteps_2,
493
+ encoder_hidden_states = encoder_hidden_states_list_2,
494
+ cross_attention_kwargs=None,
495
+ added_cond_kwargs={"text_embeds": add_text_embeds_2, "time_ids": add_time_ids_2},
496
+ return_dict=False
497
+ )[0]
498
+
499
+ loss_1 = F.mse_loss(model_pred_1.float(), noise_1.float(), reduction="mean") /2
500
+ loss_2 = F.mse_loss(model_pred_2.float(), noise_2.float(), reduction="mean") /2
501
+ loss = loss_1 + loss_2
502
+ accelerator.backward(loss)
503
+ optimizer.step()
504
+ optimizer.zero_grad()
505
+
506
+ index_no_updates = torch.ones((len(self.tokenizer),), dtype=torch.bool)
507
+ index_no_updates[self.min_added_id : self.max_added_id + 1] = False
508
+ with torch.no_grad():
509
+ accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight[
510
+ index_no_updates] = orig_embeds_params_1[index_no_updates]
511
+ index_no_updates = torch.ones((len(self.tokenizer_2),), dtype=torch.bool)
512
+ index_no_updates[self.min_added_id : self.max_added_id + 1] = False
513
+ with torch.no_grad():
514
+ accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight[
515
+ index_no_updates] = orig_embeds_params_2[index_no_updates]
516
+
517
+ logs = {"loss": loss.detach().item(), "lr": embedding_learning_rate}
518
+ progress_bar.set_postfix(**logs)
519
+ accelerator.log(logs, step=global_step)
520
+ if accelerator.sync_gradients:
521
+ progress_bar.update(1)
522
+ global_step += 1
523
+
524
+ if global_step >= max_emb_train_steps:
525
+ break
526
+ accelerator.wait_for_everyone()
527
+ accelerator.end_training()
528
+ self.text_encoder = accelerator.unwrap_model(self.text_encoder) .to(dtype = weight_dtype)
529
+ self.text_encoder_2 = accelerator.unwrap_model(self.text_encoder_2).to(dtype = weight_dtype)
530
+
531
+ def train_model_2imgs(
532
+ self,
533
+ image_gt_1,
534
+ image_gt_2,
535
+ set_string_list_1,
536
+ set_string_list_2,
537
+ gradient_accumulation_steps = 5,
538
+ max_diffusion_train_steps = 100,
539
+ diffusion_model_learning_rate = 1e-5,
540
+ train_batch_size = 1,
541
+ train_full_lora = False,
542
+ lora_rank = 4,
543
+ lora_alpha = 4
544
+ ):
545
+ self.unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet").to(device)
546
+ self.unet.ca_dim = 2048
547
+ decom_controller_1 = GroupedCAController(mask_list = self.mask_list)
548
+ decom_controller_2 = GroupedCAController(mask_list = self.mask_list_2)
549
+
550
+ mixed_precision = "fp16"
551
+ accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps,mixed_precision=mixed_precision)
552
+
553
+ weight_dtype = torch.float32
554
+ if accelerator.mixed_precision == "fp16":
555
+ weight_dtype = torch.float16
556
+ elif accelerator.mixed_precision == "bf16":
557
+ weight_dtype = torch.bfloat16
558
+
559
+
560
+ self.vae.requires_grad_(False)
561
+ self.vae.to(device, dtype=weight_dtype)
562
+ self.unet.requires_grad_(False)
563
+ self.unet.train()
564
+
565
+ self.text_encoder.requires_grad_(False)
566
+ self.text_encoder_2.requires_grad_(False)
567
+ if not train_full_lora:
568
+ trainable_params_list = []
569
+ for name, module in self.unet.named_modules():
570
+ module_name = type(module).__name__
571
+ if module_name == "Attention":
572
+ if module.to_k.in_features == 2048: # this is cross attention:
573
+ module.to_k.weight.requires_grad = True
574
+ trainable_params_list.append(module.to_k.weight)
575
+ if module.to_k.bias is not None:
576
+ module.to_k.bias.requires_grad = True
577
+ trainable_params_list.append(module.to_k.bias)
578
+
579
+ module.to_v.weight.requires_grad = True
580
+ trainable_params_list.append(module.to_v.weight)
581
+ if module.to_v.bias is not None:
582
+ module.to_v.bias.requires_grad = True
583
+ trainable_params_list.append(module.to_v.bias)
584
+ module.to_q.weight.requires_grad = True
585
+ trainable_params_list.append(module.to_q.weight)
586
+ if module.to_q.bias is not None:
587
+ module.to_q.bias.requires_grad = True
588
+ trainable_params_list.append(module.to_q.bias)
589
+ else:
590
+ unet_lora_config = LoraConfig(
591
+ r = lora_rank,
592
+ lora_alpha = lora_alpha,
593
+ init_lora_weights="gaussian",
594
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
595
+ )
596
+ self.unet.add_adapter(unet_lora_config)
597
+ print("training full parameters using lora!")
598
+ trainable_params_list = list(filter(lambda p: p.requires_grad, self.unet.parameters()))
599
+
600
+ self.text_encoder.to(device, dtype=weight_dtype)
601
+ self.text_encoder_2.to(device, dtype=weight_dtype)
602
+ optimizer = torch.optim.AdamW(trainable_params_list, lr=diffusion_model_learning_rate)
603
+ self.unet, optimizer = accelerator.prepare(self.unet, optimizer)
604
+ psum2 = sum(p.numel() for p in trainable_params_list)
605
+
606
+ effective_diffusion_train_steps = max_diffusion_train_steps // gradient_accumulation_steps
607
+ if accelerator.is_main_process:
608
+ accelerator.init_trackers("ModelFt", config={
609
+ "diffusion_model_learning_rate": diffusion_model_learning_rate,
610
+ "diffusion_model_optimization_steps": effective_diffusion_train_steps,
611
+ })
612
+
613
+ global_step = 0
614
+ progress_bar = tqdm(range(0, effective_diffusion_train_steps),initial=global_step, desc="ModelSteps")
615
+ noise_scheduler = DDPMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
616
+
617
+ latents0_1 = image2latent(image_gt_1, vae = self.vae, dtype=weight_dtype)
618
+ latents0_1 = latents0_1.repeat(train_batch_size, 1, 1, 1)
619
+
620
+ latents0_2 = image2latent(image_gt_2, vae = self.vae, dtype=weight_dtype)
621
+ latents0_2 = latents0_2.repeat(train_batch_size,1, 1, 1)
622
+
623
+ with torch.no_grad():
624
+ encoder_hidden_states_list_1, add_text_embeds_1, add_time_ids_1 = sdxl_prepare_input_decom(
625
+ set_string_list_1,
626
+ self.tokenizer,
627
+ self.tokenizer_2,
628
+ self.text_encoder,
629
+ self.text_encoder_2,
630
+ length = max_length,
631
+ bsz = train_batch_size,
632
+ weight_dtype = weight_dtype
633
+ )
634
+ encoder_hidden_states_list_2, add_text_embeds_2, add_time_ids_2 = sdxl_prepare_input_decom(
635
+ set_string_list_2,
636
+ self.tokenizer,
637
+ self.tokenizer_2,
638
+ self.text_encoder,
639
+ self.text_encoder_2,
640
+ length = max_length,
641
+ bsz = train_batch_size,
642
+ weight_dtype = weight_dtype
643
+ )
644
+
645
+ for _ in range(max_diffusion_train_steps):
646
+ with accelerator.accumulate(self.unet):
647
+ latents_1 = latents0_1.clone().detach()
648
+ noise_1 = torch.randn_like(latents_1)
649
+ bsz = latents_1.shape[0]
650
+ timesteps_1 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_1.device)
651
+ timesteps_1 = timesteps_1.long()
652
+ noisy_latents_1 = noise_scheduler.add_noise(latents_1, noise_1, timesteps_1)
653
+
654
+ latents_2 = latents0_2.clone().detach()
655
+ noise_2 = torch.randn_like(latents_2)
656
+ bsz = latents_2.shape[0]
657
+ timesteps_2 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_2.device)
658
+ timesteps_2 = timesteps_2.long()
659
+ noisy_latents_2 = noise_scheduler.add_noise(latents_2, noise_2, timesteps_2)
660
+
661
+ register_attention_disentangled_control(self.unet, decom_controller_1)
662
+ model_pred_1 = self.unet(
663
+ noisy_latents_1,
664
+ timesteps_1,
665
+ encoder_hidden_states = encoder_hidden_states_list_1,
666
+ cross_attention_kwargs = None,
667
+ return_dict = False,
668
+ added_cond_kwargs = {"text_embeds": add_text_embeds_1, "time_ids": add_time_ids_1}
669
+ )[0]
670
+
671
+ register_attention_disentangled_control(self.unet, decom_controller_2)
672
+ model_pred_2 = self.unet(
673
+ noisy_latents_2,
674
+ timesteps_2,
675
+ encoder_hidden_states = encoder_hidden_states_list_2,
676
+ cross_attention_kwargs = None,
677
+ return_dict=False,
678
+ added_cond_kwargs={"text_embeds": add_text_embeds_2, "time_ids": add_time_ids_2}
679
+ )[0]
680
+
681
+ loss_1 = F.mse_loss(model_pred_1.float(), noise_1.float(), reduction="mean")
682
+ loss_2 = F.mse_loss(model_pred_2.float(), noise_2.float(), reduction="mean")
683
+ loss = loss_1 + loss_2
684
+ accelerator.backward(loss)
685
+ optimizer.step()
686
+ optimizer.zero_grad()
687
+
688
+
689
+ logs = {"loss": loss.detach().item(), "lr": diffusion_model_learning_rate}
690
+ progress_bar.set_postfix(**logs)
691
+ accelerator.log(logs, step=global_step)
692
+ if accelerator.sync_gradients:
693
+ progress_bar.update(1)
694
+ global_step += 1
695
+
696
+ if global_step >=max_diffusion_train_steps:
697
+ break
698
+ accelerator.wait_for_everyone()
699
+ accelerator.end_training()
700
+ self.unet = accelerator.unwrap_model(self.unet).to(dtype = weight_dtype)
701
+
702
+ @torch.no_grad()
703
+ def backward_zT_to_z0_euler_decom(
704
+ self,
705
+ zT,
706
+ cond_emb_list,
707
+ cond_add_text_embeds,
708
+ add_time_ids,
709
+ uncond_emb=None,
710
+ guidance_scale = 1,
711
+ num_sampling_steps = 20,
712
+ cond_controller = None,
713
+ uncond_controller = None,
714
+ mask_hard = None,
715
+ mask_soft = None,
716
+ orig_image = None,
717
+ return_intermediate = False,
718
+ strength = 1
719
+ ):
720
+ latent_cur = zT
721
+ if uncond_emb is None:
722
+ uncond_emb = torch.zeros(zT.shape[0], 77, 2048).to(dtype = zT.dtype, device = zT.device)
723
+ uncond_add_text_embeds = torch.zeros(1, 1280).to(dtype = zT.dtype, device = zT.device)
724
+ if mask_soft is not None:
725
+ init_latents_orig = image2latent(orig_image, self.vae, dtype=self.vae.dtype)
726
+ length = init_latents_orig.shape[-1]
727
+ noise = torch.randn_like(init_latents_orig)
728
+ mask_soft = torch.nn.functional.interpolate(mask_soft.float().unsqueeze(0).unsqueeze(0), (length, length)).to(self.vae.dtype) ###
729
+ if mask_hard is not None:
730
+ init_latents_orig = image2latent(orig_image, self.vae, dtype=self.vae.dtype)
731
+ length = init_latents_orig.shape[-1]
732
+ noise = torch.randn_like(init_latents_orig)
733
+ mask_hard = torch.nn.functional.interpolate(mask_hard.float().unsqueeze(0).unsqueeze(0), (length, length)).to(self.vae.dtype) ###
734
+
735
+ intermediate_list = [latent_cur.detach()]
736
+ for i in tqdm(range(num_sampling_steps)):
737
+ t = self.scheduler.timesteps[i]
738
+ latent_input = self.scheduler.scale_model_input(latent_cur, t)
739
+
740
+ register_attention_disentangled_control(self.unet, uncond_controller)
741
+ noise_pred_uncond = self.unet(latent_input, t,
742
+ encoder_hidden_states=uncond_emb,
743
+ added_cond_kwargs={"text_embeds": uncond_add_text_embeds, "time_ids": add_time_ids},
744
+ return_dict=False,)[0]
745
+
746
+ register_attention_disentangled_control(self.unet, cond_controller)
747
+ noise_pred_cond = self.unet(latent_input, t,
748
+ encoder_hidden_states=cond_emb_list,
749
+ added_cond_kwargs={"text_embeds": cond_add_text_embeds, "time_ids": add_time_ids},
750
+ return_dict=False,)[0]
751
+
752
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
753
+ latent_cur = self.scheduler.step(noise_pred, t, latent_cur, generator = None, return_dict=False)[0]
754
+ if return_intermediate is True:
755
+ intermediate_list.append(latent_cur)
756
+ if mask_hard is not None and mask_soft is not None and i <= strength *num_sampling_steps:
757
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
758
+ mask = mask_soft.to(latent_cur.device, latent_cur.dtype) + mask_hard.to(latent_cur.device, latent_cur.dtype)
759
+ latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
760
+
761
+ elif mask_hard is not None and mask_soft is not None and i > strength *num_sampling_steps:
762
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
763
+ mask = mask_hard.to(latent_cur.device, latent_cur.dtype)
764
+ latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
765
+
766
+ elif mask_hard is None and mask_soft is not None and i <= strength *num_sampling_steps:
767
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
768
+ mask = mask_soft.to(latent_cur.device, latent_cur.dtype)
769
+ latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
770
+
771
+ elif mask_hard is None and mask_soft is not None and i > strength *num_sampling_steps:
772
+ pass
773
+
774
+ elif mask_hard is not None and mask_soft is None:
775
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
776
+ mask = mask_hard.to(latent_cur.dtype)
777
+ latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
778
+
779
+ else: # hard and soft are both none
780
+ pass
781
+
782
+ if return_intermediate is True:
783
+ return latent_cur, intermediate_list
784
+ else:
785
+ return latent_cur
786
+
787
+ @torch.no_grad()
788
+ def sampling(
789
+ self,
790
+ set_string_list,
791
+ cond_controller = None,
792
+ uncond_controller = None,
793
+ guidance_scale = 7,
794
+ num_sampling_steps = 20,
795
+ mask_hard = None,
796
+ mask_soft = None,
797
+ orig_image = None,
798
+ strength = 1.,
799
+ num_imgs = 1,
800
+ normal_token_id_list = [],
801
+ seed = 1
802
+ ):
803
+ weight_dtype = torch.float16
804
+ self.scheduler.set_timesteps(num_sampling_steps)
805
+ self.unet.to(device, dtype=weight_dtype)
806
+ self.vae.to(device, dtype=weight_dtype)
807
+ self.text_encoder.to(device, dtype=weight_dtype)
808
+ self.text_encoder_2.to(device, dtype=weight_dtype)
809
+ torch.manual_seed(seed)
810
+ torch.cuda.manual_seed(seed)
811
+
812
+ vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
813
+ zT = torch.randn(num_imgs, 4, self.resolution//vae_scale_factor,self.resolution//vae_scale_factor).to(device,dtype=weight_dtype)
814
+ zT = zT * self.scheduler.init_noise_sigma
815
+
816
+ cond_emb_list, cond_add_text_embeds, add_time_ids = sdxl_prepare_input_decom(
817
+ set_string_list,
818
+ self.tokenizer,
819
+ self.tokenizer_2,
820
+ self.text_encoder,
821
+ self.text_encoder_2,
822
+ length = max_length,
823
+ bsz = num_imgs,
824
+ weight_dtype = weight_dtype,
825
+ normal_token_id_list = normal_token_id_list
826
+ )
827
+
828
+ z0 = self.backward_zT_to_z0_euler_decom(zT, cond_emb_list, cond_add_text_embeds, add_time_ids,
829
+ guidance_scale = guidance_scale, num_sampling_steps = num_sampling_steps,
830
+ cond_controller = cond_controller, uncond_controller = uncond_controller,
831
+ mask_hard = mask_hard, mask_soft = mask_soft, orig_image =orig_image, strength = strength
832
+ )
833
+ x0 = latent2image(z0, vae = self.vae)
834
+ return x0
835
+
836
+ @torch.no_grad()
837
+ def inference_with_mask(
838
+ self,
839
+ save_path,
840
+ guidance_scale = 3,
841
+ num_sampling_steps = 50,
842
+ strength = 1,
843
+ mask_soft = None,
844
+ mask_hard= None,
845
+ orig_image=None,
846
+ mask_list = None,
847
+ num_imgs = 1,
848
+ seed = 1,
849
+ set_string_list = None
850
+ ):
851
+ if mask_list is not None:
852
+ mask_list = [m.to(device) for m in mask_list]
853
+ else:
854
+ mask_list = self.mask_list
855
+ if set_string_list is not None:
856
+ self.set_string_list = set_string_list
857
+
858
+ if mask_hard is not None and mask_soft is not None:
859
+ check_mask_overlap_torch(mask_hard, mask_soft)
860
+ null_controller = DummyController()
861
+ decom_controller = GroupedCAController(mask_list = mask_list)
862
+ x0 = self.sampling(
863
+ self.set_string_list,
864
+ guidance_scale = guidance_scale,
865
+ num_sampling_steps = num_sampling_steps,
866
+ strength = strength,
867
+ cond_controller = decom_controller,
868
+ uncond_controller = null_controller,
869
+ mask_soft = mask_soft,
870
+ mask_hard = mask_hard,
871
+ orig_image = orig_image,
872
+ num_imgs = num_imgs,
873
+ seed = seed
874
+ )
875
+ save_images(x0, save_path)