diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..056072503827445de516bb91226150edf90d49fe 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.webp filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 34e9de4e6e8e9ac57f099e3ede376625fa49e6cf..833544f68e941bd9a5c761c620fbb9b150621849 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ sdk_version: 5.23.1 app_file: app.py pinned: false license: apache-2.0 +short_description: an A subject-drivent image generation control toolkit --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/app.py b/app.py index 5c89a9143cd5c1cffd768e81c76cd3cec86a80ae..71614c6e03ef01876a06211dd7769f9bffe7d584 100644 --- a/app.py +++ b/app.py @@ -1,89 +1,106 @@ -import os -import base64 -import io -from typing import TypedDict -import requests +import spaces import gradio as gr -from PIL import Image +import torch +from typing import TypedDict +from PIL import Image, ImageDraw, ImageFont +from diffusers.pipelines import FluxPipeline +from diffusers import FluxTransformer2DModel +import numpy as np +import examples_db -# Read Baseten configuration from environment variables. -BTEN_API_KEY = os.getenv("API_KEY") -URL = os.getenv("URL") -def image_to_base64(image: Image.Image) -> str: - """Convert a PIL image to a base64-encoded PNG string.""" - with io.BytesIO() as buffer: - image.save(buffer, format="PNG") - return base64.b64encode(buffer.getvalue()).decode("utf-8") +from flux.condition import Condition +from flux.generate import seed_everything, generate +from flux.lora_controller import set_lora_scale -def ensure_image(img) -> Image.Image: - """ - Ensure the input is a PIL Image. - If it's already a PIL Image, return it. - If it's a string (file path), open it. - If it's a dict with a "name" key, open the file at that path. - """ - if isinstance(img, Image.Image): - return img - elif isinstance(img, str): - return Image.open(img) - elif isinstance(img, dict) and "name" in img: - return Image.open(img["name"]) +pipe = None +current_adapter = None +use_int8 = False +model_config = { "union_cond_attn": True, "add_cond_attn": False, "latent_lora": False, "independent_condition": True} + +def get_gpu_memory(): + return torch.cuda.get_device_properties(0).total_memory / 1024**3 + + +def init_pipeline(): + global pipe + if use_int8 or get_gpu_memory() < 33: + transformer_model = FluxTransformer2DModel.from_pretrained( + "sayakpaul/flux.1-schell-int8wo-improved", + torch_dtype=torch.bfloat16, + use_safetensors=False, + ) + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + transformer=transformer_model, + torch_dtype=torch.bfloat16, + ) else: - raise ValueError("Cannot convert input to a PIL Image.") - - -def call_baseten_generate( - image: Image.Image, - prompt: str, - steps: int, - strength: float, - height: int, - width: int, - lora_name: str, - remove_bg: bool, -) -> Image.Image | None: + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 + ) + pipe = pipe.to("cuda") + + + # Optional: Load additional LoRA weights + pipe.load_lora_weights( + "fotographerai/zenctrl_tools", + weight_name="weights/zen2con_1024_10000/" + "pytorch_lora_weights.safetensors", + adapter_name="subject" + ) + + # Optional: Load additional LoRA weights + #pipe.load_lora_weights("XLabs-AI/flux-RealismLora", adapter_name="realism") + + +def paste_on_white_background(image: Image.Image) -> Image.Image: """ - Call the Baseten /predict endpoint with provided parameters and return the generated image. + Pastes a transparent image onto a white background of the same size. """ - image = ensure_image(image) - b64_image = image_to_base64(image) - payload = { - "image": b64_image, - "prompt": prompt, - "steps": steps, - "strength": strength, - "height": height, - "width": width, - "lora_name": lora_name, - "bgrm": remove_bg, - } - if not BTEN_API_KEY: - headers = {"Authorization": f"Api-Key {os.getenv('API_KEY')}"} - else: - headers = {"Authorization": f"Api-Key {BTEN_API_KEY}"} - try: - if not URL: - raise ValueError("The URL environment variable is not set.") - - response = requests.post(URL, headers=headers, json=payload) - if response.status_code == 200: - data = response.json() - gen_b64 = data.get("generated_image", None) - if gen_b64: - return Image.open(io.BytesIO(base64.b64decode(gen_b64))) - else: - return None - else: - print(f"Error: HTTP {response.status_code}\n{response.text}") - return None - except Exception as e: - print(f"Error: {e}") - return None - - -# Mode defaults for each tab. + if image.mode != "RGBA": + image = image.convert("RGBA") + + # Create white background + white_bg = Image.new("RGBA", image.size, (255, 255, 255, 255)) + white_bg.paste(image, (0, 0), mask=image) + return white_bg.convert("RGB") # Convert back to RGB if you don't need alpha + +#@spaces.GPU +def process_image_and_text(image, text, steps=8, strength_sub=1.0, strength_spat=1.0, size=1024): + # center crop image + w, h, min_size = image.size[0], image.size[1], min(image.size) + image = image.crop( + ( + (w - min_size) // 2, + (h - min_size) // 2, + (w + min_size) // 2, + (h + min_size) // 2, + ) + ) + image = image.resize((size, size)) + image = paste_on_white_background(image) + condition0 = Condition("subject", image, position_delta=(0, size // 16)) + condition1 = Condition("subject", image, position_delta=(0, -size // 16)) + + pipe = get_pipeline() + + with set_lora_scale(["subject"], scale=3.0): + result_img = generate( + pipe, + prompt=text.strip(), + conditions=[condition0, condition1], + num_inference_steps=steps, + height=1024, + width=1024, + condition_scale = [strength_sub,strength_spat], + model_config=model_config, + ).images[0] + + return result_img + +# ================== MODE CONFIG ===================== Mode = TypedDict( "Mode", @@ -98,77 +115,76 @@ Mode = TypedDict( }, ) +MODEL_TO_LORA: dict[str, str] = { + # dropdown-value # relative path inside the HF repo + "zen2con_1024_10000": "weights/zen2con_1024_10000/pytorch_lora_weights.safetensors", + "zen2con_1440_17000": "weights/zen2con_1440_17000/pytorch_lora_weights.safetensors", + "zen_sub_sub_1024_10000": "weights/zen_sub_sub_1024_10000/pytorch_lora_weights.safetensors", + "zen_toys_1024_4000": "weights/zen_toys_1024_4000/12000/pytorch_lora_weights.safetensors", + "zen_toys_1024_15000": "weights/zen_toys_1024_4000/zen_toys_1024_15000/pytorch_lora_weights.safetensors", + # add more as you upload them +} + + MODE_DEFAULTS: dict[str, Mode] = { "Subject Generation": { - "model": "subject_99000_512", - "prompt": "A detailed portrait with soft lighting", - "default_strength": 1.2, - "default_height": 512, - "default_width": 512, - "models": [ - "zendsd_512_146000", - "subject_99000_512", - # "zen_pers_11000", - "zen_26000_512", - ], - "remove_bg": True, - }, - "Background Generation": { - "model": "bg_canny_58000_1024", + "model": "zen2con_1024_10000", "prompt": "A vibrant background with dynamic lighting and textures", "default_strength": 1.2, "default_height": 1024, "default_width": 1024, - "models": [ - "bgwlight_15000_1024", - # "rmgb_12000_1024", - "bg_canny_58000_1024", - # "gen_back_3000_1024", - "gen_back_7000_1024", - # "gen_bckgnd_18000_512", - # "gen_bckgnd_18000_512", - # "loose_25000_512", - # "looser_23000_1024", - # "looser_bg_gen_21000_1280", - # "old_looser_46000_1024", - # "relight_bg_gen_31000_1024", - ], - "remove_bg": True, - }, - "Canny": { - "model": "canny_21000_1024", - "prompt": "A futuristic cityscape with neon lights", - "default_strength": 1.2, - "default_height": 1024, - "default_width": 1024, - "models": ["canny_21000_1024"], + "models": list(MODEL_TO_LORA.keys()), "remove_bg": True, }, - "Depth": { - "model": "depth_9800_1024", - "prompt": "A scene with pronounced depth and perspective", - "default_strength": 1.2, - "default_height": 1024, - "default_width": 1024, - "models": [ - "depth_9800_1024", - ], - "remove_bg": True, - }, - "Deblurring": { - "model": "deblurr_1024_10000", - "prompt": "A scene with pronounced depth and perspective", - "default_strength": 1.2, - "default_height": 1024, - "default_width": 1024, - "models": ["deblurr_1024_10000"], # "slight_deblurr_18000", - "remove_bg": False, - }, -} + #"Image fix": { + # "model": "zen_toys_1024_4000", + # "prompt": "A detailed portrait with soft lighting", + # "default_strength": 1.2, + # "default_height": 1024, + # "default_width": 1024, + # "models": ["weights/zen_toys_1024_4000/12000/", "weights/zen_toys_1024_4000/12000/"], + # "remove_bg": True, + #} + } + + +def get_pipeline(): + """Lazy-build the pipeline inside the GPU worker.""" + global pipe + if pipe is None: + init_pipeline() # safe here – this fn is @spaces.GPU wrapped + return pipe + +def get_samples(): + sample_list = [ + { + "image": "samples/1.png", + "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'", + }, + { + "image": "samples/2.png", + "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'", + }, + { + "image": "samples/3.png", + "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.", + }, + { + "image": "samples/4.png", + "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.", + }, + { + "image": "samples/5.png", + "text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.", + }, + ] + return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list] + +# =============== UI =============== header = """ -

