Spaces:
Running
Running
Update chains/arxiv_chains.py (#2)
Browse files- Update chains/arxiv_chains.py (9fee0ab6e399bb735c6cd7d0de5e6f75be61def0)
- Update app.py (0dd8fb3e3644e91781ff973e6c5c1932f08159ae)
- app.py +2 -2
- chains/arxiv_chains.py +17 -0
app.py
CHANGED
@@ -22,8 +22,8 @@ from langchain.chains import LLMChain
|
|
22 |
from langchain_experimental.utilities.sql_database import SQLDatabase
|
23 |
from langchain_experimental.retrievers.sql_database import SQLDatabaseChainRetriever
|
24 |
from langchain_experimental.sql.base import SQLDatabaseChain
|
25 |
-
from langchain_experimental.sql.parser import VectorSQLRetrieveAllOutputParser
|
26 |
|
|
|
27 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
28 |
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
|
29 |
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
|
@@ -155,7 +155,7 @@ def build_retriever():
|
|
155 |
template=_myscale_prompt,
|
156 |
)
|
157 |
|
158 |
-
output_parser =
|
159 |
model=embeddings)
|
160 |
sql_query_chain = SQLDatabaseChain.from_llm(
|
161 |
llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
|
|
|
22 |
from langchain_experimental.utilities.sql_database import SQLDatabase
|
23 |
from langchain_experimental.retrievers.sql_database import SQLDatabaseChainRetriever
|
24 |
from langchain_experimental.sql.base import SQLDatabaseChain
|
|
|
25 |
|
26 |
+
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
|
27 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
28 |
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
|
29 |
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
|
|
|
155 |
template=_myscale_prompt,
|
156 |
)
|
157 |
|
158 |
+
output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
|
159 |
model=embeddings)
|
160 |
sql_query_chain = SQLDatabaseChain.from_llm(
|
161 |
llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
|
chains/arxiv_chains.py
CHANGED
@@ -16,6 +16,23 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
16 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
class ArXivStuffDocumentChain(StuffDocumentsChain):
|
20 |
"""Combine arxiv documents with PDF reference number"""
|
21 |
|
|
|
16 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
17 |
|
18 |
|
19 |
+
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|
20 |
+
"""Based on VectorSQLOutputParser
|
21 |
+
It also modify the SQL to get all columns
|
22 |
+
"""
|
23 |
+
|
24 |
+
@property
|
25 |
+
def _type(self) -> str:
|
26 |
+
return "vector_sql_retrieve_custom"
|
27 |
+
|
28 |
+
def parse(self, text: str) -> Dict[str, Any]:
|
29 |
+
text = text.strip()
|
30 |
+
start = text.upper().find("SELECT")
|
31 |
+
if start >= 0:
|
32 |
+
end = text.upper().find("FROM")
|
33 |
+
text = text.replace(text[start + len("SELECT") + 1 : end - 1], "title, abstract, authors, pubdate, categories, id")
|
34 |
+
return super().parse(text)
|
35 |
+
|
36 |
class ArXivStuffDocumentChain(StuffDocumentsChain):
|
37 |
"""Combine arxiv documents with PDF reference number"""
|
38 |
|