|
import streamlit as st |
|
import replicate |
|
import requests |
|
import time |
|
import os |
|
import re |
|
from dotenv import load_dotenv |
|
from io import BytesIO |
|
from PIL import Image |
|
import base64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(layout="wide", page_title="Flux.1 in Streamlit with Replicate!", page_icon=":frame_with_picture:") |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
try: |
|
|
|
if 'prompt_history' not in st.session_state: |
|
st.session_state.prompt_history = [] |
|
|
|
def wait_for_image(url, max_attempts=10, delay=2): |
|
for attempt in range(max_attempts): |
|
response = requests.head(url) |
|
if response.status_code == 200: |
|
return True |
|
time.sleep(delay) |
|
return False |
|
|
|
def display_image(url): |
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
image = Image.open(BytesIO(response.content)) |
|
st.image(image, caption="Generated Image") |
|
return image |
|
else: |
|
st.error(f"Failed to download image. Status code: {response.status_code}") |
|
return None |
|
|
|
def get_image_download_link(img, filename, text): |
|
buffered = BytesIO() |
|
img.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
href = f'<a href="data:file/png;base64,{img_str}" download="{filename}">**{text}**</a>' |
|
return href |
|
|
|
|
|
st.title("Flux.1.X - Streamlit GUI") |
|
|
|
|
|
left_column, margin_col, right_column = st.columns([6, 1, 5]) |
|
|
|
|
|
with left_column: |
|
input_prompt = st.text_area("Enter your prompt:", height=100) |
|
|
|
model_version = st.selectbox( |
|
"Model Version (schnell: fast and cheap, dev: quick and inexpensive, pro: moderate render time, most expensive)", |
|
options=["schnell", "dev", "pro","1.1-pro"], |
|
index=0 |
|
) |
|
|
|
aspect_ratio = st.selectbox( |
|
"Aspect Ratio", |
|
options=["1:1", "16:9", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"], |
|
index=0 |
|
) |
|
|
|
|
|
guidance = None |
|
steps = None |
|
safety_checker = None |
|
interval = None |
|
safety_tolerance = None |
|
|
|
if model_version == "dev": |
|
guidance = st.slider( |
|
"Guidance - How closely the model follows your prompt, 1-10, default 3.5", |
|
min_value=0.0, |
|
max_value=10.0, |
|
value=3.5, |
|
step=0.01, |
|
format="%.2f" |
|
) |
|
|
|
if model_version.startswith("pro"): |
|
guidance = st.slider( |
|
"Guidance - How closely the model follows your prompt, 2-5, default is 3", |
|
min_value=2.0, |
|
max_value=5.0, |
|
value=3.0, |
|
step=0.01, |
|
format="%.2f" |
|
) |
|
|
|
steps = st.slider( |
|
"Steps - Quality/Detail of render, 1-100, default 25.", |
|
min_value=1, |
|
max_value=100, |
|
value=25, |
|
step=1 |
|
) |
|
|
|
interval = st.slider( |
|
"Interval - Variance of the image, 4 being the most varied, default is 1", |
|
min_value=1.0, |
|
max_value=4.0, |
|
value=1.0, |
|
step=0.01, |
|
format="%.2f" |
|
) |
|
|
|
safety_tolerance = st.slider( |
|
"Safety Tolerance - 1 to 5, 5 being least restrictive, 1 default (3 on default on here)", |
|
min_value=1, |
|
max_value=5, |
|
value=3, |
|
step=1 |
|
) |
|
|
|
if not model_version.startswith("pro"): |
|
safety_checker = "On" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seed = st.number_input("Seed (optional)", min_value=0, max_value=2**32-1, step=1, value=None, key="seed") |
|
|
|
replicate_key = st.text_input("Replicate Key - Required", key="rep_key") |
|
if replicate_key is None: |
|
st.warning("You must provide a replicate auth token key for this to work.") |
|
st.stop() |
|
else: |
|
os.environ["REPLICATE_API_TOKEN"] = replicate_key |
|
|
|
col1, col2, col3 = st.columns([2,2,4]) |
|
with col1: |
|
generate_button = st.button("Generate Image") |
|
|
|
if generate_button: |
|
if replicate_key is None or replicate_key == "": |
|
st.warning("You must provide a replicate auth token key for this to work.") |
|
st.stop() |
|
|
|
|
|
|
|
|
|
if input_prompt: |
|
st.session_state.prompt_history.insert(0, input_prompt) |
|
with st.spinner(): |
|
|
|
input_dict = { |
|
"prompt": input_prompt, |
|
"aspect_ratio": aspect_ratio, |
|
"output_format": "png", |
|
"output_quality": 100 |
|
|
|
} |
|
|
|
if seed is not None: |
|
input_dict["seed"] = seed |
|
|
|
if guidance is not None: |
|
input_dict["guidance"] = guidance |
|
|
|
if steps is not None: |
|
input_dict["steps"] = steps |
|
|
|
if safety_checker is not None: |
|
input_dict["disable_safety_checker"] = safety_checker == "On" |
|
|
|
if safety_tolerance is not None: |
|
input_dict["safety_tolerance"] = safety_tolerance |
|
|
|
|
|
try: |
|
|
|
client = replicate.Client(api_token=replicate_key) |
|
|
|
output = client.run( |
|
f"black-forest-labs/flux-{model_version}", |
|
input=input_dict |
|
) |
|
|
|
if isinstance(output, list) and len(output) > 0: |
|
output = output[0] |
|
|
|
if not isinstance(output, str): |
|
st.error(f"Unexpected output format: {output}") |
|
else: |
|
with st.spinner('Waiting for image to be ready...'): |
|
if wait_for_image(output): |
|
image = display_image(output) |
|
if image: |
|
timestamp = int(time.time()) |
|
clean_prompt = re.sub(r'[^a-zA-Z0-9 ]', '', input_prompt) |
|
clean_prompt = clean_prompt.strip()[:30] |
|
clean_prompt = clean_prompt.replace(' ', '_') |
|
filename = f"{timestamp}_{clean_prompt}.png" |
|
|
|
st.markdown(get_image_download_link(image, filename, 'Download Image'), unsafe_allow_html=True) |
|
else: |
|
st.error("Timed out waiting for image to be ready.") |
|
|
|
except Exception as e: |
|
st.error(f"Error generating image: {str(e)}") |
|
|
|
else: |
|
st.warning("Please enter a prompt.") |
|
|
|
|
|
with margin_col: |
|
st.empty() |
|
|
|
|
|
with right_column: |
|
st.subheader("Prompt History") |
|
|
|
prompt_history_container = st.container() |
|
|
|
with prompt_history_container: |
|
for i, prompt in enumerate(st.session_state.prompt_history): |
|
st.text(f"{i+1}. {prompt}") |
|
|
|
st.markdown(""" |
|
<style> |
|
.stContainer { |
|
max-height: 400px; |
|
overflow-y: auto; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
except Exception as ex: |
|
st.error(f"Something errored out {ex}") |