Spaces:
Runtime error
Runtime error
File size: 7,282 Bytes
bd81242 1cd7e2c 37e5f7e bd81242 88c3112 1cd7e2c bd81242 61b9df0 1cd7e2c bd81242 0af4039 bd81242 74a4765 1cd7e2c dfa75b2 1cd7e2c dfa75b2 1cd7e2c 0af4039 1cd7e2c 0af4039 dfa75b2 3c4d0e1 37e5f7e 3c4d0e1 593deb8 bd81242 37e5f7e 593deb8 37e5f7e bd81242 37e5f7e 88c3112 bd81242 88c3112 37e5f7e bd81242 37e5f7e bd81242 37e5f7e bd81242 37e5f7e bd81242 37e5f7e 3b1dbeb 593deb8 dfa75b2 3b1dbeb bd81242 593deb8 73cb8f9 593deb8 3b1dbeb 593deb8 3b1dbeb 593deb8 3b1dbeb 593deb8 3b1dbeb 593deb8 bd81242 593deb8 bd81242 3b1dbeb 593deb8 37e5f7e bd81242 1cd7e2c 5fa41d5 dfa75b2 bd81242 5fa41d5 bd81242 f7a1102 88c3112 593deb8 3b1dbeb bd81242 61b9df0 bd81242 593deb8 37e5f7e 3c4d0e1 37e5f7e bd81242 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
# file stuff
import os
import sys
import zipfile
import requests
import tempfile
from io import BytesIO
import random
import string
#image generation stuff
from PIL import Image
# gradio / hf stuff
import gradio as gr
from openai import OpenAI
from dotenv import load_dotenv
# stats stuff
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
load_dotenv()
openai_key = os.getenv("OPENAI_API_KEY")
pw_key = os.getenv("PW")
if openai_key == "<YOUR_OPENAI_API_KEY>":
openai_key = ""
if pw_key == "<YOUR_PW>":
pw_key = ""
if pw_key == "":
sys.exit("Please Provide A Password in the Environment Variables")
if openai_key == "":
sys.exit("Please Provide Your OpenAI API Key")
# Connect to MongoDB
uri = os.getenv("MONGO_URI")
mongo_client = MongoClient(uri, server_api=ServerApi('1'))
mongo_db = mongo_client.pdr
mongo_collection = mongo_db["images"]
# Send a ping to confirm a successful connection
try:
mongo_client.admin.command('ping')
print("Pinged your deployment. You successfully connected to MongoDB!")
except Exception as e:
print(e)
# image_paths_global = []
# image_labels_global = []
def update_labels(show_labels):
if show_labels:
return [(path, label) for path, label in zip(image_paths_global, image_labels_global)]
else:
return [(path, "") for path in image_paths_global] # Empty string as label to hide them
def generate_images_wrapper(prompts, pw, model, show_labels):
global image_paths_global, image_labels_global
image_paths, image_labels = generate_images(prompts, pw, model)
image_paths_global = image_paths # Store paths globally
if show_labels:
image_labels_global = image_labels # Store labels globally if showing labels is enabled
else:
image_labels_global = [""] * len(image_labels) # Use empty labels if showing labels is disabled
# Modify the return statement to not use labels if show_labels is False
image_data = [(path, label if show_labels else "") for path, label in zip(image_paths, image_labels)]
return image_data # Return image paths with or without labels based on the toggle
def download_image(url):
response = requests.get(url)
if response.status_code == 200:
return response.content
else:
raise Exception(f"Failed to download image from URL: {url}")
def zip_images(image_paths_and_labels):
zip_file_path = tempfile.NamedTemporaryFile(delete=False, suffix='.zip').name
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
for image_url, _ in image_paths_and_labels:
image_content = download_image(image_url)
# Generate a random filename for the image
random_filename = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + ".png"
# Write the image content to the zip file with the random filename
zipf.writestr(random_filename, image_content)
return zip_file_path
def download_all_images():
global image_paths_global, image_labels_global
if not image_paths_global:
raise gr.Error("No images to download.")
image_paths_and_labels = list(zip(image_paths_global, image_labels_global))
zip_path = zip_images(image_paths_and_labels)
image_paths_global = [] # Reset the global variable
image_labels_global = [] # Reset the global variable
return zip_path
def generate_images(prompts, pw, model):
# Check for a valid password
if pw != os.getenv("PW"):
raise gr.Error("Invalid password. Please try again.")
image_paths = [] # Initialize a list to hold paths of generated images
image_labels = [] # Initialize a list to hold labels of generated images
users = [] # Initialize a list to hold user initials
# Split the prompts string into individual prompts based on semicolon separation
prompts_list = prompts.split(';')
for entry in prompts_list:
entry_parts = entry.split('-', 1) # Split by the first dash found
if len(entry_parts) != 2:
raise gr.Error("Invalid prompt format. Please ensure it is in 'initials-prompt' format.")
user_initials, text = entry_parts[0].strip(), entry_parts[1].strip() # Extract user initials and the prompt
users.append(user_initials) # Append user initials to the list
try:
client = OpenAI(api_key=openai_key)
response = client.images.generate(
prompt=text,
model=model, # dall-e-2 or dall-e-3
quality="standard", # standard or hd
size="512x512" if model == "dall-e-2" else "1024x1024", # varies for dalle-2 and dalle-3
n=1, # Number of images to generate
)
image_url = response.data[0].url
image_label = f"User: {user_initials}, Prompt: {text}" # Creating a label for the image including user initials
try:
mongo_collection.insert_one({"user": user_initials, "text": text, "model": model, "image_url": image_url})
except Exception as e:
print(e)
raise gr.Error("An error occurred while saving the prompt to the database.")
# Append the image URL and label to their respective lists
image_paths.append(image_url)
image_labels.append(image_label)
except Exception as error:
print(str(error))
raise gr.Error(f"An error occurred while generating the image for: {entry}")
return image_paths, image_labels # Return both image paths and labels
with gr.Blocks() as demo:
gr.Markdown("# <center> Prompt de Resistance Image Generator</center>")
gr.Markdown("**Instructions**: To use this service, please enter the password. Then generate an image from the prompt field below, then click the download arrow from the top right of the image to save it.")
pw = gr.Textbox(label="Password", type="password",
placeholder="Enter the password to unlock the service")
text = gr.Textbox(label="What do you want to create?",
placeholder="Enter your text and then click on the \"Image Generate\" button")
model = gr.Dropdown(choices=["dall-e-2", "dall-e-3"], label="Model", value="dall-e-2")
show_labels = gr.Checkbox(label="Show Image Labels", value=True) # Default is to show labels
btn = gr.Button("Generate Images")
output_images = gr.Gallery(label="Image Outputs", show_label=True, columns=[3], rows=[1], object_fit="contain",
height="auto", allow_preview=False)
text.submit(fn=generate_images_wrapper, inputs=[text, pw, model], outputs=output_images, api_name="generate_image")
# btn.click(fn=generate_images_wrapper, inputs=[text, pw, model], outputs=output_images, api_name=False)
btn.click(fn=generate_images_wrapper, inputs=[text, pw, model, show_labels], outputs=output_images, api_name=False)
show_labels.change(fn=update_labels, inputs=[show_labels], outputs=[output_images])
download_all_btn = gr.Button("Download All")
download_link = gr.File(label="Download Zip")
download_all_btn.click(fn=download_all_images, inputs=[], outputs=download_link)
demo.launch(share=False) |