ruslanmv commited on
Commit
2f46b7a
·
verified ·
1 Parent(s): 1d05326

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +345 -493
app.py CHANGED
@@ -1,57 +1,69 @@
1
- ##############################
2
- # ===== Standard Imports =====
3
- ##############################
4
  import os
5
- import sys
 
6
  import time
7
  import random
8
- import json
 
9
  from typing import Any, Dict, List, Optional, Union
10
 
11
  import torch
12
- import numpy as np
13
  from PIL import Image
14
  import gradio as gr
15
- import spaces
16
 
17
- # Diffusers imports
18
  from diffusers import (
19
  DiffusionPipeline,
20
  AutoencoderTiny,
21
  AutoencoderKL,
22
  AutoPipelineForImage2Image,
 
 
 
 
 
 
 
 
 
23
  )
 
24
  from diffusers.utils import load_image
25
 
26
- # Hugging Face Hub imports
27
- from huggingface_hub import ModelCard, HfFileSystem
28
-
29
- ##############################
30
- # ===== config.py =====
31
- ##############################
32
- DTYPE = torch.bfloat16
33
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
- BASE_MODEL = "black-forest-labs/FLUX.1-dev"
35
- TAEF1_MODEL = "madebyollin/taef1"
36
- MAX_SEED = 2**32 - 1
37
-
38
- ##############################
39
- # ===== utilities.py =====
40
- ##############################
41
- def calculate_shift(image_seq_len, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.16):
 
 
 
 
42
  m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
43
  b = base_shift - m * base_seq_len
44
  mu = image_seq_len * m + b
45
  return mu
46
 
47
- def retrieve_timesteps(scheduler,
48
- num_inference_steps: Optional[int] = None,
49
- device: Optional[Union[str, torch.device]] = None,
50
- timesteps: Optional[List[int]] = None,
51
- sigmas: Optional[List[float]] = None,
52
- **kwargs):
 
 
53
  if timesteps is not None and sigmas is not None:
54
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
55
  if timesteps is not None:
56
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
57
  timesteps = scheduler.timesteps
@@ -65,119 +77,31 @@ def retrieve_timesteps(scheduler,
65
  timesteps = scheduler.timesteps
66
  return timesteps, num_inference_steps
67
 
68
- def load_image_from_path(image_path: str):
69
- return load_image(image_path)
70
-
71
- def randomize_seed_if_needed(randomize_seed: bool, seed: int, max_seed: int) -> int:
72
- if randomize_seed:
73
- return random.randint(0, max_seed)
74
- return seed
75
-
76
- class calculateDuration:
77
- def __init__(self, activity_name=""):
78
- self.activity_name = activity_name
79
- def __enter__(self):
80
- self.start_time = time.time()
81
- return self
82
- def __exit__(self, exc_type, exc_value, traceback):
83
- self.end_time = time.time()
84
- elapsed = self.end_time - self.start_time
85
- if self.activity_name:
86
- print(f"Elapsed time for {self.activity_name}: {elapsed:.6f} seconds")
87
- else:
88
- print(f"Elapsed time: {elapsed:.6f} seconds")
89
-
90
- ##############################
91
- # ===== Helper: truncate_prompt =====
92
- ##############################
93
- def truncate_prompt(prompt: str) -> str:
94
- """
95
- Uses the global pipeline's tokenizer (assumed available as `pipe.tokenizer`)
96
- to truncate the prompt to the maximum allowed length.
97
- """
98
- try:
99
- tokenizer = pipe.tokenizer
100
- tokenized = tokenizer(prompt, truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt")
101
- return tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True)
102
- except Exception as e:
103
- print(f"Error in truncate_prompt: {e}")
104
- return prompt
105
-
106
- ##############################
107
- # ===== enhance.py =====
108
- ##############################
109
- def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0):
110
- SYSTEM_PROMPT = (
111
- "You are a prompt enhancer and your work is to enhance the given prompt under 100 words "
112
- "without changing the essence, only write the enhanced prompt and nothing else."
113
- )
114
- timestamp = time.time()
115
- formatted_prompt = f"<s>[INST] SYSTEM: {SYSTEM_PROMPT} [/INST][INST] {message} {timestamp} [/INST]"
116
- api_url = "https://ruslanmv-hf-llm-api.hf.space/api/v1/chat/completions"
117
- headers = {"Content-Type": "application/json"}
118
- payload = {
119
- "model": "mixtral-8x7b",
120
- "messages": [{"role": "user", "content": formatted_prompt}],
121
- "temperature": temperature,
122
- "top_p": top_p,
123
- "max_tokens": max_new_tokens,
124
- "use_cache": False,
125
- "stream": True
126
- }
127
- try:
128
- response = requests.post(api_url, headers=headers, json=payload, stream=True)
129
- response.raise_for_status()
130
- full_output = ""
131
- for line in response.iter_lines():
132
- if not line:
133
- continue
134
- decoded_line = line.decode("utf-8").strip()
135
- if decoded_line.startswith("data:"):
136
- decoded_line = decoded_line[len("data:"):].strip()
137
- if decoded_line == "[DONE]":
138
- break
139
- try:
140
- json_data = json.loads(decoded_line)
141
- for choice in json_data.get("choices", []):
142
- delta = choice.get("delta", {})
143
- content = delta.get("content", "")
144
- full_output += content
145
- yield full_output
146
- if choice.get("finish_reason") == "stop":
147
- return
148
- except json.JSONDecodeError:
149
- continue
150
- except requests.exceptions.RequestException as e:
151
- yield f"Error during generation: {str(e)}"
152
-
153
- ##############################
154
- # ===== lora_handling.py =====
155
- ##############################
156
- loras = [
157
- {"image": "placeholder.jpg", "title": "Placeholder LoRA", "repo": "placeholder/repo", "weights": None, "trigger_word": ""}
158
- ]
159
-
160
  @torch.inference_mode()
