import spaces from pathlib import Path import torch import gradio as gr from PIL import Image, ExifTags import numpy as np import torch from torch import Tensor from einops import rearrange import uuid import os from src.flux.modules.layers import ( SingleStreamBlockProcessor, DoubleStreamBlockLoraProcessor, IPDoubleStreamBlockProcessor, ImageProjModel, ) from src.flux.sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack from src.flux.util import ( load_ae, load_clip, load_flow_model, load_t5, load_controlnet, load_flow_model_quintized, Annotator, get_lora_rank, load_checkpoint ) from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor class XFluxPipeline: def __init__(self, model_type, device, offload: bool = False): self.device = torch.device(device) self.offload = offload self.model_type = model_type self.clip = load_clip(self.device) self.t5 = load_t5(self.device, max_length=512) = load_ae(model_type, device="cpu" if offload else self.device) if "fp8" in model_type: self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) else: self.model = load_flow_model(model_type, device="cpu" if offload else self.device) self.image_encoder_path = "openai/clip-vit-large-patch14" self.hf_lora_collection = "XLabs-AI/flux-lora-collection" self.lora_types_to_names = { "realism": "lora.safetensors", } self.controlnet_loaded = False self.ip_loaded = False def set_ip(self, local_path: str = None, repo_id = None, name: str = None): # unpack checkpoint checkpoint = load_checkpoint(local_path, repo_id, name) prefix = "double_blocks." blocks = {} proj = {} for key, value in checkpoint.items(): if key.startswith(prefix): blocks[key[len(prefix):].replace('.processor.', '.')] = value if key.startswith("ip_adapter_proj_model"): proj[key[len("ip_adapter_proj_model."):]] = value for key, value in checkpoint.items(): if key.startswith(prefix): blocks[key[len(prefix):].replace('.processor.', '.')] = value if key.startswith("ip_adapter_proj_model"): proj[key[len("ip_adapter_proj_model."):]] = value # load image encoder self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( self.device, dtype=torch.float16 ) self.clip_image_processor = CLIPImageProcessor() # setup image embedding projection model self.improj = ImageProjModel(4096, 768, 4) self.improj.load_state_dict(proj) self.improj =, dtype=torch.bfloat16) ip_attn_procs = {} for name, _ in self.model.attn_processors.items(): ip_state_dict = {} for k in checkpoint.keys(): if name in k: ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k] if ip_state_dict: ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) ip_attn_procs[name].load_state_dict(ip_state_dict) ip_attn_procs[name].to(self.device, dtype=torch.bfloat16) else: ip_attn_procs[name] = self.model.attn_processors[name] self.model.set_attn_processor(ip_attn_procs) self.ip_loaded = True def set_lora(self, local_path: str = None, repo_id: str = None, name: str = None, lora_weight: int = 0.7): checkpoint = load_checkpoint(local_path, repo_id, name) self.update_model_with_lora(checkpoint, lora_weight) def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7): checkpoint = load_checkpoint( None, self.hf_lora_collection, self.lora_types_to_names[lora_type] ) self.update_model_with_lora(checkpoint, lora_weight) def update_model_with_lora(self, checkpoint, lora_weight): rank = get_lora_rank(checkpoint) lora_attn_procs = {} for name, _ in self.model.attn_processors.items(): if name.startswith("single_blocks"): lora_attn_procs[name] = SingleStreamBlockProcessor() continue lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) lora_state_dict = {} for k in checkpoint.keys(): if name in k: lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight lora_attn_procs[name].load_state_dict(lora_state_dict) lora_attn_procs[name].to(self.device) self.model.set_attn_processor(lora_attn_procs) def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16) checkpoint = load_checkpoint(local_path, repo_id, name) self.controlnet.load_state_dict(checkpoint, strict=False) self.annotator = Annotator(control_type, self.device) self.controlnet_loaded = True self.control_type = control_type def get_image_proj( self, image_prompt: Tensor, ): # encode image-prompt embeds image_prompt = self.clip_image_processor( images=image_prompt, return_tensors="pt" ).pixel_values image_prompt = image_prompt_embeds = self.image_encoder( image_prompt ) device=self.device, dtype=torch.bfloat16, ) # encode image image_proj = self.improj(image_prompt_embeds) return image_proj def __call__(self, prompt: str, image_prompt: Image = None, controlnet_image: Image = None, width: int = 512, height: int = 512, guidance: float = 4, num_steps: int = 50, seed: int = 123456789, true_gs: float = 3, control_weight: float = 0.9, ip_scale: float = 1.0, neg_ip_scale: float = 1.0, neg_prompt: str = '', neg_image_prompt: Image = None, timestep_to_start_cfg: int = 0, ): width = 16 * (width // 16) height = 16 * (height // 16) image_proj = None neg_image_proj = None if not (image_prompt is None and neg_image_prompt is None) : assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input' if image_prompt is None: image_prompt = np.zeros((width, height, 3), dtype=np.uint8) if neg_image_prompt is None: neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8) image_proj = self.get_image_proj(image_prompt) neg_image_proj = self.get_image_proj(neg_image_prompt) if self.controlnet_loaded: controlnet_image = self.annotator(controlnet_image, width, height) controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) controlnet_image = controlnet_image.permute( 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) return self.forward( prompt, width, height, guidance, num_steps, seed, controlnet_image, timestep_to_start_cfg=timestep_to_start_cfg, true_gs=true_gs, control_weight=control_weight, neg_prompt=neg_prompt, image_proj=image_proj, neg_image_proj=neg_image_proj, ip_scale=ip_scale, neg_ip_scale=neg_ip_scale, ) @torch.inference_mode() @spaces.GPU() def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance, num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, lora_weight, local_path, lora_local_path, ip_local_path): if controlnet_image is not None: controlnet_image = Image.fromarray(controlnet_image) if ((self.controlnet_loaded and control_type != self.control_type) or not self.controlnet_loaded): if local_path is not None: self.set_controlnet(control_type, local_path=local_path) else: self.set_controlnet(control_type, local_path=None, repo_id=f"xlabs-ai/flux-controlnet-{control_type}-v3", name=f"flux-{control_type}-controlnet-v3.safetensors") if lora_local_path is not None: self.set_lora(local_path=lora_local_path, lora_weight=lora_weight) if image_prompt is not None: image_prompt = Image.fromarray(image_prompt) if neg_image_prompt is not None: neg_image_prompt = Image.fromarray(neg_image_prompt) if not self.ip_loaded: if ip_local_path is not None: self.set_ip(local_path=ip_local_path) else: self.set_ip(repo_id="xlabs-ai/flux-ip-adapter", name="flux-ip-adapter.safetensors") seed = int(seed) if seed == -1: seed = torch.Generator(device="cpu").seed() img = self(prompt, image_prompt, controlnet_image, width, height, guidance, num_steps, seed, true_gs, control_weight, ip_scale, neg_ip_scale, neg_prompt, neg_image_prompt, timestep_to_start_cfg) filename = f"output/gradio/{uuid.uuid4()}.jpg" os.makedirs(os.path.dirname(filename), exist_ok=True) exif_data = Image.Exif() exif_data[ExifTags.Base.Make] = "XLabs AI" exif_data[ExifTags.Base.Model] = self.model_type, format="jpeg", exif=exif_data, quality=95, subsampling=0) return img, filename def forward( self, prompt, width, height, guidance, num_steps, seed, controlnet_image = None, timestep_to_start_cfg = 0, true_gs = 3.5, control_weight = 0.9, neg_prompt="", image_proj=None, neg_image_proj=None, ip_scale=1.0, neg_ip_scale=1.0, ): x = get_noise( 1, height, width, device=self.device, dtype=torch.bfloat16, seed=seed ) timesteps = get_schedule( num_steps, (width // 8) * (height // 8) // (16 * 16), shift=True, ) torch.manual_seed(seed) with torch.no_grad(): if self.offload: self.t5, self.clip =, inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt) neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) if self.offload: self.offload_model_to_cpu(self.t5, self.clip) self.model = if self.controlnet_loaded: x = denoise_controlnet( self.model, **inp_cond, controlnet=self.controlnet, timesteps=timesteps, guidance=guidance, controlnet_cond=controlnet_image, timestep_to_start_cfg=timestep_to_start_cfg, neg_txt=neg_inp_cond['txt'], neg_txt_ids=neg_inp_cond['txt_ids'], neg_vec=neg_inp_cond['vec'], true_gs=true_gs, controlnet_gs=control_weight, image_proj=image_proj, neg_image_proj=neg_image_proj, ip_scale=ip_scale, neg_ip_scale=neg_ip_scale, ) else: x = denoise( self.model, **inp_cond, timesteps=timesteps, guidance=guidance, timestep_to_start_cfg=timestep_to_start_cfg, neg_txt=neg_inp_cond['txt'], neg_txt_ids=neg_inp_cond['txt_ids'], neg_vec=neg_inp_cond['vec'], true_gs=true_gs, image_proj=image_proj, neg_image_proj=neg_image_proj, ip_scale=ip_scale, neg_ip_scale=neg_ip_scale, ) if self.offload: self.offload_model_to_cpu(self.model) x = unpack(x.float(), height, width) x = self.offload_model_to_cpu( x1 = x.clamp(-1, 1) x1 = rearrange(x1[-1], "c h w -> h w c") output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) return output_img def offload_model_to_cpu(self, *models): def offload_model_to_cpu(self, *models):
        if not self.offload: return
        for model in models:
            model.cpu()
            torch.cuda.empty_cache()

