daoyuan98 commited on
Commit
f492855
·
verified ·
1 Parent(s): 8c37c22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +296 -83
app.py CHANGED
@@ -1,74 +1,294 @@
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
  ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- generator = torch.Generator().manual_seed(seed)
40
 
41
- image = pipe(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- return image, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
  ]
59
 
60
- css = """
61
  #col-container {
62
  margin: 0 auto;
63
- max-width: 640px;
64
  }
65
  """
66
 
67
  with gr.Blocks(css=css) as demo:
 
68
  with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
 
 
 
71
  with gr.Row():
 
72
  prompt = gr.Text(
73
  label="Prompt",
74
  show_label=False,
@@ -76,19 +296,13 @@ with gr.Blocks(css=css) as demo:
76
  placeholder="Enter your prompt",
77
  container=False,
78
  )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
  result = gr.Image(label="Result", show_label=False)
83
-
84
  with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
  seed = gr.Slider(
93
  label="Seed",
94
  minimum=0,
@@ -96,59 +310,58 @@ with gr.Blocks(css=css) as demo:
96
  step=1,
97
  value=0,
98
  )
99
-
100
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
  with gr.Row():
 
103
  width = gr.Slider(
104
  label="Width",
105
  minimum=256,
106
  maximum=MAX_IMAGE_SIZE,
107
  step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
110
-
111
  height = gr.Slider(
112
  label="Height",
113
  minimum=256,
114
  maximum=MAX_IMAGE_SIZE,
115
  step=32,
116
- value=1024, # Replace with defaults that work for your model
117
  )
118
-
119
  with gr.Row():
 
120
  guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
  step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
127
-
128
  num_inference_steps = gr.Slider(
129
  label="Number of inference steps",
130
  minimum=1,
131
  maximum=50,
132
  step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
 
 
 
 
 
 
 
 
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
  gr.on(
138
  triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
1
+ from dataclasses import dataclass
2
+ from typing import Union, Optional, List, Any, Dict
3
+
4
  import gradio as gr
5
  import numpy as np
6
  import random
7
+ import spaces
 
 
8
  import torch
9
+ from huggingface_hub import hf_hub_download
10
 
11
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, FluxPipeline
12
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
13
 
14
+ from model import Flux, FluxParams
 
 
 
15
 
16
+ def calculate_shift(
17
+ image_seq_len,
18
+ base_seq_len: int = 256,
19
+ max_seq_len: int = 4096,
20
+ base_shift: float = 0.5,
21
+ max_shift: float = 1.16,
22
+ ):
23
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
24
+ b = base_shift - m * base_seq_len
25
+ mu = image_seq_len * m + b
26
+ return mu
27
 
28
+
29
+ def retrieve_timesteps(
30
+ scheduler,
31
+ num_inference_steps: Optional[int] = None,
32
+ device: Optional[Union[str, torch.device]] = None,
33
+ timesteps: Optional[List[int]] = None,
34
+ sigmas: Optional[List[float]] = None,
35
+ **kwargs,
 
 
 
 
 
 
 
36
  ):
37
+ if timesteps is not None and sigmas is not None:
38
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
39
+ if timesteps is not None:
40
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
41
+ timesteps = scheduler.timesteps
42
+ num_inference_steps = len(timesteps)
43
+ elif sigmas is not None:
44
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
45
+ timesteps = scheduler.timesteps
46
+ num_inference_steps = len(timesteps)
47
+ else:
48
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
49
+ timesteps = scheduler.timesteps
50
+ return timesteps, num_inference_steps
51
 
 
52
 
53
+ @torch.inference_mode()
54
+ def flux_pipe_call_that_returns_an_iterable_of_images(
55
+ self,
56
+ prompt: Union[str, List[str]] = None,
57
+ prompt_2: Optional[Union[str, List[str]]] = None,
58
+ height: Optional[int] = None,
59
+ width: Optional[int] = None,
60
+ num_inference_steps: int = 28,
61
+ timesteps: List[int] = None,
62
+ guidance_scale: float = 3.5,
63
+ num_images_per_prompt: Optional[int] = 1,
64
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
65
+ latents: Optional[torch.FloatTensor] = None,
66
+ prompt_embeds: Optional[torch.FloatTensor] = None,
67
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
68
+ output_type: Optional[str] = "pil",
69
+ return_dict: bool = True,
70
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
71
+ max_sequence_length: int = 512,
72
+ good_vae: Optional[Any] = None,
73
+ ):
74
+ height = height or self.default_sample_size * self.vae_scale_factor
75
+ width = width or self.default_sample_size * self.vae_scale_factor
76
+
77
+ # 1. Check inputs
78
+ self.check_inputs(
79
+ prompt,
80
+ prompt_2,
81
+ height,
82
+ width,
83
+ prompt_embeds=prompt_embeds,
84
+ pooled_prompt_embeds=pooled_prompt_embeds,
85
+ max_sequence_length=max_sequence_length,
86
+ )
87
+
88
+ self._guidance_scale = guidance_scale
89
+ self._joint_attention_kwargs = joint_attention_kwargs
90
+ self._interrupt = False
91
+
92
+ # 2. Define call parameters
93
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
94
+ device = self._execution_device
95
+
96
+ # 3. Encode prompt
97
+ lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
98
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
99
  prompt=prompt,
100
+ prompt_2=prompt_2,
101
+ prompt_embeds=prompt_embeds,
102
+ pooled_prompt_embeds=pooled_prompt_embeds,
103
+ device=device,
104
+ num_images_per_prompt=num_images_per_prompt,
105
+ max_sequence_length=max_sequence_length,
106
+ lora_scale=lora_scale,
107
+ )
108
+ # 4. Prepare latent variables
109
+ num_channels_latents = self.transformer.config.in_channels // 4
110
+ latents, latent_image_ids = self.prepare_latents(
111
+ batch_size * num_images_per_prompt,
112
+ num_channels_latents,
113
+ height,
114
+ width,
115
+ prompt_embeds.dtype,
116
+ device,
117
+ generator,
118
+ latents,
119
+ )
120
+ # 5. Prepare timesteps
121
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
122
+ image_seq_len = latents.shape[1]
123
+ mu = calculate_shift(
124
+ image_seq_len,
125
+ self.scheduler.config.base_image_seq_len,
126
+ self.scheduler.config.max_image_seq_len,
127
+ self.scheduler.config.base_shift,
128
+ self.scheduler.config.max_shift,
129
+ )
130
+ timesteps, num_inference_steps = retrieve_timesteps(
131
+ self.scheduler,
132
+ num_inference_steps,
133
+ device,
134
+ timesteps,
135
+ sigmas,
136
+ mu=mu,
137
+ )
138
+ self._num_timesteps = len(timesteps)
139
+
140
+ # Handle guidance
141
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
142
+
143
+ # 6. Denoising loop
144
+ for i, t in enumerate(timesteps):
145
+ if self.interrupt:
146
+ continue
147
+
148
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
149
+
150
+ noise_pred = self.transformer(
151
+ hidden_states=latents,
152
+ timestep=timestep / 1000,
153
+ guidance=guidance,
154
+ pooled_projections=pooled_prompt_embeds,
155
+ encoder_hidden_states=prompt_embeds,
156
+ txt_ids=text_ids,
157
+ img_ids=latent_image_ids,
158
+ joint_attention_kwargs=self.joint_attention_kwargs,
159
+ return_dict=False,
160
+ )[0]
161
+ # Yield intermediate result
162
+ latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
163
+ latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
164
+ image = self.vae.decode(latents_for_image, return_dict=False)[0]
165
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
166
+
167
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
168
+ torch.cuda.empty_cache()
169
+
170
+ # Final image using good_vae
171
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
172
+ latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
173
+ image = good_vae.decode(latents, return_dict=False)[0]
174
+ self.maybe_free_model_hooks()
175
+ torch.cuda.empty_cache()
176
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
177
+
178
+
179
+ @dataclass
180
+ class ModelSpec:
181
+ params: FluxParams
182
+ repo_id: str
183
+ repo_flow: str
184
+ repo_ae: str
185
+ repo_id_ae: str
186
+
187
 
188
+ config = ModelSpec(
189
+ repo_id="TencentARC/flux-mini",
190
+ repo_flow="flux-mini.safetensors",
191
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
192
+ repo_ae="ae.safetensors",
193
+ params=FluxParams(
194
+ in_channels=64,
195
+ vec_in_dim=768,
196
+ context_in_dim=4096,
197
+ hidden_size=3072,
198
+ mlp_ratio=4.0,
199
+ num_heads=24,
200
+ depth=5,
201
+ depth_single_blocks=10,
202
+ axes_dim=[16, 56, 56],
203
+ theta=10_000,
204
+ qkv_bias=True,
205
+ guidance_embed=True,
206
+ )
207
+ )
208
 
209
 
210
+ def load_flow_model2(config, device: str = "cuda", hf_download: bool = True):
211
+ if (config.repo_id is not None
212
+ and config.repo_flow is not None
213
+ and hf_download
214
+ ):
215
+ ckpt_path = hf_hub_download(config.repo_id, config.repo_flow.replace("sft", "safetensors"))
216
+
217
+ model = Flux(config.params)
218
+ if ckpt_path is not None:
219
+ sd = load_sft(ckpt_path, device=str(device))
220
+ missing, unexpected = model.load_state_dict(sd, strict=True)
221
+ return model
222
+
223
+
224
+ dtype = torch.bfloat16
225
+ device = "cuda" if torch.cuda.is_available() else "cpu"
226
+
227
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler")
228
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
229
+ text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder").to(device)
230
+ tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer")
231
+ text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2").to(device)
232
+ tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2")
233
+ transformer = load_flow_model2(config, device)
234
+
235
+ pipe = FluxPipeline(
236
+ scheduler,
237
+ vae,
238
+ text_encoder,
239
+ tokenizer,
240
+ text_encoder_2,
241
+ tokenizer_2,
242
+ transformer
243
+ )
244
+ torch.cuda.empty_cache()
245
+
246
+ MAX_SEED = np.iinfo(np.int32).max
247
+ MAX_IMAGE_SIZE = 2048
248
+
249
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
250
+
251
+ @spaces.GPU(duration=75)
252
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
253
+ if randomize_seed:
254
+ seed = random.randint(0, MAX_SEED)
255
+ generator = torch.Generator().manual_seed(seed)
256
+
257
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
258
+ prompt=prompt,
259
+ guidance_scale=guidance_scale,
260
+ num_inference_steps=num_inference_steps,
261
+ width=width,
262
+ height=height,
263
+ generator=generator,
264
+ output_type="pil",
265
+ good_vae=good_vae,
266
+ ):
267
+ yield img, seed
268
+
269
  examples = [
270
+ "thousands of luminous oysters on a shore reflecting and refracting the sunset",
271
+ "profile of sad Socrates, full body, high detail, dramatic scene, Epic dynamic action, wide angle, cinematic, hyper realistic, concept art, warm muted tones as painted by Bernie Wrightson, Frank Frazetta,",
272
+ "ghosts, astronauts, robots, cats, superhero costumes, line drawings, naive, simple, exploring a strange planet, coloured pencil crayons, , black canvas background, drawn by 5 year old child",
273
  ]
274
 
275
+ css="""
276
  #col-container {
277
  margin: 0 auto;
278
+ max-width: 520px;
279
  }
280
  """
281
 
282
  with gr.Blocks(css=css) as demo:
283
+
284
  with gr.Column(elem_id="col-container"):
285
+ gr.Markdown(f"""# FLUX-Mini
286
+ A 3.2B param rectified flow transformer distilled from [FLUX.1 [dev]](https://blackforestlabs.ai/)
287
+ [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)]
288
+ """)
289
+
290
  with gr.Row():
291
+
292
  prompt = gr.Text(
293
  label="Prompt",
294
  show_label=False,
 
296
  placeholder="Enter your prompt",
297
  container=False,
298
  )
299
+
300
+ run_button = gr.Button("Run", scale=0)
301
+
302
  result = gr.Image(label="Result", show_label=False)
303
+
304
  with gr.Accordion("Advanced Settings", open=False):
305
+
 
 
 
 
 
 
306
  seed = gr.Slider(
307
  label="Seed",
308
  minimum=0,
 
310
  step=1,
311
  value=0,
312
  )
313
+
314
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
315
+
316
  with gr.Row():
317
+
318
  width = gr.Slider(
319
  label="Width",
320
  minimum=256,
321
  maximum=MAX_IMAGE_SIZE,
322
  step=32,
323
+ value=1024,
324
  )
325
+
326
  height = gr.Slider(
327
  label="Height",
328
  minimum=256,
329
  maximum=MAX_IMAGE_SIZE,
330
  step=32,
331
+ value=1024,
332
  )
333
+
334
  with gr.Row():
335
+
336
  guidance_scale = gr.Slider(
337
+ label="Guidance Scale",
338
+ minimum=1,
339
+ maximum=15,
340
  step=0.1,
341
+ value=3.5,
342
  )
343
+
344
  num_inference_steps = gr.Slider(
345
  label="Number of inference steps",
346
  minimum=1,
347
  maximum=50,
348
  step=1,
349
+ value=28,
350
  )
351
+
352
+ gr.Examples(
353
+ examples = examples,
354
+ fn = infer,
355
+ inputs = [prompt],
356
+ outputs = [result, seed],
357
+ cache_examples="lazy"
358
+ )
359
 
 
360
  gr.on(
361
  triggers=[run_button.click, prompt.submit],
362
+ fn = infer,
363
+ inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
364
+ outputs = [result, seed]
 
 
 
 
 
 
 
 
 
365
  )
366
 
367
+ demo.launch()