collin commited on
Commit
a3db6fe
1 Parent(s): b582e82
Files changed (1) hide show
  1. app.py +53 -8
app.py CHANGED
@@ -15,7 +15,8 @@ import base64
15
  # *
16
  # *****************************************************
17
 
18
- st.set_page_config(layout="wide", page_title="Flux.1 in Streamlit with Replicate!", page_icon=":frame_with_picture:")
 
19
 
20
  # Load environment variables
21
  load_dotenv()
@@ -52,7 +53,7 @@ try:
52
  return href
53
 
54
  # Streamlit app
55
- st.title("Flux.1.X - Streamlit GUI")
56
 
57
  # Create three columns
58
  left_column, margin_col, right_column = st.columns([6, 1, 5])
@@ -63,7 +64,7 @@ try:
63
 
64
  model_version = st.selectbox(
65
  "Model Version (schnell: fast and cheap, dev: quick and inexpensive, pro: moderate render time, most expensive)",
66
- options=["schnell", "dev", "pro","1.1-pro"],
67
  index=0
68
  )
69
 
@@ -79,6 +80,7 @@ try:
79
  safety_checker = None
80
  interval = None
81
  safety_tolerance = None
 
82
 
83
  if model_version == "dev":
84
  guidance = st.slider(
@@ -125,7 +127,7 @@ try:
125
  step=1
126
  )
127
 
128
- if not model_version.startswith("pro") and not model_version.startswith("1"):
129
  safety_checker = "On"
130
  # safety_checker = st.radio(
131
  # "Safety Checker - Turn on model NSFW checking",
@@ -133,6 +135,38 @@ try:
133
  # index=1,
134
  # format_func=lambda x: "Disabled" if x == "On" else "Enabled"
135
  # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  seed = st.number_input("Seed (optional)", min_value=0, max_value=2**32-1, step=1, value=None, key="seed")
138
 
@@ -151,9 +185,7 @@ try:
151
  if replicate_key is None or replicate_key == "":
152
  st.warning("You must provide a replicate auth token key for this to work.")
153
  st.stop()
154
-
155
-
156
-
157
 
158
  if input_prompt:
159
  st.session_state.prompt_history.insert(0, input_prompt)
@@ -182,13 +214,26 @@ try:
182
  if safety_tolerance is not None:
183
  input_dict["safety_tolerance"] = safety_tolerance
184
 
 
 
 
185
  # Run the model with the prepared input
186
  try:
187
 
188
  client = replicate.Client(api_token=replicate_key)
189
 
 
 
 
 
 
 
 
 
 
 
190
  output = client.run(
191
- f"black-forest-labs/flux-{model_version}",
192
  input=input_dict
193
  )
194
 
 
15
  # *
16
  # *****************************************************
17
 
18
+ st.set_page_config(layout="wide", page_title="Flux.1.X and SD 3.5 in Streamlit with Replicate!", page_icon=":frame_with_picture:")
19
+
20
 
21
  # Load environment variables
22
  load_dotenv()
 
53
  return href
54
 
55
  # Streamlit app
56
+ st.title("Flux.1.X / SD 3.5 Turbo - Streamlit GUI")
57
 
58
  # Create three columns
59
  left_column, margin_col, right_column = st.columns([6, 1, 5])
 
64
 
65
  model_version = st.selectbox(
66
  "Model Version (schnell: fast and cheap, dev: quick and inexpensive, pro: moderate render time, most expensive)",
67
+ options=["schnell", "dev", "pro","1.1-pro", "SD 3.5 Large Turbo", "SD 3.5 Large"],
68
  index=0
69
  )
70
 
 
80
  safety_checker = None
81
  interval = None
82
  safety_tolerance = None
83
+ cfg = None # sd models
84
 
85
  if model_version == "dev":
86
  guidance = st.slider(
 
127
  step=1
128
  )
129
 
130
+ if not model_version.startswith("pro") and not model_version.startswith("1") and not model_version.startswith("SD"):
131
  safety_checker = "On"
132
  # safety_checker = st.radio(
133
  # "Safety Checker - Turn on model NSFW checking",
 
135
  # index=1,
136
  # format_func=lambda x: "Disabled" if x == "On" else "Enabled"
137
  # )
138
+
139
+ if model_version == "SD 3.5 Large Turbo":
140
+ cfg = st.slider(
141
+ "CFG - Similarity to prompt, 0-20, default is 0 ",
142
+ min_value=0,
143
+ max_value=20,
144
+ value=0,
145
+ step=1
146
+ )
147
+ steps = st.slider(
148
+ "Steps - Quality/Detail of render, 1-10, default 4.",
149
+ min_value=1,
150
+ max_value=10,
151
+ value=4,
152
+ step=1
153
+ )
154
+
155
+ if model_version == "SD 3.5 Large":
156
+ cfg = st.slider(
157
+ "CFG - Similarity to prompt, 0-20, default is 3.5 ",
158
+ min_value=0.0,
159
+ max_value=20.0,
160
+ value=3.5,
161
+ step=.5
162
+ )
163
+ steps = st.slider(
164
+ "Steps - Quality/Detail of render, 1-50, default 35.",
165
+ min_value=1,
166
+ max_value=50,
167
+ value=35,
168
+ step=1
169
+ )
170
 
171
  seed = st.number_input("Seed (optional)", min_value=0, max_value=2**32-1, step=1, value=None, key="seed")
172
 
 
185
  if replicate_key is None or replicate_key == "":
186
  st.warning("You must provide a replicate auth token key for this to work.")
187
  st.stop()
188
+
 
 
189
 
190
  if input_prompt:
191
  st.session_state.prompt_history.insert(0, input_prompt)
 
214
  if safety_tolerance is not None:
215
  input_dict["safety_tolerance"] = safety_tolerance
216
 
217
+ if cfg is not None:
218
+ input_dict["cfg"] = cfg
219
+
220
  # Run the model with the prepared input
221
  try:
222
 
223
  client = replicate.Client(api_token=replicate_key)
224
 
225
+ # refactor if this model list gets any bigger
226
+ api_end_point = None
227
+
228
+ if model_version=="SD 3.5 Large Turbo":
229
+ api_end_point = "stability-ai/stable-diffusion-3.5-large-turbo"
230
+ elif model_version=="SD 3.5 Large":
231
+ api_end_point = "stability-ai/stable-diffusion-3.5-large"
232
+ else:
233
+ api_end_point = f"black-forest-labs/flux-{model_version}"
234
+
235
  output = client.run(
236
+ api_end_point,
237
  input=input_dict
238
  )
239