ultimaxxl commited on
Commit
2aedb18
·
2 Parent(s): 6eb3ba1 421487b

Merge branch 'main' into pr/5

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.webp filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 5.23.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: an A subject-drivent image generation control toolkit
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,89 +1,106 @@
1
- import os
2
- import base64
3
- import io
4
- from typing import TypedDict
5
- import requests
6
  import gradio as gr
7
- from PIL import Image
 
 
 
 
 
 
8
 
9
- # Read Baseten configuration from environment variables.
10
- BTEN_API_KEY = os.getenv("API_KEY")
11
- URL = os.getenv("URL")
12
 
13
- def image_to_base64(image: Image.Image) -> str:
14
- """Convert a PIL image to a base64-encoded PNG string."""
15
- with io.BytesIO() as buffer:
16
- image.save(buffer, format="PNG")
17
- return base64.b64encode(buffer.getvalue()).decode("utf-8")
18
 
 
 
 
19
 
20
- def ensure_image(img) -> Image.Image:
21
- """
22
- Ensure the input is a PIL Image.
23
- If it's already a PIL Image, return it.
24
- If it's a string (file path), open it.
25
- If it's a dict with a "name" key, open the file at that path.
26
- """
27
- if isinstance(img, Image.Image):
28
- return img
29
- elif isinstance(img, str):
30
- return Image.open(img)
31
- elif isinstance(img, dict) and "name" in img:
32
- return Image.open(img["name"])
 
 
 
 
 
 
 
 
 
33
  else:
34
- raise ValueError("Cannot convert input to a PIL Image.")
35
-
36
-
37
- def call_baseten_generate(
38
- image: Image.Image,
39
- prompt: str,
40
- steps: int,
41
- strength: float,
42
- height: int,
43
- width: int,
44
- lora_name: str,
45
- remove_bg: bool,
46
- ) -> Image.Image | None:
 
 
 
 
 
 
47
  """
48
- Call the Baseten /predict endpoint with provided parameters and return the generated image.
49
  """
50
- image = ensure_image(image)
51
- b64_image = image_to_base64(image)
52
- payload = {
53
- "image": b64_image,
54
- "prompt": prompt,
55
- "steps": steps,
56
- "strength": strength,
57
- "height": height,
58
- "width": width,
59
- "lora_name": lora_name,
60
- "bgrm": remove_bg,
61
- }
62
- if not BTEN_API_KEY:
63
- headers = {"Authorization": f"Api-Key {os.getenv('API_KEY')}"}
64
- else:
65
- headers = {"Authorization": f"Api-Key {BTEN_API_KEY}"}
66
- try:
67
- if not URL:
68
- raise ValueError("The URL environment variable is not set.")
69
-
70
- response = requests.post(URL, headers=headers, json=payload)
71
- if response.status_code == 200:
72
- data = response.json()
73
- gen_b64 = data.get("generated_image", None)
74
- if gen_b64:
75
- return Image.open(io.BytesIO(base64.b64decode(gen_b64)))
76
- else:
77
- return None
78
- else:
79
- print(f"Error: HTTP {response.status_code}\n{response.text}")
80
- return None
81
- except Exception as e:
82
- print(f"Error: {e}")
83
- return None
84
-
85
-
86
- # Mode defaults for each tab.
 
 
 
 
 
87
 
88
  Mode = TypedDict(
89
  "Mode",
@@ -98,77 +115,76 @@ Mode = TypedDict(
98
  },
99
  )
100
 
 
 
 
 
 
 
 
 
 
 
 
101
  MODE_DEFAULTS: dict[str, Mode] = {
102
  "Subject Generation": {
103
- "model": "subject_99000_512",
104
- "prompt": "A detailed portrait with soft lighting",
105
- "default_strength": 1.2,
106
- "default_height": 512,
107
- "default_width": 512,
108
- "models": [
109
- "zendsd_512_146000",
110
- "subject_99000_512",
111
- # "zen_pers_11000",
112
- "zen_26000_512",
113
- ],
114
- "remove_bg": True,
115
- },
116
- "Background Generation": {
117
- "model": "bg_canny_58000_1024",
118
  "prompt": "A vibrant background with dynamic lighting and textures",
119
  "default_strength": 1.2,
120
  "default_height": 1024,
121
  "default_width": 1024,
122
- "models": [
123
- "bgwlight_15000_1024",
124
- # "rmgb_12000_1024",
125
- "bg_canny_58000_1024",
126
- # "gen_back_3000_1024",
127
- "gen_back_7000_1024",
128
- # "gen_bckgnd_18000_512",
129
- # "gen_bckgnd_18000_512",
130
- # "loose_25000_512",
131
- # "looser_23000_1024",
132
- # "looser_bg_gen_21000_1280",
133
- # "old_looser_46000_1024",
134
- # "relight_bg_gen_31000_1024",
135
- ],
136
- "remove_bg": True,
137
- },
138
- "Canny": {
139
- "model": "canny_21000_1024",
140
- "prompt": "A futuristic cityscape with neon lights",
141
- "default_strength": 1.2,
142
- "default_height": 1024,
143
- "default_width": 1024,
144
- "models": ["canny_21000_1024"],
145
  "remove_bg": True,
146
  },
147
- "Depth": {
148
- "model": "depth_9800_1024",
149
- "prompt": "A scene with pronounced depth and perspective",
150
- "default_strength": 1.2,
151
- "default_height": 1024,
152
- "default_width": 1024,
153
- "models": [
154
- "depth_9800_1024",
155
- ],
156
- "remove_bg": True,
157
- },
158
- "Deblurring": {
159
- "model": "deblurr_1024_10000",
160
- "prompt": "A scene with pronounced depth and perspective",
161
- "default_strength": 1.2,
162
- "default_height": 1024,
163
- "default_width": 1024,
164
- "models": ["deblurr_1024_10000"], # "slight_deblurr_18000",
165
- "remove_bg": False,
166
- },
167
- }
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  header = """
171
- <h1>🌍 ZenCtrl / FLUX</h1>
172
  <div align="center" style="line-height: 1;">
173
  <a href="https://github.com/FotographerAI/ZenCtrl/tree/main" target="_blank" style="margin: 2px;" name="github_repo_link"><img src="https://img.shields.io/badge/GitHub-Repo-181717.svg" alt="GitHub Repo" style="display: inline-block; vertical-align: middle;"></a>
174
  <a href="https://huggingface.co/spaces/fotographerai/ZenCtrl" target="_blank" name="huggingface_space_link"><img src="https://img.shields.io/badge/🤗_HuggingFace-Space-ffbd45.svg" alt="HuggingFace Space" style="display: inline-block; vertical-align: middle;"></a>
@@ -178,136 +194,110 @@ header = """
178
  </div>
179
  """
180
 
181
- defaults = MODE_DEFAULTS["Subject Generation"]
182
-
183
-
184
- with gr.Blocks(title="🌍 ZenCtrl") as demo:
185
  gr.HTML(header)
186
  gr.Markdown(
187
  """
188
  # ZenCtrl Demo
189
- [WIP] One Agent to Generate multi-view, diverse-scene, and task-specific high-resolution images from a single subject image—without fine-tuning.
190
  We are first releasing some of the task specific weights and will release the codes soon.
191
  The goal is to unify all of the visual content generation tasks with a single LLM...
192
 
193
- **Modes:**
194
- - **Subject Generation:** Focuses on generating detailed subject portraits.
195
- - **Background Generation:** Creates dynamic, vibrant backgrounds:
196
- You can generate part of the image from sketch while keeping part of it as it is.
197
- - **Canny:** Emphasizes strong edge detection.
198
- - **Depth:** Produces images with realistic depth and perspective.
199
 
200
  For more details, shoot us a message on discord.
201
  """
202
  )
 
 
203
  with gr.Tabs():
204
- for mode in MODE_DEFAULTS:
205
- with gr.Tab(mode):
206
- defaults = MODE_DEFAULTS[mode]
207
- gr.Markdown(f"### {mode} Mode")
208
- gr.Markdown(f"**Default Model:** {defaults['model']}")
209
 
 
210
  with gr.Row():
211
- with gr.Column(scale=2, min_width=370):
212
- input_image = gr.Image(
213
- label="Upload Image",
214
- type="pil",
215
- scale=3,
216
- height=370,
217
- min_width=100,
218
  )
219
- generate_button = gr.Button("Generate")
220
- with gr.Blocks(title="Options"):
221
- model_dropdown = gr.Dropdown(
222
- label="Model",
223
- choices=defaults["models"],
224
- value=defaults["model"],
225
- interactive=True,
226
- )
227
- remove_bg_checkbox = gr.Checkbox(
228
- label="Remove Background", value=defaults["remove_bg"]
229
- )
 
 
 
 
230
 
 
231
  with gr.Column(scale=2):
232
- output_image = gr.Image(
233
- label="Generated Image",
234
- type="pil",
235
- height=573,
236
- scale=4,
237
- min_width=100,
238
- )
239
 
240
- gr.Markdown("#### Prompt")
241
- prompt_box = gr.Textbox(
242
- label="Prompt", value=defaults["prompt"], lines=2
243
- )
244
 
245
- # Wrap generation parameters in an Accordion for collapsible view.
246
- with gr.Accordion("Generation Parameters", open=False):
247
- with gr.Row():
248
- step_slider = gr.Slider(
249
- minimum=2, maximum=28, value=2, step=2, label="Steps"
250
- )
251
- strength_slider = gr.Slider(
252
- minimum=0.5,
253
- maximum=2.0,
254
- value=defaults["default_strength"],
255
- step=0.1,
256
- label="Strength",
257
- )
258
- with gr.Row():
259
- height_slider = gr.Slider(
260
- minimum=512,
261
- maximum=1360,
262
- value=defaults["default_height"],
263
- step=1,
264
- label="Height",
265
- )
266
- width_slider = gr.Slider(
267
- minimum=512,
268
- maximum=1360,
269
- value=defaults["default_width"],
270
- step=1,
271
- label="Width",
272
  )
 
 
273
 
274
- def on_generate_click(
275
- model_name,
276
- prompt,
277
- steps,
278
- strength,
279
- height,
280
- width,
281
- remove_bg,
282
- image,
283
- ):
284
- return call_baseten_generate(
285
- image,
286
- prompt,
287
- steps,
288
- strength,
289
- height,
290
- width,
291
- model_name,
292
- remove_bg,
293
  )
294
 
295
- generate_button.click(
296
- fn=on_generate_click,
297
- inputs=[
298
- model_dropdown,
299
- prompt_box,
300
- step_slider,
301
- strength_slider,
302
- height_slider,
303
- width_slider,
304
- remove_bg_checkbox,
305
- input_image,
306
- ],
307
  outputs=[output_image],
308
- concurrency_limit=None
309
  )
310
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
 
312
  if __name__ == "__main__":
313
- demo.launch()
 
 
 
 
 
1
+ import spaces
 
 
 
 
2
  import gradio as gr
3
+ import torch
4
+ from typing import TypedDict
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ from diffusers.pipelines import FluxPipeline
7
+ from diffusers import FluxTransformer2DModel
8
+ import numpy as np
9
+ import examples_db
10
 
 
 
 
11
 
 
 
 
 
 
12
 
13
+ from flux.condition import Condition
14
+ from flux.generate import seed_everything, generate
15
+ from flux.lora_controller import set_lora_scale
16
 
17
+ pipe = None
18
+ current_adapter = None
19
+ use_int8 = False
20
+ model_config = { "union_cond_attn": True, "add_cond_attn": False, "latent_lora": False, "independent_condition": True}
21
+
22
+ def get_gpu_memory():
23
+ return torch.cuda.get_device_properties(0).total_memory / 1024**3
24
+
25
+
26
+ def init_pipeline():
27
+ global pipe
28
+ if use_int8 or get_gpu_memory() < 33:
29
+ transformer_model = FluxTransformer2DModel.from_pretrained(
30
+ "sayakpaul/flux.1-schell-int8wo-improved",
31
+ torch_dtype=torch.bfloat16,
32
+ use_safetensors=False,
33
+ )
34
+ pipe = FluxPipeline.from_pretrained(
35
+ "black-forest-labs/FLUX.1-schnell",
36
+ transformer=transformer_model,
37
+ torch_dtype=torch.bfloat16,
38
+ )
39
  else:
40
+ pipe = FluxPipeline.from_pretrained(
41
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
42
+ )
43
+ pipe = pipe.to("cuda")
44
+
45
+
46
+ # Optional: Load additional LoRA weights
47
+ pipe.load_lora_weights(
48
+ "fotographerai/zenctrl_tools",
49
+ weight_name="weights/zen2con_1024_10000/"
50
+ "pytorch_lora_weights.safetensors",
51
+ adapter_name="subject"
52
+ )
53
+
54
+ # Optional: Load additional LoRA weights
55
+ #pipe.load_lora_weights("XLabs-AI/flux-RealismLora", adapter_name="realism")
56
+
57
+
58
+ def paste_on_white_background(image: Image.Image) -> Image.Image:
59
  """
