Fangrui Liu commited on
Commit
9061790
·
1 Parent(s): abcac4c

Add text 2 sql query & ask

Browse files
Files changed (4) hide show
  1. app.py +207 -76
  2. callbacks/arxiv_callbacks.py +39 -3
  3. prompts/arxiv_prompt.py +99 -0
  4. requirements.txt +1 -1
app.py CHANGED
@@ -2,6 +2,7 @@ import re
2
  import pandas as pd
3
  from os import environ
4
  import streamlit as st
 
5
  environ['TOKENIZERS_PARALLELISM'] = 'true'
6
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
7
 
@@ -13,9 +14,17 @@ from langchain.chains import RetrievalQAWithSourcesChain
13
  from langchain import OpenAI
14
  from langchain.chat_models import ChatOpenAI
15
 
16
- from prompts.arxiv_prompt import combine_prompt_template
17
- from callbacks.arxiv_callbacks import ChatDataSearchCallBackHandler, ChatDataAskCallBackHandler
 
 
18
  from langchain.prompts.prompt import PromptTemplate
 
 
 
 
 
 
19
 
20
 
21
  st.set_page_config(page_title="ChatData")
@@ -25,13 +34,24 @@ st.header("ChatData")
25
  columns = ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate']
26
 
27
 
28
- def display(dataframe, columns):
29
- if len(docs) > 0:
30
- st.dataframe(dataframe[columns])
 
 
 
 
 
 
 
 
 
 
31
  else:
32
- st.write("Sorry 😵 we didn't find any articles related to your query.\nPlease use verbs that may match the datatype.", unsafe_allow_html=True)
33
 
34
- @st.cache_resource
 
35
  def build_retriever():
36
  with st.spinner("Loading Model..."):
37
  embeddings = HuggingFaceInstructEmbeddings(
@@ -88,79 +108,190 @@ def build_retriever():
88
  doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
89
  use_original_query=False)
90
 
91
- with st.spinner('Building RetrievalQAWith SourcesChain...'):
92
- document_with_metadata_prompt = PromptTemplate(
93
- input_variables=["page_content", "id", "title", "authors"],
94
- template="Content:\n\tTitle: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\nSOURCE: {id}")
95
- COMBINE_PROMPT = PromptTemplate(
96
- template=combine_prompt_template, input_variables=["summaries", "question"])
97
- chain = RetrievalQAWithSourcesChain.from_llm(
98
- llm=ChatOpenAI(
99
- openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0.6),
100
- document_prompt=document_with_metadata_prompt,
101
- combine_prompt=COMBINE_PROMPT,
102
- retriever=retriever,
103
- return_source_documents=True,)
104
- return [{'name': m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], retriever, chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
 
107
  if 'retriever' not in st.session_state:
108
  st.session_state['metadata_columns'], \
109
  st.session_state['retriever'], \
110
- st.session_state['chain'] = \
111
- build_retriever()
112
- st.info("Chat with 2 milions arxiv papers, powered by [MyScale](https://myscale.com)", icon="🌟")
113
- st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n" +
114
- "For example: \n\n" +
115
- "- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n" +
116
- "- What is neural network? Please use articles published by Geoffrey Hinton after 2018.\n" +
117
- "- Introduce some applications of GANs published around 2019.\n" +
118
- "- 请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些?" +
 
 
 
 
119
  "- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?")
120
- # or ask questions based on retrieved papers with button `Ask`
121
- st.info("You can retrieve papers with button `Query`", icon='💡')
122
- st.dataframe(st.session_state.metadata_columns)
123
- st.text_input("Ask a question:", key='query')
124
- cols = st.columns([1, 1, 7])
125
- cols[0].button("Query", key='search')
126
- # cols[1].button("Ask", key='ask')
127
- plc_hldr = st.empty()
128
-
129
- if st.session_state.search:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  plc_hldr = st.empty()
131
- with plc_hldr.expander('Query Log', expanded=True):
132
- call_back = None
133
- callback = ChatDataSearchCallBackHandler()
134
- try:
135
- docs = st.session_state.retriever.get_relevant_documents(
136
- st.session_state.query, callbacks=[callback])
137
- callback.progress_bar.progress(value=1.0, text="Done!")
138
- docs = pd.DataFrame(
139
- [{**d.metadata, 'abstract': d.page_content} for d in docs])
140
-
141
- display(docs, columns)
142
- except Exception as e:
143
- st.write('Oops 😵 Something bad happened...')
144
- # raise e
145
-
146
- # if st.session_state.ask:
147
- # plc_hldr = st.empty()
148
- # ctx = st.container()
149
- # with plc_hldr.expander('Chat Log', expanded=True):
150
- # call_back = None
151
- # callback = ChatDataAskCallBackHandler()
152
- # try:
153
- # ret = st.session_state.chain(
154
- # st.session_state.query, callbacks=[callback])
155
- # callback.progress_bar.progress(value=1.0, text="Done!")
156
- # st.markdown(
157
- # f"### Answer from LLM\n{ret['answer']}\n### References")
158
- # docs = ret['source_documents']
159
- # ref = re.findall(
160
- # '(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['sources'])
161
- # docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content}
162
- # for d in docs if d.metadata['id'] in ref])
163
- # display(docs, columns)
164
- # except Exception as e:
165
- # st.write('Oops 😵 Something bad happened...')
166
- # # raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import pandas as pd
3
  from os import environ
