File size: 6,872 Bytes
caf1faa 99e91d8 caf1faa 99e91d8 caf1faa 99e91d8 caf1faa 99e91d8 caf1faa 99e91d8 caf1faa 99e91d8 caf1faa 99e91d8 72edd2d 99e91d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from pyvis.network import Network
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from ..engine.utils import num_tokens_from_string
from typing import List
from pydantic import Field
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
import pyalex
pyalex.config.email = "[email protected]"
def replace_nan_with_empty_dict(x):
return x if pd.notna(x) else {}
class OpenAlex():
def __init__(self):
pass
def search(self,keywords:str,n_results = 100,after = None,before = None):
if isinstance(keywords,str):
works = Works().search(keywords)
if after is not None:
assert isinstance(after,int), "after must be an integer"
assert after > 1900, "after must be greater than 1900"
works = works.filter(publication_year=f">{after}")
for page in works.paginate(per_page=n_results):
break
df_works = pd.DataFrame(page)
df_works = df_works.dropna(subset = ["title"])
df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
df_works["url"] = df_works["id"]
df_works["content"] = (df_works["title"] + "\n" + df_works["abstract"]).map(lambda x : x.strip())
df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))
df_works = df_works.drop(columns = ["abstract_inverted_index"])
# df_works["subtitle"] = df_works["title"] + " - " + df_works["primary_location"]["source"]["display_name"] + " - " + df_works["publication_year"]
return df_works
else:
raise Exception("Keywords must be a string")
def rerank(self,query,df,reranker):
scores = reranker.rank(
query,
df["content"].tolist(),
top_k = len(df),
)
scores.sort(key = lambda x : x["corpus_id"])
scores = [x["score"] for x in scores]
df["rerank_score"] = scores
return df
def make_network(self,df):
# Initialize your graph
G = nx.DiGraph()
for i,row in df.iterrows():
paper = row.to_dict()
G.add_node(paper['id'], **paper)
for reference in paper['referenced_works']:
if reference not in G:
pass
else:
# G.add_node(reference, id=reference, title="", reference_works=[], original=False)
G.add_edge(paper['id'], reference, relationship="CITING")
return G
def show_network(self,G,height = "750px",notebook = True,color_by = "pagerank"):
net = Network(height=height, width="100%", bgcolor="#ffffff", font_color="black",notebook = notebook,directed = True,neighborhood_highlight = True)
net.force_atlas_2based()
# Add nodes with size reflecting the PageRank to highlight importance
pagerank = nx.pagerank(G)
if color_by == "pagerank":
color_scores = pagerank
elif color_by == "rerank_score":
color_scores = {node: G.nodes[node].get("rerank_score", 0) for node in G.nodes}
else:
raise ValueError(f"Unknown color_by value: {color_by}")
# Normalize PageRank values to [0, 1] for color mapping
min_score = min(color_scores.values())
max_score = max(color_scores.values())
norm_color_scores = {node: (color_scores[node] - min_score) / (max_score - min_score) for node in color_scores}
for node in G.nodes:
info = G.nodes[node]
title = info["title"]
label = title[:30] + " ..."
title = [title,f"Year: {info['publication_year']}",f"ID: {info['id']}"]
title = "\n".join(title)
color_value = norm_color_scores[node]
# Generating a color from blue (low) to red (high)
color = plt.cm.RdBu_r(color_value) # coolwarm is a matplotlib colormap from blue to red
def clamp(x):
return int(max(0, min(x*255, 255)))
color = tuple([clamp(x) for x in color[:3]])
color = '#%02x%02x%02x' % color
net.add_node(node, title=title,size = pagerank[node]*1000,label = label,color = color)
# Add edges
for edge in G.edges:
net.add_edge(edge[0], edge[1],arrowStrikethrough=True,color = "gray")
# Show the network
if notebook:
return net.show("network.html")
else:
return net
def get_abstract_from_inverted_index(self,index):
if index is None:
return ""
else:
# Determine the maximum index to know the length of the reconstructed array
max_index = max([max(positions) for positions in index.values()])
# Initialize a list with placeholders for all positions
reconstructed = [''] * (max_index + 1)
# Iterate through the inverted index and place each token at its respective position(s)
for token, positions in index.items():
for position in positions:
reconstructed[position] = token
# Join the tokens to form the reconstructed sentence(s)
return ' '.join(reconstructed)
class OpenAlexRetriever(BaseRetriever):
min_year:int = 1960
max_year:int = None
k:int = 100
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
openalex = OpenAlex()
# Search for documents
df_docs = openalex.search(query,n_results=self.k,after = self.min_year,before = self.max_year)
docs = []
for i,row in df_docs.iterrows():
num_tokens = row["num_tokens"]
if num_tokens < 50 or num_tokens > 1000:
continue
doc = Document(
page_content = row["content"],
metadata = row.to_dict()
)
docs.append(doc)
return docs
|