Spaces:
Running
Running
File size: 7,298 Bytes
45180a0 eb820e1 45180a0 eb820e1 45180a0 eb820e1 526644e cd8e68f 45180a0 eb820e1 4a0bec4 45180a0 4a0bec4 92dac1d 4a0bec4 eb820e1 45180a0 eb820e1 45180a0 eb820e1 45180a0 eb820e1 45180a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import logging
import inspect
from typing import Dict, Any, Optional, List, Tuple
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever
from langchain.callbacks.manager import Callbacks
from langchain.schema.prompt_template import format_document
from langchain.docstore.document import Document
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.vectorstores.myscale import MyScale, MyScaleSettings
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
logger = logging.getLogger()
class MyScaleWithoutMetadataJson(MyScale):
def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None:
super().__init__(embedding, config, **kwargs)
self.must_have_cols: List[str] = must_have_cols
def _build_qstr(
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
) -> str:
q_emb_str = ",".join(map(str, q_emb))
if where_str:
where_str = f"PREWHERE {where_str}"
else:
where_str = ""
q_str = f"""
SELECT {self.config.column_map['text']}, dist, {','.join(self.must_have_cols)}
FROM {self.config.database}.{self.config.table}
{where_str}
ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
AS dist {self.dist_order}
LIMIT {topk}
"""
return q_str
def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]:
q_str = self._build_qstr(embedding, k, where_str)
try:
return [
Document(
page_content=r[self.config.column_map["text"]],
metadata={k: r[k] for k in self.must_have_cols},
)
for r in self.client.query(q_str).named_results()
]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
"""Based on VectorSQLOutputParser
It also modify the SQL to get all columns
"""
must_have_columns: List[str]
@property
def _type(self) -> str:
return "vector_sql_retrieve_custom"
def parse(self, text: str) -> Dict[str, Any]:
text = text.strip()
start = text.upper().find("SELECT")
if start >= 0:
end = text.upper().find("FROM")
text = text.replace(text[start + len("SELECT") + 1 : end - 1], ", ".join(self.must_have_columns))
return super().parse(text)
class ArXivStuffDocumentChain(StuffDocumentsChain):
"""Combine arxiv documents with PDF reference number"""
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
"""Construct inputs from kwargs and docs.
Format and the join all the documents together into one input with name
`self.document_variable_name`. The pluck any additional variables
from **kwargs.
Args:
docs: List of documents to format and then join into single input
**kwargs: additional inputs to chain, will pluck any other required
arguments from here.
Returns:
dictionary of inputs to LLMChain
"""
# Format each document according to the prompt
doc_strings = []
for doc_id, doc in enumerate(docs):
# add temp reference number in metadata
doc.metadata.update({'ref_id': doc_id})
doc.page_content = doc.page_content.replace('\n', ' ')
doc_strings.append(format_document(doc, self.document_prompt))
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings)
return inputs
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM.
Args:
docs: List of documents to join together into one variable
callbacks: Optional callbacks to pass along
**kwargs: additional parameters to use to get inputs to LLMChain.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
output = self.llm_chain.predict(callbacks=callbacks, **inputs)
return output, {}
@property
def _chain_type(self) -> str:
return "referenced_stuff_documents_chain"
class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
"""QA with source chain for Chat ArXiv app with references
This chain will automatically assign reference number to the article,
Then parse it back to titles or anything else.
"""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(inputs, run_manager=_run_manager)
else:
docs = self._get_docs(inputs) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
# parse source with ref_id
sources = []
ref_cnt = 1
for d in docs:
ref_id = d.metadata['ref_id']
if f"Doc #{ref_id}" in answer:
answer = answer.replace(f"Doc #{ref_id}", f"#{ref_id}")
if f"#{ref_id}" in answer:
title = d.metadata['title'].replace('\n', '')
d.metadata['ref_id'] = ref_cnt
answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
sources.append(d)
ref_cnt += 1
result: Dict[str, Any] = {
self.answer_key: answer,
self.sources_answer_key: sources,
}
if self.return_source_documents:
result["source_documents"] = docs
return result
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
raise NotImplementedError
@property
def _chain_type(self) -> str:
return "arxiv_qa_with_sources_chain" |