Linaqruf commited on
Commit
769de07
·
1 Parent(s): dd37f9c

dev -> prod

Browse files
Files changed (9) hide show
  1. .gitignore +24 -0
  2. LICENSE +21 -0
  3. README.md +6 -7
  4. app.py +482 -0
  5. config.py +33 -0
  6. config.toml +114 -0
  7. requirements.txt +8 -0
  8. style.css +212 -0
  9. utils.py +187 -0
.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python cache files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Virtual environment
7
+ venv/
8
+ env/
9
+ ENV/
10
+
11
+ # Model checkpoints
12
+ checkpoints/
13
+
14
+ # Outputs
15
+ outputs/
16
+
17
+ # IDE specific files
18
+ .vscode/
19
+ .idea/
20
+ *.swp
21
+ *.swo
22
+
23
+ # Environment variables
24
+ .env
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,13 @@
1
  ---
2
- title: Animagine Xl 4.0
3
- emoji: 👀
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.13.1
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
- short_description: The Ultimate Anime-themed SDXL Model
 
 
12
  ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Animagine XL 4.0
3
+ emoji: 🌍
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.1
8
  app_file: app.py
 
9
  license: mit
10
+ pinned: false
11
+ suggested_hardware: a10g-small
12
+ short_description: The Ultimate Anime-themed SDXL model
13
  ---
 
 
