ruslanmv commited on
Commit
817b405
·
verified ·
1 Parent(s): f7ebc46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +329 -1
app.py CHANGED
@@ -1,5 +1,333 @@
1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  import torch
5
  import gradio as gr
@@ -221,7 +549,7 @@ if __name__ == "__main__":
221
  download_all_models()
222
  interface.launch(debug=True)
223
 
224
-
225
 
226
 
227
 
 
1
 
2
+ import gradio as gr
3
+ import numpy as np
4
+ import random
5
+ import spaces
6
+ import torch
7
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, StableDiffusionPipeline
8
+ from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
9
+ from PIL import Image
10
+ import io
11
+ import zipfile
12
+ from huggingface_hub import HfApi
13
+ _HF_TOKEN = HfApi().token
14
+
15
+ # --- Constants ---
16
+ DTYPE = torch.bfloat16
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+ MAX_IMAGE_SIZE = 2048
20
+
21
+ # --- Model Definitions ---
22
+ MODELS = {
23
+ "FLUX.1-dev": {
24
+ "model_id": "black-forest-labs/FLUX.1-dev",
25
+ "vae_id": "madebyollin/taef1",
26
+ "pipeline_class": DiffusionPipeline,
27
+ "description": """
28
+ ## FLUX.1 [dev]
29
+ 12B parameter rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/).
30
+ - **License:** [Non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)
31
+ - **Blog:** [Read the announcement](https://blackforestlabs.ai/announcing-black-forest-labs/)
32
+ - **Model Card:** [View on Hugging Face](https://huggingface.co/black-forest-labs/FLUX.1-dev)
33
+ - **Developed by:** [Blackforest Labs](https://blackforestlabs.ai/)
34
+ """,
35
+
36
+ },
37
+ "stable-diffusion-v1-5": {
38
+ "model_id": "runwayml/stable-diffusion-v1-5",
39
+ "pipeline_class": StableDiffusionPipeline,
40
+ "description": """
41
+ ## Stable Diffusion v1-5
42
+ Stable Diffusion v1-5 is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input.
43
+ - **Model Card:** [View on Hugging Face](https://huggingface.co/runwayml/stable-diffusion-v1-5)
44
+ - **Developed by:** RunwayML
45
+ - **License:** [CreativeML Open RAIL-M License](https://huggingface.co/spaces/CompVis/stable-diffusion-license)
46
+ """,
47
+ "config": {
48
+ "requires_safety_checker":False
49
+ }
50
+ },
51
+ "deliberate-v3": {
52
+ "model_id": "XpucT/deliberate",
53
+ "pipeline_class": StableDiffusionPipeline,
54
+ "description": """
55
+ ## Deliberate V3
56
+ Deliberate V3 is a model merging designed for high quality image generation.
57
+ - **Model Card:** [View on Hugging Face](https://huggingface.co/XpucT/deliberate)
58
+ - **Developed by:** XpucT
59
+ - **License:** [CreativeML Open RAIL-M License](https://huggingface.co/spaces/CompVis/stable-diffusion-license)
60
+ """,
61
+ "config": {
62
+ "requires_safety_checker":False
63
+ }
64
+ },
65
+ "dreamshaper-8": {
66
+ "model_id": "Lykon/dreamshaper-8",
67
+ "pipeline_class": StableDiffusionPipeline,
68
+ "description": """
69
+ ## DreamShaper 8
70
+ DreamShaper 8 is another iteration of the fine-tuned stable diffusion model that is capable of producing high-quality and detailed images.
71
+ - **Model Card:** [View on Hugging Face](https://huggingface.co/Lykon/dreamshaper-8)
72
+ - **Developed by:** Lykon
73
+ - **License:** [CreativeML Open RAIL-M License](https://huggingface.co/spaces/CompVis/stable-diffusion-license)
74
+ """,
75
+ "config": {
76
+ "requires_safety_checker":False
77
+ }
78
+ },
79
+ # Add other models here in the same format
80
+ }
81
 
