|
import sys |
|
import os |
|
from contextlib import contextmanager |
|
|
|
from ..reranker import rerank_docs |
|
from ..graph_retriever import GraphRetriever |
|
from ...utils import remove_duplicates_keep_highest_score |
|
|
|
|
|
def divide_into_parts(target, parts): |
|
|
|
base = target // parts |
|
|
|
remainder = target % parts |
|
|
|
result = [] |
|
|
|
for i in range(parts): |
|
if i < remainder: |
|
|
|
result.append(base + 1) |
|
else: |
|
|
|
result.append(base) |
|
|
|
return result |
|
|
|
|
|
@contextmanager |
|
def suppress_output(): |
|
|
|
with open(os.devnull, 'w') as devnull: |
|
|
|
old_stdout = sys.stdout |
|
old_stderr = sys.stderr |
|
|
|
sys.stdout = devnull |
|
sys.stderr = devnull |
|
try: |
|
yield |
|
finally: |
|
|
|
sys.stdout = old_stdout |
|
sys.stderr = old_stderr |
|
|
|
|
|
def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100): |
|
|
|
def retrieve_graphs(state): |
|
print("---- Retrieving graphs ----") |
|
|
|
POSSIBLE_SOURCES = ["IEA", "OWID"] |
|
questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]] |
|
sources_input = state["sources_input"] |
|
|
|
auto_mode = "auto" in sources_input |
|
|
|
|
|
|
|
|
|
if rerank_by_question: |
|
k_by_question = divide_into_parts(k_final,len(questions)) |
|
|
|
docs = [] |
|
|
|
for i,q in enumerate(questions): |
|
|
|
question = q["question"] if isinstance(q, dict) else q |
|
|
|
print(f"Subquestion {i}: {question}") |
|
|
|
|
|
if auto_mode: |
|
sources = POSSIBLE_SOURCES |
|
|
|
else: |
|
sources = sources_input |
|
|
|
if any([x in POSSIBLE_SOURCES for x in sources]): |
|
|
|
sources = [x for x in sources if x in POSSIBLE_SOURCES] |
|
|
|
|
|
retriever = GraphRetriever( |
|
vectorstore = vectorstore, |
|
sources = sources, |
|
k_total = k_before_reranking, |
|
threshold = 0.5, |
|
) |
|
docs_question = retriever.get_relevant_documents(question) |
|
|
|
|
|
if reranker is not None and docs_question!=[]: |
|
with suppress_output(): |
|
docs_question = rerank_docs(reranker,docs_question,question) |
|
else: |
|
|
|
for doc in docs_question: |
|
doc.metadata["reranking_score"] = doc.metadata["similarity_score"] |
|
|
|
|
|
if rerank_by_question: |
|
docs_question = docs_question[:k_by_question[i]] |
|
|
|
|
|
for doc in docs_question: |
|
doc.metadata["sources_used"] = sources |
|
|
|
print(f"{len(docs_question)} graphs retrieved for subquestion {i + 1}: {docs_question}") |
|
|
|
docs.extend(docs_question) |
|
|
|
else: |
|
print(f"There are no graphs which match the sources filtered on. Sources filtered on: {sources}. Sources available: {POSSIBLE_SOURCES}.") |
|
|
|
|
|
docs = remove_duplicates_keep_highest_score(docs) |
|
|
|
|
|
|
|
docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True) |
|
docs = docs[:k_final] |
|
|
|
return {"recommended_content": docs} |
|
|
|
return retrieve_graphs |