import os
import random
import gradio as gr
import numpy as np
import spaces
import torch
from einops import rearrange
from huggingface_hub import login
from peft import LoraConfig
from PIL import Image
from pipelines.flux_pipeline.pipeline import SynCDFluxPipeline
from pipelines.flux_pipeline.transformer import FluxTransformer2DModelWithMasking
HF_TOKEN = os.getenv('HF_TOKEN')
login(token=HF_TOKEN)
torch_dtype = torch.bfloat16
transformer = FluxTransformer2DModelWithMasking.from_pretrained(
'black-forest-labs/FLUX.1-dev',
subfolder='transformer',
torch_dtype=torch_dtype
)
pipeline = SynCDFluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', transformer=transformer, torch_dtype=torch_dtype)
for name, attn_proc in pipeline.transformer.attn_processors.items():
attn_proc.name = name
target_modules=[
"to_k",
"to_q",
"to_v",
"add_k_proj",
"add_q_proj",
"add_v_proj",
"to_out.0",
"to_add_out",
"ff.net.0.proj",
"ff.net.2",
"ff_context.net.0.proj",
"ff_context.net.2",
"proj_mlp",
"proj_out",
]
lora_rank = 32
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_rank,
init_lora_weights="gaussian",
target_modules=target_modules,
)
pipeline.transformer.add_adapter(lora_config)
finetuned_path = torch.load('models/pytorch_model.bin', map_location='cpu')
transformer_dict = {}
for key,value in finetuned_path.items():
if 'transformer.base_model.model.' in key:
transformer_dict[key.replace('transformer.base_model.model.', '')] = value
pipeline.transformer.load_state_dict(transformer_dict, strict=False)
pipeline.to('cuda')
pipeline.enable_vae_slicing()
pipeline.enable_vae_tiling()
@torch.no_grad()
def decode(latents, pipeline):
latents = latents / pipeline.vae.config.scaling_factor
image = pipeline.vae.decode(latents, return_dict=False)[0]
return image
@torch.no_grad()
def encode_target_images(images, pipeline):
latents = pipeline.vae.encode(images).latent_dist.sample()
latents = latents * pipeline.vae.config.scaling_factor
return latents
@spaces.GPU(duration=120)
def generate_image(text, img1, img2, img3, guidance_scale, inference_steps, seed, rigid_object, enable_cpu_offload=False):
if enable_cpu_offload:
pipeline.enable_sequential_cpu_offload()
input_images = [img1, img2, img3]
# Delete None
input_images = [img for img in input_images if img is not None]
if len(input_images) == 0:
return "Please upload at least one image"
numref = len(input_images) + 1
images = torch.cat([2. * torch.from_numpy(np.array(Image.open(img).convert('RGB').resize((512, 512)))).permute(2, 0, 1).unsqueeze(0).to(torch_dtype)/255. -1. for img in input_images])
images = images.to(pipeline.device)
latents = encode_target_images(images, pipeline)
latents = torch.cat([torch.zeros_like(latents[:1]), latents], dim=0)
masklatent = torch.zeros_like(latents)
masklatent[:1] = 1.
latents = rearrange(latents, "(b n) c h w -> b c h (n w)", n=numref)
masklatent = rearrange(masklatent, "(b n) c h w -> b c h (n w)", n=numref)
B, C, H, W = latents.shape
latents = pipeline._pack_latents(latents, B, C, H, W)
masklatent = pipeline._pack_latents(masklatent.expand(-1, C, -1, -1) ,B, C, H, W)
output = pipeline(
text,
latents_ref=latents,
latents_mask=masklatent,
guidance_scale=guidance_scale,
num_inference_steps=inference_steps,
height=512,
width=numref * 512,
generator = torch.Generator(device="cpu").manual_seed(seed),
joint_attention_kwargs={'shared_attn': True, 'num': numref},
return_dict=False,
)[0][0]
output = rearrange(output, "b c h (n w) -> (b n) c h w", n=numref)[::numref]
img = Image.fromarray( (( torch.clip(output[0].float(), -1., 1.).permute(1,2,0).cpu().numpy()*0.5+0.5)*255).astype(np.uint8) )
return img
def get_example():
case = [
[
"An action figure on top of a mountain. Sunset in the background. Realistic shot.",
"./imgs/test_cases/action_figure/0.jpg",
"./imgs/test_cases/action_figure/1.jpg",
"./imgs/test_cases/action_figure/2.jpg",
3.5,
42,
True,
],
[
"A penguin plushie wearing pink sunglasses is lounging on a beach. Realistic shot.",
"./imgs/test_cases/penguin/0.jpg",
"./imgs/test_cases/penguin/1.jpg",
"./imgs/test_cases/penguin/2.jpg",
3.5,
42,
True,
],
[
"A toy on a beach. Waves in the background. Realistic shot.",
"./imgs/test_cases/rc_car/02.jpg",
"./imgs/test_cases/rc_car/03.jpg",
"./imgs/test_cases/rc_car/04.jpg",
3.5,
42,
True,
],
]
return case
def run_for_examples(text, img1, img2, img3, guidance_scale, seed, rigid_object, enable_cpu_offload=False):
inference_steps = 30
return generate_image(
text, img1, img2, img3, guidance_scale, inference_steps, seed, rigid_object, enable_cpu_offload
)
description = """
Synthetic Customization Dataset (SynCD) consists of multiple images of the same object in different contexts. We achieve it by promoting similar object identity using either explicit 3D object assets or, more implicitly, using masked shared attention across different views while generating images. Given this training data, we train a new encoder-based model for the task, which can successfully generate new compositions of a reference object using text prompts. You can download our dataset [here](https://huggingface.co/datasets/nupurkmr9/syncd).
Our model supports multiple input images of the same object as references. You can upload up to 3 images, with better results on 3 images vs 1 image.
**HF Spaces often encounter errors due to quota limitations, so recommend to run it locally.**
"""
article = """
---
**Citation**
If you find this repository useful, please consider giving a star ⭐ and a citation
```
@article{kumari2025syncd,
title={Generating Multi-Image Synthetic Data for Text-to-Image Customization},
author={Kumari, Nupur and Yin, Xi and Zhu, Jun-Yan and Misra, Ishan and Azadi, Samaneh},
journal={ArXiv},
year={2025}
}
```
**Contact**
If you have any questions, please feel free to open an issue or directly reach us out via email.
**Acknowledgement**
This space was modified from [OmniGen](https://huggingface.co/spaces/Shitao/OmniGen) space.
"""
# Gradio
with gr.Blocks() as demo:
gr.Markdown("# SynCD: Generating Multi-Image Synthetic Data for Text-to-Image Customization [[paper](https://arxiv.org/abs/2502.01720)] [[code](https://github.com/nupurkmr9/syncd)]")
gr.Markdown(description)
with gr.Row():
with gr.Column():
# text prompt
prompt_input = gr.Textbox(
label="Enter your prompt, more descriptive prompt will lead to better results", placeholder="Type your prompt here..."
)
with gr.Row(equal_height=True):
# input images
image_input_1 = gr.Image(label="img1", type="filepath")
image_input_2 = gr.Image(label="img2", type="filepath")
image_input_3 = gr.Image(label="img3", type="filepath")
guidance_scale_input = gr.Slider(
label="Guidance Scale", minimum=1.0, maximum=5.0, value=3.5, step=0.1
)
num_inference_steps = gr.Slider(
label="Inference Steps", minimum=1, maximum=100, value=30, step=1
)
seed_input = gr.Slider(
label="Seed", minimum=0, maximum=2147483647, value=42, step=1
)
rigid_object = gr.Checkbox(
label="rigid_object", info="Whether its a rigid object or a deformable object like pet animals, wearable etc.", value=True,
)
enable_cpu_offload = gr.Checkbox(
label="Enable CPU Offload", info="Enable CPU Offload to avoid memory issues", value=False,
)
# generate
generate_button = gr.Button("Generate Image")
with gr.Column():
# output image
output_image = gr.Image(label="Output Image")
# click
generate_button.click(
generate_image,
inputs=[
prompt_input,
image_input_1,
image_input_2,
image_input_3,
guidance_scale_input,
num_inference_steps,
seed_input,
rigid_object,
enable_cpu_offload,
],
outputs=output_image,
)
gr.Examples(
examples=get_example(),
fn=run_for_examples,
inputs=[
prompt_input,
image_input_1,
image_input_2,
image_input_3,
guidance_scale_input,
seed_input,
rigid_object,
],
outputs=output_image,
)
gr.Markdown(article)
# launch
demo.launch()