Spaces:
Sleeping
Sleeping
Commit
Β·
f18a5e1
1
Parent(s):
ca764d6
Rename to predict and add external database links
Browse files- media/predict_header.svg +1 -0
- menu.py +2 -1
- pages/explore.py +0 -20
- pages/predict.py +154 -0
- pages/validate.py +3 -95
- utils.py +9 -0
media/predict_header.svg
ADDED
|
menu.py
CHANGED
@@ -45,8 +45,9 @@ def authenticated_menu():
|
|
45 |
# st.sidebar.page_link("app.py", label="Switch Accounts", icon="π")
|
46 |
st.sidebar.page_link("pages/about.py", label="About", icon="π")
|
47 |
st.sidebar.page_link("pages/input.py", label="Input", icon="π‘")
|
|
|
48 |
st.sidebar.page_link("pages/validate.py", label="Validate", icon="β
")
|
49 |
-
st.sidebar.page_link("pages/explore.py", label="Explore", icon="π")
|
50 |
if st.session_state.role in ["admin"]:
|
51 |
st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="π§")
|
52 |
|
|
|
45 |
# st.sidebar.page_link("app.py", label="Switch Accounts", icon="π")
|
46 |
st.sidebar.page_link("pages/about.py", label="About", icon="π")
|
47 |
st.sidebar.page_link("pages/input.py", label="Input", icon="π‘")
|
48 |
+
st.sidebar.page_link("pages/predict.py", label="Predict", icon="π")
|
49 |
st.sidebar.page_link("pages/validate.py", label="Validate", icon="β
")
|
50 |
+
# st.sidebar.page_link("pages/explore.py", label="Explore", icon="π")
|
51 |
if st.session_state.role in ["admin"]:
|
52 |
st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="π§")
|
53 |
|
pages/explore.py
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from menu import menu_with_redirect
|
3 |
-
|
4 |
-
# Path manipulation
|
5 |
-
from pathlib import Path
|
6 |
-
|
7 |
-
# Custom and other imports
|
8 |
-
import project_config
|
9 |
-
|
10 |
-
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
11 |
-
menu_with_redirect()
|
12 |
-
|
13 |
-
# Header
|
14 |
-
st.image(str(project_config.MEDIA_DIR / 'explore_header.svg'), use_column_width=True)
|
15 |
-
|
16 |
-
# Main content
|
17 |
-
# st.markdown(f"Hello, {st.session_state.name}!")
|
18 |
-
|
19 |
-
# Coming soon
|
20 |
-
st.write("Coming soon...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/predict.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from menu import menu_with_redirect
|
3 |
+
|
4 |
+
# Standard imports
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
# Path manipulation
|
12 |
+
from pathlib import Path
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
|
15 |
+
# Custom and other imports
|
16 |
+
import project_config
|
17 |
+
from utils import capitalize_after_slash
|
18 |
+
|
19 |
+
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
20 |
+
menu_with_redirect()
|
21 |
+
|
22 |
+
# Header
|
23 |
+
st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
|
24 |
+
|
25 |
+
# Main content
|
26 |
+
# st.markdown(f"Hello, {st.session_state.name}!")
|
27 |
+
|
28 |
+
st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
|
29 |
+
|
30 |
+
# Print current query
|
31 |
+
st.markdown(f"**Query:** {st.session_state.query['source_node']} β‘οΈ {st.session_state.query['relation']} β‘οΈ {st.session_state.query['target_node_type']}")
|
32 |
+
|
33 |
+
with st.spinner('Loading knowledge graph...'):
|
34 |
+
kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
|
35 |
+
|
36 |
+
# Get paths to embeddings, relation weights, and edge types
|
37 |
+
with st.spinner('Downloading AI model...'):
|
38 |
+
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
39 |
+
filename="2024_03_29_04_12_52_epoch=3-step=54291_embeddings.pt",
|
40 |
+
token=st.secrets["HF_TOKEN"])
|
41 |
+
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
42 |
+
filename="2024_03_29_04_12_52_epoch=3-step=54291_relation_weights.pt",
|
43 |
+
token=st.secrets["HF_TOKEN"])
|
44 |
+
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
45 |
+
filename="2024_03_29_04_12_52_epoch=3-step=54291_edge_types.pt",
|
46 |
+
token=st.secrets["HF_TOKEN"])
|
47 |
+
|
48 |
+
# Load embeddings, relation weights, and edge types
|
49 |
+
with st.spinner('Loading AI model...'):
|
50 |
+
embeddings = torch.load(embed_path)
|
51 |
+
relation_weights = torch.load(relation_weights_path)
|
52 |
+
edge_types = torch.load(edge_types_path)
|
53 |
+
|
54 |
+
# # Print source node type
|
55 |
+
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
|
56 |
+
|
57 |
+
# # Print source node
|
58 |
+
# st.write(f"Source Node: {st.session_state.query['source_node']}")
|
59 |
+
|
60 |
+
# # Print relation
|
61 |
+
# st.write(f"Edge Type: {st.session_state.query['relation']}")
|
62 |
+
|
63 |
+
# # Print target node type
|
64 |
+
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
|
65 |
+
|
66 |
+
# Compute predictions
|
67 |
+
with st.spinner('Computing predictions...'):
|
68 |
+
|
69 |
+
source_node_type = st.session_state.query['source_node_type']
|
70 |
+
source_node = st.session_state.query['source_node']
|
71 |
+
relation = st.session_state.query['relation']
|
72 |
+
target_node_type = st.session_state.query['target_node_type']
|
73 |
+
|
74 |
+
# Get source node index
|
75 |
+
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
|
76 |
+
|
77 |
+
# Get relation index
|
78 |
+
edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
|
79 |
+
|
80 |
+
# Get target nodes indices
|
81 |
+
target_nodes = kg_nodes[kg_nodes.node_type == target_node_type]
|
82 |
+
dst_indices = target_nodes.node_index.values
|
83 |
+
src_indices = np.repeat(src_index, len(dst_indices))
|
84 |
+
|
85 |
+
# Retrieve cached embeddings and apply activation function
|
86 |
+
src_embeddings = embeddings[src_indices]
|
87 |
+
dst_embeddings = embeddings[dst_indices]
|
88 |
+
src_embeddings = F.leaky_relu(src_embeddings)
|
89 |
+
dst_embeddings = F.leaky_relu(dst_embeddings)
|
90 |
+
|
91 |
+
# Get relation weights
|
92 |
+
rel_weights = relation_weights[edge_type_index]
|
93 |
+
|
94 |
+
# Compute weighted dot product
|
95 |
+
scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
|
96 |
+
scores = torch.sigmoid(scores)
|
97 |
+
|
98 |
+
# Add scores to dataframe
|
99 |
+
target_nodes['score'] = scores.detach().numpy()
|
100 |
+
target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
|
101 |
+
target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
|
102 |
+
|
103 |
+
# Rename columns
|
104 |
+
display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
|
105 |
+
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})
|
106 |
+
|
107 |
+
# Define dictionary mapping node types to database URLs
|
108 |
+
map_dbs = {
|
109 |
+
'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
|
110 |
+
'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
|
111 |
+
'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
|
112 |
+
'disease': lambda x: x, # MONDO
|
113 |
+
# pad with 0s to 7 digits
|
114 |
+
'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
|
115 |
+
'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
|
116 |
+
'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
|
117 |
+
'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
|
118 |
+
'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
|
119 |
+
'anatomy': lambda x: x,
|
120 |
+
}
|
121 |
+
|
122 |
+
# Get name of database
|
123 |
+
display_database = display_data['Database'].values[0]
|
124 |
+
|
125 |
+
# Add URLs to database column
|
126 |
+
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
|
127 |
+
|
128 |
+
# Use multiselect to search for specific nodes
|
129 |
+
selected_nodes = st.multiselect('Search for specific nodes.', display_data.Name)
|
130 |
+
|
131 |
+
# Filter nodes
|
132 |
+
if len(selected_nodes) > 0:
|
133 |
+
selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
|
134 |
+
|
135 |
+
# Show filtered nodes
|
136 |
+
if target_node_type not in ['disease', 'anatomy']:
|
137 |
+
st.dataframe(selected_display_data, use_container_width = True,
|
138 |
+
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
139 |
+
help = "Click to visit external database.",
|
140 |
+
display_text = display_database)})
|
141 |
+
else:
|
142 |
+
st.dataframe(selected_display_data, use_container_width = True)
|
143 |
+
|
144 |
+
# Show top ranked nodes
|
145 |
+
st.subheader("Model Predictions", divider = "blue")
|
146 |
+
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
|
147 |
+
|
148 |
+
if target_node_type not in ['disease', 'anatomy']:
|
149 |
+
st.dataframe(display_data.iloc[:top_k], use_container_width = True,
|
150 |
+
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
151 |
+
help = "Click to visit external database.",
|
152 |
+
display_text = display_database)})
|
153 |
+
else:
|
154 |
+
st.dataframe(display_data.iloc[:top_k], use_container_width = True)
|
pages/validate.py
CHANGED
@@ -1,16 +1,8 @@
|
|
1 |
import streamlit as st
|
2 |
from menu import menu_with_redirect
|
3 |
|
4 |
-
# Standard imports
|
5 |
-
import numpy as np
|
6 |
-
import pandas as pd
|
7 |
-
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
import torch.nn.functional as F
|
10 |
-
|
11 |
# Path manipulation
|
12 |
from pathlib import Path
|
13 |
-
from huggingface_hub import hf_hub_download
|
14 |
|
15 |
# Custom and other imports
|
16 |
import project_config
|
@@ -24,91 +16,7 @@ st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width
|
|
24 |
# Main content
|
25 |
# st.markdown(f"Hello, {st.session_state.name}!")
|
26 |
|
27 |
-
st.subheader("
|
28 |
-
|
29 |
-
# Print current query
|
30 |
-
st.markdown(f"**Query:** {st.session_state.query['source_node']} β‘οΈ {st.session_state.query['relation']} β‘οΈ {st.session_state.query['target_node_type']}")
|
31 |
-
|
32 |
-
with st.spinner('Loading knowledge graph...'):
|
33 |
-
kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
|
34 |
-
|
35 |
-
# Get paths to embeddings, relation weights, and edge types
|
36 |
-
with st.spinner('Downloading AI model...'):
|
37 |
-
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
38 |
-
filename="2024_03_29_04_12_52_epoch=3-step=54291_embeddings.pt",
|
39 |
-
token=st.secrets["HF_TOKEN"])
|
40 |
-
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
41 |
-
filename="2024_03_29_04_12_52_epoch=3-step=54291_relation_weights.pt",
|
42 |
-
token=st.secrets["HF_TOKEN"])
|
43 |
-
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
44 |
-
filename="2024_03_29_04_12_52_epoch=3-step=54291_edge_types.pt",
|
45 |
-
token=st.secrets["HF_TOKEN"])
|
46 |
-
|
47 |
-
# Load embeddings, relation weights, and edge types
|
48 |
-
with st.spinner('Loading AI model...'):
|
49 |
-
embeddings = torch.load(embed_path)
|
50 |
-
relation_weights = torch.load(relation_weights_path)
|
51 |
-
edge_types = torch.load(edge_types_path)
|
52 |
-
|
53 |
-
# # Print source node type
|
54 |
-
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
|
55 |
-
|
56 |
-
# # Print source node
|
57 |
-
# st.write(f"Source Node: {st.session_state.query['source_node']}")
|
58 |
-
|
59 |
-
# # Print relation
|
60 |
-
# st.write(f"Edge Type: {st.session_state.query['relation']}")
|
61 |
-
|
62 |
-
# # Print target node type
|
63 |
-
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
|
64 |
-
|
65 |
-
# Compute predictions
|
66 |
-
with st.spinner('Computing predictions...'):
|
67 |
-
|
68 |
-
source_node_type = st.session_state.query['source_node_type']
|
69 |
-
source_node = st.session_state.query['source_node']
|
70 |
-
relation = st.session_state.query['relation']
|
71 |
-
target_node_type = st.session_state.query['target_node_type']
|
72 |
-
|
73 |
-
# Get source node index
|
74 |
-
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
|
75 |
-
|
76 |
-
# Get relation index
|
77 |
-
edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
|
78 |
-
|
79 |
-
# Get target nodes indices
|
80 |
-
target_nodes = kg_nodes[kg_nodes.node_type == target_node_type]
|
81 |
-
dst_indices = target_nodes.node_index.values
|
82 |
-
src_indices = np.repeat(src_index, len(dst_indices))
|
83 |
-
|
84 |
-
# Retrieve cached embeddings
|
85 |
-
src_embeddings = embeddings[src_indices]
|
86 |
-
dst_embeddings = embeddings[dst_indices]
|
87 |
-
|
88 |
-
# Apply activation function
|
89 |
-
src_embeddings = F.leaky_relu(src_embeddings)
|
90 |
-
dst_embeddings = F.leaky_relu(dst_embeddings)
|
91 |
-
|
92 |
-
# Get relation weights
|
93 |
-
rel_weights = relation_weights[edge_type_index]
|
94 |
-
|
95 |
-
# Compute weighted dot product
|
96 |
-
scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
|
97 |
-
scores = torch.sigmoid(scores)
|
98 |
-
|
99 |
-
# Add scores to dataframe
|
100 |
-
target_nodes['score'] = scores.detach().numpy()
|
101 |
-
|
102 |
-
# Rank target nodes by score
|
103 |
-
target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
|
104 |
-
|
105 |
-
# Add rank to dataframe
|
106 |
-
target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
|
107 |
-
|
108 |
-
# Show top ranked nodes
|
109 |
-
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], 50)
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'node_source': 'Database', 'score': 'Score'})
|
114 |
-
st.dataframe(display_data, use_container_width = True)
|
|
|
1 |
import streamlit as st
|
2 |
from menu import menu_with_redirect
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
# Path manipulation
|
5 |
from pathlib import Path
|
|
|
6 |
|
7 |
# Custom and other imports
|
8 |
import project_config
|
|
|
16 |
# Main content
|
17 |
# st.markdown(f"Hello, {st.session_state.name}!")
|
18 |
|
19 |
+
st.subheader("Validate Predictions", divider = "green")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
# Coming soon
|
22 |
+
st.write("Coming soon...")
|
|
|
|
utils.py
CHANGED
@@ -1,6 +1,15 @@
|
|
1 |
import base64
|
2 |
import streamlit as st
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
# From https://stackoverflow.com/questions/73251012/put-logo-and-title-above-on-top-of-page-navigation-in-sidebar-of-streamlit-multi
|
5 |
# See also https://arnaudmiribel.github.io/streamlit-extras/extras/app_logo/
|
6 |
@st.cache_data()
|
|
|
1 |
import base64
|
2 |
import streamlit as st
|
3 |
|
4 |
+
def capitalize_after_slash(s):
|
5 |
+
# Split the string by slashes first
|
6 |
+
parts = s.split('/')
|
7 |
+
# Capitalize each part separately
|
8 |
+
capitalized_parts = [part.title() for part in parts]
|
9 |
+
# Rejoin the parts with slashes
|
10 |
+
capitalized_string = '/'.join(capitalized_parts).replace('_', ' ')
|
11 |
+
return capitalized_string
|
12 |
+
|
13 |
# From https://stackoverflow.com/questions/73251012/put-logo-and-title-above-on-top-of-page-navigation-in-sidebar-of-streamlit-multi
|
14 |
# See also https://arnaudmiribel.github.io/streamlit-extras/extras/app_logo/
|
15 |
@st.cache_data()
|