app.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import json
7
+ import spaces
8
+ import config
9
+ import utils
10
+ import logging
11
+ from PIL import Image, PngImagePlugin
12
+ from datetime import datetime
13
+ from diffusers.models import AutoencoderKL
14
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
15
+ from config import (
16
+ MODEL,
17
+ MIN_IMAGE_SIZE,
18
+ MAX_IMAGE_SIZE,
19
+ USE_TORCH_COMPILE,
20
+ ENABLE_CPU_OFFLOAD,
21
+ OUTPUT_DIR,
22
+ DEFAULT_NEGATIVE_PROMPT,
23
+ DEFAULT_ASPECT_RATIO,
24
+ examples,
25
+ sampler_list,
26
+ aspect_ratios,
27
+ style_list,
28
+ )
29
+ import time
30
+ from typing import List, Dict, Tuple, Optional
31
+
32
+ # Enhanced logging configuration
33
+ logging.basicConfig(
34
+ level=logging.INFO,
35
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
36
+ datefmt='%Y-%m-%d %H:%M:%S'
37
+ )
38
+ logger = logging.getLogger(__name__)
39
+
40
+ # Constants
41
+ IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
42
+ HF_TOKEN = os.getenv("HF_TOKEN")
43
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
44
+
45
+ # PyTorch settings for better performance and determinism
46
+ torch.backends.cudnn.deterministic = True
47
+ torch.backends.cudnn.benchmark = False
48
+ torch.backends.cuda.matmul.allow_tf32 = True
49
+
50
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
51
+ logger.info(f"Using device: {device}")
52
+
53
+ class GenerationError(Exception):
54
+ """Custom exception for generation errors"""
55
+ pass
56
+
57
+ def validate_prompt(prompt: str) -> str:
58
+ """Validate and clean up the input prompt."""
59
+ if not isinstance(prompt, str):
60
+ raise GenerationError("Prompt must be a string")
61
+ try:
62
+ # Ensure proper UTF-8 encoding/decoding
63
+ prompt = prompt.encode('utf-8').decode('utf-8')
64
+ # Add space between ! and ,
65
+ prompt = prompt.replace("!,", "! ,")
66
+ except UnicodeError:
67
+ raise GenerationError("Invalid characters in prompt")
68
+
69
+ # Only check if the prompt is completely empty or only whitespace
70
+ if not prompt or prompt.isspace():
71
+ raise GenerationError("Prompt cannot be empty")
72
+ return prompt.strip()
73
+
74
+ def validate_dimensions(width: int, height: int) -> None:
75
+ """Validate image dimensions."""
76
+ if not MIN_IMAGE_SIZE <= width <= MAX_IMAGE_SIZE:
77
+ raise GenerationError(f"Width must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}")
78
+
79
+ if not MIN_IMAGE_SIZE <= height <= MAX_IMAGE_SIZE:
80
+ raise GenerationError(f"Height must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}")
81
+
82
+ @spaces.GPU
83
+ def generate(
84
+ prompt: str,
85
+ negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
86
+ seed: int = 0,
87
+ custom_width: int = 1024,
88
+ custom_height: int = 1024,
89
+ guidance_scale: float = 6.0,
90
+ num_inference_steps: int = 25,
91
+ sampler: str = "Euler a",
92
+ aspect_ratio_selector: str = DEFAULT_ASPECT_RATIO,
93
+ style_selector: str = "(None)",
94
+ use_upscaler: bool = False,
95
+ upscaler_strength: float = 0.55,
96
+ upscale_by: float = 1.5,
97
+ add_quality_tags: bool = True,
98
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
99
+ ) -> Tuple[List[str], Dict]:
100
+ """Generate images based on the given parameters."""
101
+ start_time = time.time()
102
+ upscaler_pipe = None
103
+ backup_scheduler = None
104
+
105
+ try:
106
+ # Memory management
107
+ torch.cuda.empty_cache()
108
+ gc.collect()
109
+
110
+ # Input validation
111
+ prompt = validate_prompt(prompt)
112
+ if negative_prompt:
113
+ negative_prompt = negative_prompt.encode('utf-8').decode('utf-8')
114
+
115
+ validate_dimensions(custom_width, custom_height)
116
+
117
+ # Set up generation
118
+ generator = utils.seed_everything(seed)
119
+ width, height = utils.aspect_ratio_handler(
120
+ aspect_ratio_selector,
121
+ custom_width,
122
+ custom_height,
123
+ )
124
+
125
+ # Process prompts
126
+ if add_quality_tags:
127
+ prompt = "masterpiece, high score, great score, absurdres, {prompt}".format(prompt=prompt)
128
+
129
+ prompt, negative_prompt = utils.preprocess_prompt(
130
+ styles, style_selector, prompt, negative_prompt
131
+ )
132
+
133
+ width, height = utils.preprocess_image_dimensions(width, height)
134
+
135
+ # Set up pipeline
136
+ backup_scheduler = pipe.scheduler
137
+ pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
138
+
139
+ if use_upscaler:
140
+ upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
141
+
142
+ # Prepare metadata
143
+ metadata = {
144
+ "prompt": prompt,
145
+ "negative_prompt": negative_prompt,
146
+ "resolution": f"{width} x {height}",
147
+ "guidance_scale": guidance_scale,
148
+ "num_inference_steps": num_inference_steps,
149
+ "style_preset": style_selector,
150
+ "seed": seed,
151
+ "sampler": sampler,
152
+ "Model": "Animagine XL 4.0",
153
+ "Model hash": "e3c47aedb0",
154
+ }
155
+
156
+ if use_upscaler:
157
+ new_width = int(width * upscale_by)
158
+ new_height = int(height * upscale_by)
159
+ metadata["use_upscaler"] = {
160
+ "upscale_method": "nearest-exact",
161
+ "upscaler_strength": upscaler_strength,
162
+ "upscale_by": upscale_by,
163
+ "new_resolution": f"{new_width} x {new_height}",
164
+ }
165
+ else:
166
+ metadata["use_upscaler"] = None
167
+
168
+ logger.info(f"Starting generation with parameters: {json.dumps(metadata, indent=4)}")
169
+
170
+ # Generate images
171
+ if use_upscaler:
172
+ latents = pipe(
173
+ prompt=prompt,
174
+ negative_prompt=negative_prompt,
175
+ width=width,
176
+ height=height,
177
+ guidance_scale=guidance_scale,
178
+ num_inference_steps=num_inference_steps,
179
+ generator=generator,
180
+ output_type="latent",
181
+ ).images
182
+ upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
183
+ images = upscaler_pipe(
184
+ prompt=prompt,
185
+ negative_prompt=negative_prompt,
186
+ image=upscaled_latents,
187
+ guidance_scale=guidance_scale,
188
+ num_inference_steps=num_inference_steps,
189
+ strength=upscaler_strength,
190
+ generator=generator,
191
+ output_type="pil",
192
+ ).images
193
+ else:
194
+ images = pipe(
195
+ prompt=prompt,
196
+ negative_prompt=negative_prompt,
197
+ width=width,
198
+ height=height,
199
+ guidance_scale=guidance_scale,
200
+ num_inference_steps=num_inference_steps,
201
+ generator=generator,
202
+ output_type="pil",
203
+ ).images
204
+
205
+ # Save images
206
+ if images:
207
+ total = len(images)
208
+ image_paths = []
209
+ for idx, image in enumerate(images, 1):
210
+ progress(idx/total, desc="Saving images...")
211
+ path = utils.save_image(image, metadata, OUTPUT_DIR, IS_COLAB)
212
+ image_paths.append(path)
213
+ logger.info(f"Image {idx}/{total} saved as {path}")
214
+
215
+ generation_time = time.time() - start_time
216
+ logger.info(f"Generation completed successfully in {generation_time:.2f} seconds")
217
+ metadata["generation_time"] = f"{generation_time:.2f}s"
218
+
219
+ return image_paths, metadata
220
+
221
+ except GenerationError as e:
222
+ logger.warning(f"Generation validation error: {str(e)}")
223
+ raise gr.Error(str(e))
224
+ except Exception as e:
225
+ logger.exception("Unexpected error during generation")
226
+ raise gr.Error(f"Generation failed: {str(e)}")
227
+ finally:
228
+ # Cleanup
229
+ torch.cuda.empty_cache()
230
+ gc.collect()
231
+
232
+ if upscaler_pipe is not None:
233
+ del upscaler_pipe
234
+
235
+ if backup_scheduler is not None and pipe is not None:
236
+ pipe.scheduler = backup_scheduler
237
+
238
+ utils.free_memory()
239
+
240
+ # Model initialization
241
+ if torch.cuda.is_available():
242
+ try:
243
+ logger.info("Loading VAE and pipeline...")
244
+ vae = AutoencoderKL.from_pretrained(
245
+ "madebyollin/sdxl-vae-fp16-fix",
246
+ torch_dtype=torch.float16,
247
+ )
248
+ pipe = utils.load_pipeline(MODEL, device, vae=vae)
249
+ logger.info("Pipeline loaded successfully on GPU!")
250
+ except Exception as e:
251
+ logger.error(f"Error loading VAE, falling back to default: {e}")
252
+ pipe = utils.load_pipeline(MODEL, device)
253
+ else:
254
+ logger.warning("CUDA not available, running on CPU")
255
+ pipe = None
256
+
257
+ # Process styles
258
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
259
+
260
+ with gr.Blocks(css="style.css", theme="Nymbo/Nymbo_Theme_5") as demo:
261
+ gr.HTML(
262
+ """
263
+ <div class="header">
264
+ <div class="title">ANIM4GINE</div>
265
+ <div class="subtitle">Gradio demo for <a href="https://huggingface.co/CagliostroLab/Animagine-XL-4.0" target="_blank">Animagine XL 4.0</a></div>
266
+ </div>
267
+ """,
268
+ )
269
+
270
+ with gr.Row():
271
+ with gr.Column(scale=2):
272
+ with gr.Group():
273
+ prompt = gr.Text(
274
+ label="Prompt",
275
+ max_lines=5,
276
+ placeholder="Describe what you want to generate",
277
+ info="Enter your image generation prompt here. Be specific and descriptive for better results.",
278
+ )
279
+ negative_prompt = gr.Text(
280
+ label="Negative Prompt",
281
+ max_lines=5,
282
+ placeholder="Describe what you want to avoid",
283
+ value=DEFAULT_NEGATIVE_PROMPT,
284
+ info="Specify elements you don't want in the image.",
285
+ )
286
+ add_quality_tags = gr.Checkbox(
287
+ label="Quality Tags",
288
+ value=True,
289
+ info="Add quality-enhancing tags to your prompt automatically.",
290
+ )
291
+ with gr.Accordion(label="More Settings", open=False):
292
+ with gr.Group():
293
+ aspect_ratio_selector = gr.Radio(
294
+ label="Aspect Ratio",
295
+ choices=aspect_ratios,
296
+ value=DEFAULT_ASPECT_RATIO,
297
+ container=True,
298
+ info="Choose the dimensions of your image.",
299
+ )
300
+ with gr.Group(visible=False) as custom_resolution:
301
+ with gr.Row():
302
+ custom_width = gr.Slider(
303
+ label="Width",
304
+ minimum=MIN_IMAGE_SIZE,
305
+ maximum=MAX_IMAGE_SIZE,
306
+ step=8,
307
+ value=1024,
308
+ info=f"Image width (must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})",
309
+ )
310
+ custom_height = gr.Slider(
311
+ label="Height",
312
+ minimum=MIN_IMAGE_SIZE,
313
+ maximum=MAX_IMAGE_SIZE,
314
+ step=8,
315
+ value=1024,
316
+ info=f"Image height (must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})",
317
+ )
318
+ with gr.Group():
319
+ use_upscaler = gr.Checkbox(
320
+ label="Use Upscaler",
321
+ value=False,
322
+ info="Enable high-resolution upscaling.",
323
+ )
324
+ with gr.Row() as upscaler_row:
325
+ upscaler_strength = gr.Slider(
326
+ label="Strength",
327
+ minimum=0,
328
+ maximum=1,
329
+ step=0.05,
330
+ value=0.55,
331
+ visible=False,
332
+ info="Control how much the upscaler affects the final image.",
333
+ )
334
+ upscale_by = gr.Slider(
335
+ label="Upscale by",
336
+ minimum=1,
337
+ maximum=1.5,
338
+ step=0.1,
339
+ value=1.5,
340
+ visible=False,
341
+ info="Multiplier for the final image resolution.",
342
+ )
343
+ with gr.Accordion(label="Advanced Parameters", open=False):
344
+ with gr.Group():
345
+ style_selector = gr.Dropdown(
346
+ label="Style Preset",
347
+ interactive=True,
348
+ choices=list(styles.keys()),
349
+ value="(None)",
350
+ info="Apply a predefined style to your generation.",
351
+ )
352
+ with gr.Group():
353
+ sampler = gr.Dropdown(
354
+ label="Sampler",
355
+ choices=sampler_list,
356
+ interactive=True,
357
+ value="Euler a",
358
+ info="Different samplers can produce varying results.",
359
+ )
360
+ with gr.Group():
361
+ seed = gr.Slider(
362
+ label="Seed",
363
+ minimum=0,
364
+ maximum=utils.MAX_SEED,
365
+ step=1,
366
+ value=0,
367
+ info="Set a specific seed for reproducible results.",
368
+ )
369
+ randomize_seed = gr.Checkbox(
370
+ label="Randomize seed",
371
+ value=True,
372
+ info="Generate a new random seed for each image.",
373
+ )
374
+ with gr.Group():
375
+ with gr.Row():
376
+ guidance_scale = gr.Slider(
377
+ label="Guidance scale",
378
+ minimum=1,
379
+ maximum=12,
380
+ step=0.1,
381
+ value=6.0,
382
+ info="Higher values make the image more closely match your prompt.",
383
+ )
384
+ num_inference_steps = gr.Slider(
385
+ label="Number of inference steps",
386
+ minimum=1,
387
+ maximum=50,
388
+ step=1,
389
+ value=25,
390
+ info="More steps generally mean higher quality but slower generation.",
391
+ )
392
+
393
+ with gr.Column(scale=3):
394
+ with gr.Blocks():
395
+ run_button = gr.Button("Generate", variant="primary", elem_id="generate-button")
396
+ result = gr.Gallery(
397
+ label="Generated Images",
398
+ columns=1,
399
+ height='768px',
400
+ preview=True,
401
+ show_label=True,
402
+ )
403
+ with gr.Accordion(label="Generation Parameters", open=False):
404
+ gr_metadata = gr.JSON(
405
+ label="Image Metadata",
406
+ show_label=True,
407
+ )
408
+ gr.Examples(
409
+ examples=examples,
410
+ inputs=prompt,
411
+ outputs=[result, gr_metadata],
412
+ fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
413
+ cache_examples=CACHE_EXAMPLES,
414
+ )
415
+
416
+ # Discord button in a new full row
417
+ with gr.Row():
418
+ gr.HTML(
419
+ """
420
+ <a href="https://discord.com/invite/cqh9tZgbGc" target="_blank" class="discord-btn">
421
+ <svg class="discord-icon" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 127.14 96.36"><path fill="currentColor" d="M107.7,8.07A105.15,105.15,0,0,0,81.47,0a72.06,72.06,0,0,0-3.36,6.83A97.68,97.68,0,0,0,49,6.83,72.37,72.37,0,0,0,45.64,0,105.89,105.89,0,0,0,19.39,8.09C2.79,32.65-1.71,56.6.54,80.21h0A105.73,105.73,0,0,0,32.71,96.36,77.7,77.7,0,0,0,39.6,85.25a68.42,68.42,0,0,1-10.85-5.18c.91-.66,1.8-1.34,2.66-2a75.57,75.57,0,0,0,64.32,0c.87.71,1.76,1.39,2.66,2a68.68,68.68,0,0,1-10.87,5.19,77,77,0,0,0,6.89,11.1A105.25,105.25,0,0,0,126.6,80.22h0C129.24,52.84,122.09,29.11,107.7,8.07ZM42.45,65.69C36.18,65.69,31,60,31,53s5-12.74,11.43-12.74S54,46,53.89,53,48.84,65.69,42.45,65.69Zm42.24,0C78.41,65.69,73.25,60,73.25,53s5-12.74,11.44-12.74S96.23,46,96.12,53,91.08,65.69,84.69,65.69Z"/></svg>
422
+ <span class="discord-text">Join our Discord Server</span>
423
+ </a>
424
+ """
425
+ )
426
+
427
+ use_upscaler.change(
428
+ fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
429
+ inputs=use_upscaler,
430
+ outputs=[upscaler_strength, upscale_by],
431
+ queue=False,
432
+ api_name=False,
433
+ )
434
+ aspect_ratio_selector.change(
435
+ fn=lambda x: gr.update(visible=x == "Custom"),
436
+ inputs=aspect_ratio_selector,
437
+ outputs=custom_resolution,
438
+ queue=False,
439
+ api_name=False,
440
+ )
441
+
442
+ # Combine all triggers including keyboard shortcuts
443
+ gr.on(
444
+ triggers=[
445
+ prompt.submit,
446
+ negative_prompt.submit,
447
+ run_button.click,
448
+ ],
449
+ fn=utils.randomize_seed_fn,
450
+ inputs=[seed, randomize_seed],
451
+ outputs=seed,
452
+ queue=False,
453
+ api_name=False,
454
+ ).then(
455
+ fn=lambda: gr.update(interactive=False, value="Generating..."),
456
+ outputs=run_button,
457
+ ).then(
458
+ fn=generate,
459
+ inputs=[
460
+ prompt,
461
+ negative_prompt,
462
+ seed,
463
+ custom_width,
464
+ custom_height,
465
+ guidance_scale,
466
+ num_inference_steps,
467
+ sampler,
468
+ aspect_ratio_selector,
469
+ style_selector,
470
+ use_upscaler,
471
+ upscaler_strength,
472
+ upscale_by,
473
+ add_quality_tags,
474
+ ],
475
+ outputs=[result, gr_metadata],
476
+ ).then(
477
+ fn=lambda: gr.update(interactive=True, value="Generate"),
478
+ outputs=run_button,
479
+ )
480
+
481
+ if __name__ == "__main__":
482
+ demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)
config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tomli
3
+ from typing import Dict, Any
4
+
5
+ def fix_escaping(text: str) -> str:
6
+ # When JSON is loaded, \\\\ becomes \\ automatically
7
+ # So we don't need to do any transformation
8
+ return text
9
+
10
+ def load_config() -> Dict[str, Any]:
11
+ config_path = os.path.join(os.path.dirname(__file__), 'config.toml')
12
+ with open(config_path, 'rb') as f:
13
+ config = tomli.load(f)
14
+ return config
15
+
16
+ # Load configuration
17
+ config = load_config()
18
+
19
+ # Export variables for backward compatibility
20
+ MODEL = os.getenv("MODEL", config['model']['path'])
21
+ MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", config['model']['min_image_size']))
22
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", config['model']['max_image_size']))
23
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", str(config['model']['use_torch_compile'])).lower() == "true"
24
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", str(config['model']['enable_cpu_offload'])).lower() == "true"
25
+ OUTPUT_DIR = os.getenv("OUTPUT_DIR", config['model']['output_dir'])
26
+
27
+ DEFAULT_NEGATIVE_PROMPT = config['prompts']['default_negative']
28
+ DEFAULT_ASPECT_RATIO = config['prompts']['default_aspect_ratio']
29
+
30
+ examples = config['prompts']['examples']
31
+ sampler_list = config['samplers']['list']
32
+ aspect_ratios = config['aspect_ratios']['list']
33
+ style_list = config['styles']
config.toml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [model]
2
+ path = "/workspace/animagine-xl-4.0"
3
+ min_image_size = 512
4
+ max_image_size = 2048
5
+ use_torch_compile = false
6
+ enable_cpu_offload = false
7
+ output_dir = "./outputs"
8
+
9
+ [prompts]
10
+ default_negative = "lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry"
11
+ default_aspect_ratio = "832 x 1216"
12
+ examples = [
13
+ "1girl, souryuu asuka langley, neon genesis evangelion, eyepatch, red plugsuit, sitting, on throne, crossed legs, head tilt, holding weapon, lance of longinus \\\\(evangelion\\\\), cowboy shot, depth of field, faux traditional media, painterly, impressionism, photo background",
14
+ "1boy, vash the stampede, trigun stampede, red jacket, sunglasses, gun, hand on own hip, aiming, standing, looking at viewer, upper body, desert, cliff, cowboy shot",
15
+ "1girl, vertin \\(reverse:1999\\), reverse:1999, black umbrella, headwear, suitcase, looking at viewer, rain, night, city, bridge, from side, dutch angle, upper body",
16
+ "1girl, 1boy, c.c., lelouch vi britannia, year 2024, code geass, ascot, bare shoulders, black choker, black hair, blue flower, blue rose, breasts, bush, choker, closed mouth, collarbone, couple, dress, flower, frills, green hair, hetero, long hair, long sleeves, looking at another, medium breasts, off-shoulder dress, off shoulder, parted lips, purple ascot, purple eyes, rose, short hair, sitting, straight hair, yellow eyes",
17
+ "1girl, hatsune miku, vocaloid, blue eyes, blue hair, bowl, can, chopsticks, collared shirt, detached sleeves, eating, elbow rest, fish \\(food\\), food, holding, holding chopsticks, katsudon \\(food\\), long hair, long sleeves, looking at viewer, meal, nail polish, necktie, noodles, onigiri, plate, ramen, sashimi, shirt, shrimp, shrimp tempura, sleeveless, sleeveless shirt, solo, spring onion, tempura, twintails",
18
+ "4girls, multiple girls, gotoh hitori, ijichi nijika, kita ikuyo, yamada ryo, bocchi the rock!, ahoge, black shirt, blank eyes, blonde hair, blue eyes, blue hair, brown sweater, collared shirt, cube hair ornament, detached ahoge, empty eyes, green eyes, hair ornament, hairclip, kessoku band, long sleeves, looking at viewer, medium hair, mole, mole under eye, one side up, pink hair, pink track suit, red eyes, red hair, sailor collar, school uniform, serafuku, shirt, shuka high school uniform, side ahoge, side ponytail, sweater, sweater vest, track suit, white shirt, yellow eyes, painterly, impressionism, faux traditional media, v, double v, waving",
19
+ "1other, solo, outdoors, sky, arm up, night, earth, helmet, outstretched arm, star \\(sky\\), night sky, full moon, floating, starry sky, reaching, jumping, space, cowboy shot, ambiguous gender, spacesuit, moonlight, space helmet, astronaut, horror, black and white, monochromatic, high contrast, abstract background, dutch angle, dark, depth of field, chromatic aberration, faux traditional media"
20
+ ]
21
+
22
+ [samplers]
23
+ list = [
24
+ "DPM++ 2M Karras",
25
+ "DPM++ SDE Karras",
26
+ "DPM++ 2M SDE Karras",
27
+ "Euler",
28
+ "Euler a",
29
+ "DDIM"
30
+ ]
31
+
32
+ [aspect_ratios]
33
+ list = [
34
+ "1024 x 1024",
35
+ "1152 x 896",
36
+ "896 x 1152",
37
+ "1216 x 832",
38
+ "832 x 1216",
39
+ "1344 x 768",
40
+ "768 x 1344",
41
+ "1536 x 640",
42
+ "640 x 1536",
43
+ "Custom"
44
+ ]
45
+
46
+ [[styles]]
47
+ name = "(None)"
48
+ prompt = "{prompt}"
49
+ negative_prompt = ""
50
+
51
+ [[styles]]
52
+ name = "Anim4gine"
53
+ prompt = "{prompt}, depth of field, faux traditional media, painterly, impressionism, photo background"
54
+ negative_prompt = ""
55
+
56
+ [[styles]]
57
+ name = "Painting"
58
+ prompt = "{prompt}, painterly, painting (medium)"
59
+ negative_prompt = ""
60
+
61
+ [[styles]]
62
+ name = "Pixel art"
63
+ prompt = "{prompt}, pixel art"
64
+ negative_prompt = ""
65
+
66
+ [[styles]]
67
+ name = "1980s"
68
+ prompt = "{prompt}, 1980s (style), retro artstyle"
69
+ negative_prompt = ""
70
+
71
+ [[styles]]
72
+ name = "1990s"
73
+ prompt = "{prompt}, 1990s (style), retro artstyle"
74
+ negative_prompt = ""
75
+
76
+ [[styles]]
77
+ name = "2000s"
78
+ prompt = "{prompt}, 2000s (style), retro artstyle"
79
+ negative_prompt = ""
80
+
81
+ [[styles]]
82
+ name = "Toon"
83
+ prompt = "{prompt}, toon (style)"
84
+ negative_prompt = ""
85
+
86
+ [[styles]]
87
+ name = "Lineart"
88
+ prompt = "{prompt}, lineart, thick lineart"
89
+ negative_prompt = ""
90
+
91
+ [[styles]]
92
+ name = "Art Nouveau"
93
+ prompt = "{prompt}, art nouveau"
94
+ negative_prompt = ""
95
+
96
+ [[styles]]
97
+ name = "Western Comics"
98
+ prompt = "{prompt}, western comics (style)"
99
+ negative_prompt = ""
100
+
101
+ [[styles]]
102
+ name = "3D"
103
+ prompt = "{prompt}, 3d"
104
+ negative_prompt = ""
105
+
106
+ [[styles]]
107
+ name = "Realistic"
108
+ prompt = "{prompt}, realistic, photorealistic"
109
+ negative_prompt = ""
110
+
111
+ [[styles]]
112
+ name = "Neonpunk"
113
+ prompt = "{prompt}, neonpunk"
114
+ negative_prompt = ""
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.2.1
2
+ diffusers>=0.32.1
3
+ gradio==4.44.1
4
+ hf-transfer>=0.1.9
5
+ spaces>=0.32.0
6
+ torch>=2.4.0
7
+ transformers>=4.48.0
8
+ tomli>=2.0.1
style.css ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --title-font-size: clamp(1.5rem, 6vw, 3rem);
3
+ --subtitle-font-size: clamp(1rem, 2vw, 1.2rem);
4
+ --text-color: #fff;
5
+ --font-family: 'Helvetica Neue', sans-serif;
6
+ --gradient-primary: linear-gradient(45deg, #4EACEF, #28b485);
7
+ --primary-color: #1565c0;
8
+ --primary-hover: #1976d2;
9
+ --discord-color: #5865F2;
10
+ --discord-hover: #4752C4;
11
+ --border-radius: 12px;
12
+ --box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05);
13
+ }
14
+
15
+ body {
16
+ font-family: var(--font-family);
17
+ color: var(--text-color);
18
+ margin: 0;
19
+ padding: 0;
20
+ min-height: 100vh;
21
+ background-color: #f5f5f5;
22
+ }
23
+
24
+ .header {
25
+ text-align: center;
26
+ padding: 1rem 0;
27
+ margin-bottom: 0.5rem;
28
+ }
29
+
30
+ .title {
31
+ font-size: var(--title-font-size);
32
+ font-weight: 700;
33
+ text-transform: uppercase;
34
+ margin-bottom: 0.25rem;
35
+ background-image: var(--gradient-primary);
36
+ -webkit-text-fill-color: transparent;
37
+ -webkit-background-clip: text;
38
+ background-clip: text;
39
+ display: inline-block;
40
+ }
41
+
42
+ .subtitle {
43
+ font-size: var(--subtitle-font-size);
44
+ color: #999;
45
+ margin-bottom: 0.5rem;
46
+ }
47
+
48
+ .status {
49
+ display: none;
50
+ }
51
+
52
+ #duplicate-button {
53
+ margin: 1.5rem auto 2.5rem;
54
+ color: #fff;
55
+ background: #1565c0;
56
+ border-radius: 100vh;
57
+ padding: 0.75rem 1.5rem;
58
+ font-weight: 500;
59
+ box-shadow: 0 2px 8px rgba(21, 101, 192, 0.25);
60
+ transition: all 0.2s ease;
61
+ display: block;
62
+ }
63
+
64
+ #duplicate-button:hover {
65
+ background: #1976d2;
66
+ box-shadow: 0 4px 12px rgba(21, 101, 192, 0.35);
67
+ transform: translateY(-1px);
68
+ }
69
+
70
+ .contain {
71
+ max-width: 80%;
72
+ margin: 2rem auto;
73
+ padding: 2rem 1.5rem;
74
+ }
75
+
76
+ /* Component styling */
77
+ .gr-box {
78
+ border-radius: var(--border-radius);
79
+ border: 1px solid #e0e0e0;
80
+ background: #ffffff;
81
+ box-shadow: var(--box-shadow);
82
+ transition: box-shadow 0.2s ease;
83
+ }
84
+
85
+ .gr-box:hover {
86
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
87
+ }
88
+
89
+ .gr-button.primary {
90
+ background: var(--primary-color);
91
+ border-radius: var(--border-radius);
92
+ padding: 0.8rem 2rem;
93
+ font-weight: 500;
94
+ box-shadow: 0 2px 8px rgba(21, 101, 192, 0.25);
95
+ transition: all 0.2s ease;
96
+ text-transform: uppercase;
97
+ letter-spacing: 0.5px;
98
+ }
99
+
100
+ .gr-button.primary:hover {
101
+ background: var(--primary-hover);
102
+ box-shadow: 0 4px 12px rgba(21, 101, 192, 0.35);
103
+ transform: translateY(-1px);
104
+ }
105
+
106
+ /* Form elements */
107
+ .gr-form {
108
+ background: #fff;
109
+ padding: 1.5rem;
110
+ border-radius: var(--border-radius);
111
+ box-shadow: var(--box-shadow);
112
+ }
113
+
114
+ .gr-input, .gr-textarea {
115
+ border: 1px solid #e0e0e0;
116
+ border-radius: 8px;
117
+ padding: 0.8rem;
118
+ transition: all 0.2s ease;
119
+ }
120
+
121
+ .gr-input:focus, .gr-textarea:focus {
122
+ border-color: var(--primary-color);
123
+ box-shadow: 0 0 0 2px rgba(21, 101, 192, 0.1);
124
+ }
125
+
126
+ /* Accordion styling */
127
+ .gr-accordion {
128
+ border: none;
129
+ margin: 1rem 0;
130
+ }
131
+
132
+ .gr-accordion-header {
133
+ background: #f8f9fa;
134
+ border-radius: var(--border-radius);
135
+ padding: 1rem;
136
+ font-weight: 500;
137
+ }
138
+
139
+ /* Gallery styling */
140
+ .gr-gallery {
141
+ background: #fff;
142
+ padding: 1rem;
143
+ border-radius: var(--border-radius);
144
+ box-shadow: var(--box-shadow);
145
+ }
146
+
147
+ /* Discord button */
148
+ .discord-btn {
149
+ display: inline-flex;
150
+ align-items: center;
151
+ justify-content: center;
152
+ background-color: var(--discord-color);
153
+ color: white !important;
154
+ text-decoration: none;
155
+ padding: 12px 24px;
156
+ border-radius: var(--border-radius);
157
+ transition: all 0.3s ease;
158
+ margin-top: 1rem;
159
+ font-size: 16px;
160
+ font-weight: 500;
161
+ width: 100%;
162
+ border: none;
163
+ cursor: pointer;
164
+ box-shadow: 0 2px 8px rgba(88, 101, 242, 0.25);
165
+ }
166
+
167
+ .discord-btn:hover {
168
+ background-color: var(--discord-hover);
169
+ transform: translateY(-2px);
170
+ box-shadow: 0 4px 12px rgba(88, 101, 242, 0.4);
171
+ }
172
+
173
+ .discord-icon {
174
+ width: 24px;
175
+ height: 24px;
176
+ margin-right: 12px;
177
+ }
178
+
179
+ .discord-text {
180
+ letter-spacing: 0.5px;
181
+ }
182
+
183
+ /* Tooltips */
184
+ .gr-form small {
185
+ color: #666;
186
+ font-size: 0.875rem;
187
+ margin-top: 0.25rem;
188
+ display: block;
189
+ }
190
+
191
+ /* Responsive layout */
192
+ @media (max-width: 768px) {
193
+ .contain {
194
+ max-width: 90%;
195
+ padding: 1rem;
196
+ }
197
+
198
+ .gr-box {
199
+ margin: 0.5rem 0;
200
+ }
201
+
202
+ .gr-button.primary {
203
+ width: 100%;
204
+ }
205
+ }
206
+
207
+ @media (min-width: 1200px) {
208
+ .contain {
209
+ max-width: 1400px;
210
+ padding: 2.5rem 2rem;
211
+ }
212
+ }
utils.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import json
6
+ import torch
7
+ import uuid
8
+ from PIL import Image, PngImagePlugin
9
+ from datetime import datetime
10
+ from dataclasses import dataclass
11
+ from typing import Callable, Dict, Optional, Tuple, Any, List
12
+ from diffusers import (
13
+ DDIMScheduler,
14
+ DPMSolverMultistepScheduler,
15
+ DPMSolverSinglestepScheduler,
16
+ EulerAncestralDiscreteScheduler,
17
+ EulerDiscreteScheduler,
18
+ AutoencoderKL,
19
+ StableDiffusionXLPipeline,
20
+ )
21
+ import logging
22
+
23
+ MAX_SEED = np.iinfo(np.int32).max
24
+
25
+
26
+ @dataclass
27
+ class StyleConfig:
28
+ prompt: str
29
+ negative_prompt: str
30
+
31
+
32
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
33
+ if randomize_seed:
34
+ seed = random.randint(0, MAX_SEED)
35
+ return seed
36
+
37
+
38
+ def seed_everything(seed: int) -> torch.Generator:
39
+ torch.manual_seed(seed)
40
+ torch.cuda.manual_seed_all(seed)
41
+ np.random.seed(seed)
42
+ generator = torch.Generator()
43
+ generator.manual_seed(seed)
44
+ return generator
45
+
46
+
47
+ def parse_aspect_ratio(aspect_ratio: str) -> Optional[Tuple[int, int]]:
48
+ if aspect_ratio == "Custom":
49
+ return None
50
+ width, height = aspect_ratio.split(" x ")
51
+ return int(width), int(height)
52
+
53
+
54
+ def aspect_ratio_handler(
55
+ aspect_ratio: str, custom_width: int, custom_height: int
56
+ ) -> Tuple[int, int]:
57
+ if aspect_ratio == "Custom":
58
+ return custom_width, custom_height
59
+ else:
60
+ width, height = parse_aspect_ratio(aspect_ratio)
61
+ return width, height
62
+
63
+
64
+ def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
65
+ scheduler_factory_map = {
66
+ "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
67
+ scheduler_config, use_karras_sigmas=True
68
+ ),
69
+ "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
70
+ scheduler_config, use_karras_sigmas=True
71
+ ),
72
+ "DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
73
+ scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
74
+ ),
75
+ "Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
76
+ "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(
77
+ scheduler_config
78
+ ),
79
+ "DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
80
+ }
81
+ return scheduler_factory_map.get(name, lambda: None)()
82
+
83
+
84
+ def free_memory() -> None:
85
+ """Free up GPU and system memory."""
86
+ if torch.cuda.is_available():
87
+ torch.cuda.empty_cache()
88
+ torch.cuda.ipc_collect()
89
+ gc.collect()
90
+
91
+
92
+ def preprocess_prompt(
93
+ style_dict,
94
+ style_name: str,
95
+ positive: str,
96
+ negative: str = "",
97
+ add_style: bool = True,
98
+ ) -> Tuple[str, str]:
99
+ p, n = style_dict.get(style_name, style_dict["(None)"])
100
+
101
+ if add_style and positive.strip():
102
+ formatted_positive = p.format(prompt=positive)
103
+ else:
104
+ formatted_positive = positive
105
+
106
+ combined_negative = n
107
+ if negative.strip():
108
+ if combined_negative:
109
+ combined_negative += ", " + negative
110
+ else:
111
+ combined_negative = negative
112
+
113
+ return formatted_positive, combined_negative
114
+
115
+
116
+ def common_upscale(
117
+ samples: torch.Tensor,
118
+ width: int,
119
+ height: int,
120
+ upscale_method: str,
121
+ ) -> torch.Tensor:
122
+ return torch.nn.functional.interpolate(
123
+ samples, size=(height, width), mode=upscale_method
124
+ )
125
+
126
+
127
+ def upscale(
128
+ samples: torch.Tensor, upscale_method: str, scale_by: float
129
+ ) -> torch.Tensor:
130
+ width = round(samples.shape[3] * scale_by)
131
+ height = round(samples.shape[2] * scale_by)
132
+ return common_upscale(samples, width, height, upscale_method)
133
+
134
+
135
+ def preprocess_image_dimensions(width, height):
136
+ if width % 8 != 0:
137
+ width = width - (width % 8)
138
+ if height % 8 != 0:
139
+ height = height - (height % 8)
140
+ return width, height
141
+
142
+
143
+ def save_image(image, metadata, output_dir, is_colab):
144
+ if is_colab:
145
+ current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
146
+ filename = f"image_{current_time}.png"
147
+ else:
148
+ filename = str(uuid.uuid4()) + ".png"
149
+ os.makedirs(output_dir, exist_ok=True)
150
+ filepath = os.path.join(output_dir, filename)
151
+ metadata_str = json.dumps(metadata)
152
+ info = PngImagePlugin.PngInfo()
153
+ info.add_text("parameters", metadata_str)
154
+ image.save(filepath, "PNG", pnginfo=info)
155
+ return filepath
156
+
157
+
158
+ def is_google_colab():
159
+ try:
160
+ import google.colab
161
+ return True
162
+ except:
163
+ return False
164
+
165
+
166
+ def load_pipeline(model_name: str, device: torch.device, hf_token: Optional[str] = None, vae: Optional[AutoencoderKL] = None) -> Any:
167
+ """Load the Stable Diffusion pipeline."""
168
+ try:
169
+ pipeline = (
170
+ StableDiffusionXLPipeline.from_single_file
171
+ if model_name.endswith(".safetensors")
172
+ else StableDiffusionXLPipeline.from_pretrained
173
+ )
174
+
175
+ pipe = pipeline(
176
+ model_name,
177
+ vae=vae,
178
+ torch_dtype=torch.float16,
179
+ custom_pipeline="lpw_stable_diffusion_xl",
180
+ use_safetensors=True,
181
+ add_watermarker=False
182
+ )
183
+ pipe.to(device)
184
+ return pipe
185
+ except Exception as e:
186
+ logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True)
187
+ raise