Usage InstanID-XS:

1.Download model .

# InstanID-XS
huggingface-cli download --resume-download RED-AIGC/InstantID-XS --local-dir ./checkpoints
# vae: madebyollin/sdxl-vae-fp16-fix
huggingface-cli download --resume-download madebyollin/sdxl-vae-fp16-fix --local-dir ./checkpoints
# base model: RealVisXL V4.0
huggingface-cli download --resume-download frankjoshua/realvisxlV40_v40Bakedvae --local-dir ./checkpoints

2.Get pipeline

Note: In ControlNetXS, the input of encoder_hidden_states in the controlnet part is the same as that of UNET by default, which is prompt-embeddings. We decouple the inputs of the two, so that the input of encoder_hidden_states in UNET is prompt-embeddings, while the input of encoder_hidden_states in the controlnet part is face-embeddings.

from diffusers import AutoencoderKL,UNet2DConditionModel,UniPCMultistepScheduler
from controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
import torch


base_model = './checkpoints/frankjoshua/realvisxlV40_v40Bakedvae'
vae_path = './checkpoints/madebyollin/sdxl-vae-fp16-fix'
ckpt = './checkpoints/RED-AIGC/InstantID-XS'

image_proj_path = os.path.join(ckpt, "image_proj.bin")
cnxs_path =  os.path.join(ckpt, "controlnetxs.bin")
cross_attn_path = os.path.join(ckpt, "cross_attn.bin")

# Get ControlNetXS:
unet = UNet2DConditionModel.from_pretrained(base_model, subfolder="unet").to(device, dtype=weight_dtype)
controlnet = ControlNetXSAdapter.from_unet(unet, size_ratio=0.125, learn_time_embedding=True)
state_dict = torch.load(cnxs_path, map_location="cpu", weights_only=True)
ctrl_state_dict = {}
for key, value in state_dict.items():
    if 'ctrl_' in key and 'ctrl_to_base' not in key:
        key = key.replace('ctrl_', '')
    if 'up_blocks' in key:
        key = key.replace('up_blocks', 'up_connections')
    ctrl_state_dict[key] = value
controlnet.load_state_dict(ctrl_state_dict, strict=True)
controlnet.to(device, dtype=weight_dtype)
ControlNetXS = UNetControlNetXSModel.from_unet(unet, controlnet).to(device, dtype=weight_dtype)


# Get pipeline
vae = AutoencoderKL.from_pretrained(vae_model)

pipe = StableDiffusionXLInstantIDXSPipeline.from_pretrained(
    base_model,
    vae=vae,
    unet=ControlNetXS,
    controlnet=None,
    torch_dtype=weight_dtype,
)

pipe.cuda(device=device, dtype=weight_dtype, use_xformers=True)
pipe.load_ip_adapter(image_proj_path, cross_attn_path)

pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.unet.config.ctrl_learn_time_embedding = True
pipe = pipe.to(device)

3.Infer:

import cv2
import os
from PIL import Image
from insightface.app import FaceAnalysis

app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))


img_path = './image.jpg'
image = cv2.imread(img_path)
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
image = resize_img(image)

face_infos = app.get(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
face_info = sorted(face_infos, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] 
face_emb = torch.from_numpy(face_info.normed_embedding)
face_kps = draw_kps_pil(image, face_info['kps'])

prompt = 'a woman, (looking at the viewer), portrait, daily wear, 8K texture, realistic, symmetrical hyperdetailed texture, masterpiece, enhanced details, (eye highlight:2), perfect composition, natural lighting, best quality, authentic, natural posture'
n_prompt = '(worst quality:2), (low quality:2), (normal quality:2), lowres, bad anatomy, bad hands, normal quality, long neck, hunchback, narrow shoulder, wall, (blurry), vague, indistinct, (shiny face:2), (buffing:2), (face highlight:2), pale skin'

seed = 0
image = pipe(
    prompt=prompt,
    negative_prompt=n_prompt,
    image=face_kps,
    face_emb=face_emb,
    num_images_per_prompt=1,
    num_inference_steps=20,
    generator=torch.Generator(device=device).manual_seed(seed),
    ip_adapter_scale=0.8,
    guidance_scale=4.0,
    controlnet_conditioning_scale=0.8,
).images[0]
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .

Space using RED-AIGC/InstantID-XS 1