ruslanmv commited on
Commit
0b25329
·
verified ·
1 Parent(s): aadfbd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -140
app.py CHANGED
@@ -6,21 +6,12 @@ import sys
6
  import time
7
  import random
8
  import json
9
- from math import floor
10
  from typing import Any, Dict, List, Optional, Union
11
 
12
- # Local import for default LoRA list (if available)
13
- try:
14
- from flux_app.lora import loras
15
- except ImportError:
16
- loras = [
17
- {"image": "placeholder.jpg", "title": "Placeholder LoRA", "repo": "placeholder/repo", "weights": None, "trigger_word": ""}
18
- ]
19
-
20
  import torch
21
  import numpy as np
22
- import requests
23
  from PIL import Image
 
24
  import spaces
25
 
26
  # Diffusers imports
@@ -32,12 +23,9 @@ from diffusers import (
32
  )
33
  from diffusers.utils import load_image
34
 
35
- # Hugging Face Hub
36
  from huggingface_hub import ModelCard, HfFileSystem
37
 
38
- # Gradio (UI)
39
- import gradio as gr
40
-
41
  ##############################
42
  # ===== config.py =====
43
  ##############################
@@ -50,28 +38,20 @@ MAX_SEED = 2**32 - 1
50
  ##############################
51
  # ===== utilities.py =====
52
  ##############################
53
- def calculate_shift(
54
- image_seq_len,
55
- base_seq_len: int = 256,
56
- max_seq_len: int = 4096,
57
- base_shift: float = 0.5,
58
- max_shift: float = 1.16,
59
- ):
60
  m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
61
  b = base_shift - m * base_seq_len
62
  mu = image_seq_len * m + b
63
  return mu
64
 
65
- def retrieve_timesteps(
66
- scheduler,
67
- num_inference_steps: Optional[int] = None,
68
- device: Optional[Union[str, torch.device]] = None,
69
- timesteps: Optional[List[int]] = None,
70
- sigmas: Optional[List[float]] = None,
71
- **kwargs,
72
- ):
73
  if timesteps is not None and sigmas is not None:
74
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
75
  if timesteps is not None:
76
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
77
  timesteps = scheduler.timesteps
@@ -86,11 +66,9 @@ def retrieve_timesteps(
86
  return timesteps, num_inference_steps
87
 
88
  def load_image_from_path(image_path: str):
89
- """Loads an image from a given file path."""
90
  return load_image(image_path)
91
 
92
  def randomize_seed_if_needed(randomize_seed: bool, seed: int, max_seed: int) -> int:
93
- """Randomizes the seed if requested."""
94
  if randomize_seed:
95
  return random.randint(0, max_seed)
96
  return seed
@@ -98,40 +76,29 @@ def randomize_seed_if_needed(randomize_seed: bool, seed: int, max_seed: int) ->
98
  class calculateDuration:
99
  def __init__(self, activity_name=""):
100
  self.activity_name = activity_name
101
-
102
  def __enter__(self):
103
  self.start_time = time.time()
104
  return self
105
-
106
  def __exit__(self, exc_type, exc_value, traceback):
107
  self.end_time = time.time()
108
- self.elapsed_time = self.end_time - self.start_time
109
  if self.activity_name:
110
- print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
111
  else:
112
- print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
113
 
114
  ##############################
115
  # ===== enhance.py =====
116
  ##############################
117
  def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0):
118
- """
119
- Generates an enhanced prompt using a streaming Hugging Face API.
120
- Enhances the given prompt under 100 words without changing its essence.
121
- """
122
  SYSTEM_PROMPT = (
123
  "You are a prompt enhancer and your work is to enhance the given prompt under 100 words "
124
  "without changing the essence, only write the enhanced prompt and nothing else."
125
  )
126
  timestamp = time.time()
127
- formatted_prompt = (
128
- f"<s>[INST] SYSTEM: {SYSTEM_PROMPT} [/INST]"
129
- f"[INST] {message} {timestamp} [/INST]"
130
- )
131
-
132
  api_url = "https://ruslanmv-hf-llm-api.hf.space/api/v1/chat/completions"
133
  headers = {"Content-Type": "application/json"}
134
-
135
  payload = {
136
  "model": "mixtral-8x7b",
137
  "messages": [{"role": "user", "content": formatted_prompt}],
@@ -141,12 +108,10 @@ def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetitio
141
  "use_cache": False,
142
  "stream": True
143
  }
