Fangrui Liu commited on
Commit
eb820e1
Β·
1 Parent(s): 1a24bbc

revised prompt

Browse files
Files changed (3) hide show
  1. app.py +65 -52
  2. chains/arxiv_chains.py +131 -0
  3. prompts/arxiv_prompt.py +7 -8
app.py CHANGED
@@ -10,15 +10,12 @@ from langchain.vectorstores import MyScale, MyScaleSettings
10
  from langchain.embeddings import HuggingFaceInstructEmbeddings
11
  from langchain.retrievers.self_query.base import SelfQueryRetriever
12
  from langchain.chains.query_constructor.base import AttributeInfo
13
- from langchain.chains import RetrievalQAWithSourcesChain
14
  from langchain import OpenAI
15
  from langchain.chat_models import ChatOpenAI
16
 
17
- from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
18
- from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
19
- ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
20
- ChatDataSQLAskCallBackHandler
21
  from langchain.prompts.prompt import PromptTemplate
 
 
22
  from sqlalchemy import create_engine, MetaData
23
  from langchain.chains.sql_database.base import SQLDatabaseChain
24
  from langchain.chains.sql_database.parser import VectorSQLRetrieveAllOutputParser
@@ -26,12 +23,17 @@ from langchain.chains import LLMChain
26
  from langchain.sql_database import SQLDatabase
27
  from langchain.retrievers import SQLDatabaseChainRetriever
28
 
 
 
 
 
 
29
 
30
  st.set_page_config(page_title="ChatData")
31
 
32
  st.header("ChatData")
33
 
34
- columns = ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate']
35
 
36
 
37
  def try_eval(x):
@@ -41,7 +43,9 @@ def try_eval(x):
41
  return x
42
 
43
 
44
- def display(dataframe, columns=None):
 
 
45
  if len(dataframe) > 0:
46
  if columns:
47
  st.dataframe(dataframe[columns])
@@ -108,24 +112,35 @@ def build_retriever():
108
  doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
109
  use_original_query=False)
110
 
111
- with st.spinner('Building RetrievalQAWith SourcesChain...'):
112
- document_with_metadata_prompt = PromptTemplate(
113
- input_variables=["page_content", "id", "title",
114
- "authors", "pubdate", "categories"],
115
- template="Content:\n\tTitle: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}")
116
- COMBINE_PROMPT = PromptTemplate(
117
- template=combine_prompt_template, input_variables=["summaries", "question"])
118
- chain = RetrievalQAWithSourcesChain.from_chain_type(
119
- ChatOpenAI(model_name='gpt-3.5-turbo-16k',
120
- openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0.6),
 
 
 
121
  retriever=retriever,
122
- chain_type='stuff',
123
- chain_type_kwargs={
124
- 'prompt': COMBINE_PROMPT,
125
- 'document_prompt': document_with_metadata_prompt,
126
- }, return_source_documents=True)
 
 
 
127
 
128
- with st.spinner('Building Vector SQL Database Chain'):
 
 
 
 
 
129
  MYSCALE_USER = st.secrets['MYSCALE_USER']
130
  MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
131
  MYSCALE_HOST = st.secrets['MYSCALE_HOST']
@@ -141,7 +156,7 @@ def build_retriever():
141
  output_parser = VectorSQLRetrieveAllOutputParser.from_embeddings(
142
  model=embeddings)
143
  sql_query_chain = SQLDatabaseChain.from_llm(
144
- llm=OpenAI(openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0),
145
  prompt=PROMPT,
146
  top_k=10,
147
  return_direct=True,
@@ -151,15 +166,23 @@ def build_retriever():
151
  )
152
  sql_retriever = SQLDatabaseChainRetriever(
153
  sql_db_chain=sql_query_chain, page_content_key="abstract")
154
- sql_chain = RetrievalQAWithSourcesChain.from_chain_type(
155
- ChatOpenAI(model_name='gpt-3.5-turbo-16k',
156
- openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0.6),
157
  retriever=sql_retriever,
158
- chain_type='stuff',
159
- chain_type_kwargs={
160
- 'prompt': COMBINE_PROMPT,
161
- 'document_prompt': document_with_metadata_prompt,
162
- }, return_source_documents=True)
 
 
 
 
 
 
 
 
163
 
