File size: 1,771 Bytes
14b22ed
 
 
 
92f9aad
14b22ed
e26b2cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14b22ed
 
 
 
92f9aad
14b22ed
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
---
license: openrail++
---

This checkkpoint is compiled by [ByteDance/SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) for AWS Inf2.

## Compilation

Download the unet checkpoint from [ByteDance/SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) and replace the unet checkpoint under the original sdxl checkpoint:

```python
from huggingface_hub import hf_hub_download

repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors"
hf_hub_download(repo, ckpt)
```

Replace the unet:

```bash
cp /home/ubuntu/.cache/huggingface/hub/models--ByteDance--SDXL-Lightning/snapshots/xxxxxx/sdxl_lightning_4step_unet.safetensors stable-diffusion-xl-lightning/unet/diffusion_pytorch_model.safetensors
```

Compile the whole pipeline:

```python
from optimum.neuron import NeuronStableDiffusionXLPipeline

model_id = "stable-diffusion-xl-lightning"
num_images_per_prompt = 1
input_shapes = {"batch_size": 1, "height": 1024, "width": 1024, "num_images_per_prompt": num_images_per_prompt}
compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"}

stable_diffusion = NeuronStableDiffusionXLPipeline.from_pretrained(
    model_id, export=True, **compiler_args, **input_shapes
)
save_directory = "sdxl_lightning_4_steps_neuronx/"
stable_diffusion.save_pretrained(save_directory)
# push to hub
```

## Inference

```python
from optimum.neuron import NeuronStableDiffusionXLPipeline
from diffusers import EulerDiscreteScheduler

pipe = NeuronStableDiffusionXLPipeline.from_pretrained("aws-neuron/SDXL-Lightning-4steps-neuronx")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe("A girl smiling", num_inference_steps=4, guidance_scale=0).images[0].save("output.png")

```