|
from pathlib import Path |
|
from re import S |
|
from typing import List, Union |
|
|
|
from diffusers import EulerDiscreteScheduler, StableDiffusionXLPipeline |
|
from diffusers.loaders.lora import StableDiffusionXLLoraLoaderMixin |
|
from torchvision.datasets.utils import download_url |
|
|
|
|
|
class LightningMixin: |
|
LORA_8_STEP_URL = "https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_8step_lora.safetensors" |
|
|
|
__scheduler_old = None |
|
__pipe: StableDiffusionXLPipeline = None |
|
__scheduler = None |
|
|
|
def configure_sdxl_lightning(self, pipe: StableDiffusionXLPipeline): |
|
lora_path = Path.home() / ".cache" / "lora_8_step.safetensors" |
|
|
|
download_url(self.LORA_8_STEP_URL, str(lora_path.parent), lora_path.name) |
|
|
|
pipe.load_lora_weights(str(lora_path), adapter_name="8step_lora") |
|
pipe.set_adapters([]) |
|
|
|
self.__scheduler = EulerDiscreteScheduler.from_config( |
|
pipe.scheduler.config, timestep_spacing="trailing" |
|
) |
|
self.__scheduler_old = pipe.scheduler |
|
self.__pipe = pipe |
|
|
|
def enable_sdxl_lightning(self): |
|
pipe = self.__pipe |
|
pipe.scheduler = self.__scheduler |
|
|
|
current = pipe.get_active_adapters() |
|
current.extend(["8step_lora"]) |
|
|
|
weights = self.__find_adapter_weights(current) |
|
pipe.set_adapters(current, adapter_weights=weights) |
|
|
|
return {"guidance_scale": 0, "num_inference_steps": 8} |
|
|
|
def disable_sdxl_lightning(self): |
|
pipe = self.__pipe |
|
pipe.scheduler = self.__scheduler_old |
|
|
|
current = pipe.get_active_adapters() |
|
current = [adapter for adapter in current if adapter != "8step_lora"] |
|
|
|
weights = self.__find_adapter_weights(current) |
|
pipe.set_adapters(current, adapter_weights=weights) |
|
|
|
def __find_adapter_weights(self, names: List[str]): |
|
pipe = self.__pipe |
|
|
|
model = pipe.unet |
|
|
|
from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
|
weights = [] |
|
for adapter_name in names: |
|
weight = 1.0 |
|
for module in model.modules(): |
|
if isinstance(module, BaseTunerLayer): |
|
if adapter_name in module.scaling: |
|
weight = ( |
|
module.scaling[adapter_name] |
|
* module.r[adapter_name] |
|
/ module.lora_alpha[adapter_name] |
|
) |
|
|
|
weights.append(weight) |
|
|
|
return weights |
|
|