Spaces:
Running
Running
File size: 5,414 Bytes
eb820e1 9fee0ab eb820e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import re
import inspect
from typing import Dict, Any, Optional, List, Tuple
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.schema import BaseRetriever
from langchain.callbacks.manager import Callbacks
from langchain.schema.prompt_template import format_document
from langchain.docstore.document import Document
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
"""Based on VectorSQLOutputParser
It also modify the SQL to get all columns
"""
@property
def _type(self) -> str:
return "vector_sql_retrieve_custom"
def parse(self, text: str) -> Dict[str, Any]:
text = text.strip()
start = text.upper().find("SELECT")
if start >= 0:
end = text.upper().find("FROM")
text = text.replace(text[start + len("SELECT") + 1 : end - 1], "title, abstract, authors, pubdate, categories, id")
return super().parse(text)
class ArXivStuffDocumentChain(StuffDocumentsChain):
"""Combine arxiv documents with PDF reference number"""
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
"""Construct inputs from kwargs and docs.
Format and the join all the documents together into one input with name
`self.document_variable_name`. The pluck any additional variables
from **kwargs.
Args:
docs: List of documents to format and then join into single input
**kwargs: additional inputs to chain, will pluck any other required
arguments from here.
Returns:
dictionary of inputs to LLMChain
"""
# Format each document according to the prompt
doc_strings = []
for doc_id, doc in enumerate(docs):
# add temp reference number in metadata
doc.metadata.update({'ref_id': doc_id})
doc.page_content = doc.page_content.replace('\n', ' ')
doc_strings.append(format_document(doc, self.document_prompt))
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings)
return inputs
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM.
Args:
docs: List of documents to join together into one variable
callbacks: Optional callbacks to pass along
**kwargs: additional parameters to use to get inputs to LLMChain.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
output = self.llm_chain.predict(callbacks=callbacks, **inputs)
return output, {}
@property
def _chain_type(self) -> str:
return "referenced_stuff_documents_chain"
class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
"""QA with source chain for Chat ArXiv app with references
This chain will automatically assign reference number to the article,
Then parse it back to titles or anything else.
"""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(inputs, run_manager=_run_manager)
else:
docs = self._get_docs(inputs) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
# parse source with ref_id
sources = []
ref_cnt = 1
for d in docs:
ref_id = d.metadata['ref_id']
if f"PDF #{ref_id}" in answer:
title = d.metadata['title'].replace('\n', '')
d.metadata['ref_id'] = ref_cnt
answer = answer.replace(f"PDF #{ref_id}", f"{title} [{ref_cnt}]")
sources.append(d)
ref_cnt += 1
result: Dict[str, Any] = {
self.answer_key: answer,
self.sources_answer_key: sources,
}
if self.return_source_documents:
result["source_documents"] = docs
return result
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
raise NotImplementedError
@property
def _chain_type(self) -> str:
return "arxiv_qa_with_sources_chain"
|