Fangrui Liu commited on
Commit
45180a0
·
1 Parent(s): d5a4cb4

add wikipedia

Browse files
app.py CHANGED
@@ -1,3 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  import pandas as pd
3
  from os import environ
@@ -6,34 +28,156 @@ import datetime
6
  environ['TOKENIZERS_PARALLELISM'] = 'true'
7
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
8
 
9
- from langchain.vectorstores import MyScale, MyScaleSettings
10
- from langchain.embeddings import HuggingFaceInstructEmbeddings
11
- from langchain.retrievers.self_query.base import SelfQueryRetriever
12
- from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
13
- from langchain import OpenAI
14
- from langchain.chat_models import ChatOpenAI
15
 
16
- from langchain.prompts.prompt import PromptTemplate
17
- from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
18
- SystemMessagePromptTemplate, HumanMessagePromptTemplate
19
- from sqlalchemy import create_engine, MetaData
20
- from langchain.chains import LLMChain
21
 
22
- from langchain.utilities.sql_database import SQLDatabase
23
- from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
24
- from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
25
 
26
- from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
27
- from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
28
- from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
29
- ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
30
- ChatDataSQLAskCallBackHandler
31
- from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
32
 
33
 
34
- st.set_page_config(page_title="ChatData")
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- st.header("ChatData")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  def try_eval(x):
@@ -55,14 +199,14 @@ def display(dataframe, columns_=None, index=None):
55
  st.write("Sorry 😵 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
56
 
57
 
58
- @st.cache_resource
59
- def build_retriever():
60
  with st.spinner("Loading Model..."):
61
- embeddings = HuggingFaceInstructEmbeddings(
62
- model_name='hkunlp/instructor-xl',
63
- embed_instruction="Represent the question for retrieving supporting scientific papers: ")
64
 
65
- with st.spinner("Connecting DB..."):
 
 
66
  myscale_connection = {
67
  "host": st.secrets['MYSCALE_HOST'],
68
  "port": st.secrets['MYSCALE_PORT'],
@@ -70,69 +214,40 @@ def build_retriever():
70
  "password": st.secrets['MYSCALE_PASSWORD'],
71
  }
72
 
73
- config = MyScaleSettings(**myscale_connection, table='ChatArXiv',
 
 
74
  column_map={
75
  "id": "id",
76
- "text": "abstract",
77
- "vector": "vector",
78
- "metadata": "metadata"
79
  })
80
- doc_search = MyScale(embeddings, config)
 
81
 
82
- with st.spinner("Building Self Query Retriever..."):
83
- metadata_field_info = [
84
- AttributeInfo(
85
- name=VirtualColumnName(name="pubdate"),
86
- description="The year the paper is published",
87
- type="timestamp",
88
- ),
89
- AttributeInfo(
90
- name="authors",
91
- description="List of author names",
92
- type="list[string]",
93
- ),
94
- AttributeInfo(
95
- name="title",
96
- description="Title of the paper",
97
- type="string",
98
- ),
99
- AttributeInfo(
100
- name="categories",
101
- description="arxiv categories to this paper",
102
- type="list[string]"
103
- ),
104
- AttributeInfo(
105
- name="length(categories)",
106
- description="length of arxiv categories to this paper",
107
- type="int"
108
- ),
109
- ]
110
  retriever = SelfQueryRetriever.from_llm(
111
- OpenAI(openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0),
112
  doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
113
- use_original_query=False)
114
-
115
-
116
- document_with_metadata_prompt = PromptTemplate(
117
- input_variables=["page_content", "id", "title", "ref_id",
118
- "authors", "pubdate", "categories"],
119
- template="Title for PDF #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}")
120
 
121
  COMBINE_PROMPT = ChatPromptTemplate.from_strings(
122
  string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
123
- (HumanMessagePromptTemplate, '{question}')])
124
  OPENAI_API_KEY = st.secrets['OPENAI_API_KEY']
125
 
