Upload matryoshka.py
Browse files- scheduler/matryoshka.py +35 -6
scheduler/matryoshka.py
CHANGED
@@ -20,6 +20,7 @@
|
|
20 |
|
21 |
|
22 |
import inspect
|
|
|
23 |
import math
|
24 |
from dataclasses import dataclass
|
25 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
@@ -3753,7 +3754,7 @@ class MatryoshkaPipeline(
|
|
3753 |
"""
|
3754 |
|
3755 |
model_cpu_offload_seq = "text_encoder->image_encoder->unet"
|
3756 |
-
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
3757 |
_exclude_from_cpu_offload = ["safety_checker"]
|
3758 |
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
3759 |
|
@@ -3830,6 +3831,9 @@ class MatryoshkaPipeline(
|
|
3830 |
new_config["sample_size"] = 64
|
3831 |
unet._internal_dict = FrozenDict(new_config)
|
3832 |
|
|
|
|
|
|
|
3833 |
self.register_modules(
|
3834 |
text_encoder=text_encoder,
|
3835 |
tokenizer=tokenizer,
|
@@ -3838,10 +3842,32 @@ class MatryoshkaPipeline(
|
|
3838 |
feature_extractor=feature_extractor,
|
3839 |
image_encoder=image_encoder,
|
3840 |
)
|
3841 |
-
|
3842 |
-
scheduler.scales = unet.nest_ratio + [1]
|
3843 |
self.image_processor = VaeImageProcessor(do_resize=False)
|
3844 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3845 |
def encode_prompt(
|
3846 |
self,
|
3847 |
prompt,
|
@@ -4623,9 +4649,12 @@ class MatryoshkaPipeline(
|
|
4623 |
image = latents
|
4624 |
|
4625 |
if self.scheduler.scales is not None:
|
4626 |
-
|
4627 |
-
image[i]
|
4628 |
-
|
|
|
|
|
|
|
4629 |
else:
|
4630 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
4631 |
|
|
|
20 |
|
21 |
|
22 |
import inspect
|
23 |
+
import gc
|
24 |
import math
|
25 |
from dataclasses import dataclass
|
26 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
3754 |
"""
|
3755 |
|
3756 |
model_cpu_offload_seq = "text_encoder->image_encoder->unet"
|
3757 |
+
_optional_components = ["unet", "safety_checker", "feature_extractor", "image_encoder"]
|
3758 |
_exclude_from_cpu_offload = ["safety_checker"]
|
3759 |
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
3760 |
|
|
|
3831 |
new_config["sample_size"] = 64
|
3832 |
unet._internal_dict = FrozenDict(new_config)
|
3833 |
|
3834 |
+
if hasattr(unet, "nest_ratio"):
|
3835 |
+
scheduler.scales = unet.nest_ratio + [1]
|
3836 |
+
|
3837 |
self.register_modules(
|
3838 |
text_encoder=text_encoder,
|
3839 |
tokenizer=tokenizer,
|
|
|
3842 |
feature_extractor=feature_extractor,
|
3843 |
image_encoder=image_encoder,
|
3844 |
)
|
3845 |
+
self.register_to_config(nesting_level=nesting_level)
|
|
|
3846 |
self.image_processor = VaeImageProcessor(do_resize=False)
|
3847 |
|
3848 |
+
def change_nesting_level(self, nesting_level: int):
|
3849 |
+
if nesting_level == 0:
|
3850 |
+
if hasattr(self.unet, "nest_ratio"):
|
3851 |
+
self.scheduler.scales = None
|
3852 |
+
self.unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3853 |
+
subfolder="unet/nesting_level_0").to(self.device)
|
3854 |
+
self.config.nesting_level = 0
|
3855 |
+
elif nesting_level == 1:
|
3856 |
+
self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3857 |
+
subfolder="unet/nesting_level_1").to(self.device)
|
3858 |
+
self.config.nesting_level = 1
|
3859 |
+
self.scheduler.scales = self.unet.nest_ratio + [1]
|
3860 |
+
elif nesting_level == 2:
|
3861 |
+
self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
3862 |
+
subfolder="unet/nesting_level_2").to(self.device)
|
3863 |
+
self.config.nesting_level = 2
|
3864 |
+
self.scheduler.scales = self.unet.nest_ratio + [1]
|
3865 |
+
else:
|
3866 |
+
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
3867 |
+
|
3868 |
+
gc.collect()
|
3869 |
+
torch.cuda.empty_cache()
|
3870 |
+
|
3871 |
def encode_prompt(
|
3872 |
self,
|
3873 |
prompt,
|
|
|
4649 |
image = latents
|
4650 |
|
4651 |
if self.scheduler.scales is not None:
|
4652 |
+
scales = [
|
4653 |
+
image[i].size(-1) / image[-1].size(-1)
|
4654 |
+
for i in range(len(image))
|
4655 |
+
]
|
4656 |
+
for i, (img, scale) in enumerate(zip(image, scales)):
|
4657 |
+
image[i] = self.image_processor.postprocess(img * scale, output_type=output_type)[0]
|
4658 |
else:
|
4659 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
4660 |
|