File size: 6,872 Bytes
caf1faa
 
 
 
 
99e91d8
 
 
 
 
 
 
 
 
 
 
caf1faa
 
 
 
 
99e91d8
 
 
 
caf1faa
 
 
 
 
99e91d8
caf1faa
 
 
 
 
 
 
 
 
 
 
 
99e91d8
 
 
caf1faa
 
99e91d8
 
 
 
 
 
 
 
caf1faa
99e91d8
caf1faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e91d8
 
 
 
 
 
 
72edd2d
99e91d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from pyvis.network import Network

from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun

from ..engine.utils import num_tokens_from_string

from typing import List
from pydantic import Field

from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
import pyalex

pyalex.config.email = "[email protected]"


def replace_nan_with_empty_dict(x):
    return x if pd.notna(x) else {}

class OpenAlex():
    def __init__(self):
        pass


    def search(self,keywords:str,n_results = 100,after = None,before = None):

        if isinstance(keywords,str):
            works = Works().search(keywords)
            if after is not None:
                assert isinstance(after,int), "after must be an integer"
                assert after > 1900, "after must be greater than 1900"
                works = works.filter(publication_year=f">{after}")

            for page in works.paginate(per_page=n_results):
                break

            df_works = pd.DataFrame(page)
            df_works = df_works.dropna(subset = ["title"])
            df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
            df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
            df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
            df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
            df_works["url"] = df_works["id"]
            df_works["content"] = (df_works["title"] + "\n" + df_works["abstract"]).map(lambda x : x.strip())
            df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))

            df_works = df_works.drop(columns = ["abstract_inverted_index"])
            # df_works["subtitle"] = df_works["title"] + " - " + df_works["primary_location"]["source"]["display_name"] + " - " + df_works["publication_year"]
        
            return df_works
        else:
           raise Exception("Keywords must be a string")
    

    def rerank(self,query,df,reranker):
    
        scores = reranker.rank(
            query,
            df["content"].tolist(),
            top_k = len(df),
        )
        scores.sort(key = lambda x : x["corpus_id"])
        scores = [x["score"] for x in scores]
        df["rerank_score"] = scores
        return df


    def make_network(self,df):

        # Initialize your graph
        G = nx.DiGraph()

        for i,row in df.iterrows():
            paper = row.to_dict()
            G.add_node(paper['id'], **paper)
            for reference in paper['referenced_works']:
                if reference not in G:
                    pass
                else:
                    # G.add_node(reference, id=reference, title="", reference_works=[], original=False)
                    G.add_edge(paper['id'], reference, relationship="CITING")
        return G

    def show_network(self,G,height = "750px",notebook = True,color_by = "pagerank"):

        net = Network(height=height, width="100%", bgcolor="#ffffff", font_color="black",notebook = notebook,directed = True,neighborhood_highlight = True)
        net.force_atlas_2based()

        # Add nodes with size reflecting the PageRank to highlight importance
        pagerank = nx.pagerank(G)

        if color_by == "pagerank":
            color_scores = pagerank
        elif color_by == "rerank_score":
            color_scores = {node: G.nodes[node].get("rerank_score", 0) for node in G.nodes}
        else:
            raise ValueError(f"Unknown color_by value: {color_by}")

        # Normalize PageRank values to [0, 1] for color mapping
        min_score = min(color_scores.values())
        max_score = max(color_scores.values())
        norm_color_scores = {node: (color_scores[node] - min_score) / (max_score - min_score) for node in color_scores}



        for node in G.nodes:
            info = G.nodes[node]
            title = info["title"]
            label = title[:30] + " ..."

            title = [title,f"Year: {info['publication_year']}",f"ID: {info['id']}"]
            title = "\n".join(title)

            color_value = norm_color_scores[node]
            # Generating a color from blue (low) to red (high)
            color = plt.cm.RdBu_r(color_value) # coolwarm is a matplotlib colormap from blue to red
            def clamp(x): 
                return int(max(0, min(x*255, 255)))
            color = tuple([clamp(x) for x in color[:3]])
            color = '#%02x%02x%02x' % color
            
            net.add_node(node, title=title,size = pagerank[node]*1000,label = label,color = color)

        # Add edges
        for edge in G.edges:
            net.add_edge(edge[0], edge[1],arrowStrikethrough=True,color = "gray")

        # Show the network
        if notebook:
            return net.show("network.html")
        else:
            return net


    def get_abstract_from_inverted_index(self,index):

        if index is None:
            return ""
        else:

            # Determine the maximum index to know the length of the reconstructed array
            max_index = max([max(positions) for positions in index.values()])
            
            # Initialize a list with placeholders for all positions
            reconstructed = [''] * (max_index + 1)
            
            # Iterate through the inverted index and place each token at its respective position(s)
            for token, positions in index.items():
                for position in positions:
                    reconstructed[position] = token
            
            # Join the tokens to form the reconstructed sentence(s)
            return ' '.join(reconstructed)
        


class OpenAlexRetriever(BaseRetriever):
    min_year:int = 1960
    max_year:int = None
    k:int = 100

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        
        openalex = OpenAlex()

        # Search for documents
        df_docs = openalex.search(query,n_results=self.k,after = self.min_year,before = self.max_year)

        docs = []
        for i,row in df_docs.iterrows():
            num_tokens = row["num_tokens"]

            if num_tokens < 50 or num_tokens > 1000:
                continue

            doc = Document(
                page_content = row["content"],
                metadata = row.to_dict()
            )
            docs.append(doc)
        return docs