TheoLvs's picture
feature/add_agents (#14)
48e003d verified
raw
history blame
6.87 kB
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from pyvis.network import Network
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from ..engine.utils import num_tokens_from_string
from typing import List
from pydantic import Field
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
import pyalex
pyalex.config.email = "[email protected]"
def replace_nan_with_empty_dict(x):
return x if pd.notna(x) else {}
class OpenAlex():
def __init__(self):
pass
def search(self,keywords:str,n_results = 100,after = None,before = None):
if isinstance(keywords,str):
works = Works().search(keywords)
if after is not None:
assert isinstance(after,int), "after must be an integer"
assert after > 1900, "after must be greater than 1900"
works = works.filter(publication_year=f">{after}")
for page in works.paginate(per_page=n_results):
break
df_works = pd.DataFrame(page)
df_works = df_works.dropna(subset = ["title"])
df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
df_works["url"] = df_works["id"]
df_works["content"] = (df_works["title"] + "\n" + df_works["abstract"]).map(lambda x : x.strip())
df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))
df_works = df_works.drop(columns = ["abstract_inverted_index"])
# df_works["subtitle"] = df_works["title"] + " - " + df_works["primary_location"]["source"]["display_name"] + " - " + df_works["publication_year"]
return df_works
else:
raise Exception("Keywords must be a string")
def rerank(self,query,df,reranker):
scores = reranker.rank(
query,
df["content"].tolist(),
top_k = len(df),
)
scores.sort(key = lambda x : x["corpus_id"])
scores = [x["score"] for x in scores]
df["rerank_score"] = scores
return df
def make_network(self,df):
# Initialize your graph
G = nx.DiGraph()
for i,row in df.iterrows():
paper = row.to_dict()
G.add_node(paper['id'], **paper)
for reference in paper['referenced_works']:
if reference not in G:
pass
else:
# G.add_node(reference, id=reference, title="", reference_works=[], original=False)
G.add_edge(paper['id'], reference, relationship="CITING")
return G
def show_network(self,G,height = "750px",notebook = True,color_by = "pagerank"):
net = Network(height=height, width="100%", bgcolor="#ffffff", font_color="black",notebook = notebook,directed = True,neighborhood_highlight = True)
net.force_atlas_2based()
# Add nodes with size reflecting the PageRank to highlight importance
pagerank = nx.pagerank(G)
if color_by == "pagerank":
color_scores = pagerank
elif color_by == "rerank_score":
color_scores = {node: G.nodes[node].get("rerank_score", 0) for node in G.nodes}
else:
raise ValueError(f"Unknown color_by value: {color_by}")
# Normalize PageRank values to [0, 1] for color mapping
min_score = min(color_scores.values())
max_score = max(color_scores.values())
norm_color_scores = {node: (color_scores[node] - min_score) / (max_score - min_score) for node in color_scores}
for node in G.nodes:
info = G.nodes[node]
title = info["title"]
label = title[:30] + " ..."
title = [title,f"Year: {info['publication_year']}",f"ID: {info['id']}"]
title = "\n".join(title)
color_value = norm_color_scores[node]
# Generating a color from blue (low) to red (high)
color = plt.cm.RdBu_r(color_value) # coolwarm is a matplotlib colormap from blue to red
def clamp(x):
return int(max(0, min(x*255, 255)))
color = tuple([clamp(x) for x in color[:3]])
color = '#%02x%02x%02x' % color
net.add_node(node, title=title,size = pagerank[node]*1000,label = label,color = color)
# Add edges
for edge in G.edges:
net.add_edge(edge[0], edge[1],arrowStrikethrough=True,color = "gray")
# Show the network
if notebook:
return net.show("network.html")
else:
return net
def get_abstract_from_inverted_index(self,index):
if index is None:
return ""
else:
# Determine the maximum index to know the length of the reconstructed array
max_index = max([max(positions) for positions in index.values()])
# Initialize a list with placeholders for all positions
reconstructed = [''] * (max_index + 1)
# Iterate through the inverted index and place each token at its respective position(s)
for token, positions in index.items():
for position in positions:
reconstructed[position] = token
# Join the tokens to form the reconstructed sentence(s)
return ' '.join(reconstructed)
class OpenAlexRetriever(BaseRetriever):
min_year:int = 1960
max_year:int = None
k:int = 100
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
openalex = OpenAlex()
# Search for documents
df_docs = openalex.search(query,n_results=self.k,after = self.min_year,before = self.max_year)
docs = []
for i,row in df_docs.iterrows():
num_tokens = row["num_tokens"]
if num_tokens < 50 or num_tokens > 1000:
continue
doc = Document(
page_content = row["content"],
metadata = row.to_dict()
)
docs.append(doc)
return docs