144
-
145
  try:
146
  response = requests.post(api_url, headers=headers, json=payload, stream=True)
147
  response.raise_for_status()
148
  full_output = ""
149
-
150
  for line in response.iter_lines():
151
  if not line:
152
  continue
@@ -172,32 +137,30 @@ def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetitio
172
  ##############################
173
  # ===== lora_handling.py =====
174
  ##############################
175
- # A default list of LoRAs for the UI
176
  loras = [
177
  {"image": "placeholder.jpg", "title": "Placeholder LoRA", "repo": "placeholder/repo", "weights": None, "trigger_word": ""}
178
  ]
179
 
180
  @torch.inference_mode()
181
- def flux_pipe_call_that_returns_an_iterable_of_images(
182
- self,
183
- prompt: Union[str, List[str]] = None,
184
- prompt_2: Optional[Union[str, List[str]]] = None,
185
- height: Optional[int] = None,
186
- width: Optional[int] = None,
187
- num_inference_steps: int = 28,
188
- timesteps: List[int] = None,
189
- guidance_scale: float = 3.5,
190
- num_images_per_prompt: Optional[int] = 1,
191
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
192
- latents: Optional[torch.FloatTensor] = None,
193
- prompt_embeds: Optional[torch.FloatTensor] = None,
194
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
195
- output_type: Optional[str] = "pil",
196
- return_dict: bool = True,
197
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
198
- max_sequence_length: int = 512,
199
- good_vae: Optional[Any] = None,
200
- ):
201
  height = height or self.default_sample_size * self.vae_scale_factor
202
  width = width or self.default_sample_size * self.vae_scale_factor
203
 
@@ -210,14 +173,11 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
210
  pooled_prompt_embeds=pooled_prompt_embeds,
211
  max_sequence_length=max_sequence_length,
212
  )
213
-
214
  self._guidance_scale = guidance_scale
215
  self._joint_attention_kwargs = joint_attention_kwargs
216
  self._interrupt = False
217
-
218
  batch_size = 1 if isinstance(prompt, str) else len(prompt)
219
  device = self._execution_device
220
-
221
  lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
222
  prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
223
  prompt=prompt,
@@ -229,7 +189,6 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
229
  max_sequence_length=max_sequence_length,
230
  lora_scale=lora_scale,
231
  )
232
-
233
  num_channels_latents = self.transformer.config.in_channels // 4
234
  latents, latent_image_ids = self.prepare_latents(
235
  batch_size * num_images_per_prompt,
@@ -241,7 +200,6 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
241
  generator,
242
  latents,
243
  )
244
-
245
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
246
  image_seq_len = latents.shape[1]
247
  mu = calculate_shift(
@@ -260,17 +218,13 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
260
  mu=mu,
261
  )
262
  self._num_timesteps = len(timesteps)
263
-
264
  guidance = (torch.full([1], guidance_scale, device=device, dtype=torch.float32)
265
  .expand(latents.shape[0])
266
  if self.transformer.config.guidance_embeds else None)
267
-
268
  for i, t in enumerate(timesteps):
269
  if self.interrupt:
270
  continue
271
-
272
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
273
-
274
  noise_pred = self.transformer(
275
  hidden_states=latents,
276
  timestep=timestep / 1000,
@@ -282,14 +236,12 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
282
  joint_attention_kwargs=self.joint_attention_kwargs,
283
  return_dict=False,
284
  )[0]
285
-
286
  latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
287
  latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
288
  image = self.vae.decode(latents_for_image, return_dict=False)[0]
289
  yield self.image_processor.postprocess(image, output_type=output_type)[0]
290
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
291
  torch.cuda.empty_cache()
292
-
293
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
294
  latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
295
  image = good_vae.decode(latents, return_dict=False)[0]
@@ -301,12 +253,10 @@ def get_huggingface_safetensors(link: str) -> tuple:
301
  split_link = link.split("/")
302
  if len(split_link) == 2:
303
  model_card = ModelCard.load(link)
304
- base_model = model_card.data.get("base_model")
305
- print(base_model)
306
-
307
- if base_model not in ("black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"):
308
  raise Exception("Flux LoRA Not Found!")