161
- def flux_pipe_call_that_returns_an_iterable_of_images(self,
162
- prompt: Union[str, List[str]] = None,
163
- prompt_2: Optional[Union[str, List[str]]] = None,
164
- height: Optional[int] = None,
165
- width: Optional[int] = None,
166
- num_inference_steps: int = 28,
167
- timesteps: List[int] = None,
168
- guidance_scale: float = 3.5,
169
- num_images_per_prompt: Optional[int] = 1,
170
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
171
- latents: Optional[torch.FloatTensor] = None,
172
- prompt_embeds: Optional[torch.FloatTensor] = None,
173
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
174
- output_type: Optional[str] = "pil",
175
- return_dict: bool = True,
176
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
177
- max_sequence_length: int = 512,
178
- good_vae: Optional[Any] = None):
 
 
179
  height = height or self.default_sample_size * self.vae_scale_factor
180
  width = width or self.default_sample_size * self.vae_scale_factor
 
181
  self.check_inputs(
182
  prompt,
183
  prompt_2,
@@ -187,11 +111,14 @@ def flux_pipe_call_that_returns_an_iterable_of_images(self,
187
  pooled_prompt_embeds=pooled_prompt_embeds,
188
  max_sequence_length=max_sequence_length,
189
  )
 
190
  self._guidance_scale = guidance_scale
191
  self._joint_attention_kwargs = joint_attention_kwargs
192
  self._interrupt = False
 
193
  batch_size = 1 if isinstance(prompt, str) else len(prompt)
194
  device = self._execution_device
 
195
  lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
196
  prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
197
  prompt=prompt,
@@ -203,6 +130,7 @@ def flux_pipe_call_that_returns_an_iterable_of_images(self,
203
  max_sequence_length=max_sequence_length,
204
  lora_scale=lora_scale,
205
  )
 
206
  num_channels_latents = self.transformer.config.in_channels // 4
207
  latents, latent_image_ids = self.prepare_latents(
208
  batch_size * num_images_per_prompt,
@@ -214,6 +142,7 @@ def flux_pipe_call_that_returns_an_iterable_of_images(self,
214
  generator,
215
  latents,
216
  )
 
217
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
218
  image_seq_len = latents.shape[1]
219
  mu = calculate_shift(
@@ -232,13 +161,15 @@ def flux_pipe_call_that_returns_an_iterable_of_images(self,
232
  mu=mu,
233
  )
234
  self._num_timesteps = len(timesteps)
235
- guidance = (torch.full([1], guidance_scale, device=device, dtype=torch.float32)
236
- .expand(latents.shape[0])
237
- if self.transformer.config.guidance_embeds else None)
238
  for i, t in enumerate(timesteps):
239
  if self.interrupt:
240
  continue
 
241
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
 
242
  noise_pred = self.transformer(
243
  hidden_states=latents,
244
  timestep=timestep / 1000,
@@ -250,12 +181,14 @@ def flux_pipe_call_that_returns_an_iterable_of_images(self,
250
  joint_attention_kwargs=self.joint_attention_kwargs,
251
  return_dict=False,
252
  )[0]
 
253
  latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
254
  latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
255
  image = self.vae.decode(latents_for_image, return_dict=False)[0]
256
  yield self.image_processor.postprocess(image, output_type=output_type)[0]
257
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
258
  torch.cuda.empty_cache()
 
259
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
260
  latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
261
  image = good_vae.decode(latents, return_dict=False)[0]
@@ -263,13 +196,157 @@ def flux_pipe_call_that_returns_an_iterable_of_images(self,
263
  torch.cuda.empty_cache()
264
  yield self.image_processor.postprocess(image, output_type=output_type)[0]
265
 
266
- def get_huggingface_safetensors(link: str) -> tuple:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  split_link = link.split("/")
268
- if len(split_link) == 2:
269
  model_card = ModelCard.load(link)
270
- base_model_card = model_card.data.get("base_model")
271
- print(base_model_card)
272
- if base_model_card not in ("black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"):
273
  raise Exception("Flux LoRA Not Found!")
274
  image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
275
  trigger_word = model_card.data.get("instance_prompt", "")
@@ -278,50 +355,47 @@ def get_huggingface_safetensors(link: str) -> tuple:
278
  try:
279
  list_of_files = fs.ls(link, detail=False)
280
  for file in list_of_files:
281
- if file.endswith(".safetensors"):
282
  safetensors_name = file.split("/")[-1]
283
- if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
284
  image_elements = file.split("/")
285
  image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
286
  except Exception as e:
287
  print(e)
288
- raise Exception("Invalid LoRA repository")
 
289
  return split_link[1], link, safetensors_name, trigger_word, image_url
290
  else:
291
  raise Exception("Invalid LoRA link format")
292
 
293
- def check_custom_model(link: str) -> tuple:
294
- if link.startswith("https://"):
295
- if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
296
  link_split = link.split("huggingface.co/")
297
  return get_huggingface_safetensors(link_split[1])
298
- return get_huggingface_safetensors(link)
299
-
300
- def create_lora_card(title: str, repo: str, trigger_word: str, image: str) -> str:
301
- trigger_word_info = (f"Using: <code><b>{trigger_word}</b></code> as the trigger word"
302
- if trigger_word else "No trigger word found. Include it in your prompt")
303
- return f'''
304
- <div class="custom_lora_card">
305
- <span>Loaded custom LoRA:</span>
306
- <div class="card_internal">
307
- <img src="{image}" />
308
- <div>
309
- <h3>{title}</h3>
310
- <small>{trigger_word_info}<br></small>
311
- </div>
312
- </div>
313
- </div>
314
- '''
315
 
316
- def add_custom_lora(custom_lora: str) -> tuple:
317
  global loras
318
- if custom_lora:
319
  try:
320
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
321
  print(f"Loaded custom LoRA: {repo}")
322
- card = create_lora_card(title, repo, trigger_word, image)
 
 
 
 
 
 
 
 
 
 
 
323
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
324
- if existing_item_index is None:
325
  new_item = {
326
  "image": image,
327
  "title": title,
@@ -330,334 +404,112 @@ def add_custom_lora(custom_lora: str) -> tuple:
330
  "trigger_word": trigger_word
331
  }
332
  print(new_item)
 
333
  loras.append(new_item)
334
- existing_item_index = len(loras) - 1
335
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
336
  except Exception as e:
337
- print(f"Error loading LoRA: {e}")
338
- return gr.update(visible=True, value="Invalid LoRA"), gr.update(visible=False), gr.update(), "", None, ""
339
  else:
340
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
341
 
342
- def remove_custom_lora() -> tuple:
343
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
344
 
345
- def prepare_prompt(prompt: str, selected_index: Optional[int], loras_list: list) -> str:
346
- if selected_index is None:
347
- raise gr.Error("You must select a LoRA before proceeding.🧨")
348
- selected_lora = loras_list[selected_index]
349
- trigger_word = selected_lora.get("trigger_word")
350
- if trigger_word:
351
- trigger_position = selected_lora.get("trigger_position", "append")
352
- if trigger_position == "prepend":
353
- prompt_mash = f"{trigger_word} {prompt}"
354
- else:
355
- prompt_mash = f"{prompt} {trigger_word}"
356
- else:
357
- prompt_mash = prompt
358
- # Truncate the prompt using the tokenizer to ensure token indices are in range.
359
- prompt_mash = truncate_prompt(prompt_mash)
360
- return prompt_mash
361
-
362
- def unload_lora_weights(pipe, pipe_i2i):
363
- if pipe is not None:
364
- pipe.unload_lora_weights()
365
- if pipe_i2i is not None:
366
- pipe_i2i.unload_lora_weights()
367
-
368
- def load_lora_weights_into_pipeline(pipe_to_use, lora_path: str, weight_name: Optional[str]):
369
- pipe_to_use.load_lora_weights(
370
- lora_path,
371
- weight_name=weight_name,
372
- low_cpu_mem_usage=True
373
  )
374
-
375
- def update_selection(evt: gr.SelectData, width, height) -> tuple:
376
- selected_lora = loras[evt.index]
377
- new_placeholder = f"Type a prompt for {selected_lora['title']}"
378
- lora_repo = selected_lora["repo"]
379
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✅"
380
- if "aspect" in selected_lora:
381
- if selected_lora["aspect"] == "portrait":
382
- width = 768
383
- height = 1024
384
- elif selected_lora["aspect"] == "landscape":
385
- width = 1024
386
- height = 768
387
- else:
388
- width = 1024
389
- height = 1024
390
- return (
391
- gr.update(placeholder=new_placeholder),
392
- updated_text,
393
- evt.index,
394
- width,
395
- height,
396
- )
397
-
398
- ##############################
399
- # ===== backend.py =====
400
- ##############################
401
- class ModelManager:
402
- def __init__(self, hf_token=None):
403
- self.hf_token = hf_token
404
- self.pipe = None
405
- self.pipe_i2i = None
406
- self.good_vae = None
407
- self.taef1 = None
408
- self.initialize_models()
409
-
410
- def initialize_models(self):
411
- self.taef1 = AutoencoderTiny.from_pretrained(TAEF1_MODEL, torch_dtype=DTYPE).to(DEVICE)
412
- self.good_vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae", torch_dtype=DTYPE).to(DEVICE)
413
- self.pipe = DiffusionPipeline.from_pretrained(BASE_MODEL, torch_dtype=DTYPE, vae=self.taef1).to(DEVICE)
414
- self.pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
415
- BASE_MODEL,
416
- vae=self.good_vae,
417
- transformer=self.pipe.transformer,
418
- text_encoder=self.pipe.text_encoder,
419
- tokenizer=self.pipe.tokenizer,
420
- text_encoder_2=self.pipe.text_encoder_2,
421
- tokenizer_2=self.pipe.tokenizer_2,
422
- torch_dtype=DTYPE,
423
- ).to(DEVICE)
424
- # Bind the custom LoRA method to the pipeline class (to avoid __slots__ issues)
425
- self.pipe.__class__.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images
426
-
427
- @spaces.GPU(duration=100)
428
- def generate_image(self, prompt_mash, steps, seed, cfg_scale, width, height, lora_scale):
429
- generator = torch.Generator(device=DEVICE).manual_seed(seed)
430
- with calculateDuration("Generating image"):
431
- for img in self.pipe.flux_pipe_call_that_returns_an_iterable_of_images(
432
- prompt=prompt_mash,
433
- num_inference_steps=steps,
434
- guidance_scale=cfg_scale,
435
- width=width,
436
- height=height,
437
- generator=generator,
438
- joint_attention_kwargs={"scale": lora_scale},
439
- output_type="pil",
440
- good_vae=self.good_vae,
441
- ):
442
- yield img
443
-
444
- @spaces.GPU(duration=100)
445
- def generate_image_to_image(self, prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
446
- generator = torch.Generator(device=DEVICE).manual_seed(seed)
447
- image_input = load_image_from_path(image_input_path)
448
- with calculateDuration("Generating image to image"):
449
- final_image = self.pipe_i2i(
450
- prompt=prompt_mash,
451
- image=image_input,
452
- strength=image_strength,
453
- num_inference_steps=steps,
454
- guidance_scale=cfg_scale,
455
- width=width,
456
- height=height,
457
- generator=generator,
458
- joint_attention_kwargs={"scale": lora_scale},
459
- output_type="pil",
460
- ).images[0]
461
- return final_image
462
-
463
- ##############################
464
- # ===== frontend.py =====
465
- ##############################
466
- class Frontend:
467
- def __init__(self, model_manager: ModelManager):
468
- self.model_manager = model_manager
469
- self.loras = loras
470
- self.load_initial_loras()
471
- self.css = self.define_css()
472
-
473
- def define_css(self):
474
- return '''
475
- /* Title Styling */
476
- #title {
477
- text-align: center;
478
- margin-bottom: 20px;
479
- }
480
- #title h1 {
481
- font-size: 2.5rem;
482
- margin: 0;
483
- color: #333;
484
- }
485
- /* Button and Column Styling */
486
- #gen_btn {
487
- width: 100%;
488
- padding: 12px;
489
- font-weight: bold;
490
- border-radius: 5px;
491
- }
492
- #gen_column {
493
- display: flex;
494
- align-items: center;
495
- justify-content: center;
496
- }
497
- /* Gallery and List Styling */
498
- #gallery .grid-wrap {
499
- margin-top: 15px;
500
- }
501
- #lora_list {
502
- background-color: #f5f5f5;
503
- padding: 10px;
504
- border-radius: 4px;
505
- font-size: 0.9rem;
506
- }
507
- .card_internal {
508
- display: flex;
509
- align-items: center;
510
- height: 100px;
511
- margin-top: 10px;
512
- }
513
- .card_internal img {
514
- margin-right: 10px;
515
- }
516
- .styler {
517
- --form-gap-width: 0px !important;
518
- }
519
- /* Progress Bar Styling */
520
- .progress-container {
521
- width: 100%;
522
- height: 20px;
523
- background-color: #e0e0e0;
524
- border-radius: 10px;
525
- overflow: hidden;
526
- margin-bottom: 20px;
527
- }
528
- .progress-bar {
529
- height: 100%;
530
- background-color: #4f46e5;
531
- transition: width 0.3s ease-in-out;
532
- width: calc(var(--current) / var(--total) * 100%);
533
- }
534
- '''
535
-
536
- def load_initial_loras(self):
537
- try:
538
- from lora import loras as loras_list
539
- self.loras = loras_list
540
- except ImportError:
541
- print("Warning: lora.py not found, using placeholder LoRAs.")
542
-
543
- @spaces.GPU(duration=100)
544
- def run_lora(self, prompt, image_input, image_strength, cfg_scale, steps, selected_index,
545
- randomize_seed, seed, width, height, lora_scale, use_enhancer,
546
- progress=gr.Progress(track_tqdm=True)):
547
- seed = randomize_seed_if_needed(randomize_seed, seed, MAX_SEED)
548
- prompt_mash = prepare_prompt(prompt, selected_index, self.loras)
549
- enhanced_text = ""
550
- if use_enhancer:
551
- for enhanced_chunk in generate(prompt_mash):
552
- enhanced_text = enhanced_chunk
553
- yield None, seed, gr.update(visible=False), enhanced_text
554
- prompt_mash = enhanced_text
555
- else:
556
- enhanced_text = ""
557
- selected_lora = self.loras[selected_index]
558
- unload_lora_weights(self.model_manager.pipe, self.model_manager.pipe_i2i)
559
- pipe_to_use = self.model_manager.pipe_i2i if image_input is not None else self.model_manager.pipe
560
- load_lora_weights_into_pipeline(pipe_to_use, selected_lora["repo"], selected_lora.get("weights"))
561
- if image_input is not None:
562
- final_image = self.model_manager.generate_image_to_image(
563
- prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed
564
  )
565
- yield final_image, seed, gr.update(visible=False), enhanced_text
566
- else:
567
- image_generator = self.model_manager.generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale)
568
- final_image = None
569
- step_counter = 0
570
- for image in image_generator:
571
- step_counter += 1
572
- final_image = image
573
- progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
574
- yield image, seed, gr.update(value=progress_bar, visible=True), enhanced_text
575
- yield final_image, seed, gr.update(value=progress_bar, visible=False), enhanced_text
576
-
577
- def create_ui(self):
578
- with gr.Blocks(theme=gr.themes.Base(), css=self.css, title="Flux LoRA Generation") as app:
579
- title = gr.HTML("<h1>Flux LoRA Generation</h1>", elem_id="title")
580
- selected_index = gr.State(None)
581
- with gr.Row():
582
- with gr.Column(scale=3):
583
- prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Choose the LoRA and type the prompt")
584
- with gr.Column(scale=1, elem_id="gen_column"):
585
- generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
586
  with gr.Row():
587
- with gr.Column():
588
- selected_info = gr.Markdown("")
589
- gallery = gr.Gallery(
590
- [(item["image"], item["title"]) for item in self.loras],
591
- label="LoRA Collection",
592
- allow_preview=False,
593
- columns=3,
594
- elem_id="gallery",
595
- show_share_button=False
596
- )
597
- with gr.Group():
598
- custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="prithivMLmods/Canopus-LoRA-Flux-Anime")
599
- gr.Markdown("[Check the list of FLUX LoRA's](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
600
- custom_lora_info = gr.HTML(visible=False)
601
- custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
602
- with gr.Column():
603
- progress_bar = gr.Markdown(elem_id="progress", visible=False)
604
- result = gr.Image(label="Generated Image")
605
- with gr.Row():
606
- with gr.Accordion("Advanced Settings", open=False):
607
- with gr.Row():
608
- input_image = gr.Image(label="Input image", type="filepath")
609
- image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
610
- with gr.Column():
611
- with gr.Row():
612
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
613
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
614
- with gr.Row():
615
- width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
616
- height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
617
- with gr.Row():
618
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
619
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
620
- lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=0.95)
621
- with gr.Row():
622
- use_enhancer = gr.Checkbox(value=False, label="Use Prompt Enhancer")
623
- show_enhanced_prompt = gr.Checkbox(value=False, label="Display Enhanced Prompt")
624
- enhanced_prompt_box = gr.Textbox(label="Enhanced Prompt", visible=False)
625
- gallery.select(
626
- update_selection,
627
- inputs=[width, height],
628
- outputs=[prompt, selected_info, selected_index, width, height]
629
- )
630
- custom_lora.input(
631
- add_custom_lora,
632
- inputs=[custom_lora],
633
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
634
- )
635
- custom_lora_button.click(
636
- remove_custom_lora,
637
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
638
- )
639
  show_enhanced_prompt.change(fn=lambda show: gr.update(visible=show),
640
  inputs=show_enhanced_prompt,
641
  outputs=enhanced_prompt_box)
642
- gr.on(
643
- triggers=[generate_button.click, prompt.submit],
644
- fn=self.run_lora,
645
- inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer],
646
- outputs=[result, seed, progress_bar, enhanced_prompt_box]
647
- )
648
- with gr.Row():
649
- gr.HTML("<div style='text-align:center; font-size:0.9em; margin-top:20px;'>Credits: <a href='https://ruslanmv.com' target='_blank'>ruslanmv.com</a></div>")
650
- return app
651
-
652
- ##############################
653
- # ===== Main app.py =====
654
- ##############################
655
- if __name__ == "__main__":
656
- hf_token = os.environ.get("HF_TOKEN")
657
- if not hf_token:
658
- raise ValueError("Hugging Face token (HF_TOKEN) not found in environment variables. Please set it.")
659
- model_manager = ModelManager(hf_token=hf_token)
660
- frontend = Frontend(model_manager)
661
- app = frontend.create_ui()
662
- app.queue()
663
- app.launch(share=False, debug=True)
 
 
 
 
 
1
  import os
2
+ import json
3
+ import copy
4
  import time
5
  import random
6
+ import logging
7
+ import numpy as np
8
  from typing import Any, Dict, List, Optional, Union
9
 
10
  import torch
 
11
  from PIL import Image
12
  import gradio as gr
 
13
 
 
14
  from diffusers import (
15
  DiffusionPipeline,
16
  AutoencoderTiny,
17
  AutoencoderKL,
18
  AutoPipelineForImage2Image,
19
+ FluxPipeline,
20
+ FlowMatchEulerDiscreteScheduler
21
+ )
22
+
23
+ from huggingface_hub import (
24
+ hf_hub_download,
25
+ HfFileSystem,
26
+ ModelCard,
27
+ snapshot_download
28
  )
29
+
30
  from diffusers.utils import load_image
31
 
32
+ import spaces
33
+
34
+ # Attempt to import loras from lora.py; otherwise use a default placeholder.
35
+ try:
36
+ from lora import loras
37
+ except ImportError:
38
+ loras = [
39
+ {"image": "placeholder.jpg", "title": "Placeholder LoRA", "repo": "placeholder/repo", "weights": None, "trigger_word": ""}
40
+ ]
41
+
42
+ #---if workspace = local or colab---
43
+ # (Optional: add Hugging Face login code here)
44
+
45
+ def calculate_shift(
46
+ image_seq_len,
47
+ base_seq_len: int = 256,
48
+ max_seq_len: int = 4096,
49
+ base_shift: float = 0.5,
50
+ max_shift: float = 1.16,
51
+ ):
52
  m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
53
  b = base_shift - m * base_seq_len
54
  mu = image_seq_len * m + b
55
  return mu
56
 
57
+ def retrieve_timesteps(
58
+ scheduler,
59
+ num_inference_steps: Optional[int] = None,
60
+ device: Optional[Union[str, torch.device]] = None,
61
+ timesteps: Optional[List[int]] = None,
62
+ sigmas: Optional[List[float]] = None,
63
+ **kwargs,
64
+ ):
65
  if timesteps is not None and sigmas is not None:
66
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
67
  if timesteps is not None:
68
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
69
  timesteps = scheduler.timesteps
 
77
  timesteps = scheduler.timesteps
78
  return timesteps, num_inference_steps
79
 
80
+ # FLUX pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  @torch.inference_mode()
82
+ def flux_pipe_call_that_returns_an_iterable_of_images(
83
+ self,
84
+ prompt: Union[str, List[str]] = None,
85
+ prompt_2: Optional[Union[str, List[str]]] = None,
86
+ height: Optional[int] = None,
87
+ width: Optional[int] = None,
88
+ num_inference_steps: int = 28,
89
+ timesteps: List[int] = None,
90
+ guidance_scale: float = 3.5,
91
+ num_images_per_prompt: Optional[int] = 1,
92
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
93
+ latents: Optional[torch.FloatTensor] = None,
94
+ prompt_embeds: Optional[torch.FloatTensor] = None,
95
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
96
+ output_type: Optional[str] = "pil",
97
+ return_dict: bool = True,
98
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
99
+ max_sequence_length: int = 512,
100
+ good_vae: Optional[Any] = None,
101
+ ):
102
  height = height or self.default_sample_size * self.vae_scale_factor
103
  width = width or self.default_sample_size * self.vae_scale_factor
104
+
105
  self.check_inputs(
106
  prompt,
107
  prompt_2,
 
111
  pooled_prompt_embeds=pooled_prompt_embeds,
112
  max_sequence_length=max_sequence_length,
113
  )
114
+
115
  self._guidance_scale = guidance_scale
116
  self._joint_attention_kwargs = joint_attention_kwargs
117
  self._interrupt = False
118
+
119
  batch_size = 1 if isinstance(prompt, str) else len(prompt)
120
  device = self._execution_device
121
+
122
  lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
123
  prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
124
  prompt=prompt,
 
130
  max_sequence_length=max_sequence_length,
131
  lora_scale=lora_scale,
132
  )
133
+
134
  num_channels_latents = self.transformer.config.in_channels // 4
135
  latents, latent_image_ids = self.prepare_latents(
136
  batch_size * num_images_per_prompt,
 
142
  generator,
143
  latents,
144
  )
145
+
146
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
147
  image_seq_len = latents.shape[1]
148
  mu = calculate_shift(
 
161
  mu=mu,
162
  )
163
  self._num_timesteps = len(timesteps)
164
+
165
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
166
+
167
  for i, t in enumerate(timesteps):
168
  if self.interrupt:
169
  continue
170
+
171
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
172
+
173
  noise_pred = self.transformer(
174
  hidden_states=latents,
175
  timestep=timestep / 1000,
 
181
  joint_attention_kwargs=self.joint_attention_kwargs,
182
  return_dict=False,
183
  )[0]
184
+
185
  latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
186
  latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
187
  image = self.vae.decode(latents_for_image, return_dict=False)[0]
188
  yield self.image_processor.postprocess(image, output_type=output_type)[0]
189
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
190
  torch.cuda.empty_cache()
191
+
192
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
193
  latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
194
  image = good_vae.decode(latents, return_dict=False)[0]
 
196
  torch.cuda.empty_cache()
197
  yield self.image_processor.postprocess(image, output_type=output_type)[0]
198
 
199
+ #--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
200
+ dtype = torch.bfloat16
201
+ device = "cuda" if torch.cuda.is_available() else "cpu"
202
+ base_model = "black-forest-labs/FLUX.1-dev"
203
+
204
+ # TAEF1 is a very tiny autoencoder which uses the same "latent API" as FLUX.1's VAE.
205
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
206
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
207
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
208
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model,
209
+ vae=good_vae,
210
+ transformer=pipe.transformer,
211
+ text_encoder=pipe.text_encoder,
212
+ tokenizer=pipe.tokenizer,
213
+ text_encoder_2=pipe.text_encoder_2,
214
+ tokenizer_2=pipe.tokenizer_2,
215
+ torch_dtype=dtype
216
+ )
217
+ MAX_SEED = 2**32-1
218
+
219
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
220
+
221
+ class calculateDuration:
222
+ def __init__(self, activity_name=""):
223
+ self.activity_name = activity_name
224
+ def __enter__(self):
225
+ self.start_time = time.time()
226
+ return self
227
+ def __exit__(self, exc_type, exc_value, traceback):
228
+ self.end_time = time.time()
229
+ self.elapsed_time = self.end_time - self.start_time
230
+ if self.activity_name:
231
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
232
+ else:
233
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
234
+
235
+ def update_selection(evt: gr.SelectData, width, height):
236
+ selected_lora = loras[evt.index]
237
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
238
+ lora_repo = selected_lora["repo"]
239
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✅"
240
+ if "aspect" in selected_lora:
241
+ if selected_lora["aspect"] == "portrait":
242
+ width = 768
243
+ height = 1024
244
+ elif selected_lora["aspect"] == "landscape":
245
+ width = 1024
246
+ height = 768
247
+ else:
248
+ width = 1024
249
+ height = 1024
250
+ return (
251
+ gr.update(placeholder=new_placeholder),
252
+ updated_text,
253
+ evt.index,
254
+ width,
255
+ height,
256
+ )
257
+
258
+ @spaces.GPU(duration=100)
259
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
260
+ pipe.to("cuda")
261
+ generator = torch.Generator(device="cuda").manual_seed(seed)
262
+ with calculateDuration("Generating image"):
263
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
264
+ prompt=prompt_mash,
265
+ num_inference_steps=steps,
266
+ guidance_scale=cfg_scale,
267
+ width=width,
268
+ height=height,
269
+ generator=generator,
270
+ joint_attention_kwargs={"scale": lora_scale},
271
+ output_type="pil",
272
+ good_vae=good_vae,
273
+ ):
274
+ yield img
275
+
276
+ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
277
+ generator = torch.Generator(device="cuda").manual_seed(seed)
278
+ pipe_i2i.to("cuda")
279
+ image_input = load_image(image_input_path)
280
+ final_image = pipe_i2i(
281
+ prompt=prompt_mash,
282
+ image=image_input,
283
+ strength=image_strength,
284
+ num_inference_steps=steps,
285
+ guidance_scale=cfg_scale,
286
+ width=width,
287
+ height=height,
288
+ generator=generator,
289
+ joint_attention_kwargs={"scale": lora_scale},
290
+ output_type="pil",
291
+ ).images[0]
292
+ return final_image
293
+
294
+ @spaces.GPU(duration=100)
295
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
296
+ if selected_index is None:
297
+ raise gr.Error("You must select a LoRA before proceeding.🧨")
298
+ selected_lora = loras[selected_index]
299
+ lora_path = selected_lora["repo"]
300
+ trigger_word = selected_lora["trigger_word"]
301
+ if(trigger_word):
302
+ if "trigger_position" in selected_lora:
303
+ if selected_lora["trigger_position"] == "prepend":
304
+ prompt_mash = f"{trigger_word} {prompt}"
305
+ else:
306
+ prompt_mash = f"{prompt} {trigger_word}"
307
+ else:
308
+ prompt_mash = f"{trigger_word} {prompt}"
309
+ else:
310
+ prompt_mash = prompt
311
+
312
+ with calculateDuration("Unloading LoRA"):
313
+ pipe.unload_lora_weights()
314
+ pipe_i2i.unload_lora_weights()
315
+
316
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
317
+ pipe_to_use = pipe_i2i if image_input is not None else pipe
318
+ weight_name = selected_lora.get("weights", None)
319
+ pipe_to_use.load_lora_weights(
320
+ lora_path,
321
+ weight_name=weight_name,
322
+ low_cpu_mem_usage=True
323
+ )
324
+
325
+ with calculateDuration("Randomizing seed"):
326
+ if randomize_seed:
327
+ seed = random.randint(0, MAX_SEED)
328
+
329
+ if(image_input is not None):
330
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
331
+ yield final_image, seed, gr.update(visible=False)
332
+ else:
333
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
334
+ final_image = None
335
+ step_counter = 0
336
+ for image in image_generator:
337
+ step_counter += 1
338
+ final_image = image
339
+ progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
340
+ yield image, seed, gr.update(value=progress_bar, visible=True)
341
+ yield final_image, seed, gr.update(value=progress_bar, visible=False)
342
+
343
+ def get_huggingface_safetensors(link):
344
  split_link = link.split("/")
