Spaces:
Running
Running
Fangrui Liu
commited on
Commit
·
9061790
1
Parent(s):
abcac4c
Add text 2 sql query & ask
Browse files- app.py +207 -76
- callbacks/arxiv_callbacks.py +39 -3
- prompts/arxiv_prompt.py +99 -0
- requirements.txt +1 -1
app.py
CHANGED
@@ -2,6 +2,7 @@ import re
|
|
2 |
import pandas as pd
|
3 |
from os import environ
|
4 |
import streamlit as st
|
|
|
5 |
environ['TOKENIZERS_PARALLELISM'] = 'true'
|
6 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
7 |
|
@@ -13,9 +14,17 @@ from langchain.chains import RetrievalQAWithSourcesChain
|
|
13 |
from langchain import OpenAI
|
14 |
from langchain.chat_models import ChatOpenAI
|
15 |
|
16 |
-
from prompts.arxiv_prompt import combine_prompt_template
|
17 |
-
from callbacks.arxiv_callbacks import
|
|
|
|
|
18 |
from langchain.prompts.prompt import PromptTemplate
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
st.set_page_config(page_title="ChatData")
|
@@ -25,13 +34,24 @@ st.header("ChatData")
|
|
25 |
columns = ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate']
|
26 |
|
27 |
|
28 |
-
def
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
else:
|
32 |
-
st.write("Sorry 😵 we didn't find any articles related to your query.\nPlease use verbs that may match the datatype.", unsafe_allow_html=True)
|
33 |
|
34 |
-
|
|
|
35 |
def build_retriever():
|
36 |
with st.spinner("Loading Model..."):
|
37 |
embeddings = HuggingFaceInstructEmbeddings(
|
@@ -88,79 +108,190 @@ def build_retriever():
|
|
88 |
doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
|
89 |
use_original_query=False)
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
|
107 |
if 'retriever' not in st.session_state:
|
108 |
st.session_state['metadata_columns'], \
|
109 |
st.session_state['retriever'], \
|
110 |
-
st.session_state['chain']
|
111 |
-
|
112 |
-
st.
|
113 |
-
|
114 |
-
|
115 |
-
"
|
116 |
-
|
117 |
-
"-
|
118 |
-
"
|
|
|
|
|
|
|
|
|
119 |
"- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?")
|
120 |
-
|
121 |
-
|
122 |
-
st.
|
123 |
-
st.
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
plc_hldr = st.empty()
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import pandas as pd
|
3 |
from os import environ
|
4 |
import streamlit as st
|
5 |
+
import datetime
|
6 |
environ['TOKENIZERS_PARALLELISM'] = 'true'
|
7 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
8 |
|
|
|
14 |
from langchain import OpenAI
|
15 |
from langchain.chat_models import ChatOpenAI
|
16 |
|
17 |
+
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
18 |
+
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
|
19 |
+
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
|
20 |
+
ChatDataSQLAskCallBackHandler
|
21 |
from langchain.prompts.prompt import PromptTemplate
|
22 |
+
from sqlalchemy import create_engine, MetaData
|
23 |
+
from langchain.chains.sql_database.base import SQLDatabaseChain
|
24 |
+
from langchain.chains.sql_database.parser import VectorSQLRetrieveAllOutputParser
|
25 |
+
from langchain.chains import LLMChain
|
26 |
+
from langchain.sql_database import SQLDatabase
|
27 |
+
from langchain.retrievers import SQLDatabaseChainRetriever
|
28 |
|
29 |
|
30 |
st.set_page_config(page_title="ChatData")
|
|
|
34 |
columns = ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate']
|
35 |
|
36 |
|
37 |
+
def try_eval(x):
|
38 |
+
try:
|
39 |
+
return eval(x, {'datetime': datetime})
|
40 |
+
except:
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def display(dataframe, columns=None):
|
45 |
+
if len(dataframe) > 0:
|
46 |
+
if columns:
|
47 |
+
st.dataframe(dataframe[columns])
|
48 |
+
else:
|
49 |
+
st.dataframe(dataframe)
|
50 |
else:
|
51 |
+
st.write("Sorry 😵 we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
|
52 |
|
53 |
+
|
54 |
+
@st.cache_resource
|
55 |
def build_retriever():
|
56 |
with st.spinner("Loading Model..."):
|
57 |
embeddings = HuggingFaceInstructEmbeddings(
|
|
|
108 |
doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
|
109 |
use_original_query=False)
|
110 |
|
111 |
+
with st.spinner('Building RetrievalQAWith SourcesChain...'):
|
112 |
+
document_with_metadata_prompt = PromptTemplate(
|
113 |
+
input_variables=["page_content", "id", "title",
|
114 |
+
"authors", "pubdate", "categories"],
|
115 |
+
template="Content:\n\tTitle: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}")
|
116 |
+
COMBINE_PROMPT = PromptTemplate(
|
117 |
+
template=combine_prompt_template, input_variables=["summaries", "question"])
|
118 |
+
chain = RetrievalQAWithSourcesChain.from_chain_type(
|
119 |
+
ChatOpenAI(model_name='gpt-3.5-turbo-16k',
|
120 |
+
openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0.6),
|
121 |
+
retriever=retriever,
|
122 |
+
chain_type='stuff',
|
123 |
+
chain_type_kwargs={
|
124 |
+
'prompt': COMBINE_PROMPT,
|
125 |
+
'document_prompt': document_with_metadata_prompt,
|
126 |
+
}, return_source_documents=True)
|
127 |
+
|
128 |
+
with st.spinner('Building Vector SQL Database Chain'):
|
129 |
+
MYSCALE_USER = st.secrets['MYSCALE_USER']
|
130 |
+
MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
|
131 |
+
MYSCALE_HOST = st.secrets['MYSCALE_HOST']
|
132 |
+
MYSCALE_PORT = st.secrets['MYSCALE_PORT']
|
133 |
+
engine = create_engine(
|
134 |
+
f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/default?protocol=https')
|
135 |
+
metadata = MetaData(bind=engine)
|
136 |
+
PROMPT = PromptTemplate(
|
137 |
+
input_variables=["input", "table_info", "top_k"],
|
138 |
+
template=_myscale_prompt,
|
139 |
+
)
|
140 |
+
|
141 |
+
output_parser = VectorSQLRetrieveAllOutputParser.from_embeddings(
|
142 |
+
model=embeddings)
|
143 |
+
sql_query_chain = SQLDatabaseChain.from_llm(
|
144 |
+
llm=OpenAI(openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0),
|
145 |
+
prompt=PROMPT,
|
146 |
+
top_k=10,
|
147 |
+
return_direct=True,
|
148 |
+
db=SQLDatabase(engine, None, metadata, max_string_length=1024),
|
149 |
+
sql_cmd_parser=output_parser,
|
150 |
+
native_format=True
|
151 |
+
)
|
152 |
+
sql_retriever = SQLDatabaseChainRetriever(
|
153 |
+
sql_db_chain=sql_query_chain, page_content_key="abstract")
|
154 |
+
sql_chain = RetrievalQAWithSourcesChain.from_chain_type(
|
155 |
+
ChatOpenAI(model_name='gpt-3.5-turbo-16k',
|
156 |
+
openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0.6),
|
157 |
+
retriever=sql_retriever,
|
158 |
+
chain_type='stuff',
|
159 |
+
chain_type_kwargs={
|
160 |
+
'prompt': COMBINE_PROMPT,
|
161 |
+
'document_prompt': document_with_metadata_prompt,
|
162 |
+
}, return_source_documents=True)
|
163 |
+
|
164 |
+
return [{'name': m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], retriever, chain, sql_retriever, sql_chain
|
165 |
|
166 |
|
167 |
if 'retriever' not in st.session_state:
|
168 |
st.session_state['metadata_columns'], \
|
169 |
st.session_state['retriever'], \
|
170 |
+
st.session_state['chain'], \
|
171 |
+
st.session_state['sql_retriever'], \
|
172 |
+
st.session_state['sql_chain'] = build_retriever()
|
173 |
+
|
174 |
+
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
175 |
+
"For example: \n\n"
|
176 |
+
"*If you want to search papers with complex filters*:\n\n"
|
177 |
+
"- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n"
|
178 |
+
"*If you want to ask questions based on papers in database*:\n\n"
|
179 |
+
"- What is PageRank?\n"
|
180 |
+
"- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n"
|
181 |
+
"- Introduce some applications of GANs published around 2019.\n"
|
182 |
+
"- 请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些\n"
|
183 |
"- Veuillez présenter les applications du GAN sur la base des articles autour de 2019 ?")
|
184 |
+
tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
|
185 |
+
with tab_sql:
|
186 |
+
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
|
187 |
+
st.markdown('''```sql
|
188 |
+
CREATE TABLE default.ChatArXiv (
|
189 |
+
`abstract` String,
|
190 |
+
`id` String,
|
191 |
+
`vector` Array(Float32),
|
192 |
+
`metadata` Object('JSON'),
|
193 |
+
`pubdate` DateTime,
|
194 |
+
`title` String,
|
195 |
+
`categories` Array(String),
|
196 |
+
`authors` Array(String),
|
197 |
+
`comment` String,
|
198 |
+
`primary_category` String,
|
199 |
+
VECTOR INDEX vec_idx vector TYPE MSTG('metric_type=Cosine'),
|
200 |
+
CONSTRAINT vec_len CHECK length(vector) = 768)
|
201 |
+
ENGINE = ReplacingMergeTree ORDER BY id
|
202 |
+
```''')
|
203 |
+
|
204 |
+
st.text_input("Ask a question:", key='query_sql')
|
205 |
+
cols = st.columns([1, 1, 7])
|
206 |
+
cols[0].button("Query", key='search_sql')
|
207 |
+
cols[1].button("Ask", key='ask_sql')
|
208 |
plc_hldr = st.empty()
|
209 |
+
if st.session_state.search_sql:
|
210 |
+
plc_hldr = st.empty()
|
211 |
+
print(st.session_state.query_sql)
|
212 |
+
with plc_hldr.expander('Query Log', expanded=True):
|
213 |
+
callback = ChatDataSQLSearchCallBackHandler()
|
214 |
+
try:
|
215 |
+
docs = st.session_state.sql_retriever.get_relevant_documents(
|
216 |
+
st.session_state.query_sql, callbacks=[callback])
|
217 |
+
callback.progress_bar.progress(value=1.0, text="Done!")
|
218 |
+
docs = pd.DataFrame(
|
219 |
+
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
220 |
+
display(docs)
|
221 |
+
except Exception as e:
|
222 |
+
st.write('Oops 😵 Something bad happened...')
|
223 |
+
# raise e
|
224 |
+
|
225 |
+
if st.session_state.ask_sql:
|
226 |
+
plc_hldr = st.empty()
|
227 |
+
print(st.session_state.query_sql)
|
228 |
+
with plc_hldr.expander('Chat Log', expanded=True):
|
229 |
+
callback = ChatDataSQLAskCallBackHandler()
|
230 |
+
try:
|
231 |
+
ret = st.session_state.sql_chain(
|
232 |
+
st.session_state.query_sql, callbacks=[callback])
|
233 |
+
callback.progress_bar.progress(value=1.0, text="Done!")
|
234 |
+
st.markdown(
|
235 |
+
f"### Answer from LLM\n{ret['answer']}\n### References")
|
236 |
+
docs = ret['source_documents']
|
237 |
+
ref = re.findall(
|
238 |
+
'(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['sources'])
|
239 |
+
ref += re.findall(
|
240 |
+
'(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['answer'])
|
241 |
+
docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content}
|
242 |
+
for d in docs if d.metadata['id'] in set(ref)])
|
243 |
+
display(docs, columns)
|
244 |
+
except Exception as e:
|
245 |
+
st.write('Oops 😵 Something bad happened...')
|
246 |
+
# raise e
|
247 |
+
|
248 |
+
|
249 |
+
with tab_self_query:
|
250 |
+
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
|
251 |
+
st.dataframe(st.session_state.metadata_columns)
|
252 |
+
st.text_input("Ask a question:", key='query_self')
|
253 |
+
cols = st.columns([1, 1, 7])
|
254 |
+
cols[0].button("Query", key='search_self')
|
255 |
+
cols[1].button("Ask", key='ask_self')
|
256 |
+
plc_hldr = st.empty()
|
257 |
+
if st.session_state.search_self:
|
258 |
+
plc_hldr = st.empty()
|
259 |
+
print(st.session_state.query_self)
|
260 |
+
with plc_hldr.expander('Query Log', expanded=True):
|
261 |
+
call_back = None
|
262 |
+
callback = ChatDataSelfSearchCallBackHandler()
|
263 |
+
try:
|
264 |
+
docs = st.session_state.retriever.get_relevant_documents(
|
265 |
+
st.session_state.query_self, callbacks=[callback])
|
266 |
+
callback.progress_bar.progress(value=1.0, text="Done!")
|
267 |
+
docs = pd.DataFrame(
|
268 |
+
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
269 |
+
|
270 |
+
display(docs, columns)
|
271 |
+
except Exception as e:
|
272 |
+
st.write('Oops 😵 Something bad happened...')
|
273 |
+
# raise e
|
274 |
+
|
275 |
+
if st.session_state.ask_self:
|
276 |
+
plc_hldr = st.empty()
|
277 |
+
print(st.session_state.query_self)
|
278 |
+
with plc_hldr.expander('Chat Log', expanded=True):
|
279 |
+
call_back = None
|
280 |
+
callback = ChatDataSelfAskCallBackHandler()
|
281 |
+
try:
|
282 |
+
ret = st.session_state.chain(
|
283 |
+
st.session_state.query_self, callbacks=[callback])
|
284 |
+
callback.progress_bar.progress(value=1.0, text="Done!")
|
285 |
+
st.markdown(
|
286 |
+
f"### Answer from LLM\n{ret['answer']}\n### References")
|
287 |
+
docs = ret['source_documents']
|
288 |
+
ref = re.findall(
|
289 |
+
'(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['sources'])
|
290 |
+
ref += re.findall(
|
291 |
+
'(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['answer'])
|
292 |
+
docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content}
|
293 |
+
for d in docs if d.metadata['id'] in set(ref)])
|
294 |
+
display(docs, columns)
|
295 |
+
except Exception as e:
|
296 |
+
st.write('Oops 😵 Something bad happened...')
|
297 |
+
# raise e
|
callbacks/arxiv_callbacks.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
2 |
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
|
3 |
|
4 |
-
class
|
5 |
def __init__(self) -> None:
|
6 |
self.progress_bar = st.progress(value=0.0, text="Working...")
|
7 |
self.tokens_stream = ""
|
@@ -20,7 +22,7 @@ class ChatDataSearchCallBackHandler(StreamlitCallbackHandler):
|
|
20 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
21 |
pass
|
22 |
|
23 |
-
class
|
24 |
def __init__(self) -> None:
|
25 |
self.progress_bar = st.progress(value=0.0, text='Searching DB...')
|
26 |
self.status_bar = st.empty()
|
@@ -47,4 +49,38 @@ class ChatDataAskCallBackHandler(StreamlitCallbackHandler):
|
|
47 |
self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...')
|
48 |
|
49 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
50 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from typing import Dict, Any
|
3 |
+
from sql_formatter.core import format_sql
|
4 |
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
|
5 |
|
6 |
+
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
7 |
def __init__(self) -> None:
|
8 |
self.progress_bar = st.progress(value=0.0, text="Working...")
|
9 |
self.tokens_stream = ""
|
|
|
22 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
23 |
pass
|
24 |
|
25 |
+
class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
|
26 |
def __init__(self) -> None:
|
27 |
self.progress_bar = st.progress(value=0.0, text='Searching DB...')
|
28 |
self.status_bar = st.empty()
|
|
|
49 |
self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...')
|
50 |
|
51 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
52 |
+
pass
|
53 |
+
|
54 |
+
|
55 |
+
class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
|
56 |
+
def __init__(self) -> None:
|
57 |
+
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
|
58 |
+
self.status_bar = st.empty()
|
59 |
+
self.prog_value = 0
|
60 |
+
self.prog_interval = 0.2
|
61 |
+
|
62 |
+
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
63 |
+
pass
|
64 |
+
|
65 |
+
def on_text(self, text: str, **kwargs) -> None:
|
66 |
+
if text.startswith('SELECT'):
|
67 |
+
st.write('We generated Vector SQL for you:')
|
68 |
+
st.markdown(f'''```sql\n{format_sql(text, max_len=80)}\n```''')
|
69 |
+
print(f"Vector SQL: {text}")
|
70 |
+
self.prog_value += self.prog_interval
|
71 |
+
self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")
|
72 |
+
|
73 |
+
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
74 |
+
cid = '.'.join(serialized['id'])
|
75 |
+
self.prog_value += self.prog_interval
|
76 |
+
self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...')
|
77 |
+
|
78 |
+
def on_chain_end(self, outputs, **kwargs) -> None:
|
79 |
+
pass
|
80 |
+
|
81 |
+
class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
|
82 |
+
def __init__(self) -> None:
|
83 |
+
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
|
84 |
+
self.status_bar = st.empty()
|
85 |
+
self.prog_value = 0
|
86 |
+
self.prog_interval = 0.1
|
prompts/arxiv_prompt.py
CHANGED
@@ -10,3 +10,102 @@ combine_prompt_template_ = (
|
|
10 |
|
11 |
combine_prompt_template = combine_prompt_template_ + combine_prompt_template
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
combine_prompt_template = combine_prompt_template_ + combine_prompt_template
|
12 |
|
13 |
+
_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
|
14 |
+
MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
|
15 |
+
When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.
|
16 |
+
|
17 |
+
*NOTICE*: `DISTANCE(column, array)` only accept an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user defined function called `NeuralArray(entity)` to retrieve the entity's array.
|
18 |
+
|
19 |
+
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MyScale. You should only order according to the distance function.
|
20 |
+
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
|
21 |
+
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
22 |
+
Pay attention to use today() function to get the current date, if the question involves "today". `ORDER BY` clause should always be after `WHERE` clause. DO NOT add semicolon to the end of SQL. Pay attention to the comment in table schema.
|
23 |
+
Pay attention to the data type when using functions. Always use `AND` to connect conditions in `WHERE` and never use comma.
|
24 |
+
Make sure you never write an isolated `WHERE` keyword and never use undesired condition to conrtain the query.
|
25 |
+
|
26 |
+
Use the following format:
|
27 |
+
|
28 |
+
======== table info ========
|
29 |
+
<some table infos>
|
30 |
+
|
31 |
+
Question: "Question here"
|
32 |
+
SQLQuery: "SQL Query to run"
|
33 |
+
|
34 |
+
|
35 |
+
Here are some examples:
|
36 |
+
|
37 |
+
======== table info ========
|
38 |
+
CREATE TABLE "ChatPaper" (
|
39 |
+
abstract String,
|
40 |
+
id String,
|
41 |
+
vector Array(Float32),
|
42 |
+
) ENGINE = ReplicatedReplacingMergeTree()
|
43 |
+
ORDER BY id
|
44 |
+
PRIMARY KEY id
|
45 |
+
|
46 |
+
Question: What is Feartue Pyramid Network?
|
47 |
+
SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
|
48 |
+
|
49 |
+
|
50 |
+
======== table info ========
|
51 |
+
CREATE TABLE "ChatPaper" (
|
52 |
+
abstract String,
|
53 |
+
id String,
|
54 |
+
vector Array(Float32),
|
55 |
+
categories Array(String),
|
56 |
+
pubdate DateTime,
|
57 |
+
title String,
|
58 |
+
authors Array(String),
|
59 |
+
primary_category String
|
60 |
+
) ENGINE = ReplicatedReplacingMergeTree()
|
61 |
+
ORDER BY id
|
62 |
+
PRIMARY KEY id
|
63 |
+
|
64 |
+
Question: What is PaperRank? What is the contribution of those works? Use paper with more than 2 categories.
|
65 |
+
SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper WHERE length(categories) > 2 ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
|
66 |
+
|
67 |
+
|
68 |
+
======== table info ========
|
69 |
+
CREATE TABLE "ChatArXiv" (
|
70 |
+
primary_category String
|
71 |
+
categories Array(String),
|
72 |
+
pubdate DateTime,
|
73 |
+
abstract String,
|
74 |
+
title String,
|
75 |
+
paper_id String,
|
76 |
+
vector Array(Float32),
|
77 |
+
authors Array(String),
|
78 |
+
) ENGINE = MergeTree()
|
79 |
+
ORDER BY paper_id
|
80 |
+
PRIMARY KEY paper_id
|
81 |
+
|
82 |
+
Question: Did Geoffrey Hinton wrote about Capsule Neural Networks? Please use articles published later than 2021.
|
83 |
+
SQLQuery: SELECT ChatArXiv.title, ChatArXiv.paper_id, ChatArXiv.authors FROM ChatArXiv WHERE has(authors, 'Geoffrey Hinton') AND pubdate > parseDateTimeBestEffort('2021-01-01') ORDER BY DISTANCE(vector, NeuralArray(Capsule Neural Networks)) LIMIT {top_k}
|
84 |
+
|
85 |
+
|
86 |
+
======== table info ========
|
87 |
+
CREATE TABLE "PaperDatabase" (
|
88 |
+
abstract String,
|
89 |
+
categories Array(String),
|
90 |
+
vector Array(Float32),
|
91 |
+
pubdate DateTime,
|
92 |
+
id String,
|
93 |
+
comments String,
|
94 |
+
title String,
|
95 |
+
authors Array(String),
|
96 |
+
primary_category String
|
97 |
+
) ENGINE = MergeTree()
|
98 |
+
ORDER BY id
|
99 |
+
PRIMARY KEY id
|
100 |
+
|
101 |
+
Question: Find papers whose abstract has Mutual Information in it.
|
102 |
+
SQLQuery: SELECT PaperDatabase.title, PaperDatabase.id FROM PaperDatabase WHERE abstract ILIKE '%Mutual Information%' ORDER BY DISTANCE(vector, NeuralArray(Mutual Information)) LIMIT {top_k}
|
103 |
+
|
104 |
+
|
105 |
+
Let's begin:
|
106 |
+
|
107 |
+
======== table info ========
|
108 |
+
{table_info}
|
109 |
+
|
110 |
+
Question: {input}
|
111 |
+
SQLQuery: """
|
requirements.txt
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
langchain @ git+https://github.com/myscale/langchain.git@
|
2 |
InstructorEmbedding
|
3 |
pandas
|
4 |
sentence_transformers
|
|
|
1 |
+
langchain @ git+https://github.com/myscale/langchain.git@preview
|
2 |
InstructorEmbedding
|
3 |
pandas
|
4 |
sentence_transformers
|