multimodalart HF staff commited on
Commit
2c6e805
1 Parent(s): 1816d2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -55
app.py CHANGED
@@ -70,6 +70,8 @@ def update_selection(evt: gr.SelectData, selected_indices, width, height):
70
  # Initialize outputs
71
  selected_info_1 = ""
72
  selected_info_2 = ""
 
 
73
  if len(selected_indices) >= 1:
74
  lora1 = loras[selected_indices[0]]
75
  selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
@@ -89,6 +91,8 @@ def update_selection(evt: gr.SelectData, selected_indices, width, height):
89
  selected_info_1,
90
  selected_info_2,
91
  selected_indices,
 
 
92
  width,
93
  height,
94
  )
@@ -100,13 +104,15 @@ def remove_lora_1(selected_indices):
100
  # Update selected_info_1 and selected_info_2
101
  selected_info_1 = ""
102
  selected_info_2 = ""
 
 
103
  if len(selected_indices) >= 1:
104
  lora1 = loras[selected_indices[0]]
105
  selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
106
  if len(selected_indices) >= 2:
107
  lora2 = loras[selected_indices[1]]
108
  selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
109
- return selected_info_1, selected_info_2, selected_indices
110
 
111
  def remove_lora_2(selected_indices):
112
  selected_indices = selected_indices or []
@@ -115,13 +121,25 @@ def remove_lora_2(selected_indices):
115
  # Update selected_info_1 and selected_info_2
116
  selected_info_1 = ""
117
  selected_info_2 = ""
 
 
118
  if len(selected_indices) >= 1:
119
  lora1 = loras[selected_indices[0]]
120
  selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
121
  if len(selected_indices) >= 2:
122
  lora2 = loras[selected_indices[1]]
123
  selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
124
- return selected_info_1, selected_info_2, selected_indices
 
 
 
 
 
 
 
 
 
 
125
 
126
  @spaces.GPU(duration=70)
127
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
@@ -220,42 +238,44 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
220
  yield final_image, seed, gr.update(value=progress_bar, visible=False)
221
 
222
  def get_huggingface_safetensors(link):
