|
import pandas as pd |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
from pyvis.network import Network |
|
|
|
from langchain_core.retrievers import BaseRetriever |
|
from langchain_core.vectorstores import VectorStoreRetriever |
|
from langchain_core.documents.base import Document |
|
from langchain_core.vectorstores import VectorStore |
|
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun |
|
|
|
from ..engine.utils import num_tokens_from_string |
|
|
|
from typing import List |
|
from pydantic import Field |
|
|
|
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders |
|
import pyalex |
|
|
|
pyalex.config.email = "[email protected]" |
|
|
|
|
|
def replace_nan_with_empty_dict(x): |
|
return x if pd.notna(x) else {} |
|
|
|
class OpenAlex(): |
|
def __init__(self): |
|
pass |
|
|
|
|
|
def search(self,keywords:str,n_results = 100,after = None,before = None): |
|
|
|
if isinstance(keywords,str): |
|
works = Works().search(keywords) |
|
if after is not None: |
|
assert isinstance(after,int), "after must be an integer" |
|
assert after > 1900, "after must be greater than 1900" |
|
works = works.filter(publication_year=f">{after}") |
|
|
|
for page in works.paginate(per_page=n_results): |
|
break |
|
|
|
df_works = pd.DataFrame(page) |
|
df_works = df_works.dropna(subset = ["title"]) |
|
df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict) |
|
df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("") |
|
df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False)) |
|
df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None)) |
|
df_works["url"] = df_works["id"] |
|
df_works["content"] = (df_works["title"] + "\n" + df_works["abstract"]).map(lambda x : x.strip()) |
|
df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x)) |
|
|
|
df_works = df_works.drop(columns = ["abstract_inverted_index"]) |
|
|
|
|
|
return df_works |
|
else: |
|
raise Exception("Keywords must be a string") |
|
|
|
|
|
def rerank(self,query,df,reranker): |
|
|
|
scores = reranker.rank( |
|
query, |
|
df["content"].tolist(), |
|
top_k = len(df), |
|
) |
|
scores.sort(key = lambda x : x["corpus_id"]) |
|
scores = [x["score"] for x in scores] |
|
df["rerank_score"] = scores |
|
return df |
|
|
|
|
|
def make_network(self,df): |
|
|
|
|
|
G = nx.DiGraph() |
|
|
|
for i,row in df.iterrows(): |
|
paper = row.to_dict() |
|
G.add_node(paper['id'], **paper) |
|
for reference in paper['referenced_works']: |
|
if reference not in G: |
|
pass |
|
else: |
|
|
|
G.add_edge(paper['id'], reference, relationship="CITING") |
|
return G |
|
|
|
def show_network(self,G,height = "750px",notebook = True,color_by = "pagerank"): |
|
|
|
net = Network(height=height, width="100%", bgcolor="#ffffff", font_color="black",notebook = notebook,directed = True,neighborhood_highlight = True) |
|
net.force_atlas_2based() |
|
|
|
|
|
pagerank = nx.pagerank(G) |
|
|
|
if color_by == "pagerank": |
|
color_scores = pagerank |
|
elif color_by == "rerank_score": |
|
color_scores = {node: G.nodes[node].get("rerank_score", 0) for node in G.nodes} |
|
else: |
|
raise ValueError(f"Unknown color_by value: {color_by}") |
|
|
|
|
|
min_score = min(color_scores.values()) |
|
max_score = max(color_scores.values()) |
|
norm_color_scores = {node: (color_scores[node] - min_score) / (max_score - min_score) for node in color_scores} |
|
|
|
|
|
|
|
for node in G.nodes: |
|
info = G.nodes[node] |
|
title = info["title"] |
|
label = title[:30] + " ..." |
|
|
|
title = [title,f"Year: {info['publication_year']}",f"ID: {info['id']}"] |
|
title = "\n".join(title) |
|
|
|
color_value = norm_color_scores[node] |
|
|
|
color = plt.cm.RdBu_r(color_value) |
|
def clamp(x): |
|
return int(max(0, min(x*255, 255))) |
|
color = tuple([clamp(x) for x in color[:3]]) |
|
color = '#%02x%02x%02x' % color |
|
|
|
net.add_node(node, title=title,size = pagerank[node]*1000,label = label,color = color) |
|
|
|
|
|
for edge in G.edges: |
|
net.add_edge(edge[0], edge[1],arrowStrikethrough=True,color = "gray") |
|
|
|
|
|
if notebook: |
|
return net.show("network.html") |
|
else: |
|
return net |
|
|
|
|
|
def get_abstract_from_inverted_index(self,index): |
|
|
|
if index is None: |
|
return "" |
|
else: |
|
|
|
|
|
max_index = max([max(positions) for positions in index.values()]) |
|
|
|
|
|
reconstructed = [''] * (max_index + 1) |
|
|
|
|
|
for token, positions in index.items(): |
|
for position in positions: |
|
reconstructed[position] = token |
|
|
|
|
|
return ' '.join(reconstructed) |
|
|
|
|
|
|
|
class OpenAlexRetriever(BaseRetriever): |
|
min_year:int = 1960 |
|
max_year:int = None |
|
k:int = 100 |
|
|
|
def _get_relevant_documents( |
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun |
|
) -> List[Document]: |
|
|
|
openalex = OpenAlex() |
|
|
|
|
|
df_docs = openalex.search(query,n_results=self.k,after = self.min_year,before = self.max_year) |
|
|
|
docs = [] |
|
for i,row in df_docs.iterrows(): |
|
num_tokens = row["num_tokens"] |
|
|
|
if num_tokens < 50 or num_tokens > 1000: |
|
continue |
|
|
|
doc = Document( |
|
page_content = row["content"], |
|
metadata = row.to_dict() |
|
) |
|
docs.append(doc) |
|
return docs |
|
|
|
|
|
|