leoxing1996
add demo
d16b52d
from typing import Dict, List, Optional, Union
import torch
from diffusers.loaders.lora import LoraLoaderMixin
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from diffusers.utils import USE_PEFT_BACKEND
class LoraLoaderWithWarmup(LoraLoaderMixin):
unet_warmup_name = "unet_warmup"
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name=None,
**kwargs,
):
# load lora for text encoder and unet-streaming
super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name=adapter_name, **kwargs)
# load lora for unet-warmup
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
self.load_lora_into_unet(
state_dict,
network_alphas=network_alphas,
unet=getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
_pipeline=self,
)
def fuse_lora(
self,
fuse_unet: bool = True,
fuse_text_encoder: bool = True,
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
):
# fuse lora for text encoder and unet-streaming
super().fuse_lora(fuse_unet, fuse_text_encoder, lora_scale, safe_fusing, adapter_names)
# fuse lora for unet-warmup
if fuse_unet:
unet_warmup = (
getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup
)
unet_warmup.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
# unfuse lora for text encoder and unet-streaming
super().unfuse_lora(unfuse_unet, unfuse_text_encoder)
# unfuse lora for unet-warmup
if unfuse_unet:
unet_warmup = (
getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup
)
if not USE_PEFT_BACKEND:
unet_warmup.unfuse_lora()
else:
from peft.tuners.tuners_utils import BaseTunerLayer
for module in unet_warmup.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()