Spaces:
Running
on
Zero
Running
on
Zero
HuiZhang
commited on
Update src/pipeline/pipeline_CreatiLayout.py
Browse files
src/pipeline/pipeline_CreatiLayout.py
CHANGED
@@ -420,21 +420,21 @@ class CreatiLayoutSD3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleF
|
|
420 |
clip_skip=clip_skip,
|
421 |
clip_model_index=1,
|
422 |
)
|
423 |
-
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
424 |
|
425 |
t5_prompt_embed = self._get_t5_prompt_embeds(
|
426 |
prompt=prompt_3,
|
427 |
num_images_per_prompt=num_images_per_prompt,
|
428 |
max_sequence_length=max_sequence_length,
|
429 |
device=device,
|
430 |
-
)
|
431 |
|
432 |
clip_prompt_embeds = torch.nn.functional.pad(
|
433 |
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
434 |
-
)
|
435 |
|
436 |
-
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
437 |
-
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
438 |
|
439 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
440 |
negative_prompt = negative_prompt or ""
|
@@ -867,15 +867,9 @@ class CreatiLayoutSD3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleF
|
|
867 |
# 5.5 layout
|
868 |
max_objs = 10
|
869 |
if len(bbox_raw) > max_objs:
|
870 |
-
|
871 |
print(f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.")
|
872 |
-
|
873 |
bbox_phrases = bbox_phrases[:max_objs]
|
874 |
bbox_raw = bbox_raw[:max_objs]
|
875 |
-
# prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask)
|
876 |
-
# Get tokens for phrases from pre-trained CLIPTokenizer
|
877 |
-
# from IPython.core.debugger import set_trace
|
878 |
-
# set_trace()
|
879 |
tokenizer_inputs = self.tokenizer(
|
880 |
bbox_phrases,
|
881 |
padding="max_length",
|
@@ -883,8 +877,6 @@ class CreatiLayoutSD3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleF
|
|
883 |
truncation=True,
|
884 |
return_tensors="pt",
|
885 |
).input_ids.to(device)
|
886 |
-
# For the token, we use the same pre-trained text encoder
|
887 |
-
# to obtain its text feature
|
888 |
|
889 |
text_embeddings_1 = self.text_encoder(tokenizer_inputs.to(device), output_hidden_states=True)[0]
|
890 |
|
@@ -896,9 +888,7 @@ class CreatiLayoutSD3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleF
|
|
896 |
truncation=True,
|
897 |
return_tensors="pt",
|
898 |
).input_ids.to(device)
|
899 |
-
|
900 |
-
# to obtain its text feature
|
901 |
-
|
902 |
text_embeddings_2 = self.text_encoder_2(tokenizer_inputs_2.to(device), output_hidden_states=True)[0]
|
903 |
|
904 |
clip_text_embeddings = torch.cat([text_embeddings_1, text_embeddings_2], dim=-1)
|
|
|
420 |
clip_skip=clip_skip,
|
421 |
clip_model_index=1,
|
422 |
)
|
423 |
+
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
424 |
|
425 |
t5_prompt_embed = self._get_t5_prompt_embeds(
|
426 |
prompt=prompt_3,
|
427 |
num_images_per_prompt=num_images_per_prompt,
|
428 |
max_sequence_length=max_sequence_length,
|
429 |
device=device,
|
430 |
+
)
|
431 |
|
432 |
clip_prompt_embeds = torch.nn.functional.pad(
|
433 |
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
434 |
+
)
|
435 |
|
436 |
+
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
437 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
438 |
|
439 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
440 |
negative_prompt = negative_prompt or ""
|
|
|
867 |
# 5.5 layout
|
868 |
max_objs = 10
|
869 |
if len(bbox_raw) > max_objs:
|
|
|
870 |
print(f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.")
|
|
|
871 |
bbox_phrases = bbox_phrases[:max_objs]
|
872 |
bbox_raw = bbox_raw[:max_objs]
|
|
|
|
|
|
|
|
|
873 |
tokenizer_inputs = self.tokenizer(
|
874 |
bbox_phrases,
|
875 |
padding="max_length",
|
|
|
877 |
truncation=True,
|
878 |
return_tensors="pt",
|
879 |
).input_ids.to(device)
|
|
|
|
|
880 |
|
881 |
text_embeddings_1 = self.text_encoder(tokenizer_inputs.to(device), output_hidden_states=True)[0]
|
882 |
|
|
|
888 |
truncation=True,
|
889 |
return_tensors="pt",
|
890 |
).input_ids.to(device)
|
891 |
+
|
|
|
|
|
892 |
text_embeddings_2 = self.text_encoder_2(tokenizer_inputs_2.to(device), output_hidden_states=True)[0]
|
893 |
|
894 |
clip_text_embeddings = torch.cat([text_embeddings_1, text_embeddings_2], dim=-1)
|