mpsk commited on
Commit
526644e
Β·
1 Parent(s): 3e40ebd

update to new preview

Browse files
Files changed (2) hide show
  1. app.py +9 -8
  2. chains/arxiv_chains.py +1 -1
app.py CHANGED
@@ -9,7 +9,7 @@ environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
9
  from langchain.vectorstores import MyScale, MyScaleSettings
10
  from langchain.embeddings import HuggingFaceInstructEmbeddings
11
  from langchain.retrievers.self_query.base import SelfQueryRetriever
12
- from langchain.chains.query_constructor.base import AttributeInfo
13
  from langchain import OpenAI
14
  from langchain.chat_models import ChatOpenAI
15
 
@@ -19,9 +19,9 @@ from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
19
  from sqlalchemy import create_engine, MetaData
20
  from langchain.chains import LLMChain
21
 
22
- from langchain_experimental.utilities.sql_database import SQLDatabase
23
- from langchain_experimental.retrievers.sql_database import SQLDatabaseChainRetriever
24
- from langchain_experimental.sql.base import SQLDatabaseChain
25
 
26
  from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
27
  from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
@@ -82,7 +82,7 @@ def build_retriever():
82
  with st.spinner("Building Self Query Retriever..."):
83
  metadata_field_info = [
84
  AttributeInfo(
85
- name="pubdate",
86
  description="The year the paper is published",
87
  type="timestamp",
88
  ),
@@ -155,7 +155,7 @@ def build_retriever():
155
 
156
  output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
157
  model=embeddings)
158
- sql_query_chain = SQLDatabaseChain.from_llm(
159
  llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
160
  prompt=PROMPT,
161
  top_k=10,
@@ -164,7 +164,7 @@ def build_retriever():
164
  sql_cmd_parser=output_parser,
165
  native_format=True
166
  )
167
- sql_retriever = SQLDatabaseChainRetriever(
168
  sql_db_chain=sql_query_chain, page_content_key="abstract")
169
 
170
  with st.spinner('Building QA Chain with Vector SQL...'):
@@ -184,7 +184,8 @@ def build_retriever():
184
  max_tokens_limit=12000,
185
  )
186
 
187
- return [{'name': m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], retriever, chain, sql_retriever, sql_chain
 
188
 
189
 
190
  if 'retriever' not in st.session_state:
 
9
  from langchain.vectorstores import MyScale, MyScaleSettings
10
  from langchain.embeddings import HuggingFaceInstructEmbeddings
11
  from langchain.retrievers.self_query.base import SelfQueryRetriever
12
+ from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
13
  from langchain import OpenAI
14
  from langchain.chat_models import ChatOpenAI
15
 
 
19
  from sqlalchemy import create_engine, MetaData
20
  from langchain.chains import LLMChain
21
 
22
+ from langchain.utilities.sql_database import SQLDatabase
23
+ from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
24
+ from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
25
 
26
  from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
27
  from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
 
82
  with st.spinner("Building Self Query Retriever..."):
83
  metadata_field_info = [
84
  AttributeInfo(
85
+ name=VirtualColumnName(name="pubdate"),
86
  description="The year the paper is published",
87
  type="timestamp",
88
  ),
 
155
 
156
  output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
157
  model=embeddings)
158
+ sql_query_chain = VectorSQLDatabaseChain.from_llm(
159
  llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
160
  prompt=PROMPT,
161
  top_k=10,
 
164
  sql_cmd_parser=output_parser,
165
  native_format=True
166
  )
167
+ sql_retriever = VectorSQLDatabaseChainRetriever(
168
  sql_db_chain=sql_query_chain, page_content_key="abstract")
169
 
170
  with st.spinner('Building QA Chain with Vector SQL...'):
 
184
  max_tokens_limit=12000,
185
  )
186
 
187
+ return [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], \
188
+ retriever, chain, sql_retriever, sql_chain
189
 
190
 
191
  if 'retriever' not in st.session_state:
chains/arxiv_chains.py CHANGED
@@ -15,7 +15,7 @@ from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesCha
15
  from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
16
  from langchain.chains.combine_documents.stuff import StuffDocumentsChain
17
 
18
- from langchain_experimental.sql.parser import VectorSQLOutputParser
19
 
20
 
21
  class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
 
15
  from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
16
  from langchain.chains.combine_documents.stuff import StuffDocumentsChain
17
 
18
+ from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
19
 
20
 
21
  class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):