File size: 6,748 Bytes
c4e381e
 
 
3b827fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4e381e
3b827fa
 
 
c4e381e
 
 
 
3b827fa
c4e381e
 
 
3b827fa
c4e381e
3b827fa
 
 
 
 
 
 
c4e381e
 
 
 
 
3b827fa
 
c4e381e
 
 
3b827fa
c4e381e
 
 
3b827fa
c4e381e
3b827fa
946021d
f49ac62
c4e381e
f49ac62
 
 
 
 
 
 
3b827fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4e381e
 
 
 
3b827fa
0276c23
3b827fa
0276c23
 
 
 
 
 
3b827fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0276c23
3b827fa
 
 
 
c4e381e
0276c23
3b827fa
0276c23
3b827fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4e381e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import streamlit as st
st.set_page_config(page_title="GPT-4V Demo", page_icon="🧠", layout="wide")

from PIL import Image
import base64
from io import BytesIO
from utils import get_str_to_json
from pass1 import get_gpt4V_response_1
from pass2 import get_gpt4V_response_2
from examples import example_1, example_2

def clear_data():
    st.session_state["story"] = ""
    st.session_state["goal"] = ""
    st.session_state["entity"] = ""
    st.session_state["images"] = []
    for key in st.session_state.keys():
        st.session_state.pop(key)
    # st.rerun()

print(st.session_state)

with st.sidebar:
    if st.button("Clear Inputs"):
        clear_data()

    st.title("Parameters")
    st.write("This is a demo of GPT-4V model. It takes a story, goal, entity and an image as input and generates a response.")

    st.subheader("Sampling Temperature")
    temperature = st.slider(label="x", min_value=0.1, max_value=1.0, value=0.5, step=0.1, label_visibility='hidden')
    st.write("The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.")

    st.subheader("Entity?")
    entity_opt = st.radio(label="With or Without", options=[1, 0], format_func=lambda x: ["Without", "With"][x], on_change=clear_data)

    st.subheader("Examples")
    cols = st.columns(2)
    for i, example in enumerate([example_1, example_2]):
        with cols[i % len(cols)]:
            if st.button(f"Example {i+1}", key=f"example{i+1}"):
                clear_data()
                st.session_state["data"] = example

def main():
    global temperature, entity
    st.title('What can go wrong?')

    data = st.session_state.get("data", None)

    col1, col2 = st.columns(2)

    with col1:
        story = st.text_area("Story", placeholder="Enter the story here", value=(data.story if data else ""), key="story")

        entity = None
        if entity_opt:
            entity = st.text_input("Entity", placeholder="Enter the entity here", value=(data.entity if data else ""), key="entity")

        goal = st.text_area("Goal", placeholder="Enter the goal here", value=(data.goal if data else ""), key="goal")

        images = st.file_uploader("Upload Image", type=['jpg', 'png'], accept_multiple_files=True)

        if images:
            cols = st.columns(len(images))

            for i, image in enumerate(images):
                with cols[i]:
                    image = Image.open(image)
                    st.image(image, caption="Uploaded Image", use_column_width=True)
        elif not images and data:
            cols = st.columns(len(data.images))
            for i, imb64 in enumerate(data.images_base64):
                with cols[i]:
                    image = Image.open(BytesIO(base64.b64decode(imb64)))
                    st.image(image, caption="Example Image", use_column_width=True)

        if st.button("Pass 1"):
            st.session_state["button_2"] = False
            image_to_send = None
            if images:
                image_to_send = images
            elif data:
                image_to_send = data.images_base64

            if not story or not goal or (entity_opt and not entity) or not image_to_send:
                st.error("Please fill all the fields")
                return
            with col2:
                with st.status("Generating response...", expanded=True):
                    response = get_gpt4V_response_1(story, goal, entity, image_to_send, temperature=temperature)
                
                response_json = {}
                try:
                    response_json = get_str_to_json(response)
                    if "condition" not in response_json or "alternate_condition" not in response_json:
                        raise ValueError("Invalid JSON - 1")
                    if not entity_opt and "entity" not in response_json:
                        raise ValueError("Invalid JSON - 2")

                except Exception as e:
                    print("Exception 1", e)
                    response_json = {
                        "entity": "",
                        "condition": "",
                        "alternate_condition": "",
                        "response": response
                    }

                finally:
                    out1 = {
                        "entity": response_json.get("entity", None),
                        "condition": response_json.get("condition", None),
                        "alternate_condition": response_json.get("alternate_condition", None),
                        "response": response_json.get("response", "")
                    }
                    st.session_state["output_1"] = out1
                    st.session_state["button_1"] = True
    
    with col2:
        if st.session_state.get("button_1", False): # If pass 1 is done
            output_1 = st.session_state.get("output_1", {})

            if "response" in output_1 and output_1["response"]:
                st.warning(f"Failed to parse JSON. Going for full output")
                st.write(output_1["response"])

            entity = output_1.get("entity", "")
            condition = output_1.get("condition", "")
            alternate_condition = output_1.get("alternate_condition", "")
            if not entity_opt:
                st.text_input("Entity", value=entity)
            st.text_area("Condition", value=condition)
            st.text_area("Alternate Condition", value=alternate_condition)

            if st.button("Pass 2"):
                st.session_state["button_2"] = True

                with st.status("Generating response...", expanded=True):
                    response = get_gpt4V_response_2(story, goal, alternate_condition, images, temperature=temperature)

                try:
                    
                    response_json = get_str_to_json(response)

                    if "event" not in response_json:
                        raise ValueError("Invalid JSON - 3")

                except Exception as e:
                    print("Exception 2", e)
                    st.warning(f"Failed to parse JSON. Going for full output")
                    response_json = {
                        "event": response
                    }

                finally:
                    out2 = {
                        "event": response_json.get("event", response)
                    }
                    st.session_state["output_2"] = out2

        if st.session_state.get("button_2", False): # If pass 2 is done
            output_2 = st.session_state.get("output_2", {})
            st.subheader("Event Leads to Alternate Condition")
            st.write(output_2.get("event", ""))

main()