heaversm commited on
Commit
bd81242
·
1 Parent(s): 37e5f7e

add labels to images, modify download to append the prompt to the file name

Browse files
Files changed (1) hide show
  1. app.py +43 -39
app.py CHANGED
@@ -1,18 +1,23 @@
 
1
  import os
2
  import sys
3
  import zipfile
 
 
 
 
4
 
 
5
  import gradio as gr
6
  from openai import OpenAI
7
  from dotenv import load_dotenv
8
 
 
9
  from pymongo.mongo_client import MongoClient
10
  from pymongo.server_api import ServerApi
11
 
12
- # file extension stuff
13
- import requests
14
- import tempfile
15
- from io import BytesIO
16
 
17
  load_dotenv()
18
 
@@ -48,25 +53,36 @@ except Exception as e:
48
  image_paths_global = []
49
 
50
  def generate_images_wrapper(prompts, pw, model):
51
- global image_paths_global
52
- image_paths = generate_images(prompts, pw, model)
53
  image_paths_global = image_paths # Store paths globally
54
- return image_paths # You might want to return something else for display
 
55
 
56
- def zip_images(image_paths):
 
 
 
 
 
 
 
57
  zip_file_path = tempfile.NamedTemporaryFile(delete=False, suffix='.zip').name
58
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
59
- for path in image_paths:
60
- zipf.write(path, arcname=os.path.basename(path))
61
- os.remove(path) # Clean up the temp image file
62
  return zip_file_path
63
 
 
64
  def download_all_images():
65
- global image_paths_global
66
  if not image_paths_global:
67
  raise gr.Error("No images to download.")
68
- zip_path = zip_images(image_paths_global)
 
69
  image_paths_global = [] # Reset the global variable
 
70
  return zip_path
71
 
72
  def generate_images(prompts, pw, model):
@@ -76,6 +92,7 @@ def generate_images(prompts, pw, model):
76
  raise gr.Error("Invalid password. Please try again.")
77
 
78
  image_paths = [] # Initialize a list to hold paths of generated images
 
79
  # Split the prompts string into individual prompts based on comma separation
80
  prompts_list = prompts.split(';')
81
  for prompt in prompts_list:
@@ -92,8 +109,8 @@ def generate_images(prompts, pw, model):
92
  n=1, # Number of images to generate
93
  )
94
 
95
-
96
  image_url = response.data[0].url
 
97
 
98
  try:
99
  mongo_collection.insert_one({"text": text, "model": model, "image_url": image_url})
@@ -101,47 +118,34 @@ def generate_images(prompts, pw, model):
101
  print(e)
102
  raise gr.Error("An error occurred while saving the prompt to the database.")
103
 
104
- # create a temporary file to store the image with extension
105
- image_response = requests.get(image_url)
106
- if image_response.status_code == 200:
107
- # Use a temporary file to automatically clean up after the file is closed
108
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
109
- temp_file.write(image_response.content)
110
- temp_file.close()
111
- # return the file with extension for download
112
- # return temp_file.name
113
- # append the file with extension to the list of image paths
114
- print(temp_file.name)
115
- image_paths.append(temp_file.name)
116
- else:
117
- raise gr.Error("Failed to download the image.")
118
  except Exception as error:
119
  print(str(error))
120
  raise gr.Error(f"An error occurred while generating the image for: {prompt}")
121
 
122
- return image_paths
123
-
124
 
125
  with gr.Blocks() as demo:
126
  gr.Markdown("# <center> Prompt de Resistance Image Generator</center>")
127
  gr.Markdown("**Instructions**: To use this service, please enter the password. Then generate an image from the prompt field below, then click the download arrow from the top right of the image to save it.")
128
  pw = gr.Textbox(label="Password", type="password",
129
- placeholder="Enter the password to unlock the service")
130
  text = gr.Textbox(label="What do you want to create?",
131
- placeholder="Enter your text and then click on the \"Image Generate\" button")
132
 
133
  model = gr.Dropdown(choices=["dall-e-2", "dall-e-3"], label="Model", value="dall-e-3")
134
  btn = gr.Button("Generate Images")
135
- # output_image = gr.Image(label="Image Output")
136
- output_images = gr.Gallery(label="Image Outputs",columns=[3], rows=[1], object_fit="contain", height="auto",allow_preview=False)
137
 
138
- text.submit(fn=generate_images_wrapper, inputs=[text,pw,model], outputs=output_images, api_name="generate_image")
139
- btn.click(fn=generate_images_wrapper, inputs=[text,pw,model], outputs=output_images, api_name=False)
140
 
141
  download_all_btn = gr.Button("Download All")
142
  download_link = gr.File(label="Download Zip")
143
  download_all_btn.click(fn=download_all_images, inputs=[], outputs=download_link)
