Charbel Malo commited on
Commit
6bfe0a3
·
verified ·
1 Parent(s): 9782b0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +423 -236
app.py CHANGED
@@ -15,7 +15,7 @@ import time
15
  import requests
16
  import pandas as pd
17
 
18
- # Load prompts for randomization
19
  df = pd.read_csv('prompts.csv', header=None)
20
  prompt_values = df.values.flatten()
21
 
@@ -81,9 +81,339 @@ def download_file(url, directory=None):
81
  file.write(response.content)
82
 
83
  return filepath
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def check_custom_model(link):
86
- # Your existing implementation of check_custom_model
87
  if link.endswith(".safetensors"):
88
  # Treat as direct link to the LoRA weights
89
  title = os.path.basename(link)
@@ -102,7 +432,6 @@ def check_custom_model(link):
102
  # Assume it's a Hugging Face model path
103
  return get_huggingface_safetensors(link)
104
 
105
-
106
  def update_history(new_image, history):
107
  """Updates the history gallery with the new image."""
108
  if history is None:
@@ -133,7 +462,7 @@ css = '''
133
  #component-11{align-self: stretch;}
134
  '''
135
 
136
- with gr.Blocks(css=css, theme=gr.themes.Default(), delete_cache=(60, 60)) as app:
137
  title = gr.HTML(
138
  """<h1><img src="https://i.imgur.com/wMh2Oek.png" alt="LoRA"> LoRA Lab [beta]</h1><br><span style="
139
  margin-top: -25px !important;
@@ -144,250 +473,108 @@ with gr.Blocks(css=css, theme=gr.themes.Default(), delete_cache=(60, 60)) as app
144
  )
145
  loras_state = gr.State(loras)
146
  selected_indices = gr.State([])
147
-
148
- # Define UI components
149
  with gr.Row():
150
  with gr.Column(scale=3):
151
  prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
152
  with gr.Column(scale=1):
153
  generate_button = gr.Button("Generate", variant="primary", elem_classes=["button_total"])
154
-
155
  with gr.Row(elem_id="loaded_loras"):
156
- randomize_button = gr.Button("🎲", variant="secondary", scale=1, elem_id="random_btn")
157
- # We'll dynamically render the LoRA selections below using @gr.render
158
-
159
- with gr.Group():
160
- with gr.Row(elem_id="custom_lora_structure"):
161
- custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path or *.safetensors public URL", placeholder="multimodalart/vintage-ads-flux", scale=3, min_width=150)
162
- add_custom_lora_button = gr.Button("Add Custom LoRA", elem_id="custom_lora_btn", scale=2, min_width=150)
163
- remove_custom_lora_button = gr.Button("Remove Custom LoRA", visible=False)
164
- gr.Markdown("[Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
165
- gallery = gr.Gallery(
166
- [(item["image"], item["title"]) for item in loras],
167
- label="Or pick from the LoRA Explorer gallery",
168
- allow_preview=False,
169
- columns=5,
170
- elem_id="gallery",
171
- show_share_button=False,
172
- interactive=True # Set to True to allow selection
173
- )
174
- progress_bar = gr.Markdown(elem_id="progress", visible=False)
175
- result = gr.Image(label="Generated Image", interactive=False, show_share_button=False)
176
- with gr.Accordion("History", open=False):
177
- history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
178
-
179
- with gr.Accordion("Advanced Settings", open=False):
180
- with gr.Row():
181
- input_image = gr.Image(label="Input image", type="filepath", show_share_button=False)
182
- image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
183
- with gr.Column():
184
  with gr.Row():
185
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
186
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
187
-
 
 
 
 
 
 
188
  with gr.Row():
189
- width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
190
- height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
 
 
192
  with gr.Row():
193
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
194
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
195
-
196
- # Define states for LoRA selections
197
- selected_loras_state = gr.State([]) # List of selected LoRA indices
198
- lora_scales_state = gr.State([]) # List of corresponding scales
199
-
200
- # Function to handle gallery selection
201
- @gr.render(inputs=[gallery], outputs=[], triggers=[gallery.select])
202
- def update_lora_selection(selected_gallery):
203
- selected_index = selected_gallery.index
204
- selected_indices = selected_loras_state.value.copy()
205
- if selected_index in selected_indices:
206
- selected_indices.remove(selected_index)
207
- else:
208
- if len(selected_indices) < 6:
209
- selected_indices.append(selected_index)
210
- else:
211
- gr.Warning("You can select up to 6 LoRAs. Remove one to select a new one.")
212
- selected_loras_state.value = selected_indices
213
-
214
- # Function to render the LoRA selection components dynamically
215
- @gr.render(inputs=[selected_loras_state], outputs=[], triggers=[selected_loras_state.change])
216
- def render_lora_selections(selected_indices):
217
- lora_scales = lora_scales_state.value
218
- if len(lora_scales) != len(selected_indices):
219
- lora_scales = [1.15] * len(selected_indices)
220
- lora_scales_state.value = lora_scales
221
-
222
- # Clear previous components
223
- gr.Markdown("### Selected LoRAs")
224
- if not selected_indices:
225
- gr.Markdown("No LoRAs selected.")
226
- else:
227
- for idx, sel_idx in enumerate(selected_indices):
228
- lora = loras_state.value[sel_idx]
229
  with gr.Row():
230
- with gr.Column(scale=1, min_width=50):
231
- gr.Image(value=lora['image'], interactive=False, show_label=False, height=50)
232
- with gr.Column(scale=3):
233
- gr.Markdown(f"### LoRA {idx+1}: [{lora['title']}](https://huggingface.co/{lora['repo']}) ✨")
234
- with gr.Column(scale=2):
235
- scale_slider = gr.Slider(label=f"Scale {idx+1}", minimum=0, maximum=3, step=0.01, value=lora_scales[idx])
236
- scale_slider.change(lambda val, idx=idx: update_lora_scale(idx, val), inputs=[scale_slider], outputs=[])
237
- with gr.Column(scale=1, min_width=50):
238
- remove_btn = gr.Button("Remove", size="sm")
239
- remove_btn.click(lambda idx=idx: remove_lora(idx), inputs=[], outputs=[])
240
-
241
- # Helper function to update LoRA scales
242
- def update_lora_scale(idx, value):
243
- lora_scales = lora_scales_state.value
244
- lora_scales[idx] = value
245
- lora_scales_state.value = lora_scales
246
-
247
- # Helper function to remove a LoRA
248
- def remove_lora(idx):
249
- selected_indices = selected_loras_state.value
250
- lora_scales = lora_scales_state.value
251
- if idx < len(selected_indices):
252
- selected_indices.pop(idx)
253
- lora_scales.pop(idx)
254
- selected_loras_state.value = selected_indices
255
- lora_scales_state.value = lora_scales
256
-
257
- # Randomize LoRAs
258
- def randomize_loras():
259
- num_loras = min(6, len(loras_state.value))
260
- selected_indices = random.sample(range(len(loras_state.value)), num_loras)
261
- lora_scales = [1.15] * num_loras
262
- selected_loras_state.value = selected_indices
263
- lora_scales_state.value = lora_scales
264
- random_prompt = random.choice(prompt_values)
265
- prompt.value = random_prompt
266
-
267
- randomize_button.click(randomize_loras, inputs=[], outputs=[])
268
-
269
- # Add custom LoRA
270
- def add_custom_lora_fn(custom_lora_input):
271
- if custom_lora_input:
272
- try:
273
- title, repo, path, trigger_word, image = check_custom_model(custom_lora_input)
274
- existing_item_index = next((index for (index, item) in enumerate(loras_state.value) if item['repo'] == repo), None)
275
- if existing_item_index is None:
276
- if repo.endswith(".safetensors") and repo.startswith("http"):
277
- repo = download_file(repo)
278
- new_item = {
279
- "image": image if image else "/home/user/app/custom.png",
280
- "title": title,
281
- "repo": repo,
282
- "weights": path,
283
- "trigger_word": trigger_word
284
- }
285
- existing_item_index = len(loras_state.value)
286
- loras_state.value.append(new_item)
287
- # Update gallery
288
- gallery.value = [(item["image"], item["title"]) for item in loras_state.value]
289
- if len(selected_loras_state.value) < 6:
290
- selected_loras_state.value.append(existing_item_index)
291
- lora_scales_state.value.append(1.15)
292
- else:
293
- gr.Warning("You can select up to 6 LoRAs. Remove one to select a new one.")
294
- except Exception as e:
295
- gr.Warning(str(e))
296
-
297
- add_custom_lora_button.click(add_custom_lora_fn, inputs=[custom_lora], outputs=[])
298
-
299
- # Run the LoRA generation
300
- @spaces.GPU(duration=75)
301
- def run_lora(prompt_text, image_input_path, image_strength_value, cfg_scale_value, steps_value, randomize_seed_value, seed_value, width_value, height_value):
302
- selected_indices = selected_loras_state.value
303
- lora_scales = lora_scales_state.value
304
-
305
- if not selected_indices:
306
- raise gr.Error("You must select at least one LoRA before proceeding.")
307
-
308
- selected_loras = [loras_state.value[idx] for idx in selected_indices]
309
-
310
- # Build the prompt with trigger words
311
- prepends = []
312
- appends = []
313
- for lora in selected_loras:
314
- trigger_word = lora.get('trigger_word', '')
315
- if trigger_word:
316
- if lora.get("trigger_position") == "prepend":
317
- prepends.append(trigger_word)
318
- else:
319
- appends.append(trigger_word)
320
- prompt_mash = " ".join(prepends + [prompt_text] + appends)
321
- print("Prompt Mash: ", prompt_mash)
322
- # Unload previous LoRA weights
323
- with calculateDuration("Unloading LoRA"):
324
- pipe.unload_lora_weights()
325
- pipe_i2i.unload_lora_weights()
326
-
327
- print(pipe.get_active_adapters())
328
- # Load LoRA weights with respective scales
329
- lora_names = []
330
- lora_weights = []
331
- with calculateDuration("Loading LoRA weights"):
332
- for idx, lora in enumerate(selected_loras):
333
- lora_name = f"lora_{idx}"
334
- lora_names.append(lora_name)
335
- print(f"Lora Name: {lora_name}")
336
- lora_weights.append(lora_scales[idx])
337
- lora_path = lora['repo']
338
- weight_name = lora.get("weights")
339
- print(f"Lora Path: {lora_path}")
340
- pipe_to_use = pipe_i2i if image_input_path is not None else pipe
341
- pipe_to_use.load_lora_weights(
342
- lora_path,
343
- weight_name=weight_name if weight_name else None,
344
- low_cpu_mem_usage=True,
345
- adapter_name=lora_name
346
- )
347
- print("Loaded LoRAs:", lora_names)
348
- print("Adapter weights:", lora_weights)
349
- if image_input_path is not None:
350
- pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
351
- else:
352
- pipe.set_adapters(lora_names, adapter_weights=lora_weights)
353
- print(pipe.get_active_adapters())
354
- # Set random seed for reproducibility
355
- with calculateDuration("Randomizing seed"):
356
- if randomize_seed_value:
357
- seed_value = random.randint(0, MAX_SEED)
358
 
359
- # Generate image
360
- if image_input_path is not None:
361
- final_image = generate_image_to_image(prompt_mash, image_input_path, image_strength_value, steps_value, cfg_scale_value, width_value, height_value, seed_value)
362
- yield final_image, seed_value, gr.update(visible=False)
363
- else:
364
- image_generator = generate_image(prompt_mash, steps_value, seed_value, cfg_scale_value, width_value, height_value, progress=gr.Progress(track_tqdm=True))
365
- # Consume the generator to get the final image
366
- final_image = None
367
- step_counter = 0
368
- for image in image_generator:
369
- step_counter += 1
370
- final_image = image
371
- progress_bar_html = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps_value};"></div></div>'
372
- yield image, seed_value, gr.update(value=progress_bar_html, visible=True)
373
- yield final_image, seed_value, gr.update(value=progress_bar_html, visible=False)
374
-
375
- run_lora.zerogpu = True
376
-
377
- # Bind the generate button to run_lora function
378
- generate_button.click(
379
- run_lora,
380
- inputs=[
381
- prompt,
382
- input_image,
383
- image_strength,
384
- cfg_scale,
385
- steps,
386
- randomize_seed,
387
- seed,
388
- width,
389
- height
390
- ],
 
 
 
 
 
 
 
 
 
391
  outputs=[result, seed, progress_bar]
392
  ).then(
393
  fn=lambda x, history: update_history(x, history),
 
15
  import requests
16
  import pandas as pd
17
 
18
+ #Load prompts for randomization
19
  df = pd.read_csv('prompts.csv', header=None)
20
  prompt_values = df.values.flatten()
21
 
 
81
  file.write(response.content)
82
 
83
  return filepath
84
+
85
+ def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
86
+ selected_index = evt.index
87
+ selected_indices = selected_indices or []
88
+ if selected_index in selected_indices:
89
+ selected_indices.remove(selected_index)
90
+ else:
91
+ if len(selected_indices) < 2:
92
+ selected_indices.append(selected_index)
93
+ else:
94
+ gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
95
+ return gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), width, height, gr.update(), gr.update()
96
+
97
+ selected_info_1 = "Select a LoRA 1"
98
+ selected_info_2 = "Select a LoRA 2"
99
+ lora_scale_1 = 1.15
100
+ lora_scale_2 = 1.15
101
+ lora_image_1 = None
102
+ lora_image_2 = None
103
+ if len(selected_indices) >= 1:
104
+ lora1 = loras_state[selected_indices[0]]
105
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
106
+ lora_image_1 = lora1['image']
107
+ if len(selected_indices) >= 2:
108
+ lora2 = loras_state[selected_indices[1]]
109
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
110
+ lora_image_2 = lora2['image']
111
+
112
+ if selected_indices:
113
+ last_selected_lora = loras_state[selected_indices[-1]]
114
+ new_placeholder = f"Type a prompt for {last_selected_lora['title']}"
115
+ else:
116
+ new_placeholder = "Type a prompt after selecting a LoRA"
117
+
118
+ return gr.update(placeholder=new_placeholder), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2
119
+
120
+ def remove_lora_1(selected_indices, loras_state):
121
+ if len(selected_indices) >= 1:
122
+ selected_indices.pop(0)
123
+ selected_info_1 = "Select a LoRA 1"
124
+ selected_info_2 = "Select a LoRA 2"
125
+ lora_scale_1 = 1.15
126
+ lora_scale_2 = 1.15
127
+ lora_image_1 = None
128
+ lora_image_2 = None
129
+ if len(selected_indices) >= 1:
130
+ lora1 = loras_state[selected_indices[0]]
131
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
132
+ lora_image_1 = lora1['image']
133
+ if len(selected_indices) >= 2:
134
+ lora2 = loras_state[selected_indices[1]]
135
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
136
+ lora_image_2 = lora2['image']
137
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
138
+
139
+ def remove_lora_2(selected_indices, loras_state):
140
+ if len(selected_indices) >= 2:
141
+ selected_indices.pop(1)
142
+ selected_info_1 = "Select LoRA 1"
143
+ selected_info_2 = "Select LoRA 2"
144
+ lora_scale_1 = 1.15
145
+ lora_scale_2 = 1.15
146
+ lora_image_1 = None
147
+ lora_image_2 = None
148
+ if len(selected_indices) >= 1:
149
+ lora1 = loras_state[selected_indices[0]]
150
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
151
+ lora_image_1 = lora1['image']
152
+ if len(selected_indices) >= 2:
153
+ lora2 = loras_state[selected_indices[1]]
154
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
155
+ lora_image_2 = lora2['image']
156
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
157
+
158
+ def randomize_loras(selected_indices, loras_state):
159
+ if len(loras_state) < 2:
160
+ raise gr.Error("Not enough LoRAs to randomize.")
161
+ selected_indices = random.sample(range(len(loras_state)), 2)
162
+ lora1 = loras_state[selected_indices[0]]
163
+ lora2 = loras_state[selected_indices[1]]
164
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
165
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
166
+ lora_scale_1 = 1.15
167
+ lora_scale_2 = 1.15
168
+ lora_image_1 = lora1['image']
169
+ lora_image_2 = lora2['image']
170
+ random_prompt = random.choice(prompt_values)
171
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, random_prompt
172
+
173
+ def add_custom_lora(custom_lora, selected_indices, current_loras, gallery):
174
+ if custom_lora:
175
+ try:
176
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora)
177
+ print(f"Loaded custom LoRA: {repo}")
178
+ existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
179
+ if existing_item_index is None:
180
+ if repo.endswith(".safetensors") and repo.startswith("http"):
181
+ repo = download_file(repo)
182
+ new_item = {
183
+ "image": image if image else "/home/user/app/custom.png",
184
+ "title": title,
185
+ "repo": repo,
186
+ "weights": path,
187
+ "trigger_word": trigger_word
188
+ }
189
+ print(f"New LoRA: {new_item}")
190
+ existing_item_index = len(current_loras)
191
+ current_loras.append(new_item)
192
+
193
+ # Update gallery
194
+ gallery_items = [(item["image"], item["title"]) for item in current_loras]
195
+ # Update selected_indices if there's room
196
+ if len(selected_indices) < 2:
197
+ selected_indices.append(existing_item_index)
198
+ else:
199
+ gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
200
+
201
+ # Update selected_info and images
202
+ selected_info_1 = "Select a LoRA 1"
203
+ selected_info_2 = "Select a LoRA 2"
204
+ lora_scale_1 = 1.15
205
+ lora_scale_2 = 1.15
206
+ lora_image_1 = None
207
+ lora_image_2 = None
208
+ if len(selected_indices) >= 1:
209
+ lora1 = current_loras[selected_indices[0]]
210
+ selected_info_1 = f"### LoRA 1 Selected: {lora1['title']} ✨"
211
+ lora_image_1 = lora1['image'] if lora1['image'] else None
212
+ if len(selected_indices) >= 2:
213
+ lora2 = current_loras[selected_indices[1]]
214
+ selected_info_2 = f"### LoRA 2 Selected: {lora2['title']} ✨"
215
+ lora_image_2 = lora2['image'] if lora2['image'] else None
216
+ print("Finished adding custom LoRA")
217
+ return (
218
+ current_loras,
219
+ gr.update(value=gallery_items),
220
+ selected_info_1,
221
+ selected_info_2,
222
+ selected_indices,
223
+ lora_scale_1,
224
+ lora_scale_2,
225
+ lora_image_1,
226
+ lora_image_2
227
+ )
228
+ except Exception as e:
229
+ print(e)
230
+ gr.Warning(str(e))
231
+ return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
232
+ else:
233
+ return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
234
+
235
+ def remove_custom_lora(selected_indices, current_loras, gallery):
236
+ if current_loras:
237
+ custom_lora_repo = current_loras[-1]['repo']
238
+ # Remove from loras list
239
+ current_loras = current_loras[:-1]
240
+ # Remove from selected_indices if selected
241
+ custom_lora_index = len(current_loras)
242
+ if custom_lora_index in selected_indices:
243
+ selected_indices.remove(custom_lora_index)
244
+ # Update gallery
245
+ gallery_items = [(item["image"], item["title"]) for item in current_loras]
246
+ # Update selected_info and images
247
+ selected_info_1 = "Select a LoRA 1"
248
+ selected_info_2 = "Select a LoRA 2"
249
+ lora_scale_1 = 1.15
250
+ lora_scale_2 = 1.15
251
+ lora_image_1 = None
252
+ lora_image_2 = None
253
+ if len(selected_indices) >= 1:
254
+ lora1 = current_loras[selected_indices[0]]
255
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
256
+ lora_image_1 = lora1['image']
257
+ if len(selected_indices) >= 2:
258
+ lora2 = current_loras[selected_indices[1]]
259
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
260
+ lora_image_2 = lora2['image']
261
+ return (
262
+ current_loras,
263
+ gr.update(value=gallery_items),
264
+ selected_info_1,
265
+ selected_info_2,
266
+ selected_indices,
267
+ lora_scale_1,
268
+ lora_scale_2,
269
+ lora_image_1,
270
+ lora_image_2
271
+ )
272
+
273
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
274
+ print("Generating image...")
275
+ pipe.to("cuda")
276
+ generator = torch.Generator(device="cuda").manual_seed(seed)
277
+ with calculateDuration("Generating image"):
278
+ # Generate image
279
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
280
+ prompt=prompt_mash,
281
+ num_inference_steps=steps,
282
+ guidance_scale=cfg_scale,
283
+ width=width,
284
+ height=height,
285
+ generator=generator,
286
+ joint_attention_kwargs={"scale": 1.0},
287
+ output_type="pil",
288
+ good_vae=good_vae,
289
+ ):
290
+ yield img
291
+
292
+ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
293
+ pipe_i2i.to("cuda")
294
+ generator = torch.Generator(device="cuda").manual_seed(seed)
295
+ image_input = load_image(image_input_path)
296
+ final_image = pipe_i2i(
297
+ prompt=prompt_mash,
298
+ image=image_input,
299
+ strength=image_strength,
300
+ num_inference_steps=steps,
301
+ guidance_scale=cfg_scale,
302
+ width=width,
303
+ height=height,
304
+ generator=generator,
305
+ joint_attention_kwargs={"scale": 1.0},
306
+ output_type="pil",
307
+ ).images[0]
308
+ return final_image
309
+
310
+ @spaces.GPU(duration=75)
311
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
312
+ if not selected_indices:
313
+ raise gr.Error("You must select at least one LoRA before proceeding.")
314
+
315
+ selected_loras = [loras_state[idx] for idx in selected_indices]
316
+
317
+ # Build the prompt with trigger words
318
+ prepends = []
319
+ appends = []
320
+ for lora in selected_loras:
321
+ trigger_word = lora.get('trigger_word', '')
322
+ if trigger_word:
323
+ if lora.get("trigger_position") == "prepend":
324
+ prepends.append(trigger_word)
325
+ else:
326
+ appends.append(trigger_word)
327
+ prompt_mash = " ".join(prepends + [prompt] + appends)
328
+ print("Prompt Mash: ", prompt_mash)
329
+ # Unload previous LoRA weights
330
+ with calculateDuration("Unloading LoRA"):
331
+ pipe.unload_lora_weights()
332
+ pipe_i2i.unload_lora_weights()
333
+
334
+ print(pipe.get_active_adapters())
335
+ # Load LoRA weights with respective scales
336
+ lora_names = []
337
+ lora_weights = []
338
+ with calculateDuration("Loading LoRA weights"):
339
+ for idx, lora in enumerate(selected_loras):
340
+ lora_name = f"lora_{idx}"
341
+ lora_names.append(lora_name)
342
+ print(f"Lora Name: {lora_name}")
343
+ lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
344
+ lora_path = lora['repo']
345
+ weight_name = lora.get("weights")
346
+ print(f"Lora Path: {lora_path}")
347
+ pipe_to_use = pipe_i2i if image_input is not None else pipe
348
+ pipe_to_use.load_lora_weights(
349
+ lora_path,
350
+ weight_name=weight_name if weight_name else None,
351
+ low_cpu_mem_usage=True,
352
+ adapter_name=lora_name
353
+ )
354
+ # if image_input is not None: pipe_i2i = pipe_to_use
355
+ # else: pipe = pipe_to_use
356
+ print("Loaded LoRAs:", lora_names)
357
+ print("Adapter weights:", lora_weights)
358
+ if image_input is not None:
359
+ pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
360
+ else:
361
+ pipe.set_adapters(lora_names, adapter_weights=lora_weights)
362
+ print(pipe.get_active_adapters())
363
+ # Set random seed for reproducibility
364
+ with calculateDuration("Randomizing seed"):
365
+ if randomize_seed:
366
+ seed = random.randint(0, MAX_SEED)
367
+
368
+ # Generate image
369
+ if image_input is not None:
370
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
371
+ yield final_image, seed, gr.update(visible=False)
372
+ else:
373
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
374
+ # Consume the generator to get the final image
375
+ final_image = None
376
+ step_counter = 0
377
+ for image in image_generator:
378
+ step_counter += 1
379
+ final_image = image
380
+ progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
381
+ yield image, seed, gr.update(value=progress_bar, visible=True)
382
+ yield final_image, seed, gr.update(value=progress_bar, visible=False)
383
+
384
+ run_lora.zerogpu = True
385
+
386
+ def get_huggingface_safetensors(link):
387
+ split_link = link.split("/")
388
+ if len(split_link) == 2:
389
+ model_card = ModelCard.load(link)
390
+ base_model = model_card.data.get("base_model")
391
+ print(f"Base model: {base_model}")
392
+ if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
393
+ raise Exception("Not a FLUX LoRA!")
394
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
395
+ trigger_word = model_card.data.get("instance_prompt", "")
396
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
397
+ fs = HfFileSystem()
398
+ safetensors_name = None
399
+ try:
400
+ list_of_files = fs.ls(link, detail=False)
401
+ for file in list_of_files:
402
+ if file.endswith(".safetensors"):
403
+ safetensors_name = file.split("/")[-1]
404
+ if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
405
+ image_elements = file.split("/")
406
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
407
+ except Exception as e:
408
+ print(e)
409
+ raise gr.Error("Invalid Hugging Face repository with a *.safetensors LoRA")
410
+ if not safetensors_name:
411
+ raise gr.Error("No *.safetensors file found in the repository")
412
+ return split_link[1], link, safetensors_name, trigger_word, image_url
413
+ else:
414
+ raise gr.Error("Invalid Hugging Face repository link")
415
 
416
  def check_custom_model(link):
 
417
  if link.endswith(".safetensors"):
418
  # Treat as direct link to the LoRA weights
419
  title = os.path.basename(link)
 
432
  # Assume it's a Hugging Face model path
433
  return get_huggingface_safetensors(link)
434
 
 
435
  def update_history(new_image, history):
436
  """Updates the history gallery with the new image."""
437
  if history is None:
 
462
  #component-11{align-self: stretch;}
463
  '''
464
 
465
+ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
466
  title = gr.HTML(
467
  """<h1><img src="https://i.imgur.com/wMh2Oek.png" alt="LoRA"> LoRA Lab [beta]</h1><br><span style="
468
  margin-top: -25px !important;
 
473
  )
