File size: 4,863 Bytes
6b43c86 ccd4b9e 6b43c86 ccd4b9e 6b43c86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import sys
import os
from contextlib import contextmanager
from ..reranker import rerank_docs
from ..graph_retriever import GraphRetriever
from ...utils import remove_duplicates_keep_highest_score
def divide_into_parts(target, parts):
# Base value for each part
base = target // parts
# Remainder to distribute
remainder = target % parts
# List to hold the result
result = []
for i in range(parts):
if i < remainder:
# These parts get base value + 1
result.append(base + 1)
else:
# The rest get the base value
result.append(base)
return result
@contextmanager
def suppress_output():
# Open a null device
with open(os.devnull, 'w') as devnull:
# Store the original stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
# Redirect stdout and stderr to the null device
sys.stdout = devnull
sys.stderr = devnull
try:
yield
finally:
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
def retrieve_graphs(state):
print("---- Retrieving graphs ----")
POSSIBLE_SOURCES = ["IEA", "OWID"]
questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]]
sources_input = state["sources_input"]
auto_mode = "auto" in sources_input
# There are several options to get the final top k
# Option 1 - Get 100 documents by question and rerank by question
# Option 2 - Get 100/n documents by question and rerank the total
if rerank_by_question:
k_by_question = divide_into_parts(k_final,len(questions))
docs = []
for i,q in enumerate(questions):
question = q["question"] if isinstance(q, dict) else q
print(f"Subquestion {i}: {question}")
# If auto mode, we use all sources
if auto_mode:
sources = POSSIBLE_SOURCES
# Otherwise, we use the config
else:
sources = sources_input
if any([x in POSSIBLE_SOURCES for x in sources]):
sources = [x for x in sources if x in POSSIBLE_SOURCES]
# Search the document store using the retriever
retriever = GraphRetriever(
vectorstore = vectorstore,
sources = sources,
k_total = k_before_reranking,
threshold = 0.5,
)
docs_question = retriever.get_relevant_documents(question)
# Rerank
if reranker is not None and docs_question!=[]:
with suppress_output():
docs_question = rerank_docs(reranker,docs_question,question)
else:
# Add a default reranking score
for doc in docs_question:
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
# If rerank by question we select the top documents for each question
if rerank_by_question:
docs_question = docs_question[:k_by_question[i]]
# Add sources used in the metadata
for doc in docs_question:
doc.metadata["sources_used"] = sources
print(f"{len(docs_question)} graphs retrieved for subquestion {i + 1}: {docs_question}")
docs.extend(docs_question)
else:
print(f"There are no graphs which match the sources filtered on. Sources filtered on: {sources}. Sources available: {POSSIBLE_SOURCES}.")
# Remove duplicates and keep the duplicate document with the highest reranking score
docs = remove_duplicates_keep_highest_score(docs)
# Sorting the list in descending order by rerank_score
# Then select the top k
docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
docs = docs[:k_final]
return {"recommended_content": docs}
return retrieve_graphs |