144
 
145
-
146
-
147
- demo.launch(share=True)
 
1
+ # file stuff
2
  import os
3
  import sys
4
  import zipfile
5
+ import requests
6
+ import tempfile
7
+ import json
8
+ from io import BytesIO
9
 
10
+ # gradio / hf stuff
11
  import gradio as gr
12
  from openai import OpenAI
13
  from dotenv import load_dotenv
14
 
15
+ # stats stuff
16
  from pymongo.mongo_client import MongoClient
17
  from pymongo.server_api import ServerApi
18
 
19
+
20
+
 
 
21
 
22
  load_dotenv()
23
 
 
53
  image_paths_global = []
54
 
55
  def generate_images_wrapper(prompts, pw, model):
56
+ global image_paths_global, image_labels_global
57
+ image_paths, image_labels = generate_images(prompts, pw, model)
58
  image_paths_global = image_paths # Store paths globally
59
+ image_labels_global = image_labels # Store labels globally
60
+ return list(zip(image_paths, image_labels)) # Return a list of tuples containing image paths and labels
61
 
62
+ def download_image(url):
63
+ response = requests.get(url)
64
+ if response.status_code == 200:
65
+ return response.content
66
+ else:
67
+ raise Exception(f"Failed to download image from URL: {url}")
68
+
69
+ def zip_images(image_paths_and_labels):
70
  zip_file_path = tempfile.NamedTemporaryFile(delete=False, suffix='.zip').name
71
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
72
+ for image_url, label in image_paths_and_labels:
73
+ image_content = download_image(image_url)
74
+ zipf.writestr(label + ".png", image_content)
75
  return zip_file_path
76
 
77
+
78
  def download_all_images():
79
+ global image_paths_global, image_labels_global
80
  if not image_paths_global:
81
  raise gr.Error("No images to download.")
82
+ image_paths_and_labels = list(zip(image_paths_global, image_labels_global))
83
+ zip_path = zip_images(image_paths_and_labels)
84
  image_paths_global = [] # Reset the global variable
85
+ image_labels_global = [] # Reset the global variable
86
  return zip_path
87
 
88
  def generate_images(prompts, pw, model):
 
92
  raise gr.Error("Invalid password. Please try again.")
93
 
94
  image_paths = [] # Initialize a list to hold paths of generated images
95
+ image_labels = [] # Initialize a list to hold labels of generated images
96
  # Split the prompts string into individual prompts based on comma separation
97
  prompts_list = prompts.split(';')
98
  for prompt in prompts_list:
 
109
  n=1, # Number of images to generate
110
  )
111
 
 
112
  image_url = response.data[0].url
113
+ image_label = f"Prompt: {text}" # Creating a label for the image
114
 
115
  try:
116
  mongo_collection.insert_one({"text": text, "model": model, "image_url": image_url})
 
118
  print(e)
119
  raise gr.Error("An error occurred while saving the prompt to the database.")
120
 
121
+ # append the image URL to the list of image paths
122
+ image_paths.append(image_url)
123
+ image_labels.append(image_label) # Append the label to the list of labels
124
+
 
 
 
 
 
 
 
 
 
 
125
  except Exception as error:
126
  print(str(error))
127
  raise gr.Error(f"An error occurred while generating the image for: {prompt}")
128
 
129
+ return image_paths, image_labels # Return both image paths and labels
 
130
 
131
  with gr.Blocks() as demo:
132
  gr.Markdown("# <center> Prompt de Resistance Image Generator</center>")
133
  gr.Markdown("**Instructions**: To use this service, please enter the password. Then generate an image from the prompt field below, then click the download arrow from the top right of the image to save it.")
134
  pw = gr.Textbox(label="Password", type="password",
135
+ placeholder="Enter the password to unlock the service")
136
  text = gr.Textbox(label="What do you want to create?",
137
+ placeholder="Enter your text and then click on the \"Image Generate\" button")
138
 
139
  model = gr.Dropdown(choices=["dall-e-2", "dall-e-3"], label="Model", value="dall-e-3")
140
  btn = gr.Button("Generate Images")
141
+ output_images = gr.Gallery(label="Image Outputs", show_label=True, columns=[3], rows=[1], object_fit="contain",
142
+ height="auto", allow_preview=False)
143
 
144
+ text.submit(fn=generate_images_wrapper, inputs=[text, pw, model], outputs=output_images, api_name="generate_image")
145
+ btn.click(fn=generate_images_wrapper, inputs=[text, pw, model], outputs=output_images, api_name=False)
146
 
147
  download_all_btn = gr.Button("Download All")
148
  download_link = gr.File(label="Download Zip")
149
  download_all_btn.click(fn=download_all_images, inputs=[], outputs=download_link)
150
 
151
+ demo.launch(share=False)