ayushnoori commited on
Commit
f18a5e1
Β·
1 Parent(s): ca764d6

Rename to predict and add external database links

Browse files
Files changed (6) hide show
  1. media/predict_header.svg +1 -0
  2. menu.py +2 -1
  3. pages/explore.py +0 -20
  4. pages/predict.py +154 -0
  5. pages/validate.py +3 -95
  6. utils.py +9 -0
media/predict_header.svg ADDED
menu.py CHANGED
@@ -45,8 +45,9 @@ def authenticated_menu():
45
  # st.sidebar.page_link("app.py", label="Switch Accounts", icon="πŸ”’")
46
  st.sidebar.page_link("pages/about.py", label="About", icon="πŸ“–")
47
  st.sidebar.page_link("pages/input.py", label="Input", icon="πŸ’‘")
 
48
  st.sidebar.page_link("pages/validate.py", label="Validate", icon="βœ…")
49
- st.sidebar.page_link("pages/explore.py", label="Explore", icon="πŸ”")
50
  if st.session_state.role in ["admin"]:
51
  st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="πŸ”§")
52
 
 
45
  # st.sidebar.page_link("app.py", label="Switch Accounts", icon="πŸ”’")
46
  st.sidebar.page_link("pages/about.py", label="About", icon="πŸ“–")
47
  st.sidebar.page_link("pages/input.py", label="Input", icon="πŸ’‘")
48
+ st.sidebar.page_link("pages/predict.py", label="Predict", icon="πŸ”")
49
  st.sidebar.page_link("pages/validate.py", label="Validate", icon="βœ…")
50
+ # st.sidebar.page_link("pages/explore.py", label="Explore", icon="πŸ”")
51
  if st.session_state.role in ["admin"]:
52
  st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="πŸ”§")
53
 
