|
|
|
|
|
|
|
|
|
import streamlit as st |
|
import requests |
|
from PIL import Image |
|
import io |
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
hf_api_key = os.getenv('HF_API_KEY') |
|
if not hf_api_key: |
|
raise ValueError("HF_API_KEY not set in environment variables") |
|
|
|
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0" |
|
headers = {"Authorization": f"Bearer {hf_api_key}"} |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
page_title="Stxtement | Image Generation", |
|
page_icon=":art:", |
|
layout="centered", |
|
initial_sidebar_state="auto", |
|
menu_items={ |
|
'Get Help': 'mailto:[email protected]', |
|
'About': "This app is built to support spreadsheet analysis" |
|
} |
|
) |
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
|
|
|
|
|
|
|
image_width = 300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
st.title("MTSS.ai") |
|
|
|
|
|
with st.expander("Need help and report a bug"): |
|
st.write(""" |
|
**Contact**: Cheyne LeVesseur, PhD |
|
**Email**: [email protected] |
|
""") |
|
st.divider() |
|
st.subheader('User Instructions') |
|
|
|
|
|
User_Instructions = """ |
|
Enter a detailed description of the image you want to generate, and the app will create it based on your prompt. |
|
""" |
|
st.markdown(User_Instructions) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def query(payload): |
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
if response.status_code != 200: |
|
st.error(f"Error: {response.status_code} - {response.text}") |
|
return None |
|
return response.content |
|
|
|
def generate_image(prompt): |
|
image_bytes = query({"inputs": prompt}) |
|
if image_bytes: |
|
return Image.open(io.BytesIO(image_bytes)) |
|
return None |
|
|
|
def generate_image_callback(): |
|
prompt = st.session_state.get("prompt_input", "") |
|
if prompt: |
|
st.session_state["prompt"] = prompt |
|
|
|
image = generate_image(prompt) |
|
if image: |
|
st.session_state["image"] = image |
|
else: |
|
st.warning("Please enter a prompt.") |
|
|
|
def reset_callback(): |
|
|
|
st.session_state["prompt"] = "" |
|
st.session_state["prompt_input"] = "" |
|
st.session_state["image"] = None |
|
|
|
def main(): |
|
st.title("Stxtement | Image Generation") |
|
|
|
|
|
st.text_input( |
|
"Enter a prompt for image generation:", |
|
value=st.session_state.get("prompt_input", ""), |
|
key="prompt_input" |
|
) |
|
|
|
|
|
generate_button_clicked = st.button("Generate Image") |
|
|
|
|
|
spinner_placeholder = st.empty() |
|
|
|
if generate_button_clicked: |
|
if st.session_state.get("prompt_input", ""): |
|
with spinner_placeholder: |
|
with st.spinner('Generating image...'): |
|
generate_image_callback() |
|
else: |
|
st.warning("Please enter a prompt.") |
|
|
|
|
|
if st.session_state.get("image"): |
|
st.image(st.session_state["image"], caption="Generated Image") |
|
|
|
|
|
image_bytes = io.BytesIO() |
|
st.session_state["image"].save(image_bytes, format='PNG') |
|
|
|
st.download_button( |
|
label="Download Image", |
|
data=image_bytes.getvalue(), |
|
file_name="generated_image.png", |
|
mime="image/png" |
|
) |
|
|
|
|
|
st.button("Reset", on_click=reset_callback) |
|
|
|
|
|
elif st.session_state.get("prompt_input"): |
|
st.button("Reset", on_click=reset_callback) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |