Spaces:
Running
Running
| import streamlit as st | |
| import yaml | |
| import requests | |
| import re | |
| import os | |
| from pdfParser import get_pdf_text | |
| # Get HuggingFace API key | |
| api_key_name = "HUGGINGFACE_HUB_TOKEN" | |
| api_key = os.getenv(api_key_name) | |
| if api_key is None: | |
| st.error(f"Failed to read `{api_key_name}`. Ensure the token is correctly located") | |
| with open("config/model_config.yml", "r") as file: | |
| model_config = yaml.safe_load(file) | |
| system_message = model_config["system_message"] | |
| model_id = model_config["model_id"] | |
| def query(payload, model_id): | |
| headers = {"Authorization": f"Bearer {api_key}"} | |
| API_URL = f"https://api-inference.huggingface.co/models/{model_id}" | |
| response = requests.post(API_URL, headers=headers, json=payload) | |
| return response.json() | |
| def prompt_generator(system_message, user_message): | |
| return f""" | |
| <s>[INST] <<SYS>> | |
| {system_message} | |
| <</SYS>> | |
| {user_message} [/INST] | |
| """ | |
| # Pattern to clean up text response from API | |
| pattern = r".*\[/INST\]([\s\S]*)$" | |
| # Initialize chat history | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Include PDF upload ability | |
| pdf_upload = st.file_uploader( | |
| "Upload a .PDF here", | |
| type=".pdf", | |
| ) | |
| if pdf_upload is not None: | |
| pdf_text = get_pdf_text(pdf_upload) | |
| if "key_inputs" not in st.session_state: | |
| st.session_state.key_inputs = {} | |
| col1, col2, col3 = st.columns([3, 3, 2]) | |
| with col1: | |
| key_name = st.text_input("Key/Column Name (e.g. patient_name)", key="key_name") | |
| with col2: | |
| key_description = st.text_area( | |
| "*(Optional) Description of key/column", key="key_description" | |
| ) | |
| with col3: | |
| if st.button("Extract this column"): | |
| if key_description: | |
| st.session_state.key_inputs[key_name] = key_description | |
| else: | |
| st.session_state.key_inputs[key_name] = "No further description provided" | |
| if st.session_state.key_inputs: | |
| keys_title = st.write("\nKeys/Columns for extraction:") | |
| keys_values = st.write(st.session_state.key_inputs) | |
| with st.spinner("Extracting requested data"): | |
| if st.button("Extract data!"): | |
| user_message = f""" | |
| Use the text provided and denoted by 3 backticks ```{pdf_text}```. | |
| Extract the following columns and return a table that could be uploaded to an SQL database. | |
| {'; '.join([key + ': ' + st.session_state.key_inputs[key] for key in st.session_state.key_inputs])} | |
| """ | |
| the_prompt = prompt_generator( | |
| system_message=system_message, user_message=user_message | |
| ) | |
| response = query( | |
| { | |
| "inputs": the_prompt, | |
| "parameters": {"max_new_tokens": 500, "temperature": 0.1}, | |
| }, | |
| model_id, | |
| ) | |
| try: | |
| match = re.search( | |
| pattern, response[0]["generated_text"], re.MULTILINE | re.DOTALL | |
| ) | |
| if match: | |
| response = match.group(1).strip() | |
| response = eval(response) | |
| st.success("Data Extracted Successfully!") | |
| st.write(response) | |
| except: | |
| st.error("Unable to connect to model. Please try again later.") | |
| # st.success(f"Data Extracted!") | |