Spaces:
Sleeping
Sleeping
File size: 6,748 Bytes
c4e381e 3b827fa c4e381e 3b827fa c4e381e 3b827fa c4e381e 3b827fa c4e381e 3b827fa c4e381e 3b827fa c4e381e 3b827fa c4e381e 3b827fa c4e381e 3b827fa 946021d f49ac62 c4e381e f49ac62 3b827fa c4e381e 3b827fa 0276c23 3b827fa 0276c23 3b827fa 0276c23 3b827fa c4e381e 0276c23 3b827fa 0276c23 3b827fa c4e381e |
|
import streamlit as st
st.set_page_config(page_title="GPT-4V Demo", page_icon="🧠", layout="wide")
from PIL import Image
import base64
from io import BytesIO
from utils import get_str_to_json
from pass1 import get_gpt4V_response_1
from pass2 import get_gpt4V_response_2
from examples import example_1, example_2
def clear_data():
st.session_state["story"] = ""
st.session_state["goal"] = ""
st.session_state["entity"] = ""
st.session_state["images"] = []
for key in st.session_state.keys():
st.session_state.pop(key)
# st.rerun()
print(st.session_state)
with st.sidebar:
if st.button("Clear Inputs"):
clear_data()
st.title("Parameters")
st.write("This is a demo of GPT-4V model. It takes a story, goal, entity and an image as input and generates a response.")
st.subheader("Sampling Temperature")
temperature = st.slider(label="x", min_value=0.1, max_value=1.0, value=0.5, step=0.1, label_visibility='hidden')
st.write("The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.")
st.subheader("Entity?")
entity_opt = st.radio(label="With or Without", options=[1, 0], format_func=lambda x: ["Without", "With"][x], on_change=clear_data)
st.subheader("Examples")
cols = st.columns(2)
for i, example in enumerate([example_1, example_2]):
with cols[i % len(cols)]:
if st.button(f"Example {i+1}", key=f"example{i+1}"):
clear_data()
st.session_state["data"] = example
def main():
global temperature, entity
st.title('What can go wrong?')
data = st.session_state.get("data", None)
col1, col2 = st.columns(2)
with col1:
story = st.text_area("Story", placeholder="Enter the story here", value=(data.story if data else ""), key="story")
entity = None
if entity_opt:
entity = st.text_input("Entity", placeholder="Enter the entity here", value=(data.entity if data else ""), key="entity")
goal = st.text_area("Goal", placeholder="Enter the goal here", value=(data.goal if data else ""), key="goal")
images = st.file_uploader("Upload Image", type=['jpg', 'png'], accept_multiple_files=True)
if images:
cols = st.columns(len(images))
for i, image in enumerate(images):
with cols[i]:
image = Image.open(image)
st.image(image, caption="Uploaded Image", use_column_width=True)
elif not images and data:
cols = st.columns(len(data.images))
for i, imb64 in enumerate(data.images_base64):
with cols[i]:
image = Image.open(BytesIO(base64.b64decode(imb64)))
st.image(image, caption="Example Image", use_column_width=True)
if st.button("Pass 1"):
st.session_state["button_2"] = False
image_to_send = None
if images:
image_to_send = images
elif data:
image_to_send = data.images_base64
if not story or not goal or (entity_opt and not entity) or not image_to_send:
st.error("Please fill all the fields")
return
with col2:
with st.status("Generating response...", expanded=True):
response = get_gpt4V_response_1(story, goal, entity, image_to_send, temperature=temperature)
response_json = {}
try:
response_json = get_str_to_json(response)
if "condition" not in response_json or "alternate_condition" not in response_json:
raise ValueError("Invalid JSON - 1")
if not entity_opt and "entity" not in response_json:
raise ValueError("Invalid JSON - 2")
except Exception as e:
print("Exception 1", e)
response_json = {
"entity": "",
"condition": "",
"alternate_condition": "",
"response": response
}
finally:
out1 = {
"entity": response_json.get("entity", None),
"condition": response_json.get("condition", None),
"alternate_condition": response_json.get("alternate_condition", None),
"response": response_json.get("response", "")
}
st.session_state["output_1"] = out1
st.session_state["button_1"] = True
with col2:
if st.session_state.get("button_1", False): # If pass 1 is done
output_1 = st.session_state.get("output_1", {})
if "response" in output_1 and output_1["response"]:
st.warning(f"Failed to parse JSON. Going for full output")
st.write(output_1["response"])
entity = output_1.get("entity", "")
condition = output_1.get("condition", "")
alternate_condition = output_1.get("alternate_condition", "")
if not entity_opt:
st.text_input("Entity", value=entity)
st.text_area("Condition", value=condition)
st.text_area("Alternate Condition", value=alternate_condition)
if st.button("Pass 2"):
st.session_state["button_2"] = True
with st.status("Generating response...", expanded=True):
response = get_gpt4V_response_2(story, goal, alternate_condition, images, temperature=temperature)
try:
response_json = get_str_to_json(response)
if "event" not in response_json:
raise ValueError("Invalid JSON - 3")
except Exception as e:
print("Exception 2", e)
st.warning(f"Failed to parse JSON. Going for full output")
response_json = {
"event": response
}
finally:
out2 = {
"event": response_json.get("event", response)
}
st.session_state["output_2"] = out2
if st.session_state.get("button_2", False): # If pass 2 is done
output_2 = st.session_state.get("output_2", {})
st.subheader("Event Leads to Alternate Condition")
st.write(output_2.get("event", ""))
main() |