4
  import streamlit as st
5
+ import datetime
6
  environ['TOKENIZERS_PARALLELISM'] = 'true'
7
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
8
 
 
14
  from langchain import OpenAI
15
  from langchain.chat_models import ChatOpenAI
16
 
17
+ from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
18
+ from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
19
+ ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
20
+ ChatDataSQLAskCallBackHandler
21
  from langchain.prompts.prompt import PromptTemplate
22
+ from sqlalchemy import create_engine, MetaData
23
+ from langchain.chains.sql_database.base import SQLDatabaseChain
24
+ from langchain.chains.sql_database.parser import VectorSQLRetrieveAllOutputParser
25
+ from langchain.chains import LLMChain
26
+ from langchain.sql_database import SQLDatabase
27
+ from langchain.retrievers import SQLDatabaseChainRetriever
28
 
29
 
30
  st.set_page_config(page_title="ChatData")
 
34
  columns = ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate']
35
 
36
 
37
+ def try_eval(x):
38
+ try:
39
+ return eval(x, {'datetime': datetime})
40
+ except:
41
+ return x
42
+
43
+
44
+ def display(dataframe, columns=None):
45
+ if len(dataframe) > 0:
46
+ if columns:
47
+ st.dataframe(dataframe[columns])
48
+ else:
49
+ st.dataframe(dataframe)
50
  else:
51
+ st.write("Sorry 😵 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
52
 
53
+
54
+ @st.cache_resource
55
  def build_retriever():
56
  with st.spinner("Loading Model..."):
57
  embeddings = HuggingFaceInstructEmbeddings(
 
108
  doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
109
  use_original_query=False)
110
 
111
+ with st.spinner('Building RetrievalQAWith SourcesChain...'):
112
+ document_with_metadata_prompt = PromptTemplate(
113
+ input_variables=["page_content", "id", "title",
114
+ "authors", "pubdate", "categories"],
115
+ template="Content:\n\tTitle: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}")
116
+ COMBINE_PROMPT = PromptTemplate(
117
+ template=combine_prompt_template, input_variables=["summaries", "question"])
118
+ chain = RetrievalQAWithSourcesChain.from_chain_type(
119
+ ChatOpenAI(model_name='gpt-3.5-turbo-16k',
120
+ openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0.6),
121
+ retriever=retriever,
122
+ chain_type='stuff',
123
+ chain_type_kwargs={
124
+ 'prompt': COMBINE_PROMPT,
125
+ 'document_prompt': document_with_metadata_prompt,
126
+ }, return_source_documents=True)
127
+
128
+ with st.spinner('Building Vector SQL Database Chain'):
129
+ MYSCALE_USER = st.secrets['MYSCALE_USER']
130
+ MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
131
+ MYSCALE_HOST = st.secrets['MYSCALE_HOST']
132
+ MYSCALE_PORT = st.secrets['MYSCALE_PORT']
133
+ engine = create_engine(
134
+ f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/default?protocol=https')
135
+ metadata = MetaData(bind=engine)
136
+ PROMPT = PromptTemplate(
137
+ input_variables=["input", "table_info", "top_k"],
138
+ template=_myscale_prompt,
139
+ )
140
+
141
+ output_parser = VectorSQLRetrieveAllOutputParser.from_embeddings(
142
+ model=embeddings)
143
+ sql_query_chain = SQLDatabaseChain.from_llm(
144
+ llm=OpenAI(openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0),
145
+ prompt=PROMPT,
146
+ top_k=10,
147
+ return_direct=True,
148
+ db=SQLDatabase(engine, None, metadata, max_string_length=1024),
149
+ sql_cmd_parser=output_parser,
150
+ native_format=True
151
+ )
152
+ sql_retriever = SQLDatabaseChainRetriever(
153
+ sql_db_chain=sql_query_chain, page_content_key="abstract")
154
+ sql_chain = RetrievalQAWithSourcesChain.from_chain_type(
155
+ ChatOpenAI(model_name='gpt-3.5-turbo-16k',
156
+ openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0.6),
157
+ retriever=sql_retriever,
158
+ chain_type='stuff',
159
+ chain_type_kwargs={
160
+ 'prompt': COMBINE_PROMPT,
161
+ 'document_prompt': document_with_metadata_prompt,
162
+ }, return_source_documents=True)
163
+
164
+ return [{'name': m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], retriever, chain, sql_retriever, sql_chain
165
 
166
 
167
  if 'retriever' not in st.session_state:
168
  st.session_state['metadata_columns'], \
169
  st.session_state['retriever'], \
170
+ st.session_state['chain'], \
171
+ st.session_state['sql_retriever'], \
172
+ st.session_state['sql_chain'] = build_retriever()
173
+
174
+ st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
175
+ "For example: \n\n"
176
+ "*If you want to search papers with complex filters*:\n\n"
177
+ "- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n"
178
+ "*If you want to ask questions based on papers in database*:\n\n"
179
+ "- What is PageRank?\n"
180
+ "- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n"
181
+ "- Introduce some applications of GANs published around 2019.\n"
182
+ "- 请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些\n"
183
  "- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?")
184
+ tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
185
+ with tab_sql:
186
+ st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
187
+ st.markdown('''```sql
188
+ CREATE TABLE default.ChatArXiv (
189
+ `abstract` String,
190
+ `id` String,
191
+ `vector` Array(Float32),
192
+ `metadata` Object('JSON'),
193
+ `pubdate` DateTime,
194
+ `title` String,
195
+ `categories` Array(String),
196
+ `authors` Array(String),
197
+ `comment` String,
198
+ `primary_category` String,
199
+ VECTOR INDEX vec_idx vector TYPE MSTG('metric_type=Cosine'),
200
+ CONSTRAINT vec_len CHECK length(vector) = 768)
201
+ ENGINE = ReplacingMergeTree ORDER BY id
202
+ ```''')
203
+
204
+ st.text_input("Ask a question:", key='query_sql')
205
+ cols = st.columns([1, 1, 7])
206
+ cols[0].button("Query", key='search_sql')
207
+ cols[1].button("Ask", key='ask_sql')
208
  plc_hldr = st.empty()
209
+ if st.session_state.search_sql:
210
+ plc_hldr = st.empty()
211
+ print(st.session_state.query_sql)
212
+ with plc_hldr.expander('Query Log', expanded=True):
213
+ callback = ChatDataSQLSearchCallBackHandler()
214
+ try:
215
+ docs = st.session_state.sql_retriever.get_relevant_documents(
216
+ st.session_state.query_sql, callbacks=[callback])
217
+ callback.progress_bar.progress(value=1.0, text="Done!")
218
+ docs = pd.DataFrame(
219
+ [{**d.metadata, 'abstract': d.page_content} for d in docs])
220
+ display(docs)
221
+ except Exception as e:
222
+ st.write('Oops 😵 Something bad happened...')
223
+ # raise e
224
+
225
+ if st.session_state.ask_sql:
226
+ plc_hldr = st.empty()
227
+ print(st.session_state.query_sql)
228
+ with plc_hldr.expander('Chat Log', expanded=True):
229
+ callback = ChatDataSQLAskCallBackHandler()
230
+ try:
231
+ ret = st.session_state.sql_chain(
232
+ st.session_state.query_sql, callbacks=[callback])
233
+ callback.progress_bar.progress(value=1.0, text="Done!")
234
+ st.markdown(
235
+ f"### Answer from LLM\n{ret['answer']}\n### References")
236
+ docs = ret['source_documents']
237
+ ref = re.findall(
238
+ '(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['sources'])
239
+ ref += re.findall(
240
+ '(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['answer'])
241
+ docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content}
242
+ for d in docs if d.metadata['id'] in set(ref)])
243
+ display(docs, columns)
244
+ except Exception as e:
245
+ st.write('Oops 😵 Something bad happened...')
246
+ # raise e
247
+
248
+
249
+ with tab_self_query:
250
+ st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
251
+ st.dataframe(st.session_state.metadata_columns)
252
+ st.text_input("Ask a question:", key='query_self')
253
+ cols = st.columns([1, 1, 7])
254
+ cols[0].button("Query", key='search_self')
255
+ cols[1].button("Ask", key='ask_self')
256
+ plc_hldr = st.empty()
257
+ if st.session_state.search_self:
258
+ plc_hldr = st.empty()
259
+ print(st.session_state.query_self)
260
+ with plc_hldr.expander('Query Log', expanded=True):
261
+ call_back = None
262
+ callback = ChatDataSelfSearchCallBackHandler()
263
+ try:
264
+ docs = st.session_state.retriever.get_relevant_documents(
265
+ st.session_state.query_self, callbacks=[callback])
266
+ callback.progress_bar.progress(value=1.0, text="Done!")
267
+ docs = pd.DataFrame(
268
+ [{**d.metadata, 'abstract': d.page_content} for d in docs])
269
+
270
+ display(docs, columns)
271
+ except Exception as e:
272
+ st.write('Oops 😵 Something bad happened...')
273
+ # raise e
274
+
275
+ if st.session_state.ask_self:
276
+ plc_hldr = st.empty()
277
+ print(st.session_state.query_self)
278
+ with plc_hldr.expander('Chat Log', expanded=True):
279
+ call_back = None
280
+ callback = ChatDataSelfAskCallBackHandler()
281
+ try:
282
+ ret = st.session_state.chain(
283
+ st.session_state.query_self, callbacks=[callback])
284
+ callback.progress_bar.progress(value=1.0, text="Done!")
285
+ st.markdown(
286
+ f"### Answer from LLM\n{ret['answer']}\n### References")
287
+ docs = ret['source_documents']
288
+ ref = re.findall(
289
+ '(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['sources'])
290
+ ref += re.findall(
291
+ '(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['answer'])
292
+ docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content}
293
+ for d in docs if d.metadata['id'] in set(ref)])
294
+ display(docs, columns)
295
+ except Exception as e:
296
+ st.write('Oops 😵 Something bad happened...')
297
+ # raise e
callbacks/arxiv_callbacks.py CHANGED
@@ -1,7 +1,9 @@
1
  import streamlit as st
 
 
