File size: 5,492 Bytes
4be498f
 
 
 
8b3ed0b
 
 
 
 
 
4be498f
 
 
 
 
 
 
 
8b3ed0b
 
4be498f
 
f0aa736
 
 
 
 
9a63147
 
f0aa736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4be498f
 
 
8b3ed0b
8e6266f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b3ed0b
 
 
 
 
 
 
 
 
 
 
 
 
 
9a63147
8b3ed0b
8e6266f
 
 
 
 
 
 
ff03303
8e6266f
8b3ed0b
 
ff03303
 
 
 
 
 
 
8b3ed0b
 
 
0efa8d3
8e6266f
7498bae
8e6266f
 
0efa8d3
8e6266f
 
 
 
 
 
0efa8d3
 
 
 
 
 
ff03303
 
516628c
8e6266f
4be498f
 
 
 
8b3ed0b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#------------------------------------------------------------------------
# Import
#------------------------------------------------------------------------

import streamlit as st
import requests
from PIL import Image
import io
import os

#------------------------------------------------------------------------
# HF API
#------------------------------------------------------------------------

# Retrieve the HF API key from environment variables 
hf_api_key = os.getenv('HF_API_KEY')
if not hf_api_key:
    raise ValueError("HF_API_KEY not set in environment variables")

API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
headers = {"Authorization": f"Bearer {hf_api_key}"}

#------------------------------------------------------------------------
# Configurations
#------------------------------------------------------------------------
# Streamlit page setup
st.set_page_config(
    page_title="Stxtement | Image Generation", 
    page_icon=":art:", 
    layout="centered", 
    initial_sidebar_state="auto",
    menu_items={
        'Get Help': 'mailto:[email protected]',
        'About': "This app is built to support spreadsheet analysis"
    }
)

#------------------------------------------------------------------------
# Sidebar
#------------------------------------------------------------------------
with st.sidebar:
    # Password input field
    # password = st.text_input("Enter Password:", type="password")
    
    # Set the desired width in pixels
    image_width = 300  
    # Define the path to the image
    image_path = "mimtss.png"
    # Display the image
    st.image(image_path, width=image_width)
    
    # Toggle for Help and Report a Bug
    with st.expander("Need help and report a bug"):
        st.write("""
        **Contact**: Cheyne LeVesseur, PhD  
        **Email**: [email protected]
        """)
    st.divider()
    st.subheader('User Instructions')
    
    # Principles text with Markdown formatting
    User_Instructions = """
    Enter a detailed description of the image you want to generate, and the app will create it based on your prompt.
    """
    st.markdown(User_Instructions)

#------------------------------------------------------------------------
# Define functions
#------------------------------------------------------------------------

# SIMPLE CODE
# def query(payload):
#     response = requests.post(API_URL, headers=headers, json=payload)
#     if response.status_code != 200:
#         st.error(f"Error: {response.status_code} - {response.text}")
#         return None
#     return response.content

# def generate_image(prompt):
#     image_bytes = query({"inputs": prompt})
#     if image_bytes:
#         return Image.open(io.BytesIO(image_bytes))
#     return None

# def main():
#     st.title("Stxtement | Image Generation")

#     prompt = st.text_input("Enter a prompt for image generation:")
    
#     if st.button("Generate Image"):
#         if prompt:
#             image = generate_image(prompt)
#             if image:
#                 st.image(image, caption="Generated Image")
#         else:
#             st.warning("Please enter a prompt.")

# COMPREHENSIVE CODE
def query(payload):
    response = requests.post(API_URL, headers=headers, json=payload)
    if response.status_code != 200:
        st.error(f"Error: {response.status_code} - {response.text}")
        return None
    return response.content

def generate_image(prompt):
    image_bytes = query({"inputs": prompt})
    if image_bytes:
        return Image.open(io.BytesIO(image_bytes))
    return None

def main():
    st.title("Stxtement | Image Generation")

    # Initialize session state variables for prompt and image
    if "image" not in st.session_state:
        st.session_state["image"] = None
    if "prompt" not in st.session_state:
        st.session_state["prompt"] = ""

    # Input field for the prompt
    prompt = st.text_input("Enter a prompt for image generation:", value=st.session_state["prompt"])

    if st.button("Generate Image"):
        if prompt:
            # Add a spinner while generating the image
            with st.spinner('Generating image...'):
                image = generate_image(prompt)
                if image:
                    st.session_state["image"] = image  # Store generated image in session state
                    st.image(image, caption="Generated Image")
                    st.session_state["prompt"] = prompt  # Store the prompt in session state
        else:
            st.warning("Please enter a prompt.")

    # Show download and reset buttons only if an image is generated
    if st.session_state["image"]:
        # Download button
        image_bytes = io.BytesIO()
        st.session_state["image"].save(image_bytes, format='PNG')
        
        st.download_button(
            label="Download Image",
            data=image_bytes.getvalue(),
            file_name="generated_image.png",
            mime="image/png"
        )
        
        # Reset button
        if st.button("Reset"):
            # Clear session state variables
            st.session_state["image"] = None
            st.session_state["prompt"] = ""

            # Clear UI by resetting query params
            st.query_params = {}

#------------------------------------------------------------------------
# Main Guard
#------------------------------------------------------------------------

if __name__ == "__main__":
    main()