|
--- |
|
license: apache-2.0 |
|
library_name: diffusers |
|
tags: |
|
- stable-diffusion-xl |
|
- stable-diffusion-xl-diffusers |
|
- text-to-image |
|
- diffusers |
|
- controlnet |
|
- diffusers-training |
|
--- |
|
|
|
# SDXL ControlNet: DWPose |
|
|
|
Here are controlnet weights trained on stabilityai/stable-diffusion-xl-base-1.0 with [DWPose](https://github.com/IDEA-Research/DWPose) conditioning. |
|
|
|
### Using in 🧨 diffusers |
|
|
|
First, install all the libraries: |
|
|
|
```bash |
|
pip install -q easy-dwpose transformers accelerate |
|
pip install -q git+https://github.com/huggingface/diffusers |
|
``` |
|
|
|
#### Example 1 |
|
|
|
To generate a realistic DJ with the following image driving the pose: |
|
|
|
![Pose Image 1](./images/pose_image_1.png) |
|
|
|
Run the following code: |
|
|
|
```python |
|
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline |
|
import torch |
|
from diffusers.utils import load_image |
|
|
|
from easy_dwpose import DWposeDetector |
|
|
|
|
|
pose_image = load_image("./pose_image_1.png") |
|
|
|
# Load detector |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
dwpose = DWposeDetector(device=device) |
|
|
|
# Compute DWpose conditioning image. |
|
skeleton = dwpose( |
|
pose_image, |
|
detect_resolution=pose_image.width, |
|
output_type="pil", |
|
include_hands=True, |
|
include_face=True, |
|
) |
|
|
|
# Initialize ControlNet pipeline. |
|
controlnet = ControlNetModel.from_pretrained( |
|
"dimitribarbot/controlnet-dwpose-sdxl-1.0", |
|
torch_dtype=torch.float16, |
|
) |
|
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
controlnet=controlnet, |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
).to(device) |
|
|
|
# Infer. |
|
prompt = "DJ in a party, shallow depth of field, highly detailed, high budget, gorgeous" |
|
negative_prompt = "bad quality, blur, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured" |
|
image = pipe( |
|
prompt, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=50, |
|
guidance_scale=5, |
|
image=skeleton, |
|
generator=torch.manual_seed(97), |
|
).images[0] |
|
``` |
|
|
|
Generated pose is: |
|
|
|
![Pose 1](./images/dwpose_1.png) |
|
|
|
Image generated by SDXL is: |
|
|
|
![Pose 1](./images/dwpose_image_1.png) |
|
|
|
#### Example 2 |
|
|
|
To generate a anime version of a woman sitting on a bench with the following image driving the pose: |
|
|
|
![Pose Image 2](./images/pose_image_2.png) |
|
|
|
Run the following code: |
|
|
|
```python |
|
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline |
|
import torch |
|
from diffusers.utils import load_image |
|
|
|
from easy_dwpose import DWposeDetector |
|
|
|
|
|
pose_image = load_image("./pose_image_2.png") |
|
|
|
# Load detector |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
dwpose = DWposeDetector(device=device) |
|
|
|
# Compute DWpose conditioning image. |
|
skeleton = dwpose( |
|
pose_image, |
|
detect_resolution=pose_image.width, |
|
output_type="pil", |
|
include_hands=True, |
|
include_face=True, |
|
) |
|
|
|
# Initialize ControlNet pipeline. |
|
controlnet = ControlNetModel.from_pretrained( |
|
"dimitribarbot/controlnet-dwpose-sdxl-1.0", |
|
torch_dtype=torch.float16, |
|
) |
|
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
controlnet=controlnet, |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
) |
|
if torch.cuda.is_available(): |
|
pipe.to(torch.device("cuda")) |
|
|
|
# Infer. |
|
prompt = "Anime girl sitting on a bench, highly detailed, noon, ambiant light" |
|
negative_prompt = "bad quality, blur, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured" |
|
image = pipe( |
|
prompt, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=25, |
|
guidance_scale=18, |
|
image=skeleton, |
|
generator=torch.manual_seed(79), |
|
).images[0] |
|
``` |
|
|
|
Generated pose is: |
|
|
|
![Pose 2](./images/dwpose_2.png) |
|
|
|
Image generated by SDXL is: |
|
|
|
![Pose 2](./images/dwpose_image_2.png) |
|
|
|
### Training |
|
|
|
The [training script](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sdxl.md) by HF🤗 was used. |
|
|
|
#### Training data |
|
This checkpoint was trained for 15,000 steps on the [dimitribarbot/dw_pose_controlnet](https://huggingface.co/datasets/dimitribarbot/dw_pose_controlnet) dataset with a resolution of 1024. |
|
|
|
#### Compute |
|
One 1xA40 machine (during 48 hours) |
|
|
|
#### Batch size |
|
Data parallel with a single GPU batch size of 2 with gradient accumulation 8. |
|
|
|
#### Hyper Parameters |
|
Constant learning rate of 8e-5 |
|
|
|
#### Mixed precision |
|
fp16 |