126
- with st.spinner('Building QA Chain with Self-query...'):
127
  chain = ArXivQAwithSourcesChain(
128
  retriever=retriever,
129
  combine_documents_chain=ArXivStuffDocumentChain(
130
  llm_chain=LLMChain(
131
  prompt=COMBINE_PROMPT,
132
- llm=ChatOpenAI(model_name='gpt-3.5-turbo-16k',
133
- openai_api_key=OPENAI_API_KEY, temperature=0.6),
134
  ),
135
- document_prompt=document_with_metadata_prompt,
136
  document_variable_name="summaries",
137
 
138
  ),
@@ -140,23 +255,22 @@ def build_retriever():
140
  max_tokens_limit=12000,
141
  )
142
 
143
- with st.spinner('Building Vector SQL Database Retriever'):
144
  MYSCALE_USER = st.secrets['MYSCALE_USER']
145
  MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
146
  MYSCALE_HOST = st.secrets['MYSCALE_HOST']
147
  MYSCALE_PORT = st.secrets['MYSCALE_PORT']
148
  engine = create_engine(
149
- f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/default?protocol=https')
150
  metadata = MetaData(bind=engine)
151
  PROMPT = PromptTemplate(
152
  input_variables=["input", "table_info", "top_k"],
153
  template=_myscale_prompt,
154
  )
155
-
156
  output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
157
- model=embeddings)
158
  sql_query_chain = VectorSQLDatabaseChain.from_llm(
159
- llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
160
  prompt=PROMPT,
161
  top_k=10,
162
  return_direct=True,
@@ -165,18 +279,18 @@ def build_retriever():
165
  native_format=True
166
  )
167
  sql_retriever = VectorSQLDatabaseChainRetriever(
168
- sql_db_chain=sql_query_chain, page_content_key="abstract")
169
 
170
- with st.spinner('Building QA Chain with Vector SQL...'):
171
  sql_chain = ArXivQAwithSourcesChain(
172
  retriever=sql_retriever,
173
  combine_documents_chain=ArXivStuffDocumentChain(
174
  llm_chain=LLMChain(
175
  prompt=COMBINE_PROMPT,
176
- llm=ChatOpenAI(model_name='gpt-3.5-turbo-16k',
177
- openai_api_key=OPENAI_API_KEY, temperature=0.6),
178
  ),
179
- document_prompt=document_with_metadata_prompt,
180
  document_variable_name="summaries",
181
 
182
  ),
@@ -184,48 +298,33 @@ def build_retriever():
184
  max_tokens_limit=12000,
185
  )
186
 
187
- return [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], \
188
- retriever, chain, sql_retriever, sql_chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
 
191
  if 'retriever' not in st.session_state:
192
- st.session_state['metadata_columns'], \
193
- st.session_state['retriever'], \
194
- st.session_state['chain'], \
195
- st.session_state['sql_retriever'], \
196
- st.session_state['sql_chain'] = build_retriever()
197
-
198
- st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
199
- "For example: \n\n"
200
- "*If you want to search papers with complex filters*:\n\n"
201
- "- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n"
202
- "*If you want to ask questions based on papers in database*:\n\n"
203
- "- What is PageRank?\n"
204
- "- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n"
205
- "- Introduce some applications of GANs published around 2019.\n"
206
- "- 请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些\n"
207
- "- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?\n"
208
- "- Is it possible to synthesize room temperature super conductive material?")
209
  tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
210
  with tab_sql:
211
- st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
212
- st.markdown('''```sql
213
- CREATE TABLE default.ChatArXiv (
214
- `abstract` String,
215
- `id` String,
216
- `vector` Array(Float32),
217
- `metadata` Object('JSON'),
218
- `pubdate` DateTime,
219
- `title` String,
220
- `categories` Array(String),
221
- `authors` Array(String),
222
- `comment` String,
223
- `primary_category` String,
224
- VECTOR INDEX vec_idx vector TYPE MSTG('metric_type=Cosine'),
225
- CONSTRAINT vec_len CHECK length(vector) = 768)
226
- ENGINE = ReplacingMergeTree ORDER BY id
227
- ```''')
228
-
229
  st.text_input("Ask a question:", key='query_sql')
230
  cols = st.columns([1, 1, 7])
231
  cols[0].button("Query", key='search_sql')
