File size: 3,722 Bytes
b73544b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
---
license: openrail++
---
# FLAX Latent Consistency Model (LCM) LoRA: SDXL - UNet
Unet with merged LCM weights (lora_scale=0.7) and converted to work with FLAX.
## Setup
To use on TPUs:
```bash
git clone https://github.com/entrpn/diffusers
cd diffusers
git checkout lcm_flax
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install transformers flax torch torchvision
pip install .
```
## Run
```python
import os
from diffusers import FlaxStableDiffusionXLPipeline
import torch
import time
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
import numpy as np
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache(os.path.expanduser("~/jax_cache"))
from diffusers import (
FlaxUNet2DConditionModel,
FlaxLCMScheduler
)
base_model = "stabilityai/stable-diffusion-xl-base-1.0"
weight_dtype = jnp.bfloat16
revision= 'refs/pr/95'
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
base_model, revision=revision, dtype=weight_dtype
)
del params["unet"]
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
"jffacevedo/flax_lcm_unet",
dtype=weight_dtype,
)
scheduler, scheduler_state = FlaxLCMScheduler.from_pretrained(
base_model,
subfolder="scheduler",
revision=revision,
dtype=jnp.float32
)
params["unet"] = unet_params
pipeline.unet = unet
pipeline.scheduler = scheduler
params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), params)
params["scheduler"] = scheduler_state
default_prompt = "high-quality photo of a baby dolphin playing in a pool and wearing a party hat"
default_neg_prompt = ""
default_seed = 42
default_guidance_scale = 1.0
default_num_steps = 4
def tokenize_prompt(prompt, neg_prompt):
prompt_ids = pipeline.prepare_inputs(prompt)
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
return prompt_ids, neg_prompt_ids
NUM_DEVICES = jax.device_count()
p_params = replicate(params)
def replicate_all(prompt_ids, neg_prompt_ids, seed):
p_prompt_ids = replicate(prompt_ids)
p_neg_prompt_ids = replicate(neg_prompt_ids)
rng = jax.random.PRNGKey(seed)
rng = jax.random.split(rng, NUM_DEVICES)
return p_prompt_ids, p_neg_prompt_ids, rng
def generate(
prompt,
negative_prompt,
seed=default_seed,
guidance_scale=default_guidance_scale,
num_inference_steps=default_num_steps,
):
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
images = pipeline(
prompt_ids,
p_params,
rng,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
do_classifier_free_guidance=False,
jit=True,
).images
print("images.shape: ", images.shape)
# convert the images to PIL
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
return pipeline.numpy_to_pil(np.array(images))
start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")
dts = []
i = 0
for x in range(2):
start = time.time()
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
neg_prompt = ""
print(f"Prompt: {prompt}")
images = generate(prompt, neg_prompt)
t = time.time() - start
print(f"Inference in {t}")
dts.append(t)
for img in images:
img.save(f'{i:06d}.jpg')
i += 1
mean = np.mean(dts)
stdev = np.std(dts)
print(f"batches: {i}, Mean {mean:.2f} sec/batch± {stdev * 1.96 / np.sqrt(len(dts)):.2f} (95%)")
```
|