import os import random import gradio as gr import numpy as np import spaces import torch from einops import rearrange from huggingface_hub import login from peft import LoraConfig from PIL import Image from pipelines.flux_pipeline.pipeline import SynCDFluxPipeline from pipelines.flux_pipeline.transformer import FluxTransformer2DModelWithMasking HF_TOKEN = os.getenv('HF_TOKEN') login(token=HF_TOKEN) torch_dtype = torch.bfloat16 transformer = FluxTransformer2DModelWithMasking.from_pretrained( 'black-forest-labs/FLUX.1-dev', subfolder='transformer', torch_dtype=torch_dtype ) pipeline = SynCDFluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', transformer=transformer, torch_dtype=torch_dtype) for name, attn_proc in pipeline.transformer.attn_processors.items(): attn_proc.name = name target_modules=[ "to_k", "to_q", "to_v", "add_k_proj", "add_q_proj", "add_v_proj", "to_out.0", "to_add_out", "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", "proj_mlp", "proj_out", ] lora_rank = 32 lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_rank, init_lora_weights="gaussian", target_modules=target_modules, ) pipeline.transformer.add_adapter(lora_config) finetuned_path = torch.load('models/pytorch_model.bin', map_location='cpu') transformer_dict = {} for key,value in finetuned_path.items(): if 'transformer.base_model.model.' in key: transformer_dict[key.replace('transformer.base_model.model.', '')] = value pipeline.transformer.load_state_dict(transformer_dict, strict=False) pipeline.to('cuda') pipeline.enable_vae_slicing() pipeline.enable_vae_tiling() @torch.no_grad() def decode(latents, pipeline): latents = latents / pipeline.vae.config.scaling_factor image = pipeline.vae.decode(latents, return_dict=False)[0] return image @torch.no_grad() def encode_target_images(images, pipeline): latents = pipeline.vae.encode(images).latent_dist.sample() latents = latents * pipeline.vae.config.scaling_factor return latents @spaces.GPU(duration=120) def generate_image(text, img1, img2, img3, guidance_scale, inference_steps, seed, rigid_object, enable_cpu_offload=False): if enable_cpu_offload: pipeline.enable_sequential_cpu_offload() input_images = [img1, img2, img3] # Delete None input_images = [img for img in input_images if img is not None] if len(input_images) == 0: return "Please upload at least one image" numref = len(input_images) + 1 images = torch.cat([2. * torch.from_numpy(np.array(Image.open(img).convert('RGB').resize((512, 512)))).permute(2, 0, 1).unsqueeze(0).to(torch_dtype)/255. -1. for img in input_images]) images = images.to(pipeline.device) latents = encode_target_images(images, pipeline) latents = torch.cat([torch.zeros_like(latents[:1]), latents], dim=0) masklatent = torch.zeros_like(latents) masklatent[:1] = 1. latents = rearrange(latents, "(b n) c h w -> b c h (n w)", n=numref) masklatent = rearrange(masklatent, "(b n) c h w -> b c h (n w)", n=numref) B, C, H, W = latents.shape latents = pipeline._pack_latents(latents, B, C, H, W) masklatent = pipeline._pack_latents(masklatent.expand(-1, C, -1, -1) ,B, C, H, W) output = pipeline( text, latents_ref=latents, latents_mask=masklatent, guidance_scale=guidance_scale, num_inference_steps=inference_steps, height=512, width=numref * 512, generator = torch.Generator(device="cpu").manual_seed(seed), joint_attention_kwargs={'shared_attn': True, 'num': numref}, return_dict=False, )[0][0] output = rearrange(output, "b c h (n w) -> (b n) c h w", n=numref)[::numref] img = Image.fromarray( (( torch.clip(output[0].float(), -1., 1.).permute(1,2,0).cpu().numpy()*0.5+0.5)*255).astype(np.uint8) ) return img def get_example(): case = [ [ "An action figure on top of a mountain. Sunset in the background. Realistic shot.", "./imgs/test_cases/action_figure/0.jpg", "./imgs/test_cases/action_figure/1.jpg", "./imgs/test_cases/action_figure/2.jpg", 3.5, 42, True, ], [ "A penguin plushie wearing pink sunglasses is lounging on a beach. Realistic shot.", "./imgs/test_cases/penguin/0.jpg", "./imgs/test_cases/penguin/1.jpg", "./imgs/test_cases/penguin/2.jpg", 3.5, 42, True, ], [ "A toy on a beach. Waves in the background. Realistic shot.", "./imgs/test_cases/rc_car/02.jpg", "./imgs/test_cases/rc_car/03.jpg", "./imgs/test_cases/rc_car/04.jpg", 3.5, 42, True, ], ] return case def run_for_examples(text, img1, img2, img3, guidance_scale, seed, rigid_object, enable_cpu_offload=False): inference_steps = 30 return generate_image( text, img1, img2, img3, guidance_scale, inference_steps, seed, rigid_object, enable_cpu_offload ) description = """ Synthetic Customization Dataset (SynCD) consists of multiple images of the same object in different contexts. We achieve it by promoting similar object identity using either explicit 3D object assets or, more implicitly, using masked shared attention across different views while generating images. Given this training data, we train a new encoder-based model for the task, which can successfully generate new compositions of a reference object using text prompts. You can download our dataset [here](https://huggingface.co/datasets/nupurkmr9/syncd). Our model supports multiple input images of the same object as references. You can upload up to 3 images, with better results on 3 images vs 1 image. **HF Spaces often encounter errors due to quota limitations, so recommend to run it locally.** """ article = """ --- **Citation**
If you find this repository useful, please consider giving a star ⭐ and a citation ``` @article{kumari2025syncd, title={Generating Multi-Image Synthetic Data for Text-to-Image Customization}, author={Kumari, Nupur and Yin, Xi and Zhu, Jun-Yan and Misra, Ishan and Azadi, Samaneh}, journal={ArXiv}, year={2025} } ``` **Contact**
If you have any questions, please feel free to open an issue or directly reach us out via email. **Acknowledgement**
This space was modified from [OmniGen](https://huggingface.co/spaces/Shitao/OmniGen) space. """ # Gradio with gr.Blocks() as demo: gr.Markdown("# SynCD: Generating Multi-Image Synthetic Data for Text-to-Image Customization [[paper](https://arxiv.org/abs/2502.01720)] [[code](https://github.com/nupurkmr9/syncd)]") gr.Markdown(description) with gr.Row(): with gr.Column(): # text prompt prompt_input = gr.Textbox( label="Enter your prompt, more descriptive prompt will lead to better results", placeholder="Type your prompt here..." ) with gr.Row(equal_height=True): # input images image_input_1 = gr.Image(label="img1", type="filepath") image_input_2 = gr.Image(label="img2", type="filepath") image_input_3 = gr.Image(label="img3", type="filepath") guidance_scale_input = gr.Slider( label="Guidance Scale", minimum=1.0, maximum=5.0, value=3.5, step=0.1 ) num_inference_steps = gr.Slider( label="Inference Steps", minimum=1, maximum=100, value=30, step=1 ) seed_input = gr.Slider( label="Seed", minimum=0, maximum=2147483647, value=42, step=1 ) rigid_object = gr.Checkbox( label="rigid_object", info="Whether its a rigid object or a deformable object like pet animals, wearable etc.", value=True, ) enable_cpu_offload = gr.Checkbox( label="Enable CPU Offload", info="Enable CPU Offload to avoid memory issues", value=False, ) # generate generate_button = gr.Button("Generate Image") with gr.Column(): # output image output_image = gr.Image(label="Output Image") # click generate_button.click( generate_image, inputs=[ prompt_input, image_input_1, image_input_2, image_input_3, guidance_scale_input, num_inference_steps, seed_input, rigid_object, enable_cpu_offload, ], outputs=output_image, ) gr.Examples( examples=get_example(), fn=run_for_examples, inputs=[ prompt_input, image_input_1, image_input_2, image_input_3, guidance_scale_input, seed_input, rigid_object, ], outputs=output_image, ) gr.Markdown(article) # launch demo.launch()