diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..056072503827445de516bb91226150edf90d49fe 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.webp filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index 34e9de4e6e8e9ac57f099e3ede376625fa49e6cf..833544f68e941bd9a5c761c620fbb9b150621849 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,7 @@ sdk_version: 5.23.1
app_file: app.py
pinned: false
license: apache-2.0
+short_description: an A subject-drivent image generation control toolkit
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
\ No newline at end of file
diff --git a/app.py b/app.py
index 5c89a9143cd5c1cffd768e81c76cd3cec86a80ae..71614c6e03ef01876a06211dd7769f9bffe7d584 100644
--- a/app.py
+++ b/app.py
@@ -1,89 +1,106 @@
-import os
-import base64
-import io
-from typing import TypedDict
-import requests
+import spaces
import gradio as gr
-from PIL import Image
+import torch
+from typing import TypedDict
+from PIL import Image, ImageDraw, ImageFont
+from diffusers.pipelines import FluxPipeline
+from diffusers import FluxTransformer2DModel
+import numpy as np
+import examples_db
-# Read Baseten configuration from environment variables.
-BTEN_API_KEY = os.getenv("API_KEY")
-URL = os.getenv("URL")
-def image_to_base64(image: Image.Image) -> str:
- """Convert a PIL image to a base64-encoded PNG string."""
- with io.BytesIO() as buffer:
- image.save(buffer, format="PNG")
- return base64.b64encode(buffer.getvalue()).decode("utf-8")
+from flux.condition import Condition
+from flux.generate import seed_everything, generate
+from flux.lora_controller import set_lora_scale
-def ensure_image(img) -> Image.Image:
- """
- Ensure the input is a PIL Image.
- If it's already a PIL Image, return it.
- If it's a string (file path), open it.
- If it's a dict with a "name" key, open the file at that path.
- """
- if isinstance(img, Image.Image):
- return img
- elif isinstance(img, str):
- return Image.open(img)
- elif isinstance(img, dict) and "name" in img:
- return Image.open(img["name"])
+pipe = None
+current_adapter = None
+use_int8 = False
+model_config = { "union_cond_attn": True, "add_cond_attn": False, "latent_lora": False, "independent_condition": True}
+
+def get_gpu_memory():
+ return torch.cuda.get_device_properties(0).total_memory / 1024**3
+
+
+def init_pipeline():
+ global pipe
+ if use_int8 or get_gpu_memory() < 33:
+ transformer_model = FluxTransformer2DModel.from_pretrained(
+ "sayakpaul/flux.1-schell-int8wo-improved",
+ torch_dtype=torch.bfloat16,
+ use_safetensors=False,
+ )
+ pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell",
+ transformer=transformer_model,
+ torch_dtype=torch.bfloat16,
+ )
else:
- raise ValueError("Cannot convert input to a PIL Image.")
-
-
-def call_baseten_generate(
- image: Image.Image,
- prompt: str,
- steps: int,
- strength: float,
- height: int,
- width: int,
- lora_name: str,
- remove_bg: bool,
-) -> Image.Image | None:
+ pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
+ )
+ pipe = pipe.to("cuda")
+
+
+ # Optional: Load additional LoRA weights
+ pipe.load_lora_weights(
+ "fotographerai/zenctrl_tools",
+ weight_name="weights/zen2con_1024_10000/"
+ "pytorch_lora_weights.safetensors",
+ adapter_name="subject"
+ )
+
+ # Optional: Load additional LoRA weights
+ #pipe.load_lora_weights("XLabs-AI/flux-RealismLora", adapter_name="realism")
+
+
+def paste_on_white_background(image: Image.Image) -> Image.Image:
"""
- Call the Baseten /predict endpoint with provided parameters and return the generated image.
+ Pastes a transparent image onto a white background of the same size.
"""
- image = ensure_image(image)
- b64_image = image_to_base64(image)
- payload = {
- "image": b64_image,
- "prompt": prompt,
- "steps": steps,
- "strength": strength,
- "height": height,
- "width": width,
- "lora_name": lora_name,
- "bgrm": remove_bg,
- }
- if not BTEN_API_KEY:
- headers = {"Authorization": f"Api-Key {os.getenv('API_KEY')}"}
- else:
- headers = {"Authorization": f"Api-Key {BTEN_API_KEY}"}
- try:
- if not URL:
- raise ValueError("The URL environment variable is not set.")
-
- response = requests.post(URL, headers=headers, json=payload)
- if response.status_code == 200:
- data = response.json()
- gen_b64 = data.get("generated_image", None)
- if gen_b64:
- return Image.open(io.BytesIO(base64.b64decode(gen_b64)))
- else:
- return None
- else:
- print(f"Error: HTTP {response.status_code}\n{response.text}")
- return None
- except Exception as e:
- print(f"Error: {e}")
- return None
-
-
-# Mode defaults for each tab.
+ if image.mode != "RGBA":
+ image = image.convert("RGBA")
+
+ # Create white background
+ white_bg = Image.new("RGBA", image.size, (255, 255, 255, 255))
+ white_bg.paste(image, (0, 0), mask=image)
+ return white_bg.convert("RGB") # Convert back to RGB if you don't need alpha
+
+#@spaces.GPU
+def process_image_and_text(image, text, steps=8, strength_sub=1.0, strength_spat=1.0, size=1024):
+ # center crop image
+ w, h, min_size = image.size[0], image.size[1], min(image.size)
+ image = image.crop(
+ (
+ (w - min_size) // 2,
+ (h - min_size) // 2,
+ (w + min_size) // 2,
+ (h + min_size) // 2,
+ )
+ )
+ image = image.resize((size, size))
+ image = paste_on_white_background(image)
+ condition0 = Condition("subject", image, position_delta=(0, size // 16))
+ condition1 = Condition("subject", image, position_delta=(0, -size // 16))
+
+ pipe = get_pipeline()
+
+ with set_lora_scale(["subject"], scale=3.0):
+ result_img = generate(
+ pipe,
+ prompt=text.strip(),
+ conditions=[condition0, condition1],
+ num_inference_steps=steps,
+ height=1024,
+ width=1024,
+ condition_scale = [strength_sub,strength_spat],
+ model_config=model_config,
+ ).images[0]
+
+ return result_img
+
+# ================== MODE CONFIG =====================
Mode = TypedDict(
"Mode",
@@ -98,77 +115,76 @@ Mode = TypedDict(
},
)
+MODEL_TO_LORA: dict[str, str] = {
+ # dropdown-value # relative path inside the HF repo
+ "zen2con_1024_10000": "weights/zen2con_1024_10000/pytorch_lora_weights.safetensors",
+ "zen2con_1440_17000": "weights/zen2con_1440_17000/pytorch_lora_weights.safetensors",
+ "zen_sub_sub_1024_10000": "weights/zen_sub_sub_1024_10000/pytorch_lora_weights.safetensors",
+ "zen_toys_1024_4000": "weights/zen_toys_1024_4000/12000/pytorch_lora_weights.safetensors",
+ "zen_toys_1024_15000": "weights/zen_toys_1024_4000/zen_toys_1024_15000/pytorch_lora_weights.safetensors",
+ # add more as you upload them
+}
+
+
MODE_DEFAULTS: dict[str, Mode] = {
"Subject Generation": {
- "model": "subject_99000_512",
- "prompt": "A detailed portrait with soft lighting",
- "default_strength": 1.2,
- "default_height": 512,
- "default_width": 512,
- "models": [
- "zendsd_512_146000",
- "subject_99000_512",
- # "zen_pers_11000",
- "zen_26000_512",
- ],
- "remove_bg": True,
- },
- "Background Generation": {
- "model": "bg_canny_58000_1024",
+ "model": "zen2con_1024_10000",
"prompt": "A vibrant background with dynamic lighting and textures",
"default_strength": 1.2,
"default_height": 1024,
"default_width": 1024,
- "models": [
- "bgwlight_15000_1024",
- # "rmgb_12000_1024",
- "bg_canny_58000_1024",
- # "gen_back_3000_1024",
- "gen_back_7000_1024",
- # "gen_bckgnd_18000_512",
- # "gen_bckgnd_18000_512",
- # "loose_25000_512",
- # "looser_23000_1024",
- # "looser_bg_gen_21000_1280",
- # "old_looser_46000_1024",
- # "relight_bg_gen_31000_1024",
- ],
- "remove_bg": True,
- },
- "Canny": {
- "model": "canny_21000_1024",
- "prompt": "A futuristic cityscape with neon lights",
- "default_strength": 1.2,
- "default_height": 1024,
- "default_width": 1024,
- "models": ["canny_21000_1024"],
+ "models": list(MODEL_TO_LORA.keys()),
"remove_bg": True,
},
- "Depth": {
- "model": "depth_9800_1024",
- "prompt": "A scene with pronounced depth and perspective",
- "default_strength": 1.2,
- "default_height": 1024,
- "default_width": 1024,
- "models": [
- "depth_9800_1024",
- ],
- "remove_bg": True,
- },
- "Deblurring": {
- "model": "deblurr_1024_10000",
- "prompt": "A scene with pronounced depth and perspective",
- "default_strength": 1.2,
- "default_height": 1024,
- "default_width": 1024,
- "models": ["deblurr_1024_10000"], # "slight_deblurr_18000",
- "remove_bg": False,
- },
-}
+ #"Image fix": {
+ # "model": "zen_toys_1024_4000",
+ # "prompt": "A detailed portrait with soft lighting",
+ # "default_strength": 1.2,
+ # "default_height": 1024,
+ # "default_width": 1024,
+ # "models": ["weights/zen_toys_1024_4000/12000/", "weights/zen_toys_1024_4000/12000/"],
+ # "remove_bg": True,
+ #}
+ }
+
+
+def get_pipeline():
+ """Lazy-build the pipeline inside the GPU worker."""
+ global pipe
+ if pipe is None:
+ init_pipeline() # safe here β this fn is @spaces.GPU wrapped
+ return pipe
+
+def get_samples():
+ sample_list = [
+ {
+ "image": "samples/1.png",
+ "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
+ },
+ {
+ "image": "samples/2.png",
+ "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
+ },
+ {
+ "image": "samples/3.png",
+ "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
+ },
+ {
+ "image": "samples/4.png",
+ "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
+ },
+ {
+ "image": "samples/5.png",
+ "text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.",
+ },
+ ]
+ return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list]
+
+# =============== UI ===============
header = """
-
π ZenCtrl / FLUX
+π ZenCtrl medium

@@ -178,136 +194,110 @@ header = """
"""
-defaults = MODE_DEFAULTS["Subject Generation"]
-
-
-with gr.Blocks(title="π ZenCtrl") as demo:
+with gr.Blocks(title="π ZenCtrl-medium") as demo:
+ # ---------- banner ----------
gr.HTML(header)
gr.Markdown(
"""
# ZenCtrl Demo
- [WIP] One Agent to Generate multi-view, diverse-scene, and task-specific high-resolution images from a single subject imageβwithout fine-tuning.
+ One framework to Generate multi-view, diverse-scene, and task-specific high-resolution images from a single subject imageβwithout fine-tuning.
We are first releasing some of the task specific weights and will release the codes soon.
The goal is to unify all of the visual content generation tasks with a single LLM...
- **Modes:**
- - **Subject Generation:** Focuses on generating detailed subject portraits.
- - **Background Generation:** Creates dynamic, vibrant backgrounds:
- You can generate part of the image from sketch while keeping part of it as it is.
- - **Canny:** Emphasizes strong edge detection.
- - **Depth:** Produces images with realistic depth and perspective.
+ **Mode:**
+ - **Subject-driven Image Generation:** Generate in-context images of your subject with high fidelity and in different perspectives.
For more details, shoot us a message on discord.
"""
)
+
+ # ---------- tab bar ----------
with gr.Tabs():
- for mode in MODE_DEFAULTS:
- with gr.Tab(mode):
- defaults = MODE_DEFAULTS[mode]
- gr.Markdown(f"### {mode} Mode")
- gr.Markdown(f"**Default Model:** {defaults['model']}")
+ for mode_name, defaults in MODE_DEFAULTS.items():
+ with gr.Tab(mode_name):
+ gr.Markdown(f"### {mode_name}")
+ # -------- left (input) column --------
with gr.Row():
- with gr.Column(scale=2, min_width=370):
- input_image = gr.Image(
- label="Upload Image",
- type="pil",
- scale=3,
- height=370,
- min_width=100,
+ with gr.Column(scale=2):
+ input_image = gr.Image(label="Input Image", type="pil")
+ model_dropdown = gr.Dropdown(
+ label="Model (LoRA adapter)",
+ choices=defaults["models"],
+ value=defaults["model"],
+ interactive=True,
)
- generate_button = gr.Button("Generate")
- with gr.Blocks(title="Options"):
- model_dropdown = gr.Dropdown(
- label="Model",
- choices=defaults["models"],
- value=defaults["model"],
- interactive=True,
- )
- remove_bg_checkbox = gr.Checkbox(
- label="Remove Background", value=defaults["remove_bg"]
- )
+ prompt_box = gr.Textbox(label="Prompt",
+ value=defaults["prompt"], lines=2)
+ generate_btn = gr.Button("Generate")
+
+ with gr.Accordion("Generation Parameters", open=False):
+ step_slider = gr.Slider(2, 28, value=12, step=2, label="Steps")
+ strength_sub_slider = gr.Slider(0.0, 2.0,
+ value=defaults["default_strength"],
+ step=0.1, label="Strength (subject)")
+ strength_spat_slider = gr.Slider(0.0, 2.0,
+ value=defaults["default_strength"],
+ step=0.1, label="Strength (spatial)")
+ size_slider = gr.Slider(512, 2048,
+ value=defaults["default_height"],
+ step=64, label="Size (px)")
+ # -------- right (output) column --------
with gr.Column(scale=2):
- output_image = gr.Image(
- label="Generated Image",
- type="pil",
- height=573,
- scale=4,
- min_width=100,
- )
+ output_image = gr.Image(label="Output Image", type="pil")
- gr.Markdown("#### Prompt")
- prompt_box = gr.Textbox(
- label="Prompt", value=defaults["prompt"], lines=2
- )
+ # ---------- click handler ----------
+ @spaces.GPU
+ def _run(image, model_name, prompt, steps, s_sub, s_spat, size):
+ global current_adapter
- # Wrap generation parameters in an Accordion for collapsible view.
- with gr.Accordion("Generation Parameters", open=False):
- with gr.Row():
- step_slider = gr.Slider(
- minimum=2, maximum=28, value=2, step=2, label="Steps"
- )
- strength_slider = gr.Slider(
- minimum=0.5,
- maximum=2.0,
- value=defaults["default_strength"],
- step=0.1,
- label="Strength",
- )
- with gr.Row():
- height_slider = gr.Slider(
- minimum=512,
- maximum=1360,
- value=defaults["default_height"],
- step=1,
- label="Height",
- )
- width_slider = gr.Slider(
- minimum=512,
- maximum=1360,
- value=defaults["default_width"],
- step=1,
- label="Width",
+ pipe = get_pipeline()
+
+ # ββ switch adapter if needed ββββββββββββββββββββββββββ
+ if model_name != current_adapter:
+ lora_path = MODEL_TO_LORA[model_name]
+ # load & activate the chosen adapter
+ pipe.load_lora_weights(
+ "fotographerai/zenctrl_tools",
+ weight_name=lora_path,
+ adapter_name=model_name,
)
+ pipe.set_adapters([model_name])
+ current_adapter = model_name
- def on_generate_click(
- model_name,
- prompt,
- steps,
- strength,
- height,
- width,
- remove_bg,
- image,
- ):
- return call_baseten_generate(
- image,
- prompt,
- steps,
- strength,
- height,
- width,
- model_name,
- remove_bg,
+ # ββ run generation βββββββββββββββββββββββββββββββββββ
+ delta = size // 16
+ return process_image_and_text(
+ image, prompt, steps=steps,
+ strength_sub=s_sub, strength_spat=s_spat, size=size
)
- generate_button.click(
- fn=on_generate_click,
- inputs=[
- model_dropdown,
- prompt_box,
- step_slider,
- strength_slider,
- height_slider,
- width_slider,
- remove_bg_checkbox,
- input_image,
- ],
+ generate_btn.click(
+ fn=_run,
+ inputs=[input_image, model_dropdown, prompt_box,
+ step_slider, strength_sub_slider,
+ strength_spat_slider, size_slider],
outputs=[output_image],
- concurrency_limit=None
)
+ # ---------------- Templates --------------------
+ if examples_db.MODE_EXAMPLES.get(mode_name):
+ gr.Examples(
+ examples=examples_db.MODE_EXAMPLES[mode_name],
+ inputs=[ input_image, # Image widget
+ model_dropdown, # Dropdown for adapter
+ prompt_box, # Textbox for prompt
+ output_image, # Gallery for output
+ ],
+ label="Presets (Image / Model / Prompt)",
+ examples_per_page=15,
+ )
+# =============== launch ===============
if __name__ == "__main__":
- demo.launch()
\ No newline at end of file
+ #init_pipeline()
+ demo.launch(
+ debug=True,
+ share=True
+ )
\ No newline at end of file
diff --git a/examples_db.py b/examples_db.py
new file mode 100644
index 0000000000000000000000000000000000000000..071f98ed27429df85d64ecc834abc73a0b745f83
--- /dev/null
+++ b/examples_db.py
@@ -0,0 +1,96 @@
+MODE_EXAMPLES = {
+ "Subject Generation": [
+ [
+ "samples/7.png",
+ "zen2con_1440_17000",
+ "A man wearing white shoes stepping outside, closeup on the shoes",
+ "samples/22_out.webp",
+ ],
+ [
+ "samples/7.png",
+ "zen2con_1440_17000",
+ "Low angle photography, shoes, stepping in water in a futuristic cityscape with neon lights, in the back in large, the words 'ZenCtrl 2025' are written in a futuristic font",
+ "samples/2_out.webp",
+ ],
+ [
+ "samples/8.png",
+ "zen2con_1024_10000",
+ "a watch, resting on wet volcanic rock, ocean spray mist, golden sunrise back-light, 50 mm lens, shallow DOF, 8k realism",
+ "samples/3_out.webp",
+ ],
+ ["samples/9.png","zen_sub_sub_1024_10000", "a bag, placed on a desk in a cozy bedroom, sunny light coming through the window", "samples/4_out.webp"],
+ ["samples/10.png", "zen2con_1440_17000","A bottle sits poolside at a luxury hotel. The bottle rests in front of a clear blue pool and resort landscape in the background. Light reflects off the bottle highlighting its sophisticated design. An elegant glass and tropical fruits are included creating a relaxing atmosphere. The image should convey luxury sophistication and an otherworldly refreshment.", "samples/5_out.webp"],
+
+ #[
+ # "samples/11.png",
+ # "zen_toys_1024_4000",
+ # "Low angle photography, shoes, stepping in water in a futuristic cityscape with neon lights, in the back in large, the words 'ZenCtrl 2025' are written in a futuristic font",
+ # "samples/1.png",
+ #],
+ [
+ "samples/12.png",
+ "zen2con_1024_10000",
+ "a woman , wearing a pair of sunglasses, in front of the beach",
+ "samples/6_out.webp",
+ ],
+ [
+ "samples/13.png",
+ "zen2con_1440_17000",
+ "a woman , standing outside , in the streets, next to a cafe",
+ "samples/7_out.webp",
+ ],
+ ["samples/14.png","zen_sub_sub_1024_10000", "a bag , held by woman , walking outside", "samples/8_out.webp"],
+ ["samples/15.png","zen_sub_sub_1024_10000", "a man wearing a suit, sitting on a chair in a living room", "samples/9_out.webp"],
+ [
+ "samples/6.png",
+ "zen_toys_1024_4000",
+ "A kid playing with a toy figurine , indoor, on a sunny day",
+ "samples/11_out.webp",
+ ],
+ [
+ "samples/6.png",
+ "zen_toys_1024_4000",
+ "A small figurine placed on a table, surrounded by toys. A child is placing it",
+ "samples/12_out.webp",
+ ],
+ [
+ "samples/17.png",
+ "zen2con_1024_10000",
+ "A black man wearing black wireless headphones at a basketball game",
+ "samples/20_out.webp",
+ ],
+ [
+ "samples/21_1.png",
+ "zen2con_1024_10000",
+ "a man holding a camera facing the objective.",
+ "samples/21_out.webp",
+ ],
+ ],
+ "Image fix": [
+ [
+ "samples/1.png",
+ "placed on a dark marble table in a bathroom of luxury hotel modern light authentic atmosphere",
+ "samples/1.png",
+ ],
+ [
+ "samples/1.png",
+ "sitting on the middle of the city road on a sunny day very bright day front view",
+ "samples/1.png",
+ ],
+ [
+ "samples/1.png",
+ "A creative capture in an art gallery, with soft, focused lighting highlighting both the personβs features and the abstract surroundings, exuding sophistication.",
+ "samples/1.png",
+ ],
+ [
+ "samples/1.png",
+ "In a rain-soaked urban nightscape, with headlights piercing through the mist and wet streets reflecting the cityβs vibrant neon colors, creating an atmosphere of mystery and modern elegance.",
+ "samples/1.png",
+ ],
+ [
+ "samples/1.png",
+ "An elegant room scene featuring a minimalist table and chairs, next to a flower vase, illuminated by ambient lighting that casts gentle shadows and enhances the refined, contemporary decor.",
+ "samples/1.png",
+ ],
+ ],
+}
diff --git a/flux/__init__.py b/flux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/flux/block.py b/flux/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6b1c5c7eb72fe387064e62a264c1067b18b5541
--- /dev/null
+++ b/flux/block.py
@@ -0,0 +1,461 @@
+# Recycled from Ominicontrol and modified to accept an extra condition.
+# While Zenctrl pursued a similar idea, it diverged structurally.
+# We appreciate the clarity of Omini's implementation and decided to align with it.
+
+import torch
+from typing import List, Union, Optional, Dict, Any, Callable
+from diffusers.models.attention_processor import Attention, F
+from .lora_controller import enable_lora
+from diffusers.models.embeddings import apply_rotary_emb
+
+def attn_forward(
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ condition_latents: torch.FloatTensor = None,
+ extra_condition_latents: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ cond_rotary_emb: Optional[torch.Tensor] = None,
+ extra_cond_rotary_emb: Optional[torch.Tensor] = None,
+ model_config: Optional[Dict[str, Any]] = {},
+) -> torch.FloatTensor:
+ batch_size, _, _ = (
+ hidden_states.shape
+ if encoder_hidden_states is None
+ else encoder_hidden_states.shape
+ )
+
+ with enable_lora(
+ (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
+ ):
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ if encoder_hidden_states is not None:
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(
+ encoder_hidden_states_query_proj
+ )
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(
+ encoder_hidden_states_key_proj
+ )
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ if condition_latents is not None:
+ cond_query = attn.to_q(condition_latents)
+ cond_key = attn.to_k(condition_latents)
+ cond_value = attn.to_v(condition_latents)
+
+ cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
+ 1, 2
+ )
+ cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
+ 1, 2
+ )
+ if attn.norm_q is not None:
+ cond_query = attn.norm_q(cond_query)
+ if attn.norm_k is not None:
+ cond_key = attn.norm_k(cond_key)
+
+ #extra condition
+ if extra_condition_latents is not None:
+ extra_cond_query = attn.to_q(extra_condition_latents)
+ extra_cond_key = attn.to_k(extra_condition_latents)
+ extra_cond_value = attn.to_v(extra_condition_latents)
+
+ extra_cond_query = extra_cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
+ 1, 2
+ )
+ extra_cond_key = extra_cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ extra_cond_value = extra_cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
+ 1, 2
+ )
+ if attn.norm_q is not None:
+ extra_cond_query = attn.norm_q(extra_cond_query)
+ if attn.norm_k is not None:
+ extra_cond_key = attn.norm_k(extra_cond_key)
+
+
+ if extra_cond_rotary_emb is not None:
+ extra_cond_query = apply_rotary_emb(extra_cond_query, extra_cond_rotary_emb)
+ extra_cond_key = apply_rotary_emb(extra_cond_key, extra_cond_rotary_emb)
+
+ if cond_rotary_emb is not None:
+ cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
+ cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
+
+ if condition_latents is not None:
+ if extra_condition_latents is not None:
+
+ query = torch.cat([query, cond_query, extra_cond_query], dim=2)
+ key = torch.cat([key, cond_key, extra_cond_key], dim=2)
+ value = torch.cat([value, cond_value, extra_cond_value], dim=2)
+ else:
+ query = torch.cat([query, cond_query], dim=2)
+ key = torch.cat([key, cond_key], dim=2)
+ value = torch.cat([value, cond_value], dim=2)
+ print("concat Omini latents: ", query.shape, key.shape, value.shape)
+
+
+ if not model_config.get("union_cond_attn", True):
+
+ attention_mask = torch.ones(
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
+ )
+ condition_n = cond_query.shape[2]
+ attention_mask[-condition_n:, :-condition_n] = False
+ attention_mask[:-condition_n, -condition_n:] = False
+ elif model_config.get("independent_condition", False):
+ attention_mask = torch.ones(
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
+ )
+ condition_n = cond_query.shape[2]
+ attention_mask[-condition_n:, :-condition_n] = False
+
+ if hasattr(attn, "c_factor"):
+ attention_mask = torch.zeros(
+ query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
+ )
+ condition_n = cond_query.shape[2]
+ condition_e = extra_cond_query.shape[2]
+ bias = torch.log(attn.c_factor[0])
+ attention_mask[-condition_n-condition_e:-condition_e, :-condition_n-condition_e] = bias
+ attention_mask[:-condition_n-condition_e, -condition_n-condition_e:-condition_e] = bias
+
+ bias = torch.log(attn.c_factor[1])
+ attention_mask[-condition_e:, :-condition_n-condition_e] = bias
+ attention_mask[:-condition_n-condition_e, -condition_e:] = bias
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ if condition_latents is not None:
+ if extra_condition_latents is not None:
+ encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]*2
+ ],
+ hidden_states[:, -condition_latents.shape[1]*2 :-condition_latents.shape[1]],
+ hidden_states[:, -condition_latents.shape[1] :], #extra condition latents
+ )
+ else:
+ encoder_hidden_states, hidden_states, condition_latents = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
+ ],
+ hidden_states[:, -condition_latents.shape[1] :]
+ )
+ else:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if condition_latents is not None:
+ condition_latents = attn.to_out[0](condition_latents)
+ condition_latents = attn.to_out[1](condition_latents)
+
+ if extra_condition_latents is not None:
+ extra_condition_latents = attn.to_out[0](extra_condition_latents)
+ extra_condition_latents = attn.to_out[1](extra_condition_latents)
+
+
+ return (
+ # (hidden_states, encoder_hidden_states, condition_latents, extra_condition_latents)
+ (hidden_states, encoder_hidden_states, condition_latents, extra_condition_latents)
+ if condition_latents is not None
+ else (hidden_states, encoder_hidden_states)
+ )
+ elif condition_latents is not None:
+ # if there are condition_latents, we need to separate the hidden_states and the condition_latents
+ if extra_condition_latents is not None:
+ hidden_states, condition_latents, extra_condition_latents = (
+ hidden_states[:, : -condition_latents.shape[1]*2],
+ hidden_states[:, -condition_latents.shape[1]*2 :-condition_latents.shape[1]],
+ hidden_states[:, -condition_latents.shape[1] :],
+ )
+ else:
+ hidden_states, condition_latents = (
+ hidden_states[:, : -condition_latents.shape[1]],
+ hidden_states[:, -condition_latents.shape[1] :],
+ )
+ return hidden_states, condition_latents, extra_condition_latents
+ else:
+ return hidden_states
+
+
+def block_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor,
+ condition_latents: torch.FloatTensor,
+ extra_condition_latents: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ cond_temb: torch.FloatTensor,
+ extra_cond_temb: torch.FloatTensor,
+ cond_rotary_emb=None,
+ extra_cond_rotary_emb=None,
+ image_rotary_emb=None,
+ model_config: Optional[Dict[str, Any]] = {},
+):
+ use_cond = condition_latents is not None
+
+ use_extra_cond = extra_condition_latents is not None
+ with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, emb=temb
+ )
+
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
+ self.norm1_context(encoder_hidden_states, emb=temb)
+ )
+
+ if use_cond:
+ (
+ norm_condition_latents,
+ cond_gate_msa,
+ cond_shift_mlp,
+ cond_scale_mlp,
+ cond_gate_mlp,
+ ) = self.norm1(condition_latents, emb=cond_temb)
+ (
+ norm_extra_condition_latents,
+ extra_cond_gate_msa,
+ extra_cond_shift_mlp,
+ extra_cond_scale_mlp,
+ extra_cond_gate_mlp,
+ ) = self.norm1(extra_condition_latents, emb=extra_cond_temb)
+
+ # Attention.
+ result = attn_forward(
+ self.attn,
+ model_config=model_config,
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ condition_latents=norm_condition_latents if use_cond else None,
+ extra_condition_latents=norm_extra_condition_latents if use_cond else None,
+ image_rotary_emb=image_rotary_emb,
+ cond_rotary_emb=cond_rotary_emb if use_cond else None,
+ extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_cond else None,
+ )
+ # print("in self block: ", result.shape)
+ attn_output, context_attn_output = result[:2]
+ cond_attn_output = result[2] if use_cond else None
+ extra_condition_output = result[3]
+
+ # Process attention outputs for the `hidden_states`.
+ # 1. hidden_states
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+ # 2. encoder_hidden_states
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+ # 3. condition_latents
+ if use_cond:
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
+ condition_latents = condition_latents + cond_attn_output
+ #need to make new condition_extra and add extra_condition_output
+ if use_extra_cond:
+ extra_condition_output = extra_cond_gate_msa.unsqueeze(1) * extra_condition_output
+ extra_condition_latents = extra_condition_latents + extra_condition_output
+
+ if model_config.get("add_cond_attn", False):
+ hidden_states += cond_attn_output
+ hidden_states += extra_condition_output
+
+
+ # LayerNorm + MLP.
+ # 1. hidden_states
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = (
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ )
+ # 2. encoder_hidden_states
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = (
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+ )
+ # 3. condition_latents
+ if use_cond:
+ norm_condition_latents = self.norm2(condition_latents)
+ norm_condition_latents = (
+ norm_condition_latents * (1 + cond_scale_mlp[:, None])
+ + cond_shift_mlp[:, None]
+ )
+
+ if use_extra_cond:
+ #added conditions
+ extra_norm_condition_latents = self.norm2(extra_condition_latents)
+ extra_norm_condition_latents = (
+ extra_norm_condition_latents * (1 + extra_cond_scale_mlp[:, None])
+ + extra_cond_shift_mlp[:, None]
+ )
+
+ # Feed-forward.
+ with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
+ # 1. hidden_states
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ # 2. encoder_hidden_states
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
+ # 3. condition_latents
+ if use_cond:
+ cond_ff_output = self.ff(norm_condition_latents)
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
+
+ if use_extra_cond:
+ extra_cond_ff_output = self.ff(extra_norm_condition_latents)
+ extra_cond_ff_output = extra_cond_gate_mlp.unsqueeze(1) * extra_cond_ff_output
+
+ # Process feed-forward outputs.
+ hidden_states = hidden_states + ff_output
+ encoder_hidden_states = encoder_hidden_states + context_ff_output
+ if use_cond:
+ condition_latents = condition_latents + cond_ff_output
+ if use_extra_cond:
+ extra_condition_latents = extra_condition_latents + extra_cond_ff_output
+
+ # Clip to avoid overflow.
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents if use_cond else None
+
+
+def single_block_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ image_rotary_emb=None,
+ condition_latents: torch.FloatTensor = None,
+ extra_condition_latents: torch.FloatTensor = None,
+ cond_temb: torch.FloatTensor = None,
+ extra_cond_temb: torch.FloatTensor = None,
+ cond_rotary_emb=None,
+ extra_cond_rotary_emb=None,
+ model_config: Optional[Dict[str, Any]] = {},
+):
+
+ using_cond = condition_latents is not None
+ using_extra_cond = extra_condition_latents is not None
+ residual = hidden_states
+ with enable_lora(
+ (
+ self.norm.linear,
+ self.proj_mlp,
+ ),
+ model_config.get("latent_lora", False),
+ ):
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+ if using_cond:
+ residual_cond = condition_latents
+ norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
+
+ if using_extra_cond:
+ extra_residual_cond = extra_condition_latents
+ extra_norm_condition_latents, extra_cond_gate = self.norm(extra_condition_latents, emb=extra_cond_temb)
+ extra_mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(extra_norm_condition_latents))
+
+ attn_output = attn_forward(
+ self.attn,
+ model_config=model_config,
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **(
+ {
+ "condition_latents": norm_condition_latents,
+ "cond_rotary_emb": cond_rotary_emb if using_cond else None,
+ "extra_condition_latents": extra_norm_condition_latents if using_cond else None,
+ "extra_cond_rotary_emb": extra_cond_rotary_emb if using_cond else None,
+ }
+ if using_cond
+ else {}
+ ),
+ )
+
+ if using_cond:
+ attn_output, cond_attn_output, extra_cond_attn_output = attn_output
+
+
+ with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ if using_cond:
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
+ cond_gate = cond_gate.unsqueeze(1)
+ condition_latents = cond_gate * self.proj_out(condition_latents)
+ condition_latents = residual_cond + condition_latents
+
+ extra_condition_latents = torch.cat([extra_cond_attn_output, extra_mlp_cond_hidden_states], dim=2)
+ extra_cond_gate = extra_cond_gate.unsqueeze(1)
+ extra_condition_latents = extra_cond_gate * self.proj_out(extra_condition_latents)
+ extra_condition_latents = extra_residual_cond + extra_condition_latents
+
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ return hidden_states if not using_cond else (hidden_states, condition_latents, extra_condition_latents)
diff --git a/flux/condition.py b/flux/condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..77b27c33bc408591f513f5b40077e5ebda746a5f
--- /dev/null
+++ b/flux/condition.py
@@ -0,0 +1,89 @@
+# Recycled from Ominicontrol and modified to accept an extra condition.
+# While Zenctrl pursued a similar idea, it diverged structurally.
+# We appreciate the clarity of Omini's implementation and decided to align with it.
+
+import torch
+from typing import Optional, Union, List, Tuple
+from diffusers.pipelines import FluxPipeline
+from PIL import Image, ImageFilter
+import numpy as np
+import cv2
+
+# from pipeline_tools import encode_images
+from .pipeline_tools import encode_images
+
+condition_dict = {
+ "subject": 1,
+ "sr": 2,
+ "cot": 3,
+}
+
+
+class Condition(object):
+ def __init__(
+ self,
+ condition_type: str,
+ raw_img: Union[Image.Image, torch.Tensor] = None,
+ condition: Union[Image.Image, torch.Tensor] = None,
+ position_delta=None,
+ ) -> None:
+ self.condition_type = condition_type
+ assert raw_img is not None or condition is not None
+ if raw_img is not None:
+ self.condition = self.get_condition(condition_type, raw_img)
+ else:
+ self.condition = condition
+ self.position_delta = position_delta
+
+
+ def get_condition(
+ self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
+ ) -> Union[Image.Image, torch.Tensor]:
+ """
+ Returns the condition image.
+ """
+ if condition_type == "subject":
+ return raw_img
+ elif condition_type == "sr":
+ return raw_img
+ elif condition_type == "cot":
+ return raw_img.convert("RGB")
+ return self.condition
+
+
+ @property
+ def type_id(self) -> int:
+ """
+ Returns the type id of the condition.
+ """
+ return condition_dict[self.condition_type]
+
+ def encode(
+ self, pipe: FluxPipeline, empty: bool = False
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
+ """
+ Encodes the condition into tokens, ids and type_id.
+ """
+ if self.condition_type in [
+ "subject",
+ "sr",
+ "cot"
+ ]:
+ if empty:
+ # make the condition black
+ e_condition = Image.new("RGB", self.condition.size, (0, 0, 0))
+ e_condition = e_condition.convert("RGB")
+ tokens, ids = encode_images(pipe, e_condition)
+ else:
+ tokens, ids = encode_images(pipe, self.condition)
+ else:
+ raise NotImplementedError(
+ f"Condition type {self.condition_type} not implemented"
+ )
+ if self.position_delta is None and self.condition_type == "subject":
+ self.position_delta = [0, -self.condition.size[0] // 16]
+ if self.position_delta is not None:
+ ids[:, 1] += self.position_delta[0]
+ ids[:, 2] += self.position_delta[1]
+ type_id = torch.ones_like(ids[:, :1]) * self.type_id
+ return tokens, ids, type_id
diff --git a/flux/generate.py b/flux/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..567aa2fdc018cd38ef02178915cc3a3187637c87
--- /dev/null
+++ b/flux/generate.py
@@ -0,0 +1,337 @@
+# Recycled from Ominicontrol and modified to accept an extra condition.
+# While Zenctrl pursued a similar idea, it diverged structurally.
+# We appreciate the clarity of Omini's implementation and decided to align with it.
+
+import torch
+import yaml, os
+from diffusers.pipelines import FluxPipeline
+from typing import List, Union, Optional, Dict, Any, Callable
+from .transformer import tranformer_forward
+from .condition import Condition
+
+
+from diffusers.pipelines.flux.pipeline_flux import (
+ FluxPipelineOutput,
+ calculate_shift,
+ retrieve_timesteps,
+ np,
+)
+
+
+def get_config(config_path: str = None):
+ config_path = config_path or os.environ.get("XFL_CONFIG")
+ if not config_path:
+ return {}
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+ return config
+
+
+def prepare_params(
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ **kwargs: dict,
+):
+ return (
+ prompt,
+ prompt_2,
+ height,
+ width,
+ num_inference_steps,
+ timesteps,
+ guidance_scale,
+ num_images_per_prompt,
+ generator,
+ latents,
+ prompt_embeds,
+ pooled_prompt_embeds,
+ output_type,
+ return_dict,
+ joint_attention_kwargs,
+ callback_on_step_end,
+ callback_on_step_end_tensor_inputs,
+ max_sequence_length,
+ )
+
+
+def seed_everything(seed: int = 42):
+ torch.backends.cudnn.deterministic = True
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+
+@torch.no_grad()
+def generate(
+ pipeline: FluxPipeline,
+ conditions: List[Condition] = None,
+ config_path: str = None,
+ model_config: Optional[Dict[str, Any]] = {},
+ condition_scale: float = [1, 1],
+ default_lora: bool = False,
+ image_guidance_scale: float = 1.0,
+ **params: dict,
+):
+ model_config = model_config or get_config(config_path).get("model", {})
+ if condition_scale != [1,1]:
+ for name, module in pipeline.transformer.named_modules():
+ if not name.endswith(".attn"):
+ continue
+ module.c_factor = torch.tensor(condition_scale)
+
+ self = pipeline
+ (
+ prompt,
+ prompt_2,
+ height,
+ width,
+ num_inference_steps,
+ timesteps,
+ guidance_scale,
+ num_images_per_prompt,
+ generator,
+ latents,
+ prompt_embeds,
+ pooled_prompt_embeds,
+ output_type,
+ return_dict,
+ joint_attention_kwargs,
+ callback_on_step_end,
+ callback_on_step_end_tensor_inputs,
+ max_sequence_length,
+ ) = prepare_params(**params)
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None)
+ if self.joint_attention_kwargs is not None
+ else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 4.1. Prepare conditions
+ condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
+ extra_condition_latents, extra_condition_ids, extra_condition_type_ids = ([] for _ in range(3))
+ use_condition = conditions is not None or []
+ if use_condition:
+ if not default_lora:
+ pipeline.set_adapters(conditions[1].condition_type)
+ # for condition in conditions:
+ tokens, ids, type_id = conditions[0].encode(self)
+ condition_latents.append(tokens) # [batch_size, token_n, token_dim]
+ condition_ids.append(ids) # [token_n, id_dim(3)]
+ condition_type_ids.append(type_id) # [token_n, 1]
+ condition_latents = torch.cat(condition_latents, dim=1)
+ condition_ids = torch.cat(condition_ids, dim=0)
+ condition_type_ids = torch.cat(condition_type_ids, dim=0)
+
+ tokens, ids, type_id = conditions[1].encode(self)
+ extra_condition_latents.append(tokens) # [batch_size, token_n, token_dim]
+ extra_condition_ids.append(ids) # [token_n, id_dim(3)]
+ extra_condition_type_ids.append(type_id) # [token_n, 1]
+ extra_condition_latents = torch.cat(extra_condition_latents, dim=1)
+ extra_condition_ids = torch.cat(extra_condition_ids, dim=0)
+ extra_condition_type_ids = torch.cat(extra_condition_type_ids, dim=0)
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
+ )
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.tensor([guidance_scale], device=device)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+ noise_pred = tranformer_forward(
+ self.transformer,
+ model_config=model_config,
+ # Inputs of the condition (new feature)
+ condition_latents=condition_latents if use_condition else None,
+ condition_ids=condition_ids if use_condition else None,
+ condition_type_ids=condition_type_ids if use_condition else None,
+ extra_condition_latents=extra_condition_latents if use_condition else None,
+ extra_condition_ids=extra_condition_ids if use_condition else None,
+ extra_condition_type_ids=extra_condition_type_ids if use_condition else None,
+ # Inputs to the original transformer
+ hidden_states=latents,
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if image_guidance_scale != 1.0:
+ uncondition_latents = conditions.encode(self, empty=True)[0]
+ unc_pred = tranformer_forward(
+ self.transformer,
+ model_config=model_config,
+ # Inputs of the condition (new feature)
+ condition_latents=uncondition_latents if use_condition else None,
+ condition_ids=condition_ids if use_condition else None,
+ condition_type_ids=condition_type_ids if use_condition else None,
+ # Inputs to the original transformer
+ hidden_states=latents,
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
+ timestep=timestep / 1000,
+ guidance=torch.ones_like(guidance),
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (
+ latents / self.vae.config.scaling_factor
+ ) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if condition_scale != [1,1]:
+ for name, module in pipeline.transformer.named_modules():
+ if not name.endswith(".attn"):
+ continue
+ del module.c_factor
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/flux/lora_controller.py b/flux/lora_controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..01cd2b256c1907e823b1b0933c8f73be12883785
--- /dev/null
+++ b/flux/lora_controller.py
@@ -0,0 +1,82 @@
+#As is from OminiControl
+from peft.tuners.tuners_utils import BaseTunerLayer
+from typing import List, Any, Optional, Type
+from .condition import condition_dict
+
+
+class enable_lora:
+ def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
+ self.activated: bool = activated
+ if activated:
+ return
+ self.lora_modules: List[BaseTunerLayer] = [
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
+ ]
+ self.scales = [
+ {
+ active_adapter: lora_module.scaling[active_adapter]
+ for active_adapter in lora_module.active_adapters
+ }
+ for lora_module in self.lora_modules
+ ]
+
+ def __enter__(self) -> None:
+ if self.activated:
+ return
+
+ for lora_module in self.lora_modules:
+ if not isinstance(lora_module, BaseTunerLayer):
+ continue
+ for active_adapter in lora_module.active_adapters:
+ if (
+ active_adapter in condition_dict.keys()
+ or active_adapter == "default"
+ ):
+ lora_module.scaling[active_adapter] = 0.0
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[Any],
+ ) -> None:
+ if self.activated:
+ return
+ for i, lora_module in enumerate(self.lora_modules):
+ if not isinstance(lora_module, BaseTunerLayer):
+ continue
+ for active_adapter in lora_module.active_adapters:
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
+
+
+class set_lora_scale:
+ def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
+ self.lora_modules: List[BaseTunerLayer] = [
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
+ ]
+ self.scales = [
+ {
+ active_adapter: lora_module.scaling[active_adapter]
+ for active_adapter in lora_module.active_adapters
+ }
+ for lora_module in self.lora_modules
+ ]
+ self.scale = scale
+
+ def __enter__(self) -> None:
+ for lora_module in self.lora_modules:
+ if not isinstance(lora_module, BaseTunerLayer):
+ continue
+ lora_module.scale_layer(self.scale)
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[Any],
+ ) -> None:
+ for i, lora_module in enumerate(self.lora_modules):
+ if not isinstance(lora_module, BaseTunerLayer):
+ continue
+ for active_adapter in lora_module.active_adapters:
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
\ No newline at end of file
diff --git a/flux/pipeline_tools.py b/flux/pipeline_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bc7d6aec5355ad93135a4aa5ac77ebd653668b3
--- /dev/null
+++ b/flux/pipeline_tools.py
@@ -0,0 +1,53 @@
+#As is from OminiControl
+from diffusers.pipelines import FluxPipeline
+from diffusers.utils import logging
+from diffusers.pipelines.flux.pipeline_flux import logger
+from torch import Tensor
+
+
+def encode_images(pipeline: FluxPipeline, images: Tensor):
+ images = pipeline.image_processor.preprocess(images)
+ images = images.to(pipeline.device).to(pipeline.dtype)
+ images = pipeline.vae.encode(images).latent_dist.sample()
+ images = (
+ images - pipeline.vae.config.shift_factor
+ ) * pipeline.vae.config.scaling_factor
+ images_tokens = pipeline._pack_latents(images, *images.shape)
+ images_ids = pipeline._prepare_latent_image_ids(
+ images.shape[0],
+ images.shape[2],
+ images.shape[3],
+ pipeline.device,
+ pipeline.dtype,
+ )
+ if images_tokens.shape[1] != images_ids.shape[0]:
+ images_ids = pipeline._prepare_latent_image_ids(
+ images.shape[0],
+ images.shape[2] // 2,
+ images.shape[3] // 2,
+ pipeline.device,
+ pipeline.dtype,
+ )
+ return images_tokens, images_ids
+
+
+def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
+ # Turn off warnings (CLIP overflow)
+ logger.setLevel(logging.ERROR)
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = pipeline.encode_prompt(
+ prompt=prompts,
+ prompt_2=None,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ device=pipeline.device,
+ num_images_per_prompt=1,
+ max_sequence_length=max_sequence_length,
+ lora_scale=None,
+ )
+ # Turn on warnings
+ logger.setLevel(logging.WARNING)
+ return prompt_embeds, pooled_prompt_embeds, text_ids
diff --git a/flux/transformer.py b/flux/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e0480561f5eee5711361fdf45f39d5a661e4c4
--- /dev/null
+++ b/flux/transformer.py
@@ -0,0 +1,286 @@
+# Recycled from Ominicontrol and modified to accept an extra condition.
+# While Zenctrl pursued a similar idea, it diverged structurally.
+# We appreciate the clarity of Omini's implementation and decided to align with it.
+
+import torch
+from diffusers.pipelines import FluxPipeline
+from typing import List, Union, Optional, Dict, Any, Callable
+from .block import block_forward, single_block_forward
+from .lora_controller import enable_lora
+from accelerate.utils import is_torch_version
+from diffusers.models.transformers.transformer_flux import (
+ FluxTransformer2DModel,
+ Transformer2DModelOutput,
+ USE_PEFT_BACKEND,
+ scale_lora_layers,
+ unscale_lora_layers,
+ logger,
+)
+import numpy as np
+
+
+def prepare_params(
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ return_dict: bool = True,
+ **kwargs: dict,
+):
+ return (
+ hidden_states,
+ encoder_hidden_states,
+ pooled_projections,
+ timestep,
+ img_ids,
+ txt_ids,
+ guidance,
+ joint_attention_kwargs,
+ controlnet_block_samples,
+ controlnet_single_block_samples,
+ return_dict,
+ )
+
+
+def tranformer_forward(
+ transformer: FluxTransformer2DModel,
+ condition_latents: torch.Tensor,
+ extra_condition_latents: torch.Tensor,
+ condition_ids: torch.Tensor,
+ condition_type_ids: torch.Tensor,
+ extra_condition_ids: torch.Tensor,
+ extra_condition_type_ids: torch.Tensor,
+ model_config: Optional[Dict[str, Any]] = {},
+ c_t=0,
+ **params: dict,
+):
+ self = transformer
+ use_condition = condition_latents is not None
+ use_extra_condition = extra_condition_latents is not None
+
+ (
+ hidden_states,
+ encoder_hidden_states,
+ pooled_projections,
+ timestep,
+ img_ids,
+ txt_ids,
+ guidance,
+ joint_attention_kwargs,
+ controlnet_block_samples,
+ controlnet_single_block_samples,
+ return_dict,
+ ) = prepare_params(**params)
+
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if (
+ joint_attention_kwargs is not None
+ and joint_attention_kwargs.get("scale", None) is not None
+ ):
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
+ hidden_states = self.x_embedder(hidden_states)
+ condition_latents = self.x_embedder(condition_latents) if use_condition else None
+ extra_condition_latents = self.x_embedder(extra_condition_latents) if use_extra_condition else None
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+ else:
+ guidance = None
+
+ temb = (
+ self.time_text_embed(timestep, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, pooled_projections)
+ )
+
+ cond_temb = (
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(
+ torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
+ )
+ )
+ extra_cond_temb = (
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(
+ torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
+ )
+ )
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if txt_ids.ndim == 3:
+ logger.warning(
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ logger.warning(
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+ if use_condition:
+ # condition_ids[:, :1] = condition_type_ids
+ cond_rotary_emb = self.pos_embed(condition_ids)
+
+ if use_extra_condition:
+ extra_cond_rotary_emb = self.pos_embed(extra_condition_ids)
+
+
+ # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
+
+ #print("here!")
+ for index_block, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+ ckpt_kwargs: Dict[str, Any] = (
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ )
+ encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = (
+ torch.utils.checkpoint.checkpoint(
+ block_forward,
+ self=block,
+ model_config=model_config,
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ condition_latents=condition_latents if use_condition else None,
+ extra_condition_latents=extra_condition_latents if use_extra_condition else None,
+ temb=temb,
+ cond_temb=cond_temb if use_condition else None,
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
+ extra_cond_temb=extra_cond_temb if use_extra_condition else None,
+ extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_condition else None,
+ image_rotary_emb=image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ )
+
+ else:
+ encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = block_forward(
+ block,
+ model_config=model_config,
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ condition_latents=condition_latents if use_condition else None,
+ extra_condition_latents=extra_condition_latents if use_extra_condition else None,
+ temb=temb,
+ cond_temb=cond_temb if use_condition else None,
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
+ extra_cond_temb=cond_temb if use_extra_condition else None,
+ extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_condition else None,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # controlnet residual
+ if controlnet_block_samples is not None:
+ interval_control = len(self.transformer_blocks) / len(
+ controlnet_block_samples
+ )
+ interval_control = int(np.ceil(interval_control))
+ hidden_states = (
+ hidden_states
+ + controlnet_block_samples[index_block // interval_control]
+ )
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+ ckpt_kwargs: Dict[str, Any] = (
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ )
+ result = torch.utils.checkpoint.checkpoint(
+ single_block_forward,
+ self=block,
+ model_config=model_config,
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ **(
+ {
+ "condition_latents": condition_latents,
+ "extra_condition_latents": extra_condition_latents,
+ "cond_temb": cond_temb,
+ "cond_rotary_emb": cond_rotary_emb,
+ "extra_cond_temb": extra_cond_temb,
+ "extra_cond_rotary_emb": extra_cond_rotary_emb,
+ }
+ if use_condition
+ else {}
+ ),
+ **ckpt_kwargs,
+ )
+
+ else:
+ result = single_block_forward(
+ block,
+ model_config=model_config,
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ **(
+ {
+ "condition_latents": condition_latents,
+ "extra_condition_latents": extra_condition_latents,
+ "cond_temb": cond_temb,
+ "cond_rotary_emb": cond_rotary_emb,
+ "extra_cond_temb": extra_cond_temb,
+ "extra_cond_rotary_emb": extra_cond_rotary_emb,
+ }
+ if use_condition
+ else {}
+ ),
+ )
+ if use_condition:
+ hidden_states, condition_latents, extra_condition_latents = result
+ else:
+ hidden_states = result
+
+ # controlnet residual
+ if controlnet_single_block_samples is not None:
+ interval_control = len(self.single_transformer_blocks) / len(
+ controlnet_single_block_samples
+ )
+ interval_control = int(np.ceil(interval_control))
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ + controlnet_single_block_samples[index_block // interval_control]
+ )
+
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/imgs/bg_i1.png b/imgs/bg_i1.png
new file mode 100644
index 0000000000000000000000000000000000000000..ac07eab92116f09f5592e86a260fb16721f7030f
--- /dev/null
+++ b/imgs/bg_i1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8233ff6e5eaaf97f6157599708ed67e62380f3e09820d67bb2ebd472d84165a7
+size 97739
diff --git a/imgs/bg_i2.png b/imgs/bg_i2.png
new file mode 100644
index 0000000000000000000000000000000000000000..5c6919aff4d06d9690fb033c0a246e923c82ff81
--- /dev/null
+++ b/imgs/bg_i2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41ff9fabfb2e31cce35cc97f2c5962165c3b68bce60732e6669692307ec5ebed
+size 196618
diff --git a/imgs/bg_i3.png b/imgs/bg_i3.png
new file mode 100644
index 0000000000000000000000000000000000000000..55dbf1a217a2584c260f7cc5e992e7b11c225581
--- /dev/null
+++ b/imgs/bg_i3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89815263bfd1b72e5be817ddcefc770a728b46bdb7d8264258cae8b5a4270493
+size 279403
diff --git a/imgs/bg_i4.png b/imgs/bg_i4.png
new file mode 100644
index 0000000000000000000000000000000000000000..53813bd918ac8cda07da7b6808f0a9b541e7e071
--- /dev/null
+++ b/imgs/bg_i4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ebedef7bdaf6b6e184cad63665f40d2f86dd5a6c54334b400d1ce479c5ec339e
+size 1004461
diff --git a/imgs/bg_i5.png b/imgs/bg_i5.png
new file mode 100644
index 0000000000000000000000000000000000000000..297584180289e5b8e5afe1b42449cade3fcbffa4
--- /dev/null
+++ b/imgs/bg_i5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f3065ff3010bdf54d9eed2784a9f998597fb3f10646f0a6ba71e7d2abd9041c2
+size 946929
diff --git a/imgs/bg_o1.png b/imgs/bg_o1.png
new file mode 100644
index 0000000000000000000000000000000000000000..62fabe8ae7a543658962c329535e0d947c11daf2
--- /dev/null
+++ b/imgs/bg_o1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8e80909e8b091f02d8d00f8614194f5e6606b076bf86d4fd88e58ffcfeaaa4ed
+size 621107
diff --git a/imgs/bg_o2.png b/imgs/bg_o2.png
new file mode 100644
index 0000000000000000000000000000000000000000..14754b40677c5dca576038331915608b2f351a31
--- /dev/null
+++ b/imgs/bg_o2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e32d9497741ea352725220043e54d1772b1fe4e81910cfaf668cc95ee6e9a23f
+size 944735
diff --git a/imgs/bg_o3.jpg b/imgs/bg_o3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9336cda35a50353c495fe39cc380485309785217
--- /dev/null
+++ b/imgs/bg_o3.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ffef49f49bbed31e53e10552a3e6bedc8c492d0b69b3e5284e5f612705f84983
+size 58127
diff --git a/imgs/bg_o4.jpg b/imgs/bg_o4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e23b454771a76a883be493478a7e8f019f49f01b
--- /dev/null
+++ b/imgs/bg_o4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f26c092a93c15a781ae013a4980f2a4b0489277687263e7f6619f5ea454fa645
+size 69668
diff --git a/imgs/bg_o5.jpg b/imgs/bg_o5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b9f3c2a7eda087c80eff9c69f28a53a10b1fd386
--- /dev/null
+++ b/imgs/bg_o5.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3ebe1a2421d9a36a108d9aca97064d96f1c5016bc6ea378dee637fc843a0357c
+size 37044
diff --git a/imgs/sub_i1.png b/imgs/sub_i1.png
new file mode 100644
index 0000000000000000000000000000000000000000..5c6919aff4d06d9690fb033c0a246e923c82ff81
--- /dev/null
+++ b/imgs/sub_i1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41ff9fabfb2e31cce35cc97f2c5962165c3b68bce60732e6669692307ec5ebed
+size 196618
diff --git a/imgs/sub_i2.png b/imgs/sub_i2.png
new file mode 100644
index 0000000000000000000000000000000000000000..53813bd918ac8cda07da7b6808f0a9b541e7e071
--- /dev/null
+++ b/imgs/sub_i2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ebedef7bdaf6b6e184cad63665f40d2f86dd5a6c54334b400d1ce479c5ec339e
+size 1004461
diff --git a/imgs/sub_i3.png b/imgs/sub_i3.png
new file mode 100644
index 0000000000000000000000000000000000000000..297584180289e5b8e5afe1b42449cade3fcbffa4
--- /dev/null
+++ b/imgs/sub_i3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f3065ff3010bdf54d9eed2784a9f998597fb3f10646f0a6ba71e7d2abd9041c2
+size 946929
diff --git a/imgs/sub_i4.png b/imgs/sub_i4.png
new file mode 100644
index 0000000000000000000000000000000000000000..5b68514fc4b59a786ce76f5bb3b3ea9097e812a2
--- /dev/null
+++ b/imgs/sub_i4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:24deed93849e951f8c7c44de47226240db7c5a42a289f3fdc7c0fa5621f65609
+size 280772
diff --git a/imgs/sub_i5.png b/imgs/sub_i5.png
new file mode 100644
index 0000000000000000000000000000000000000000..8ce13aa317153d416d8dc07589f6054cc3b2bede
--- /dev/null
+++ b/imgs/sub_i5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:50b05fff1d2d404da6d6045c36971cd810c2e0d425168cdafa1511f95c5ba269
+size 249692
diff --git a/imgs/sub_o1.webp b/imgs/sub_o1.webp
new file mode 100644
index 0000000000000000000000000000000000000000..fa9b70b11ffbe47b98905845511da54c1facce47
--- /dev/null
+++ b/imgs/sub_o1.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c1c9eef884328c26cc58de4f35680b2ca5f211071c2b9bd5254223b0dec0757e
+size 32074
diff --git a/imgs/sub_o2.webp b/imgs/sub_o2.webp
new file mode 100644
index 0000000000000000000000000000000000000000..753136203c30dec326c713a05b57a037e889d095
--- /dev/null
+++ b/imgs/sub_o2.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:026e18cab72811630eb367bf4ae2e38933fb9c31c0146b66990c5f8879b8b71b
+size 27314
diff --git a/imgs/sub_o3.webp b/imgs/sub_o3.webp
new file mode 100644
index 0000000000000000000000000000000000000000..d8ff14507c56d53bff6012a302aec6a97b0c4ad6
--- /dev/null
+++ b/imgs/sub_o3.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f0b2abf5d005c747c092c60837b3a80b8d96e5c52625b2d288475799603c8147
+size 28268
diff --git a/imgs/sub_o4.webp b/imgs/sub_o4.webp
new file mode 100644
index 0000000000000000000000000000000000000000..6b8b1bc359fa06aed0efd835e623972fa1362ee7
--- /dev/null
+++ b/imgs/sub_o4.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f77c4b61c9ad96cbf312af18c631f9d558bdcf1f51beb1317bc483bfe700cf18
+size 16900
diff --git a/imgs/sub_o5.webp b/imgs/sub_o5.webp
new file mode 100644
index 0000000000000000000000000000000000000000..d69e7531af9fd2a0f2b50ae32367b13c9e65944a
--- /dev/null
+++ b/imgs/sub_o5.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:14ef4c02f4a2970a26a1bd92ce4c937b8aad477a91463fef5a8a34dd61afe3e8
+size 17432
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a716da4a358e116c752571bbd733b2205b0733cb
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,17 @@
+torch
+torchvision
+torchaudio
+
+diffusers==0.31.0
+transformers
+accelerate
+huggingface_hub
+sentencepiece
+
+numpy
+pillow>=9.0.0
+einops>=0.7.0
+safetensors>=0.4.0
+opencv-python-headless
+peft==0.15.2
+spaces
\ No newline at end of file
diff --git a/samples/1.png b/samples/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..4fb28f1e76650b6fbe294505e022e256e78773d3
--- /dev/null
+++ b/samples/1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee99aba59d42e4c8221785aa032fb5dc3b623478f27a4a9255e42e19bc3308d7
+size 809438
diff --git a/samples/10.png b/samples/10.png
new file mode 100644
index 0000000000000000000000000000000000000000..f680bf69a449f02b44826d3d3be8fd4c7c3d4479
--- /dev/null
+++ b/samples/10.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e6108e472742ee174e2e37d3a8acbfb14c51bdf0c252139086c0813026fb6b73
+size 404613
diff --git a/samples/11.png b/samples/11.png
new file mode 100644
index 0000000000000000000000000000000000000000..800222f4450b055443719ad5967f572d6fc08fce
--- /dev/null
+++ b/samples/11.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a8f107773e71b227d9f95692b7af837d41c4b4a39bda48940755db399d619d21
+size 390691
diff --git a/samples/11_out.webp b/samples/11_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..1c99858ce9055111d0dda64bd3339f7e550677ec
--- /dev/null
+++ b/samples/11_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6ec7669bc2d1423c69edd87e6a6229d15b76043100ae91d7ff2b07b6197646aa
+size 40050
diff --git a/samples/12.png b/samples/12.png
new file mode 100644
index 0000000000000000000000000000000000000000..20b310eb19ccce7e7592fe3257504dac9d6fc6c7
--- /dev/null
+++ b/samples/12.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ede3b034dcb0a378073cf7d18841b8bf58c20c7f2161239718520dd3c36fe338
+size 493753
diff --git a/samples/12_out.webp b/samples/12_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..2ba75294c72aa86955fec72d1f7b078474169969
--- /dev/null
+++ b/samples/12_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6c67d211f75e77242af15d259c20a3d527b844b9f5e9c2a199ff2797e0bc459b
+size 30236
diff --git a/samples/13.png b/samples/13.png
new file mode 100644
index 0000000000000000000000000000000000000000..6594ee87b1a38288e825b11e96fd3858019098c2
--- /dev/null
+++ b/samples/13.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a6740a2d29fde73c310dc9d95b82e02163ad2fe014ec3aad8bfbfe4aed964948
+size 1132115
diff --git a/samples/14.png b/samples/14.png
new file mode 100644
index 0000000000000000000000000000000000000000..1c48a349fd7465371cf1617ac76fd9e08e4bcfeb
--- /dev/null
+++ b/samples/14.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1180b403e72baed790b92903729f5f07a09f3fdb1ab750330a26b3dca5dfdf22
+size 487379
diff --git a/samples/15.png b/samples/15.png
new file mode 100644
index 0000000000000000000000000000000000000000..6f6250292dc8c7617b6a1ac71bc3ff4cc0f82899
--- /dev/null
+++ b/samples/15.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:da6cceb7325b7d4819cf5218768e81fcc486140b5d478c43cc086a7aa3ffed71
+size 959236
diff --git a/samples/16.png b/samples/16.png
new file mode 100644
index 0000000000000000000000000000000000000000..29ffc03b42a5051b85d16108600c24e842e77df7
--- /dev/null
+++ b/samples/16.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7caca84161a104189ff5a7571dbec8342de97238f99280feeda92b2b768a51fb
+size 2224530
diff --git a/samples/17.png b/samples/17.png
new file mode 100644
index 0000000000000000000000000000000000000000..a0293ec170c8cfdafa174353cef4ee8e8e8e8fef
--- /dev/null
+++ b/samples/17.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:15290c3bdcb2bcc575a293a91a1c1e0debdbfcfe2494a8f36458b59403f5fcea
+size 845999
diff --git a/samples/18.png b/samples/18.png
new file mode 100644
index 0000000000000000000000000000000000000000..8d50d442953971fe55ce8c3443a99f31e4d51cf2
--- /dev/null
+++ b/samples/18.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b5771918d8ba7605198be2d670d2cf713b642c29da4a04f069d98cd5e45baa7
+size 676802
diff --git a/samples/19.png b/samples/19.png
new file mode 100644
index 0000000000000000000000000000000000000000..c49cc441a5cd320f44a4f91b80bb1aae4170ca51
--- /dev/null
+++ b/samples/19.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:df6a779b5df871fdce101c5a2b13c77c25160b658c4c729f0d4a3ab34ae11996
+size 952727
diff --git a/samples/1_out.webp b/samples/1_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..b159edebbc6363374eb4c385a01393abf783764e
--- /dev/null
+++ b/samples/1_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4ff52fb3f22b7413cfaf090bd65f93be38491d4b250c668e67a33b4cc6c88a8a
+size 40780
diff --git a/samples/2.png b/samples/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..424bc505f1bc57505ee17615f65e527d84b2082f
--- /dev/null
+++ b/samples/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1a12bcdfac8d901c8d0947213391541dedff0e49964174b9f436726db21a3ad6
+size 899506
diff --git a/samples/20.png b/samples/20.png
new file mode 100644
index 0000000000000000000000000000000000000000..6c986d504ff5913022f2a26ccf0cb2cc0d78765a
--- /dev/null
+++ b/samples/20.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:afa037f9006cdbdf0707db41b88ae6e8109c196f5ceb5f37523a8364b467d534
+size 192685
diff --git a/samples/20_out.webp b/samples/20_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..6f429847fb80d2bbf7b9d66916a2b84e5966a127
--- /dev/null
+++ b/samples/20_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:45e65cbbadd94d062f4d56a0204dc37d7e64a8e3dcc05b78dbc54cc6d65b8cb1
+size 36774
diff --git a/samples/21.png b/samples/21.png
new file mode 100644
index 0000000000000000000000000000000000000000..246828667373ddfbd3784afcbf70342cd8577d12
--- /dev/null
+++ b/samples/21.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ecd7d68258b5d342ff9a316c9dfddefda11bc9558e8b582a4bb01977f3da87e5
+size 302983
diff --git a/samples/21_1.png b/samples/21_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..e8abb7609276bae3f9af83b2f5a5a04fad21360a
--- /dev/null
+++ b/samples/21_1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bf2f5448e75bb285b32265c779b5844ae9c182e53264ecc4950eabca2b563545
+size 4971528
diff --git a/samples/21_out.webp b/samples/21_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..fbde820786401d1137876ff12af4ce4994d17c80
--- /dev/null
+++ b/samples/21_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:54e64f200816a238acbabe0095ef4541f2b1e1d8adde1a331a3ce7328a14a6b2
+size 40026
diff --git a/samples/22.png b/samples/22.png
new file mode 100644
index 0000000000000000000000000000000000000000..8c333f6d216f963bc5f2962b59d3a52710dd8c5a
--- /dev/null
+++ b/samples/22.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:00fd1d1ea48dc25ca4855cf1d6099141cb1375bfd527fa2fab08b0d07208f815
+size 235874
diff --git a/samples/22_out.webp b/samples/22_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..c3e386ae8ecb67e87f5a035ebfb6afa461b29fe0
--- /dev/null
+++ b/samples/22_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:02d87428b86f2923f7207d130f812bb36b4dc67f76b4ea495c20324eeb791021
+size 22016
diff --git a/samples/23.png b/samples/23.png
new file mode 100644
index 0000000000000000000000000000000000000000..97b94e9e607526d30000add0db7eba7d90fb4c36
--- /dev/null
+++ b/samples/23.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:42b9683272e6bbd92a796ccdfb4caca5b6f317ae9abfa4f345c5003a735ddc9e
+size 229596
diff --git a/samples/24.png b/samples/24.png
new file mode 100644
index 0000000000000000000000000000000000000000..615f14b52a0fbb88bd03658af3db16723827f3be
--- /dev/null
+++ b/samples/24.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e498022e215bb0d33b0179d192c05e473ee91f7ebfe6d424317c140cf92bf16
+size 324358
diff --git a/samples/25.png b/samples/25.png
new file mode 100644
index 0000000000000000000000000000000000000000..f140155c862979e03637301ed9ab86854e4454d0
--- /dev/null
+++ b/samples/25.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2ee43af3e149026a1a36c30d64b5ef488bd8a7399b0a5a4f0ea0fd1fb9fe60fb
+size 222029
diff --git a/samples/26.png b/samples/26.png
new file mode 100644
index 0000000000000000000000000000000000000000..f60e3d2d8a8682f7959336a46db81d49c2a43e42
--- /dev/null
+++ b/samples/26.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89e771ba69a141a3f94821be4434edbdf440aa33ec293bfc85285131e95fe3f2
+size 459854
diff --git a/samples/2_out.webp b/samples/2_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..c1fb2b689c7b2ac6d012d8388b14deb624d2f40a
--- /dev/null
+++ b/samples/2_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f64c0edefd4f15d5d88d55c0776fa95c9db746b5b13c05669211f4e342dff9ab
+size 85752
diff --git a/samples/3.png b/samples/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..0c05c53ad0bfce0ad48ffd2a6009c93005b2d5df
--- /dev/null
+++ b/samples/3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b87e33444befcd05e121183603c47b3666c1fd2f5656741dfb1d3bb124835702
+size 470075
diff --git a/samples/3_out.webp b/samples/3_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..19299821fe13125519c7885071e790fd302255d9
--- /dev/null
+++ b/samples/3_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b8f8b8c43439c6c26d4847a6ac4c77a8e4832695cc4aa3ea86f15dd90028bca
+size 42180
diff --git a/samples/4.png b/samples/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..2da7fb284ec99b6b92b1f072c6f46fd4837965ce
--- /dev/null
+++ b/samples/4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:29689ad98946560acfd07c5be9e06c27c94a19e94dcdd73a388f3dbc46b26259
+size 359441
diff --git a/samples/4_out.webp b/samples/4_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..c86ed9611a883357c4baa5fba4dbd6779db6a751
--- /dev/null
+++ b/samples/4_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:11fd666873415763a916edd08a9e19ca534162cd73217bffa0dcfcc75d3e13a4
+size 35576
diff --git a/samples/5.png b/samples/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..8b183c75bd540c5a7fa05b7eba8a56657a0d078d
--- /dev/null
+++ b/samples/5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5b0def9aa8c45a63bf1ee00e02765ad357456b355fbea66e405f53662bf86db7
+size 268539
diff --git a/samples/5_out.webp b/samples/5_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..d9406c69f90fcb57980d6828b31255f699348831
--- /dev/null
+++ b/samples/5_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f625e21ab93e5f6b90954683fcd18c0c2c1fd8465c9b4cf67182dd6d1941d0e2
+size 60376
diff --git a/samples/6.png b/samples/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..86cfb99b4437b3e9c9996b68c0ab474634e09816
--- /dev/null
+++ b/samples/6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:85a0ba550ba4c14c29edb02e3ce12d77b50004489035d0c914d0327197c2f55c
+size 352262
diff --git a/samples/6_out.webp b/samples/6_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..f9254ea1d1b5194994ccbf9d7728e30c46bdb828
--- /dev/null
+++ b/samples/6_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d9616e390eeb36ba16cd881bf65869b5bfd72305a3bdabb94cded20ce30bcafa
+size 26108
diff --git a/samples/7.png b/samples/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..2a08a68b64e248d0dcc5d71fb418cb1a9261b4b1
--- /dev/null
+++ b/samples/7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9a8a55d994069cdd39da729c64713a40b24c1d7fe5b6999ef3040a4c36c69a83
+size 791285
diff --git a/samples/7_out.webp b/samples/7_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..d8ba03a98ef51df7e4b0925cf17a9d8fe94ea71d
--- /dev/null
+++ b/samples/7_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4d2defc4bac2563da5f3486b194c03dbd31688e71d3677c61c965cc211c5ef42
+size 63218
diff --git a/samples/8.png b/samples/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..cb7046e336f6086f3719fc18f1f9a500f33f8ce5
--- /dev/null
+++ b/samples/8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5f7047eab7c212fa34f908ca6b0184c4b056144d44be10fe9e6afbac365fd7c9
+size 483145
diff --git a/samples/8_out.webp b/samples/8_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..d27f06448d417399d57d90170e2739016f4ea9b0
--- /dev/null
+++ b/samples/8_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f3254176ae81676d0d199ba82d4d6e194a4f4a18820b2441992b8a96caf074f2
+size 35850
diff --git a/samples/9.png b/samples/9.png
new file mode 100644
index 0000000000000000000000000000000000000000..9bb0f7bbca111a0394fcaf56ea993d5a11733463
--- /dev/null
+++ b/samples/9.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b68eb987d1d23f178f4e7d385e32d14506a26d0031ab63c079247d495a0c19b5
+size 743147
diff --git a/samples/9_out.webp b/samples/9_out.webp
new file mode 100644
index 0000000000000000000000000000000000000000..9d9c819e50e6a0354d18db1adadd02c6f1b514f7
--- /dev/null
+++ b/samples/9_out.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:05aed5335b5f3c4b0d1214a5d42ad958e4ddbdd1886344798b862ab0dd382d4d
+size 43570