CM2000112 / internals /util /sdxl_lightning.py
jayparmr's picture
update : inference
35575bb verified
from pathlib import Path
from re import S
from typing import List, Union
from diffusers import EulerDiscreteScheduler, StableDiffusionXLPipeline
from diffusers.loaders.lora import StableDiffusionXLLoraLoaderMixin
from torchvision.datasets.utils import download_url
class LightningMixin:
LORA_8_STEP_URL = "https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_8step_lora.safetensors"
__scheduler_old = None
__pipe: StableDiffusionXLPipeline = None
__scheduler = None
def configure_sdxl_lightning(self, pipe: StableDiffusionXLPipeline):
lora_path = Path.home() / ".cache" / "lora_8_step.safetensors"
download_url(self.LORA_8_STEP_URL, str(lora_path.parent), lora_path.name)
pipe.load_lora_weights(str(lora_path), adapter_name="8step_lora")
pipe.set_adapters([])
self.__scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
self.__scheduler_old = pipe.scheduler
self.__pipe = pipe
def enable_sdxl_lightning(self):
pipe = self.__pipe
pipe.scheduler = self.__scheduler
current = pipe.get_active_adapters()
current.extend(["8step_lora"])
weights = self.__find_adapter_weights(current)
pipe.set_adapters(current, adapter_weights=weights)
return {"guidance_scale": 0, "num_inference_steps": 8}
def disable_sdxl_lightning(self):
pipe = self.__pipe
pipe.scheduler = self.__scheduler_old
current = pipe.get_active_adapters()
current = [adapter for adapter in current if adapter != "8step_lora"]
weights = self.__find_adapter_weights(current)
pipe.set_adapters(current, adapter_weights=weights)
def __find_adapter_weights(self, names: List[str]):
pipe = self.__pipe
model = pipe.unet
from peft.tuners.tuners_utils import BaseTunerLayer
weights = []
for adapter_name in names:
weight = 1.0
for module in model.modules():
if isinstance(module, BaseTunerLayer):
if adapter_name in module.scaling:
weight = (
module.scaling[adapter_name]
* module.r[adapter_name]
/ module.lora_alpha[adapter_name]
)
weights.append(weight)
return weights