dipta007's picture
added 2 passes
3b827fa
import streamlit as st
st.set_page_config(page_title="GPT-4V Demo", page_icon="🧠", layout="wide")
from PIL import Image
import base64
from io import BytesIO
from utils import get_str_to_json
from pass1 import get_gpt4V_response_1
from pass2 import get_gpt4V_response_2
from examples import example_1, example_2
def clear_data():
st.session_state["story"] = ""
st.session_state["goal"] = ""
st.session_state["entity"] = ""
st.session_state["images"] = []
for key in st.session_state.keys():
st.session_state.pop(key)
# st.rerun()
print(st.session_state)
with st.sidebar:
if st.button("Clear Inputs"):
clear_data()
st.title("Parameters")
st.write("This is a demo of GPT-4V model. It takes a story, goal, entity and an image as input and generates a response.")
st.subheader("Sampling Temperature")
temperature = st.slider(label="x", min_value=0.1, max_value=1.0, value=0.5, step=0.1, label_visibility='hidden')
st.write("The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.")
st.subheader("Entity?")
entity_opt = st.radio(label="With or Without", options=[1, 0], format_func=lambda x: ["Without", "With"][x], on_change=clear_data)
st.subheader("Examples")
cols = st.columns(2)
for i, example in enumerate([example_1, example_2]):
with cols[i % len(cols)]:
if st.button(f"Example {i+1}", key=f"example{i+1}"):
clear_data()
st.session_state["data"] = example
def main():
global temperature, entity
st.title('What can go wrong?')
data = st.session_state.get("data", None)
col1, col2 = st.columns(2)
with col1:
story = st.text_area("Story", placeholder="Enter the story here", value=(data.story if data else ""), key="story")
entity = None
if entity_opt:
entity = st.text_input("Entity", placeholder="Enter the entity here", value=(data.entity if data else ""), key="entity")
goal = st.text_area("Goal", placeholder="Enter the goal here", value=(data.goal if data else ""), key="goal")
images = st.file_uploader("Upload Image", type=['jpg', 'png'], accept_multiple_files=True)
if images:
cols = st.columns(len(images))
for i, image in enumerate(images):
with cols[i]:
image = Image.open(image)
st.image(image, caption="Uploaded Image", use_column_width=True)
elif not images and data:
cols = st.columns(len(data.images))
for i, imb64 in enumerate(data.images_base64):
with cols[i]:
image = Image.open(BytesIO(base64.b64decode(imb64)))
st.image(image, caption="Example Image", use_column_width=True)
if st.button("Pass 1"):
st.session_state["button_2"] = False
image_to_send = None
if images:
image_to_send = images
elif data:
image_to_send = data.images_base64
if not story or not goal or (entity_opt and not entity) or not image_to_send:
st.error("Please fill all the fields")
return
with col2:
with st.status("Generating response...", expanded=True):
response = get_gpt4V_response_1(story, goal, entity, image_to_send, temperature=temperature)
response_json = {}
try:
response_json = get_str_to_json(response)
if "condition" not in response_json or "alternate_condition" not in response_json:
raise ValueError("Invalid JSON - 1")
if not entity_opt and "entity" not in response_json:
raise ValueError("Invalid JSON - 2")
except Exception as e:
print("Exception 1", e)
response_json = {
"entity": "",
"condition": "",
"alternate_condition": "",
"response": response
}
finally:
out1 = {
"entity": response_json.get("entity", None),
"condition": response_json.get("condition", None),
"alternate_condition": response_json.get("alternate_condition", None),
"response": response_json.get("response", "")
}
st.session_state["output_1"] = out1
st.session_state["button_1"] = True
with col2:
if st.session_state.get("button_1", False): # If pass 1 is done
output_1 = st.session_state.get("output_1", {})
if "response" in output_1 and output_1["response"]:
st.warning(f"Failed to parse JSON. Going for full output")
st.write(output_1["response"])
entity = output_1.get("entity", "")
condition = output_1.get("condition", "")
alternate_condition = output_1.get("alternate_condition", "")
if not entity_opt:
st.text_input("Entity", value=entity)
st.text_area("Condition", value=condition)
st.text_area("Alternate Condition", value=alternate_condition)
if st.button("Pass 2"):
st.session_state["button_2"] = True
with st.status("Generating response...", expanded=True):
response = get_gpt4V_response_2(story, goal, alternate_condition, images, temperature=temperature)
try:
response_json = get_str_to_json(response)
if "event" not in response_json:
raise ValueError("Invalid JSON - 3")
except Exception as e:
print("Exception 2", e)
st.warning(f"Failed to parse JSON. Going for full output")
response_json = {
"event": response
}
finally:
out2 = {
"event": response_json.get("event", response)
}
st.session_state["output_2"] = out2
if st.session_state.get("button_2", False): # If pass 2 is done
output_2 = st.session_state.get("output_2", {})
st.subheader("Event Leads to Alternate Condition")
st.write(output_2.get("event", ""))
main()