Spaces:
Sleeping
Sleeping
from typing import Dict, List, Optional, Union | |
import torch | |
from diffusers.loaders.lora import LoraLoaderMixin | |
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT | |
from diffusers.utils import USE_PEFT_BACKEND | |
class LoraLoaderWithWarmup(LoraLoaderMixin): | |
unet_warmup_name = "unet_warmup" | |
def load_lora_weights( | |
self, | |
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | |
adapter_name=None, | |
**kwargs, | |
): | |
# load lora for text encoder and unet-streaming | |
super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name=adapter_name, **kwargs) | |
# load lora for unet-warmup | |
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | |
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) | |
self.load_lora_into_unet( | |
state_dict, | |
network_alphas=network_alphas, | |
unet=getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup, | |
low_cpu_mem_usage=low_cpu_mem_usage, | |
adapter_name=adapter_name, | |
_pipeline=self, | |
) | |
def fuse_lora( | |
self, | |
fuse_unet: bool = True, | |
fuse_text_encoder: bool = True, | |
lora_scale: float = 1.0, | |
safe_fusing: bool = False, | |
adapter_names: Optional[List[str]] = None, | |
): | |
# fuse lora for text encoder and unet-streaming | |
super().fuse_lora(fuse_unet, fuse_text_encoder, lora_scale, safe_fusing, adapter_names) | |
# fuse lora for unet-warmup | |
if fuse_unet: | |
unet_warmup = ( | |
getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup | |
) | |
unet_warmup.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) | |
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): | |
# unfuse lora for text encoder and unet-streaming | |
super().unfuse_lora(unfuse_unet, unfuse_text_encoder) | |
# unfuse lora for unet-warmup | |
if unfuse_unet: | |
unet_warmup = ( | |
getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup | |
) | |
if not USE_PEFT_BACKEND: | |
unet_warmup.unfuse_lora() | |
else: | |
from peft.tuners.tuners_utils import BaseTunerLayer | |
for module in unet_warmup.modules(): | |
if isinstance(module, BaseTunerLayer): | |
module.unmerge() | |