223
- split_link = link.split("/")
224
- if(len(split_link) == 2):
225
- model_card = ModelCard.load(link)
226
- base_model = model_card.data.get("base_model")
227
- print(base_model)
228
- if((base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell")):
229
- raise Exception("Not a FLUX LoRA!")
230
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
231
- trigger_word = model_card.data.get("instance_prompt", "")
232
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
233
- fs = HfFileSystem()
234
- try:
235
- list_of_files = fs.ls(link, detail=False)
236
- for file in list_of_files:
237
- if(file.endswith(".safetensors")):
238
- safetensors_name = file.split("/")[-1]
239
- if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
240
- image_elements = file.split("/")
241
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
242
- except Exception as e:
243
- print(e)
244
- gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
245
- raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
246
- return split_link[1], link, safetensors_name, trigger_word, image_url
 
 
247
 
248
  def check_custom_model(link):
249
- if(link.startswith("https://")):
250
- if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
251
  link_split = link.split("huggingface.co/")
252
  return get_huggingface_safetensors(link_split[1])
253
  else:
254
  return get_huggingface_safetensors(link)
255
 
256
- def add_custom_lora(custom_lora):
257
  global loras
258
- if(custom_lora):
259
  try:
260
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
261
  print(f"Loaded custom LoRA: {repo}")
@@ -272,7 +292,7 @@ def add_custom_lora(custom_lora):
272
  </div>
273
  '''
274
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
275
- if(not existing_item_index):
276
  new_item = {
277
  "image": image,
278
  "title": title,
@@ -283,16 +303,46 @@ def add_custom_lora(custom_lora):
283
  print(new_item)
284
  existing_item_index = len(loras)
285
  loras.append(new_item)
286
-
287
- return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  except Exception as e:
289
- gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-FLUX LoRA")
290
- return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-FLUX LoRA"), gr.update(visible=True), gr.update(), "", None, ""
291
  else:
292
- return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
293
 
294
- def remove_custom_lora():
295
- return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  run_lora.zerogpu = True
298
 
@@ -303,6 +353,7 @@ css = '''
303
  #title img{width: 100px; margin-right: 0.5em}
304
  #gallery .grid-wrap{height: 10vh}
305
  #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
 
306
  .card_internal{display: flex;height: 100px;margin-top: .5em}
307
  .card_internal img{margin-right: 1em}
308
  .styler{--form-gap-width: 0px !important}
@@ -311,6 +362,7 @@ css = '''
311
  .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
312
  .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
313
  '''
 
314
  with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
315
  title = gr.HTML(
316
  """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> LoRA Lab</h1>""",
@@ -322,6 +374,14 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
322
  prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
323
  with gr.Column(scale=1, elem_id="gen_column"):
324
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
 
 
 
 
 
 
 
 
325
  with gr.Row():
326
  with gr.Column():
327
  gallery = gr.Gallery(
@@ -339,15 +399,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
339
  with gr.Column():
340
  progress_bar = gr.Markdown(elem_id="progress",visible=False)
341
  result = gr.Image(label="Generated Image")
342
- with gr.Row():
343
- with gr.Column():
344
- selected_info_1 = gr.Markdown("")
345
- lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=0.95)
346
- remove_button_1 = gr.Button("Remove LoRA 1")
347
- with gr.Column():
348
- selected_info_2 = gr.Markdown("")
349
- lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=0.95)
350
- remove_button_2 = gr.Button("Remove LoRA 2")
351
  with gr.Row():
352
  with gr.Accordion("Advanced Settings", open=False):
353
  with gr.Row():
@@ -369,26 +420,32 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
369
  gallery.select(
370
  update_selection,
371
  inputs=[selected_indices, width, height],
372
- outputs=[prompt, selected_info_1, selected_info_2, selected_indices, width, height]
373
  )
374
  remove_button_1.click(
375
  remove_lora_1,
376
  inputs=[selected_indices],
377
- outputs=[selected_info_1, selected_info_2, selected_indices]
378
  )
379
  remove_button_2.click(
380
  remove_lora_2,
381
  inputs=[selected_indices],
382
- outputs=[selected_info_1, selected_info_2, selected_indices]
 
 
 
 
 
383
  )
384
- custom_lora.input(
385
  add_custom_lora,
386
- inputs=[custom_lora],
387
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info_1, selected_indices, prompt]
388
  )
389
  custom_lora_button.click(
390
  remove_custom_lora,
391
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info_1, selected_indices, custom_lora]
 
392
  )
393
  gr.on(
394
  triggers=[generate_button.click, prompt.submit],
@@ -398,4 +455,4 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
398
  )
399
 
400
  app.queue()
401
- app.launch()
 
70
  # Initialize outputs
71
  selected_info_1 = ""
72
  selected_info_2 = ""
73
+ lora_scale_1 = 0.95
74
+ lora_scale_2 = 0.95
75
  if len(selected_indices) >= 1:
76
  lora1 = loras[selected_indices[0]]
77
  selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
 
91
  selected_info_1,
92
  selected_info_2,
93
  selected_indices,
94
+ lora_scale_1,
95
+ lora_scale_2,
96
  width,
97
  height,
98
  )
 
104
  # Update selected_info_1 and selected_info_2
105
  selected_info_1 = ""
106
  selected_info_2 = ""
107
+ lora_scale_1 = 0.95
108
+ lora_scale_2 = 0.95
109
  if len(selected_indices) >= 1:
110
  lora1 = loras[selected_indices[0]]
111
  selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
112
  if len(selected_indices) >= 2:
113
  lora2 = loras[selected_indices[1]]
114
  selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
115
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2
116
 
117
  def remove_lora_2(selected_indices):
118
  selected_indices = selected_indices or []
 
121
  # Update selected_info_1 and selected_info_2
122
  selected_info_1 = ""
