Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
from einops import rearrange | |
from huggingface_hub import login | |
from peft import LoraConfig | |
from PIL import Image | |
from pipelines.flux_pipeline.pipeline import SynCDFluxPipeline | |
from pipelines.flux_pipeline.transformer import FluxTransformer2DModelWithMasking | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
login(token=HF_TOKEN) | |
torch_dtype = torch.bfloat16 | |
transformer = FluxTransformer2DModelWithMasking.from_pretrained( | |
'black-forest-labs/FLUX.1-dev', | |
subfolder='transformer', | |
torch_dtype=torch_dtype | |
) | |
pipeline = SynCDFluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', transformer=transformer, torch_dtype=torch_dtype) | |
for name, attn_proc in pipeline.transformer.attn_processors.items(): | |
attn_proc.name = name | |
target_modules=[ | |
"to_k", | |
"to_q", | |
"to_v", | |
"add_k_proj", | |
"add_q_proj", | |
"add_v_proj", | |
"to_out.0", | |
"to_add_out", | |
"ff.net.0.proj", | |
"ff.net.2", | |
"ff_context.net.0.proj", | |
"ff_context.net.2", | |
"proj_mlp", | |
"proj_out", | |
] | |
lora_rank = 32 | |
lora_config = LoraConfig( | |
r=lora_rank, | |
lora_alpha=lora_rank, | |
init_lora_weights="gaussian", | |
target_modules=target_modules, | |
) | |
pipeline.transformer.add_adapter(lora_config) | |
finetuned_path = torch.load('models/pytorch_model.bin', map_location='cpu') | |
transformer_dict = {} | |
for key,value in finetuned_path.items(): | |
if 'transformer.base_model.model.' in key: | |
transformer_dict[key.replace('transformer.base_model.model.', '')] = value | |
pipeline.transformer.load_state_dict(transformer_dict, strict=False) | |
pipeline.to('cuda') | |
pipeline.enable_vae_slicing() | |
pipeline.enable_vae_tiling() | |
def decode(latents, pipeline): | |
latents = latents / pipeline.vae.config.scaling_factor | |
image = pipeline.vae.decode(latents, return_dict=False)[0] | |
return image | |
def encode_target_images(images, pipeline): | |
latents = pipeline.vae.encode(images).latent_dist.sample() | |
latents = latents * pipeline.vae.config.scaling_factor | |
return latents | |
def generate_image(text, img1, img2, img3, guidance_scale, inference_steps, seed, rigid_object, enable_cpu_offload=False): | |
if enable_cpu_offload: | |
pipeline.enable_sequential_cpu_offload() | |
input_images = [img1, img2, img3] | |
# Delete None | |
input_images = [img for img in input_images if img is not None] | |
if len(input_images) == 0: | |
return "Please upload at least one image" | |
numref = len(input_images) + 1 | |
images = torch.cat([2. * torch.from_numpy(np.array(Image.open(img).convert('RGB').resize((512, 512)))).permute(2, 0, 1).unsqueeze(0).to(torch_dtype)/255. -1. for img in input_images]) | |
images = images.to(pipeline.device) | |
latents = encode_target_images(images, pipeline) | |
latents = torch.cat([torch.zeros_like(latents[:1]), latents], dim=0) | |
masklatent = torch.zeros_like(latents) | |
masklatent[:1] = 1. | |
latents = rearrange(latents, "(b n) c h w -> b c h (n w)", n=numref) | |
masklatent = rearrange(masklatent, "(b n) c h w -> b c h (n w)", n=numref) | |
B, C, H, W = latents.shape | |
latents = pipeline._pack_latents(latents, B, C, H, W) | |
masklatent = pipeline._pack_latents(masklatent.expand(-1, C, -1, -1) ,B, C, H, W) | |
output = pipeline( | |
text, | |
latents_ref=latents, | |
latents_mask=masklatent, | |
guidance_scale=guidance_scale, | |
num_inference_steps=inference_steps, | |
height=512, | |
width=numref * 512, | |
generator = torch.Generator(device="cpu").manual_seed(seed), | |
joint_attention_kwargs={'shared_attn': True, 'num': numref}, | |
return_dict=False, | |
)[0][0] | |
output = rearrange(output, "b c h (n w) -> (b n) c h w", n=numref)[::numref] | |
img = Image.fromarray( (( torch.clip(output[0].float(), -1., 1.).permute(1,2,0).cpu().numpy()*0.5+0.5)*255).astype(np.uint8) ) | |
return img | |
def get_example(): | |
case = [ | |
[ | |
"An action figure on top of a mountain. Sunset in the background. Realistic shot.", | |
"./imgs/test_cases/action_figure/0.jpg", | |
"./imgs/test_cases/action_figure/1.jpg", | |
"./imgs/test_cases/action_figure/2.jpg", | |
3.5, | |
42, | |
True, | |
], | |
[ | |
"A penguin plushie wearing pink sunglasses is lounging on a beach. Realistic shot.", | |
"./imgs/test_cases/penguin/0.jpg", | |
"./imgs/test_cases/penguin/1.jpg", | |
"./imgs/test_cases/penguin/2.jpg", | |
3.5, | |
42, | |
True, | |
], | |
[ | |
"A toy on a beach. Waves in the background. Realistic shot.", | |
"./imgs/test_cases/rc_car/02.jpg", | |
"./imgs/test_cases/rc_car/03.jpg", | |
"./imgs/test_cases/rc_car/04.jpg", | |
3.5, | |
42, | |
True, | |
], | |
] | |
return case | |
def run_for_examples(text, img1, img2, img3, guidance_scale, seed, rigid_object, enable_cpu_offload=False): | |
inference_steps = 30 | |
return generate_image( | |
text, img1, img2, img3, guidance_scale, inference_steps, seed, rigid_object, enable_cpu_offload | |
) | |
description = """ | |
Synthetic Customization Dataset (SynCD) consists of multiple images of the same object in different contexts. We achieve it by promoting similar object identity using either explicit 3D object assets or, more implicitly, using masked shared attention across different views while generating images. Given this training data, we train a new encoder-based model for the task, which can successfully generate new compositions of a reference object using text prompts. You can download our dataset [here](https://huggingface.co/datasets/nupurkmr9/syncd). | |
Our model supports multiple input images of the same object as references. You can upload up to 3 images, with better results on 3 images vs 1 image. | |
**HF Spaces often encounter errors due to quota limitations, so recommend to run it locally.** | |
""" | |
article = """ | |
--- | |
**Citation** | |
<br> | |
If you find this repository useful, please consider giving a star ⭐ and a citation | |
``` | |
@article{kumari2025syncd, | |
title={Generating Multi-Image Synthetic Data for Text-to-Image Customization}, | |
author={Kumari, Nupur and Yin, Xi and Zhu, Jun-Yan and Misra, Ishan and Azadi, Samaneh}, | |
journal={ArXiv}, | |
year={2025} | |
} | |
``` | |
**Contact** | |
<br> | |
If you have any questions, please feel free to open an issue or directly reach us out via email. | |
**Acknowledgement** | |
<br> | |
This space was modified from [OmniGen](https://huggingface.co/spaces/Shitao/OmniGen) space. | |
""" | |
# Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown("# SynCD: Generating Multi-Image Synthetic Data for Text-to-Image Customization [[paper](https://arxiv.org/abs/2502.01720)] [[code](https://github.com/nupurkmr9/syncd)]") | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
# text prompt | |
prompt_input = gr.Textbox( | |
label="Enter your prompt, more descriptive prompt will lead to better results", placeholder="Type your prompt here..." | |
) | |
with gr.Row(equal_height=True): | |
# input images | |
image_input_1 = gr.Image(label="img1", type="filepath") | |
image_input_2 = gr.Image(label="img2", type="filepath") | |
image_input_3 = gr.Image(label="img3", type="filepath") | |
guidance_scale_input = gr.Slider( | |
label="Guidance Scale", minimum=1.0, maximum=5.0, value=3.5, step=0.1 | |
) | |
num_inference_steps = gr.Slider( | |
label="Inference Steps", minimum=1, maximum=100, value=30, step=1 | |
) | |
seed_input = gr.Slider( | |
label="Seed", minimum=0, maximum=2147483647, value=42, step=1 | |
) | |
rigid_object = gr.Checkbox( | |
label="rigid_object", info="Whether its a rigid object or a deformable object like pet animals, wearable etc.", value=True, | |
) | |
enable_cpu_offload = gr.Checkbox( | |
label="Enable CPU Offload", info="Enable CPU Offload to avoid memory issues", value=False, | |
) | |
# generate | |
generate_button = gr.Button("Generate Image") | |
with gr.Column(): | |
# output image | |
output_image = gr.Image(label="Output Image") | |
# click | |
generate_button.click( | |
generate_image, | |
inputs=[ | |
prompt_input, | |
image_input_1, | |
image_input_2, | |
image_input_3, | |
guidance_scale_input, | |
num_inference_steps, | |
seed_input, | |
rigid_object, | |
enable_cpu_offload, | |
], | |
outputs=output_image, | |
) | |
gr.Examples( | |
examples=get_example(), | |
fn=run_for_examples, | |
inputs=[ | |
prompt_input, | |
image_input_1, | |
image_input_2, | |
image_input_3, | |
guidance_scale_input, | |
seed_input, | |
rigid_object, | |
], | |
outputs=output_image, | |
) | |
gr.Markdown(article) | |
# launch | |
demo.launch() | |