File size: 2,966 Bytes
f0fc5f8 38ed905 f0fc5f8 887905a 139fefe 393b23a f0fc5f8 4b4bf28 f0fc5f8 139fefe f0fc5f8 6b43c86 887905a f0fc5f8 139fefe 4b4bf28 139fefe f0fc5f8 9a9100e f0fc5f8 139fefe f0fc5f8 9a9100e f0fc5f8 139fefe f0fc5f8 887905a f0fc5f8 99e91d8 f0fc5f8 4b4bf28 139fefe f0fc5f8 139fefe f0fc5f8 139fefe f0fc5f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
# https://github.com/langchain-ai/langchain/issues/8623
import pandas as pd
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from typing import List
from pydantic import Field
class ClimateQARetriever(BaseRetriever):
vectorstore:VectorStore
sources:list = ["IPCC","IPBES","IPOS"]
reports:list = []
threshold:float = 0.6
k_summary:int = 3
k_total:int = 10
namespace:str = "vectors",
min_size:int = 200,
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
# Check if all elements in the list are either IPCC or IPBES
assert isinstance(self.sources,list)
assert self.sources
assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
# Prepare base search kwargs
filters = {}
if len(self.reports) > 0:
filters["short_name"] = {"$in":self.reports}
else:
filters["source"] = { "$in":self.sources}
# Search for k_summary documents in the summaries dataset
filters_summaries = {
**filters,
"report_type": { "$in":["SPM"]},
}
docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
# Search for k_total - k_summary documents in the full reports dataset
filters_full = {
**filters,
"report_type": { "$nin":["SPM"]},
}
k_full = self.k_total - len(docs_summaries)
docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
# Concatenate documents
docs = docs_summaries + docs_full
# Filter if scores are below threshold
docs = [x for x in docs if len(x[0].page_content) > self.min_size]
# docs = [x for x in docs if x[1] > self.threshold]
# Add score to metadata
results = []
for i,(doc,score) in enumerate(docs):
doc.page_content = doc.page_content.replace("\r\n"," ")
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
results.append(doc)
# Sort by score
# results = sorted(results,key = lambda x : x.metadata["similarity_score"],reverse = True)
return results
|