164
  return [{'name': m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], retriever, chain, sql_retriever, sql_chain
165
 
@@ -220,7 +243,7 @@ ENGINE = ReplacingMergeTree ORDER BY id
220
  display(docs)
221
  except Exception as e:
222
  st.write('Oops 😡 Something bad happened...')
223
- # raise e
224
 
225
  if st.session_state.ask_sql:
226
  plc_hldr = st.empty()
@@ -233,17 +256,12 @@ ENGINE = ReplacingMergeTree ORDER BY id
233
  callback.progress_bar.progress(value=1.0, text="Done!")
234
  st.markdown(
235
  f"### Answer from LLM\n{ret['answer']}\n### References")
236
- docs = ret['source_documents']
237
- ref = re.findall(
238
- '(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['sources'])
239
- ref += re.findall(
240
- '(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['answer'])
241
- docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content}
242
- for d in docs if d.metadata['id'] in set(ref)])
243
- display(docs, columns)
244
  except Exception as e:
245
  st.write('Oops 😡 Something bad happened...')
246
- # raise e
247
 
248
 
249
  with tab_self_query:
@@ -270,7 +288,7 @@ with tab_self_query:
270
  display(docs, columns)
271
  except Exception as e:
272
  st.write('Oops 😡 Something bad happened...')
273
- # raise e
274
 
275
  if st.session_state.ask_self:
276
  plc_hldr = st.empty()
@@ -284,14 +302,9 @@ with tab_self_query:
284
  callback.progress_bar.progress(value=1.0, text="Done!")
285
  st.markdown(
286
  f"### Answer from LLM\n{ret['answer']}\n### References")
287
- docs = ret['source_documents']
288
- ref = re.findall(
289
- '(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['sources'])
290
- ref += re.findall(
291
- '(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['answer'])
292
- docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content}
293
- for d in docs if d.metadata['id'] in set(ref)])
294
- display(docs, columns)
295
  except Exception as e:
296
  st.write('Oops 😡 Something bad happened...')
297
- # raise e
 
10
  from langchain.embeddings import HuggingFaceInstructEmbeddings
11
  from langchain.retrievers.self_query.base import SelfQueryRetriever
12
  from langchain.chains.query_constructor.base import AttributeInfo
 
13
  from langchain import OpenAI
14
  from langchain.chat_models import ChatOpenAI
15
 
 
 
 
 
16
  from langchain.prompts.prompt import PromptTemplate
17
+ from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
18
+ SystemMessagePromptTemplate, HumanMessagePromptTemplate
19
  from sqlalchemy import create_engine, MetaData
20
  from langchain.chains.sql_database.base import SQLDatabaseChain
21
  from langchain.chains.sql_database.parser import VectorSQLRetrieveAllOutputParser
 
23
  from langchain.sql_database import SQLDatabase
24
  from langchain.retrievers import SQLDatabaseChainRetriever
25
 
26
+ from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
27
+ from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
28
+ ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
29
+ ChatDataSQLAskCallBackHandler
30
+ from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
31
 
32
  st.set_page_config(page_title="ChatData")
33
 
34
  st.header("ChatData")
35
 
36
+ columns = ['ref_id', 'title', 'id', 'categories', 'abstract', 'authors', 'pubdate']
37
 
38
 
39
  def try_eval(x):
 
43
  return x
44
 
45
 
46
+ def display(dataframe, columns=None, index=None):
47
+ if index:
48
+ dataframe.set_index(index)
49
  if len(dataframe) > 0:
50
  if columns:
51
  st.dataframe(dataframe[columns])
 
112
  doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
113
  use_original_query=False)
114
 
115
+
116
+ document_with_metadata_prompt = PromptTemplate(
117
+ input_variables=["page_content", "id", "title", "ref_id",
118
+ "authors", "pubdate", "categories"],
119
+ template="Title for PDF #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}")
120
+
121
+ COMBINE_PROMPT = ChatPromptTemplate.from_strings(
122
+ string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
123
+ (HumanMessagePromptTemplate, '{question}')])
124
+ OPENAI_API_KEY = st.secrets['OPENAI_API_KEY']
125
+
126
+ with st.spinner('Building QA Chain with Self-query...'):
127
+ chain = ArXivQAwithSourcesChain(
128
  retriever=retriever,
129
+ combine_documents_chain=ArXivStuffDocumentChain(
130
+ llm_chain=LLMChain(
131
+ prompt=COMBINE_PROMPT,
132
+ llm=ChatOpenAI(model_name='gpt-3.5-turbo-16k',
133
+ openai_api_key=OPENAI_API_KEY, temperature=0.6),
134
+ ),
135
+ document_prompt=document_with_metadata_prompt,
136
+ document_variable_name="summaries",
137
 
138
+ ),
139
+ return_source_documents=True,
140
+ max_tokens_limit=12000,
141
+ )
142
+
143
+ with st.spinner('Building Vector SQL Database Retriever'):
144
  MYSCALE_USER = st.secrets['MYSCALE_USER']
145
  MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
146
  MYSCALE_HOST = st.secrets['MYSCALE_HOST']
 
