File size: 2,966 Bytes
f0fc5f8
 
 
 
38ed905
 
 
 
 
 
f0fc5f8
 
 
 
 
887905a
139fefe
393b23a
f0fc5f8
 
4b4bf28
 
f0fc5f8
139fefe
 
 
 
f0fc5f8
 
 
6b43c86
887905a
f0fc5f8
 
 
139fefe
4b4bf28
139fefe
 
 
 
f0fc5f8
 
 
 
9a9100e
f0fc5f8
139fefe
 
f0fc5f8
 
 
 
 
9a9100e
f0fc5f8
 
139fefe
f0fc5f8
 
 
 
 
887905a
 
f0fc5f8
 
 
 
99e91d8
f0fc5f8
 
4b4bf28
139fefe
f0fc5f8
 
139fefe
 
f0fc5f8
139fefe
f0fc5f8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# https://github.com/langchain-ai/langchain/issues/8623

import pandas as pd

from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun

from typing import List
from pydantic import Field

class ClimateQARetriever(BaseRetriever):
    vectorstore:VectorStore
    sources:list = ["IPCC","IPBES","IPOS"]
    reports:list = []
    threshold:float = 0.6
    k_summary:int = 3
    k_total:int = 10
    namespace:str = "vectors",
    min_size:int = 200,


    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:

        # Check if all elements in the list are either IPCC or IPBES
        assert isinstance(self.sources,list)
        assert self.sources
        assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
        assert self.k_total > self.k_summary, "k_total should be greater than k_summary"

        # Prepare base search kwargs
        filters = {}

        if len(self.reports) > 0:
            filters["short_name"] = {"$in":self.reports}
        else:
            filters["source"] = { "$in":self.sources}

        # Search for k_summary documents in the summaries dataset
        filters_summaries = {
            **filters,
            "report_type": { "$in":["SPM"]},
        }

        docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
        docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]

        # Search for k_total - k_summary documents in the full reports dataset
        filters_full = {
            **filters,
            "report_type": { "$nin":["SPM"]},
        }
        k_full = self.k_total - len(docs_summaries)
        docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)

        # Concatenate documents
        docs = docs_summaries + docs_full

        # Filter if scores are below threshold
        docs = [x for x in docs if len(x[0].page_content) > self.min_size]
        # docs = [x for x in docs if x[1] > self.threshold]

        # Add score to metadata
        results = []
        for i,(doc,score) in enumerate(docs):
            doc.page_content = doc.page_content.replace("\r\n"," ")
            doc.metadata["similarity_score"] = score
            doc.metadata["content"] = doc.page_content
            doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
            # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
            results.append(doc)

        # Sort by score
        # results = sorted(results,key = lambda x : x.metadata["similarity_score"],reverse = True)

        return results