474
  loras_state = gr.State(loras)
475
  selected_indices = gr.State([])
 
 
476
  with gr.Row():
477
  with gr.Column(scale=3):
478
  prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
479
  with gr.Column(scale=1):
480
  generate_button = gr.Button("Generate", variant="primary", elem_classes=["button_total"])
 
481
  with gr.Row(elem_id="loaded_loras"):
482
+ with gr.Column(scale=1, min_width=25):
483
+ randomize_button = gr.Button("🎲", variant="secondary", scale=1, elem_id="random_btn")
484
+ with gr.Column(scale=8):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  with gr.Row():
486
+ with gr.Column(scale=0, min_width=50):
487
+ lora_image_1 = gr.Image(label="LoRA 1 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
488
+ with gr.Column(scale=3, min_width=100):
489
+ selected_info_1 = gr.Markdown("Select a LoRA 1")
490
+ with gr.Column(scale=5, min_width=50):
491
+ lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
492
+ with gr.Row():
493
+ remove_button_1 = gr.Button("Remove", size="sm")
494
+ with gr.Column(scale=8):
495
  with gr.Row():
496
+ with gr.Column(scale=0, min_width=50):
497
+ lora_image_2 = gr.Image(label="LoRA 2 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
498
+ with gr.Column(scale=3, min_width=100):
499
+ selected_info_2 = gr.Markdown("Select a LoRA 2")
500
+ with gr.Column(scale=5, min_width=50):
501
+ lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
502
+ with gr.Row():
503
+ remove_button_2 = gr.Button("Remove", size="sm")
504
+ with gr.Row():
505
+ with gr.Column():
506
+ with gr.Group():
507
+ with gr.Row(elem_id="custom_lora_structure"):
508
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path or *.safetensors public URL", placeholder="multimodalart/vintage-ads-flux", scale=3, min_width=150)
509
+ add_custom_lora_button = gr.Button("Add Custom LoRA", elem_id="custom_lora_btn", scale=2, min_width=150)
510
+ remove_custom_lora_button = gr.Button("Remove Custom LoRA", visible=False)
511
+ gr.Markdown("[Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
512
+ gallery = gr.Gallery(
513
+ [(item["image"], item["title"]) for item in loras],
514
+ label="Or pick from the LoRA Explorer gallery",
515
+ allow_preview=False,
516
+ columns=5,
517
+ elem_id="gallery",
518
+ show_share_button=False,
519
+ interactive=False
520
+ )
521
+ with gr.Column():
522
+ progress_bar = gr.Markdown(elem_id="progress", visible=False)
523
+ result = gr.Image(label="Generated Image", interactive=False, show_share_button=False)
524
+ with gr.Accordion("History", open=False):
525
+ history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
526
 
527
+ with gr.Row():
528
+ with gr.Accordion("Advanced Settings", open=False):
529
  with gr.Row():
530
+ input_image = gr.Image(label="Input image", type="filepath", show_share_button=False)
531
+ image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
532
+ with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  with gr.Row():
534
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
535
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
 
537
+ with gr.Row():
538
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
539
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
540
+
541
+ with gr.Row():
542
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
543
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
544
+
545
+ gallery.select(
546
+ update_selection,
547
+ inputs=[selected_indices, loras_state, width, height],
548
+ outputs=[prompt, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2])
549
+ remove_button_1.click(
550
+ remove_lora_1,
551
+ inputs=[selected_indices, loras_state],
552
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
553
+ )
554
+ remove_button_2.click(
555
+ remove_lora_2,
556
+ inputs=[selected_indices, loras_state],
557
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
558
+ )
559
+ randomize_button.click(
560
+ randomize_loras,
561
+ inputs=[selected_indices, loras_state],
562
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, prompt]
563
+ )
564
+ add_custom_lora_button.click(
565
+ add_custom_lora,
566
+ inputs=[custom_lora, selected_indices, loras_state, gallery],
567
+ outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
568
+ )
569
+ remove_custom_lora_button.click(
570
+ remove_custom_lora,
571
+ inputs=[selected_indices, loras_state, gallery],
572
+ outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
573
+ )
574
+ gr.on(
575
+ triggers=[generate_button.click, prompt.submit],
576
+ fn=run_lora,
577
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
578
  outputs=[result, seed, progress_bar]
579
  ).then(
580
  fn=lambda x, history: update_history(x, history),