123
  selected_info_2 = ""
124
+ lora_scale_1 = 0.95
125
+ lora_scale_2 = 0.95
126
  if len(selected_indices) >= 1:
127
  lora1 = loras[selected_indices[0]]
128
  selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
129
  if len(selected_indices) >= 2:
130
  lora2 = loras[selected_indices[1]]
131
  selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
132
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2
133
+
134
+ def randomize_loras(selected_indices):
135
+ if len(loras) < 2:
136
+ raise gr.Error("Not enough LoRAs to randomize.")
137
+ selected_indices = random.sample(range(len(loras)), 2)
138
+ selected_info_1 = f"### LoRA 1 Selected: [{loras[selected_indices[0]]['title']}](https://huggingface.co/{loras[selected_indices[0]]['repo']}) ✨"
139
+ selected_info_2 = f"### LoRA 2 Selected: [{loras[selected_indices[1]]['title']}](https://huggingface.co/{loras[selected_indices[1]]['repo']}) ✨"
140
+ lora_scale_1 = 0.95
141
+ lora_scale_2 = 0.95
142
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2
143
 
144
  @spaces.GPU(duration=70)
145
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
 
238
  yield final_image, seed, gr.update(value=progress_bar, visible=False)
239
 
240
  def get_huggingface_safetensors(link):
241
+ split_link = link.split("/")
242
+ if len(split_link) == 2:
243
+ model_card = ModelCard.load(link)
244
+ base_model = model_card.data.get("base_model")
245
+ print(base_model)
246
+ if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
247
+ raise Exception("Not a FLUX LoRA!")
248
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
249
+ trigger_word = model_card.data.get("instance_prompt", "")
250
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
251
+ fs = HfFileSystem()
252
+ safetensors_name = None
253
+ try:
254
+ list_of_files = fs.ls(link, detail=False)
255
+ for file in list_of_files:
256
+ if file.endswith(".safetensors"):
257
+ safetensors_name = file.split("/")[-1]
258
+ if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
259
+ image_elements = file.split("/")
260
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
261
+ except Exception as e:
262
+ print(e)
263
+ raise Exception("Invalid Hugging Face repository with a *.safetensors LoRA")
264
+ if not safetensors_name:
265
+ raise Exception("No *.safetensors file found in the repository")
266
+ return split_link[1], link, safetensors_name, trigger_word, image_url
267
 
268
  def check_custom_model(link):
269
+ if link.startswith("https://"):
270
+ if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
271
  link_split = link.split("huggingface.co/")
272
  return get_huggingface_safetensors(link_split[1])
273
  else:
274
  return get_huggingface_safetensors(link)
275
 
276
+ def add_custom_lora(custom_lora, selected_indices):
277
  global loras
278
+ if custom_lora:
279
  try:
280
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
281
  print(f"Loaded custom LoRA: {repo}")
 
292
  </div>
