Spaces:
Running
Running
update to new preview
Browse files- app.py +9 -8
- chains/arxiv_chains.py +1 -1
app.py
CHANGED
@@ -9,7 +9,7 @@ environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
|
9 |
from langchain.vectorstores import MyScale, MyScaleSettings
|
10 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
11 |
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
12 |
-
from langchain.chains.query_constructor.base import AttributeInfo
|
13 |
from langchain import OpenAI
|
14 |
from langchain.chat_models import ChatOpenAI
|
15 |
|
@@ -19,9 +19,9 @@ from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
|
|
19 |
from sqlalchemy import create_engine, MetaData
|
20 |
from langchain.chains import LLMChain
|
21 |
|
22 |
-
from
|
23 |
-
from langchain_experimental.retrievers.
|
24 |
-
from langchain_experimental.sql.
|
25 |
|
26 |
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
|
27 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
@@ -82,7 +82,7 @@ def build_retriever():
|
|
82 |
with st.spinner("Building Self Query Retriever..."):
|
83 |
metadata_field_info = [
|
84 |
AttributeInfo(
|
85 |
-
name="pubdate",
|
86 |
description="The year the paper is published",
|
87 |
type="timestamp",
|
88 |
),
|
@@ -155,7 +155,7 @@ def build_retriever():
|
|
155 |
|
156 |
output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
|
157 |
model=embeddings)
|
158 |
-
sql_query_chain =
|
159 |
llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
|
160 |
prompt=PROMPT,
|
161 |
top_k=10,
|
@@ -164,7 +164,7 @@ def build_retriever():
|
|
164 |
sql_cmd_parser=output_parser,
|
165 |
native_format=True
|
166 |
)
|
167 |
-
sql_retriever =
|
168 |
sql_db_chain=sql_query_chain, page_content_key="abstract")
|
169 |
|
170 |
with st.spinner('Building QA Chain with Vector SQL...'):
|
@@ -184,7 +184,8 @@ def build_retriever():
|
|
184 |
max_tokens_limit=12000,
|
185 |
)
|
186 |
|
187 |
-
return [{'name': m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info],
|
|
|
188 |
|
189 |
|
190 |
if 'retriever' not in st.session_state:
|
|
|
9 |
from langchain.vectorstores import MyScale, MyScaleSettings
|
10 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
11 |
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
12 |
+
from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
|
13 |
from langchain import OpenAI
|
14 |
from langchain.chat_models import ChatOpenAI
|
15 |
|
|
|
19 |
from sqlalchemy import create_engine, MetaData
|
20 |
from langchain.chains import LLMChain
|
21 |
|
22 |
+
from langchain.utilities.sql_database import SQLDatabase
|
23 |
+
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
|
24 |
+
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
|
25 |
|
26 |
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
|
27 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
|
|
82 |
with st.spinner("Building Self Query Retriever..."):
|
83 |
metadata_field_info = [
|
84 |
AttributeInfo(
|
85 |
+
name=VirtualColumnName(name="pubdate"),
|
86 |
description="The year the paper is published",
|
87 |
type="timestamp",
|
88 |
),
|
|
|
155 |
|
156 |
output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
|
157 |
model=embeddings)
|
158 |
+
sql_query_chain = VectorSQLDatabaseChain.from_llm(
|
159 |
llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
|
160 |
prompt=PROMPT,
|
161 |
top_k=10,
|
|
|
164 |
sql_cmd_parser=output_parser,
|
165 |
native_format=True
|
166 |
)
|
167 |
+
sql_retriever = VectorSQLDatabaseChainRetriever(
|
168 |
sql_db_chain=sql_query_chain, page_content_key="abstract")
|
169 |
|
170 |
with st.spinner('Building QA Chain with Vector SQL...'):
|
|
|
184 |
max_tokens_limit=12000,
|
185 |
)
|
186 |
|
187 |
+
return [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], \
|
188 |
+
retriever, chain, sql_retriever, sql_chain
|
189 |
|
190 |
|
191 |
if 'retriever' not in st.session_state:
|
chains/arxiv_chains.py
CHANGED
@@ -15,7 +15,7 @@ from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesCha
|
|
15 |
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
16 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
17 |
|
18 |
-
from langchain_experimental.sql.
|
19 |
|
20 |
|
21 |
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|
|
|
15 |
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
16 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
17 |
|
18 |
+
from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
|
19 |
|
20 |
|
21 |
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|