156
  output_parser = VectorSQLRetrieveAllOutputParser.from_embeddings(
157
  model=embeddings)
158
  sql_query_chain = SQLDatabaseChain.from_llm(
159
+ llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
160
  prompt=PROMPT,
161
  top_k=10,
162
  return_direct=True,
 
166
  )
167
  sql_retriever = SQLDatabaseChainRetriever(
168
  sql_db_chain=sql_query_chain, page_content_key="abstract")
169
+
170
+ with st.spinner('Building QA Chain with Vector SQL...'):
171
+ sql_chain = ArXivQAwithSourcesChain(
172
  retriever=sql_retriever,
173
+ combine_documents_chain=ArXivStuffDocumentChain(
174
+ llm_chain=LLMChain(
175
+ prompt=COMBINE_PROMPT,
176
+ llm=ChatOpenAI(model_name='gpt-3.5-turbo-16k',
177
+ openai_api_key=OPENAI_API_KEY, temperature=0.6),
178
+ ),
179
+ document_prompt=document_with_metadata_prompt,
180
+ document_variable_name="summaries",
181
+
182
+ ),
183
+ return_source_documents=True,
184
+ max_tokens_limit=12000,
185
+ )
186
 
187
  return [{'name': m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], retriever, chain, sql_retriever, sql_chain
188
 
 
243
  display(docs)
244
  except Exception as e:
245
  st.write('Oops 😡 Something bad happened...')
246
+ raise e
247
 
248
  if st.session_state.ask_sql:
249
  plc_hldr = st.empty()
 
256
  callback.progress_bar.progress(value=1.0, text="Done!")
257
  st.markdown(
258
  f"### Answer from LLM\n{ret['answer']}\n### References")
259
+ docs = ret['sources']
260
+ docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content} for d in docs])
261
+ display(docs, columns, index='ref_id')
 
 
 
 
 
262
  except Exception as e:
263
  st.write('Oops 😡 Something bad happened...')
264
+ raise e
265
 
266
 
267
  with tab_self_query:
 
288
  display(docs, columns)
289
  except Exception as e:
290
  st.write('Oops 😡 Something bad happened...')
291
+ raise e
292
 
293
  if st.session_state.ask_self:
294
  plc_hldr = st.empty()
 
302
  callback.progress_bar.progress(value=1.0, text="Done!")
303
  st.markdown(
304
  f"### Answer from LLM\n{ret['answer']}\n### References")
305
+ docs = ret['sources']
306
+ docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content} for d in docs])
307
+ display(docs, columns, index='ref_id')
 
 
 
 
 
308
  except Exception as e:
309
  st.write('Oops 😡 Something bad happened...')