293
  '''
294
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
295
+ if existing_item_index is None:
296
  new_item = {
297
  "image": image,
298
  "title": title,
 
303
  print(new_item)
304
  existing_item_index = len(loras)
305
  loras.append(new_item)
306
+
307
+ # Update gallery
308
+ gallery_items = [(item["image"], item["title"]) for item in loras]
309
+ # Update selected_indices if there's room
310
+ if len(selected_indices) < 2:
311
+ selected_indices.append(existing_item_index)
312
+ selected_info_1 = ""
313
+ selected_info_2 = ""
314
+ lora_scale_1 = 0.95
315
+ lora_scale_2 = 0.95
316
+ if len(selected_indices) >= 1:
317
+ lora1 = loras[selected_indices[0]]
318
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
319
+ if len(selected_indices) >= 2:
320
+ lora2 = loras[selected_indices[1]]
321
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
322
+ return (gr.update(visible=True, value=card), gr.update(visible=True), gr.update(value=gallery_items),
323
+ selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2)
324
+ else:
325
+ return (gr.update(visible=True, value=card), gr.update(visible=True), gr.update(value=gallery_items),
326
+ gr.NoChange(), gr.NoChange(), selected_indices, gr.NoChange(), gr.NoChange())
327
  except Exception as e:
328
+ print(e)
329
+ return gr.update(visible=True, value=str(e)), gr.update(visible=True), gr.NoChange(), gr.NoChange(), gr.NoChange(), selected_indices, gr.NoChange(), gr.NoChange()
330
  else:
331
+ return gr.update(visible=False), gr.update(visible=False), gr.NoChange(), gr.NoChange(), gr.NoChange(), selected_indices, gr.NoChange(), gr.NoChange()
332
 
333
+ def remove_custom_lora(custom_lora_info, custom_lora_button, selected_indices):
334
+ global loras
335
+ if loras:
336
+ custom_lora_repo = loras[-1]['repo']
337
+ # Remove from loras list
338
+ loras = loras[:-1]
339
+ # Remove from selected_indices if selected
340
+ custom_lora_index = len(loras)
341
+ if custom_lora_index in selected_indices:
342
+ selected_indices.remove(custom_lora_index)
343
+ # Update gallery
344
+ gallery_items = [(item["image"], item["title"]) for item in loras]
345
+ return gr.update(visible=False), gr.update(visible=False), gr.update(value=gallery_items), gr.NoChange(), gr.NoChange(), selected_indices, gr.NoChange(), gr.NoChange()
346
 
347
  run_lora.zerogpu = True
348
 
 
353
  #title img{width: 100px; margin-right: 0.5em}
354
  #gallery .grid-wrap{height: 10vh}
355
  #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
356
+ .custom_lora_card{margin-bottom: 1em}
357
  .card_internal{display: flex;height: 100px;margin-top: .5em}
358
  .card_internal img{margin-right: 1em}
359
  .styler{--form-gap-width: 0px !important}
 
362
  .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
363
  .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
364
  '''
365
+
366
  with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
367
  title = gr.HTML(
368
  """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> LoRA Lab</h1>""",
 
374
  prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
375
  with gr.Column(scale=1, elem_id="gen_column"):
376
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
377
+ randomize_button = gr.Button("🎲", variant="secondary")
378
+ with gr.Row():
379
+ selected_info_1 = gr.Markdown("")
380
+ lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=0.95)
381
+ remove_button_1 = gr.Button("Remove LoRA 1")
382
+ selected_info_2 = gr.Markdown("")
383
+ lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=0.95)
384
+ remove_button_2 = gr.Button("Remove LoRA 2")
385
  with gr.Row():
386
  with gr.Column():
387
  gallery = gr.Gallery(
 
399
  with gr.Column():
400
  progress_bar = gr.Markdown(elem_id="progress",visible=False)
401
  result = gr.Image(label="Generated Image")
 
 
 
 
 
 
 
 
 
402
  with gr.Row():
403
  with gr.Accordion("Advanced Settings", open=False):
404
  with gr.Row():
 
420
  gallery.select(
421
  update_selection,
422
  inputs=[selected_indices, width, height],
423
+ outputs=[prompt, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height]
424
  )
425
  remove_button_1.click(
426
  remove_lora_1,
427
  inputs=[selected_indices],
428
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2]
429
  )
430
  remove_button_2.click(
431
  remove_lora_2,
432
  inputs=[selected_indices],
433
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2]
434
+ )
435
+ randomize_button.click(
436
+ randomize_loras,
437
+ inputs=[selected_indices],
438
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2]
439
  )
440
+ custom_lora.submit(
441
  add_custom_lora,
442
+ inputs=[custom_lora, selected_indices],
443
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2]
444
  )
445
  custom_lora_button.click(
446
  remove_custom_lora,
447
+ inputs=[custom_lora_info, custom_lora_button, selected_indices],
448
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2]
449
  )
450
  gr.on(
451
  triggers=[generate_button.click, prompt.submit],
 
455
  )
456
 
457
  app.queue()
458
+ app.launch()