Spaces:
Running
Running
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt | |
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \ | |
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \ | |
ChatDataSQLAskCallBackHandler | |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain | |
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser | |
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain | |
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever | |
from langchain.utilities.sql_database import SQLDatabase | |
from langchain.chains import LLMChain | |
from sqlalchemy import create_engine, MetaData | |
from langchain.prompts import PromptTemplate, ChatPromptTemplate, \ | |
SystemMessagePromptTemplate, HumanMessagePromptTemplate | |
from langchain.prompts.prompt import PromptTemplate | |
from langchain.chat_models import ChatOpenAI | |
from langchain import OpenAI | |
import re | |
import pandas as pd | |
from os import environ | |
import streamlit as st | |
import datetime | |
from helper import build_all, sel_map, display | |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE'] | |
st.set_page_config(page_title="ChatData") | |
st.header("ChatData") | |
if 'retriever' not in st.session_state: | |
st.session_state["sel_map_obj"] = build_all() | |
sel = st.selectbox('Choose the knowledge base you want to ask with:', | |
options=['ArXiv Papers', 'Wikipedia']) | |
sel_map[sel]['hint']() | |
tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers']) | |
with tab_sql: | |
sel_map[sel]['hint_sql']() | |
st.text_input("Ask a question:", key='query_sql') | |
cols = st.columns([1, 1, 7]) | |
cols[0].button("Query", key='search_sql') | |
cols[1].button("Ask", key='ask_sql') | |
plc_hldr = st.empty() | |
if st.session_state.search_sql: | |
plc_hldr = st.empty() | |
print(st.session_state.query_sql) | |
with plc_hldr.expander('Query Log', expanded=True): | |
callback = ChatDataSQLSearchCallBackHandler() | |
try: | |
docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents( | |
st.session_state.query_sql, callbacks=[callback]) | |
callback.progress_bar.progress(value=1.0, text="Done!") | |
docs = pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | |
display(docs) | |
except Exception as e: | |
st.write('Oops π΅ Something bad happened...') | |
raise e | |
if st.session_state.ask_sql: | |
plc_hldr = st.empty() | |
print(st.session_state.query_sql) | |
with plc_hldr.expander('Chat Log', expanded=True): | |
callback = ChatDataSQLAskCallBackHandler() | |
try: | |
ret = st.session_state.sel_map_obj[sel]["sql_chain"]( | |
st.session_state.query_sql, callbacks=[callback]) | |
callback.progress_bar.progress(value=1.0, text="Done!") | |
st.markdown( | |
f"### Answer from LLM\n{ret['answer']}\n### References") | |
docs = ret['sources'] | |
docs = pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | |
display( | |
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') | |
except Exception as e: | |
st.write('Oops π΅ Something bad happened...') | |
raise e | |
with tab_self_query: | |
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='π‘') | |
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"]) | |
st.text_input("Ask a question:", key='query_self') | |
cols = st.columns([1, 1, 7]) | |
cols[0].button("Query", key='search_self') | |
cols[1].button("Ask", key='ask_self') | |
plc_hldr = st.empty() | |
if st.session_state.search_self: | |
plc_hldr = st.empty() | |
print(st.session_state.query_self) | |
with plc_hldr.expander('Query Log', expanded=True): | |
call_back = None | |
callback = ChatDataSelfSearchCallBackHandler() | |
try: | |
docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents( | |
st.session_state.query_self, callbacks=[callback]) | |
print(docs) | |
callback.progress_bar.progress(value=1.0, text="Done!") | |
docs = pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | |
display(docs, sel_map[sel]["must_have_cols"]) | |
except Exception as e: | |
st.write('Oops π΅ Something bad happened...') | |
raise e | |
if st.session_state.ask_self: | |
plc_hldr = st.empty() | |
print(st.session_state.query_self) | |
with plc_hldr.expander('Chat Log', expanded=True): | |
call_back = None | |
callback = ChatDataSelfAskCallBackHandler() | |
try: | |
ret = st.session_state.sel_map_obj[sel]["chain"]( | |
st.session_state.query_self, callbacks=[callback]) | |
callback.progress_bar.progress(value=1.0, text="Done!") | |
st.markdown( | |
f"### Answer from LLM\n{ret['answer']}\n### References") | |
docs = ret['sources'] | |
docs = pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | |
display( | |
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') | |
except Exception as e: | |
st.write('Oops π΅ Something bad happened...') | |
raise e | |