John Graham Reynolds commited on
Commit
ef9e2e1
Β·
1 Parent(s): 8bb1747

add main working app code, similar to how we stream for RAG chain

Browse files
Files changed (1) hide show
  1. app.py +236 -1
app.py CHANGED
@@ -1 +1,236 @@
1
- # base file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # import threading
3
+ import streamlit as st
4
+ from itertools import tee
5
+ from model import InferenceBuilder
6
+ # from chain import ChainBuilder
7
+
8
+ # DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST")
9
+ # DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN")
10
+ # remove these secrets from the container
11
+ # VS_ENDPOINT_NAME = os.environ.get("VS_ENDPOINT_NAME")
12
+ # VS_INDEX_NAME = os.environ.get("VS_INDEX_NAME")
13
+
14
+ # if DATABRICKS_HOST is None:
15
+ # raise ValueError("DATABRICKS_HOST environment variable must be set")
16
+ # if DATABRICKS_TOKEN is None:
17
+ # raise ValueError("DATABRICKS_TOKEN environment variable must be set")
18
+
19
+ MODEL_AVATAR_URL= "./iphone_robot.png"
20
+ MAX_CHAT_TURNS = 10 # limit this for preliminary testing
21
+ MSG_MAX_TURNS_EXCEEDED = f"Sorry! The CyberSolve LinAlg playground is limited to {MAX_CHAT_TURNS} turns in a single history. Click the 'Clear Chat' button or refresh the page to start a new conversation."
22
+ # MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground"
23
+
24
+ EXAMPLE_PROMPTS = [
25
+ "How is a data lake used at Vanderbilt University Medical Center?",
26
+ "In a table, what are some of the greatest hurdles to healthcare in the United States?",
27
+ "What does EDW stand for in the context of Vanderbilt University Medical Center?",
28
+ "Code a sql statement that can query a database named 'VUMC'.",
29
+ "Write a short story about a country concert in Nashville, Tennessee.",
30
+ "Tell me about maximum out-of-pocket costs in healthcare.",
31
+ ]
32
+
33
+ TITLE = "CyberSolve LinAlg 1.2"
34
+ DESCRIPTION= """Welcome to the CyberSolve LinAlg 1.2 demo! \n
35
+
36
+ **Overview and Usage**: This πŸ€— Space is designed to demo the abilities of the CyberSolve LinAlg 1.2 text-to-text language model.
37
+ and is augmented with additional organization-specific knowledge. Particularly, it has been preliminarily augmented with knowledge of Vanderbilt University Medical Center
38
+ terms like **EDW**, **HCERA**, **NRHA** and **thousands more**. (Ask the assistant if you don't know what any of these terms mean!) On the left is a sidebar of **Examples**;
39
+ click any of these examples to issue the corresponding query to the AI.
40
+
41
+ **Feedback**: Feedback is welcomed, encouraged, and invaluable! To give feedback in regards to one of the model's responses, click the **Give Feedback on Last Response** button just below
42
+ the user input bar. This allows you to provide either positive or negative feedback in regards to the model's most recent response. A **Feedback Form** will appear above the model's title.
43
+ Please be sure to select either πŸ‘ or πŸ‘Ž before adding additional notes about your choice. Be as brief or as detailed as you like! Note that you are making a difference; this
44
+ feedback allows us to later improve this model for your usage through a training technique known as reinforcement learning through human feedback. \n
45
+
46
+ Please provide any additional, larger feedback, ideas, or issues to the email: **[email protected]**. Happy inference!"""
47
+
48
+ GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation."
49
+
50
+ # # To prevent streaming too fast, chunk the output into TOKEN_CHUNK_SIZE chunks
51
+ TOKEN_CHUNK_SIZE = 1 # test this number
52
+ # if TOKEN_CHUNK_SIZE_ENV is not None:
53
+ # TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV)
54
+
55
+ QUEUE_SIZE = 20 # maximize this value for adding enough places in the global queue?
56
+ # if QUEUE_SIZE_ENV is not None:
57
+ # QUEUE_SIZE = int(QUEUE_SIZE_ENV)
58
+
59
+ # @st.cache_resource
60
+ # def get_global_semaphore():
61
+ # return threading.BoundedSemaphore(QUEUE_SIZE)
62
+ # global_semaphore = get_global_semaphore()
63
+
64
+ st.set_page_config(layout="wide")
65
+
66
+ st.title(TITLE)
67
+ # st.image("sunrise.jpg", caption="Sunrise by the mountains") # TODO add a Vanderbilt related picture to the head of our Space!
68
+ st.markdown(DESCRIPTION)
69
+ st.markdown("\n")
70
+
71
+ # use this to format later
72
+ with open("./style.css") as css:
73
+ st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True)
74
+
75
+ if "messages" not in st.session_state:
76
+ st.session_state["messages"] = []
77
+
78
+ if "feedback" not in st.session_state:
79
+ st.session_state["feedback"] = [None]
80
+
81
+ def clear_chat_history():
82
+ st.session_state["messages"] = []
83
+
84
+ st.button('Clear Chat', on_click=clear_chat_history)
85
+
86
+ # build our chain outside the working body so that its only instantiated once - simply pass it the chat history for chat completion
87
+ builder = InferenceBuilder()
88
+ tokenizer = builder.load_tokenizer()
89
+ model = builder.load_model()
90
+
91
+ def last_role_is_user():
92
+ return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user"
93
+
94
+ def get_last_question():
95
+ return st.session_state["messages"][-1]["content"]
96
+
97
+ def text_stream(stream):
98
+ for chunk in stream:
99
+ if chunk["content"] is not None:
100
+ yield chunk["content"]
101
+
102
+ def get_stream_warning_error(stream):
103
+ error = None
104
+ warning = None
105
+ for chunk in stream:
106
+ if chunk["error"] is not None:
107
+ error = chunk["error"]
108
+ if chunk["warning"] is not None:
109
+ warning = chunk["warning"]
110
+ return warning, error
111
+
112
+ # # @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3))
113
+ # def chain_call(history):
114
+ # input = {'messages': [{"role": m["role"], "content": m["content"]} for m in history]}
115
+ # chat_completion = chain.stream(input)
116
+ # return chat_completion
117
+
118
+ def model_inference(messages):
119
+ # input_ids = tokenizer(get_last_question(), return_tensors="pt").input_ids.to("cuda") # tokenize the input and put it on the GPU
120
+ input_ids = tokenizer(get_last_question(), return_tensors="pt").input_ids # move to GPU eventually
121
+ outputs = model.generate(input_ids)
122
+ for chunk in tokenizer.decode(outputs[0], skip_special_tokens=True):
123
+ yield chunk # yield each chunk of the predicted string character by character
124
+
125
+ def write_response():
126
+ stream = chat_completion(st.session_state["messages"])
127
+ content_stream, error_stream = tee(stream)
128
+ response = st.write_stream(text_stream(content_stream))
129
+ stream_warning, stream_error = get_stream_warning_error(error_stream)
130
+ if stream_warning is not None:
131
+ st.warning(stream_warning,icon="⚠️")
132
+ if stream_error is not None:
133
+ st.error(stream_error,icon="🚨")
134
+ # if there was an error, a list will be returned instead of a string: https://docs.streamlit.io/library/api-reference/write-magic/st.write_stream
135
+ if isinstance(response, list):
136
+ response = None
137
+ return response, stream_warning, stream_error
138
+
139
+ def chat_completion(messages):
140
+ if (len(messages)-1)//2 >= MAX_CHAT_TURNS:
141
+ yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None}
142
+ return
143
+
144
+ chat_completion = None
145
+ error = None
146
+ # *** TODO add code for implementing a global queue with a bounded semaphore?
147
+ # wait to be in queue
148
+ # with global_semaphore:
149
+ # try:
150
+ # chat_completion = chat_api_call(history_dbrx_format)
151
+ # except Exception as e:
152
+ # error = e
153
+ # chat_completion = chain_call(history_dbrx_format)
154
+ chat_completion = model_inference(messages)
155
+ if error is not None:
156
+ yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None}
157
+ print(error)
158
+ return
159
+
160
+ max_token_warning = None
161
+ partial_message = ""
162
+ chunk_counter = 0
163
+ for chunk in chat_completion:
164
+ if chunk is not None:
165
+ chunk_counter += 1
166
+ partial_message += chunk
167
+ if chunk_counter % TOKEN_CHUNK_SIZE == 0:
168
+ chunk_counter = 0
169
+ yield {"content": partial_message, "error": None, "warning": None}
170
+ partial_message = ""
171
+ # if chunk.choices[0].finish_reason == "length":
172
+ # max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS
173
+
174
+ yield {"content": partial_message, "error": None, "warning": max_token_warning}
175
+
176
+ # if assistant is the last message, we need to prompt the user
177
+ # if user is the last message, we need to retry the assistant.
178
+ def handle_user_input(user_input):
179
+ with history:
180
+ response, stream_warning, stream_error = [None, None, None]
181
+ if last_role_is_user():
182
+ # retry the assistant if the user tries to send a new message
183
+ with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
184
+ response, stream_warning, stream_error = write_response()
185
+ else:
186
+ st.session_state["messages"].append({"role": "user", "content": user_input, "warning": None, "error": None})
187
+ with st.chat_message("user", avatar="πŸ§‘β€πŸ’»"):
188
+ st.markdown(user_input)
189
+ # stream = chat_completion(st.session_state["messages"])
190
+ with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
191
+ response, stream_warning, stream_error = write_response()
192
+
193
+ st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning, "error": stream_error})
194
+
195
+ def feedback():
196
+ with st.form("feedback_form"):
197
+ st.title("Feedback Form")
198
+ st.markdown("Please select either πŸ‘ or πŸ‘Ž before providing a reason for your review of the most recent response. Dont forget to click submit!")
199
+ rating = st.feedback()
200
+ feedback = st.text_input("Please detail your feedback: ")
201
+ # implement a method for writing these responses to storage!
202
+ submitted = st.form_submit_button("Submit Feedback")
203
+
204
+ main = st.container()
205
+ with main:
206
+ if st.session_state["feedback"][-1] is not None: # TODO clean this up in a fn?
207
+ st.markdown("Thank you! Feedback received! Type a new message to continue your conversation.")
208
+ history = st.container(height=400)
209
+ with history:
210
+ for message in st.session_state["messages"]:
211
+ avatar = "πŸ§‘β€πŸ’»"
212
+ if message["role"] == "assistant":
213
+ avatar = MODEL_AVATAR_URL
214
+ with st.chat_message(message["role"], avatar=avatar):
215
+ if message["content"] is not None:
216
+ st.markdown(message["content"])
217
+ if message["error"] is not None:
218
+ st.error(message["error"],icon="🚨")
219
+ if message["warning"] is not None:
220
+ st.warning(message["warning"],icon="⚠️")
221
+
222
+ if prompt := st.chat_input("Type a message!", max_chars=5000):
223
+ handle_user_input(prompt)
224
+ st.markdown("\n") #add some space for iphone users
225
+ gave_feedback = st.button('Give Feedback on Last Response', on_click=feedback)
226
+ if gave_feedback: # TODO clean up the conditions here with a function
227
+ st.session_state["feedback"].append("given")
228
+ else:
229
+ st.session_state["feedback"].append(None)
230
+
231
+
232
+ with st.sidebar:
233
+ with st.container():
234
+ st.title("Examples")
235
+ for prompt in EXAMPLE_PROMPTS:
236
+ st.button(prompt, args=(prompt,), on_click=handle_user_input)