@@ -237,7 +336,7 @@ ENGINE = ReplacingMergeTree ORDER BY id
237
  with plc_hldr.expander('Query Log', expanded=True):
238
  callback = ChatDataSQLSearchCallBackHandler()
239
  try:
240
- docs = st.session_state.sql_retriever.get_relevant_documents(
241
  st.session_state.query_sql, callbacks=[callback])
242
  callback.progress_bar.progress(value=1.0, text="Done!")
243
  docs = pd.DataFrame(
@@ -253,14 +352,16 @@ ENGINE = ReplacingMergeTree ORDER BY id
253
  with plc_hldr.expander('Chat Log', expanded=True):
254
  callback = ChatDataSQLAskCallBackHandler()
255
  try:
256
- ret = st.session_state.sql_chain(
257
  st.session_state.query_sql, callbacks=[callback])
258
  callback.progress_bar.progress(value=1.0, text="Done!")
259
  st.markdown(
260
  f"### Answer from LLM\n{ret['answer']}\n### References")
261
  docs = ret['sources']
262
- docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content} for d in docs])
263
- display(docs, ['ref_id', 'title', 'id', 'categories', 'abstract', 'authors', 'pubdate'], index='ref_id')
 
 
264
  except Exception as e:
265
  st.write('Oops 😵 Something bad happened...')
266
  raise e
@@ -268,7 +369,7 @@ ENGINE = ReplacingMergeTree ORDER BY id
268
 
269
  with tab_self_query:
270
  st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
271
- st.dataframe(st.session_state.metadata_columns)
272
  st.text_input("Ask a question:", key='query_self')
273
  cols = st.columns([1, 1, 7])
274
  cols[0].button("Query", key='search_self')
@@ -281,13 +382,13 @@ with tab_self_query:
281
  call_back = None
282
  callback = ChatDataSelfSearchCallBackHandler()
283
  try:
284
- docs = st.session_state.retriever.get_relevant_documents(
285
  st.session_state.query_self, callbacks=[callback])
 
286
  callback.progress_bar.progress(value=1.0, text="Done!")
287
  docs = pd.DataFrame(
288
  [{**d.metadata, 'abstract': d.page_content} for d in docs])
289
-
290
- display(docs, ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'])
291
  except Exception as e:
292
  st.write('Oops 😵 Something bad happened...')
293
  raise e
@@ -299,14 +400,16 @@ with tab_self_query:
299
  call_back = None
300
  callback = ChatDataSelfAskCallBackHandler()
301
  try:
302
- ret = st.session_state.chain(
303
  st.session_state.query_self, callbacks=[callback])
304
  callback.progress_bar.progress(value=1.0, text="Done!")
305
  st.markdown(
306
  f"### Answer from LLM\n{ret['answer']}\n### References")
307
  docs = ret['sources']
308
- docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content} for d in docs])
309
- display(docs, ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'], index='ref_id')
 
 
310
  except Exception as e:
311
  st.write('Oops 😵 Something bad happened...')
312
  raise e
 
1
+ from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
2
+ from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
3
+ ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
4
+ ChatDataSQLAskCallBackHandler
5
+ from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
6
+ from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
7
+ from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
8
+ from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
9
+ from langchain.utilities.sql_database import SQLDatabase
10
+ from langchain.chains import LLMChain
11
+ from sqlalchemy import create_engine, MetaData
12
+ from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
13
+ SystemMessagePromptTemplate, HumanMessagePromptTemplate
14
+ from langchain.prompts.prompt import PromptTemplate
15
+ from langchain.chat_models import ChatOpenAI
16
+ from langchain import OpenAI
17
+ from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
18
+ from langchain.retrievers.self_query.base import SelfQueryRetriever
19
+ from langchain.retrievers.self_query.myscale import MyScaleTranslator
20
+ from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings
21
+ from langchain.vectorstores import MyScaleSettings
22
+ from chains.arxiv_chains import MyScaleWithoutMetadataJson
23
  import re
24
  import pandas as pd
25
  from os import environ
 
28
  environ['TOKENIZERS_PARALLELISM'] = 'true'
29
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
30
 
 
 
 
 
 
 
31
 
32
+ st.set_page_config(page_title="ChatData")
 
 
 
 
33
 
34
+ st.header("ChatData")
 
 
35
 
36
+ # query_model_name = "gpt-3.5-turbo-instruct"
37
+ query_model_name = "text-davinci-003"
38
+ chat_model_name = "gpt-3.5-turbo-16k"
 
 
 
39
 
40
 
41
+ def hint_arxiv():
42
+ st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
43
+ "For example: \n\n"
44
+ "*If you want to search papers with complex filters*:\n\n"
45
+ "- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n"
46
+ "*If you want to ask questions based on papers in database*:\n\n"
47
+ "- What is PageRank?\n"
48
+ "- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n"
49
+ "- Introduce some applications of GANs published around 2019.\n"
50
+ "- 请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些\n"
51
+ "- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?\n"
52
+ "- Is it possible to synthesize room temperature super conductive material?")
53
 
