tolgacangoz commited on
Commit
9228d10
·
verified ·
1 Parent(s): ca01bef

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. scheduler/matryoshka.py +35 -6
scheduler/matryoshka.py CHANGED
@@ -20,6 +20,7 @@
20
 
21
 
22
  import inspect
 
23
  import math
24
  from dataclasses import dataclass
25
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -3753,7 +3754,7 @@ class MatryoshkaPipeline(
3753
  """
3754
 
3755
  model_cpu_offload_seq = "text_encoder->image_encoder->unet"
3756
- _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
3757
  _exclude_from_cpu_offload = ["safety_checker"]
3758
  _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
3759
 
@@ -3830,6 +3831,9 @@ class MatryoshkaPipeline(
3830
  new_config["sample_size"] = 64
3831
  unet._internal_dict = FrozenDict(new_config)
3832
 
 
 
 
3833
  self.register_modules(
3834
  text_encoder=text_encoder,
3835
  tokenizer=tokenizer,
@@ -3838,10 +3842,32 @@ class MatryoshkaPipeline(
3838
  feature_extractor=feature_extractor,
3839
  image_encoder=image_encoder,
3840
  )
3841
- if hasattr(unet, "nest_ratio"):
3842
- scheduler.scales = unet.nest_ratio + [1]
3843
  self.image_processor = VaeImageProcessor(do_resize=False)
3844
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3845
  def encode_prompt(
3846
  self,
3847
  prompt,
@@ -4623,9 +4649,12 @@ class MatryoshkaPipeline(
4623
  image = latents
4624
 
4625
  if self.scheduler.scales is not None:
4626
- for i in range(len(image)):
4627
- image[i] = image[i] * self.scheduler.scales[i]
4628
- image[i] = self.image_processor.postprocess(image[i], output_type=output_type)
 
 
 
4629
  else:
4630
  image = self.image_processor.postprocess(image, output_type=output_type)
4631
 
 
20
 
21
 
22
  import inspect
23
+ import gc
24
  import math
25
  from dataclasses import dataclass
26
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
3754
  """
3755
 
3756
  model_cpu_offload_seq = "text_encoder->image_encoder->unet"
3757
+ _optional_components = ["unet", "safety_checker", "feature_extractor", "image_encoder"]
3758
  _exclude_from_cpu_offload = ["safety_checker"]
3759
  _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
3760
 
 
3831
  new_config["sample_size"] = 64
3832
  unet._internal_dict = FrozenDict(new_config)
3833
 
3834
+ if hasattr(unet, "nest_ratio"):
3835
+ scheduler.scales = unet.nest_ratio + [1]
3836
+
3837
  self.register_modules(
3838
  text_encoder=text_encoder,
3839
  tokenizer=tokenizer,
 
3842
  feature_extractor=feature_extractor,
3843
  image_encoder=image_encoder,
3844
  )
3845
+ self.register_to_config(nesting_level=nesting_level)
 
3846
  self.image_processor = VaeImageProcessor(do_resize=False)
3847
 
3848
+ def change_nesting_level(self, nesting_level: int):
3849
+ if nesting_level == 0:
3850
+ if hasattr(self.unet, "nest_ratio"):
3851
+ self.scheduler.scales = None
3852
+ self.unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3853
+ subfolder="unet/nesting_level_0").to(self.device)
3854
+ self.config.nesting_level = 0
3855
+ elif nesting_level == 1:
3856
+ self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3857
+ subfolder="unet/nesting_level_1").to(self.device)
3858
+ self.config.nesting_level = 1
3859
+ self.scheduler.scales = self.unet.nest_ratio + [1]
3860
+ elif nesting_level == 2:
3861
+ self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3862
+ subfolder="unet/nesting_level_2").to(self.device)
3863
+ self.config.nesting_level = 2
3864
+ self.scheduler.scales = self.unet.nest_ratio + [1]
3865
+ else:
3866
+ raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3867
+
3868
+ gc.collect()
3869
+ torch.cuda.empty_cache()
3870
+
3871
  def encode_prompt(
3872
  self,
3873
  prompt,
 
4649
  image = latents
4650
 
4651
  if self.scheduler.scales is not None:
4652
+ scales = [
4653
+ image[i].size(-1) / image[-1].size(-1)
4654
+ for i in range(len(image))
4655
+ ]
4656
+ for i, (img, scale) in enumerate(zip(image, scales)):
4657
+ image[i] = self.image_processor.postprocess(img * scale, output_type=output_type)[0]
4658
  else:
4659
  image = self.image_processor.postprocess(image, output_type=output_type)
4660