60
+ Pastes a transparent image onto a white background of the same size.
61
  """
62
+ if image.mode != "RGBA":
63
+ image = image.convert("RGBA")
64
+
65
+ # Create white background
66
+ white_bg = Image.new("RGBA", image.size, (255, 255, 255, 255))
67
+ white_bg.paste(image, (0, 0), mask=image)
68
+ return white_bg.convert("RGB") # Convert back to RGB if you don't need alpha
69
+
70
+ #@spaces.GPU
71
+ def process_image_and_text(image, text, steps=8, strength_sub=1.0, strength_spat=1.0, size=1024):
72
+ # center crop image
73
+ w, h, min_size = image.size[0], image.size[1], min(image.size)
74
+ image = image.crop(
75
+ (
76
+ (w - min_size) // 2,
77
+ (h - min_size) // 2,
78
+ (w + min_size) // 2,
79
+ (h + min_size) // 2,
80
+ )
81
+ )
82
+ image = image.resize((size, size))
83
+ image = paste_on_white_background(image)
84
+ condition0 = Condition("subject", image, position_delta=(0, size // 16))
85
+ condition1 = Condition("subject", image, position_delta=(0, -size // 16))
86
+
87
+ pipe = get_pipeline()
88
+
89
+ with set_lora_scale(["subject"], scale=3.0):
90
+ result_img = generate(
91
+ pipe,
92
+ prompt=text.strip(),
93
+ conditions=[condition0, condition1],
94
+ num_inference_steps=steps,
95
+ height=1024,
96
+ width=1024,
97
+ condition_scale = [strength_sub,strength_spat],
98
+ model_config=model_config,
99
+ ).images[0]
100
+
101
+ return result_img
102
+
103
+ # ================== MODE CONFIG =====================
104
 
105
  Mode = TypedDict(
106
  "Mode",
 
115
  },
116
  )
117
 
118
+ MODEL_TO_LORA: dict[str, str] = {
119
+ # dropdown-value # relative path inside the HF repo
120
+ "zen2con_1024_10000": "weights/zen2con_1024_10000/pytorch_lora_weights.safetensors",
121
+ "zen2con_1440_17000": "weights/zen2con_1440_17000/pytorch_lora_weights.safetensors",
122
+ "zen_sub_sub_1024_10000": "weights/zen_sub_sub_1024_10000/pytorch_lora_weights.safetensors",
123
+ "zen_toys_1024_4000": "weights/zen_toys_1024_4000/12000/pytorch_lora_weights.safetensors",
124
+ "zen_toys_1024_15000": "weights/zen_toys_1024_4000/zen_toys_1024_15000/pytorch_lora_weights.safetensors",
125
+ # add more as you upload them
126
+ }
127
+
128
+
129
  MODE_DEFAULTS: dict[str, Mode] = {
130
  "Subject Generation": {
131
+ "model": "zen2con_1024_10000",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  "prompt": "A vibrant background with dynamic lighting and textures",
133
  "default_strength": 1.2,
134
  "default_height": 1024,
135
  "default_width": 1024,
136
+ "models": list(MODEL_TO_LORA.keys()),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  "remove_bg": True,
138
  },
139
+ #"Image fix": {
140
+ # "model": "zen_toys_1024_4000",
141
+ # "prompt": "A detailed portrait with soft lighting",
142
+ # "default_strength": 1.2,
143
+ # "default_height": 1024,
144
+ # "default_width": 1024,
145
+ # "models": ["weights/zen_toys_1024_4000/12000/", "weights/zen_toys_1024_4000/12000/"],
146
+ # "remove_bg": True,
147
+ #}
148
+ }
149
+
150
+
151
+ def get_pipeline():
152
+ """Lazy-build the pipeline inside the GPU worker."""
153
+ global pipe
154
+ if pipe is None:
155
+ init_pipeline() # safe here – this fn is @spaces.GPU wrapped
156
+ return pipe
157
+
 
 
158
 
159
+ def get_samples():
160
+ sample_list = [
161
+ {
162
+ "image": "samples/1.png",
163
+ "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
164
+ },
165
+ {
166
+ "image": "samples/2.png",
167
+ "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
168
+ },
169
+ {
170
+ "image": "samples/3.png",
171
+ "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
172
+ },
173
+ {
174
+ "image": "samples/4.png",
175
+ "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
176
+ },
177
+ {
178
+ "image": "samples/5.png",
179
+ "text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.",
180
+ },
181
+ ]
182
+ return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list]
183
+
184
+ # =============== UI ===============
185
 
186
  header = """
187
+ <h1>🌍 ZenCtrl medium</h1>
188
  <div align="center" style="line-height: 1;">
189
  <a href="https://github.com/FotographerAI/ZenCtrl/tree/main" target="_blank" style="margin: 2px;" name="github_repo_link"><img src="https://img.shields.io/badge/GitHub-Repo-181717.svg" alt="GitHub Repo" style="display: inline-block; vertical-align: middle;"></a>
190
  <a href="https://huggingface.co/spaces/fotographerai/ZenCtrl" target="_blank" name="huggingface_space_link"><img src="https://img.shields.io/badge/🤗_HuggingFace-Space-ffbd45.svg" alt="HuggingFace Space" style="display: inline-block; vertical-align: middle;"></a>
 
194
  </div>
195
  """
196
 
197
+ with gr.Blocks(title="🌍 ZenCtrl-medium") as demo:
198
+ # ---------- banner ----------
 
 
199
  gr.HTML(header)
200
  gr.Markdown(
201
  """
202
  # ZenCtrl Demo
203
+ One framework to Generate multi-view, diverse-scene, and task-specific high-resolution images from a single subject image—without fine-tuning.
204
  We are first releasing some of the task specific weights and will release the codes soon.
205
  The goal is to unify all of the visual content generation tasks with a single LLM...
206
 
207
+ **Mode:**
208
+ - **Subject-driven Image Generation:** Generate in-context images of your subject with high fidelity and in different perspectives.
 
 
 
 
209
 
210
  For more details, shoot us a message on discord.
