Spaces:
Sleeping
Sleeping
import streamlit as st | |
st.set_page_config(page_title="GPT-4V Demo", page_icon="🧠", layout="wide") | |
from PIL import Image | |
import base64 | |
from io import BytesIO | |
from utils import get_str_to_json | |
from pass1 import get_gpt4V_response_1 | |
from pass2 import get_gpt4V_response_2 | |
from examples import example_1, example_2 | |
def clear_data(): | |
st.session_state["story"] = "" | |
st.session_state["goal"] = "" | |
st.session_state["entity"] = "" | |
st.session_state["images"] = [] | |
for key in st.session_state.keys(): | |
st.session_state.pop(key) | |
# st.rerun() | |
print(st.session_state) | |
with st.sidebar: | |
if st.button("Clear Inputs"): | |
clear_data() | |
st.title("Parameters") | |
st.write("This is a demo of GPT-4V model. It takes a story, goal, entity and an image as input and generates a response.") | |
st.subheader("Sampling Temperature") | |
temperature = st.slider(label="x", min_value=0.1, max_value=1.0, value=0.5, step=0.1, label_visibility='hidden') | |
st.write("The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.") | |
st.subheader("Entity?") | |
entity_opt = st.radio(label="With or Without", options=[1, 0], format_func=lambda x: ["Without", "With"][x], on_change=clear_data) | |
st.subheader("Examples") | |
cols = st.columns(2) | |
for i, example in enumerate([example_1, example_2]): | |
with cols[i % len(cols)]: | |
if st.button(f"Example {i+1}", key=f"example{i+1}"): | |
clear_data() | |
st.session_state["data"] = example | |
def main(): | |
global temperature, entity | |
st.title('What can go wrong?') | |
data = st.session_state.get("data", None) | |
col1, col2 = st.columns(2) | |
with col1: | |
story = st.text_area("Story", placeholder="Enter the story here", value=(data.story if data else ""), key="story") | |
entity = None | |
if entity_opt: | |
entity = st.text_input("Entity", placeholder="Enter the entity here", value=(data.entity if data else ""), key="entity") | |
goal = st.text_area("Goal", placeholder="Enter the goal here", value=(data.goal if data else ""), key="goal") | |
images = st.file_uploader("Upload Image", type=['jpg', 'png'], accept_multiple_files=True) | |
if images: | |
cols = st.columns(len(images)) | |
for i, image in enumerate(images): | |
with cols[i]: | |
image = Image.open(image) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
elif not images and data: | |
cols = st.columns(len(data.images)) | |
for i, imb64 in enumerate(data.images_base64): | |
with cols[i]: | |
image = Image.open(BytesIO(base64.b64decode(imb64))) | |
st.image(image, caption="Example Image", use_column_width=True) | |
if st.button("Pass 1"): | |
st.session_state["button_2"] = False | |
image_to_send = None | |
if images: | |
image_to_send = images | |
elif data: | |
image_to_send = data.images_base64 | |
if not story or not goal or (entity_opt and not entity) or not image_to_send: | |
st.error("Please fill all the fields") | |
return | |
with col2: | |
with st.status("Generating response...", expanded=True): | |
response = get_gpt4V_response_1(story, goal, entity, image_to_send, temperature=temperature) | |
response_json = {} | |
try: | |
response_json = get_str_to_json(response) | |
if "condition" not in response_json or "alternate_condition" not in response_json: | |
raise ValueError("Invalid JSON - 1") | |
if not entity_opt and "entity" not in response_json: | |
raise ValueError("Invalid JSON - 2") | |
except Exception as e: | |
print("Exception 1", e) | |
response_json = { | |
"entity": "", | |
"condition": "", | |
"alternate_condition": "", | |
"response": response | |
} | |
finally: | |
out1 = { | |
"entity": response_json.get("entity", None), | |
"condition": response_json.get("condition", None), | |
"alternate_condition": response_json.get("alternate_condition", None), | |
"response": response_json.get("response", "") | |
} | |
st.session_state["output_1"] = out1 | |
st.session_state["button_1"] = True | |
with col2: | |
if st.session_state.get("button_1", False): # If pass 1 is done | |
output_1 = st.session_state.get("output_1", {}) | |
if "response" in output_1 and output_1["response"]: | |
st.warning(f"Failed to parse JSON. Going for full output") | |
st.write(output_1["response"]) | |
entity = output_1.get("entity", "") | |
condition = output_1.get("condition", "") | |
alternate_condition = output_1.get("alternate_condition", "") | |
if not entity_opt: | |
st.text_input("Entity", value=entity) | |
st.text_area("Condition", value=condition) | |
st.text_area("Alternate Condition", value=alternate_condition) | |
if st.button("Pass 2"): | |
st.session_state["button_2"] = True | |
with st.status("Generating response...", expanded=True): | |
response = get_gpt4V_response_2(story, goal, alternate_condition, images, temperature=temperature) | |
try: | |
response_json = get_str_to_json(response) | |
if "event" not in response_json: | |
raise ValueError("Invalid JSON - 3") | |
except Exception as e: | |
print("Exception 2", e) | |
st.warning(f"Failed to parse JSON. Going for full output") | |
response_json = { | |
"event": response | |
} | |
finally: | |
out2 = { | |
"event": response_json.get("event", response) | |
} | |
st.session_state["output_2"] = out2 | |
if st.session_state.get("button_2", False): # If pass 2 is done | |
output_2 = st.session_state.get("output_2", {}) | |
st.subheader("Event Leads to Alternate Condition") | |
st.write(output_2.get("event", "")) | |
main() |