345
+ if(len(split_link) == 2):
346
  model_card = ModelCard.load(link)
347
+ base_model = model_card.data.get("base_model")
348
+ print(base_model)
349
+ if((base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell")):
350
  raise Exception("Flux LoRA Not Found!")
351
  image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
352
  trigger_word = model_card.data.get("instance_prompt", "")
 
355
  try:
356
  list_of_files = fs.ls(link, detail=False)
357
  for file in list_of_files:
358
+ if(file.endswith(".safetensors")):
359
  safetensors_name = file.split("/")[-1]
360
+ if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
361
  image_elements = file.split("/")
362
  image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
363
  except Exception as e:
364
  print(e)
365
+ gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
366
+ raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
367
  return split_link[1], link, safetensors_name, trigger_word, image_url
368
  else:
369
  raise Exception("Invalid LoRA link format")
370
 
371
+ def check_custom_model(link):
372
+ if(link.startswith("https://")):
373
+ if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
374
  link_split = link.split("huggingface.co/")
375
  return get_huggingface_safetensors(link_split[1])
376
+ else:
377
+ return get_huggingface_safetensors(link)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
 
379
+ def add_custom_lora(custom_lora):
380
  global loras
381
+ if(custom_lora):
382
  try:
383
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
384
  print(f"Loaded custom LoRA: {repo}")
385
+ card = f'''
386
+ <div class="custom_lora_card">
387
+ <span>Loaded custom LoRA:</span>
388
+ <div class="card_internal">
389
+ <img src="{image}" />
390
+ <div>
391
+ <h3>{title}</h3>
392
+ <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
393
+ </div>
394
+ </div>
395
+ </div>
396
+ '''
397
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
398
+ if(not existing_item_index):
399
  new_item = {
400
  "image": image,
401
  "title": title,
 
404
  "trigger_word": trigger_word
405
  }
406
  print(new_item)
407
+ existing_item_index = len(loras)
408
  loras.append(new_item)
409
+
410
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
411
  except Exception as e:
412
+ gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-FLUX LoRA")
413
+ return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-FLUX LoRA"), gr.update(visible=False), gr.update(), "", None, ""
414
  else:
415
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
416
 
417
+ def remove_custom_lora():
418
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
419
 
420
+ run_lora.zerogpu = True
421
+
422
+ css = '''
423
+ #gen_btn{height: 100%}
424
+ #gen_column{align-self: stretch}
425
+ #title{text-align: center}
426
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
427
+ #title img{width: 100px; margin-right: 0.5em}
428
+ #gallery .grid-wrap{height: 10vh}
429
+ #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
430
+ .card_internal{display: flex;height: 100px;margin-top: .5em}
431
+ .card_internal img{margin-right: 1em}
432
+ .styler{--form-gap-width: 0px !important}
433
+ #progress{height:30px}
434
+ #progress .generating{display:none}
435
+ .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
436
+ .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
437
+ '''
438
+
439
+ with gr.Blocks(theme="YTheme/Minecraft", css=css, delete_cache=(60, 60)) as app:
440
+ title = gr.HTML(
441
+ """<h1>FLUX LoRA DLC🥳</h1>""",
442
+ elem_id="title",
 
 
 
 
 
443
  )
444
+ selected_index = gr.State(None)
445
+ with gr.Row():
446
+ with gr.Column(scale=3):
447
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder=":/ choose the LoRA and type the prompt ")
448
+ with gr.Column(scale=1, elem_id="gen_column"):
449
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
450
+ with gr.Row():
451
+ with gr.Column():
452
+ selected_info = gr.Markdown("")
453
+ gallery = gr.Gallery(
454
+ [(item["image"], item["title"]) for item in loras],
455
+ label="LoRA DLC's",
456
+ allow_preview=False,
457
+ columns=3,
458
+ elem_id="gallery",
459
+ show_share_button=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  )
461
+ with gr.Group():
462
+ custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="prithivMLmods/Canopus-LoRA-Flux-Anime")
463
+ gr.Markdown("[Check the list of FLUX LoRA's](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
464
+ custom_lora_info = gr.HTML(visible=False)
465
+ custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
466
+ with gr.Column():
467
+ progress_bar = gr.Markdown(elem_id="progress",visible=False)
468
+ result = gr.Image(label="Generated Image")
469
+ with gr.Row():
470
+ with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
471
  with gr.Row():
472
+ input_image = gr.Image(label="Input image", type="filepath")
473
+ image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
474
+ with gr.Column():
475
+ with gr.Row():
476
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
477
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
478
+ with gr.Row():
479
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
480
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
481
+ with gr.Row():
482
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
483
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
484
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=0.95)
485
+ with gr.Row():
486
+ use_enhancer = gr.Checkbox(value=False, label="Use Prompt Enhancer")
487
+ show_enhanced_prompt = gr.Checkbox(value=False, label="Display Enhanced Prompt")
488
+ enhanced_prompt_box = gr.Textbox(label="Enhanced Prompt", visible=False)
489
+ # Add the change event so that the enhanced prompt box visibility toggles.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  show_enhanced_prompt.change(fn=lambda show: gr.update(visible=show),
491
  inputs=show_enhanced_prompt,
492
  outputs=enhanced_prompt_box)
493
+ gallery.select(
494
+ update_selection,
495
+ inputs=[width, height],
496
+ outputs=[prompt, selected_info, selected_index, width, height]
497
+ )
498
+ custom_lora.input(
499
+ add_custom_lora,
500
+ inputs=[custom_lora],
501
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
502
+ )
503
+ custom_lora_button.click(
504
+ remove_custom_lora,
505
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
506
+ )
507
+ gr.on(
508
+ triggers=[generate_button.click, prompt.submit],
509
+ fn=run_lora,
510
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
511
+ outputs=[result, seed, progress_bar]
512
+ )
513
+
514
+ app.queue()
515
+ app.launch(debug=True)