pages/explore.py DELETED
@@ -1,20 +0,0 @@
1
- import streamlit as st
2
- from menu import menu_with_redirect
3
-
4
- # Path manipulation
5
- from pathlib import Path
6
-
7
- # Custom and other imports
8
- import project_config
9
-
10
- # Redirect to app.py if not logged in, otherwise show the navigation menu
11
- menu_with_redirect()
12
-
13
- # Header
14
- st.image(str(project_config.MEDIA_DIR / 'explore_header.svg'), use_column_width=True)
15
-
16
- # Main content
17
- # st.markdown(f"Hello, {st.session_state.name}!")
18
-
19
- # Coming soon
20
- st.write("Coming soon...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/predict.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from menu import menu_with_redirect
3
+
4
+ # Standard imports
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ # Path manipulation
12
+ from pathlib import Path
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ # Custom and other imports
16
+ import project_config
17
+ from utils import capitalize_after_slash
18
+
19
+ # Redirect to app.py if not logged in, otherwise show the navigation menu
20
+ menu_with_redirect()
21
+
22
+ # Header
23
+ st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
24
+
25
+ # Main content
26
+ # st.markdown(f"Hello, {st.session_state.name}!")
27
+
28
+ st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
29
+
30
+ # Print current query
31
+ st.markdown(f"**Query:** {st.session_state.query['source_node']} ➑️ {st.session_state.query['relation']} ➑️ {st.session_state.query['target_node_type']}")
32
+
33
+ with st.spinner('Loading knowledge graph...'):
34
+ kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
35
+
36
+ # Get paths to embeddings, relation weights, and edge types
37
+ with st.spinner('Downloading AI model...'):
38
+ embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
39
+ filename="2024_03_29_04_12_52_epoch=3-step=54291_embeddings.pt",
40
+ token=st.secrets["HF_TOKEN"])
41
+ relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
42
+ filename="2024_03_29_04_12_52_epoch=3-step=54291_relation_weights.pt",
43
+ token=st.secrets["HF_TOKEN"])
44
+ edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
45
+ filename="2024_03_29_04_12_52_epoch=3-step=54291_edge_types.pt",
46
+ token=st.secrets["HF_TOKEN"])
47
+
48
+ # Load embeddings, relation weights, and edge types
49
+ with st.spinner('Loading AI model...'):
50
+ embeddings = torch.load(embed_path)
51
+ relation_weights = torch.load(relation_weights_path)
52
+ edge_types = torch.load(edge_types_path)
53
+
54
+ # # Print source node type
55
+ # st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
56
+
57
+ # # Print source node
58
+ # st.write(f"Source Node: {st.session_state.query['source_node']}")
59
+
60
+ # # Print relation
61
+ # st.write(f"Edge Type: {st.session_state.query['relation']}")
62
+
63
+ # # Print target node type
64
+ # st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
65
+
66
+ # Compute predictions
67
+ with st.spinner('Computing predictions...'):
68
+
69
+ source_node_type = st.session_state.query['source_node_type']
70
+ source_node = st.session_state.query['source_node']
71
+ relation = st.session_state.query['relation']
72
+ target_node_type = st.session_state.query['target_node_type']
73
+
74
+ # Get source node index
75
+ src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
76
+
77
+ # Get relation index
78
+ edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
79
+
80
+ # Get target nodes indices
81
+ target_nodes = kg_nodes[kg_nodes.node_type == target_node_type]
82
+ dst_indices = target_nodes.node_index.values
83
+ src_indices = np.repeat(src_index, len(dst_indices))
84
+
85
+ # Retrieve cached embeddings and apply activation function
86
+ src_embeddings = embeddings[src_indices]
87
+ dst_embeddings = embeddings[dst_indices]
88
+ src_embeddings = F.leaky_relu(src_embeddings)
89
+ dst_embeddings = F.leaky_relu(dst_embeddings)
90
+
91
+ # Get relation weights
92
+ rel_weights = relation_weights[edge_type_index]
93
+
94
+ # Compute weighted dot product
95
+ scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
96
+ scores = torch.sigmoid(scores)
97
+
98
+ # Add scores to dataframe
99
+ target_nodes['score'] = scores.detach().numpy()
100
+ target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
101
+ target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
102
+
103
+ # Rename columns
104
+ display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
105
+ display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})
106
+
107
+ # Define dictionary mapping node types to database URLs
108
+ map_dbs = {
109
+ 'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
110
+ 'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
111
+ 'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
112
+ 'disease': lambda x: x, # MONDO
113
+ # pad with 0s to 7 digits
114
+ 'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
115
+ 'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
116
+ 'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
117
+ 'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
118
+ 'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
119
+ 'anatomy': lambda x: x,
120
+ }
121
+
122
+ # Get name of database
123
+ display_database = display_data['Database'].values[0]
124
+
125
+ # Add URLs to database column
126
+ display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
127
+
128
+ # Use multiselect to search for specific nodes
129
+ selected_nodes = st.multiselect('Search for specific nodes.', display_data.Name)
130
+
131
+ # Filter nodes
132
+ if len(selected_nodes) > 0:
133
+ selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
134
+
135
+ # Show filtered nodes
136
+ if target_node_type not in ['disease', 'anatomy']:
137
+ st.dataframe(selected_display_data, use_container_width = True,
138
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
139
+ help = "Click to visit external database.",
140
+ display_text = display_database)})
141
+ else:
142
+ st.dataframe(selected_display_data, use_container_width = True)
143
+
144
+ # Show top ranked nodes
145
+ st.subheader("Model Predictions", divider = "blue")
146
+ top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
147
+
148
+ if target_node_type not in ['disease', 'anatomy']:
149
+ st.dataframe(display_data.iloc[:top_k], use_container_width = True,
150
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
151
+ help = "Click to visit external database.",
152
+ display_text = display_database)})
153
+ else:
154
+ st.dataframe(display_data.iloc[:top_k], use_container_width = True)
pages/validate.py CHANGED
@@ -1,16 +1,8 @@
1
  import streamlit as st
2
  from menu import menu_with_redirect
3
 
4
- # Standard imports
5
- import numpy as np
6
- import pandas as pd
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
  # Path manipulation
12
  from pathlib import Path
13
- from huggingface_hub import hf_hub_download
14
 
15
  # Custom and other imports
16
  import project_config
@@ -24,91 +16,7 @@ st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width
24
  # Main content
25
  # st.markdown(f"Hello, {st.session_state.name}!")
26
 
