from pathlib import Path from re import S from typing import List, Union from diffusers import EulerDiscreteScheduler, StableDiffusionXLPipeline from diffusers.loaders.lora import StableDiffusionXLLoraLoaderMixin from torchvision.datasets.utils import download_url class LightningMixin: LORA_8_STEP_URL = "https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_8step_lora.safetensors" __scheduler_old = None __pipe: StableDiffusionXLPipeline = None __scheduler = None def configure_sdxl_lightning(self, pipe: StableDiffusionXLPipeline): lora_path = Path.home() / ".cache" / "lora_8_step.safetensors" download_url(self.LORA_8_STEP_URL, str(lora_path.parent), lora_path.name) pipe.load_lora_weights(str(lora_path), adapter_name="8step_lora") pipe.set_adapters([]) self.__scheduler = EulerDiscreteScheduler.from_config( pipe.scheduler.config, timestep_spacing="trailing" ) self.__scheduler_old = pipe.scheduler self.__pipe = pipe def enable_sdxl_lightning(self): pipe = self.__pipe pipe.scheduler = self.__scheduler current = pipe.get_active_adapters() current.extend(["8step_lora"]) weights = self.__find_adapter_weights(current) pipe.set_adapters(current, adapter_weights=weights) return {"guidance_scale": 0, "num_inference_steps": 8} def disable_sdxl_lightning(self): pipe = self.__pipe pipe.scheduler = self.__scheduler_old current = pipe.get_active_adapters() current = [adapter for adapter in current if adapter != "8step_lora"] weights = self.__find_adapter_weights(current) pipe.set_adapters(current, adapter_weights=weights) def __find_adapter_weights(self, names: List[str]): pipe = self.__pipe model = pipe.unet from peft.tuners.tuners_utils import BaseTunerLayer weights = [] for adapter_name in names: weight = 1.0 for module in model.modules(): if isinstance(module, BaseTunerLayer): if adapter_name in module.scaling: weight = ( module.scaling[adapter_name] * module.r[adapter_name] / module.lora_alpha[adapter_name] ) weights.append(weight) return weights