82
+
83
+ MODELS = {
84
+ "FLUX.1-schnell": {
85
+ "pipeline_class": FluxPipeline,
86
+ "model_id": "black-forest-labs/FLUX.1-schnell",
87
+ "config": {"torch_dtype": torch.bfloat16},
88
+ "description": "**FLUX.1-schnell** is a fast and efficient model designed for quick image generation. It excels at producing high-quality images rapidly, making it ideal for applications where speed is crucial. However, its rapid generation may slightly compromise on the level of detail compared to slower, more meticulous models.",
89
+ },
90
+ }
91
+ # --- Function to pre-download models ---
92
+ def download_all_models():
93
+ print("Downloading all models...")
94
+ for model_key, config in MODELS.items():
95
+ try:
96
+ pipeline_class = config["pipeline_class"]
97
+ model_id = config["model_id"]
98
+ # Attempt to download the pipeline without loading it into memory
99
+ pipeline_class.download(model_id, token=_HF_TOKEN, **config.get("config", {}))
100
+
101
+ print(f"Model '{model_key}' downloaded successfully.")
102
+ except Exception as e:
103
+ print(f"Error downloading model '{model_key}': {e}")
104
+ print("Model download process complete.")
105
+
106
+ # --- Function to clear GPU memory ---
107
+ def clear_gpu_memory():
108
+ if DEVICE == "cuda":
109
+ with torch.no_grad():
110
+ torch.cuda.empty_cache()
111
+
112
+ # --- Function to load models and setup pipeline ---
113
+ def load_models(model_key):
114
+ clear_gpu_memory()
115
+ model_info = MODELS[model_key]
116
+ pipeline_class = model_info["pipeline_class"]
117
+
118
+ if "vae_id" in model_info:
119
+ vae = AutoencoderTiny.from_pretrained(model_info["vae_id"], torch_dtype=DTYPE).to(DEVICE)
120
+ good_vae = AutoencoderKL.from_pretrained(model_info["model_id"], subfolder="vae", torch_dtype=DTYPE).to(DEVICE)
121
+ pipe = pipeline_class.from_pretrained(model_info["model_id"], torch_dtype=DTYPE, vae=vae, **model_info.get("config", {})).to(DEVICE)
122
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
123
+ return pipe, good_vae
124
+
125
+ else:
126
+ pipe = pipeline_class.from_pretrained(model_info["model_id"], torch_dtype=DTYPE, **model_info.get("config", {})).to(DEVICE)
127
+ return pipe, None
128
+
129
+ # --- Initial model load ---
130
+ current_model_key = "FLUX.1-dev" # Start with FLUX.1-dev
131
+ pipe, good_vae = load_models(current_model_key)
132
+
133
+ # --- Inference function ---
134
+ @spaces.GPU(duration=75)
135
+ def infer(model_key, prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
136
+ global pipe, good_vae, current_model_key
137
+ if model_key != current_model_key:
138
+ pipe, good_vae = load_models(model_key)
139
+ current_model_key = model_key
140
+
141
+ if randomize_seed:
142
+ seed = random.randint(0, MAX_SEED)
143
+ generator = torch.Generator().manual_seed(seed)
144
+
145
+ images = []
146
+
147
+ if current_model_key == "FLUX.1-dev":
148
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
149
+ prompt=prompt,
150
+ guidance_scale=guidance_scale,
151
+ num_inference_steps=num_inference_steps,
152
+ width=width,
153
+ height=height,
154
+ generator=generator,
155
+ output_type="pil",
156
+ good_vae=good_vae,
157
+ ):
158
+ images.append(img)
159
+ yield img, seed, None
160
+ else:
161
+ result = pipe(
162
+ prompt=prompt,
163
+ guidance_scale=guidance_scale,
164
+ num_inference_steps=num_inference_steps,
165
+ width=width,
166
+ height=height,
167
+ generator=generator,
168
+
169
+ )
170
+ images.extend(result.images)
171
+
172
+ for img in result.images:
173
+ yield img, seed, None
174
+
175
+ if images:
176
+ zip_buffer = io.BytesIO()
177
+ with zipfile.ZipFile(zip_buffer, "w") as zf:
178
+ for i, img in enumerate(images):
179
+ img_buffer = io.BytesIO()
180
+ img.save(img_buffer, format="PNG")
181
+ zf.writestr(f"image_{i}.png", img_buffer.getvalue())
182
+ yield images[-1], seed, zip_buffer
183
+ else:
184
+ yield None, seed, None
185
+
186
+ # --- Example prompts ---
187
+ examples = [
188
+ "a tiny astronaut hatching from an egg on the moon",
189
+ "a cat holding a sign that says hello world",
190
+ "an anime illustration of a wiener schnitzel",
191
+ ]
192
+
193
+ # --- CSS for styling ---
194
+ css = """
195
+ #col-container {
196
+ margin-left: auto;
197
+ margin-right: auto;
198
+ text-align: center;
199
+ }
200
+ .text-center {
201
+ text-align: center;
202
+ }
203
+ .title {
204
+ font-size: 1.5rem;
205
+ font-weight: bold;
206
+ margin-bottom: 1rem;
207
+ }
208
+ .footer {
209
+ text-align: center;
210
+ margin-top: 1rem;
211
+ }
212
+ .description-text {
213
+ text-align: left;
214
+ margin-bottom: 1rem;
215
+
216
+ }
217
+ """
218
+
219
+ # --- Gradio Interface ---
220
+ with gr.Blocks(css=css) as demo:
221
+ with gr.Column(elem_id="col-container"):
222
+ gr.Markdown(
223
+ """
224
+ <div class="title">
225
+ 🖼️ AI Image Generator 🖼️
226
+ </div>
227
+ <div class="text-center">
228
+ Choose a model and generate stunning images with AI!
229
+ </div>
230
+ """,
231
+ )
232
+
233
+ with gr.Tab("Generator"):
234
+ with gr.Row():
235
+ model_selector = gr.Dropdown(
236
+ label="Select Model",
237
+ choices=list(MODELS.keys()),
238
+ value=current_model_key,
239
+ )
240
+
241
+ with gr.Row():
242
+ prompt = gr.Text(
243
+ label="Prompt",
244
+ show_label=False,
245
+ max_lines=1,
246
+ placeholder="Enter your prompt",
247
+ container=False,
248
+ )
249
+ run_button = gr.Button("Run", scale=0)
250
+ result = gr.Image(label="Result", show_label=False, elem_id="result-image")
251
+ download_button = gr.Button("Download Results")
252
+ with gr.Accordion("Advanced Settings", open=False):
253
+ seed = gr.Slider(
254
+ label="Seed",
255
+ minimum=0,
256
+ maximum=MAX_SEED,
257
+ step=1,
258
+ value=0,
259
+ )
260
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
261
+ with gr.Row():
262
+ width = gr.Slider(
263
+ label="Width",
264
+ minimum=256,
265
+ maximum=MAX_IMAGE_SIZE,
266
+ step=32,
267
+ value=1024,
268
+ )
269
+ height = gr.Slider(
270
+ label="Height",
271
+ minimum=256,
272
+ maximum=MAX_IMAGE_SIZE,
273
+ step=32,
274
+ value=1024,
275
+ )
276
+ with gr.Row():
277
+ guidance_scale = gr.Slider(
278
+ label="Guidance Scale",
279
+ minimum=1,
280
+ maximum=15,
281
+ step=0.1,
282
+ value=3.5,
283
+ )
284
+ num_inference_steps = gr.Slider(
285
+ label="Number of inference steps",
286
+ minimum=1,
287
+ maximum=50,
288
+ step=1,
289
+ value=28,
290
+ )
291
+ gr.Examples(
292
+ examples=examples,
293
+ fn=infer,
294
+ inputs=[model_selector, prompt],
295
+ outputs=[result, seed, download_button],
296
+ cache_examples="lazy",
297
+ )
298
+ with gr.Tab("Model Descriptions"):
299
+ for model_key, model_info in MODELS.items():
300
+ with gr.Accordion(model_key, open=False):
301
+ gr.Markdown(model_info["description"], elem_classes="description-text")
302
+
303
+ gr.Markdown(
304
+ """
305
+ <div class="footer">
306
+ <p>
307
+ ⚡ Powered by <a href="https://www.gradio.app/" target="_blank">Gradio</a> and <a href="https://huggingface.co/spaces" target="_blank">🤗 Spaces</a>.
308
+ </p>
309
+ </div>
310
+ """,
311
+ )
312
+
313
+ # --- Event handlers ---
314
+ inputs = [model_selector, prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
315
+ outputs = [result, seed, download_button]
316
+ gr.on(
317
+ triggers=[run_button.click, prompt.submit],
318
+ fn=infer,
319
+ inputs=inputs,
320
+ outputs=outputs,
321
+ )
322
+ download_event = download_button.click(lambda x: x, inputs=download_button, outputs=download_button, queue=False)
323
+
324
+ # --- Pre-download all models at startup ---
325
+ download_all_models()
326
+
327
+ # --- Launch the demo ---
328
+ demo.queue().launch(debug=True)
329
+
330
+ '''
331
  import os
332
  import torch
333
  import gradio as gr
 
549
  download_all_models()
550
  interface.launch(debug=True)
551
 
552
+ '''
553
 
554
 
555