2
  from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
3
 
4
- class ChatDataSearchCallBackHandler(StreamlitCallbackHandler):
5
  def __init__(self) -> None:
6
  self.progress_bar = st.progress(value=0.0, text="Working...")
7
  self.tokens_stream = ""
@@ -20,7 +22,7 @@ class ChatDataSearchCallBackHandler(StreamlitCallbackHandler):
20
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
21
  pass
22
 
23
- class ChatDataAskCallBackHandler(StreamlitCallbackHandler):
24
  def __init__(self) -> None:
25
  self.progress_bar = st.progress(value=0.0, text='Searching DB...')
26
  self.status_bar = st.empty()
@@ -47,4 +49,38 @@ class ChatDataAskCallBackHandler(StreamlitCallbackHandler):
47
  self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...')
48
 
49
  def on_chain_end(self, outputs, **kwargs) -> None:
50
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from typing import Dict, Any
3
+ from sql_formatter.core import format_sql
4
  from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
5
 
6
+ class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
7
  def __init__(self) -> None:
8
  self.progress_bar = st.progress(value=0.0, text="Working...")
9
  self.tokens_stream = ""
 
22
  def on_chain_start(self, serialized, inputs, **kwargs) -> None:
23
  pass
24
 
25
+ class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
26
  def __init__(self) -> None:
27
  self.progress_bar = st.progress(value=0.0, text='Searching DB...')
28
  self.status_bar = st.empty()
 
49
  self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...')
50
 
51
  def on_chain_end(self, outputs, **kwargs) -> None:
