daoyuan98 commited on
Commit
6e29804
·
verified ·
1 Parent(s): 1d45c1e

update main file, fix local bugs

Browse files
Files changed (1) hide show
  1. app.py +54 -45
app.py CHANGED
@@ -6,8 +6,8 @@ import numpy as np
6
  import random
7
  import spaces
8
  import torch
9
- from huggingface_hub import hf_hub_download
10
  from safetensors.torch import load_file as load_sft
 
11
 
12
  from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, FluxPipeline
13
  from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
@@ -29,10 +29,10 @@ def calculate_shift(
29
 
30
  def retrieve_timesteps(
31
  scheduler,
32
- num_inference_steps: Optional[int] = None,
33
- device: Optional[Union[str, torch.device]] = None,
34
- timesteps: Optional[List[int]] = None,
35
- sigmas: Optional[List[float]] = None,
36
  **kwargs,
37
  ):
38
  if timesteps is not None and sigmas is not None:
@@ -54,23 +54,23 @@ def retrieve_timesteps(
54
  @torch.inference_mode()
55
  def flux_pipe_call_that_returns_an_iterable_of_images(
56
  self,
57
- prompt: Union[str, List[str]] = None,
58
- prompt_2: Optional[Union[str, List[str]]] = None,
59
- height: Optional[int] = None,
60
- width: Optional[int] = None,
61
  num_inference_steps: int = 28,
62
- timesteps: List[int] = None,
63
  guidance_scale: float = 3.5,
64
- num_images_per_prompt: Optional[int] = 1,
65
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
66
- latents: Optional[torch.FloatTensor] = None,
67
- prompt_embeds: Optional[torch.FloatTensor] = None,
68
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
69
- output_type: Optional[str] = "pil",
70
- return_dict: bool = True,
71
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
72
- max_sequence_length: int = 512,
73
- good_vae: Optional[Any] = None,
74
  ):
75
  height = height or self.default_sample_size * self.vae_scale_factor
76
  width = width or self.default_sample_size * self.vae_scale_factor
@@ -92,7 +92,10 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
92
 
93
  # 2. Define call parameters
94
  batch_size = 1 if isinstance(prompt, str) else len(prompt)
95
- device = self._execution_device
 
 
 
96
 
97
  # 3. Encode prompt
98
  lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
@@ -107,7 +110,7 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
107
  lora_scale=lora_scale,
108
  )
109
  # 4. Prepare latent variables
110
- num_channels_latents = self.transformer.config.in_channels // 4
111
  latents, latent_image_ids = self.prepare_latents(
112
  batch_size * num_images_per_prompt,
113
  num_channels_latents,
@@ -139,26 +142,25 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
139
  self._num_timesteps = len(timesteps)
140
 
141
  # Handle guidance
142
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
143
 
 
144
  # 6. Denoising loop
145
  for i, t in enumerate(timesteps):
146
  if self.interrupt:
147
  continue
148
 
149
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
150
 
151
  noise_pred = self.transformer(
152
- hidden_states=latents,
153
- timestep=timestep / 1000,
154
- guidance=guidance,
155
- pooled_projections=pooled_prompt_embeds,
156
- encoder_hidden_states=prompt_embeds,
157
- txt_ids=text_ids,
158
- img_ids=latent_image_ids,
159
- joint_attention_kwargs=self.joint_attention_kwargs,
160
- return_dict=False,
161
- )[0]
162
  # Yield intermediate result
163
  latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
164
  latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
@@ -184,6 +186,7 @@ class ModelSpec:
184
  repo_flow: str
185
  repo_ae: str
186
  repo_id_ae: str
 
187
 
188
 
189
  config = ModelSpec(
@@ -191,6 +194,7 @@ config = ModelSpec(
191
  repo_flow="flux-mini.safetensors",
192
  repo_id_ae="black-forest-labs/FLUX.1-dev",
193
  repo_ae="ae.safetensors",
 
194
  params=FluxParams(
195
  in_channels=64,
196
  vec_in_dim=768,
@@ -209,11 +213,14 @@ config = ModelSpec(
209
 
210
 
211
  def load_flow_model2(config, device: str = "cuda", hf_download: bool = True):
212
- if (config.repo_id is not None
 
213
  and config.repo_flow is not None
214
  and hf_download
215
  ):
216
  ckpt_path = hf_hub_download(config.repo_id, config.repo_flow.replace("sft", "safetensors"))
 
 
217
 
218
  model = Flux(config.params)
219
  if ckpt_path is not None:
@@ -226,12 +233,12 @@ dtype = torch.bfloat16
226
  device = "cuda" if torch.cuda.is_available() else "cpu"
227
 
228
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler")
229
- vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
230
- text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder").to(device)
231
  tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer")
232
- text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2").to(device)
233
  tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2")
234
- transformer = load_flow_model2(config, device)
235
 
236
  pipe = FluxPipeline(
237
  scheduler,
@@ -245,19 +252,20 @@ pipe = FluxPipeline(
245
  torch.cuda.empty_cache()
246
 
247
  MAX_SEED = np.iinfo(np.int32).max
248
- MAX_IMAGE_SIZE = 2048
249
 
250
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
251
 
252
  @spaces.GPU(duration=75)
253
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
 
254
  if randomize_seed:
255
  seed = random.randint(0, MAX_SEED)
256
  generator = torch.Generator().manual_seed(seed)
257
 
258
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
259
  prompt=prompt,
260
- guidance_scale=guidance_scale,
261
  num_inference_steps=num_inference_steps,
262
  width=width,
263
  height=height,
@@ -265,12 +273,13 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidan
265
  output_type="pil",
266
  good_vae=good_vae,
267
  ):
268
- yield img, seed
 
269
 
270
  examples = [
 
271
  "thousands of luminous oysters on a shore reflecting and refracting the sunset",
272
- "profile of sad Socrates, full body, high detail, dramatic scene, Epic dynamic action, wide angle, cinematic, hyper realistic, concept art, warm muted tones as painted by Bernie Wrightson, Frank Frazetta,",
273
- "ghosts, astronauts, robots, cats, superhero costumes, line drawings, naive, simple, exploring a strange planet, coloured pencil crayons, , black canvas background, drawn by 5 year old child",
274
  ]
275
 
276
  css="""
@@ -365,4 +374,4 @@ A 3.2B param rectified flow transformer distilled from [FLUX.1 [dev]](https://bl
365
  outputs = [result, seed]
366
  )
367
 
368
- demo.launch()
 
6
  import random
7
  import spaces
8
  import torch
 
9
  from safetensors.torch import load_file as load_sft
10
+ from huggingface_hub import hf_hub_download
11
 
12
  from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, FluxPipeline
13
  from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
 
29
 
30
  def retrieve_timesteps(
31
  scheduler,
32
+ num_inference_steps: Optional = None,
33
+ device: Optional = None,
34
+ timesteps: Optional = None,
35
+ sigmas: Optional = None,
36
  **kwargs,
37
  ):
38
  if timesteps is not None and sigmas is not None:
 
54
  @torch.inference_mode()
55
  def flux_pipe_call_that_returns_an_iterable_of_images(
56
  self,
57
+ prompt = None,
58
+ prompt_2 = None,
59
+ height = None,
60
+ width = None,
61
  num_inference_steps: int = 28,
62
+ timesteps = None,
63
  guidance_scale: float = 3.5,
64
+ num_images_per_prompt = 1,
65
+ generator = None,
66
+ latents = None,
67
+ prompt_embeds = None,
68
+ pooled_prompt_embeds = None,
69
+ output_type = "pil",
70
+ return_dict = True,
71
+ joint_attention_kwargs = None,
72
+ max_sequence_length = 512,
73
+ good_vae = None,
74
  ):
75
  height = height or self.default_sample_size * self.vae_scale_factor
76
  width = width or self.default_sample_size * self.vae_scale_factor
 
92
 
93
  # 2. Define call parameters
94
  batch_size = 1 if isinstance(prompt, str) else len(prompt)
95
+ try:
96
+ device = self._execution_device
97
+ except:
98
+ device = torch.device('cuda:0')
99
 
100
  # 3. Encode prompt
101
  lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
 
110
  lora_scale=lora_scale,
111
  )
112
  # 4. Prepare latent variables
113
+ num_channels_latents = self.transformer.in_channels // 4
114
  latents, latent_image_ids = self.prepare_latents(
115
  batch_size * num_images_per_prompt,
116
  num_channels_latents,
 
142
  self._num_timesteps = len(timesteps)
143
 
144
  # Handle guidance
145
+ guidance = torch.full([1], guidance_scale, device=device, dtype=dtype).expand(latents.shape[0]) # if self.transformer.params.guidance_embeds else None
146
 
147
+ # print(latent_image_ids.shape, text_ids.shape, pooled_prompt_embeds.shape)
148
  # 6. Denoising loop
149
  for i, t in enumerate(timesteps):
150
  if self.interrupt:
151
  continue
152
 
153
+ timestep = t.expand(latents.shape[0]).to(dtype)
154
 
155
  noise_pred = self.transformer(
156
+ img=latents.to(dtype).to(device),
157
+ timesteps=(timestep / 1000).to(dtype),
158
+ guidance=guidance.to(dtype).to(device),
159
+ y=pooled_prompt_embeds.to(dtype).to(device),
160
+ txt=prompt_embeds.to(dtype).to(device),
161
+ txt_ids=text_ids.to(dtype).to(device),
162
+ img_ids=latent_image_ids.to(dtype).to(device),
163
+ )
 
 
164
  # Yield intermediate result
165
  latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
166
  latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
 
186
  repo_flow: str
187
  repo_ae: str
188
  repo_id_ae: str
189
+ ckpt_path: str
190
 
191
 
192
  config = ModelSpec(
 
194
  repo_flow="flux-mini.safetensors",
195
  repo_id_ae="black-forest-labs/FLUX.1-dev",
196
  repo_ae="ae.safetensors",
197
+ ckpt_path=None,
198
  params=FluxParams(
199
  in_channels=64,
200
  vec_in_dim=768,
 
213
 
214
 
215
  def load_flow_model2(config, device: str = "cuda", hf_download: bool = True):
216
+ if (config.ckpt_path is None
217
+ and config.repo_id is not None
218
  and config.repo_flow is not None
219
  and hf_download
220
  ):
221
  ckpt_path = hf_hub_download(config.repo_id, config.repo_flow.replace("sft", "safetensors"))
222
+ else:
223
+ ckpt_path = config.ckpt_path
224
 
225
  model = Flux(config.params)
226
  if ckpt_path is not None:
 
233
  device = "cuda" if torch.cuda.is_available() else "cpu"
234
 
235
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler")
236
+ good_vae = vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
237
+ text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder", torch_dtype=dtype).to(device)
238
  tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer")
239
+ text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2", torch_dtype=dtype).to(device)
240
  tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2")
241
+ transformer = load_flow_model2(config, device).to(dtype).to(device)
242
 
243
  pipe = FluxPipeline(
244
  scheduler,
 
252
  torch.cuda.empty_cache()
253
 
254
  MAX_SEED = np.iinfo(np.int32).max
255
+ MAX_IMAGE_SIZE = 1024
256
 
257
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
258
 
259
  @spaces.GPU(duration=75)
260
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
261
+ torch.cuda.empty_cache()
262
  if randomize_seed:
263
  seed = random.randint(0, MAX_SEED)
264
  generator = torch.Generator().manual_seed(seed)
265
 
266
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
267
  prompt=prompt,
268
+ guidance_scale=guidance_scale0,
269
  num_inference_steps=num_inference_steps,
270
  width=width,
271
  height=height,
 
273
  output_type="pil",
274
  good_vae=good_vae,
275
  ):
276
+ pass
277
+ return img, seed
278
 
279
  examples = [
280
+ "a lovely cat",
281
  "thousands of luminous oysters on a shore reflecting and refracting the sunset",
282
+ "profile of sad Socrates, full body, high detail, dramatic scene, Epic dynamic action, wide angle, cinematic, hyper realistic, concept art, warm muted tones as painted by Bernie Wrightson, Frank Frazetta,"
 
283
  ]
284
 
285
  css="""
 
374
  outputs = [result, seed]
375
  )
376
 
377
+ demo.launch(server_name='0.0.0.0', server_port=12345)