#------------------------------------------------------------------------ | |
# Import | |
#------------------------------------------------------------------------ | |
import streamlit as st | |
import requests | |
from PIL import Image | |
import io | |
import os | |
#------------------------------------------------------------------------ | |
# HF API | |
#------------------------------------------------------------------------ | |
# Retrieve the HF API key from environment variables | |
hf_api_key = os.getenv('HF_API_KEY') | |
if not hf_api_key: | |
raise ValueError("HF_API_KEY not set in environment variables") | |
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0" | |
headers = {"Authorization": f"Bearer {hf_api_key}"} | |
#------------------------------------------------------------------------ | |
# Configurations | |
#------------------------------------------------------------------------ | |
# Streamlit page setup | |
st.set_page_config( | |
page_title="Stxtement | Image Generation", | |
page_icon=":art:", | |
layout="centered", | |
initial_sidebar_state="auto", | |
menu_items={ | |
'Get Help': 'mailto:[email protected]', | |
'About': "This app is built to support spreadsheet analysis" | |
} | |
) | |
#------------------------------------------------------------------------ | |
# Sidebar | |
#------------------------------------------------------------------------ | |
with st.sidebar: | |
# Password input field | |
# password = st.text_input("Enter Password:", type="password") | |
# Set the desired width in pixels | |
image_width = 300 | |
# Define the path to the image | |
image_path = "mimtss.png" | |
# Display the image | |
st.image(image_path, width=image_width) | |
# Toggle for Help and Report a Bug | |
with st.expander("Need help and report a bug"): | |
st.write(""" | |
**Contact**: Cheyne LeVesseur, PhD | |
**Email**: [email protected] | |
""") | |
st.divider() | |
st.subheader('User Instructions') | |
# Principles text with Markdown formatting | |
User_Instructions = """ | |
Enter a detailed description of the image you want to generate, and the app will create it based on your prompt. | |
""" | |
st.markdown(User_Instructions) | |
#------------------------------------------------------------------------ | |
# Define functions | |
#------------------------------------------------------------------------ | |
# SIMPLE CODE | |
# def query(payload): | |
# response = requests.post(API_URL, headers=headers, json=payload) | |
# if response.status_code != 200: | |
# st.error(f"Error: {response.status_code} - {response.text}") | |
# return None | |
# return response.content | |
# def generate_image(prompt): | |
# image_bytes = query({"inputs": prompt}) | |
# if image_bytes: | |
# return Image.open(io.BytesIO(image_bytes)) | |
# return None | |
# def main(): | |
# st.title("Stxtement | Image Generation") | |
# prompt = st.text_input("Enter a prompt for image generation:") | |
# if st.button("Generate Image"): | |
# if prompt: | |
# image = generate_image(prompt) | |
# if image: | |
# st.image(image, caption="Generated Image") | |
# else: | |
# st.warning("Please enter a prompt.") | |
# COMPREHENSIVE CODE | |
def query(payload): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
if response.status_code != 200: | |
st.error(f"Error: {response.status_code} - {response.text}") | |
return None | |
return response.content | |
def generate_image(prompt): | |
image_bytes = query({"inputs": prompt}) | |
if image_bytes: | |
return Image.open(io.BytesIO(image_bytes)) | |
return None | |
def main(): | |
st.title("Stxtement | Image Generation") | |
# Initialize session state variables for prompt and image | |
if "image" not in st.session_state: | |
st.session_state["image"] = None | |
if "prompt" not in st.session_state: | |
st.session_state["prompt"] = "" | |
# Input field for the prompt | |
prompt = st.text_input("Enter a prompt for image generation:", value=st.session_state["prompt"]) | |
if st.button("Generate Image"): | |
if prompt: | |
# Add a spinner while generating the image | |
with st.spinner('Generating image...'): | |
image = generate_image(prompt) | |
if image: | |
st.session_state["image"] = image # Store generated image in session state | |
st.image(image, caption="Generated Image") | |
st.session_state["prompt"] = prompt # Store the prompt in session state | |
else: | |
st.warning("Please enter a prompt.") | |
# Show download and reset buttons only if an image is generated | |
if st.session_state["image"]: | |
# Download button | |
image_bytes = io.BytesIO() | |
st.session_state["image"].save(image_bytes, format='PNG') | |
st.download_button( | |
label="Download Image", | |
data=image_bytes.getvalue(), | |
file_name="generated_image.png", | |
mime="image/png" | |
) | |
# Reset button | |
if st.button("Reset"): | |
# Clear session state variables | |
st.session_state["image"] = None | |
st.session_state["prompt"] = "" | |
# Clear UI by resetting query params | |
st.query_params = {} | |
#------------------------------------------------------------------------ | |
# Main Guard | |
#------------------------------------------------------------------------ | |
if __name__ == "__main__": | |
main() |