multimodalart HF staff commited on
Commit
e300c6e
1 Parent(s): 1500e0d
Files changed (1) hide show
  1. app.py +63 -20
app.py CHANGED
@@ -5,9 +5,9 @@ import logging
5
  import torch
6
  from PIL import Image
7
  import spaces
8
- from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
9
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
10
-
11
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
12
  import copy
13
  import random
@@ -25,6 +25,15 @@ base_model = "black-forest-labs/FLUX.1-dev"
25
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
26
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
27
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
 
 
 
 
 
 
 
 
 
28
 
29
  MAX_SEED = 2**32-1
30
 
@@ -88,7 +97,26 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
88
  ):
89
  yield img
90
 
91
- def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  if selected_index is None:
93
  raise gr.Error("You must select a LoRA before proceeding.")
94
  selected_lora = loras[selected_index]
@@ -107,32 +135,44 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
107
 
108
  with calculateDuration("Unloading LoRA"):
109
  pipe.unload_lora_weights()
 
110
 
111
  # Load LoRA weights
112
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
113
- if "weights" in selected_lora:
114
- pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
 
 
 
115
  else:
116
- pipe.load_lora_weights(lora_path)
117
-
 
 
 
118
  # Set random seed for reproducibility
119
  with calculateDuration("Randomizing seed"):
120
  if randomize_seed:
121
  seed = random.randint(0, MAX_SEED)
122
-
123
- image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
 
 
 
 
 
124
 
125
- # Consume the generator to get the final image
126
- final_image = None
127
- step_counter = 0
128
- for image in image_generator:
129
- step_counter+=1
130
- final_image = image
131
- progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
132
- yield image, seed, gr.update(value=progress_bar, visible=True)
 
 
133
 
134
- yield final_image, seed, gr.update(value=progress_bar, visible=False)
135
-
136
  def get_huggingface_safetensors(link):
137
  split_link = link.split("/")
138
  if(len(split_link) == 2):
@@ -257,6 +297,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
257
 
258
  with gr.Row():
259
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
260
  with gr.Column():
261
  with gr.Row():
262
  cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
@@ -288,7 +331,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
288
  gr.on(
289
  triggers=[generate_button.click, prompt.submit],
290
  fn=run_lora,
291
- inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
292
  outputs=[result, seed, progress_bar]
293
  )
294
 
 
5
  import torch
6
  from PIL import Image
7
  import spaces
8
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
9
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
10
+ from diffusers.utils import load_image
11
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
12
  import copy
13
  import random
 
25
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
26
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
27
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
28
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model,
29
+ vae=good_vae,
30
+ transformer=pipe.transformer,
31
+ text_encoder=pipe.text_encoder,
32
+ tokenizer=pipe.tokenizer,
33
+ text_encoder_2=pipe.text_encoder_2,
34
+ tokenizer_2=pipe.tokenizer_2,
35
+ torch_dtype=dtype
36
+ )
37
 
38
  MAX_SEED = 2**32-1
39
 
 
97
  ):
98
  yield img
99
 
100
+ @spaces.GPU(duration=70)
101
+ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
102
+ generator = torch.Generator(device="cuda").manual_seed(seed)
103
+ pipe_i2i.to("cuda")
104
+ image_input = load_image(image_input_path)
105
+ final_image = pipe_i2i(
106
+ prompt=prompt_mash,
107
+ image=image_input,
108
+ strength=image_strength,
109
+ num_inference_steps=steps,
110
+ guidance_scale=cfg_scale,
111
+ width=width,
112
+ height=height,
113
+ generator=generator,
114
+ joint_attention_kwargs={"scale": lora_scale},
115
+ output_type="pil",
116
+ ).images[0]
117
+ return final_image
118
+
119
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
120
  if selected_index is None:
121
  raise gr.Error("You must select a LoRA before proceeding.")
122
  selected_lora = loras[selected_index]
 
135
 
136
  with calculateDuration("Unloading LoRA"):
137
  pipe.unload_lora_weights()
138
+ pipe_i2i.unload_lora_weights()
139
 
140
  # Load LoRA weights
141
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
142
+ if(image_input is not None):
143
+ if "weights" in selected_lora:
144
+ pipe_i2i.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
145
+ else:
146
+ pipe_i2i.load_lora_weights(lora_path)
147
  else:
148
+ if "weights" in selected_lora:
149
+ pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
150
+ else:
151
+ pipe.load_lora_weights(lora_path)
152
+
153
  # Set random seed for reproducibility
154
  with calculateDuration("Randomizing seed"):
155
  if randomize_seed:
156
  seed = random.randint(0, MAX_SEED)
157
+
158
+ if(image_input is not None):
159
+
160
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
161
+ yield final_image, seed, gr.update(visible=False)
162
+ else:
163
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
164
 
165
+ # Consume the generator to get the final image
166
+ final_image = None
167
+ step_counter = 0
168
+ for image in image_generator:
169
+ step_counter+=1
170
+ final_image = image
171
+ progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
172
+ yield image, seed, gr.update(value=progress_bar, visible=True)
173
+
174
+ yield final_image, seed, gr.update(value=progress_bar, visible=False)
175
 
 
 
176
  def get_huggingface_safetensors(link):
177
  split_link = link.split("/")
178
  if(len(split_link) == 2):
 
297
 
298
  with gr.Row():
299
  with gr.Accordion("Advanced Settings", open=False):
300
+ with gr.Row():
301
+ input_image = gr.Image(label="Input image", type="filepath")
302
+ image_strength = gr.Slider(label="Image Strength", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
303
  with gr.Column():
304
  with gr.Row():
305
  cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
 
331
  gr.on(
332
  triggers=[generate_button.click, prompt.submit],
333
  fn=run_lora,
334
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
335
  outputs=[result, seed, progress_bar]
336
  )
337