310
+ raise e
chains/arxiv_chains.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import inspect
3
+ from typing import Dict, Any, Optional, List, Tuple
4
+
5
+
6
+ from langchain.callbacks.manager import (
7
+ AsyncCallbackManagerForChainRun,
8
+ CallbackManagerForChainRun,
9
+ )
10
+ from langchain.schema import BaseRetriever
11
+ from langchain.callbacks.manager import Callbacks
12
+ from langchain.schema.prompt_template import format_document
13
+ from langchain.docstore.document import Document
14
+ from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
15
+ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
16
+ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
17
+
18
+
19
+ class ArXivStuffDocumentChain(StuffDocumentsChain):
20
+ """Combine arxiv documents with PDF reference number"""
21
+
22
+ def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
23
+ """Construct inputs from kwargs and docs.
24
+
25
+ Format and the join all the documents together into one input with name
26
+ `self.document_variable_name`. The pluck any additional variables
27
+ from **kwargs.
28
+
29
+ Args:
30
+ docs: List of documents to format and then join into single input
31
+ **kwargs: additional inputs to chain, will pluck any other required
32
+ arguments from here.
33
+
34
+ Returns:
35
+ dictionary of inputs to LLMChain
36
+ """
37
+ # Format each document according to the prompt
38
+ doc_strings = []
39
+ for doc_id, doc in enumerate(docs):
40
+ # add temp reference number in metadata
41
+ doc.metadata.update({'ref_id': doc_id})
42
+ doc.page_content = doc.page_content.replace('\n', ' ')
43
+ doc_strings.append(format_document(doc, self.document_prompt))
44
+ # Join the documents together to put them in the prompt.
45
+ inputs = {
46
+ k: v
47
+ for k, v in kwargs.items()
48
+ if k in self.llm_chain.prompt.input_variables
49
+ }
50
+ inputs[self.document_variable_name] = self.document_separator.join(
51
+ doc_strings)
52
+ return inputs
53
+
54
+ def combine_docs(
55
+ self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
56
+ ) -> Tuple[str, dict]:
57
+ """Stuff all documents into one prompt and pass to LLM.
58
+
59
+ Args:
60
+ docs: List of documents to join together into one variable
61
+ callbacks: Optional callbacks to pass along
62
+ **kwargs: additional parameters to use to get inputs to LLMChain.
63
+
64
+ Returns:
65
+ The first element returned is the single string output. The second
66
+ element returned is a dictionary of other keys to return.
67
+ """
68
+ inputs = self._get_inputs(docs, **kwargs)
69
+ # Call predict on the LLM.
70
+ output = self.llm_chain.predict(callbacks=callbacks, **inputs)
71
+ return output, {}
72
+
73
+ @property
74
+ def _chain_type(self) -> str:
75
+ return "referenced_stuff_documents_chain"
76
+
77
+
78
+ class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
79
+ """QA with source chain for Chat ArXiv app with references
80
+
81
+ This chain will automatically assign reference number to the article,
82
+ Then parse it back to titles or anything else.
83
+ """
84
+
85
+ def _call(
86
+ self,
87
+ inputs: Dict[str, Any],
88
+ run_manager: Optional[CallbackManagerForChainRun] = None,
89
+ ) -> Dict[str, str]:
90
+ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
91
+ accepts_run_manager = (
92
+ "run_manager" in inspect.signature(self._get_docs).parameters
93
+ )
94
+ if accepts_run_manager:
95
+ docs = self._get_docs(inputs, run_manager=_run_manager)
96
+ else:
97
+ docs = self._get_docs(inputs) # type: ignore[call-arg]
98
+
99
+ answer = self.combine_documents_chain.run(
100
+ input_documents=docs, callbacks=_run_manager.get_child(), **inputs
101
+ )
102
+ # parse source with ref_id
103
+ sources = []
104
+ ref_cnt = 1
105
+ for d in docs:
106
+ ref_id = d.metadata['ref_id']
107
+ if f"PDF #{ref_id}" in answer:
108
+ title = d.metadata['title'].replace('\n', '')
109
+ d.metadata['ref_id'] = ref_cnt
110
+ answer = answer.replace(f"PDF #{ref_id}", f"{title} [{ref_cnt}]")
111
+ sources.append(d)
112
+ ref_cnt += 1
113
+
114
+ result: Dict[str, Any] = {
115
+ self.answer_key: answer,
116
+ self.sources_answer_key: sources,
117
+ }
118
+ if self.return_source_documents:
119
+ result["source_documents"] = docs
120
+ return result
121
+
122
+ async def _acall(
123
+ self,
124
+ inputs: Dict[str, Any],
125
+ run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
126
+ ) -> Dict[str, Any]:
127
+ raise NotImplementedError
128
+
129
+ @property
130
+ def _chain_type(self) -> str:
131
+ return "arxiv_qa_with_sources_chain"
prompts/arxiv_prompt.py CHANGED
@@ -1,15 +1,14 @@
1
- from langchain.chains.qa_with_sources.map_reduce_prompt import combine_prompt_template
2
- combine_prompt_template_ = (
3
- "You are a helpful paper assistant. Your task is to provide information and answer any questions "
4
- + "related to PDFs given below. You should only use the abstract of the selected papers as your source of information "
5
  + "and try to provide concise and accurate answers to any questions asked by the user. If you are unable to find "
6
  + "relevant information in the given sections, you will need to let the user know that the source does not contain "
7
- + "relevant information but still try to provide an answer based on your general knowledge. The following is the related information "
8
- + "about the paper that will help you answer users' questions, you MUST answer it using question's language:\n\n"
 
 
9
  )
10
 
11
- combine_prompt_template = combine_prompt_template_ + combine_prompt_template
12
-
13
  _myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
14
  MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
15
  When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.
 
1
+ combine_prompt_template = (
2
+ "You are a helpful PDF assistant. Your task is to provide information and answer any questions "
3
+ + "related to PDFs given below. You should use the sections, title and abstract of the selected PDFs as your source of information "
 
4
  + "and try to provide concise and accurate answers to any questions asked by the user. If you are unable to find "
5
  + "relevant information in the given sections, you will need to let the user know that the source does not contain "
6
+ + "relevant information but still try to provide an answer based on your general knowledge. You must refer to the "
7
+ + "corresponding section name and page that you refer to when answering. The following is the related information "
8
+ + "about the PDF file that will help you answer users' questions, you MUST answer it using question's language:\n\n {summaries}"
9
+ + "Now you should anwser user's question. Remember you must use the PDF # to refer papers:\n\n"
10
  )
11
 
 
 
12
  _myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
13
  MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
14
  When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.