HuiZhang commited on
Commit
6aed6eb
·
verified ·
1 Parent(s): 9f98f9a

Update src/pipeline/pipeline_CreatiLayout.py

Browse files
src/pipeline/pipeline_CreatiLayout.py CHANGED
@@ -420,21 +420,21 @@ class CreatiLayoutSD3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleF
420
  clip_skip=clip_skip,
421
  clip_model_index=1,
422
  )
423
- clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) # torch.Size([B, 77, 768])+ torch.Size([B, 77, 1280])-> torch.Size([B, 77, 2048])
424
 
425
  t5_prompt_embed = self._get_t5_prompt_embeds(
426
  prompt=prompt_3,
427
  num_images_per_prompt=num_images_per_prompt,
428
  max_sequence_length=max_sequence_length,
429
  device=device,
430
- ) # [B,256,4096]
431
 
432
  clip_prompt_embeds = torch.nn.functional.pad(
433
  clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
434
- ) # [B,77,4096]
435
 
436
- prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) # torch.Size([B, 333(256+77), 4096])
437
- pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)# [B,2048]
438
 
439
  if do_classifier_free_guidance and negative_prompt_embeds is None:
440
  negative_prompt = negative_prompt or ""
@@ -867,15 +867,9 @@ class CreatiLayoutSD3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleF
867
  # 5.5 layout
868
  max_objs = 10
869
  if len(bbox_raw) > max_objs:
870
-
871
  print(f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.")
872
-
873
  bbox_phrases = bbox_phrases[:max_objs]
874
  bbox_raw = bbox_raw[:max_objs]
875
- # prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask)
876
- # Get tokens for phrases from pre-trained CLIPTokenizer
877
- # from IPython.core.debugger import set_trace
878
- # set_trace()
879
  tokenizer_inputs = self.tokenizer(
880
  bbox_phrases,
881
  padding="max_length",
@@ -883,8 +877,6 @@ class CreatiLayoutSD3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleF
883
  truncation=True,
884
  return_tensors="pt",
885
  ).input_ids.to(device)
886
- # For the token, we use the same pre-trained text encoder
887
- # to obtain its text feature
888
 
889
  text_embeddings_1 = self.text_encoder(tokenizer_inputs.to(device), output_hidden_states=True)[0]
890
 
@@ -896,9 +888,7 @@ class CreatiLayoutSD3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleF
896
  truncation=True,
897
  return_tensors="pt",
898
  ).input_ids.to(device)
899
- # For the token, we use the same pre-trained text encoder
900
- # to obtain its text feature
901
-
902
  text_embeddings_2 = self.text_encoder_2(tokenizer_inputs_2.to(device), output_hidden_states=True)[0]
903
 
904
  clip_text_embeddings = torch.cat([text_embeddings_1, text_embeddings_2], dim=-1)
 
420
  clip_skip=clip_skip,
421
  clip_model_index=1,
422
  )
423
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
424
 
425
  t5_prompt_embed = self._get_t5_prompt_embeds(
426
  prompt=prompt_3,
427
  num_images_per_prompt=num_images_per_prompt,
428
  max_sequence_length=max_sequence_length,
429
  device=device,
430
+ )
431
 
432
  clip_prompt_embeds = torch.nn.functional.pad(
433
  clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
434
+ )
435
 
436
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
437
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
438
 
439
  if do_classifier_free_guidance and negative_prompt_embeds is None:
440
  negative_prompt = negative_prompt or ""
 
867
  # 5.5 layout
868
  max_objs = 10
869
  if len(bbox_raw) > max_objs:
 
870
  print(f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.")
 
871
  bbox_phrases = bbox_phrases[:max_objs]
872
  bbox_raw = bbox_raw[:max_objs]
 
 
 
 
873
  tokenizer_inputs = self.tokenizer(
874
  bbox_phrases,
875
  padding="max_length",
 
877
  truncation=True,
878
  return_tensors="pt",
879
  ).input_ids.to(device)
 
 
880
 
881
  text_embeddings_1 = self.text_encoder(tokenizer_inputs.to(device), output_hidden_states=True)[0]
882
 
 
888
  truncation=True,
889
  return_tensors="pt",
890
  ).input_ids.to(device)
891
+
 
 
892
  text_embeddings_2 = self.text_encoder_2(tokenizer_inputs_2.to(device), output_hidden_states=True)[0]
893
 
894
  clip_text_embeddings = torch.cat([text_embeddings_1, text_embeddings_2], dim=-1)