RanM commited on
Commit
53490b9
·
verified ·
1 Parent(s): 4f086ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -2,29 +2,28 @@ import os
2
  import asyncio
3
  from concurrent.futures import ProcessPoolExecutor
4
  from io import BytesIO
5
- from PIL import Image
6
  from diffusers import StableDiffusionPipeline
7
  import gradio as gr
8
  from generate_prompts import generate_prompt
9
 
10
  # Load the model once at the start
11
  print("Loading the Stable Diffusion model...")
12
- try:
13
- model = StableDiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo")
14
- print("Model loaded successfully.")
15
- except Exception as e:
16
- print(f"Error loading model: {e}")
17
- model = None
 
 
18
 
19
  def generate_image(prompt, prompt_name):
20
  try:
21
- if model is None:
22
- raise ValueError("Model not loaded properly.")
23
-
24
  print(f"Generating image for {prompt_name} with prompt: {prompt}")
25
- output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
 
26
  print(f"Model output for {prompt_name}: {output}")
27
-
28
  if output and hasattr(output, 'images') and output.images:
29
  print(f"Image generated for {prompt_name}")
30
  image = output.images[0]
 
2
  import asyncio
3
  from concurrent.futures import ProcessPoolExecutor
4
  from io import BytesIO
 
5
  from diffusers import StableDiffusionPipeline
6
  import gradio as gr
7
  from generate_prompts import generate_prompt
8
 
9
  # Load the model once at the start
10
  print("Loading the Stable Diffusion model...")
11
+ model = StableDiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo")
12
+ print("Model loaded successfully.")
13
+
14
+ def truncate_prompt(prompt, max_length=77):
15
+ tokens = prompt.split()
16
+ if len(tokens) > max_length:
17
+ prompt = " ".join(tokens[:max_length])
18
+ return prompt
19
 
20
  def generate_image(prompt, prompt_name):
21
  try:
 
 
 
22
  print(f"Generating image for {prompt_name} with prompt: {prompt}")
23
+ truncated_prompt = truncate_prompt(prompt)
24
+ output = model(prompt=truncated_prompt, num_inference_steps=1, guidance_scale=0.0)
25
  print(f"Model output for {prompt_name}: {output}")
26
+
27
  if output and hasattr(output, 'images') and output.images:
28
  print(f"Image generated for {prompt_name}")
29
  image = output.images[0]