Spaces:
Running
on
Zero
Running
on
Zero
update main file, fix local bugs
Browse files
app.py
CHANGED
@@ -6,8 +6,8 @@ import numpy as np
|
|
6 |
import random
|
7 |
import spaces
|
8 |
import torch
|
9 |
-
from huggingface_hub import hf_hub_download
|
10 |
from safetensors.torch import load_file as load_sft
|
|
|
11 |
|
12 |
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, FluxPipeline
|
13 |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
@@ -29,10 +29,10 @@ def calculate_shift(
|
|
29 |
|
30 |
def retrieve_timesteps(
|
31 |
scheduler,
|
32 |
-
num_inference_steps: Optional
|
33 |
-
device: Optional
|
34 |
-
timesteps: Optional
|
35 |
-
sigmas: Optional
|
36 |
**kwargs,
|
37 |
):
|
38 |
if timesteps is not None and sigmas is not None:
|
@@ -54,23 +54,23 @@ def retrieve_timesteps(
|
|
54 |
@torch.inference_mode()
|
55 |
def flux_pipe_call_that_returns_an_iterable_of_images(
|
56 |
self,
|
57 |
-
prompt
|
58 |
-
prompt_2
|
59 |
-
height
|
60 |
-
width
|
61 |
num_inference_steps: int = 28,
|
62 |
-
timesteps
|
63 |
guidance_scale: float = 3.5,
|
64 |
-
num_images_per_prompt
|
65 |
-
generator
|
66 |
-
latents
|
67 |
-
prompt_embeds
|
68 |
-
pooled_prompt_embeds
|
69 |
-
output_type
|
70 |
-
return_dict
|
71 |
-
joint_attention_kwargs
|
72 |
-
max_sequence_length
|
73 |
-
good_vae
|
74 |
):
|
75 |
height = height or self.default_sample_size * self.vae_scale_factor
|
76 |
width = width or self.default_sample_size * self.vae_scale_factor
|
@@ -92,7 +92,10 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
|
|
92 |
|
93 |
# 2. Define call parameters
|
94 |
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
95 |
-
|
|
|
|
|
|
|
96 |
|
97 |
# 3. Encode prompt
|
98 |
lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
|
@@ -107,7 +110,7 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
|
|
107 |
lora_scale=lora_scale,
|
108 |
)
|
109 |
# 4. Prepare latent variables
|
110 |
-
num_channels_latents = self.transformer.
|
111 |
latents, latent_image_ids = self.prepare_latents(
|
112 |
batch_size * num_images_per_prompt,
|
113 |
num_channels_latents,
|
@@ -139,26 +142,25 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
|
|
139 |
self._num_timesteps = len(timesteps)
|
140 |
|
141 |
# Handle guidance
|
142 |
-
guidance = torch.full([1], guidance_scale, device=device, dtype=
|
143 |
|
|
|
144 |
# 6. Denoising loop
|
145 |
for i, t in enumerate(timesteps):
|
146 |
if self.interrupt:
|
147 |
continue
|
148 |
|
149 |
-
timestep = t.expand(latents.shape[0]).to(
|
150 |
|
151 |
noise_pred = self.transformer(
|
152 |
-
|
153 |
-
|
154 |
-
guidance=guidance,
|
155 |
-
|
156 |
-
|
157 |
-
txt_ids=text_ids,
|
158 |
-
img_ids=latent_image_ids,
|
159 |
-
|
160 |
-
return_dict=False,
|
161 |
-
)[0]
|
162 |
# Yield intermediate result
|
163 |
latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
164 |
latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
@@ -184,6 +186,7 @@ class ModelSpec:
|
|
184 |
repo_flow: str
|
185 |
repo_ae: str
|
186 |
repo_id_ae: str
|
|
|
187 |
|
188 |
|
189 |
config = ModelSpec(
|
@@ -191,6 +194,7 @@ config = ModelSpec(
|
|
191 |
repo_flow="flux-mini.safetensors",
|
192 |
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
193 |
repo_ae="ae.safetensors",
|
|
|
194 |
params=FluxParams(
|
195 |
in_channels=64,
|
196 |
vec_in_dim=768,
|
@@ -209,11 +213,14 @@ config = ModelSpec(
|
|
209 |
|
210 |
|
211 |
def load_flow_model2(config, device: str = "cuda", hf_download: bool = True):
|
212 |
-
if (config.
|
|
|
213 |
and config.repo_flow is not None
|
214 |
and hf_download
|
215 |
):
|
216 |
ckpt_path = hf_hub_download(config.repo_id, config.repo_flow.replace("sft", "safetensors"))
|
|
|
|
|
217 |
|
218 |
model = Flux(config.params)
|
219 |
if ckpt_path is not None:
|
@@ -226,12 +233,12 @@ dtype = torch.bfloat16
|
|
226 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
227 |
|
228 |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler")
|
229 |
-
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
|
230 |
-
text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder").to(device)
|
231 |
tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer")
|
232 |
-
text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2").to(device)
|
233 |
tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2")
|
234 |
-
transformer = load_flow_model2(config, device)
|
235 |
|
236 |
pipe = FluxPipeline(
|
237 |
scheduler,
|
@@ -245,19 +252,20 @@ pipe = FluxPipeline(
|
|
245 |
torch.cuda.empty_cache()
|
246 |
|
247 |
MAX_SEED = np.iinfo(np.int32).max
|
248 |
-
MAX_IMAGE_SIZE =
|
249 |
|
250 |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
251 |
|
252 |
@spaces.GPU(duration=75)
|
253 |
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
|
|
|
254 |
if randomize_seed:
|
255 |
seed = random.randint(0, MAX_SEED)
|
256 |
generator = torch.Generator().manual_seed(seed)
|
257 |
|
258 |
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
259 |
prompt=prompt,
|
260 |
-
guidance_scale=
|
261 |
num_inference_steps=num_inference_steps,
|
262 |
width=width,
|
263 |
height=height,
|
@@ -265,12 +273,13 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidan
|
|
265 |
output_type="pil",
|
266 |
good_vae=good_vae,
|
267 |
):
|
268 |
-
|
|
|
269 |
|
270 |
examples = [
|
|
|
271 |
"thousands of luminous oysters on a shore reflecting and refracting the sunset",
|
272 |
-
"profile of sad Socrates, full body, high detail, dramatic scene, Epic dynamic action, wide angle, cinematic, hyper realistic, concept art, warm muted tones as painted by Bernie Wrightson, Frank Frazetta,"
|
273 |
-
"ghosts, astronauts, robots, cats, superhero costumes, line drawings, naive, simple, exploring a strange planet, coloured pencil crayons, , black canvas background, drawn by 5 year old child",
|
274 |
]
|
275 |
|
276 |
css="""
|
@@ -365,4 +374,4 @@ A 3.2B param rectified flow transformer distilled from [FLUX.1 [dev]](https://bl
|
|
365 |
outputs = [result, seed]
|
366 |
)
|
367 |
|
368 |
-
demo.launch()
|
|
|
6 |
import random
|
7 |
import spaces
|
8 |
import torch
|
|
|
9 |
from safetensors.torch import load_file as load_sft
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
|
12 |
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, FluxPipeline
|
13 |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
|
|
29 |
|
30 |
def retrieve_timesteps(
|
31 |
scheduler,
|
32 |
+
num_inference_steps: Optional = None,
|
33 |
+
device: Optional = None,
|
34 |
+
timesteps: Optional = None,
|
35 |
+
sigmas: Optional = None,
|
36 |
**kwargs,
|
37 |
):
|
38 |
if timesteps is not None and sigmas is not None:
|
|
|
54 |
@torch.inference_mode()
|
55 |
def flux_pipe_call_that_returns_an_iterable_of_images(
|
56 |
self,
|
57 |
+
prompt = None,
|
58 |
+
prompt_2 = None,
|
59 |
+
height = None,
|
60 |
+
width = None,
|
61 |
num_inference_steps: int = 28,
|
62 |
+
timesteps = None,
|
63 |
guidance_scale: float = 3.5,
|
64 |
+
num_images_per_prompt = 1,
|
65 |
+
generator = None,
|
66 |
+
latents = None,
|
67 |
+
prompt_embeds = None,
|
68 |
+
pooled_prompt_embeds = None,
|
69 |
+
output_type = "pil",
|
70 |
+
return_dict = True,
|
71 |
+
joint_attention_kwargs = None,
|
72 |
+
max_sequence_length = 512,
|
73 |
+
good_vae = None,
|
74 |
):
|
75 |
height = height or self.default_sample_size * self.vae_scale_factor
|
76 |
width = width or self.default_sample_size * self.vae_scale_factor
|
|
|
92 |
|
93 |
# 2. Define call parameters
|
94 |
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
95 |
+
try:
|
96 |
+
device = self._execution_device
|
97 |
+
except:
|
98 |
+
device = torch.device('cuda:0')
|
99 |
|
100 |
# 3. Encode prompt
|
101 |
lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
|
|
|
110 |
lora_scale=lora_scale,
|
111 |
)
|
112 |
# 4. Prepare latent variables
|
113 |
+
num_channels_latents = self.transformer.in_channels // 4
|
114 |
latents, latent_image_ids = self.prepare_latents(
|
115 |
batch_size * num_images_per_prompt,
|
116 |
num_channels_latents,
|
|
|
142 |
self._num_timesteps = len(timesteps)
|
143 |
|
144 |
# Handle guidance
|
145 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=dtype).expand(latents.shape[0]) # if self.transformer.params.guidance_embeds else None
|
146 |
|
147 |
+
# print(latent_image_ids.shape, text_ids.shape, pooled_prompt_embeds.shape)
|
148 |
# 6. Denoising loop
|
149 |
for i, t in enumerate(timesteps):
|
150 |
if self.interrupt:
|
151 |
continue
|
152 |
|
153 |
+
timestep = t.expand(latents.shape[0]).to(dtype)
|
154 |
|
155 |
noise_pred = self.transformer(
|
156 |
+
img=latents.to(dtype).to(device),
|
157 |
+
timesteps=(timestep / 1000).to(dtype),
|
158 |
+
guidance=guidance.to(dtype).to(device),
|
159 |
+
y=pooled_prompt_embeds.to(dtype).to(device),
|
160 |
+
txt=prompt_embeds.to(dtype).to(device),
|
161 |
+
txt_ids=text_ids.to(dtype).to(device),
|
162 |
+
img_ids=latent_image_ids.to(dtype).to(device),
|
163 |
+
)
|
|
|
|
|
164 |
# Yield intermediate result
|
165 |
latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
166 |
latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
|
|
186 |
repo_flow: str
|
187 |
repo_ae: str
|
188 |
repo_id_ae: str
|
189 |
+
ckpt_path: str
|
190 |
|
191 |
|
192 |
config = ModelSpec(
|
|
|
194 |
repo_flow="flux-mini.safetensors",
|
195 |
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
196 |
repo_ae="ae.safetensors",
|
197 |
+
ckpt_path=None,
|
198 |
params=FluxParams(
|
199 |
in_channels=64,
|
200 |
vec_in_dim=768,
|
|
|
213 |
|
214 |
|
215 |
def load_flow_model2(config, device: str = "cuda", hf_download: bool = True):
|
216 |
+
if (config.ckpt_path is None
|
217 |
+
and config.repo_id is not None
|
218 |
and config.repo_flow is not None
|
219 |
and hf_download
|
220 |
):
|
221 |
ckpt_path = hf_hub_download(config.repo_id, config.repo_flow.replace("sft", "safetensors"))
|
222 |
+
else:
|
223 |
+
ckpt_path = config.ckpt_path
|
224 |
|
225 |
model = Flux(config.params)
|
226 |
if ckpt_path is not None:
|
|
|
233 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
234 |
|
235 |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler")
|
236 |
+
good_vae = vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
|
237 |
+
text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder", torch_dtype=dtype).to(device)
|
238 |
tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer")
|
239 |
+
text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2", torch_dtype=dtype).to(device)
|
240 |
tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2")
|
241 |
+
transformer = load_flow_model2(config, device).to(dtype).to(device)
|
242 |
|
243 |
pipe = FluxPipeline(
|
244 |
scheduler,
|
|
|
252 |
torch.cuda.empty_cache()
|
253 |
|
254 |
MAX_SEED = np.iinfo(np.int32).max
|
255 |
+
MAX_IMAGE_SIZE = 1024
|
256 |
|
257 |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
258 |
|
259 |
@spaces.GPU(duration=75)
|
260 |
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
|
261 |
+
torch.cuda.empty_cache()
|
262 |
if randomize_seed:
|
263 |
seed = random.randint(0, MAX_SEED)
|
264 |
generator = torch.Generator().manual_seed(seed)
|
265 |
|
266 |
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
267 |
prompt=prompt,
|
268 |
+
guidance_scale=guidance_scale0,
|
269 |
num_inference_steps=num_inference_steps,
|
270 |
width=width,
|
271 |
height=height,
|
|
|
273 |
output_type="pil",
|
274 |
good_vae=good_vae,
|
275 |
):
|
276 |
+
pass
|
277 |
+
return img, seed
|
278 |
|
279 |
examples = [
|
280 |
+
"a lovely cat",
|
281 |
"thousands of luminous oysters on a shore reflecting and refracting the sunset",
|
282 |
+
"profile of sad Socrates, full body, high detail, dramatic scene, Epic dynamic action, wide angle, cinematic, hyper realistic, concept art, warm muted tones as painted by Bernie Wrightson, Frank Frazetta,"
|
|
|
283 |
]
|
284 |
|
285 |
css="""
|
|
|
374 |
outputs = [result, seed]
|
375 |
)
|
376 |
|
377 |
+
demo.launch(server_name='0.0.0.0', server_port=12345)
|