RanM commited on
Commit
216a041
·
verified ·
1 Parent(s): 89a9597

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -28
app.py CHANGED
@@ -1,37 +1,27 @@
1
  import gradio as gr
2
  from diffusers import AutoPipelineForText2Image
3
  from generate_propmts import generate_prompt
4
- from concurrent.futures import ThreadPoolExecutor, as_completed
5
  from PIL import Image
6
- import threading
7
  import traceback
8
 
9
  # Load the model once outside of the function
10
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
11
 
12
- # Create a thread-local storage for step indices
13
- scheduler_step_storage = threading.local()
14
-
15
- def generate_image(prompt):
16
  try:
17
- # Initialize step index per thread if not already set
18
- if not hasattr(scheduler_step_storage, 'step'):
19
- scheduler_step_storage.step = 0
20
-
21
  # Use a sensible default for num_inference_steps
22
  num_inference_steps = 1 # Adjust this value as needed
23
-
24
- # Use the thread-local step index
25
- output = model(
26
- prompt=prompt,
 
27
  num_inference_steps=num_inference_steps, # Use a higher value for inference steps
28
  guidance_scale=0.0, # Typical value for guidance scale in image generation
29
  output_type="pil" # Directly get PIL Image objects
30
  )
31
 
32
- # Increment the step index after generating the image
33
- scheduler_step_storage.step += 1
34
-
35
  # Check for output validity and return
36
  if output.images:
37
  return output.images[0]
@@ -43,7 +33,7 @@ def generate_image(prompt):
43
  traceback.print_exc()
44
  return None # Return None on error to handle it gracefully in the UI
45
 
46
- def inference(sentence_mapping, character_dict, selected_style):
47
  images = []
48
  print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
49
  prompts = []
@@ -55,17 +45,12 @@ def inference(sentence_mapping, character_dict, selected_style):
55
  prompts.append(prompt)
56
  print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
57
 
58
- with ThreadPoolExecutor() as executor:
59
- futures = [executor.submit(generate_image, prompt) for prompt in prompts]
 
60
 
61
- for future in as_completed(futures):
62
- try:
63
- image = future.result()
64
- if image:
65
- images.append(image)
66
- except Exception as e:
67
- print(f"Error processing prompt: {e}")
68
- traceback.print_exc()
69
 
70
  return images
71
 
 
1
  import gradio as gr
2
  from diffusers import AutoPipelineForText2Image
3
  from generate_propmts import generate_prompt
 
4
  from PIL import Image
5
+ import asyncio
6
  import traceback
7
 
8
  # Load the model once outside of the function
9
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
10
 
11
+ async def generate_image(prompt):
 
 
 
12
  try:
 
 
 
 
13
  # Use a sensible default for num_inference_steps
14
  num_inference_steps = 1 # Adjust this value as needed
15
+
16
+ # Use the model to generate an image
17
+ output = await asyncio.to_thread(
18
+ model,
19
+ prompt=prompt,
20
  num_inference_steps=num_inference_steps, # Use a higher value for inference steps
21
  guidance_scale=0.0, # Typical value for guidance scale in image generation
22
  output_type="pil" # Directly get PIL Image objects
23
  )
24
 
 
 
 
25
  # Check for output validity and return
26
  if output.images:
27
  return output.images[0]
 
33
  traceback.print_exc()
34
  return None # Return None on error to handle it gracefully in the UI
35
 
36
+ async def inference(sentence_mapping, character_dict, selected_style):
37
  images = []
38
  print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
39
  prompts = []
 
45
  prompts.append(prompt)
46
  print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
47
 
48
+ # Use asyncio.gather to run generate_image in parallel
49
+ tasks = [generate_image(prompt) for prompt in prompts]
50
+ images = await asyncio.gather(*tasks)
51
 
52
+ # Filter out None values
53
+ images = [image for image in images if image is not None]
 
 
 
 
 
 
54
 
55
  return images
56