309
-
310
  image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
311
  trigger_word = model_card.data.get("instance_prompt", "")
312
  image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
@@ -319,12 +269,12 @@ def get_huggingface_safetensors(link: str) -> tuple:
319
  if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
320
  image_elements = file.split("/")
321
  image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
322
- return split_link[1], link, safetensors_name, trigger_word, image_url
323
  except Exception as e:
324
  print(e)
325
- raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
 
326
  else:
327
- raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
328
 
329
  def check_custom_model(link: str) -> tuple:
330
  if link.startswith("https://"):
@@ -334,11 +284,8 @@ def check_custom_model(link: str) -> tuple:
334
  return get_huggingface_safetensors(link)
335
 
336
  def create_lora_card(title: str, repo: str, trigger_word: str, image: str) -> str:
337
- trigger_word_info = (
338
- f"Using: <code><b>{trigger_word}</b></code> as the trigger word"
339
- if trigger_word
340
- else "No trigger word found. If there's a trigger word, include it in your prompt"
341
- )
342
  return f'''
343
  <div class="custom_lora_card">
344
  <span>Loaded custom LoRA:</span>
@@ -352,14 +299,14 @@ def create_lora_card(title: str, repo: str, trigger_word: str, image: str) -> st
352
  </div>
353
  '''
354
 
355
- def add_custom_lora(custom_lora: str, loras_list: list) -> tuple:
 
356
  if custom_lora:
357
  try:
358
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
359
  print(f"Loaded custom LoRA: {repo}")
360
  card = create_lora_card(title, repo, trigger_word, image)
361
-
362
- existing_item_index = next((index for (index, item) in enumerate(loras_list) if item['repo'] == repo), None)
363
  if existing_item_index is None:
364
  new_item = {
365
  "image": image,
@@ -369,11 +316,9 @@ def add_custom_lora(custom_lora: str, loras_list: list) -> tuple:
369
  "trigger_word": trigger_word
370
  }
371
  print(new_item)
372
- loras_list.append(new_item)
373
- existing_item_index = len(loras_list) - 1
374
-
375
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
376
-
377
  except Exception as e:
378
  print(f"Error loading LoRA: {e}")
379
  return gr.update(visible=True, value="Invalid LoRA"), gr.update(visible=False), gr.update(), "", None, ""
@@ -386,7 +331,6 @@ def remove_custom_lora() -> tuple:
386
  def prepare_prompt(prompt: str, selected_index: Optional[int], loras_list: list) -> str:
387
  if selected_index is None:
388
  raise gr.Error("You must select a LoRA before proceeding.🧨")
389
-
390
  selected_lora = loras_list[selected_index]
391
  trigger_word = selected_lora.get("trigger_word")
392
  if trigger_word:
@@ -412,8 +356,8 @@ def load_lora_weights_into_pipeline(pipe_to_use, lora_path: str, weight_name: Op
412
  low_cpu_mem_usage=True
413
  )
414
 
415
- def update_selection(evt: gr.SelectData, width, height, loras_list):
416
- selected_lora = loras_list[evt.index]
417
  new_placeholder = f"Type a prompt for {selected_lora['title']}"
418
  lora_repo = selected_lora["repo"]
419
  updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✅"
@@ -448,12 +392,9 @@ class ModelManager:
448
  self.initialize_models()
449
 
450
  def initialize_models(self):
451
- """Initializes the diffusion pipelines and autoencoders."""
452
  self.taef1 = AutoencoderTiny.from_pretrained(TAEF1_MODEL, torch_dtype=DTYPE).to(DEVICE)
453
  self.good_vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae", torch_dtype=DTYPE).to(DEVICE)
454
- # Optionally, pass use_auth_token=self.hf_token if needed.
455
- self.pipe = DiffusionPipeline.from_pretrained(BASE_MODEL, torch_dtype=DTYPE, vae=self.taef1)
456
- self.pipe = self.pipe.to(DEVICE)
457
  self.pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
458
  BASE_MODEL,
459
  vae=self.good_vae,
@@ -464,14 +405,11 @@ class ModelManager:
464
  tokenizer_2=self.pipe.tokenizer_2,
465
  torch_dtype=DTYPE,
466
  ).to(DEVICE)
