--- license: apache-2.0 library_name: diffusers tags: - stable-diffusion-xl - stable-diffusion-xl-diffusers - text-to-image - diffusers - controlnet - diffusers-training --- # SDXL ControlNet: DWPose Here are controlnet weights trained on stabilityai/stable-diffusion-xl-base-1.0 with [DWPose](https://github.com/IDEA-Research/DWPose) conditioning. ### Using in 🧨 diffusers First, install all the libraries: ```bash pip install -q easy-dwpose transformers accelerate pip install -q git+https://github.com/huggingface/diffusers ``` #### Example 1 To generate a realistic DJ with the following image driving the pose: ![Pose Image 1](./images/pose_image_1.png) Run the following code: ```python from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline import torch from diffusers.utils import load_image from easy_dwpose import DWposeDetector pose_image = load_image("./pose_image_1.png") # Load detector device = "cuda:0" if torch.cuda.is_available() else "cpu" dwpose = DWposeDetector(device=device) # Compute DWpose conditioning image. skeleton = dwpose( pose_image, detect_resolution=pose_image.width, output_type="pil", include_hands=True, include_face=True, ) # Initialize ControlNet pipeline. controlnet = ControlNetModel.from_pretrained( "dimitribarbot/controlnet-dwpose-sdxl-1.0", torch_dtype=torch.float16, ) pipe = StableDiffusionXLControlNetPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", ).to(device) # Infer. prompt = "DJ in a party, shallow depth of field, highly detailed, high budget, gorgeous" negative_prompt = "bad quality, blur, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured" image = pipe( prompt, negative_prompt=negative_prompt, num_inference_steps=50, guidance_scale=5, image=skeleton, generator=torch.manual_seed(97), ).images[0] ``` Generated pose is: ![Pose 1](./images/dwpose_1.png) Image generated by SDXL is: ![Pose 1](./images/dwpose_image_1.png) #### Example 2 To generate a anime version of a woman sitting on a bench with the following image driving the pose: ![Pose Image 2](./images/pose_image_2.png) Run the following code: ```python from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline import torch from diffusers.utils import load_image from easy_dwpose import DWposeDetector pose_image = load_image("./pose_image_2.png") # Load detector device = "cuda:0" if torch.cuda.is_available() else "cpu" dwpose = DWposeDetector(device=device) # Compute DWpose conditioning image. skeleton = dwpose( pose_image, detect_resolution=pose_image.width, output_type="pil", include_hands=True, include_face=True, ) # Initialize ControlNet pipeline. controlnet = ControlNetModel.from_pretrained( "dimitribarbot/controlnet-dwpose-sdxl-1.0", torch_dtype=torch.float16, ) pipe = StableDiffusionXLControlNetPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", ) if torch.cuda.is_available(): pipe.to(torch.device("cuda")) # Infer. prompt = "Anime girl sitting on a bench, highly detailed, noon, ambiant light" negative_prompt = "bad quality, blur, anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured" image = pipe( prompt, negative_prompt=negative_prompt, num_inference_steps=25, guidance_scale=18, image=skeleton, generator=torch.manual_seed(79), ).images[0] ``` Generated pose is: ![Pose 2](./images/dwpose_2.png) Image generated by SDXL is: ![Pose 2](./images/dwpose_image_2.png) ### Training The [training script](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sdxl.md) by HF🤗 was used. #### Training data This checkpoint was trained for 15,000 steps on the [dimitribarbot/dw_pose_controlnet](https://huggingface.co/datasets/dimitribarbot/dw_pose_controlnet) dataset with a resolution of 1024. #### Compute One 1xA40 machine (during 48 hours) #### Batch size Data parallel with a single GPU batch size of 2 with gradient accumulation 8. #### Hyper Parameters Constant learning rate of 8e-5 #### Mixed precision fp16