ruslanmv commited on
Commit
1d05326
·
verified ·
1 Parent(s): 0b25329

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -5
app.py CHANGED
@@ -87,6 +87,22 @@ class calculateDuration:
87
  else:
88
  print(f"Elapsed time: {elapsed:.6f} seconds")
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  ##############################
91
  # ===== enhance.py =====
92
  ##############################
@@ -137,7 +153,6 @@ def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetitio
137
  ##############################
138
  # ===== lora_handling.py =====
139
  ##############################
140
- # Default LoRA list for initial UI setup
141
  loras = [
142
  {"image": "placeholder.jpg", "title": "Placeholder LoRA", "repo": "placeholder/repo", "weights": None, "trigger_word": ""}
143
  ]
@@ -163,7 +178,6 @@ def flux_pipe_call_that_returns_an_iterable_of_images(self,
163
  good_vae: Optional[Any] = None):
164
  height = height or self.default_sample_size * self.vae_scale_factor
165
  width = width or self.default_sample_size * self.vae_scale_factor
166
-
167
  self.check_inputs(
168
  prompt,
169
  prompt_2,
@@ -341,6 +355,8 @@ def prepare_prompt(prompt: str, selected_index: Optional[int], loras_list: list)
341
  prompt_mash = f"{prompt} {trigger_word}"
342
  else:
343
  prompt_mash = prompt
 
 
344
  return prompt_mash
345
 
346
  def unload_lora_weights(pipe, pipe_i2i):
@@ -405,10 +421,10 @@ class ModelManager:
405
  tokenizer_2=self.pipe.tokenizer_2,
406
  torch_dtype=DTYPE,
407
  ).to(DEVICE)
408
- # Bind custom LoRA method to the pipeline class (to avoid __slots__ issues)
409
  self.pipe.__class__.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images
410
-
411
- @spaces.GPU(duration=100)
412
  def generate_image(self, prompt_mash, steps, seed, cfg_scale, width, height, lora_scale):
413
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
414
  with calculateDuration("Generating image"):
@@ -425,6 +441,7 @@ class ModelManager:
425
  ):
426
  yield img
427
 
 
428
  def generate_image_to_image(self, prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
429
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
430
  image_input = load_image_from_path(image_input_path)
 
87
  else:
88
  print(f"Elapsed time: {elapsed:.6f} seconds")
89
 
90
+ ##############################
91
+ # ===== Helper: truncate_prompt =====
92
+ ##############################
93
+ def truncate_prompt(prompt: str) -> str:
94
+ """
95
+ Uses the global pipeline's tokenizer (assumed available as `pipe.tokenizer`)
96
+ to truncate the prompt to the maximum allowed length.
97
+ """
98
+ try:
99
+ tokenizer = pipe.tokenizer
100
+ tokenized = tokenizer(prompt, truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt")
101
+ return tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True)
102
+ except Exception as e:
103
+ print(f"Error in truncate_prompt: {e}")
104
+ return prompt
105
+
106
  ##############################
107
  # ===== enhance.py =====
108
  ##############################
 
153
  ##############################
154
  # ===== lora_handling.py =====
155
  ##############################
 
156
  loras = [
157
  {"image": "placeholder.jpg", "title": "Placeholder LoRA", "repo": "placeholder/repo", "weights": None, "trigger_word": ""}
158
  ]
 
178
  good_vae: Optional[Any] = None):
179
  height = height or self.default_sample_size * self.vae_scale_factor
180
  width = width or self.default_sample_size * self.vae_scale_factor
 
181
  self.check_inputs(
182
  prompt,
183
  prompt_2,
 
355
  prompt_mash = f"{prompt} {trigger_word}"
356
  else:
357
  prompt_mash = prompt
358
+ # Truncate the prompt using the tokenizer to ensure token indices are in range.
359
+ prompt_mash = truncate_prompt(prompt_mash)
360
  return prompt_mash
361
 
362
  def unload_lora_weights(pipe, pipe_i2i):
 
421
  tokenizer_2=self.pipe.tokenizer_2,
422
  torch_dtype=DTYPE,
423
  ).to(DEVICE)
424
+ # Bind the custom LoRA method to the pipeline class (to avoid __slots__ issues)
425
  self.pipe.__class__.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images
426
+
427
+ @spaces.GPU(duration=100)
428
  def generate_image(self, prompt_mash, steps, seed, cfg_scale, width, height, lora_scale):
429
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
430
  with calculateDuration("Generating image"):
 
441
  ):
442
  yield img
443
 
444
+ @spaces.GPU(duration=100)
445
  def generate_image_to_image(self, prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
446
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
447
  image_input = load_image_from_path(image_input_path)