54
+
55
+ def hint_sql_arxiv():
56
+ st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
57
+ st.markdown('''```sql
58
+ CREATE TABLE default.ChatArXiv (
59
+ `abstract` String,
60
+ `id` String,
61
+ `vector` Array(Float32),
62
+ `metadata` Object('JSON'),
63
+ `pubdate` DateTime,
64
+ `title` String,
65
+ `categories` Array(String),
66
+ `authors` Array(String),
67
+ `comment` String,
68
+ `primary_category` String,
69
+ VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
70
+ CONSTRAINT vec_len CHECK length(vector) = 768)
71
+ ENGINE = ReplacingMergeTree ORDER BY id
72
+ ```''')
73
+
74
+
75
+ def hint_wiki():
76
+ st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
77
+ "For example: \n\n"
78
+ "- Which company did Elon Musk found?\n"
79
+ "- What is Iron Gwazi?\n"
80
+ "- What is a Ring in mathematics?\n"
81
+ "- 苹果的发源地是那里?\n")
82
+
83
+
84
+ def hint_sql_wiki():
85
+ st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
86
+ st.markdown('''```sql
87
+ CREATE TABLE wiki.Wikipedia (
88
+ `id` String,
89
+ `title` String,
90
+ `text` String,
91
+ `url` String,
92
+ `wiki_id` UInt64,
93
+ `views` Float32,
94
+ `paragraph_id` UInt64,
95
+ `langs` UInt32,
96
+ `emb` Array(Float32),
97
+ VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
98
+ CONSTRAINT emb_len CHECK length(emb) = 768)
99
+ ENGINE = ReplacingMergeTree ORDER BY id
100
+ ```''')
101
+
102
+
103
+ sel_map = {
104
+ 'Wikipedia': {
105
+ "database": "wiki",
106
+ "table": "Wikipedia",
107
+ "hint": hint_wiki,
108
+ "hint_sql": hint_sql_wiki,
109
+ "doc_prompt": PromptTemplate(
110
+ input_variables=["page_content", "url", "title", "ref_id", "views"],
111
+ template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"),
112
+ "metadata_cols": [
113
+ AttributeInfo(
114
+ name="title",
115
+ description="title of the wikipedia page",
116
+ type="string",
117
+ ),
118
+ AttributeInfo(
119
+ name="text",
120
+ description="paragraph from this wiki page",
121
+ type="string",
122
+ ),
123
+ AttributeInfo(
124
+ name="views",
125
+ description="number of views",
126
+ type="float"
127
+ ),
128
+ ],
129
+ "must_have_cols": ['id', 'title', 'url', 'text', 'views'],
130
+ "vector_col": "emb",
131
+ "text_col": "text",
132
+ "metadata_col": "metadata",
133
+ "emb_model": lambda: SentenceTransformerEmbeddings(
134
+ model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',)
135
+ },
136
+ 'ArXiv Papers': {
137
+ "database": "default",
138
+ "table": "ChatArXiv",
139
+ "hint": hint_arxiv,
140
+ "hint_sql": hint_sql_arxiv,
141
+ "doc_prompt": PromptTemplate(
142
+ input_variables=["page_content", "id", "title", "ref_id",
143
+ "authors", "pubdate", "categories"],
144
+ template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"),
145
+ "metadata_cols": [
146
+ AttributeInfo(
147
+ name=VirtualColumnName(name="pubdate"),
148
+ description="The year the paper is published",
149
+ type="timestamp",
150
+ ),
151
+ AttributeInfo(
152
+ name="authors",
153
+ description="List of author names",
154
+ type="list[string]",
155
+ ),
156
+ AttributeInfo(
157
+ name="title",
158
+ description="Title of the paper",
159
+ type="string",
160
+ ),
161
+ AttributeInfo(
162
+ name="categories",
163
+ description="arxiv categories to this paper",
164
+ type="list[string]"
165
+ ),
166
+ AttributeInfo(
167
+ name="length(categories)",
168
+ description="length of arxiv categories to this paper",
169
+ type="int"
170
+ ),
171
+ ],
172
+ "must_have_cols": ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'],
173
+ "vector_col": "vector",
174
+ "text_col": "abstract",
175
+ "metadata_col": "metadata",
176
+ "emb_model": lambda: HuggingFaceInstructEmbeddings(
177
+ model_name='hkunlp/instructor-xl',
178
+ embed_instruction="Represent the question for retrieving supporting scientific papers: ")
179
+ }
180
+ }
181
 
182
 
183
  def try_eval(x):
 
199
  st.write("Sorry 😵 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
200
 
201
 
202
+ def build_embedding_model(_sel):
 
203
  with st.spinner("Loading Model..."):
204
+ embeddings = sel_map[_sel]["emb_model"]()
205
+ return embeddings
 
206
 
207
+
208
+ def build_retriever(_sel):
209
+ with st.spinner(f"Connecting DB for {_sel}..."):
210
  myscale_connection = {
211
  "host": st.secrets['MYSCALE_HOST'],
212
  "port": st.secrets['MYSCALE_PORT'],
 
214
  "password": st.secrets['MYSCALE_PASSWORD'],
215
  }
216
 
217
+ config = MyScaleSettings(**myscale_connection,
218
+ database=sel_map[_sel]["database"],
219
+ table=sel_map[_sel]["table"],
220
  column_map={
221
  "id": "id",
222
+ "text": sel_map[_sel]["text_col"],
223
+ "vector": sel_map[_sel]["vector_col"],
224
+ "metadata": sel_map[_sel]["metadata_col"]
225
  })
226
+ doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config,
227
+ must_have_cols=sel_map[_sel]['must_have_cols'])
228
 
229
+ with st.spinner(f"Building Self Query Retriever for {_sel}..."):
230
+ metadata_field_info = sel_map[_sel]["metadata_cols"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  retriever = SelfQueryRetriever.from_llm(
232
+ OpenAI(model_name=query_model_name, openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0),
233
  doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
234
+ use_original_query=False, structured_query_translator=MyScaleTranslator())
 
 
 
 
 
 
235
 
236
  COMBINE_PROMPT = ChatPromptTemplate.from_strings(
237
  string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
238
+ (HumanMessagePromptTemplate, '{question}')])
239
  OPENAI_API_KEY = st.secrets['OPENAI_API_KEY']
240
 
241
+ with st.spinner(f'Building QA Chain with Self-query for {_sel}...'):
242
  chain = ArXivQAwithSourcesChain(
243
  retriever=retriever,
244
  combine_documents_chain=ArXivStuffDocumentChain(
245
  llm_chain=LLMChain(
246
  prompt=COMBINE_PROMPT,
247
+ llm=ChatOpenAI(model_name=chat_model_name,
248
+ openai_api_key=OPENAI_API_KEY, temperature=0.6),
249
  ),
250
+ document_prompt=sel_map[_sel]["doc_prompt"],
251
  document_variable_name="summaries",
252
 
253
  ),
 
255
  max_tokens_limit=12000,
256
  )
257
 
258
+ with st.spinner(f'Building Vector SQL Database Retriever for {_sel}...'):
259
  MYSCALE_USER = st.secrets['MYSCALE_USER']
260
  MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
261
  MYSCALE_HOST = st.secrets['MYSCALE_HOST']
262
  MYSCALE_PORT = st.secrets['MYSCALE_PORT']
263
  engine = create_engine(
264
+ f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/{sel_map[_sel]["database"]}?protocol=https')
265
  metadata = MetaData(bind=engine)
266
  PROMPT = PromptTemplate(
267
  input_variables=["input", "table_info", "top_k"],
268
  template=_myscale_prompt,
269
  )
 
270
  output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
271
+ model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"])
272
  sql_query_chain = VectorSQLDatabaseChain.from_llm(
273
+ llm=OpenAI(model_name=query_model_name, openai_api_key=OPENAI_API_KEY, temperature=0),
274
  prompt=PROMPT,
275
  top_k=10,
276
  return_direct=True,
 
279
  native_format=True
280
  )
281
  sql_retriever = VectorSQLDatabaseChainRetriever(
282
+ sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"])
283
 
284
+ with st.spinner(f'Building QA Chain with Vector SQL for {_sel}...'):
285
  sql_chain = ArXivQAwithSourcesChain(
286
  retriever=sql_retriever,
287
  combine_documents_chain=ArXivStuffDocumentChain(
288
  llm_chain=LLMChain(
289
  prompt=COMBINE_PROMPT,
290
+ llm=ChatOpenAI(model_name=chat_model_name,
291
+ openai_api_key=OPENAI_API_KEY, temperature=0.6),
292
  ),
293
+ document_prompt=sel_map[_sel]["doc_prompt"],
294
  document_variable_name="summaries",
295
 
296
  ),
 
298
  max_tokens_limit=12000,
299
  )
300
 
301
+ return {
302
+ "metadata_columns": [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info],
303
+ "retriever": retriever,
304
+ "chain": chain,
305
+ "sql_retriever": sql_retriever,
306
+ "sql_chain": sql_chain
307
+ }
308
+
309
+
310
+ @st.cache_resource
311
+ def build_all():
312
+ sel_map_obj = {}
313
+ for k in sel_map:
314
+ st.session_state[f'emb_model_{k}'] = build_embedding_model(k)
315
+ sel_map_obj[k] = build_retriever(k)
316
+ return sel_map_obj
317
 
318
 
319
  if 'retriever' not in st.session_state:
320
+ st.session_state["sel_map_obj"] = build_all()
321
+
322
+ sel = st.selectbox('Choose the knowledge base you want to ask with:',
323
+ options=['ArXiv Papers', 'Wikipedia'])
324
+ sel_map[sel]['hint']()
 
 
 
 
 
 
 
 
 
 
 
 
325
  tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
326
  with tab_sql:
327
+ sel_map[sel]['hint_sql']()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  st.text_input("Ask a question:", key='query_sql')
329
  cols = st.columns([1, 1, 7])
330
  cols[0].button("Query", key='search_sql')
 
336
  with plc_hldr.expander('Query Log', expanded=True):
337
  callback = ChatDataSQLSearchCallBackHandler()
338
  try:
339
+ docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
340
  st.session_state.query_sql, callbacks=[callback])
341
  callback.progress_bar.progress(value=1.0, text="Done!")
342
  docs = pd.DataFrame(
 
352
  with plc_hldr.expander('Chat Log', expanded=True):
353
  callback = ChatDataSQLAskCallBackHandler()
354
  try:
355
+ ret = st.session_state.sel_map_obj[sel]["sql_chain"](
356
  st.session_state.query_sql, callbacks=[callback])
357
  callback.progress_bar.progress(value=1.0, text="Done!")
358
  st.markdown(
359
  f"### Answer from LLM\n{ret['answer']}\n### References")
360
  docs = ret['sources']
361
+ docs = pd.DataFrame(
362
+ [{**d.metadata, 'abstract': d.page_content} for d in docs])
363
+ display(
364
+ docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
365
  except Exception as e:
366
  st.write('Oops 😵 Something bad happened...')
367
  raise e
 
369
 
370
  with tab_self_query:
371
  st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
372
+ st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
373
  st.text_input("Ask a question:", key='query_self')
374
  cols = st.columns([1, 1, 7])
375
  cols[0].button("Query", key='search_self')
 
382
  call_back = None
383
  callback = ChatDataSelfSearchCallBackHandler()
384
  try:
385
+ docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
386
  st.session_state.query_self, callbacks=[callback])
387
+ print(docs)
388
  callback.progress_bar.progress(value=1.0, text="Done!")
389
  docs = pd.DataFrame(
390
  [{**d.metadata, 'abstract': d.page_content} for d in docs])
391
+ display(docs, sel_map[sel]["must_have_cols"])
 
392
  except Exception as e:
393
  st.write('Oops 😵 Something bad happened...')
394
  raise e
 
400
  call_back = None
401
  callback = ChatDataSelfAskCallBackHandler()
402
  try:
403
+ ret = st.session_state.sel_map_obj[sel]["chain"](
404
  st.session_state.query_self, callbacks=[callback])
405
  callback.progress_bar.progress(value=1.0, text="Done!")
406
  st.markdown(
407
  f"### Answer from LLM\n{ret['answer']}\n### References")
408
  docs = ret['sources']
409
+ docs = pd.DataFrame(
410
+ [{**d.metadata, 'abstract': d.page_content} for d in docs])
411
+ display(
412
+ docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
413
  except Exception as e:
414
  st.write('Oops 😵 Something bad happened...')
415
  raise e
callbacks/arxiv_callbacks.py CHANGED
@@ -90,4 +90,4 @@ class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
90
  self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
91
  self.status_bar = st.empty()
92
  self.prog_value = 0
93
- self.prog_interval = 0.1
 
90
  self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
91
  self.status_bar = st.empty()
92
  self.prog_value = 0
93
+ self.prog_interval = 0.1
chains/arxiv_chains.py CHANGED
@@ -1,4 +1,4 @@
1
- import re
2
  import inspect
3
  from typing import Dict, Any, Optional, List, Tuple
4
 
@@ -7,21 +7,62 @@ from langchain.callbacks.manager import (
7
  AsyncCallbackManagerForChainRun,
8
  CallbackManagerForChainRun,
9
  )
 
10
  from langchain.schema import BaseRetriever
11
  from langchain.callbacks.manager import Callbacks
12
  from langchain.schema.prompt_template import format_document
13
  from langchain.docstore.document import Document
14
  from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
15
- from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
16
  from langchain.chains.combine_documents.stuff import StuffDocumentsChain
17
 
18
  from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
22
  """Based on VectorSQLOutputParser
23
  It also modify the SQL to get all columns
24
  """
 
25
 
26
  @property
27
  def _type(self) -> str:
@@ -123,12 +164,15 @@ class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
123
  ref_cnt = 1
124
  for d in docs:
125
  ref_id = d.metadata['ref_id']
126
- if f"PDF #{ref_id}" in answer:
 
 
127
  title = d.metadata['title'].replace('\n', '')
128
  d.metadata['ref_id'] = ref_cnt
129
- answer = answer.replace(f"PDF #{ref_id}", f"{title} [{ref_cnt}]")
130
  sources.append(d)
131
  ref_cnt += 1
 
132
 
133
  result: Dict[str, Any] = {
134
  self.answer_key: answer,
@@ -147,4 +191,4 @@ class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
147
 
148
  @property
149
  def _chain_type(self) -> str:
150
- return "arxiv_qa_with_sources_chain"
 
1
+ import logging
2
  import inspect
3
  from typing import Dict, Any, Optional, List, Tuple
4
 
 
7
  AsyncCallbackManagerForChainRun,
8
  CallbackManagerForChainRun,
9
  )
10
+ from langchain.embeddings.base import Embeddings
11
  from langchain.schema import BaseRetriever
12
  from langchain.callbacks.manager import Callbacks
13
  from langchain.schema.prompt_template import format_document
14
  from langchain.docstore.document import Document
15
  from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
16
+ from langchain.vectorstores.myscale import MyScale, MyScaleSettings
17
  from langchain.chains.combine_documents.stuff import StuffDocumentsChain
18
 
19
  from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
20
 
21
+ logger = logging.getLogger()
22
+
23
+ class MyScaleWithoutMetadataJson(MyScale):
24
+ def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None:
25
+ super().__init__(embedding, config, **kwargs)
26
+ self.must_have_cols: List[str] = must_have_cols
27
+
28
+ def _build_qstr(
29
+ self, q_emb: List[float], topk: int, where_str: Optional[str] = None
30
+ ) -> str:
31
+ q_emb_str = ",".join(map(str, q_emb))
32
+ if where_str:
33
+ where_str = f"PREWHERE {where_str}"
34
+ else:
35
+ where_str = ""
36
+
37
+ q_str = f"""
38
+ SELECT {self.config.column_map['text']}, dist, {','.join(self.must_have_cols)}
39
+ FROM {self.config.database}.{self.config.table}
40
+ {where_str}
41
+ ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
42
+ AS dist {self.dist_order}
43
+ LIMIT {topk}
44
+ """
45
+ return q_str
46
+
47
+ def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]:
48
+ q_str = self._build_qstr(embedding, k, where_str)
49
+ try:
50
+ return [
51
+ Document(
52
+ page_content=r[self.config.column_map["text"]],
53
+ metadata={k: r[k] for k in self.must_have_cols},
54
+ )
55
+ for r in self.client.query(q_str).named_results()
56
+ ]
57
+ except Exception as e:
58
+ logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
59
+ return []
60
 
61
  class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
62
  """Based on VectorSQLOutputParser
63
  It also modify the SQL to get all columns
64
  """
65
+ must_have_columns: List[str]
66
 
67
  @property
68
  def _type(self) -> str:
 
164
  ref_cnt = 1
165
  for d in docs:
166
  ref_id = d.metadata['ref_id']
167
+ if f"Doc #{ref_id}" in answer:
168
+ answer = answer.replace(f"Doc #{ref_id}", f"#{ref_id}")
169
+ if f"#{ref_id}" in answer:
170
  title = d.metadata['title'].replace('\n', '')
171
  d.metadata['ref_id'] = ref_cnt
172
+ answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
173
  sources.append(d)
174
  ref_cnt += 1
175
+
176
 
177
  result: Dict[str, Any] = {
178
  self.answer_key: answer,
 
191
 
192
  @property
193
  def _chain_type(self) -> str:
194
+ return "arxiv_qa_with_sources_chain"
prompts/arxiv_prompt.py CHANGED
@@ -1,12 +1,12 @@
1
  combine_prompt_template = (
2
- "You are a helpful PDF assistant. Your task is to provide information and answer any questions "
3
- + "related to PDFs given below. You should use the sections, title and abstract of the selected PDFs as your source of information "
4
  + "and try to provide concise and accurate answers to any questions asked by the user. If you are unable to find "
5
  + "relevant information in the given sections, you will need to let the user know that the source does not contain "
6
  + "relevant information but still try to provide an answer based on your general knowledge. You must refer to the "
7
  + "corresponding section name and page that you refer to when answering. The following is the related information "
8
- + "about the PDF file that will help you answer users' questions, you MUST answer it using question's language:\n\n {summaries}"
9
- + "Now you should anwser user's question. Remember you must use the PDF # to refer papers:\n\n"
10
  )
11
 
12
  _myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
 
1
  combine_prompt_template = (
2
+ "You are a helpful document assistant. Your task is to provide information and answer any questions "
3
+ + "related to documents given below. You should use the sections, title and abstract of the selected documents as your source of information "
4
  + "and try to provide concise and accurate answers to any questions asked by the user. If you are unable to find "
5
  + "relevant information in the given sections, you will need to let the user know that the source does not contain "
6
  + "relevant information but still try to provide an answer based on your general knowledge. You must refer to the "
7
  + "corresponding section name and page that you refer to when answering. The following is the related information "
8
+ + "about the document that will help you answer users' questions, you MUST answer it using question's language:\n\n {summaries}"
9
+ + "Now you should anwser user's question. Remember you must use `Doc #` to refer papers:\n\n"
10
  )
11
 
12
  _myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.