collin commited on
Commit
082af7e
1 Parent(s): e0dcbd8

See if it works

Browse files
Files changed (2) hide show
  1. app.py +244 -0
  2. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import replicate
3
+ import requests
4
+ import time
5
+ import os
6
+ import re
7
+ from dotenv import load_dotenv
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ import base64
11
+
12
+ # ******* For More Info on Flux.1 on Replicate ********
13
+ # *
14
+ # https://replicate.com/black-forest-labs *
15
+ # *
16
+ # *****************************************************
17
+
18
+ st.set_page_config(layout="wide", page_title="Flux.1 in Streamlit with Replicate!", page_icon=":frame_with_picture:")
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
+
23
+ # Global error catch as I'm lazy
24
+ try:
25
+ # Initialize session state for prompt history
26
+ if 'prompt_history' not in st.session_state:
27
+ st.session_state.prompt_history = []
28
+
29
+ def wait_for_image(url, max_attempts=10, delay=2):
30
+ for attempt in range(max_attempts):
31
+ response = requests.head(url)
32
+ if response.status_code == 200:
33
+ return True
34
+ time.sleep(delay)
35
+ return False
36
+
37
+ def display_image(url):
38
+ response = requests.get(url)
39
+ if response.status_code == 200:
40
+ image = Image.open(BytesIO(response.content))
41
+ st.image(image, caption="Generated Image")
42
+ return image
43
+ else:
44
+ st.error(f"Failed to download image. Status code: {response.status_code}")
45
+ return None
46
+
47
+ def get_image_download_link(img, filename, text):
48
+ buffered = BytesIO()
49
+ img.save(buffered, format="PNG")
50
+ img_str = base64.b64encode(buffered.getvalue()).decode()
51
+ href = f'<a href="data:file/png;base64,{img_str}" download="{filename}">**{text}**</a>'
52
+ return href
53
+
54
+ # Streamlit app
55
+ st.title("Flux.1.X - Streamlit GUI")
56
+
57
+ # Create three columns
58
+ left_column, margin_col, right_column = st.columns([6, 1, 5])
59
+
60
+ # Left column contents
61
+ with left_column:
62
+ input_prompt = st.text_area("Enter your prompt:", height=100)
63
+
64
+ model_version = st.selectbox(
65
+ "Model Version (schnell: fast and cheap, dev: quick and inexpensive, pro: moderate render time, most expensive)",
66
+ options=["schnell", "dev", "pro","1.1-pro"],
67
+ index=0
68
+ )
69
+
70
+ aspect_ratio = st.selectbox(
71
+ "Aspect Ratio",
72
+ options=["1:1", "16:9", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"],
73
+ index=0
74
+ )
75
+
76
+ # some default values since different versions of the model require different params
77
+ guidance = None
78
+ steps = None
79
+ safety_checker = None
80
+ interval = None
81
+ safety_tolerance = None
82
+
83
+ if model_version == "dev":
84
+ guidance = st.slider(
85
+ "Guidance - How closely the model follows your prompt, 1-10, default 3.5",
86
+ min_value=0.0,
87
+ max_value=10.0,
88
+ value=3.5,
89
+ step=0.01,
90
+ format="%.2f"
91
+ )
92
+
93
+ if model_version.startswith("pro"):
94
+ guidance = st.slider(
95
+ "Guidance - How closely the model follows your prompt, 2-5, default is 3",
96
+ min_value=2.0,
97
+ max_value=5.0,
98
+ value=3.0,
99
+ step=0.01,
100
+ format="%.2f"
101
+ )
102
+
103
+ steps = st.slider(
104
+ "Steps - Quality/Detail of render, 1-100, default 25.",
105
+ min_value=1,
106
+ max_value=100,
107
+ value=25,
108
+ step=1
109
+ )
110
+
111
+ interval = st.slider(
112
+ "Interval - Variance of the image, 4 being the most varied, default is 1",
113
+ min_value=1.0,
114
+ max_value=4.0,
115
+ value=1.0,
116
+ step=0.01,
117
+ format="%.2f"
118
+ )
119
+
120
+ safety_tolerance = st.slider(
121
+ "Safety Tolerance - 1 to 5, 5 being least restrictive, 1 default (3 on default on here)",
122
+ min_value=1,
123
+ max_value=5,
124
+ value=3,
125
+ step=1
126
+ )
127
+
128
+ if not model_version.startswith("pro"):
129
+ safety_checker = "On"
130
+ # safety_checker = st.radio(
131
+ # "Safety Checker - Turn on model NSFW checking",
132
+ # options=["Off", "On"],
133
+ # index=1,
134
+ # format_func=lambda x: "Disabled" if x == "On" else "Enabled"
135
+ # )
136
+
137
+ seed = st.number_input("Seed (optional)", min_value=0, max_value=2**32-1, step=1, value=None, key="seed")
138
+
139
+ replicate_key = st.text_input("Replicate Key - Required", key="rep_key")
140
+ if replicate_key is None:
141
+ st.warning("You must provide a replicate auth token key for this to work.")
142
+ st.stop()
143
+ else:
144
+ os.environ["REPLICATE_API_TOKEN"] = replicate_key
145
+
146
+ col1, col2, col3 = st.columns([2,2,4])
147
+ with col1:
148
+ generate_button = st.button("Generate Image")
149
+
150
+ if generate_button:
151
+ if replicate_key is None or replicate_key == "":
152
+ st.warning("You must provide a replicate auth token key for this to work.")
153
+ st.stop()
154
+
155
+
156
+
157
+
158
+ if input_prompt:
159
+ st.session_state.prompt_history.insert(0, input_prompt)
160
+ with st.spinner():
161
+
162
+ input_dict = {
163
+ "prompt": input_prompt,
164
+ "aspect_ratio": aspect_ratio,
165
+ "output_format": "png",
166
+ "output_quality": 100 # output_quality, note this is ignored if output is .png
167
+
168
+ }
169
+
170
+ if seed is not None:
171
+ input_dict["seed"] = seed
172
+
173
+ if guidance is not None:
174
+ input_dict["guidance"] = guidance
175
+
176
+ if steps is not None:
177
+ input_dict["steps"] = steps
178
+
179
+ if safety_checker is not None:
180
+ input_dict["disable_safety_checker"] = safety_checker == "On"
181
+
182
+ if safety_tolerance is not None:
183
+ input_dict["safety_tolerance"] = safety_tolerance
184
+
185
+ # Run the model with the prepared input
186
+ try:
187
+
188
+ client = replicate.Client(api_token=replicate_key)
189
+
190
+ output = client.run(
191
+ f"black-forest-labs/flux-{model_version}",
192
+ input=input_dict
193
+ )
194
+
195
+ if isinstance(output, list) and len(output) > 0:
196
+ output = output[0]
197
+
198
+ if not isinstance(output, str):
199
+ st.error(f"Unexpected output format: {output}")
200
+ else:
201
+ with st.spinner('Waiting for image to be ready...'):
202
+ if wait_for_image(output):
203
+ image = display_image(output)
204
+ if image:
205
+ timestamp = int(time.time())
206
+ clean_prompt = re.sub(r'[^a-zA-Z0-9 ]', '', input_prompt)
207
+ clean_prompt = clean_prompt.strip()[:30]
208
+ clean_prompt = clean_prompt.replace(' ', '_')
209
+ filename = f"{timestamp}_{clean_prompt}.png"
210
+
211
+ st.markdown(get_image_download_link(image, filename, 'Download Image'), unsafe_allow_html=True)
212
+ else:
213
+ st.error("Timed out waiting for image to be ready.")
214
+
215
+ except Exception as e:
216
+ st.error(f"Error generating image: {str(e)}")
217
+
218
+ else:
219
+ st.warning("Please enter a prompt.")
220
+
221
+ # Margin column (empty for spacing)
222
+ with margin_col:
223
+ st.empty()
224
+
225
+ # Right column contents
226
+ with right_column:
227
+ st.subheader("Prompt History")
228
+
229
+ prompt_history_container = st.container()
230
+
231
+ with prompt_history_container:
232
+ for i, prompt in enumerate(st.session_state.prompt_history):
233
+ st.text(f"{i+1}. {prompt}")
234
+
235
+ st.markdown("""
236
+ <style>
237
+ .stContainer {
238
+ max-height: 400px;
239
+ overflow-y: auto;
240
+ }
241
+ </style>
242
+ """, unsafe_allow_html=True)
243
+ except Exception as ex:
244
+ st.error(f"Something errored out {ex}")
requirements.txt ADDED
Binary file (1.85 kB). View file