def create_demo(
        model_type: str,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        offload: bool = False,
        ckpt_dir: str = "",
):
    xflux_pipeline = XFluxPipeline(model_type, device, offload)
    checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors"))

    with gr.Blocks() as demo: gr.Markdown(f"# Flux Adapters by XLabs AI - Model: {model_type}")
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
                with gr.Accordion("Generation Options", open=False):
                    with gr.Row():
                        width = gr.Slider(512, 2048, 1024, step=16, label="Width")
                        height = gr.Slider(512, 2048, 1024, step=16, label="Height")
                    neg_prompt = gr.Textbox(label="Negative Prompt", value="bad photo")
                    with gr.Row():
                        num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
                        timestep_to_start_cfg = gr.Slider(1, 50, 1, step=1, label="timestep_to_start_cfg")
                    with gr.Row():
                        guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
                        true_gs = gr.Slider(1.0, 5.0, 3.5, step=0.1, label="True Guidance", interactive=True)
                    seed = gr.Textbox(-1, label="Seed (-1 for random)")
                with gr.Accordion("ControlNet Options", open=False):
                    control_type = gr.Dropdown(["canny", "hed", "depth"], label="Control type")
                    control_weight = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Controlnet weight", interactive=True)
                    local_path = gr.Dropdown(checkpoints, label="Controlnet Checkpoint",
                                             info="Local Path to Controlnet weights (if no, it will be downloaded from HF)"
                                             )
                    controlnet_image = gr.Image(label="Input Controlnet Image", visible=True, interactive=True)
                with gr.Accordion("LoRA Options", open=False):
                    lora_weight = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="LoRA weight", interactive=True)
                    lora_local_path = gr.Dropdown(
                        checkpoints, label="LoRA Checkpoint",
                        info="Local Path to Lora weights"
                    )
                with gr.Accordion("IP Adapter Options", open=False):
                    image_prompt = gr.Image(label="image_prompt", visible=True, interactive=True)
                    ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="ip_scale")
                    neg_image_prompt = gr.Image(label="neg_image_prompt", visible=True, interactive=True)
                    neg_ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="neg_ip_scale")
                    ip_local_path = gr.Dropdown(
                        checkpoints, label="IP Adapter Checkpoint",
                        info="Local Path to IP Adapter weights (if no, it will be downloaded from HF)"
                    )
                generate_btn = gr.Button("Generate")
            with gr.Column():
                output_image = gr.Image(label="Generated Image")
                download_btn = gr.File(label="Download full-resolution")

        inputs = [prompt, image_prompt, controlnet_image, width, height, guidance, num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, lora_weight, local_path, lora_local_path, ip_local_path
                  ]
        
            fn=xflux_pipeline.gradio_generate,
            inputs=inputs,
            outputs=[output_image, download_btn],
        )

    return demo


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Flux")
    parser.add_argument("--name", type=str, default="flux-dev", help="Model name")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device to use")
    parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
    parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
    parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format")
    args = parser.parse_args()

    demo = create_demo(, args.device, args.offload, args.ckpt_dir)
    demo.launch(share=args.share)