Update app.py
Browse files
@@ -56,6 +56,373 @@ def query(payload):
56 |
def get_output(prompt):
57 |
return query({"inputs": prompt})
58 |
59 |
def main():
60 |
st.title("Medical Llama Test Bench with Inference Endpoints Llama 7B")
61 |
prompt = f"Write instructions to teach anyone to write a discharge plan. List the entities, features and relationships to CCDA and FHIR objects in boldface."
@@ -64,5 +431,168 @@ def main():
64 |
if st.button("Run Prompt With Dr Llama"):
65 |
66 |
67 |
if __name__ == "__main__":
68 |
56 |
def get_output(prompt):
57 |
return query({"inputs": prompt})
58 |
59 |
60 |
61 |
62 |
63 |
import streamlit as st
64 |
import openai
65 |
import os
66 |
import base64
67 |
import glob
68 |
import json
69 |
import mistune
70 |
import pytz
71 |
import math
72 |
import requests
73 |
import time
74 |
import re
75 |
import textract
76 |
import zipfile # New import for zipping files
77 |
78 |
79 |
from datetime import datetime
80 |
from openai import ChatCompletion
81 |
from xml.etree import ElementTree as ET
82 |
from bs4 import BeautifulSoup
83 |
from collections import deque
84 |
from audio_recorder_streamlit import audio_recorder
85 |
from dotenv import load_dotenv
86 |
from PyPDF2 import PdfReader
87 |
from langchain.text_splitter import CharacterTextSplitter
88 |
from langchain.embeddings import OpenAIEmbeddings
89 |
from langchain.vectorstores import FAISS
90 |
from langchain.chat_models import ChatOpenAI
91 |
from langchain.memory import ConversationBufferMemory
92 |
from langchain.chains import ConversationalRetrievalChain
93 |
from templates import css, bot_template, user_template
94 |
95 |
# page config and sidebar declares up front allow all other functions to see global class variables
96 |
st.set_page_config(page_title="GPT Streamlit Document Reasoner", layout="wide")
97 |
should_save = st.sidebar.checkbox("๐พ Save", value=True)
98 |
99 |
def generate_filename_old(prompt, file_type):
100 |
central = pytz.timezone('US/Central')
101 |
safe_date_time = datetime.now(central).strftime("%m%d_%H%M") # Date and time DD-HHMM
102 |
safe_prompt = "".join(x for x in prompt if x.isalnum())[:90] # Limit file name size and trim whitespace
103 |
return f"{safe_date_time}_{safe_prompt}.{file_type}" # Return a safe file name
104 |
105 |
def generate_filename(prompt, file_type):
106 |
central = pytz.timezone('US/Central')
107 |
safe_date_time = datetime.now(central).strftime("%m%d_%H%M")
108 |
replaced_prompt = prompt.replace(" ", "_").replace("\n", "_")
109 |
safe_prompt = "".join(x for x in replaced_prompt if x.isalnum() or x == "_")[:90]
110 |
return f"{safe_date_time}_{safe_prompt}.{file_type}"
111 |
112 |
def transcribe_audio(openai_key, file_path, model):
113 |
OPENAI_API_URL = "https://api.openai.com/v1/audio/transcriptions"
114 |
headers = {
115 |
"Authorization": f"Bearer {openai_key}",
116 |
117 |
with open(file_path, 'rb') as f:
118 |
data = {'file': f}
119 |
response = requests.post(OPENAI_API_URL, headers=headers, files=data, data={'model': model})
120 |
if response.status_code == 200:
121 |
122 |
chatResponse = chat_with_model(response.json().get('text'), '') # *************************************
123 |
transcript = response.json().get('text')
124 |
125 |
126 |
filename = generate_filename(transcript, 'txt')
127 |
#create_file(filename, transcript, chatResponse)
128 |
response = chatResponse
129 |
user_prompt = transcript
130 |
create_file(filename, user_prompt, response, should_save)
131 |
return transcript
132 |
133 |
134 |
st.error("Error in API call.")
135 |
return None
136 |
137 |
def save_and_play_audio(audio_recorder):
138 |
audio_bytes = audio_recorder()
139 |
if audio_bytes:
140 |
filename = generate_filename("Recording", "wav")
141 |
with open(filename, 'wb') as f:
142 |
143 |
st.audio(audio_bytes, format="audio/wav")
144 |
return filename
145 |
return None
146 |
147 |
def create_file(filename, prompt, response, should_save=True):
148 |
if not should_save:
149 |
150 |
151 |
# Step 2: Extract base filename without extension
152 |
base_filename, ext = os.path.splitext(filename)
153 |
154 |
# Step 3: Check if the response contains Python code
155 |
has_python_code = bool(re.search(r"```python([\s\S]*?)```", response))
156 |
157 |
# Step 4: Write files based on type
158 |
if ext in ['.txt', '.htm', '.md']:
159 |
# Create Prompt file
160 |
with open(f"{base_filename}-Prompt.txt", 'w') as file:
161 |
162 |
163 |
# Create Response file
164 |
with open(f"{base_filename}-Response.md", 'w') as file:
165 |
166 |
167 |
# Create Code file if Python code is present
168 |
if has_python_code:
169 |
# Extract Python code from the response
170 |
python_code = re.findall(r"```python([\s\S]*?)```", response)[0].strip()
171 |
172 |
with open(f"{base_filename}-Code.py", 'w') as file:
173 |
174 |
175 |
176 |
def create_file_old(filename, prompt, response, should_save=True):
177 |
if not should_save:
178 |
179 |
if filename.endswith(".txt"):
180 |
with open(filename, 'w') as file:
181 |
182 |
elif filename.endswith(".htm"):
183 |
with open(filename, 'w') as file:
184 |
file.write(f"{prompt} {response}")
185 |
elif filename.endswith(".md"):
186 |
with open(filename, 'w') as file:
187 |
188 |
189 |
def truncate_document(document, length):
190 |
return document[:length]
191 |
def divide_document(document, max_length):
192 |
return [document[i:i+max_length] for i in range(0, len(document), max_length)]
193 |
194 |
def get_table_download_link(file_path):
195 |
with open(file_path, 'r') as file:
196 |
197 |
data = file.read()
198 |
199 |
200 |
return file_path
201 |
b64 = base64.b64encode(data.encode()).decode()
202 |
file_name = os.path.basename(file_path)
203 |
ext = os.path.splitext(file_name)[1] # get the file extension
204 |
if ext == '.txt':
205 |
mime_type = 'text/plain'
206 |
elif ext == '.py':
207 |
mime_type = 'text/plain'
208 |
elif ext == '.xlsx':
209 |
mime_type = 'text/plain'
210 |
elif ext == '.csv':
211 |
mime_type = 'text/plain'
212 |
elif ext == '.htm':
213 |
mime_type = 'text/html'
214 |
elif ext == '.md':
215 |
mime_type = 'text/markdown'
216 |
217 |
mime_type = 'application/octet-stream' # general binary data type
218 |
href = f'<a href="data:{mime_type};base64,{b64}" target="_blank" download="{file_name}">{file_name}</a>'
219 |
return href
220 |
221 |
def CompressXML(xml_text):
222 |
root = ET.fromstring(xml_text)
223 |
for elem in list(root.iter()):
224 |
if isinstance(elem.tag, str) and 'Comment' in elem.tag:
225 |
226 |
return ET.tostring(root, encoding='unicode', method="xml")
227 |
228 |
def read_file_content(file,max_length):
229 |
if file.type == "application/json":
230 |
content = json.load(file)
231 |
return str(content)
232 |
elif file.type == "text/html" or file.type == "text/htm":
233 |
content = BeautifulSoup(file, "html.parser")
234 |
return content.text
235 |
elif file.type == "application/xml" or file.type == "text/xml":
236 |
tree = ET.parse(file)
237 |
root = tree.getroot()
238 |
xml = CompressXML(ET.tostring(root, encoding='unicode'))
239 |
return xml
240 |
elif file.type == "text/markdown" or file.type == "text/md":
241 |
md = mistune.create_markdown()
242 |
content = md(file.read().decode())
243 |
return content
244 |
elif file.type == "text/plain":
245 |
return file.getvalue().decode()
246 |
247 |
return ""
248 |
249 |
def chat_with_model(prompt, document_section, model_choice='gpt-3.5-turbo'):
250 |
model = model_choice
251 |
conversation = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
252 |
conversation.append({'role': 'user', 'content': prompt})
253 |
if len(document_section)>0:
254 |
conversation.append({'role': 'assistant', 'content': document_section})
255 |
256 |
start_time = time.time()
257 |
report = []
258 |
res_box = st.empty()
259 |
collected_chunks = []
260 |
collected_messages = []
261 |
262 |
for chunk in openai.ChatCompletion.create(
263 |
264 |
265 |
266 |
267 |
268 |
269 |
collected_chunks.append(chunk) # save the event response
270 |
chunk_message = chunk['choices'][0]['delta'] # extract the message
271 |
collected_messages.append(chunk_message) # save the message
272 |
273 |
274 |
275 |
276 |
277 |
if len(content) > 0:
278 |
result = "".join(report).strip()
279 |
#result = result.replace("\n", "")
280 |
281 |
282 |
st.write(' ')
283 |
284 |
full_reply_content = ''.join([m.get('content', '') for m in collected_messages])
285 |
st.write("Elapsed time:")
286 |
st.write(time.time() - start_time)
287 |
return full_reply_content
288 |
289 |
def chat_with_file_contents(prompt, file_content, model_choice='gpt-3.5-turbo'):
290 |
conversation = [{'role': 'system', 'content': 'You are a helpful assistant.'}]
291 |
conversation.append({'role': 'user', 'content': prompt})
292 |
if len(file_content)>0:
293 |
conversation.append({'role': 'assistant', 'content': file_content})
294 |
response = openai.ChatCompletion.create(model=model_choice, messages=conversation)
295 |
return response['choices'][0]['message']['content']
296 |
297 |
def extract_mime_type(file):
298 |
# Check if the input is a string
299 |
if isinstance(file, str):
300 |
pattern = r"type='(.*?)'"
301 |
match = re.search(pattern, file)
302 |
if match:
303 |
return match.group(1)
304 |
305 |
raise ValueError(f"Unable to extract MIME type from {file}")
306 |
# If it's not a string, assume it's a streamlit.UploadedFile object
307 |
elif isinstance(file, streamlit.UploadedFile):
308 |
return file.type
309 |
310 |
raise TypeError("Input should be a string or a streamlit.UploadedFile object")
311 |
312 |
from io import BytesIO
313 |
import re
314 |
315 |
def extract_file_extension(file):
316 |
# get the file name directly from the UploadedFile object
317 |
file_name = file.name
318 |
pattern = r".*?\.(.*?)$"
319 |
match = re.search(pattern, file_name)
320 |
if match:
321 |
return match.group(1)
322 |
323 |
raise ValueError(f"Unable to extract file extension from {file_name}")
324 |
325 |
def pdf2txt(docs):
326 |
text = ""
327 |
for file in docs:
328 |
file_extension = extract_file_extension(file)
329 |
# print the file extension
330 |
st.write(f"File type extension: {file_extension}")
331 |
332 |
# read the file according to its extension
333 |
334 |
if file_extension.lower() in ['py', 'txt', 'html', 'htm', 'xml', 'json']:
335 |
text += file.getvalue().decode('utf-8')
336 |
elif file_extension.lower() == 'pdf':
337 |
from PyPDF2 import PdfReader
338 |
pdf = PdfReader(BytesIO(file.getvalue()))
339 |
for page in range(len(pdf.pages)):
340 |
text += pdf.pages[page].extract_text() # new PyPDF2 syntax
341 |
except Exception as e:
342 |
st.write(f"Error processing file {file.name}: {e}")
343 |
344 |
return text
345 |
346 |
def pdf2txt_old(pdf_docs):
347 |
348 |
for file in pdf_docs:
349 |
mime_type = extract_mime_type(file)
350 |
st.write(f"MIME type of file: {mime_type}")
351 |
352 |
text = ""
353 |
for pdf in pdf_docs:
354 |
pdf_reader = PdfReader(pdf)
355 |
for page in pdf_reader.pages:
356 |
text += page.extract_text()
357 |
return text
358 |
359 |
def txt2chunks(text):
360 |
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len)
361 |
return text_splitter.split_text(text)
362 |
363 |
def vector_store(text_chunks):
364 |
key = os.getenv('OPENAI_API_KEY')
365 |
embeddings = OpenAIEmbeddings(openai_api_key=key)
366 |
return FAISS.from_texts(texts=text_chunks, embedding=embeddings)
367 |
368 |
def get_chain(vectorstore):
369 |
llm = ChatOpenAI()
370 |
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
371 |
return ConversationalRetrievalChain.from_llm(llm=llm, retriever=vectorstore.as_retriever(), memory=memory)
372 |
373 |
def process_user_input(user_question):
374 |
response = st.session_state.conversation({'question': user_question})
375 |
st.session_state.chat_history = response['chat_history']
376 |
for i, message in enumerate(st.session_state.chat_history):
377 |
template = user_template if i % 2 == 0 else bot_template
378 |
st.write(template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
379 |
# Save file output from PDF query results
380 |
filename = generate_filename(user_question, 'txt')
381 |
#create_file(filename, user_question, message.content)
382 |
response = message.content
383 |
user_prompt = user_question
384 |
create_file(filename, user_prompt, response, should_save)
385 |
#st.sidebar.markdown(get_table_download_link(filename), unsafe_allow_html=True)
386 |
387 |
def divide_prompt(prompt, max_length):
388 |
words = prompt.split()
389 |
chunks = []
390 |
current_chunk = []
391 |
current_length = 0
392 |
for word in words:
393 |
if len(word) + current_length <= max_length:
394 |
current_length += len(word) + 1 # Adding 1 to account for spaces
395 |
396 |
397 |
chunks.append(' '.join(current_chunk))
398 |
current_chunk = [word]
399 |
current_length = len(word)
400 |
chunks.append(' '.join(current_chunk)) # Append the final chunk
401 |
return chunks
402 |
403 |
def create_zip_of_files(files):
404 |
405 |
Create a zip file from a list of files.
406 |
407 |
zip_name = "all_files.zip"
408 |
with zipfile.ZipFile(zip_name, 'w') as zipf:
409 |
for file in files:
410 |
411 |
return zip_name
412 |
413 |
414 |
def get_zip_download_link(zip_file):
415 |
416 |
Generate a link to download the zip file.
417 |
418 |
with open(zip_file, 'rb') as f:
419 |
data = f.read()
420 |
b64 = base64.b64encode(data).decode()
421 |
href = f'<a href="data:application/zip;base64,{b64}" download="{zip_file}">Download All</a>'
422 |
return href
423 |
424 |
425 |
426 |
def main():
427 |
st.title("Medical Llama Test Bench with Inference Endpoints Llama 7B")
428 |
prompt = f"Write instructions to teach anyone to write a discharge plan. List the entities, features and relationships to CCDA and FHIR objects in boldface."
431 |
if st.button("Run Prompt With Dr Llama"):
432 |
433 |
434 |
# clip ---
435 |
436 |
openai.api_key = os.getenv('OPENAI_API_KEY')
437 |
438 |
# File type for output, model choice
439 |
menu = ["txt", "htm", "xlsx", "csv", "md", "py"]
440 |
choice = st.sidebar.selectbox("Output File Type:", menu)
441 |
model_choice = st.sidebar.radio("Select Model:", ('gpt-3.5-turbo', 'gpt-3.5-turbo-0301'))
442 |
443 |
# Audio, transcribe, GPT:
444 |
filename = save_and_play_audio(audio_recorder)
445 |
if filename is not None:
446 |
transcription = transcribe_audio(openai.api_key, filename, "whisper-1")
447 |
st.sidebar.markdown(get_table_download_link(filename), unsafe_allow_html=True)
448 |
filename = None
449 |
450 |
# prompt interfaces
451 |
user_prompt = st.text_area("Enter prompts, instructions & questions:", '', height=100)
452 |
453 |
# file section interface for prompts against large documents as context
454 |
collength, colupload = st.columns([2,3]) # adjust the ratio as needed
455 |
with collength:
456 |
max_length = st.slider("File section length for large files", min_value=1000, max_value=128000, value=12000, step=1000)
457 |
with colupload:
458 |
uploaded_file = st.file_uploader("Add a file for context:", type=["pdf", "xml", "json", "xlsx", "csv", "html", "htm", "md", "txt"])
459 |
460 |
461 |
# Document section chat
462 |
463 |
document_sections = deque()
464 |
document_responses = {}
465 |
if uploaded_file is not None:
466 |
file_content = read_file_content(uploaded_file, max_length)
467 |
document_sections.extend(divide_document(file_content, max_length))
468 |
if len(document_sections) > 0:
469 |
if st.button("๐๏ธ View Upload"):
470 |
st.markdown("**Sections of the uploaded file:**")
471 |
for i, section in enumerate(list(document_sections)):
472 |
st.markdown(f"**Section {i+1}**\n{section}")
473 |
st.markdown("**Chat with the model:**")
474 |
for i, section in enumerate(list(document_sections)):
475 |
if i in document_responses:
476 |
st.markdown(f"**Section {i+1}**\n{document_responses[i]}")
477 |
478 |
if st.button(f"Chat about Section {i+1}"):
479 |
st.write('Reasoning with your inputs...')
480 |
response = chat_with_model(user_prompt, section, model_choice) # *************************************
481 |
482 |
483 |
document_responses[i] = response
484 |
filename = generate_filename(f"{user_prompt}_section_{i+1}", choice)
485 |
create_file(filename, user_prompt, response, should_save)
486 |
st.sidebar.markdown(get_table_download_link(filename), unsafe_allow_html=True)
487 |
488 |
if st.button('๐ฌ Chat'):
489 |
st.write('Reasoning with your inputs...')
490 |
491 |
#response = chat_with_model(user_prompt, ''.join(list(document_sections,)), model_choice) # *************************************
492 |
493 |
# Divide the user_prompt into smaller sections
494 |
user_prompt_sections = divide_prompt(user_prompt, max_length)
495 |
full_response = ''
496 |
for prompt_section in user_prompt_sections:
497 |
# Process each section with the model
498 |
response = chat_with_model(prompt_section, ''.join(list(document_sections)), model_choice)
499 |
full_response += response + '\n' # Combine the responses
500 |
501 |
502 |
503 |
504 |
response = full_response
505 |
506 |
507 |
508 |
filename = generate_filename(user_prompt, choice)
509 |
create_file(filename, user_prompt, response, should_save)
510 |
st.sidebar.markdown(get_table_download_link(filename), unsafe_allow_html=True)
511 |
512 |
all_files = glob.glob("*.*")
513 |
all_files = [file for file in all_files if len(os.path.splitext(file)[0]) >= 20] # exclude files with short names
514 |
all_files.sort(key=lambda x: (os.path.splitext(x)[1], x), reverse=True) # sort by file type and file name in descending order
515 |
516 |
# Added "Delete All" button
517 |
if st.sidebar.button("๐ Delete All"):
518 |
for file in all_files:
519 |
520 |
521 |
522 |
# Added "Download All" button
523 |
if st.sidebar.button("โฌ๏ธ Download All"):
524 |
zip_file = create_zip_of_files(all_files)
525 |
st.sidebar.markdown(get_zip_download_link(zip_file), unsafe_allow_html=True)
526 |
527 |
# Sidebar of Files Saving History and surfacing files as context of prompts and responses
528 |
529 |
530 |
for file in all_files:
531 |
col1, col2, col3, col4, col5 = st.sidebar.columns([1,6,1,1,1]) # adjust the ratio as needed
532 |
with col1:
533 |
if st.button("๐", key="md_"+file): # md emoji button
534 |
with open(file, 'r') as f:
535 |
file_contents = f.read()
536 |
537 |
with col2:
538 |
st.markdown(get_table_download_link(file), unsafe_allow_html=True)
539 |
with col3:
540 |
if st.button("๐", key="open_"+file): # open emoji button
541 |
with open(file, 'r') as f:
542 |
file_contents = f.read()
543 |
544 |
with col4:
545 |
if st.button("๐", key="read_"+file): # search emoji button
546 |
with open(file, 'r') as f:
547 |
file_contents = f.read()
548 |
549 |
with col5:
550 |
if st.button("๐", key="delete_"+file):
551 |
552 |
553 |
554 |
if len(file_contents) > 0:
555 |
if next_action=='open':
556 |
file_content_area = st.text_area("File Contents:", file_contents, height=500)
557 |
if next_action=='md':
558 |
559 |
if next_action=='search':
560 |
file_content_area = st.text_area("File Contents:", file_contents, height=500)
561 |
st.write('Reasoning with your inputs...')
562 |
response = chat_with_model(user_prompt, file_contents, model_choice)
563 |
filename = generate_filename(file_contents, choice)
564 |
create_file(filename, user_prompt, response, should_save)
565 |
566 |
567 |
#st.sidebar.markdown(get_table_download_link(filename), unsafe_allow_html=True)
568 |
569 |
570 |
571 |
st.write(css, unsafe_allow_html=True)
572 |
573 |
st.header("Chat with documents :books:")
574 |
user_question = st.text_input("Ask a question about your documents:")
575 |
if user_question:
576 |
577 |
578 |
with st.sidebar:
579 |
st.subheader("Your documents")
580 |
docs = st.file_uploader("import documents", accept_multiple_files=True)
581 |
with st.spinner("Processing"):
582 |
raw = pdf2txt(docs)
583 |
if len(raw) > 0:
584 |
length = str(len(raw))
585 |
text_chunks = txt2chunks(raw)
586 |
vectorstore = vector_store(text_chunks)
587 |
st.session_state.conversation = get_chain(vectorstore)
588 |
st.markdown('# AI Search Index of Length:' + length + ' Created.') # add timing
589 |
filename = generate_filename(raw, 'txt')
590 |
create_file(filename, raw, '', should_save)
591 |
#create_file(filename, raw, '')
592 |
593 |
594 |
595 |
596 |
597 |
if __name__ == "__main__":
598 |