52
+ pass
53
+
54
+
55
+ class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
56
+ def __init__(self) -> None:
57
+ self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
58
+ self.status_bar = st.empty()
59
+ self.prog_value = 0
60
+ self.prog_interval = 0.2
61
+
62
+ def on_llm_start(self, serialized, prompts, **kwargs) -> None:
63
+ pass
64
+
65
+ def on_text(self, text: str, **kwargs) -> None:
66
+ if text.startswith('SELECT'):
67
+ st.write('We generated Vector SQL for you:')
68
+ st.markdown(f'''```sql\n{format_sql(text, max_len=80)}\n```''')
69
+ print(f"Vector SQL: {text}")
70
+ self.prog_value += self.prog_interval
71
+ self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")
72
+
73
+ def on_chain_start(self, serialized, inputs, **kwargs) -> None:
74
+ cid = '.'.join(serialized['id'])
75
+ self.prog_value += self.prog_interval
76
+ self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...')
77
+
78
+ def on_chain_end(self, outputs, **kwargs) -> None:
79
+ pass
80
+
81
+ class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
82
+ def __init__(self) -> None:
83
+ self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
84
+ self.status_bar = st.empty()
85
+ self.prog_value = 0
86
+ self.prog_interval = 0.1
prompts/arxiv_prompt.py CHANGED
@@ -10,3 +10,102 @@ combine_prompt_template_ = (
10
 
11
  combine_prompt_template = combine_prompt_template_ + combine_prompt_template
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  combine_prompt_template = combine_prompt_template_ + combine_prompt_template
12
 
13
+ _myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
14
+ MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
15
+ When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.
16
+
17
+ *NOTICE*: `DISTANCE(column, array)` only accept an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user defined function called `NeuralArray(entity)` to retrieve the entity's array.
18
+
19
+ Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MyScale. You should only order according to the distance function.
20
+ Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
21
+ Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
22
+ Pay attention to use today() function to get the current date, if the question involves "today". `ORDER BY` clause should always be after `WHERE` clause. DO NOT add semicolon to the end of SQL. Pay attention to the comment in table schema.
23
+ Pay attention to the data type when using functions. Always use `AND` to connect conditions in `WHERE` and never use comma.
24
+ Make sure you never write an isolated `WHERE` keyword and never use undesired condition to conrtain the query.
25
+
26
+ Use the following format:
27
+
28
+ ======== table info ========
29
+ <some table infos>
30
+
31
+ Question: "Question here"
32
+ SQLQuery: "SQL Query to run"
33
+
34
+
35
+ Here are some examples:
36
+
37
+ ======== table info ========
38
+ CREATE TABLE "ChatPaper" (
39
+ abstract String,
40
+ id String,
41
+ vector Array(Float32),
42
+ ) ENGINE = ReplicatedReplacingMergeTree()
43
+ ORDER BY id
44
+ PRIMARY KEY id
45
+
46
+ Question: What is Feartue Pyramid Network?
47
+ SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
48
+
49
+
50
+ ======== table info ========
51
+ CREATE TABLE "ChatPaper" (
52
+ abstract String,
53
+ id String,
54
+ vector Array(Float32),
55
+ categories Array(String),
56
+ pubdate DateTime,
57
+ title String,
58
+ authors Array(String),
59
+ primary_category String
60
+ ) ENGINE = ReplicatedReplacingMergeTree()
61
+ ORDER BY id
62
+ PRIMARY KEY id
63
+
64
+ Question: What is PaperRank? What is the contribution of those works? Use paper with more than 2 categories.
65
+ SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper WHERE length(categories) > 2 ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
66
+
67
+
68
+ ======== table info ========
69
+ CREATE TABLE "ChatArXiv" (
70
+ primary_category String
71
+ categories Array(String),
72
+ pubdate DateTime,
73
+ abstract String,
74
+ title String,
75
+ paper_id String,
76
+ vector Array(Float32),
77
+ authors Array(String),
78
+ ) ENGINE = MergeTree()
79
+ ORDER BY paper_id
80
+ PRIMARY KEY paper_id
81
+
82
+ Question: Did Geoffrey Hinton wrote about Capsule Neural Networks? Please use articles published later than 2021.
83
+ SQLQuery: SELECT ChatArXiv.title, ChatArXiv.paper_id, ChatArXiv.authors FROM ChatArXiv WHERE has(authors, 'Geoffrey Hinton') AND pubdate > parseDateTimeBestEffort('2021-01-01') ORDER BY DISTANCE(vector, NeuralArray(Capsule Neural Networks)) LIMIT {top_k}
84
+
85
+
86
+ ======== table info ========
87
+ CREATE TABLE "PaperDatabase" (
88
+ abstract String,
89
+ categories Array(String),
90
+ vector Array(Float32),
91
+ pubdate DateTime,
92
+ id String,
93
+ comments String,
94
+ title String,
95
+ authors Array(String),
96
+ primary_category String
97
+ ) ENGINE = MergeTree()
98
+ ORDER BY id
99
+ PRIMARY KEY id
100
+
101
+ Question: Find papers whose abstract has Mutual Information in it.
102
+ SQLQuery: SELECT PaperDatabase.title, PaperDatabase.id FROM PaperDatabase WHERE abstract ILIKE '%Mutual Information%' ORDER BY DISTANCE(vector, NeuralArray(Mutual Information)) LIMIT {top_k}
103
+
104
+
105
+ Let's begin:
106
+
107
+ ======== table info ========
108
+ {table_info}
109
+
110
+ Question: {input}
111
+ SQLQuery: """
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- langchain @ git+https://github.com/myscale/langchain.git@master
2
  InstructorEmbedding
3
  pandas
4
  sentence_transformers
 
1
+ langchain @ git+https://github.com/myscale/langchain.git@preview
2
  InstructorEmbedding
3
  pandas
4
  sentence_transformers