dimitribarbot's picture
Reduce png file size
e3473c0
|
raw
history blame
4.29 kB
metadata
license: apache-2.0
library_name: diffusers
tags:
  - stable-diffusion-xl
  - stable-diffusion-xl-diffusers
  - text-to-image
  - diffusers
  - controlnet
  - diffusers-training

SDXL ControlNet: DWPose

Here are controlnet weights trained on stabilityai/stable-diffusion-xl-base-1.0 with DWPose conditioning.

Using in 🧨 diffusers

First, install all the libraries:

pip install -q easy-dwpose transformers accelerate
pip install -q git+https://github.com/huggingface/diffusers

Example 1

To generate a realistic DJ with the following image driving the pose:

Pose Image 1

Run the following code:

from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
import torch
from diffusers.utils import load_image

from easy_dwpose import DWposeDetector


pose_image = load_image("./pose_image_1.png")

# Load detector
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dwpose = DWposeDetector(device=device)

# Compute DWpose conditioning image.
skeleton = dwpose(
    pose_image,
    detect_resolution=pose_image.width,
    output_type="pil",
    include_hands=True,
    include_face=True,
)

# Initialize ControlNet pipeline.
controlnet = ControlNetModel.from_pretrained(
    "dimitribarbot/controlnet-dwpose-sdxl-1.0",
    torch_dtype=torch.float16,
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    torch_dtype=torch.float16,
    variant="fp16",
).to(device)

# Infer.
prompt = "DJ in a party, shallow depth of field, highly detailed, high budget, gorgeous"
negative_prompt = "bad quality, blur, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured"
image = pipe(
    prompt,
    negative_prompt=negative_prompt,
    num_inference_steps=50,
    guidance_scale=5,
    image=skeleton,
    generator=torch.manual_seed(97),
).images[0]

Generated pose is:

Pose 1

Image generated by SDXL is:

Pose 1

Example 2

To generate a anime version of a woman sitting on a bench with the following image driving the pose:

Pose Image 2

Run the following code:

from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
import torch
from diffusers.utils import load_image

from easy_dwpose import DWposeDetector


pose_image = load_image("./pose_image_2.png")

# Load detector
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dwpose = DWposeDetector(device=device)

# Compute DWpose conditioning image.
skeleton = dwpose(
    pose_image,
    detect_resolution=pose_image.width,
    output_type="pil",
    include_hands=True,
    include_face=True,
)

# Initialize ControlNet pipeline.
controlnet = ControlNetModel.from_pretrained(
    "dimitribarbot/controlnet-dwpose-sdxl-1.0",
    torch_dtype=torch.float16,
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    torch_dtype=torch.float16,
    variant="fp16",
)
if torch.cuda.is_available():
    pipe.to(torch.device("cuda"))

# Infer.
prompt = "Anime girl sitting on a bench, highly detailed, noon, ambiant light"
negative_prompt = "bad quality, blur, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured"
image = pipe(
    prompt,
    negative_prompt=negative_prompt,
    num_inference_steps=25,
    guidance_scale=18,
    image=skeleton,
    generator=torch.manual_seed(79),
).images[0]

Generated pose is:

Pose 2

Image generated by SDXL is:

Pose 2

Training

The training script by HF🤗 was used.

Training data

This checkpoint was trained for 15,000 steps on the dimitribarbot/dw_pose_controlnet dataset with a resolution of 1024.

Compute

One 1xA40 machine (during 48 hours)

Batch size

Data parallel with a single GPU batch size of 2 with gradient accumulation 8.

Hyper Parameters

Constant learning rate of 8e-5

Mixed precision

fp16