27
- st.subheader("Model Predictions", divider = "green")
28
-
29
- # Print current query
30
- st.markdown(f"**Query:** {st.session_state.query['source_node']} ➑️ {st.session_state.query['relation']} ➑️ {st.session_state.query['target_node_type']}")
31
-
32
- with st.spinner('Loading knowledge graph...'):
33
- kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
34
-
35
- # Get paths to embeddings, relation weights, and edge types
36
- with st.spinner('Downloading AI model...'):
37
- embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
38
- filename="2024_03_29_04_12_52_epoch=3-step=54291_embeddings.pt",
39
- token=st.secrets["HF_TOKEN"])
40
- relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
41
- filename="2024_03_29_04_12_52_epoch=3-step=54291_relation_weights.pt",
42
- token=st.secrets["HF_TOKEN"])
43
- edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
44
- filename="2024_03_29_04_12_52_epoch=3-step=54291_edge_types.pt",
45
- token=st.secrets["HF_TOKEN"])
46
-
47
- # Load embeddings, relation weights, and edge types
48
- with st.spinner('Loading AI model...'):
49
- embeddings = torch.load(embed_path)
50
- relation_weights = torch.load(relation_weights_path)
51
- edge_types = torch.load(edge_types_path)
52
-
53
- # # Print source node type
54
- # st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
55
-
56
- # # Print source node
57
- # st.write(f"Source Node: {st.session_state.query['source_node']}")
58
-
59
- # # Print relation
60
- # st.write(f"Edge Type: {st.session_state.query['relation']}")
61
-
62
- # # Print target node type
63
- # st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
64
-
65
- # Compute predictions
66
- with st.spinner('Computing predictions...'):
67
-
68
- source_node_type = st.session_state.query['source_node_type']
69
- source_node = st.session_state.query['source_node']
70
- relation = st.session_state.query['relation']
71
- target_node_type = st.session_state.query['target_node_type']
72
-
73
- # Get source node index
74
- src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
75
-
76
- # Get relation index
77
- edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
78
-
79
- # Get target nodes indices
80
- target_nodes = kg_nodes[kg_nodes.node_type == target_node_type]
81
- dst_indices = target_nodes.node_index.values
82
- src_indices = np.repeat(src_index, len(dst_indices))
83
-
84
- # Retrieve cached embeddings
85
- src_embeddings = embeddings[src_indices]
86
- dst_embeddings = embeddings[dst_indices]
87
-
88
- # Apply activation function
89
- src_embeddings = F.leaky_relu(src_embeddings)
90
- dst_embeddings = F.leaky_relu(dst_embeddings)
91
-
92
- # Get relation weights
93
- rel_weights = relation_weights[edge_type_index]
94
-
95
- # Compute weighted dot product
96
- scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
97
- scores = torch.sigmoid(scores)
98
-
99
- # Add scores to dataframe
100
- target_nodes['score'] = scores.detach().numpy()
101
-
102
- # Rank target nodes by score
103
- target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
104
-
105
- # Add rank to dataframe
106
- target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
107
-
108
- # Show top ranked nodes
109
- top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], 50)
110
 
111
- # Rename columns
112
- display_data = target_nodes[['rank', 'node_id', 'node_name', 'node_source', 'score']].iloc[:top_k].copy()
113
- display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'node_source': 'Database', 'score': 'Score'})
114
- st.dataframe(display_data, use_container_width = True)
 
1
  import streamlit as st
2
  from menu import menu_with_redirect
3
 
 
 
 
 
 
 
 
4
  # Path manipulation
5
  from pathlib import Path
 
6
 
7
  # Custom and other imports
8
  import project_config
 
16
  # Main content
17
  # st.markdown(f"Hello, {st.session_state.name}!")
18
 
19
+ st.subheader("Validate Predictions", divider = "green")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Coming soon
22
+ st.write("Coming soon...")
 
 
utils.py CHANGED
@@ -1,6 +1,15 @@
1
  import base64
2
  import streamlit as st
3
 
 
 
 
 
 
 
 
 
 
4
  # From https://stackoverflow.com/questions/73251012/put-logo-and-title-above-on-top-of-page-navigation-in-sidebar-of-streamlit-multi
5
  # See also https://arnaudmiribel.github.io/streamlit-extras/extras/app_logo/
6
  @st.cache_data()
 
1
  import base64
2
  import streamlit as st
3
 
4
+ def capitalize_after_slash(s):
5
+ # Split the string by slashes first
6
+ parts = s.split('/')
7
+ # Capitalize each part separately
8
+ capitalized_parts = [part.title() for part in parts]
9
+ # Rejoin the parts with slashes
10
+ capitalized_string = '/'.join(capitalized_parts).replace('_', ' ')
11
+ return capitalized_string
12
+
13
  # From https://stackoverflow.com/questions/73251012/put-logo-and-title-above-on-top-of-page-navigation-in-sidebar-of-streamlit-multi
14
  # See also https://arnaudmiribel.github.io/streamlit-extras/extras/app_logo/
15
  @st.cache_data()