211
  """
212
  )
213
+
214
+ # ---------- tab bar ----------
215
  with gr.Tabs():
216
+ for mode_name, defaults in MODE_DEFAULTS.items():
217
+ with gr.Tab(mode_name):
218
+ gr.Markdown(f"### {mode_name}")
 
 
219
 
220
+ # -------- left (input) column --------
221
  with gr.Row():
222
+ with gr.Column(scale=2):
223
+ input_image = gr.Image(label="Input Image", type="pil")
224
+ model_dropdown = gr.Dropdown(
225
+ label="Model (LoRA adapter)",
226
+ choices=defaults["models"],
227
+ value=defaults["model"],
228
+ interactive=True,
229
  )
230
+ prompt_box = gr.Textbox(label="Prompt",
231
+ value=defaults["prompt"], lines=2)
232
+ generate_btn = gr.Button("Generate")
233
+
234
+ with gr.Accordion("Generation Parameters", open=False):
235
+ step_slider = gr.Slider(2, 28, value=12, step=2, label="Steps")
236
+ strength_sub_slider = gr.Slider(0.0, 2.0,
237
+ value=defaults["default_strength"],
238
+ step=0.1, label="Strength (subject)")
239
+ strength_spat_slider = gr.Slider(0.0, 2.0,
240
+ value=defaults["default_strength"],
241
+ step=0.1, label="Strength (spatial)")
242
+ size_slider = gr.Slider(512, 2048,
243
+ value=defaults["default_height"],
244
+ step=64, label="Size (px)")
245
 
246
+ # -------- right (output) column --------
247
  with gr.Column(scale=2):
248
+ output_image = gr.Image(label="Output Image", type="pil")
 
 
 
 
 
 
249
 
250
+ # ---------- click handler ----------
251
+ @spaces.GPU
252
+ def _run(image, model_name, prompt, steps, s_sub, s_spat, size):
253
+ global current_adapter
254
 
255
+ pipe = get_pipeline()
256
+
257
+ # ── switch adapter if needed ──────────────────────────
258
+ if model_name != current_adapter:
259
+ lora_path = MODEL_TO_LORA[model_name]
260
+ # load & activate the chosen adapter
261
+ pipe.load_lora_weights(
262
+ "fotographerai/zenctrl_tools",
263
+ weight_name=lora_path,
264
+ adapter_name=model_name,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  )
266
+ pipe.set_adapters([model_name])
267
+ current_adapter = model_name
268
 
269
+ # ── run generation ───────────────────────────────────
270
+ delta = size // 16
271
+ return process_image_and_text(
272
+ image, prompt, steps=steps,
273
+ strength_sub=s_sub, strength_spat=s_spat, size=size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  )
275
 
276
+ generate_btn.click(
277
+ fn=_run,
278
+ inputs=[input_image, model_dropdown, prompt_box,
279
+ step_slider, strength_sub_slider,
280
+ strength_spat_slider, size_slider],
 
 
 
 
 
 
 
281
  outputs=[output_image],
 
282
  )
283
 
284
+ # ---------------- Templates --------------------
285
+ if examples_db.MODE_EXAMPLES.get(mode_name):
286
+ gr.Examples(
287
+ examples=examples_db.MODE_EXAMPLES[mode_name],
288
+ inputs=[ input_image, # Image widget
289
+ model_dropdown, # Dropdown for adapter
290
+ prompt_box, # Textbox for prompt
291
+ output_image, # Gallery for output
292
+ ],
293
+ label="Presets (Image / Model / Prompt)",
294
+ examples_per_page=15,
295
+ )
296
 
297
+ # =============== launch ===============
298
  if __name__ == "__main__":
299
+ #init_pipeline()
300
+ demo.launch(
301
+ debug=True,
302
+ share=True
303
+ )
examples_db.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODE_EXAMPLES = {
2
+ "Subject Generation": [
3
+ [
4
+ "samples/7.png",
5
+ "zen2con_1440_17000",
6
+ "A man wearing white shoes stepping outside, closeup on the shoes",
7
+ "samples/22_out.webp",
8
+ ],
9
+ [
10
+ "samples/7.png",
11
+ "zen2con_1440_17000",
12
+ "Low angle photography, shoes, stepping in water in a futuristic cityscape with neon lights, in the back in large, the words 'ZenCtrl 2025' are written in a futuristic font",
13
+ "samples/2_out.webp",
14
+ ],
15
+ [
16
+ "samples/8.png",
17
+ "zen2con_1024_10000",
18
+ "a watch, resting on wet volcanic rock, ocean spray mist, golden sunrise back-light, 50 mm lens, shallow DOF, 8k realism",
19
+ "samples/3_out.webp",
20
+ ],
21
+ ["samples/9.png","zen_sub_sub_1024_10000", "a bag, placed on a desk in a cozy bedroom, sunny light coming through the window", "samples/4_out.webp"],
22
+ ["samples/10.png", "zen2con_1440_17000","A bottle sits poolside at a luxury hotel. The bottle rests in front of a clear blue pool and resort landscape in the background. Light reflects off the bottle highlighting its sophisticated design. An elegant glass and tropical fruits are included creating a relaxing atmosphere. The image should convey luxury sophistication and an otherworldly refreshment.", "samples/5_out.webp"],
23
+
24
+ #[
25
+ # "samples/11.png",
26
+ # "zen_toys_1024_4000",
27
+ # "Low angle photography, shoes, stepping in water in a futuristic cityscape with neon lights, in the back in large, the words 'ZenCtrl 2025' are written in a futuristic font",
28
+ # "samples/1.png",
29
+ #],
30
+ [
31
+ "samples/12.png",
32
+ "zen2con_1024_10000",
33
+ "a woman , wearing a pair of sunglasses, in front of the beach",
34
+ "samples/6_out.webp",
35
+ ],
36
+ [
37
+ "samples/13.png",
38
+ "zen2con_1440_17000",
39
+ "a woman , standing outside , in the streets, next to a cafe",
40
+ "samples/7_out.webp",
41
+ ],
42
+ ["samples/14.png","zen_sub_sub_1024_10000", "a bag , held by woman , walking outside", "samples/8_out.webp"],
43
+ ["samples/15.png","zen_sub_sub_1024_10000", "a man wearing a suit, sitting on a chair in a living room", "samples/9_out.webp"],
44
+ [
45
+ "samples/6.png",
46
+ "zen_toys_1024_4000",
47
+ "A kid playing with a toy figurine , indoor, on a sunny day",
48
+ "samples/11_out.webp",
49
+ ],
50
+ [
51
+ "samples/6.png",
52
+ "zen_toys_1024_4000",
53
+ "A small figurine placed on a table, surrounded by toys. A child is placing it",
54
+ "samples/12_out.webp",
55
+ ],
56
+ [
57
+ "samples/17.png",
58
+ "zen2con_1024_10000",
59
+ "A black man wearing black wireless headphones at a basketball game",
60
+ "samples/20_out.webp",
61
+ ],
62
+ [
63
+ "samples/21_1.png",
64
+ "zen2con_1024_10000",
65
+ "a man holding a camera facing the objective.",
66
+ "samples/21_out.webp",
67
+ ],
68
+ ],
69
+ "Image fix": [
70
+ [
71
+ "samples/1.png",
72
+ "placed on a dark marble table in a bathroom of luxury hotel modern light authentic atmosphere",
73
+ "samples/1.png",
74
+ ],
75
+ [
76
+ "samples/1.png",
77
+ "sitting on the middle of the city road on a sunny day very bright day front view",
78
+ "samples/1.png",
79
+ ],
80
+ [
81
+ "samples/1.png",
82
+ "A creative capture in an art gallery, with soft, focused lighting highlighting both the person’s features and the abstract surroundings, exuding sophistication.",
83
+ "samples/1.png",
84
+ ],
85
+ [
86
+ "samples/1.png",
87
+ "In a rain-soaked urban nightscape, with headlights piercing through the mist and wet streets reflecting the city’s vibrant neon colors, creating an atmosphere of mystery and modern elegance.",
88
+ "samples/1.png",
89
+ ],
90
+ [
91
+ "samples/1.png",
92
+ "An elegant room scene featuring a minimalist table and chairs, next to a flower vase, illuminated by ambient lighting that casts gentle shadows and enhances the refined, contemporary decor.",
93
+ "samples/1.png",
94
+ ],
95
+ ],
96
+ }
flux/__init__.py ADDED
File without changes
flux/block.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Recycled from Ominicontrol and modified to accept an extra condition.
2
+ # While Zenctrl pursued a similar idea, it diverged structurally.
3
+ # We appreciate the clarity of Omini's implementation and decided to align with it.
4
+
5
+ import torch
6
+ from typing import List, Union, Optional, Dict, Any, Callable
7
+ from diffusers.models.attention_processor import Attention, F
8
+ from .lora_controller import enable_lora
9
+ from diffusers.models.embeddings import apply_rotary_emb
10
+
11
+ def attn_forward(
12
+ attn: Attention,
13
+ hidden_states: torch.FloatTensor,
14
+ encoder_hidden_states: torch.FloatTensor = None,
15
+ condition_latents: torch.FloatTensor = None,
16
+ extra_condition_latents: torch.FloatTensor = None,
17
+ attention_mask: Optional[torch.FloatTensor] = None,
18
+ image_rotary_emb: Optional[torch.Tensor] = None,
19
+ cond_rotary_emb: Optional[torch.Tensor] = None,
20
+ extra_cond_rotary_emb: Optional[torch.Tensor] = None,
21
+ model_config: Optional[Dict[str, Any]] = {},
22
+ ) -> torch.FloatTensor:
23
+ batch_size, _, _ = (
24
+ hidden_states.shape
25
+ if encoder_hidden_states is None
26
+ else encoder_hidden_states.shape
27
+ )
28
+
29
+ with enable_lora(
30
+ (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
31
+ ):
32
+ # `sample` projections.
33
+ query = attn.to_q(hidden_states)
34
+ key = attn.to_k(hidden_states)
35
+ value = attn.to_v(hidden_states)
36
+
37
+ inner_dim = key.shape[-1]
38
+ head_dim = inner_dim // attn.heads
39
+
40
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
41
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
42
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
43
+
44
+ if attn.norm_q is not None:
45
+ query = attn.norm_q(query)
46
+ if attn.norm_k is not None:
47
+ key = attn.norm_k(key)
48
+
49
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
50
+ if encoder_hidden_states is not None:
51
+ # `context` projections.
52
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
53
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
54
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
55
+
56
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
57
+ batch_size, -1, attn.heads, head_dim
58
+ ).transpose(1, 2)
59
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
60
+ batch_size, -1, attn.heads, head_dim
61
+ ).transpose(1, 2)
62
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
63
+ batch_size, -1, attn.heads, head_dim
64
+ ).transpose(1, 2)
65
+
66
+ if attn.norm_added_q is not None:
67
+ encoder_hidden_states_query_proj = attn.norm_added_q(
68
+ encoder_hidden_states_query_proj
69
+ )
70
+ if attn.norm_added_k is not None:
71
+ encoder_hidden_states_key_proj = attn.norm_added_k(
72
+ encoder_hidden_states_key_proj
73
+ )
74
+
75
+ # attention
76
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
77
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
78
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
79
+
80
+ if image_rotary_emb is not None:
81
+
82
+
83
+ query = apply_rotary_emb(query, image_rotary_emb)
84
+ key = apply_rotary_emb(key, image_rotary_emb)
85
+
86
+ if condition_latents is not None:
87
+ cond_query = attn.to_q(condition_latents)
88
+ cond_key = attn.to_k(condition_latents)
89
+ cond_value = attn.to_v(condition_latents)
90
+
91
+ cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
92
+ 1, 2
93
+ )
94
+ cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
95
+ cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
96
+ 1, 2
97
+ )
98
+ if attn.norm_q is not None:
99
+ cond_query = attn.norm_q(cond_query)
100
+ if attn.norm_k is not None:
101
+ cond_key = attn.norm_k(cond_key)
102
+
103
+ #extra condition
104
+ if extra_condition_latents is not None:
105
+ extra_cond_query = attn.to_q(extra_condition_latents)
106
+ extra_cond_key = attn.to_k(extra_condition_latents)
107
+ extra_cond_value = attn.to_v(extra_condition_latents)
108
+
109
+ extra_cond_query = extra_cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
110
+ 1, 2
111
+ )
112
+ extra_cond_key = extra_cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
113
+ extra_cond_value = extra_cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
114
+ 1, 2
115
+ )
116
+ if attn.norm_q is not None:
117
+ extra_cond_query = attn.norm_q(extra_cond_query)
118
+ if attn.norm_k is not None:
119
+ extra_cond_key = attn.norm_k(extra_cond_key)
120
+
121
+
122
+ if extra_cond_rotary_emb is not None:
123
+ extra_cond_query = apply_rotary_emb(extra_cond_query, extra_cond_rotary_emb)
124
+ extra_cond_key = apply_rotary_emb(extra_cond_key, extra_cond_rotary_emb)
125
+
126
+ if cond_rotary_emb is not None:
127
+ cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
128
+ cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
129
+
130
+ if condition_latents is not None:
131
+ if extra_condition_latents is not None:
132
+
133
+ query = torch.cat([query, cond_query, extra_cond_query], dim=2)
134
+ key = torch.cat([key, cond_key, extra_cond_key], dim=2)
135
+ value = torch.cat([value, cond_value, extra_cond_value], dim=2)
136
+ else:
137
+ query = torch.cat([query, cond_query], dim=2)
138
+ key = torch.cat([key, cond_key], dim=2)
139
+ value = torch.cat([value, cond_value], dim=2)
140
+ print("concat Omini latents: ", query.shape, key.shape, value.shape)
141
+
142
+
143
+ if not model_config.get("union_cond_attn", True):
144
+
145
+ attention_mask = torch.ones(
146
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
147
+ )
148
+ condition_n = cond_query.shape[2]
149
+ attention_mask[-condition_n:, :-condition_n] = False
150
+ attention_mask[:-condition_n, -condition_n:] = False
151
+ elif model_config.get("independent_condition", False):
152
+ attention_mask = torch.ones(
153
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
154
+ )
155
+ condition_n = cond_query.shape[2]
156
+ attention_mask[-condition_n:, :-condition_n] = False
157
+
158
+ if hasattr(attn, "c_factor"):
159
+ attention_mask = torch.zeros(
160
+ query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
161
+ )
162
+ condition_n = cond_query.shape[2]
163
+ condition_e = extra_cond_query.shape[2]
164
+ bias = torch.log(attn.c_factor[0])
165
+ attention_mask[-condition_n-condition_e:-condition_e, :-condition_n-condition_e] = bias
166
+ attention_mask[:-condition_n-condition_e, -condition_n-condition_e:-condition_e] = bias
167
+
168
+ bias = torch.log(attn.c_factor[1])
169
+ attention_mask[-condition_e:, :-condition_n-condition_e] = bias
170
+ attention_mask[:-condition_n-condition_e, -condition_e:] = bias
171
+
172
+ hidden_states = F.scaled_dot_product_attention(
173
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
174
+ )
175
+ hidden_states = hidden_states.transpose(1, 2).reshape(
176
+ batch_size, -1, attn.heads * head_dim
177
+ )
178
+ hidden_states = hidden_states.to(query.dtype)
179
+
180
+ if encoder_hidden_states is not None:
181
+ if condition_latents is not None:
182
+ if extra_condition_latents is not None:
183
+ encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = (
184
+ hidden_states[:, : encoder_hidden_states.shape[1]],
185
+ hidden_states[
186
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]*2
187
+ ],
188
+ hidden_states[:, -condition_latents.shape[1]*2 :-condition_latents.shape[1]],
189
+ hidden_states[:, -condition_latents.shape[1] :], #extra condition latents
190
+ )
191
+ else:
192
+ encoder_hidden_states, hidden_states, condition_latents = (
193
+ hidden_states[:, : encoder_hidden_states.shape[1]],
194
+ hidden_states[
195
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
196
+ ],
197
+ hidden_states[:, -condition_latents.shape[1] :]
198
+ )
199
+ else:
200
+ encoder_hidden_states, hidden_states = (
201
+ hidden_states[:, : encoder_hidden_states.shape[1]],
202
+ hidden_states[:, encoder_hidden_states.shape[1] :],
203
+ )
204
+
205
+ with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
206
+ # linear proj
207
+ hidden_states = attn.to_out[0](hidden_states)
208
+ # dropout
209
+ hidden_states = attn.to_out[1](hidden_states)
210
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
211
+
212
+ if condition_latents is not None:
213
+ condition_latents = attn.to_out[0](condition_latents)
214
+ condition_latents = attn.to_out[1](condition_latents)
215
+
216
+ if extra_condition_latents is not None:
217
+ extra_condition_latents = attn.to_out[0](extra_condition_latents)
218
+ extra_condition_latents = attn.to_out[1](extra_condition_latents)
219
+
220
+
221
+ return (
222
+ # (hidden_states, encoder_hidden_states, condition_latents, extra_condition_latents)
223
+ (hidden_states, encoder_hidden_states, condition_latents, extra_condition_latents)
224
+ if condition_latents is not None
225
+ else (hidden_states, encoder_hidden_states)
226
+ )
227
+ elif condition_latents is not None:
228
+ # if there are condition_latents, we need to separate the hidden_states and the condition_latents
229
+ if extra_condition_latents is not None:
230
+ hidden_states, condition_latents, extra_condition_latents = (
231
+ hidden_states[:, : -condition_latents.shape[1]*2],
232
+ hidden_states[:, -condition_latents.shape[1]*2 :-condition_latents.shape[1]],
233
+ hidden_states[:, -condition_latents.shape[1] :],
234
+ )
235
+ else:
236
+ hidden_states, condition_latents = (
237
+ hidden_states[:, : -condition_latents.shape[1]],
238
+ hidden_states[:, -condition_latents.shape[1] :],
239
+ )
240
+ return hidden_states, condition_latents, extra_condition_latents
241
+ else:
242
+ return hidden_states
243
+
244
+
245
+ def block_forward(
246
+ self,
247
+ hidden_states: torch.FloatTensor,
248
+ encoder_hidden_states: torch.FloatTensor,
249
+ condition_latents: torch.FloatTensor,
250
+ extra_condition_latents: torch.FloatTensor,
251
+ temb: torch.FloatTensor,
252
+ cond_temb: torch.FloatTensor,
253
+ extra_cond_temb: torch.FloatTensor,
254
+ cond_rotary_emb=None,
255
+ extra_cond_rotary_emb=None,
256
+ image_rotary_emb=None,
257
+ model_config: Optional[Dict[str, Any]] = {},
258
+ ):
259
+ use_cond = condition_latents is not None
260
+
261
+ use_extra_cond = extra_condition_latents is not None
262
+ with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
263
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
264
+ hidden_states, emb=temb
265
+ )
266
+
267
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
268
+ self.norm1_context(encoder_hidden_states, emb=temb)
269
+ )
270
+
271
+ if use_cond:
272
+ (
273
+ norm_condition_latents,
274
+ cond_gate_msa,
275
+ cond_shift_mlp,
276
+ cond_scale_mlp,
277
+ cond_gate_mlp,
278
+ ) = self.norm1(condition_latents, emb=cond_temb)
279
+ (
280
+ norm_extra_condition_latents,
281
+ extra_cond_gate_msa,
282
+ extra_cond_shift_mlp,
283
+ extra_cond_scale_mlp,
284
+ extra_cond_gate_mlp,
285
+ ) = self.norm1(extra_condition_latents, emb=extra_cond_temb)
286
+
287
+ # Attention.
288
+ result = attn_forward(
289
+ self.attn,
290
+ model_config=model_config,
291
+ hidden_states=norm_hidden_states,
292
+ encoder_hidden_states=norm_encoder_hidden_states,
293
+ condition_latents=norm_condition_latents if use_cond else None,
294
+ extra_condition_latents=norm_extra_condition_latents if use_cond else None,
295
+ image_rotary_emb=image_rotary_emb,
296
+ cond_rotary_emb=cond_rotary_emb if use_cond else None,
297
+ extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_cond else None,
298
+ )
299
+ # print("in self block: ", result.shape)
300
+ attn_output, context_attn_output = result[:2]
301
+ cond_attn_output = result[2] if use_cond else None
302
+ extra_condition_output = result[3]
303
+
304
+ # Process attention outputs for the `hidden_states`.
305
+ # 1. hidden_states
306
+ attn_output = gate_msa.unsqueeze(1) * attn_output
307
+ hidden_states = hidden_states + attn_output
308
+ # 2. encoder_hidden_states
309
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
310
+
311
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
312
+ # 3. condition_latents
313
+ if use_cond:
314
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
315
+ condition_latents = condition_latents + cond_attn_output
316
+ #need to make new condition_extra and add extra_condition_output
317
+ if use_extra_cond:
318
+ extra_condition_output = extra_cond_gate_msa.unsqueeze(1) * extra_condition_output
319
+ extra_condition_latents = extra_condition_latents + extra_condition_output
320
+
321
+ if model_config.get("add_cond_attn", False):
322
+ hidden_states += cond_attn_output
323
+ hidden_states += extra_condition_output
324
+
325
+
326
+ # LayerNorm + MLP.
327
+ # 1. hidden_states
328
+ norm_hidden_states = self.norm2(hidden_states)
329
+ norm_hidden_states = (
330
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
331
+ )
332
+ # 2. encoder_hidden_states
333
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
334
+ norm_encoder_hidden_states = (
335
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
336
+ )
337
+ # 3. condition_latents
338
+ if use_cond:
339
+ norm_condition_latents = self.norm2(condition_latents)
340
+ norm_condition_latents = (
341
+ norm_condition_latents * (1 + cond_scale_mlp[:, None])
342
+ + cond_shift_mlp[:, None]
343
+ )
344
+
345
+ if use_extra_cond:
346
+ #added conditions
347
+ extra_norm_condition_latents = self.norm2(extra_condition_latents)
348
+ extra_norm_condition_latents = (
349
+ extra_norm_condition_latents * (1 + extra_cond_scale_mlp[:, None])
350
+ + extra_cond_shift_mlp[:, None]
351
+ )
352
+
353
+ # Feed-forward.
354
+ with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
355
+ # 1. hidden_states
356
+ ff_output = self.ff(norm_hidden_states)
357
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
358
+ # 2. encoder_hidden_states
359
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
360
+ context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
361
+ # 3. condition_latents
362
+ if use_cond:
363
+ cond_ff_output = self.ff(norm_condition_latents)
364
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
365
+
366
+ if use_extra_cond:
367
+ extra_cond_ff_output = self.ff(extra_norm_condition_latents)
368
+ extra_cond_ff_output = extra_cond_gate_mlp.unsqueeze(1) * extra_cond_ff_output
369
+
370
+ # Process feed-forward outputs.
371
+ hidden_states = hidden_states + ff_output
372
+ encoder_hidden_states = encoder_hidden_states + context_ff_output
373
+ if use_cond:
374
+ condition_latents = condition_latents + cond_ff_output
375
+ if use_extra_cond:
376
+ extra_condition_latents = extra_condition_latents + extra_cond_ff_output
377
+
378
+ # Clip to avoid overflow.
379
+ if encoder_hidden_states.dtype == torch.float16:
380
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
381
+
382
+ return encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents if use_cond else None
383
+
384
+
385
+ def single_block_forward(
386
+ self,
387
+ hidden_states: torch.FloatTensor,
388
+ temb: torch.FloatTensor,
389
+ image_rotary_emb=None,
390
+ condition_latents: torch.FloatTensor = None,
391
+ extra_condition_latents: torch.FloatTensor = None,
392
+ cond_temb: torch.FloatTensor = None,
393
+ extra_cond_temb: torch.FloatTensor = None,
394
+ cond_rotary_emb=None,
395
+ extra_cond_rotary_emb=None,
396
+ model_config: Optional[Dict[str, Any]] = {},
397
+ ):
398
+
399
+ using_cond = condition_latents is not None
400
+ using_extra_cond = extra_condition_latents is not None
401
+ residual = hidden_states
402
+ with enable_lora(
403
+ (
404
+ self.norm.linear,
405
+ self.proj_mlp,
406
+ ),
407
+ model_config.get("latent_lora", False),
408
+ ):
409
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
410
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
411
+ if using_cond:
412
+ residual_cond = condition_latents
413
+ norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
414
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
415
+
416
+ if using_extra_cond:
417
+ extra_residual_cond = extra_condition_latents
418
+ extra_norm_condition_latents, extra_cond_gate = self.norm(extra_condition_latents, emb=extra_cond_temb)
419
+ extra_mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(extra_norm_condition_latents))
420
+
421
+ attn_output = attn_forward(
422
+ self.attn,
423
+ model_config=model_config,
424
+ hidden_states=norm_hidden_states,
425
+ image_rotary_emb=image_rotary_emb,
426
+ **(
427
+ {
428
+ "condition_latents": norm_condition_latents,
429
+ "cond_rotary_emb": cond_rotary_emb if using_cond else None,
430
+ "extra_condition_latents": extra_norm_condition_latents if using_cond else None,
431
+ "extra_cond_rotary_emb": extra_cond_rotary_emb if using_cond else None,
432
+ }
433
+ if using_cond
434
+ else {}
435
+ ),
436
+ )
437
+
438
+ if using_cond:
439
+ attn_output, cond_attn_output, extra_cond_attn_output = attn_output
440
+
441
+
442
+ with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
443
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
444
+ gate = gate.unsqueeze(1)
445
+ hidden_states = gate * self.proj_out(hidden_states)
446
+ hidden_states = residual + hidden_states
447
+ if using_cond:
448
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
449
+ cond_gate = cond_gate.unsqueeze(1)
450
+ condition_latents = cond_gate * self.proj_out(condition_latents)
451
+ condition_latents = residual_cond + condition_latents
452
+
453
+ extra_condition_latents = torch.cat([extra_cond_attn_output, extra_mlp_cond_hidden_states], dim=2)
454
+ extra_cond_gate = extra_cond_gate.unsqueeze(1)
455
+ extra_condition_latents = extra_cond_gate * self.proj_out(extra_condition_latents)
456
+ extra_condition_latents = extra_residual_cond + extra_condition_latents
457
+
458
+ if hidden_states.dtype == torch.float16:
459
+ hidden_states = hidden_states.clip(-65504, 65504)
460
+
461
+ return hidden_states if not using_cond else (hidden_states, condition_latents, extra_condition_latents)
flux/condition.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Recycled from Ominicontrol and modified to accept an extra condition.
2
+ # While Zenctrl pursued a similar idea, it diverged structurally.
3
+ # We appreciate the clarity of Omini's implementation and decided to align with it.
4
+
5
+ import torch
6
+ from typing import Optional, Union, List, Tuple
7
+ from diffusers.pipelines import FluxPipeline
8
+ from PIL import Image, ImageFilter
9
+ import numpy as np
10
+ import cv2
11
+
12
+ # from pipeline_tools import encode_images
13
+ from .pipeline_tools import encode_images
14
+
15
+ condition_dict = {
16
+ "subject": 1,
17
+ "sr": 2,
18
+ "cot": 3,
19
+ }
20
+
21
+
22
+ class Condition(object):
23
+ def __init__(
24
+ self,
25
+ condition_type: str,
26
+ raw_img: Union[Image.Image, torch.Tensor] = None,
27
+ condition: Union[Image.Image, torch.Tensor] = None,
28
+ position_delta=None,
29
+ ) -> None:
30
+ self.condition_type = condition_type
31
+ assert raw_img is not None or condition is not None
32
+ if raw_img is not None:
33
+ self.condition = self.get_condition(condition_type, raw_img)
34
+ else:
35
+ self.condition = condition
36
+ self.position_delta = position_delta
37
+
38
+
39
+ def get_condition(
40
+ self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
41
+ ) -> Union[Image.Image, torch.Tensor]:
42
+ """
43
+ Returns the condition image.
44
+ """
45
+ if condition_type == "subject":
46
+ return raw_img
47
+ elif condition_type == "sr":
48
+ return raw_img
49
+ elif condition_type == "cot":
50
+ return raw_img.convert("RGB")
51
+ return self.condition
52
+
53
+
54
+ @property
55
+ def type_id(self) -> int:
56
+ """
57
+ Returns the type id of the condition.
58
+ """
59
+ return condition_dict[self.condition_type]
60
+
61
+ def encode(
62
+ self, pipe: FluxPipeline, empty: bool = False
63
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
64
+ """
65
+ Encodes the condition into tokens, ids and type_id.
66
+ """
67
+ if self.condition_type in [
68
+ "subject",
69
+ "sr",
70
+ "cot"
71
+ ]:
72
+ if empty:
73
+ # make the condition black
74
+ e_condition = Image.new("RGB", self.condition.size, (0, 0, 0))
75
+ e_condition = e_condition.convert("RGB")
76
+ tokens, ids = encode_images(pipe, e_condition)
77
+ else:
78
+ tokens, ids = encode_images(pipe, self.condition)
79
+ else:
80
+ raise NotImplementedError(
81
+ f"Condition type {self.condition_type} not implemented"
82
+ )
83
+ if self.position_delta is None and self.condition_type == "subject":
84
+ self.position_delta = [0, -self.condition.size[0] // 16]
85
+ if self.position_delta is not None:
86
+ ids[:, 1] += self.position_delta[0]
87
+ ids[:, 2] += self.position_delta[1]
88
+ type_id = torch.ones_like(ids[:, :1]) * self.type_id
89
+ return tokens, ids, type_id
flux/generate.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Recycled from Ominicontrol and modified to accept an extra condition.
2
+ # While Zenctrl pursued a similar idea, it diverged structurally.
3
+ # We appreciate the clarity of Omini's implementation and decided to align with it.
4
+
5
+ import torch
6
+ import yaml, os
7
+ from diffusers.pipelines import FluxPipeline
8
+ from typing import List, Union, Optional, Dict, Any, Callable
9
+ from .transformer import tranformer_forward
10
+ from .condition import Condition
11
+
12
+
13
+ from diffusers.pipelines.flux.pipeline_flux import (
14
+ FluxPipelineOutput,
15
+ calculate_shift,
16
+ retrieve_timesteps,
17
+ np,
18
+ )
19
+
20
+
21
+ def get_config(config_path: str = None):
22
+ config_path = config_path or os.environ.get("XFL_CONFIG")
23
+ if not config_path:
24
+ return {}
25
+ with open(config_path, "r") as f:
26
+ config = yaml.safe_load(f)
27
+ return config
28
+
29
+
30
+ def prepare_params(
31
+ prompt: Union[str, List[str]] = None,
32
+ prompt_2: Optional[Union[str, List[str]]] = None,
33
+ height: Optional[int] = 512,
34
+ width: Optional[int] = 512,
35
+ num_inference_steps: int = 28,
36
+ timesteps: List[int] = None,
37
+ guidance_scale: float = 3.5,
38
+ num_images_per_prompt: Optional[int] = 1,
39
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
40
+ latents: Optional[torch.FloatTensor] = None,
41
+ prompt_embeds: Optional[torch.FloatTensor] = None,
42
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
43
+ output_type: Optional[str] = "pil",
44
+ return_dict: bool = True,
45
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
46
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
47
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
48
+ max_sequence_length: int = 512,
49
+ **kwargs: dict,
50
+ ):
51
+ return (
52
+ prompt,
53
+ prompt_2,
54
+ height,
55
+ width,
56
+ num_inference_steps,
57
+ timesteps,
58
+ guidance_scale,
59
+ num_images_per_prompt,
60
+ generator,
61
+ latents,
62
+ prompt_embeds,
63
+ pooled_prompt_embeds,
64
+ output_type,
65
+ return_dict,
66
+ joint_attention_kwargs,
67
+ callback_on_step_end,
68
+ callback_on_step_end_tensor_inputs,
69
+ max_sequence_length,
70
+ )
71
+
72
+
73
+ def seed_everything(seed: int = 42):
74
+ torch.backends.cudnn.deterministic = True
75
+ torch.manual_seed(seed)
76
+ np.random.seed(seed)
77
+
78
+
79
+ @torch.no_grad()
80
+ def generate(
81
+ pipeline: FluxPipeline,
82
+ conditions: List[Condition] = None,
83
+ config_path: str = None,
84
+ model_config: Optional[Dict[str, Any]] = {},
85
+ condition_scale: float = [1, 1],
86
+ default_lora: bool = False,
87
+ image_guidance_scale: float = 1.0,
88
+ **params: dict,
89
+ ):
90
+ model_config = model_config or get_config(config_path).get("model", {})
91
+ if condition_scale != [1,1]:
92
+ for name, module in pipeline.transformer.named_modules():
93
+ if not name.endswith(".attn"):
94
+ continue
95
+ module.c_factor = torch.tensor(condition_scale)
96
+
97
+ self = pipeline
98
+ (
99
+ prompt,
100
+ prompt_2,
101
+ height,
102
+ width,
103
+ num_inference_steps,
104
+ timesteps,
105
+ guidance_scale,
106
+ num_images_per_prompt,
107
+ generator,
108
+ latents,
109
+ prompt_embeds,
110
+ pooled_prompt_embeds,
111
+ output_type,
112
+ return_dict,
113
+ joint_attention_kwargs,
114
+ callback_on_step_end,
115
+ callback_on_step_end_tensor_inputs,
116
+ max_sequence_length,
117
+ ) = prepare_params(**params)
118
+
119
+ height = height or self.default_sample_size * self.vae_scale_factor
120
+ width = width or self.default_sample_size * self.vae_scale_factor
121
+
122
+ # 1. Check inputs. Raise error if not correct
123
+ self.check_inputs(
124
+ prompt,
125
+ prompt_2,
126
+ height,
127
+ width,
128
+ prompt_embeds=prompt_embeds,
129
+ pooled_prompt_embeds=pooled_prompt_embeds,
130
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
131
+ max_sequence_length=max_sequence_length,
132
+ )
133
+
134
+ self._guidance_scale = guidance_scale
135
+ self._joint_attention_kwargs = joint_attention_kwargs
136
+ self._interrupt = False
137
+
138
+ # 2. Define call parameters
139
+ if prompt is not None and isinstance(prompt, str):
140
+ batch_size = 1
141
+ elif prompt is not None and isinstance(prompt, list):
142
+ batch_size = len(prompt)
143
+ else:
144
+ batch_size = prompt_embeds.shape[0]
145
+
146
+ device = self._execution_device
147
+
148
+ lora_scale = (
149
+ self.joint_attention_kwargs.get("scale", None)
150
+ if self.joint_attention_kwargs is not None
151
+ else None
152
+ )
153
+ (
154
+ prompt_embeds,
155
+ pooled_prompt_embeds,
156
+ text_ids,
157
+ ) = self.encode_prompt(
158
+ prompt=prompt,
159
+ prompt_2=prompt_2,
160
+ prompt_embeds=prompt_embeds,
161
+ pooled_prompt_embeds=pooled_prompt_embeds,
162
+ device=device,
163
+ num_images_per_prompt=num_images_per_prompt,
164
+ max_sequence_length=max_sequence_length,
165
+ lora_scale=lora_scale,
166
+ )
167
+
168
+ # 4. Prepare latent variables
169
+ num_channels_latents = self.transformer.config.in_channels // 4
170
+ latents, latent_image_ids = self.prepare_latents(
171
+ batch_size * num_images_per_prompt,
172
+ num_channels_latents,
173
+ height,
174
+ width,
175
+ prompt_embeds.dtype,
176
+ device,
177
+ generator,
178
+ latents,
179
+ )
180
+
181
+ # 4.1. Prepare conditions
182
+ condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
183
+ extra_condition_latents, extra_condition_ids, extra_condition_type_ids = ([] for _ in range(3))
184
+ use_condition = conditions is not None or []
185
+ if use_condition:
186
+ if not default_lora:
187
+ pipeline.set_adapters(conditions[1].condition_type)
188
+ # for condition in conditions:
189
+ tokens, ids, type_id = conditions[0].encode(self)
190
+ condition_latents.append(tokens) # [batch_size, token_n, token_dim]
191
+ condition_ids.append(ids) # [token_n, id_dim(3)]
192
+ condition_type_ids.append(type_id) # [token_n, 1]
193
+ condition_latents = torch.cat(condition_latents, dim=1)
194
+ condition_ids = torch.cat(condition_ids, dim=0)
195
+ condition_type_ids = torch.cat(condition_type_ids, dim=0)
196
+
197
+ tokens, ids, type_id = conditions[1].encode(self)
198
+ extra_condition_latents.append(tokens) # [batch_size, token_n, token_dim]
199
+ extra_condition_ids.append(ids) # [token_n, id_dim(3)]
200
+ extra_condition_type_ids.append(type_id) # [token_n, 1]
201
+ extra_condition_latents = torch.cat(extra_condition_latents, dim=1)
202
+ extra_condition_ids = torch.cat(extra_condition_ids, dim=0)
203
+ extra_condition_type_ids = torch.cat(extra_condition_type_ids, dim=0)
204
+
205
+ # 5. Prepare timesteps
206
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
207
+ image_seq_len = latents.shape[1]
208
+ mu = calculate_shift(
209
+ image_seq_len,
210
+ self.scheduler.config.base_image_seq_len,
211
+ self.scheduler.config.max_image_seq_len,
212
+ self.scheduler.config.base_shift,
213
+ self.scheduler.config.max_shift,
214
+ )
215
+ timesteps, num_inference_steps = retrieve_timesteps(
216
+ self.scheduler,
217
+ num_inference_steps,
218
+ device,
219
+ timesteps,
220
+ sigmas,
221
+ mu=mu,
222
+ )
223
+ num_warmup_steps = max(
224
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
225
+ )
226
+ self._num_timesteps = len(timesteps)
227
+
228
+ # 6. Denoising loop
229
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
230
+ for i, t in enumerate(timesteps):
231
+ if self.interrupt:
232
+ continue
233
+
234
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
235
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
236
+
237
+ # handle guidance
238
+ if self.transformer.config.guidance_embeds:
239
+ guidance = torch.tensor([guidance_scale], device=device)
240
+ guidance = guidance.expand(latents.shape[0])
241
+ else:
242
+ guidance = None
243
+ noise_pred = tranformer_forward(
244
+ self.transformer,
245
+ model_config=model_config,
246
+ # Inputs of the condition (new feature)
247
+ condition_latents=condition_latents if use_condition else None,
248
+ condition_ids=condition_ids if use_condition else None,
249
+ condition_type_ids=condition_type_ids if use_condition else None,
250
+ extra_condition_latents=extra_condition_latents if use_condition else None,
251
+ extra_condition_ids=extra_condition_ids if use_condition else None,
252
+ extra_condition_type_ids=extra_condition_type_ids if use_condition else None,
253
+ # Inputs to the original transformer
254
+ hidden_states=latents,
255
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
256
+ timestep=timestep / 1000,
257
+ guidance=guidance,
258
+ pooled_projections=pooled_prompt_embeds,
259
+ encoder_hidden_states=prompt_embeds,
260
+ txt_ids=text_ids,
261
+ img_ids=latent_image_ids,
262
+ joint_attention_kwargs=self.joint_attention_kwargs,
263
+ return_dict=False,
264
+ )[0]
265
+
266
+ if image_guidance_scale != 1.0:
267
+ uncondition_latents = conditions.encode(self, empty=True)[0]
268
+ unc_pred = tranformer_forward(
269
+ self.transformer,
270
+ model_config=model_config,
271
+ # Inputs of the condition (new feature)
272
+ condition_latents=uncondition_latents if use_condition else None,
273
+ condition_ids=condition_ids if use_condition else None,
274
+ condition_type_ids=condition_type_ids if use_condition else None,
275
+ # Inputs to the original transformer
276
+ hidden_states=latents,
277
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
278
+ timestep=timestep / 1000,
279
+ guidance=torch.ones_like(guidance),
280
+ pooled_projections=pooled_prompt_embeds,
281
+ encoder_hidden_states=prompt_embeds,
282
+ txt_ids=text_ids,
283
+ img_ids=latent_image_ids,
284
+ joint_attention_kwargs=self.joint_attention_kwargs,
285
+ return_dict=False,
286
+ )[0]
287
+
288
+ noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
289
+
290
+ # compute the previous noisy sample x_t -> x_t-1
291
+ latents_dtype = latents.dtype
292
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
293
+
294
+ if latents.dtype != latents_dtype:
295
+ if torch.backends.mps.is_available():
296
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
297
+ latents = latents.to(latents_dtype)
298
+
299
+ if callback_on_step_end is not None:
300
+ callback_kwargs = {}
301
+ for k in callback_on_step_end_tensor_inputs:
302
+ callback_kwargs[k] = locals()[k]
303
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
304
+
305
+ latents = callback_outputs.pop("latents", latents)
306
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
307
+
308
+ # call the callback, if provided
309
+ if i == len(timesteps) - 1 or (
310
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
311
+ ):
312
+ progress_bar.update()
313
+
314
+ if output_type == "latent":
315
+ image = latents
316
+
317
+ else:
318
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
319
+ latents = (
320
+ latents / self.vae.config.scaling_factor
321
+ ) + self.vae.config.shift_factor
322
+ image = self.vae.decode(latents, return_dict=False)[0]
323
+ image = self.image_processor.postprocess(image, output_type=output_type)
324
+
325
+ # Offload all models
326
+ self.maybe_free_model_hooks()
327
+
328
+ if condition_scale != [1,1]:
329
+ for name, module in pipeline.transformer.named_modules():
330
+ if not name.endswith(".attn"):
331
+ continue
332
+ del module.c_factor
333
+
334
+ if not return_dict:
335
+ return (image,)
336
+
337
+ return FluxPipelineOutput(images=image)
flux/lora_controller.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #As is from OminiControl
2
+ from peft.tuners.tuners_utils import BaseTunerLayer
3
+ from typing import List, Any, Optional, Type
4
+ from .condition import condition_dict
5
+
6
+
7
+ class enable_lora:
8
+ def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
9
+ self.activated: bool = activated
10
+ if activated:
11
+ return
12
+ self.lora_modules: List[BaseTunerLayer] = [
13
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
14
+ ]
15
+ self.scales = [
16
+ {
17
+ active_adapter: lora_module.scaling[active_adapter]
18
+ for active_adapter in lora_module.active_adapters
19
+ }
20
+ for lora_module in self.lora_modules
21
+ ]
22
+
23
+ def __enter__(self) -> None:
24
+ if self.activated:
25
+ return
26
+
27
+ for lora_module in self.lora_modules:
28
+ if not isinstance(lora_module, BaseTunerLayer):
29
+ continue
30
+ for active_adapter in lora_module.active_adapters:
31
+ if (
32
+ active_adapter in condition_dict.keys()
33
+ or active_adapter == "default"
34
+ ):
35
+ lora_module.scaling[active_adapter] = 0.0
36
+
37
+ def __exit__(
38
+ self,
39
+ exc_type: Optional[Type[BaseException]],
40
+ exc_val: Optional[BaseException],
41
+ exc_tb: Optional[Any],
42
+ ) -> None:
43
+ if self.activated:
44
+ return
45
+ for i, lora_module in enumerate(self.lora_modules):
46
+ if not isinstance(lora_module, BaseTunerLayer):
47
+ continue
48
+ for active_adapter in lora_module.active_adapters:
49
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
50
+
51
+
52
+ class set_lora_scale:
53
+ def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
54
+ self.lora_modules: List[BaseTunerLayer] = [
55
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
56
+ ]
57
+ self.scales = [
58
+ {
59
+ active_adapter: lora_module.scaling[active_adapter]
60
+ for active_adapter in lora_module.active_adapters
61
+ }
62
+ for lora_module in self.lora_modules
63
+ ]
64
+ self.scale = scale
65
+
66
+ def __enter__(self) -> None:
67
+ for lora_module in self.lora_modules:
68
+ if not isinstance(lora_module, BaseTunerLayer):
69
+ continue
70
+ lora_module.scale_layer(self.scale)
71
+
72
+ def __exit__(
73
+ self,
74
+ exc_type: Optional[Type[BaseException]],
75
+ exc_val: Optional[BaseException],
76
+ exc_tb: Optional[Any],
77
+ ) -> None:
78
+ for i, lora_module in enumerate(self.lora_modules):
79
+ if not isinstance(lora_module, BaseTunerLayer):
80
+ continue
81
+ for active_adapter in lora_module.active_adapters:
82
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
flux/pipeline_tools.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #As is from OminiControl
2
+ from diffusers.pipelines import FluxPipeline
3
+ from diffusers.utils import logging
4
+ from diffusers.pipelines.flux.pipeline_flux import logger
5
+ from torch import Tensor
6
+
7
+
8
+ def encode_images(pipeline: FluxPipeline, images: Tensor):
9
+ images = pipeline.image_processor.preprocess(images)
10
+ images = images.to(pipeline.device).to(pipeline.dtype)
11
+ images = pipeline.vae.encode(images).latent_dist.sample()
12
+ images = (
13
+ images - pipeline.vae.config.shift_factor
14
+ ) * pipeline.vae.config.scaling_factor
15
+ images_tokens = pipeline._pack_latents(images, *images.shape)
16
+ images_ids = pipeline._prepare_latent_image_ids(
17
+ images.shape[0],
18
+ images.shape[2],
19
+ images.shape[3],
20
+ pipeline.device,
21
+ pipeline.dtype,
22
+ )
23
+ if images_tokens.shape[1] != images_ids.shape[0]:
24
+ images_ids = pipeline._prepare_latent_image_ids(
25
+ images.shape[0],
26
+ images.shape[2] // 2,
27
+ images.shape[3] // 2,
28
+ pipeline.device,
29
+ pipeline.dtype,
30
+ )
31
+ return images_tokens, images_ids
32
+
33
+
34
+ def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
35
+ # Turn off warnings (CLIP overflow)
36
+ logger.setLevel(logging.ERROR)
37
+ (
38
+ prompt_embeds,
39
+ pooled_prompt_embeds,
40
+ text_ids,
41
+ ) = pipeline.encode_prompt(
42
+ prompt=prompts,
43
+ prompt_2=None,
44
+ prompt_embeds=None,
45
+ pooled_prompt_embeds=None,
46
+ device=pipeline.device,
47
+ num_images_per_prompt=1,
48
+ max_sequence_length=max_sequence_length,
49
+ lora_scale=None,
50
+ )
51
+ # Turn on warnings
52
+ logger.setLevel(logging.WARNING)
53
+ return prompt_embeds, pooled_prompt_embeds, text_ids
flux/transformer.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Recycled from Ominicontrol and modified to accept an extra condition.
2
+ # While Zenctrl pursued a similar idea, it diverged structurally.
3
+ # We appreciate the clarity of Omini's implementation and decided to align with it.
4
+
5
+ import torch
6
+ from diffusers.pipelines import FluxPipeline
7
+ from typing import List, Union, Optional, Dict, Any, Callable
8
+ from .block import block_forward, single_block_forward
9
+ from .lora_controller import enable_lora
10
+ from accelerate.utils import is_torch_version
11
+ from diffusers.models.transformers.transformer_flux import (
12
+ FluxTransformer2DModel,
13
+ Transformer2DModelOutput,
14
+ USE_PEFT_BACKEND,
15
+ scale_lora_layers,
16
+ unscale_lora_layers,
17
+ logger,
18
+ )
19
+ import numpy as np
20
+
21
+
22
+ def prepare_params(
23
+ hidden_states: torch.Tensor,
24
+ encoder_hidden_states: torch.Tensor = None,
25
+ pooled_projections: torch.Tensor = None,
26
+ timestep: torch.LongTensor = None,
27
+ img_ids: torch.Tensor = None,
28
+ txt_ids: torch.Tensor = None,
29
+ guidance: torch.Tensor = None,
30
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
31
+ controlnet_block_samples=None,
32
+ controlnet_single_block_samples=None,
33
+ return_dict: bool = True,
34
+ **kwargs: dict,
35
+ ):
36
+ return (
37
+ hidden_states,
38
+ encoder_hidden_states,
39
+ pooled_projections,
40
+ timestep,
41
+ img_ids,
42
+ txt_ids,
43
+ guidance,
44
+ joint_attention_kwargs,
45
+ controlnet_block_samples,
46
+ controlnet_single_block_samples,
47
+ return_dict,
48
+ )
49
+
50
+
51
+ def tranformer_forward(
52
+ transformer: FluxTransformer2DModel,
53
+ condition_latents: torch.Tensor,
54
+ extra_condition_latents: torch.Tensor,
55
+ condition_ids: torch.Tensor,
56
+ condition_type_ids: torch.Tensor,
57
+ extra_condition_ids: torch.Tensor,
58
+ extra_condition_type_ids: torch.Tensor,
59
+ model_config: Optional[Dict[str, Any]] = {},
60
+ c_t=0,
61
+ **params: dict,
62
+ ):
63
+ self = transformer
64
+ use_condition = condition_latents is not None
65
+ use_extra_condition = extra_condition_latents is not None
66
+
67
+ (
68
+ hidden_states,
69
+ encoder_hidden_states,
70
+ pooled_projections,
71
+ timestep,
72
+ img_ids,
73
+ txt_ids,
74
+ guidance,
75
+ joint_attention_kwargs,
76
+ controlnet_block_samples,
77
+ controlnet_single_block_samples,
78
+ return_dict,
79
+ ) = prepare_params(**params)
80
+
81
+ if joint_attention_kwargs is not None:
82
+ joint_attention_kwargs = joint_attention_kwargs.copy()
83
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
84
+ else:
85
+ lora_scale = 1.0
86
+
87
+ if USE_PEFT_BACKEND:
88
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
89
+ scale_lora_layers(self, lora_scale)
90
+ else:
91
+ if (
92
+ joint_attention_kwargs is not None
93
+ and joint_attention_kwargs.get("scale", None) is not None
94
+ ):
95
+ logger.warning(
96
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
97
+ )
98
+
99
+ with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
100
+ hidden_states = self.x_embedder(hidden_states)
101
+ condition_latents = self.x_embedder(condition_latents) if use_condition else None
102
+ extra_condition_latents = self.x_embedder(extra_condition_latents) if use_extra_condition else None
103
+
104
+ timestep = timestep.to(hidden_states.dtype) * 1000
105
+
106
+ if guidance is not None:
107
+ guidance = guidance.to(hidden_states.dtype) * 1000
108
+ else:
109
+ guidance = None
110
+
111
+ temb = (
112
+ self.time_text_embed(timestep, pooled_projections)
113
+ if guidance is None
114
+ else self.time_text_embed(timestep, guidance, pooled_projections)
115
+ )
116
+
117
+ cond_temb = (
118
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
119
+ if guidance is None
120
+ else self.time_text_embed(
121
+ torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
122
+ )
123
+ )
124
+ extra_cond_temb = (
125
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
126
+ if guidance is None
127
+ else self.time_text_embed(
128
+ torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
129
+ )
130
+ )
131
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
132
+
133
+ if txt_ids.ndim == 3:
134
+ logger.warning(
135
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
136
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
137
+ )
138
+ txt_ids = txt_ids[0]
139
+ if img_ids.ndim == 3:
140
+ logger.warning(
141
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
142
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
143
+ )
144
+ img_ids = img_ids[0]
145
+
146
+ ids = torch.cat((txt_ids, img_ids), dim=0)
147
+ image_rotary_emb = self.pos_embed(ids)
148
+ if use_condition:
149
+ # condition_ids[:, :1] = condition_type_ids
150
+ cond_rotary_emb = self.pos_embed(condition_ids)
151
+
152
+ if use_extra_condition:
153
+ extra_cond_rotary_emb = self.pos_embed(extra_condition_ids)
154
+
155
+
156
+ # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
157
+
158
+ #print("here!")
159
+ for index_block, block in enumerate(self.transformer_blocks):
160
+ if self.training and self.gradient_checkpointing:
161
+ ckpt_kwargs: Dict[str, Any] = (
162
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
163
+ )
164
+ encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = (
165
+ torch.utils.checkpoint.checkpoint(
166
+ block_forward,
167
+ self=block,
168
+ model_config=model_config,
169
+ hidden_states=hidden_states,
170
+ encoder_hidden_states=encoder_hidden_states,
171
+ condition_latents=condition_latents if use_condition else None,
172
+ extra_condition_latents=extra_condition_latents if use_extra_condition else None,
173
+ temb=temb,
174
+ cond_temb=cond_temb if use_condition else None,
175
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
176
+ extra_cond_temb=extra_cond_temb if use_extra_condition else None,
177
+ extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_condition else None,
178
+ image_rotary_emb=image_rotary_emb,
179
+ **ckpt_kwargs,
180
+ )
181
+ )
182
+
183
+ else:
184
+ encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = block_forward(
185
+ block,
186
+ model_config=model_config,
187
+ hidden_states=hidden_states,
188
+ encoder_hidden_states=encoder_hidden_states,
189
+ condition_latents=condition_latents if use_condition else None,
190
+ extra_condition_latents=extra_condition_latents if use_extra_condition else None,
191
+ temb=temb,
192
+ cond_temb=cond_temb if use_condition else None,
193
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
194
+ extra_cond_temb=cond_temb if use_extra_condition else None,
195
+ extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_condition else None,
196
+ image_rotary_emb=image_rotary_emb,
197
+ )
198
+
199
+ # controlnet residual
200
+ if controlnet_block_samples is not None:
201
+ interval_control = len(self.transformer_blocks) / len(
202
+ controlnet_block_samples
203
+ )
204
+ interval_control = int(np.ceil(interval_control))
205
+ hidden_states = (
206
+ hidden_states
207
+ + controlnet_block_samples[index_block // interval_control]
208
+ )
209
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
210
+
211
+
212
+ for index_block, block in enumerate(self.single_transformer_blocks):
213
+ if self.training and self.gradient_checkpointing:
214
+ ckpt_kwargs: Dict[str, Any] = (
215
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
216
+ )
217
+ result = torch.utils.checkpoint.checkpoint(
218
+ single_block_forward,
219
+ self=block,
220
+ model_config=model_config,
221
+ hidden_states=hidden_states,
222
+ temb=temb,
223
+ image_rotary_emb=image_rotary_emb,
224
+ **(
225
+ {
226
+ "condition_latents": condition_latents,
227
+ "extra_condition_latents": extra_condition_latents,
228
+ "cond_temb": cond_temb,
229
+ "cond_rotary_emb": cond_rotary_emb,
230
+ "extra_cond_temb": extra_cond_temb,
231
+ "extra_cond_rotary_emb": extra_cond_rotary_emb,
232
+ }
233
+ if use_condition
234
+ else {}
235
+ ),
236
+ **ckpt_kwargs,
237
+ )
238
+
239
+ else:
240
+ result = single_block_forward(
241
+ block,
242
+ model_config=model_config,
243
+ hidden_states=hidden_states,
244
+ temb=temb,
245
+ image_rotary_emb=image_rotary_emb,
246
+ **(
247
+ {
248
+ "condition_latents": condition_latents,
249
+ "extra_condition_latents": extra_condition_latents,
250
+ "cond_temb": cond_temb,
251
+ "cond_rotary_emb": cond_rotary_emb,
252
+ "extra_cond_temb": extra_cond_temb,
253
+ "extra_cond_rotary_emb": extra_cond_rotary_emb,
254
+ }
255
+ if use_condition
256
+ else {}
257
+ ),
258
+ )
259
+ if use_condition:
260
+ hidden_states, condition_latents, extra_condition_latents = result
261
+ else:
262
+ hidden_states = result
263
+
264
+ # controlnet residual
265
+ if controlnet_single_block_samples is not None:
266
+ interval_control = len(self.single_transformer_blocks) / len(
267
+ controlnet_single_block_samples
268
+ )
269
+ interval_control = int(np.ceil(interval_control))
270
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
271
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
272
+ + controlnet_single_block_samples[index_block // interval_control]
273
+ )
274
+
275
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
276
+
277
+ hidden_states = self.norm_out(hidden_states, temb)
278
+ output = self.proj_out(hidden_states)
279
+
280
+ if USE_PEFT_BACKEND:
281
+ # remove `lora_scale` from each PEFT layer
282
+ unscale_lora_layers(self, lora_scale)
283
+
284
+ if not return_dict:
285
+ return (output,)
286
+ return Transformer2DModelOutput(sample=output)
imgs/bg_i1.png ADDED

Git LFS Details

  • SHA256: 8233ff6e5eaaf97f6157599708ed67e62380f3e09820d67bb2ebd472d84165a7
  • Pointer size: 130 Bytes
  • Size of remote file: 97.7 kB
imgs/bg_i2.png ADDED

Git LFS Details

  • SHA256: 41ff9fabfb2e31cce35cc97f2c5962165c3b68bce60732e6669692307ec5ebed
  • Pointer size: 131 Bytes
  • Size of remote file: 197 kB
imgs/bg_i3.png ADDED

Git LFS Details

  • SHA256: 89815263bfd1b72e5be817ddcefc770a728b46bdb7d8264258cae8b5a4270493
  • Pointer size: 131 Bytes
  • Size of remote file: 279 kB
imgs/bg_i4.png ADDED

Git LFS Details

  • SHA256: ebedef7bdaf6b6e184cad63665f40d2f86dd5a6c54334b400d1ce479c5ec339e
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
imgs/bg_i5.png ADDED

Git LFS Details

  • SHA256: f3065ff3010bdf54d9eed2784a9f998597fb3f10646f0a6ba71e7d2abd9041c2
  • Pointer size: 131 Bytes
  • Size of remote file: 947 kB
imgs/bg_o1.png ADDED

Git LFS Details

  • SHA256: 8e80909e8b091f02d8d00f8614194f5e6606b076bf86d4fd88e58ffcfeaaa4ed
  • Pointer size: 131 Bytes
  • Size of remote file: 621 kB
imgs/bg_o2.png ADDED

Git LFS Details

  • SHA256: e32d9497741ea352725220043e54d1772b1fe4e81910cfaf668cc95ee6e9a23f
  • Pointer size: 131 Bytes
  • Size of remote file: 945 kB
imgs/bg_o3.jpg ADDED

Git LFS Details

  • SHA256: ffef49f49bbed31e53e10552a3e6bedc8c492d0b69b3e5284e5f612705f84983
  • Pointer size: 130 Bytes
  • Size of remote file: 58.1 kB
imgs/bg_o4.jpg ADDED

Git LFS Details

  • SHA256: f26c092a93c15a781ae013a4980f2a4b0489277687263e7f6619f5ea454fa645
  • Pointer size: 130 Bytes
  • Size of remote file: 69.7 kB
imgs/bg_o5.jpg ADDED

Git LFS Details

  • SHA256: 3ebe1a2421d9a36a108d9aca97064d96f1c5016bc6ea378dee637fc843a0357c
  • Pointer size: 130 Bytes
  • Size of remote file: 37 kB
imgs/sub_i1.png ADDED

Git LFS Details

  • SHA256: 41ff9fabfb2e31cce35cc97f2c5962165c3b68bce60732e6669692307ec5ebed
  • Pointer size: 131 Bytes
  • Size of remote file: 197 kB
imgs/sub_i2.png ADDED

Git LFS Details

  • SHA256: ebedef7bdaf6b6e184cad63665f40d2f86dd5a6c54334b400d1ce479c5ec339e
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
imgs/sub_i3.png ADDED

Git LFS Details

  • SHA256: f3065ff3010bdf54d9eed2784a9f998597fb3f10646f0a6ba71e7d2abd9041c2
  • Pointer size: 131 Bytes
  • Size of remote file: 947 kB
imgs/sub_i4.png ADDED

Git LFS Details

  • SHA256: 24deed93849e951f8c7c44de47226240db7c5a42a289f3fdc7c0fa5621f65609
  • Pointer size: 131 Bytes
  • Size of remote file: 281 kB
imgs/sub_i5.png ADDED

Git LFS Details

  • SHA256: 50b05fff1d2d404da6d6045c36971cd810c2e0d425168cdafa1511f95c5ba269
  • Pointer size: 131 Bytes
  • Size of remote file: 250 kB
imgs/sub_o1.webp ADDED

Git LFS Details

  • SHA256: c1c9eef884328c26cc58de4f35680b2ca5f211071c2b9bd5254223b0dec0757e
  • Pointer size: 130 Bytes
  • Size of remote file: 32.1 kB
imgs/sub_o2.webp ADDED

Git LFS Details

  • SHA256: 026e18cab72811630eb367bf4ae2e38933fb9c31c0146b66990c5f8879b8b71b
  • Pointer size: 130 Bytes
  • Size of remote file: 27.3 kB
imgs/sub_o3.webp ADDED

Git LFS Details

  • SHA256: f0b2abf5d005c747c092c60837b3a80b8d96e5c52625b2d288475799603c8147
  • Pointer size: 130 Bytes
  • Size of remote file: 28.3 kB
imgs/sub_o4.webp ADDED

Git LFS Details

  • SHA256: f77c4b61c9ad96cbf312af18c631f9d558bdcf1f51beb1317bc483bfe700cf18
  • Pointer size: 130 Bytes
  • Size of remote file: 16.9 kB
imgs/sub_o5.webp ADDED

Git LFS Details

  • SHA256: 14ef4c02f4a2970a26a1bd92ce4c937b8aad477a91463fef5a8a34dd61afe3e8
  • Pointer size: 130 Bytes
  • Size of remote file: 17.4 kB
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+
5
+ diffusers==0.31.0
6
+ transformers
7
+ accelerate
8
+ huggingface_hub
9
+ sentencepiece
10
+
11
+ numpy
12
+ pillow>=9.0.0
13
+ einops>=0.7.0
14
+ safetensors>=0.4.0
15
+ opencv-python-headless
16
+ peft==0.15.2
17
+ spaces
samples/1.png ADDED

Git LFS Details

  • SHA256: ee99aba59d42e4c8221785aa032fb5dc3b623478f27a4a9255e42e19bc3308d7
  • Pointer size: 131 Bytes
  • Size of remote file: 809 kB
samples/10.png ADDED

Git LFS Details

  • SHA256: e6108e472742ee174e2e37d3a8acbfb14c51bdf0c252139086c0813026fb6b73
  • Pointer size: 131 Bytes
  • Size of remote file: 405 kB
samples/11.png ADDED

Git LFS Details

  • SHA256: a8f107773e71b227d9f95692b7af837d41c4b4a39bda48940755db399d619d21
  • Pointer size: 131 Bytes
  • Size of remote file: 391 kB
samples/11_out.webp ADDED

Git LFS Details

  • SHA256: 6ec7669bc2d1423c69edd87e6a6229d15b76043100ae91d7ff2b07b6197646aa
  • Pointer size: 130 Bytes
  • Size of remote file: 40.1 kB
samples/12.png ADDED

Git LFS Details

  • SHA256: ede3b034dcb0a378073cf7d18841b8bf58c20c7f2161239718520dd3c36fe338
  • Pointer size: 131 Bytes
  • Size of remote file: 494 kB
samples/12_out.webp ADDED

Git LFS Details

  • SHA256: 6c67d211f75e77242af15d259c20a3d527b844b9f5e9c2a199ff2797e0bc459b
  • Pointer size: 130 Bytes
  • Size of remote file: 30.2 kB
samples/13.png ADDED

Git LFS Details

  • SHA256: a6740a2d29fde73c310dc9d95b82e02163ad2fe014ec3aad8bfbfe4aed964948
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
samples/14.png ADDED

Git LFS Details

  • SHA256: 1180b403e72baed790b92903729f5f07a09f3fdb1ab750330a26b3dca5dfdf22
  • Pointer size: 131 Bytes
  • Size of remote file: 487 kB
samples/15.png ADDED

Git LFS Details

  • SHA256: da6cceb7325b7d4819cf5218768e81fcc486140b5d478c43cc086a7aa3ffed71
  • Pointer size: 131 Bytes
  • Size of remote file: 959 kB
samples/16.png ADDED

Git LFS Details

  • SHA256: 7caca84161a104189ff5a7571dbec8342de97238f99280feeda92b2b768a51fb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.22 MB
samples/17.png ADDED

Git LFS Details

  • SHA256: 15290c3bdcb2bcc575a293a91a1c1e0debdbfcfe2494a8f36458b59403f5fcea
  • Pointer size: 131 Bytes
  • Size of remote file: 846 kB
samples/18.png ADDED

Git LFS Details

  • SHA256: 4b5771918d8ba7605198be2d670d2cf713b642c29da4a04f069d98cd5e45baa7
  • Pointer size: 131 Bytes
  • Size of remote file: 677 kB
samples/19.png ADDED

Git LFS Details

  • SHA256: df6a779b5df871fdce101c5a2b13c77c25160b658c4c729f0d4a3ab34ae11996
  • Pointer size: 131 Bytes
  • Size of remote file: 953 kB
samples/1_out.webp ADDED

Git LFS Details

  • SHA256: 4ff52fb3f22b7413cfaf090bd65f93be38491d4b250c668e67a33b4cc6c88a8a
  • Pointer size: 130 Bytes
  • Size of remote file: 40.8 kB
samples/2.png ADDED

Git LFS Details

  • SHA256: 1a12bcdfac8d901c8d0947213391541dedff0e49964174b9f436726db21a3ad6
  • Pointer size: 131 Bytes
  • Size of remote file: 900 kB
samples/20.png ADDED

Git LFS Details

  • SHA256: afa037f9006cdbdf0707db41b88ae6e8109c196f5ceb5f37523a8364b467d534
  • Pointer size: 131 Bytes
  • Size of remote file: 193 kB
samples/20_out.webp ADDED

Git LFS Details

  • SHA256: 45e65cbbadd94d062f4d56a0204dc37d7e64a8e3dcc05b78dbc54cc6d65b8cb1
  • Pointer size: 130 Bytes
  • Size of remote file: 36.8 kB
samples/21.png ADDED

Git LFS Details

  • SHA256: ecd7d68258b5d342ff9a316c9dfddefda11bc9558e8b582a4bb01977f3da87e5
  • Pointer size: 131 Bytes
  • Size of remote file: 303 kB