467
- # Instead of binding to the instance (which fails due to __slots__),
468
- # bind the custom method to the pipeline’s class.
469
  self.pipe.__class__.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images
470
 
471
- @spaces.GPU(duration=100)
472
  def generate_image(self, prompt_mash, steps, seed, cfg_scale, width, height, lora_scale):
473
- """Generates an image using the text-to-image pipeline."""
474
- self.pipe.to(DEVICE)
475
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
476
  with calculateDuration("Generating image"):
477
  for img in self.pipe.flux_pipe_call_that_returns_an_iterable_of_images(
@@ -488,9 +426,7 @@ class ModelManager:
488
  yield img
489
 
490
  def generate_image_to_image(self, prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
491
- """Generates an image using the image-to-image pipeline."""
492
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
493
- self.pipe_i2i.to(DEVICE)
494
  image_input = load_image_from_path(image_input_path)
495
  with calculateDuration("Generating image to image"):
496
  final_image = self.pipe_i2i(
@@ -513,12 +449,11 @@ class ModelManager:
513
  class Frontend:
514
  def __init__(self, model_manager: ModelManager):
515
  self.model_manager = model_manager
516
- self.loras = loras # Use the default LoRA list defined above.
517
  self.load_initial_loras()
518
  self.css = self.define_css()
519
 
520
  def define_css(self):
521
- # Clean and professional CSS styling.
522
  return '''
523
  /* Title Styling */
524
  #title {
@@ -587,18 +522,14 @@ class Frontend:
587
  self.loras = loras_list
588
  except ImportError:
589
  print("Warning: lora.py not found, using placeholder LoRAs.")
590
- pass
591
 
592
  @spaces.GPU(duration=100)
593
  def run_lora(self, prompt, image_input, image_strength, cfg_scale, steps, selected_index,
594
  randomize_seed, seed, width, height, lora_scale, use_enhancer,
595
  progress=gr.Progress(track_tqdm=True)):
596
  seed = randomize_seed_if_needed(randomize_seed, seed, MAX_SEED)
597
- # Prepare the prompt using the selected LoRA trigger word.
598
  prompt_mash = prepare_prompt(prompt, selected_index, self.loras)
599
  enhanced_text = ""
600
-
601
- # Optionally enhance the prompt.
602
  if use_enhancer:
603
  for enhanced_chunk in generate(prompt_mash):
604
  enhanced_text = enhanced_chunk
@@ -606,12 +537,10 @@ class Frontend:
606
  prompt_mash = enhanced_text
607
  else:
608
  enhanced_text = ""
609
-
610
  selected_lora = self.loras[selected_index]
611
  unload_lora_weights(self.model_manager.pipe, self.model_manager.pipe_i2i)
612
  pipe_to_use = self.model_manager.pipe_i2i if image_input is not None else self.model_manager.pipe
613
  load_lora_weights_into_pipeline(pipe_to_use, selected_lora["repo"], selected_lora.get("weights"))
614
-
615
  if image_input is not None:
616
  final_image = self.model_manager.generate_image_to_image(
617
  prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed
@@ -630,12 +559,8 @@ class Frontend:
630
 
631
  def create_ui(self):
632
  with gr.Blocks(theme=gr.themes.Base(), css=self.css, title="Flux LoRA Generation") as app:
633
- title = gr.HTML(
634
- """<h1>Flux LoRA Generation</h1>""",
635
- elem_id="title",
636
- )
637
  selected_index = gr.State(None)
638
-
639
  with gr.Row():
640
  with gr.Column(scale=3):
641
  prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Choose the LoRA and type the prompt")
@@ -660,7 +585,6 @@ class Frontend:
660
  with gr.Column():
661
  progress_bar = gr.Markdown(elem_id="progress", visible=False)
662
  result = gr.Image(label="Generated Image")
663
-
664
  with gr.Row():
665
  with gr.Accordion("Advanced Settings", open=False):
666
  with gr.Row():
@@ -681,44 +605,37 @@ class Frontend:
681
  use_enhancer = gr.Checkbox(value=False, label="Use Prompt Enhancer")
682
  show_enhanced_prompt = gr.Checkbox(value=False, label="Display Enhanced Prompt")
683
  enhanced_prompt_box = gr.Textbox(label="Enhanced Prompt", visible=False)
684
-
685
  gallery.select(
686
  update_selection,
687
- inputs=[width, height, gr.State(self.loras)],
688
  outputs=[prompt, selected_info, selected_index, width, height]
689
  )
690
  custom_lora.input(
691
  add_custom_lora,
692
- inputs=[custom_lora, gr.State(self.loras)],
693
  outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
694
  )
695
  custom_lora_button.click(
696
  remove_custom_lora,
697
  outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
698
  )
699
-
700
  show_enhanced_prompt.change(fn=lambda show: gr.update(visible=show),
701
  inputs=show_enhanced_prompt,
702
  outputs=enhanced_prompt_box)
703
-
704
  gr.on(
705
  triggers=[generate_button.click, prompt.submit],
706
  fn=self.run_lora,
707
- inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index,
708
- randomize_seed, seed, width, height, lora_scale, use_enhancer],
709
  outputs=[result, seed, progress_bar, enhanced_prompt_box]
710
  )
711
-
712
  with gr.Row():
713
  gr.HTML("<div style='text-align:center; font-size:0.9em; margin-top:20px;'>Credits: <a href='https://ruslanmv.com' target='_blank'>ruslanmv.com</a></div>")
714
-
715
  return app
716
 
717
  ##############################
718
  # ===== Main app.py =====
719
  ##############################
720
  if __name__ == "__main__":
721
- # Get the Hugging Face token from the environment.
722
  hf_token = os.environ.get("HF_TOKEN")
723
  if not hf_token:
724
  raise ValueError("Hugging Face token (HF_TOKEN) not found in environment variables. Please set it.")
@@ -726,5 +643,4 @@ if __name__ == "__main__":
726
  frontend = Frontend(model_manager)
727
  app = frontend.create_ui()
728
  app.queue()
729
- # Set share=True to create a public link if desired.
730
  app.launch(share=False, debug=True)
 
6
  import time
7
  import random
8
  import json
 
9
  from typing import Any, Dict, List, Optional, Union
10
 
 
 
 
 
 
 
 
 
11
  import torch
12
  import numpy as np
 
13
  from PIL import Image
14
+ import gradio as gr
15
  import spaces
16
 
17
  # Diffusers imports
 
23
  )
24
  from diffusers.utils import load_image
25
 
26
+ # Hugging Face Hub imports
27
  from huggingface_hub import ModelCard, HfFileSystem
28
 
 
 
 
29
  ##############################
30
  # ===== config.py =====
31
  ##############################
 
38
  ##############################
39
  # ===== utilities.py =====
40
  ##############################
41
+ def calculate_shift(image_seq_len, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.16):
 
 
 
 
 
 
42
  m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
43
  b = base_shift - m * base_seq_len
44
  mu = image_seq_len * m + b
45
  return mu
46
 
47
+ def retrieve_timesteps(scheduler,
48
+ num_inference_steps: Optional[int] = None,
49
+ device: Optional[Union[str, torch.device]] = None,
50
+ timesteps: Optional[List[int]] = None,
51
+ sigmas: Optional[List[float]] = None,
52
+ **kwargs):
 
 
53
  if timesteps is not None and sigmas is not None:
54
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
55
  if timesteps is not None:
56
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
57
  timesteps = scheduler.timesteps
 
66
  return timesteps, num_inference_steps
67
 
68
  def load_image_from_path(image_path: str):
 
69
  return load_image(image_path)
70
 
71
  def randomize_seed_if_needed(randomize_seed: bool, seed: int, max_seed: int) -> int:
 
72
  if randomize_seed:
73
  return random.randint(0, max_seed)
74
  return seed
 
76
  class calculateDuration:
77
  def __init__(self, activity_name=""):
78
  self.activity_name = activity_name
 
79
  def __enter__(self):
80
  self.start_time = time.time()
81
  return self
 
82
  def __exit__(self, exc_type, exc_value, traceback):
83
  self.end_time = time.time()
84
+ elapsed = self.end_time - self.start_time
85
  if self.activity_name:
86
+ print(f"Elapsed time for {self.activity_name}: {elapsed:.6f} seconds")
87
  else:
88
+ print(f"Elapsed time: {elapsed:.6f} seconds")
89
 
90
  ##############################
91
  # ===== enhance.py =====
92
  ##############################
93
  def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0):
 
 
 
 
94
  SYSTEM_PROMPT = (
95
  "You are a prompt enhancer and your work is to enhance the given prompt under 100 words "
96
  "without changing the essence, only write the enhanced prompt and nothing else."
97
  )
98
  timestamp = time.time()
99
+ formatted_prompt = f"<s>[INST] SYSTEM: {SYSTEM_PROMPT} [/INST][INST] {message} {timestamp} [/INST]"
 
 
 
 
100
  api_url = "https://ruslanmv-hf-llm-api.hf.space/api/v1/chat/completions"
101
  headers = {"Content-Type": "application/json"}
 
102
  payload = {
103
  "model": "mixtral-8x7b",
104
  "messages": [{"role": "user", "content": formatted_prompt}],
 
108
  "use_cache": False,
109
  "stream": True
110
  }
 
111
  try:
112
  response = requests.post(api_url, headers=headers, json=payload, stream=True)
113
  response.raise_for_status()
114
  full_output = ""
 
115
  for line in response.iter_lines():
116
  if not line:
117
  continue
 
137
  ##############################
138
  # ===== lora_handling.py =====
139
  ##############################
140
+ # Default LoRA list for initial UI setup
141
  loras = [
142
  {"image": "placeholder.jpg", "title": "Placeholder LoRA", "repo": "placeholder/repo", "weights": None, "trigger_word": ""}
143
  ]
144
 
145
  @torch.inference_mode()
146
+ def flux_pipe_call_that_returns_an_iterable_of_images(self,
147
+ prompt: Union[str, List[str]] = None,
148
+ prompt_2: Optional[Union[str, List[str]]] = None,
149
+ height: Optional[int] = None,
150
+ width: Optional[int] = None,
151
+ num_inference_steps: int = 28,
152
+ timesteps: List[int] = None,
153
+ guidance_scale: float = 3.5,
154
+ num_images_per_prompt: Optional[int] = 1,
155
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
156
+ latents: Optional[torch.FloatTensor] = None,
157
+ prompt_embeds: Optional[torch.FloatTensor] = None,
158
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
159
+ output_type: Optional[str] = "pil",
160
+ return_dict: bool = True,
161
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
162
+ max_sequence_length: int = 512,
163
+ good_vae: Optional[Any] = None):
 
 
164
  height = height or self.default_sample_size * self.vae_scale_factor
165
  width = width or self.default_sample_size * self.vae_scale_factor
166
 
 
173
  pooled_prompt_embeds=pooled_prompt_embeds,
174
  max_sequence_length=max_sequence_length,
175
  )
 
176
  self._guidance_scale = guidance_scale
177
  self._joint_attention_kwargs = joint_attention_kwargs
178
  self._interrupt = False
 
179
  batch_size = 1 if isinstance(prompt, str) else len(prompt)
180
  device = self._execution_device
 
181
  lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
182
  prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
183
  prompt=prompt,
 
189
  max_sequence_length=max_sequence_length,
190
  lora_scale=lora_scale,
191
  )
 
192
  num_channels_latents = self.transformer.config.in_channels // 4
193
  latents, latent_image_ids = self.prepare_latents(
194
  batch_size * num_images_per_prompt,
 
200
  generator,
201
  latents,
202
  )
 
203
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
204
  image_seq_len = latents.shape[1]
205
  mu = calculate_shift(
 
218
  mu=mu,
219
  )
220
  self._num_timesteps = len(timesteps)
 
221
  guidance = (torch.full([1], guidance_scale, device=device, dtype=torch.float32)
222
  .expand(latents.shape[0])
223
  if self.transformer.config.guidance_embeds else None)
 
224
  for i, t in enumerate(timesteps):
225
  if self.interrupt:
226
  continue
 
227
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
 
228
  noise_pred = self.transformer(
229
  hidden_states=latents,
230
  timestep=timestep / 1000,
 
236
  joint_attention_kwargs=self.joint_attention_kwargs,
237
  return_dict=False,
238
  )[0]
 
239
  latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
240
  latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
241
  image = self.vae.decode(latents_for_image, return_dict=False)[0]
242
  yield self.image_processor.postprocess(image, output_type=output_type)[0]
243
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
244
  torch.cuda.empty_cache()
 
245
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
246
  latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
247
  image = good_vae.decode(latents, return_dict=False)[0]
 
253
  split_link = link.split("/")
254
  if len(split_link) == 2:
255
  model_card = ModelCard.load(link)
256
+ base_model_card = model_card.data.get("base_model")
257
+ print(base_model_card)
258
+ if base_model_card not in ("black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"):
 
259
  raise Exception("Flux LoRA Not Found!")
 
260
  image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
261
  trigger_word = model_card.data.get("instance_prompt", "")
262
  image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
 
269
  if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
270
  image_elements = file.split("/")
271
  image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
 
272
  except Exception as e:
273
  print(e)
274
+ raise Exception("Invalid LoRA repository")
275
+ return split_link[1], link, safetensors_name, trigger_word, image_url
276
  else:
277
+ raise Exception("Invalid LoRA link format")
278
 
279
  def check_custom_model(link: str) -> tuple:
280
  if link.startswith("https://"):
 
284
  return get_huggingface_safetensors(link)
285
 
286
  def create_lora_card(title: str, repo: str, trigger_word: str, image: str) -> str:
287
+ trigger_word_info = (f"Using: <code><b>{trigger_word}</b></code> as the trigger word"
288
+ if trigger_word else "No trigger word found. Include it in your prompt")
 
 
 
289
  return f'''
290
  <div class="custom_lora_card">
291
  <span>Loaded custom LoRA:</span>
 
299
  </div>
300
  '''
301
 
302
+ def add_custom_lora(custom_lora: str) -> tuple:
303
+ global loras
304
  if custom_lora:
305
  try:
306
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
307
  print(f"Loaded custom LoRA: {repo}")
308
  card = create_lora_card(title, repo, trigger_word, image)
309
+ existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
 
310
  if existing_item_index is None:
311
  new_item = {
312
  "image": image,
 
316
  "trigger_word": trigger_word
317
  }
318
  print(new_item)
319
+ loras.append(new_item)
320
+ existing_item_index = len(loras) - 1
 
321
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
 
322
  except Exception as e:
323
  print(f"Error loading LoRA: {e}")
324
  return gr.update(visible=True, value="Invalid LoRA"), gr.update(visible=False), gr.update(), "", None, ""
 
331
  def prepare_prompt(prompt: str, selected_index: Optional[int], loras_list: list) -> str:
332
  if selected_index is None:
333
  raise gr.Error("You must select a LoRA before proceeding.🧨")
 
334
  selected_lora = loras_list[selected_index]
335
  trigger_word = selected_lora.get("trigger_word")
336
  if trigger_word:
 
356
  low_cpu_mem_usage=True
357
  )
358
 
359
+ def update_selection(evt: gr.SelectData, width, height) -> tuple:
360
+ selected_lora = loras[evt.index]
361
  new_placeholder = f"Type a prompt for {selected_lora['title']}"
362
  lora_repo = selected_lora["repo"]
363
  updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✅"
 
392
  self.initialize_models()
393
 
394
  def initialize_models(self):
 
395
  self.taef1 = AutoencoderTiny.from_pretrained(TAEF1_MODEL, torch_dtype=DTYPE).to(DEVICE)
396
  self.good_vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae", torch_dtype=DTYPE).to(DEVICE)
397
+ self.pipe = DiffusionPipeline.from_pretrained(BASE_MODEL, torch_dtype=DTYPE, vae=self.taef1).to(DEVICE)
 
 
398
  self.pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
399
  BASE_MODEL,
400
  vae=self.good_vae,
 
405
  tokenizer_2=self.pipe.tokenizer_2,
406
  torch_dtype=DTYPE,
407
  ).to(DEVICE)
408
+ # Bind custom LoRA method to the pipeline class (to avoid __slots__ issues)
 
409
  self.pipe.__class__.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images
410
 
411
+ @spaces.GPU(duration=100)
412
  def generate_image(self, prompt_mash, steps, seed, cfg_scale, width, height, lora_scale):
 
 
413
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
414
  with calculateDuration("Generating image"):
415
  for img in self.pipe.flux_pipe_call_that_returns_an_iterable_of_images(
 
426
  yield img
427
 
428
  def generate_image_to_image(self, prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
 
429
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
 
430
  image_input = load_image_from_path(image_input_path)
431
  with calculateDuration("Generating image to image"):
432
  final_image = self.pipe_i2i(
 
449
  class Frontend:
450
  def __init__(self, model_manager: ModelManager):
451
  self.model_manager = model_manager
452
+ self.loras = loras
453
  self.load_initial_loras()
454
  self.css = self.define_css()
455
 
456
  def define_css(self):
 
457
  return '''
458
  /* Title Styling */
459
  #title {
 
522
  self.loras = loras_list
523
  except ImportError:
524
  print("Warning: lora.py not found, using placeholder LoRAs.")
 
525
 
526
  @spaces.GPU(duration=100)
527
  def run_lora(self, prompt, image_input, image_strength, cfg_scale, steps, selected_index,
528
  randomize_seed, seed, width, height, lora_scale, use_enhancer,
529
  progress=gr.Progress(track_tqdm=True)):
530
  seed = randomize_seed_if_needed(randomize_seed, seed, MAX_SEED)
 
531
  prompt_mash = prepare_prompt(prompt, selected_index, self.loras)
532
  enhanced_text = ""
 
 
533
  if use_enhancer:
534
  for enhanced_chunk in generate(prompt_mash):
535
  enhanced_text = enhanced_chunk
 
537
  prompt_mash = enhanced_text
538
  else:
539
  enhanced_text = ""
 
540
  selected_lora = self.loras[selected_index]
541
  unload_lora_weights(self.model_manager.pipe, self.model_manager.pipe_i2i)
542
  pipe_to_use = self.model_manager.pipe_i2i if image_input is not None else self.model_manager.pipe
543
  load_lora_weights_into_pipeline(pipe_to_use, selected_lora["repo"], selected_lora.get("weights"))
 
544
  if image_input is not None:
545
  final_image = self.model_manager.generate_image_to_image(
546
  prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed
 
559
 
560
  def create_ui(self):
561
  with gr.Blocks(theme=gr.themes.Base(), css=self.css, title="Flux LoRA Generation") as app:
562
+ title = gr.HTML("<h1>Flux LoRA Generation</h1>", elem_id="title")
 
 
 
563
  selected_index = gr.State(None)
 
564
  with gr.Row():
565
  with gr.Column(scale=3):
566
  prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Choose the LoRA and type the prompt")
 
585
  with gr.Column():
586
  progress_bar = gr.Markdown(elem_id="progress", visible=False)
587
  result = gr.Image(label="Generated Image")
 
588
  with gr.Row():
589
  with gr.Accordion("Advanced Settings", open=False):
590
  with gr.Row():
 
605
  use_enhancer = gr.Checkbox(value=False, label="Use Prompt Enhancer")
606
  show_enhanced_prompt = gr.Checkbox(value=False, label="Display Enhanced Prompt")
607
  enhanced_prompt_box = gr.Textbox(label="Enhanced Prompt", visible=False)
 
608
  gallery.select(
609
  update_selection,
610
+ inputs=[width, height],
611
  outputs=[prompt, selected_info, selected_index, width, height]
612
  )
613
  custom_lora.input(
614
  add_custom_lora,
615
+ inputs=[custom_lora],
616
  outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
617
  )
618
  custom_lora_button.click(
619
  remove_custom_lora,
620
  outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
621
  )
 
622
  show_enhanced_prompt.change(fn=lambda show: gr.update(visible=show),
623
  inputs=show_enhanced_prompt,
624
  outputs=enhanced_prompt_box)
 
625
  gr.on(
626
  triggers=[generate_button.click, prompt.submit],
627
  fn=self.run_lora,
628
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer],
 
629
  outputs=[result, seed, progress_bar, enhanced_prompt_box]
630
  )
 
631
  with gr.Row():
632
  gr.HTML("<div style='text-align:center; font-size:0.9em; margin-top:20px;'>Credits: <a href='https://ruslanmv.com' target='_blank'>ruslanmv.com</a></div>")
 
633
  return app
634
 
635
  ##############################
636
  # ===== Main app.py =====
637
  ##############################
638
  if __name__ == "__main__":
 
639
  hf_token = os.environ.get("HF_TOKEN")
640
  if not hf_token:
641
  raise ValueError("Hugging Face token (HF_TOKEN) not found in environment variables. Please set it.")
 
643
  frontend = Frontend(model_manager)
644
  app = frontend.create_ui()
645
  app.queue()
 
646
  app.launch(share=False, debug=True)