RanM commited on
Commit
6035350
·
verified ·
1 Parent(s): dd55a1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -14
app.py CHANGED
@@ -1,17 +1,14 @@
1
  import gradio as gr
2
- import torch
3
  from diffusers import AutoPipelineForText2Image
4
  from io import BytesIO
5
  from generate_propmts import generate_prompt
6
  from concurrent.futures import ThreadPoolExecutor, as_completed
7
- import json
8
 
9
  # Load the model once outside of the function
10
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
11
 
12
  def generate_image(prompt):
13
  try:
14
- # Truncate prompt if necessary
15
  output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
16
  print(f"Model output: {output}")
17
 
@@ -21,7 +18,6 @@ def generate_image(prompt):
21
  buffered = BytesIO()
22
  image.save(buffered, format="JPEG")
23
  image_bytes = buffered.getvalue()
24
- print(f'image_bytes:{image_bytes}')
25
  return image_bytes
26
  else:
27
  raise Exception("No images returned by the model.")
@@ -31,7 +27,7 @@ def generate_image(prompt):
31
  return None
32
 
33
  def inference(sentence_mapping, character_dict, selected_style):
34
- images = {}
35
  print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
36
  prompts = []
37
 
@@ -39,23 +35,21 @@ def inference(sentence_mapping, character_dict, selected_style):
39
  for paragraph_number, sentences in sentence_mapping.items():
40
  combined_sentence = " ".join(sentences)
41
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
42
- prompts.append((paragraph_number, prompt))
43
  print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
44
 
45
  with ThreadPoolExecutor() as executor:
46
- future_to_paragraph = {executor.submit(generate_image, prompt): paragraph_number for paragraph_number, prompt in prompts}
47
 
48
- for future in as_completed(future_to_paragraph):
49
- paragraph_number = future_to_paragraph[future]
50
  try:
51
  image = future.result()
52
  if image:
53
- images[paragraph_number] = image
54
  except Exception as e:
55
- print(f"Error processing paragraph {paragraph_number}: {e}")
56
 
57
- # Return the images sorted by paragraph number
58
- return [images[paragraph_number] for paragraph_number in sorted(images.keys())]
59
 
60
  gradio_interface = gr.Interface(
61
  fn=inference,
@@ -64,7 +58,7 @@ gradio_interface = gr.Interface(
64
  gr.JSON(label="Character Dict"),
65
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
66
  ],
67
- outputs="json"
68
  )
69
 
70
  if __name__ == "__main__":
 
1
  import gradio as gr
 
2
  from diffusers import AutoPipelineForText2Image
3
  from io import BytesIO
4
  from generate_propmts import generate_prompt
5
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
6
 
7
  # Load the model once outside of the function
8
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
9
 
10
  def generate_image(prompt):
11
  try:
 
12
  output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
13
  print(f"Model output: {output}")
14
 
 
18
  buffered = BytesIO()
19
  image.save(buffered, format="JPEG")
20
  image_bytes = buffered.getvalue()
 
21
  return image_bytes
22
  else:
23
  raise Exception("No images returned by the model.")
 
27
  return None
28
 
29
  def inference(sentence_mapping, character_dict, selected_style):
30
+ images = []
31
  print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
32
  prompts = []
33
 
 
35
  for paragraph_number, sentences in sentence_mapping.items():
36
  combined_sentence = " ".join(sentences)
37
  prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
38
+ prompts.append(prompt)
39
  print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
40
 
41
  with ThreadPoolExecutor() as executor:
42
+ futures = [executor.submit(generate_image, prompt) for prompt in prompts]
43
 
44
+ for future in as_completed(futures):
 
45
  try:
46
  image = future.result()
47
  if image:
48
+ images.append(image)
49
  except Exception as e:
50
+ print(f"Error processing prompt: {e}")
51
 
52
+ return images
 
53
 
54
  gradio_interface = gr.Interface(
55
  fn=inference,
 
58
  gr.JSON(label="Character Dict"),
59
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
60
  ],
61
+ outputs=gr.Gallery(label="Generated Images")
62
  )
63
 
64
  if __name__ == "__main__":