🌍 ZenCtrl / FLUX

+

🌍 ZenCtrl medium

GitHub Repo HuggingFace Space @@ -178,136 +194,110 @@ header = """
""" -defaults = MODE_DEFAULTS["Subject Generation"] - - -with gr.Blocks(title="🌍 ZenCtrl") as demo: +with gr.Blocks(title="🌍 ZenCtrl-medium") as demo: + # ---------- banner ---------- gr.HTML(header) gr.Markdown( """ # ZenCtrl Demo - [WIP] One Agent to Generate multi-view, diverse-scene, and task-specific high-resolution images from a single subject imageβ€”without fine-tuning. + One framework to Generate multi-view, diverse-scene, and task-specific high-resolution images from a single subject imageβ€”without fine-tuning. We are first releasing some of the task specific weights and will release the codes soon. The goal is to unify all of the visual content generation tasks with a single LLM... - **Modes:** - - **Subject Generation:** Focuses on generating detailed subject portraits. - - **Background Generation:** Creates dynamic, vibrant backgrounds: - You can generate part of the image from sketch while keeping part of it as it is. - - **Canny:** Emphasizes strong edge detection. - - **Depth:** Produces images with realistic depth and perspective. + **Mode:** + - **Subject-driven Image Generation:** Generate in-context images of your subject with high fidelity and in different perspectives. For more details, shoot us a message on discord. """ ) + + # ---------- tab bar ---------- with gr.Tabs(): - for mode in MODE_DEFAULTS: - with gr.Tab(mode): - defaults = MODE_DEFAULTS[mode] - gr.Markdown(f"### {mode} Mode") - gr.Markdown(f"**Default Model:** {defaults['model']}") + for mode_name, defaults in MODE_DEFAULTS.items(): + with gr.Tab(mode_name): + gr.Markdown(f"### {mode_name}") + # -------- left (input) column -------- with gr.Row(): - with gr.Column(scale=2, min_width=370): - input_image = gr.Image( - label="Upload Image", - type="pil", - scale=3, - height=370, - min_width=100, + with gr.Column(scale=2): + input_image = gr.Image(label="Input Image", type="pil") + model_dropdown = gr.Dropdown( + label="Model (LoRA adapter)", + choices=defaults["models"], + value=defaults["model"], + interactive=True, ) - generate_button = gr.Button("Generate") - with gr.Blocks(title="Options"): - model_dropdown = gr.Dropdown( - label="Model", - choices=defaults["models"], - value=defaults["model"], - interactive=True, - ) - remove_bg_checkbox = gr.Checkbox( - label="Remove Background", value=defaults["remove_bg"] - ) + prompt_box = gr.Textbox(label="Prompt", + value=defaults["prompt"], lines=2) + generate_btn = gr.Button("Generate") + + with gr.Accordion("Generation Parameters", open=False): + step_slider = gr.Slider(2, 28, value=12, step=2, label="Steps") + strength_sub_slider = gr.Slider(0.0, 2.0, + value=defaults["default_strength"], + step=0.1, label="Strength (subject)") + strength_spat_slider = gr.Slider(0.0, 2.0, + value=defaults["default_strength"], + step=0.1, label="Strength (spatial)") + size_slider = gr.Slider(512, 2048, + value=defaults["default_height"], + step=64, label="Size (px)") + # -------- right (output) column -------- with gr.Column(scale=2): - output_image = gr.Image( - label="Generated Image", - type="pil", - height=573, - scale=4, - min_width=100, - ) + output_image = gr.Image(label="Output Image", type="pil") - gr.Markdown("#### Prompt") - prompt_box = gr.Textbox( - label="Prompt", value=defaults["prompt"], lines=2 - ) + # ---------- click handler ---------- + @spaces.GPU + def _run(image, model_name, prompt, steps, s_sub, s_spat, size): + global current_adapter - # Wrap generation parameters in an Accordion for collapsible view. - with gr.Accordion("Generation Parameters", open=False): - with gr.Row(): - step_slider = gr.Slider( - minimum=2, maximum=28, value=2, step=2, label="Steps" - ) - strength_slider = gr.Slider( - minimum=0.5, - maximum=2.0, - value=defaults["default_strength"], - step=0.1, - label="Strength", - ) - with gr.Row(): - height_slider = gr.Slider( - minimum=512, - maximum=1360, - value=defaults["default_height"], - step=1, - label="Height", - ) - width_slider = gr.Slider( - minimum=512, - maximum=1360, - value=defaults["default_width"], - step=1, - label="Width", + pipe = get_pipeline() + + # ── switch adapter if needed ────────────────────────── + if model_name != current_adapter: + lora_path = MODEL_TO_LORA[model_name] + # load & activate the chosen adapter + pipe.load_lora_weights( + "fotographerai/zenctrl_tools", + weight_name=lora_path, + adapter_name=model_name, ) + pipe.set_adapters([model_name]) + current_adapter = model_name - def on_generate_click( - model_name, - prompt, - steps, - strength, - height, - width, - remove_bg, - image, - ): - return call_baseten_generate( - image, - prompt, - steps, - strength, - height, - width, - model_name, - remove_bg, + # ── run generation ─────────────────────────────────── + delta = size // 16 + return process_image_and_text( + image, prompt, steps=steps, + strength_sub=s_sub, strength_spat=s_spat, size=size ) - generate_button.click( - fn=on_generate_click, - inputs=[ - model_dropdown, - prompt_box, - step_slider, - strength_slider, - height_slider, - width_slider, - remove_bg_checkbox, - input_image, - ], + generate_btn.click( + fn=_run, + inputs=[input_image, model_dropdown, prompt_box, + step_slider, strength_sub_slider, + strength_spat_slider, size_slider], outputs=[output_image], - concurrency_limit=None ) + # ---------------- Templates -------------------- + if examples_db.MODE_EXAMPLES.get(mode_name): + gr.Examples( + examples=examples_db.MODE_EXAMPLES[mode_name], + inputs=[ input_image, # Image widget + model_dropdown, # Dropdown for adapter + prompt_box, # Textbox for prompt + output_image, # Gallery for output + ], + label="Presets (Image / Model / Prompt)", + examples_per_page=15, + ) +# =============== launch =============== if __name__ == "__main__": - demo.launch() \ No newline at end of file + #init_pipeline() + demo.launch( + debug=True, + share=True + ) \ No newline at end of file diff --git a/examples_db.py b/examples_db.py new file mode 100644 index 0000000000000000000000000000000000000000..071f98ed27429df85d64ecc834abc73a0b745f83 --- /dev/null +++ b/examples_db.py @@ -0,0 +1,96 @@ +MODE_EXAMPLES = { + "Subject Generation": [ + [ + "samples/7.png", + "zen2con_1440_17000", + "A man wearing white shoes stepping outside, closeup on the shoes", + "samples/22_out.webp", + ], + [ + "samples/7.png", + "zen2con_1440_17000", + "Low angle photography, shoes, stepping in water in a futuristic cityscape with neon lights, in the back in large, the words 'ZenCtrl 2025' are written in a futuristic font", + "samples/2_out.webp", + ], + [ + "samples/8.png", + "zen2con_1024_10000", + "a watch, resting on wet volcanic rock, ocean spray mist, golden sunrise back-light, 50 mm lens, shallow DOF, 8k realism", + "samples/3_out.webp", + ], + ["samples/9.png","zen_sub_sub_1024_10000", "a bag, placed on a desk in a cozy bedroom, sunny light coming through the window", "samples/4_out.webp"], + ["samples/10.png", "zen2con_1440_17000","A bottle sits poolside at a luxury hotel. The bottle rests in front of a clear blue pool and resort landscape in the background. Light reflects off the bottle highlighting its sophisticated design. An elegant glass and tropical fruits are included creating a relaxing atmosphere. The image should convey luxury sophistication and an otherworldly refreshment.", "samples/5_out.webp"], + + #[ + # "samples/11.png", + # "zen_toys_1024_4000", + # "Low angle photography, shoes, stepping in water in a futuristic cityscape with neon lights, in the back in large, the words 'ZenCtrl 2025' are written in a futuristic font", + # "samples/1.png", + #], + [ + "samples/12.png", + "zen2con_1024_10000", + "a woman , wearing a pair of sunglasses, in front of the beach", + "samples/6_out.webp", + ], + [ + "samples/13.png", + "zen2con_1440_17000", + "a woman , standing outside , in the streets, next to a cafe", + "samples/7_out.webp", + ], + ["samples/14.png","zen_sub_sub_1024_10000", "a bag , held by woman , walking outside", "samples/8_out.webp"], + ["samples/15.png","zen_sub_sub_1024_10000", "a man wearing a suit, sitting on a chair in a living room", "samples/9_out.webp"], + [ + "samples/6.png", + "zen_toys_1024_4000", + "A kid playing with a toy figurine , indoor, on a sunny day", + "samples/11_out.webp", + ], + [ + "samples/6.png", + "zen_toys_1024_4000", + "A small figurine placed on a table, surrounded by toys. A child is placing it", + "samples/12_out.webp", + ], + [ + "samples/17.png", + "zen2con_1024_10000", + "A black man wearing black wireless headphones at a basketball game", + "samples/20_out.webp", + ], + [ + "samples/21_1.png", + "zen2con_1024_10000", + "a man holding a camera facing the objective.", + "samples/21_out.webp", + ], + ], + "Image fix": [ + [ + "samples/1.png", + "placed on a dark marble table in a bathroom of luxury hotel modern light authentic atmosphere", + "samples/1.png", + ], + [ + "samples/1.png", + "sitting on the middle of the city road on a sunny day very bright day front view", + "samples/1.png", + ], + [ + "samples/1.png", + "A creative capture in an art gallery, with soft, focused lighting highlighting both the person’s features and the abstract surroundings, exuding sophistication.", + "samples/1.png", + ], + [ + "samples/1.png", + "In a rain-soaked urban nightscape, with headlights piercing through the mist and wet streets reflecting the city’s vibrant neon colors, creating an atmosphere of mystery and modern elegance.", + "samples/1.png", + ], + [ + "samples/1.png", + "An elegant room scene featuring a minimalist table and chairs, next to a flower vase, illuminated by ambient lighting that casts gentle shadows and enhances the refined, contemporary decor.", + "samples/1.png", + ], + ], +} diff --git a/flux/__init__.py b/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/flux/block.py b/flux/block.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b1c5c7eb72fe387064e62a264c1067b18b5541 --- /dev/null +++ b/flux/block.py @@ -0,0 +1,461 @@ +# Recycled from Ominicontrol and modified to accept an extra condition. +# While Zenctrl pursued a similar idea, it diverged structurally. +# We appreciate the clarity of Omini's implementation and decided to align with it. + +import torch +from typing import List, Union, Optional, Dict, Any, Callable +from diffusers.models.attention_processor import Attention, F +from .lora_controller import enable_lora +from diffusers.models.embeddings import apply_rotary_emb + +def attn_forward( + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + condition_latents: torch.FloatTensor = None, + extra_condition_latents: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + cond_rotary_emb: Optional[torch.Tensor] = None, + extra_cond_rotary_emb: Optional[torch.Tensor] = None, + model_config: Optional[Dict[str, Any]] = {}, +) -> torch.FloatTensor: + batch_size, _, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + with enable_lora( + (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False) + ): + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q( + encoder_hidden_states_query_proj + ) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k( + encoder_hidden_states_key_proj + ) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if condition_latents is not None: + cond_query = attn.to_q(condition_latents) + cond_key = attn.to_k(condition_latents) + cond_value = attn.to_v(condition_latents) + + cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose( + 1, 2 + ) + cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose( + 1, 2 + ) + if attn.norm_q is not None: + cond_query = attn.norm_q(cond_query) + if attn.norm_k is not None: + cond_key = attn.norm_k(cond_key) + + #extra condition + if extra_condition_latents is not None: + extra_cond_query = attn.to_q(extra_condition_latents) + extra_cond_key = attn.to_k(extra_condition_latents) + extra_cond_value = attn.to_v(extra_condition_latents) + + extra_cond_query = extra_cond_query.view(batch_size, -1, attn.heads, head_dim).transpose( + 1, 2 + ) + extra_cond_key = extra_cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + extra_cond_value = extra_cond_value.view(batch_size, -1, attn.heads, head_dim).transpose( + 1, 2 + ) + if attn.norm_q is not None: + extra_cond_query = attn.norm_q(extra_cond_query) + if attn.norm_k is not None: + extra_cond_key = attn.norm_k(extra_cond_key) + + + if extra_cond_rotary_emb is not None: + extra_cond_query = apply_rotary_emb(extra_cond_query, extra_cond_rotary_emb) + extra_cond_key = apply_rotary_emb(extra_cond_key, extra_cond_rotary_emb) + + if cond_rotary_emb is not None: + cond_query = apply_rotary_emb(cond_query, cond_rotary_emb) + cond_key = apply_rotary_emb(cond_key, cond_rotary_emb) + + if condition_latents is not None: + if extra_condition_latents is not None: + + query = torch.cat([query, cond_query, extra_cond_query], dim=2) + key = torch.cat([key, cond_key, extra_cond_key], dim=2) + value = torch.cat([value, cond_value, extra_cond_value], dim=2) + else: + query = torch.cat([query, cond_query], dim=2) + key = torch.cat([key, cond_key], dim=2) + value = torch.cat([value, cond_value], dim=2) + print("concat Omini latents: ", query.shape, key.shape, value.shape) + + + if not model_config.get("union_cond_attn", True): + + attention_mask = torch.ones( + query.shape[2], key.shape[2], device=query.device, dtype=torch.bool + ) + condition_n = cond_query.shape[2] + attention_mask[-condition_n:, :-condition_n] = False + attention_mask[:-condition_n, -condition_n:] = False + elif model_config.get("independent_condition", False): + attention_mask = torch.ones( + query.shape[2], key.shape[2], device=query.device, dtype=torch.bool + ) + condition_n = cond_query.shape[2] + attention_mask[-condition_n:, :-condition_n] = False + + if hasattr(attn, "c_factor"): + attention_mask = torch.zeros( + query.shape[2], key.shape[2], device=query.device, dtype=query.dtype + ) + condition_n = cond_query.shape[2] + condition_e = extra_cond_query.shape[2] + bias = torch.log(attn.c_factor[0]) + attention_mask[-condition_n-condition_e:-condition_e, :-condition_n-condition_e] = bias + attention_mask[:-condition_n-condition_e, -condition_n-condition_e:-condition_e] = bias + + bias = torch.log(attn.c_factor[1]) + attention_mask[-condition_e:, :-condition_n-condition_e] = bias + attention_mask[:-condition_n-condition_e, -condition_e:] = bias + + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + if condition_latents is not None: + if extra_condition_latents is not None: + encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[ + :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]*2 + ], + hidden_states[:, -condition_latents.shape[1]*2 :-condition_latents.shape[1]], + hidden_states[:, -condition_latents.shape[1] :], #extra condition latents + ) + else: + encoder_hidden_states, hidden_states, condition_latents = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[ + :, encoder_hidden_states.shape[1] : -condition_latents.shape[1] + ], + hidden_states[:, -condition_latents.shape[1] :] + ) + else: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)): + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if condition_latents is not None: + condition_latents = attn.to_out[0](condition_latents) + condition_latents = attn.to_out[1](condition_latents) + + if extra_condition_latents is not None: + extra_condition_latents = attn.to_out[0](extra_condition_latents) + extra_condition_latents = attn.to_out[1](extra_condition_latents) + + + return ( + # (hidden_states, encoder_hidden_states, condition_latents, extra_condition_latents) + (hidden_states, encoder_hidden_states, condition_latents, extra_condition_latents) + if condition_latents is not None + else (hidden_states, encoder_hidden_states) + ) + elif condition_latents is not None: + # if there are condition_latents, we need to separate the hidden_states and the condition_latents + if extra_condition_latents is not None: + hidden_states, condition_latents, extra_condition_latents = ( + hidden_states[:, : -condition_latents.shape[1]*2], + hidden_states[:, -condition_latents.shape[1]*2 :-condition_latents.shape[1]], + hidden_states[:, -condition_latents.shape[1] :], + ) + else: + hidden_states, condition_latents = ( + hidden_states[:, : -condition_latents.shape[1]], + hidden_states[:, -condition_latents.shape[1] :], + ) + return hidden_states, condition_latents, extra_condition_latents + else: + return hidden_states + + +def block_forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + condition_latents: torch.FloatTensor, + extra_condition_latents: torch.FloatTensor, + temb: torch.FloatTensor, + cond_temb: torch.FloatTensor, + extra_cond_temb: torch.FloatTensor, + cond_rotary_emb=None, + extra_cond_rotary_emb=None, + image_rotary_emb=None, + model_config: Optional[Dict[str, Any]] = {}, +): + use_cond = condition_latents is not None + + use_extra_cond = extra_condition_latents is not None + with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)): + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, emb=temb + ) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( + self.norm1_context(encoder_hidden_states, emb=temb) + ) + + if use_cond: + ( + norm_condition_latents, + cond_gate_msa, + cond_shift_mlp, + cond_scale_mlp, + cond_gate_mlp, + ) = self.norm1(condition_latents, emb=cond_temb) + ( + norm_extra_condition_latents, + extra_cond_gate_msa, + extra_cond_shift_mlp, + extra_cond_scale_mlp, + extra_cond_gate_mlp, + ) = self.norm1(extra_condition_latents, emb=extra_cond_temb) + + # Attention. + result = attn_forward( + self.attn, + model_config=model_config, + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + condition_latents=norm_condition_latents if use_cond else None, + extra_condition_latents=norm_extra_condition_latents if use_cond else None, + image_rotary_emb=image_rotary_emb, + cond_rotary_emb=cond_rotary_emb if use_cond else None, + extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_cond else None, + ) + # print("in self block: ", result.shape) + attn_output, context_attn_output = result[:2] + cond_attn_output = result[2] if use_cond else None + extra_condition_output = result[3] + + # Process attention outputs for the `hidden_states`. + # 1. hidden_states + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + # 2. encoder_hidden_states + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + + encoder_hidden_states = encoder_hidden_states + context_attn_output + # 3. condition_latents + if use_cond: + cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output + condition_latents = condition_latents + cond_attn_output + #need to make new condition_extra and add extra_condition_output + if use_extra_cond: + extra_condition_output = extra_cond_gate_msa.unsqueeze(1) * extra_condition_output + extra_condition_latents = extra_condition_latents + extra_condition_output + + if model_config.get("add_cond_attn", False): + hidden_states += cond_attn_output + hidden_states += extra_condition_output + + + # LayerNorm + MLP. + # 1. hidden_states + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + # 2. encoder_hidden_states + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + ) + # 3. condition_latents + if use_cond: + norm_condition_latents = self.norm2(condition_latents) + norm_condition_latents = ( + norm_condition_latents * (1 + cond_scale_mlp[:, None]) + + cond_shift_mlp[:, None] + ) + + if use_extra_cond: + #added conditions + extra_norm_condition_latents = self.norm2(extra_condition_latents) + extra_norm_condition_latents = ( + extra_norm_condition_latents * (1 + extra_cond_scale_mlp[:, None]) + + extra_cond_shift_mlp[:, None] + ) + + # Feed-forward. + with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)): + # 1. hidden_states + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + # 2. encoder_hidden_states + context_ff_output = self.ff_context(norm_encoder_hidden_states) + context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output + # 3. condition_latents + if use_cond: + cond_ff_output = self.ff(norm_condition_latents) + cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output + + if use_extra_cond: + extra_cond_ff_output = self.ff(extra_norm_condition_latents) + extra_cond_ff_output = extra_cond_gate_mlp.unsqueeze(1) * extra_cond_ff_output + + # Process feed-forward outputs. + hidden_states = hidden_states + ff_output + encoder_hidden_states = encoder_hidden_states + context_ff_output + if use_cond: + condition_latents = condition_latents + cond_ff_output + if use_extra_cond: + extra_condition_latents = extra_condition_latents + extra_cond_ff_output + + # Clip to avoid overflow. + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents if use_cond else None + + +def single_block_forward( + self, + hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + condition_latents: torch.FloatTensor = None, + extra_condition_latents: torch.FloatTensor = None, + cond_temb: torch.FloatTensor = None, + extra_cond_temb: torch.FloatTensor = None, + cond_rotary_emb=None, + extra_cond_rotary_emb=None, + model_config: Optional[Dict[str, Any]] = {}, +): + + using_cond = condition_latents is not None + using_extra_cond = extra_condition_latents is not None + residual = hidden_states + with enable_lora( + ( + self.norm.linear, + self.proj_mlp, + ), + model_config.get("latent_lora", False), + ): + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + if using_cond: + residual_cond = condition_latents + norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb) + mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents)) + + if using_extra_cond: + extra_residual_cond = extra_condition_latents + extra_norm_condition_latents, extra_cond_gate = self.norm(extra_condition_latents, emb=extra_cond_temb) + extra_mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(extra_norm_condition_latents)) + + attn_output = attn_forward( + self.attn, + model_config=model_config, + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **( + { + "condition_latents": norm_condition_latents, + "cond_rotary_emb": cond_rotary_emb if using_cond else None, + "extra_condition_latents": extra_norm_condition_latents if using_cond else None, + "extra_cond_rotary_emb": extra_cond_rotary_emb if using_cond else None, + } + if using_cond + else {} + ), + ) + + if using_cond: + attn_output, cond_attn_output, extra_cond_attn_output = attn_output + + + with enable_lora((self.proj_out,), model_config.get("latent_lora", False)): + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if using_cond: + condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2) + cond_gate = cond_gate.unsqueeze(1) + condition_latents = cond_gate * self.proj_out(condition_latents) + condition_latents = residual_cond + condition_latents + + extra_condition_latents = torch.cat([extra_cond_attn_output, extra_mlp_cond_hidden_states], dim=2) + extra_cond_gate = extra_cond_gate.unsqueeze(1) + extra_condition_latents = extra_cond_gate * self.proj_out(extra_condition_latents) + extra_condition_latents = extra_residual_cond + extra_condition_latents + + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states if not using_cond else (hidden_states, condition_latents, extra_condition_latents) diff --git a/flux/condition.py b/flux/condition.py new file mode 100644 index 0000000000000000000000000000000000000000..77b27c33bc408591f513f5b40077e5ebda746a5f --- /dev/null +++ b/flux/condition.py @@ -0,0 +1,89 @@ +# Recycled from Ominicontrol and modified to accept an extra condition. +# While Zenctrl pursued a similar idea, it diverged structurally. +# We appreciate the clarity of Omini's implementation and decided to align with it. + +import torch +from typing import Optional, Union, List, Tuple +from diffusers.pipelines import FluxPipeline +from PIL import Image, ImageFilter +import numpy as np +import cv2 + +# from pipeline_tools import encode_images +from .pipeline_tools import encode_images + +condition_dict = { + "subject": 1, + "sr": 2, + "cot": 3, +} + + +class Condition(object): + def __init__( + self, + condition_type: str, + raw_img: Union[Image.Image, torch.Tensor] = None, + condition: Union[Image.Image, torch.Tensor] = None, + position_delta=None, + ) -> None: + self.condition_type = condition_type + assert raw_img is not None or condition is not None + if raw_img is not None: + self.condition = self.get_condition(condition_type, raw_img) + else: + self.condition = condition + self.position_delta = position_delta + + + def get_condition( + self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor] + ) -> Union[Image.Image, torch.Tensor]: + """ + Returns the condition image. + """ + if condition_type == "subject": + return raw_img + elif condition_type == "sr": + return raw_img + elif condition_type == "cot": + return raw_img.convert("RGB") + return self.condition + + + @property + def type_id(self) -> int: + """ + Returns the type id of the condition. + """ + return condition_dict[self.condition_type] + + def encode( + self, pipe: FluxPipeline, empty: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Encodes the condition into tokens, ids and type_id. + """ + if self.condition_type in [ + "subject", + "sr", + "cot" + ]: + if empty: + # make the condition black + e_condition = Image.new("RGB", self.condition.size, (0, 0, 0)) + e_condition = e_condition.convert("RGB") + tokens, ids = encode_images(pipe, e_condition) + else: + tokens, ids = encode_images(pipe, self.condition) + else: + raise NotImplementedError( + f"Condition type {self.condition_type} not implemented" + ) + if self.position_delta is None and self.condition_type == "subject": + self.position_delta = [0, -self.condition.size[0] // 16] + if self.position_delta is not None: + ids[:, 1] += self.position_delta[0] + ids[:, 2] += self.position_delta[1] + type_id = torch.ones_like(ids[:, :1]) * self.type_id + return tokens, ids, type_id diff --git a/flux/generate.py b/flux/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..567aa2fdc018cd38ef02178915cc3a3187637c87 --- /dev/null +++ b/flux/generate.py @@ -0,0 +1,337 @@ +# Recycled from Ominicontrol and modified to accept an extra condition. +# While Zenctrl pursued a similar idea, it diverged structurally. +# We appreciate the clarity of Omini's implementation and decided to align with it. + +import torch +import yaml, os +from diffusers.pipelines import FluxPipeline +from typing import List, Union, Optional, Dict, Any, Callable +from .transformer import tranformer_forward +from .condition import Condition + + +from diffusers.pipelines.flux.pipeline_flux import ( + FluxPipelineOutput, + calculate_shift, + retrieve_timesteps, + np, +) + + +def get_config(config_path: str = None): + config_path = config_path or os.environ.get("XFL_CONFIG") + if not config_path: + return {} + with open(config_path, "r") as f: + config = yaml.safe_load(f) + return config + + +def prepare_params( + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + **kwargs: dict, +): + return ( + prompt, + prompt_2, + height, + width, + num_inference_steps, + timesteps, + guidance_scale, + num_images_per_prompt, + generator, + latents, + prompt_embeds, + pooled_prompt_embeds, + output_type, + return_dict, + joint_attention_kwargs, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + max_sequence_length, + ) + + +def seed_everything(seed: int = 42): + torch.backends.cudnn.deterministic = True + torch.manual_seed(seed) + np.random.seed(seed) + + +@torch.no_grad() +def generate( + pipeline: FluxPipeline, + conditions: List[Condition] = None, + config_path: str = None, + model_config: Optional[Dict[str, Any]] = {}, + condition_scale: float = [1, 1], + default_lora: bool = False, + image_guidance_scale: float = 1.0, + **params: dict, +): + model_config = model_config or get_config(config_path).get("model", {}) + if condition_scale != [1,1]: + for name, module in pipeline.transformer.named_modules(): + if not name.endswith(".attn"): + continue + module.c_factor = torch.tensor(condition_scale) + + self = pipeline + ( + prompt, + prompt_2, + height, + width, + num_inference_steps, + timesteps, + guidance_scale, + num_images_per_prompt, + generator, + latents, + prompt_embeds, + pooled_prompt_embeds, + output_type, + return_dict, + joint_attention_kwargs, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + max_sequence_length, + ) = prepare_params(**params) + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) + if self.joint_attention_kwargs is not None + else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 4.1. Prepare conditions + condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3)) + extra_condition_latents, extra_condition_ids, extra_condition_type_ids = ([] for _ in range(3)) + use_condition = conditions is not None or [] + if use_condition: + if not default_lora: + pipeline.set_adapters(conditions[1].condition_type) + # for condition in conditions: + tokens, ids, type_id = conditions[0].encode(self) + condition_latents.append(tokens) # [batch_size, token_n, token_dim] + condition_ids.append(ids) # [token_n, id_dim(3)] + condition_type_ids.append(type_id) # [token_n, 1] + condition_latents = torch.cat(condition_latents, dim=1) + condition_ids = torch.cat(condition_ids, dim=0) + condition_type_ids = torch.cat(condition_type_ids, dim=0) + + tokens, ids, type_id = conditions[1].encode(self) + extra_condition_latents.append(tokens) # [batch_size, token_n, token_dim] + extra_condition_ids.append(ids) # [token_n, id_dim(3)] + extra_condition_type_ids.append(type_id) # [token_n, 1] + extra_condition_latents = torch.cat(extra_condition_latents, dim=1) + extra_condition_ids = torch.cat(extra_condition_ids, dim=0) + extra_condition_type_ids = torch.cat(extra_condition_type_ids, dim=0) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.tensor([guidance_scale], device=device) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + noise_pred = tranformer_forward( + self.transformer, + model_config=model_config, + # Inputs of the condition (new feature) + condition_latents=condition_latents if use_condition else None, + condition_ids=condition_ids if use_condition else None, + condition_type_ids=condition_type_ids if use_condition else None, + extra_condition_latents=extra_condition_latents if use_condition else None, + extra_condition_ids=extra_condition_ids if use_condition else None, + extra_condition_type_ids=extra_condition_type_ids if use_condition else None, + # Inputs to the original transformer + hidden_states=latents, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if image_guidance_scale != 1.0: + uncondition_latents = conditions.encode(self, empty=True)[0] + unc_pred = tranformer_forward( + self.transformer, + model_config=model_config, + # Inputs of the condition (new feature) + condition_latents=uncondition_latents if use_condition else None, + condition_ids=condition_ids if use_condition else None, + condition_type_ids=condition_type_ids if use_condition else None, + # Inputs to the original transformer + hidden_states=latents, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timestep / 1000, + guidance=torch.ones_like(guidance), + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = ( + latents / self.vae.config.scaling_factor + ) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if condition_scale != [1,1]: + for name, module in pipeline.transformer.named_modules(): + if not name.endswith(".attn"): + continue + del module.c_factor + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/flux/lora_controller.py b/flux/lora_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..01cd2b256c1907e823b1b0933c8f73be12883785 --- /dev/null +++ b/flux/lora_controller.py @@ -0,0 +1,82 @@ +#As is from OminiControl +from peft.tuners.tuners_utils import BaseTunerLayer +from typing import List, Any, Optional, Type +from .condition import condition_dict + + +class enable_lora: + def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None: + self.activated: bool = activated + if activated: + return + self.lora_modules: List[BaseTunerLayer] = [ + each for each in lora_modules if isinstance(each, BaseTunerLayer) + ] + self.scales = [ + { + active_adapter: lora_module.scaling[active_adapter] + for active_adapter in lora_module.active_adapters + } + for lora_module in self.lora_modules + ] + + def __enter__(self) -> None: + if self.activated: + return + + for lora_module in self.lora_modules: + if not isinstance(lora_module, BaseTunerLayer): + continue + for active_adapter in lora_module.active_adapters: + if ( + active_adapter in condition_dict.keys() + or active_adapter == "default" + ): + lora_module.scaling[active_adapter] = 0.0 + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + if self.activated: + return + for i, lora_module in enumerate(self.lora_modules): + if not isinstance(lora_module, BaseTunerLayer): + continue + for active_adapter in lora_module.active_adapters: + lora_module.scaling[active_adapter] = self.scales[i][active_adapter] + + +class set_lora_scale: + def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None: + self.lora_modules: List[BaseTunerLayer] = [ + each for each in lora_modules if isinstance(each, BaseTunerLayer) + ] + self.scales = [ + { + active_adapter: lora_module.scaling[active_adapter] + for active_adapter in lora_module.active_adapters + } + for lora_module in self.lora_modules + ] + self.scale = scale + + def __enter__(self) -> None: + for lora_module in self.lora_modules: + if not isinstance(lora_module, BaseTunerLayer): + continue + lora_module.scale_layer(self.scale) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + for i, lora_module in enumerate(self.lora_modules): + if not isinstance(lora_module, BaseTunerLayer): + continue + for active_adapter in lora_module.active_adapters: + lora_module.scaling[active_adapter] = self.scales[i][active_adapter] \ No newline at end of file diff --git a/flux/pipeline_tools.py b/flux/pipeline_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc7d6aec5355ad93135a4aa5ac77ebd653668b3 --- /dev/null +++ b/flux/pipeline_tools.py @@ -0,0 +1,53 @@ +#As is from OminiControl +from diffusers.pipelines import FluxPipeline +from diffusers.utils import logging +from diffusers.pipelines.flux.pipeline_flux import logger +from torch import Tensor + + +def encode_images(pipeline: FluxPipeline, images: Tensor): + images = pipeline.image_processor.preprocess(images) + images = images.to(pipeline.device).to(pipeline.dtype) + images = pipeline.vae.encode(images).latent_dist.sample() + images = ( + images - pipeline.vae.config.shift_factor + ) * pipeline.vae.config.scaling_factor + images_tokens = pipeline._pack_latents(images, *images.shape) + images_ids = pipeline._prepare_latent_image_ids( + images.shape[0], + images.shape[2], + images.shape[3], + pipeline.device, + pipeline.dtype, + ) + if images_tokens.shape[1] != images_ids.shape[0]: + images_ids = pipeline._prepare_latent_image_ids( + images.shape[0], + images.shape[2] // 2, + images.shape[3] // 2, + pipeline.device, + pipeline.dtype, + ) + return images_tokens, images_ids + + +def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512): + # Turn off warnings (CLIP overflow) + logger.setLevel(logging.ERROR) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = pipeline.encode_prompt( + prompt=prompts, + prompt_2=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + device=pipeline.device, + num_images_per_prompt=1, + max_sequence_length=max_sequence_length, + lora_scale=None, + ) + # Turn on warnings + logger.setLevel(logging.WARNING) + return prompt_embeds, pooled_prompt_embeds, text_ids diff --git a/flux/transformer.py b/flux/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b6e0480561f5eee5711361fdf45f39d5a661e4c4 --- /dev/null +++ b/flux/transformer.py @@ -0,0 +1,286 @@ +# Recycled from Ominicontrol and modified to accept an extra condition. +# While Zenctrl pursued a similar idea, it diverged structurally. +# We appreciate the clarity of Omini's implementation and decided to align with it. + +import torch +from diffusers.pipelines import FluxPipeline +from typing import List, Union, Optional, Dict, Any, Callable +from .block import block_forward, single_block_forward +from .lora_controller import enable_lora +from accelerate.utils import is_torch_version +from diffusers.models.transformers.transformer_flux import ( + FluxTransformer2DModel, + Transformer2DModelOutput, + USE_PEFT_BACKEND, + scale_lora_layers, + unscale_lora_layers, + logger, +) +import numpy as np + + +def prepare_params( + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + **kwargs: dict, +): + return ( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + img_ids, + txt_ids, + guidance, + joint_attention_kwargs, + controlnet_block_samples, + controlnet_single_block_samples, + return_dict, + ) + + +def tranformer_forward( + transformer: FluxTransformer2DModel, + condition_latents: torch.Tensor, + extra_condition_latents: torch.Tensor, + condition_ids: torch.Tensor, + condition_type_ids: torch.Tensor, + extra_condition_ids: torch.Tensor, + extra_condition_type_ids: torch.Tensor, + model_config: Optional[Dict[str, Any]] = {}, + c_t=0, + **params: dict, +): + self = transformer + use_condition = condition_latents is not None + use_extra_condition = extra_condition_latents is not None + + ( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + img_ids, + txt_ids, + guidance, + joint_attention_kwargs, + controlnet_block_samples, + controlnet_single_block_samples, + return_dict, + ) = prepare_params(**params) + + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if ( + joint_attention_kwargs is not None + and joint_attention_kwargs.get("scale", None) is not None + ): + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)): + hidden_states = self.x_embedder(hidden_states) + condition_latents = self.x_embedder(condition_latents) if use_condition else None + extra_condition_latents = self.x_embedder(extra_condition_latents) if use_extra_condition else None + + timestep = timestep.to(hidden_states.dtype) * 1000 + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + + cond_temb = ( + self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections) + if guidance is None + else self.time_text_embed( + torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections + ) + ) + extra_cond_temb = ( + self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections) + if guidance is None + else self.time_text_embed( + torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections + ) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + if use_condition: + # condition_ids[:, :1] = condition_type_ids + cond_rotary_emb = self.pos_embed(condition_ids) + + if use_extra_condition: + extra_cond_rotary_emb = self.pos_embed(extra_condition_ids) + + + # hidden_states = torch.cat([hidden_states, condition_latents], dim=1) + + #print("here!") + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = ( + torch.utils.checkpoint.checkpoint( + block_forward, + self=block, + model_config=model_config, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + condition_latents=condition_latents if use_condition else None, + extra_condition_latents=extra_condition_latents if use_extra_condition else None, + temb=temb, + cond_temb=cond_temb if use_condition else None, + cond_rotary_emb=cond_rotary_emb if use_condition else None, + extra_cond_temb=extra_cond_temb if use_extra_condition else None, + extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_condition else None, + image_rotary_emb=image_rotary_emb, + **ckpt_kwargs, + ) + ) + + else: + encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = block_forward( + block, + model_config=model_config, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + condition_latents=condition_latents if use_condition else None, + extra_condition_latents=extra_condition_latents if use_extra_condition else None, + temb=temb, + cond_temb=cond_temb if use_condition else None, + cond_rotary_emb=cond_rotary_emb if use_condition else None, + extra_cond_temb=cond_temb if use_extra_condition else None, + extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_condition else None, + image_rotary_emb=image_rotary_emb, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len( + controlnet_block_samples + ) + interval_control = int(np.ceil(interval_control)) + hidden_states = ( + hidden_states + + controlnet_block_samples[index_block // interval_control] + ) + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + result = torch.utils.checkpoint.checkpoint( + single_block_forward, + self=block, + model_config=model_config, + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + **( + { + "condition_latents": condition_latents, + "extra_condition_latents": extra_condition_latents, + "cond_temb": cond_temb, + "cond_rotary_emb": cond_rotary_emb, + "extra_cond_temb": extra_cond_temb, + "extra_cond_rotary_emb": extra_cond_rotary_emb, + } + if use_condition + else {} + ), + **ckpt_kwargs, + ) + + else: + result = single_block_forward( + block, + model_config=model_config, + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + **( + { + "condition_latents": condition_latents, + "extra_condition_latents": extra_condition_latents, + "cond_temb": cond_temb, + "cond_rotary_emb": cond_rotary_emb, + "extra_cond_temb": extra_cond_temb, + "extra_cond_rotary_emb": extra_cond_rotary_emb, + } + if use_condition + else {} + ), + ) + if use_condition: + hidden_states, condition_latents, extra_condition_latents = result + else: + hidden_states = result + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len( + controlnet_single_block_samples + ) + interval_control = int(np.ceil(interval_control)) + hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( + hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + controlnet_single_block_samples[index_block // interval_control] + ) + + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/imgs/bg_i1.png b/imgs/bg_i1.png new file mode 100644 index 0000000000000000000000000000000000000000..ac07eab92116f09f5592e86a260fb16721f7030f --- /dev/null +++ b/imgs/bg_i1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8233ff6e5eaaf97f6157599708ed67e62380f3e09820d67bb2ebd472d84165a7 +size 97739 diff --git a/imgs/bg_i2.png b/imgs/bg_i2.png new file mode 100644 index 0000000000000000000000000000000000000000..5c6919aff4d06d9690fb033c0a246e923c82ff81 --- /dev/null +++ b/imgs/bg_i2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:41ff9fabfb2e31cce35cc97f2c5962165c3b68bce60732e6669692307ec5ebed +size 196618 diff --git a/imgs/bg_i3.png b/imgs/bg_i3.png new file mode 100644 index 0000000000000000000000000000000000000000..55dbf1a217a2584c260f7cc5e992e7b11c225581 --- /dev/null +++ b/imgs/bg_i3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89815263bfd1b72e5be817ddcefc770a728b46bdb7d8264258cae8b5a4270493 +size 279403 diff --git a/imgs/bg_i4.png b/imgs/bg_i4.png new file mode 100644 index 0000000000000000000000000000000000000000..53813bd918ac8cda07da7b6808f0a9b541e7e071 --- /dev/null +++ b/imgs/bg_i4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebedef7bdaf6b6e184cad63665f40d2f86dd5a6c54334b400d1ce479c5ec339e +size 1004461 diff --git a/imgs/bg_i5.png b/imgs/bg_i5.png new file mode 100644 index 0000000000000000000000000000000000000000..297584180289e5b8e5afe1b42449cade3fcbffa4 --- /dev/null +++ b/imgs/bg_i5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3065ff3010bdf54d9eed2784a9f998597fb3f10646f0a6ba71e7d2abd9041c2 +size 946929 diff --git a/imgs/bg_o1.png b/imgs/bg_o1.png new file mode 100644 index 0000000000000000000000000000000000000000..62fabe8ae7a543658962c329535e0d947c11daf2 --- /dev/null +++ b/imgs/bg_o1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e80909e8b091f02d8d00f8614194f5e6606b076bf86d4fd88e58ffcfeaaa4ed +size 621107 diff --git a/imgs/bg_o2.png b/imgs/bg_o2.png new file mode 100644 index 0000000000000000000000000000000000000000..14754b40677c5dca576038331915608b2f351a31 --- /dev/null +++ b/imgs/bg_o2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e32d9497741ea352725220043e54d1772b1fe4e81910cfaf668cc95ee6e9a23f +size 944735 diff --git a/imgs/bg_o3.jpg b/imgs/bg_o3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9336cda35a50353c495fe39cc380485309785217 --- /dev/null +++ b/imgs/bg_o3.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffef49f49bbed31e53e10552a3e6bedc8c492d0b69b3e5284e5f612705f84983 +size 58127 diff --git a/imgs/bg_o4.jpg b/imgs/bg_o4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e23b454771a76a883be493478a7e8f019f49f01b --- /dev/null +++ b/imgs/bg_o4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f26c092a93c15a781ae013a4980f2a4b0489277687263e7f6619f5ea454fa645 +size 69668 diff --git a/imgs/bg_o5.jpg b/imgs/bg_o5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b9f3c2a7eda087c80eff9c69f28a53a10b1fd386 --- /dev/null +++ b/imgs/bg_o5.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ebe1a2421d9a36a108d9aca97064d96f1c5016bc6ea378dee637fc843a0357c +size 37044 diff --git a/imgs/sub_i1.png b/imgs/sub_i1.png new file mode 100644 index 0000000000000000000000000000000000000000..5c6919aff4d06d9690fb033c0a246e923c82ff81 --- /dev/null +++ b/imgs/sub_i1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:41ff9fabfb2e31cce35cc97f2c5962165c3b68bce60732e6669692307ec5ebed +size 196618 diff --git a/imgs/sub_i2.png b/imgs/sub_i2.png new file mode 100644 index 0000000000000000000000000000000000000000..53813bd918ac8cda07da7b6808f0a9b541e7e071 --- /dev/null +++ b/imgs/sub_i2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebedef7bdaf6b6e184cad63665f40d2f86dd5a6c54334b400d1ce479c5ec339e +size 1004461 diff --git a/imgs/sub_i3.png b/imgs/sub_i3.png new file mode 100644 index 0000000000000000000000000000000000000000..297584180289e5b8e5afe1b42449cade3fcbffa4 --- /dev/null +++ b/imgs/sub_i3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3065ff3010bdf54d9eed2784a9f998597fb3f10646f0a6ba71e7d2abd9041c2 +size 946929 diff --git a/imgs/sub_i4.png b/imgs/sub_i4.png new file mode 100644 index 0000000000000000000000000000000000000000..5b68514fc4b59a786ce76f5bb3b3ea9097e812a2 --- /dev/null +++ b/imgs/sub_i4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24deed93849e951f8c7c44de47226240db7c5a42a289f3fdc7c0fa5621f65609 +size 280772 diff --git a/imgs/sub_i5.png b/imgs/sub_i5.png new file mode 100644 index 0000000000000000000000000000000000000000..8ce13aa317153d416d8dc07589f6054cc3b2bede --- /dev/null +++ b/imgs/sub_i5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50b05fff1d2d404da6d6045c36971cd810c2e0d425168cdafa1511f95c5ba269 +size 249692 diff --git a/imgs/sub_o1.webp b/imgs/sub_o1.webp new file mode 100644 index 0000000000000000000000000000000000000000..fa9b70b11ffbe47b98905845511da54c1facce47 --- /dev/null +++ b/imgs/sub_o1.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1c9eef884328c26cc58de4f35680b2ca5f211071c2b9bd5254223b0dec0757e +size 32074 diff --git a/imgs/sub_o2.webp b/imgs/sub_o2.webp new file mode 100644 index 0000000000000000000000000000000000000000..753136203c30dec326c713a05b57a037e889d095 --- /dev/null +++ b/imgs/sub_o2.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:026e18cab72811630eb367bf4ae2e38933fb9c31c0146b66990c5f8879b8b71b +size 27314 diff --git a/imgs/sub_o3.webp b/imgs/sub_o3.webp new file mode 100644 index 0000000000000000000000000000000000000000..d8ff14507c56d53bff6012a302aec6a97b0c4ad6 --- /dev/null +++ b/imgs/sub_o3.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0b2abf5d005c747c092c60837b3a80b8d96e5c52625b2d288475799603c8147 +size 28268 diff --git a/imgs/sub_o4.webp b/imgs/sub_o4.webp new file mode 100644 index 0000000000000000000000000000000000000000..6b8b1bc359fa06aed0efd835e623972fa1362ee7 --- /dev/null +++ b/imgs/sub_o4.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f77c4b61c9ad96cbf312af18c631f9d558bdcf1f51beb1317bc483bfe700cf18 +size 16900 diff --git a/imgs/sub_o5.webp b/imgs/sub_o5.webp new file mode 100644 index 0000000000000000000000000000000000000000..d69e7531af9fd2a0f2b50ae32367b13c9e65944a --- /dev/null +++ b/imgs/sub_o5.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:14ef4c02f4a2970a26a1bd92ce4c937b8aad477a91463fef5a8a34dd61afe3e8 +size 17432 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a716da4a358e116c752571bbd733b2205b0733cb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +torch +torchvision +torchaudio + +diffusers==0.31.0 +transformers +accelerate +huggingface_hub +sentencepiece + +numpy +pillow>=9.0.0 +einops>=0.7.0 +safetensors>=0.4.0 +opencv-python-headless +peft==0.15.2 +spaces \ No newline at end of file diff --git a/samples/1.png b/samples/1.png new file mode 100644 index 0000000000000000000000000000000000000000..4fb28f1e76650b6fbe294505e022e256e78773d3 --- /dev/null +++ b/samples/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee99aba59d42e4c8221785aa032fb5dc3b623478f27a4a9255e42e19bc3308d7 +size 809438 diff --git a/samples/10.png b/samples/10.png new file mode 100644 index 0000000000000000000000000000000000000000..f680bf69a449f02b44826d3d3be8fd4c7c3d4479 --- /dev/null +++ b/samples/10.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6108e472742ee174e2e37d3a8acbfb14c51bdf0c252139086c0813026fb6b73 +size 404613 diff --git a/samples/11.png b/samples/11.png new file mode 100644 index 0000000000000000000000000000000000000000..800222f4450b055443719ad5967f572d6fc08fce --- /dev/null +++ b/samples/11.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8f107773e71b227d9f95692b7af837d41c4b4a39bda48940755db399d619d21 +size 390691 diff --git a/samples/11_out.webp b/samples/11_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..1c99858ce9055111d0dda64bd3339f7e550677ec --- /dev/null +++ b/samples/11_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ec7669bc2d1423c69edd87e6a6229d15b76043100ae91d7ff2b07b6197646aa +size 40050 diff --git a/samples/12.png b/samples/12.png new file mode 100644 index 0000000000000000000000000000000000000000..20b310eb19ccce7e7592fe3257504dac9d6fc6c7 --- /dev/null +++ b/samples/12.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ede3b034dcb0a378073cf7d18841b8bf58c20c7f2161239718520dd3c36fe338 +size 493753 diff --git a/samples/12_out.webp b/samples/12_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..2ba75294c72aa86955fec72d1f7b078474169969 --- /dev/null +++ b/samples/12_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c67d211f75e77242af15d259c20a3d527b844b9f5e9c2a199ff2797e0bc459b +size 30236 diff --git a/samples/13.png b/samples/13.png new file mode 100644 index 0000000000000000000000000000000000000000..6594ee87b1a38288e825b11e96fd3858019098c2 --- /dev/null +++ b/samples/13.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6740a2d29fde73c310dc9d95b82e02163ad2fe014ec3aad8bfbfe4aed964948 +size 1132115 diff --git a/samples/14.png b/samples/14.png new file mode 100644 index 0000000000000000000000000000000000000000..1c48a349fd7465371cf1617ac76fd9e08e4bcfeb --- /dev/null +++ b/samples/14.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1180b403e72baed790b92903729f5f07a09f3fdb1ab750330a26b3dca5dfdf22 +size 487379 diff --git a/samples/15.png b/samples/15.png new file mode 100644 index 0000000000000000000000000000000000000000..6f6250292dc8c7617b6a1ac71bc3ff4cc0f82899 --- /dev/null +++ b/samples/15.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da6cceb7325b7d4819cf5218768e81fcc486140b5d478c43cc086a7aa3ffed71 +size 959236 diff --git a/samples/16.png b/samples/16.png new file mode 100644 index 0000000000000000000000000000000000000000..29ffc03b42a5051b85d16108600c24e842e77df7 --- /dev/null +++ b/samples/16.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7caca84161a104189ff5a7571dbec8342de97238f99280feeda92b2b768a51fb +size 2224530 diff --git a/samples/17.png b/samples/17.png new file mode 100644 index 0000000000000000000000000000000000000000..a0293ec170c8cfdafa174353cef4ee8e8e8e8fef --- /dev/null +++ b/samples/17.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15290c3bdcb2bcc575a293a91a1c1e0debdbfcfe2494a8f36458b59403f5fcea +size 845999 diff --git a/samples/18.png b/samples/18.png new file mode 100644 index 0000000000000000000000000000000000000000..8d50d442953971fe55ce8c3443a99f31e4d51cf2 --- /dev/null +++ b/samples/18.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b5771918d8ba7605198be2d670d2cf713b642c29da4a04f069d98cd5e45baa7 +size 676802 diff --git a/samples/19.png b/samples/19.png new file mode 100644 index 0000000000000000000000000000000000000000..c49cc441a5cd320f44a4f91b80bb1aae4170ca51 --- /dev/null +++ b/samples/19.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df6a779b5df871fdce101c5a2b13c77c25160b658c4c729f0d4a3ab34ae11996 +size 952727 diff --git a/samples/1_out.webp b/samples/1_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..b159edebbc6363374eb4c385a01393abf783764e --- /dev/null +++ b/samples/1_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4ff52fb3f22b7413cfaf090bd65f93be38491d4b250c668e67a33b4cc6c88a8a +size 40780 diff --git a/samples/2.png b/samples/2.png new file mode 100644 index 0000000000000000000000000000000000000000..424bc505f1bc57505ee17615f65e527d84b2082f --- /dev/null +++ b/samples/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a12bcdfac8d901c8d0947213391541dedff0e49964174b9f436726db21a3ad6 +size 899506 diff --git a/samples/20.png b/samples/20.png new file mode 100644 index 0000000000000000000000000000000000000000..6c986d504ff5913022f2a26ccf0cb2cc0d78765a --- /dev/null +++ b/samples/20.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:afa037f9006cdbdf0707db41b88ae6e8109c196f5ceb5f37523a8364b467d534 +size 192685 diff --git a/samples/20_out.webp b/samples/20_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..6f429847fb80d2bbf7b9d66916a2b84e5966a127 --- /dev/null +++ b/samples/20_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45e65cbbadd94d062f4d56a0204dc37d7e64a8e3dcc05b78dbc54cc6d65b8cb1 +size 36774 diff --git a/samples/21.png b/samples/21.png new file mode 100644 index 0000000000000000000000000000000000000000..246828667373ddfbd3784afcbf70342cd8577d12 --- /dev/null +++ b/samples/21.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ecd7d68258b5d342ff9a316c9dfddefda11bc9558e8b582a4bb01977f3da87e5 +size 302983 diff --git a/samples/21_1.png b/samples/21_1.png new file mode 100644 index 0000000000000000000000000000000000000000..e8abb7609276bae3f9af83b2f5a5a04fad21360a --- /dev/null +++ b/samples/21_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf2f5448e75bb285b32265c779b5844ae9c182e53264ecc4950eabca2b563545 +size 4971528 diff --git a/samples/21_out.webp b/samples/21_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..fbde820786401d1137876ff12af4ce4994d17c80 --- /dev/null +++ b/samples/21_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54e64f200816a238acbabe0095ef4541f2b1e1d8adde1a331a3ce7328a14a6b2 +size 40026 diff --git a/samples/22.png b/samples/22.png new file mode 100644 index 0000000000000000000000000000000000000000..8c333f6d216f963bc5f2962b59d3a52710dd8c5a --- /dev/null +++ b/samples/22.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00fd1d1ea48dc25ca4855cf1d6099141cb1375bfd527fa2fab08b0d07208f815 +size 235874 diff --git a/samples/22_out.webp b/samples/22_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..c3e386ae8ecb67e87f5a035ebfb6afa461b29fe0 --- /dev/null +++ b/samples/22_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02d87428b86f2923f7207d130f812bb36b4dc67f76b4ea495c20324eeb791021 +size 22016 diff --git a/samples/23.png b/samples/23.png new file mode 100644 index 0000000000000000000000000000000000000000..97b94e9e607526d30000add0db7eba7d90fb4c36 --- /dev/null +++ b/samples/23.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42b9683272e6bbd92a796ccdfb4caca5b6f317ae9abfa4f345c5003a735ddc9e +size 229596 diff --git a/samples/24.png b/samples/24.png new file mode 100644 index 0000000000000000000000000000000000000000..615f14b52a0fbb88bd03658af3db16723827f3be --- /dev/null +++ b/samples/24.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e498022e215bb0d33b0179d192c05e473ee91f7ebfe6d424317c140cf92bf16 +size 324358 diff --git a/samples/25.png b/samples/25.png new file mode 100644 index 0000000000000000000000000000000000000000..f140155c862979e03637301ed9ab86854e4454d0 --- /dev/null +++ b/samples/25.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ee43af3e149026a1a36c30d64b5ef488bd8a7399b0a5a4f0ea0fd1fb9fe60fb +size 222029 diff --git a/samples/26.png b/samples/26.png new file mode 100644 index 0000000000000000000000000000000000000000..f60e3d2d8a8682f7959336a46db81d49c2a43e42 --- /dev/null +++ b/samples/26.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89e771ba69a141a3f94821be4434edbdf440aa33ec293bfc85285131e95fe3f2 +size 459854 diff --git a/samples/2_out.webp b/samples/2_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..c1fb2b689c7b2ac6d012d8388b14deb624d2f40a --- /dev/null +++ b/samples/2_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f64c0edefd4f15d5d88d55c0776fa95c9db746b5b13c05669211f4e342dff9ab +size 85752 diff --git a/samples/3.png b/samples/3.png new file mode 100644 index 0000000000000000000000000000000000000000..0c05c53ad0bfce0ad48ffd2a6009c93005b2d5df --- /dev/null +++ b/samples/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b87e33444befcd05e121183603c47b3666c1fd2f5656741dfb1d3bb124835702 +size 470075 diff --git a/samples/3_out.webp b/samples/3_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..19299821fe13125519c7885071e790fd302255d9 --- /dev/null +++ b/samples/3_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b8f8b8c43439c6c26d4847a6ac4c77a8e4832695cc4aa3ea86f15dd90028bca +size 42180 diff --git a/samples/4.png b/samples/4.png new file mode 100644 index 0000000000000000000000000000000000000000..2da7fb284ec99b6b92b1f072c6f46fd4837965ce --- /dev/null +++ b/samples/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29689ad98946560acfd07c5be9e06c27c94a19e94dcdd73a388f3dbc46b26259 +size 359441 diff --git a/samples/4_out.webp b/samples/4_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..c86ed9611a883357c4baa5fba4dbd6779db6a751 --- /dev/null +++ b/samples/4_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11fd666873415763a916edd08a9e19ca534162cd73217bffa0dcfcc75d3e13a4 +size 35576 diff --git a/samples/5.png b/samples/5.png new file mode 100644 index 0000000000000000000000000000000000000000..8b183c75bd540c5a7fa05b7eba8a56657a0d078d --- /dev/null +++ b/samples/5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b0def9aa8c45a63bf1ee00e02765ad357456b355fbea66e405f53662bf86db7 +size 268539 diff --git a/samples/5_out.webp b/samples/5_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..d9406c69f90fcb57980d6828b31255f699348831 --- /dev/null +++ b/samples/5_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f625e21ab93e5f6b90954683fcd18c0c2c1fd8465c9b4cf67182dd6d1941d0e2 +size 60376 diff --git a/samples/6.png b/samples/6.png new file mode 100644 index 0000000000000000000000000000000000000000..86cfb99b4437b3e9c9996b68c0ab474634e09816 --- /dev/null +++ b/samples/6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85a0ba550ba4c14c29edb02e3ce12d77b50004489035d0c914d0327197c2f55c +size 352262 diff --git a/samples/6_out.webp b/samples/6_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..f9254ea1d1b5194994ccbf9d7728e30c46bdb828 --- /dev/null +++ b/samples/6_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9616e390eeb36ba16cd881bf65869b5bfd72305a3bdabb94cded20ce30bcafa +size 26108 diff --git a/samples/7.png b/samples/7.png new file mode 100644 index 0000000000000000000000000000000000000000..2a08a68b64e248d0dcc5d71fb418cb1a9261b4b1 --- /dev/null +++ b/samples/7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a8a55d994069cdd39da729c64713a40b24c1d7fe5b6999ef3040a4c36c69a83 +size 791285 diff --git a/samples/7_out.webp b/samples/7_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..d8ba03a98ef51df7e4b0925cf17a9d8fe94ea71d --- /dev/null +++ b/samples/7_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d2defc4bac2563da5f3486b194c03dbd31688e71d3677c61c965cc211c5ef42 +size 63218 diff --git a/samples/8.png b/samples/8.png new file mode 100644 index 0000000000000000000000000000000000000000..cb7046e336f6086f3719fc18f1f9a500f33f8ce5 --- /dev/null +++ b/samples/8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f7047eab7c212fa34f908ca6b0184c4b056144d44be10fe9e6afbac365fd7c9 +size 483145 diff --git a/samples/8_out.webp b/samples/8_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..d27f06448d417399d57d90170e2739016f4ea9b0 --- /dev/null +++ b/samples/8_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3254176ae81676d0d199ba82d4d6e194a4f4a18820b2441992b8a96caf074f2 +size 35850 diff --git a/samples/9.png b/samples/9.png new file mode 100644 index 0000000000000000000000000000000000000000..9bb0f7bbca111a0394fcaf56ea993d5a11733463 --- /dev/null +++ b/samples/9.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b68eb987d1d23f178f4e7d385e32d14506a26d0031ab63c079247d495a0c19b5 +size 743147 diff --git a/samples/9_out.webp b/samples/9_out.webp new file mode 100644 index 0000000000000000000000000000000000000000..9d9c819e50e6a0354d18db1adadd02c6f1b514f7 --- /dev/null +++ b/samples/9_out.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05aed5335b5f3c4b0d1214a5d42ad958e4ddbdd1886344798b862ab0dd382d4d +size 43570