Spaces:
Sleeping
Sleeping
Commit
·
ca764d6
1
Parent(s):
ba1c7a0
Add real-time inference
Browse files- .gitignore +5 -1
- README.md +1 -1
- data/kg_edge_types.csv +75 -0
- data/kg_node_types.csv +11 -0
- pages/about.py +4 -1
- pages/explore.py +4 -1
- pages/input.py +54 -1
- pages/validate.py +98 -1
- requirements.txt +2 -1
.gitignore
CHANGED
@@ -5,6 +5,10 @@
|
|
5 |
# Ignore python cache files
|
6 |
__pycache__/
|
7 |
|
|
|
|
|
|
|
8 |
# Ignore secrets
|
9 |
.streamlit/secrets.toml
|
10 |
-
.streamlit/gravity-user-db.json
|
|
|
|
5 |
# Ignore python cache files
|
6 |
__pycache__/
|
7 |
|
8 |
+
# Ignore model files
|
9 |
+
data/*.pt
|
10 |
+
|
11 |
# Ignore secrets
|
12 |
.streamlit/secrets.toml
|
13 |
+
.streamlit/gravity-user-db.json
|
14 |
+
test.ipynb
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 💻
|
4 |
colorFrom: red
|
5 |
colorTo: purple
|
|
|
1 |
---
|
2 |
+
title: GRAVITY
|
3 |
emoji: 💻
|
4 |
colorFrom: red
|
5 |
colorTo: purple
|
data/kg_edge_types.csv
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
x_type,relation,display_relation,y_type,direction,N
|
2 |
+
anatomy,anatomy_protein_present,expression present,gene/protein,forward,3831782
|
3 |
+
gene/protein,rev_anatomy_protein_present,expression present,anatomy,reverse,3831782
|
4 |
+
drug,drug_drug,synergistic interaction,drug,forward,1433261
|
5 |
+
drug,rev_drug_drug,synergistic interaction,drug,reverse,1433261
|
6 |
+
anatomy,anatomy_protein_absent,expression absent,gene/protein,forward,324186
|
7 |
+
gene/protein,rev_anatomy_protein_absent,expression absent,anatomy,reverse,324186
|
8 |
+
gene/protein,protein_protein,ppi,gene/protein,forward,321090
|
9 |
+
gene/protein,rev_protein_protein,ppi,gene/protein,reverse,321090
|
10 |
+
disease,disease_phenotype_positive,phenotype present,effect/phenotype,forward,200354
|
11 |
+
effect/phenotype,rev_disease_phenotype_positive,phenotype present,disease,reverse,200354
|
12 |
+
disease,disease_protein,associated with,gene/protein,forward,147984
|
13 |
+
gene/protein,rev_disease_protein,associated with,disease,reverse,147984
|
14 |
+
biological_process,bioprocess_protein,interacts with,gene/protein,forward,138297
|
15 |
+
gene/protein,rev_bioprocess_protein,interacts with,biological_process,reverse,138297
|
16 |
+
cellular_component,cellcomp_protein,interacts with,gene/protein,forward,83089
|
17 |
+
gene/protein,rev_cellcomp_protein,interacts with,cellular_component,reverse,83089
|
18 |
+
disease,disease_protein_negative,expression downregulated,gene/protein,forward,71135
|
19 |
+
gene/protein,rev_disease_protein_negative,expression downregulated,disease,reverse,71135
|
20 |
+
gene/protein,molfunc_protein,interacts with,molecular_function,forward,70291
|
21 |
+
molecular_function,rev_molfunc_protein,interacts with,gene/protein,reverse,70291
|
22 |
+
disease,disease_protein_positive,expression upregulated,gene/protein,forward,69488
|
23 |
+
gene/protein,rev_disease_protein_positive,expression upregulated,disease,reverse,69488
|
24 |
+
drug,drug_effect,side effect,effect/phenotype,forward,64249
|
25 |
+
effect/phenotype,rev_drug_effect,side effect,drug,reverse,64249
|
26 |
+
biological_process,bioprocess_bioprocess,parent-child,biological_process,forward,50232
|
27 |
+
biological_process,rev_bioprocess_bioprocess,parent-child,biological_process,reverse,50232
|
28 |
+
gene/protein,pathway_protein,interacts with,pathway,forward,44116
|
29 |
+
pathway,rev_pathway_protein,interacts with,gene/protein,reverse,44116
|
30 |
+
disease,disease_disease,parent-child,disease,forward,37808
|
31 |
+
disease,rev_disease_disease,parent-child,disease,reverse,37808
|
32 |
+
disease,contraindication,contraindication,drug,forward,26899
|
33 |
+
drug,rev_contraindication,contraindication,disease,reverse,26899
|
34 |
+
effect/phenotype,phenotype_phenotype,parent-child,effect/phenotype,forward,20183
|
35 |
+
effect/phenotype,rev_phenotype_phenotype,parent-child,effect/phenotype,reverse,20183
|
36 |
+
drug,drug_protein,target,gene/protein,forward,18513
|
37 |
+
gene/protein,rev_drug_protein,target,drug,reverse,18513
|
38 |
+
disease,weak_clinical_evidence,clinical candidate,drug,forward,16111
|
39 |
+
drug,rev_weak_clinical_evidence,clinical candidate,disease,reverse,16111
|
40 |
+
anatomy,anatomy_anatomy,parent-child,anatomy,forward,14383
|
41 |
+
anatomy,rev_anatomy_anatomy,parent-child,anatomy,reverse,14383
|
42 |
+
molecular_function,molfunc_molfunc,parent-child,molecular_function,forward,13735
|
43 |
+
molecular_function,rev_molfunc_molfunc,parent-child,molecular_function,reverse,13735
|
44 |
+
disease,indication,indication,drug,forward,12608
|
45 |
+
drug,rev_indication,indication,disease,reverse,12608
|
46 |
+
drug,drug_protein,enzyme,gene/protein,forward,5919
|
47 |
+
gene/protein,rev_drug_protein,enzyme,drug,reverse,5919
|
48 |
+
disease,strong_clinical_evidence,clinical candidate,drug,forward,5352
|
49 |
+
drug,rev_strong_clinical_evidence,clinical candidate,disease,reverse,5352
|
50 |
+
cellular_component,cellcomp_cellcomp,parent-child,cellular_component,forward,4683
|
51 |
+
cellular_component,rev_cellcomp_cellcomp,parent-child,cellular_component,reverse,4683
|
52 |
+
effect/phenotype,phenotype_protein,associated with,gene/protein,forward,4437
|
53 |
+
gene/protein,rev_phenotype_protein,associated with,effect/phenotype,reverse,4437
|
54 |
+
drug,drug_protein,transporter,gene/protein,forward,3349
|
55 |
+
gene/protein,rev_drug_protein,transporter,drug,reverse,3349
|
56 |
+
pathway,pathway_pathway,parent-child,pathway,forward,2647
|
57 |
+
pathway,rev_pathway_pathway,parent-child,pathway,reverse,2647
|
58 |
+
disease,exposure_disease,linked to,exposure,forward,2421
|
59 |
+
exposure,rev_exposure_disease,linked to,disease,reverse,2421
|
60 |
+
disease,off_label_use,off-label use,drug,forward,2370
|
61 |
+
drug,rev_off_label_use,off-label use,disease,reverse,2370
|
62 |
+
exposure,exposure_exposure,parent-child,exposure,forward,2263
|
63 |
+
exposure,rev_exposure_exposure,parent-child,exposure,reverse,2263
|
64 |
+
exposure,exposure_protein,interacts with,gene/protein,forward,2012
|
65 |
+
gene/protein,rev_exposure_protein,interacts with,exposure,reverse,2012
|
66 |
+
biological_process,exposure_bioprocess,interacts with,exposure,forward,1990
|
67 |
+
exposure,rev_exposure_bioprocess,interacts with,biological_process,reverse,1990
|
68 |
+
drug,drug_protein,carrier,gene/protein,forward,993
|
69 |
+
gene/protein,rev_drug_protein,carrier,drug,reverse,993
|
70 |
+
disease,disease_phenotype_negative,phenotype absent,effect/phenotype,forward,508
|
71 |
+
effect/phenotype,rev_disease_phenotype_negative,phenotype absent,disease,reverse,508
|
72 |
+
exposure,exposure_molfunc,interacts with,molecular_function,forward,45
|
73 |
+
molecular_function,rev_exposure_molfunc,interacts with,exposure,reverse,45
|
74 |
+
cellular_component,exposure_cellcomp,interacts with,exposure,forward,12
|
75 |
+
exposure,rev_exposure_cellcomp,interacts with,cellular_component,reverse,12
|
data/kg_node_types.csv
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
node_type,N
|
2 |
+
gene/protein,35198
|
3 |
+
biological_process,27668
|
4 |
+
disease,22201
|
5 |
+
effect/phenotype,16711
|
6 |
+
anatomy,14384
|
7 |
+
molecular_function,11228
|
8 |
+
drug,8160
|
9 |
+
cellular_component,4054
|
10 |
+
pathway,2629
|
11 |
+
exposure,860
|
pages/about.py
CHANGED
@@ -14,4 +14,7 @@ menu_with_redirect()
|
|
14 |
st.image(str(project_config.MEDIA_DIR / 'about_header.svg'), use_column_width=True)
|
15 |
|
16 |
# Main content
|
17 |
-
st.markdown(f"Hello, {st.session_state.name}! Welcome to GRAVITY, a
|
|
|
|
|
|
|
|
14 |
st.image(str(project_config.MEDIA_DIR / 'about_header.svg'), use_column_width=True)
|
15 |
|
16 |
# Main content
|
17 |
+
st.markdown(f"Hello, {st.session_state.name}! Welcome to GRAVITY, a **GR**aph **A**I **VI**sualization **T**ool to query and visualize knowledge graph-grounded biomedical AI models.")
|
18 |
+
|
19 |
+
# Subheader
|
20 |
+
st.subheader("About GRAVITY", divider = "grey")
|
pages/explore.py
CHANGED
@@ -14,4 +14,7 @@ menu_with_redirect()
|
|
14 |
st.image(str(project_config.MEDIA_DIR / 'explore_header.svg'), use_column_width=True)
|
15 |
|
16 |
# Main content
|
17 |
-
st.markdown(f"Hello, {st.session_state.name}!")
|
|
|
|
|
|
|
|
14 |
st.image(str(project_config.MEDIA_DIR / 'explore_header.svg'), use_column_width=True)
|
15 |
|
16 |
# Main content
|
17 |
+
# st.markdown(f"Hello, {st.session_state.name}!")
|
18 |
+
|
19 |
+
# Coming soon
|
20 |
+
st.write("Coming soon...")
|
pages/input.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
import streamlit as st
|
2 |
from menu import menu_with_redirect
|
3 |
|
|
|
|
|
|
|
|
|
4 |
# Path manipulation
|
5 |
from pathlib import Path
|
6 |
|
@@ -14,4 +18,53 @@ menu_with_redirect()
|
|
14 |
st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
|
15 |
|
16 |
# Main content
|
17 |
-
st.markdown(f"Hello, {st.session_state.name}!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from menu import menu_with_redirect
|
3 |
|
4 |
+
# Standard imports
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
# Path manipulation
|
9 |
from pathlib import Path
|
10 |
|
|
|
18 |
st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
|
19 |
|
20 |
# Main content
|
21 |
+
# st.markdown(f"Hello, {st.session_state.name}!")
|
22 |
+
|
23 |
+
st.subheader("Construct Query", divider = "red")
|
24 |
+
|
25 |
+
# Checkbox to allow reverse edges
|
26 |
+
allow_reverse_edges = st.checkbox("Reverse Edges", value = False)
|
27 |
+
|
28 |
+
with st.spinner('Loading knowledge graph...'):
|
29 |
+
kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
|
30 |
+
node_types = pd.read_csv(project_config.DATA_DIR / 'kg_node_types.csv')
|
31 |
+
edge_types = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')
|
32 |
+
|
33 |
+
if not allow_reverse_edges:
|
34 |
+
edge_types = edge_types[edge_types.direction == 'forward']
|
35 |
+
|
36 |
+
# Select source node type
|
37 |
+
source_node_type = st.selectbox("Source Node Type", node_types['node_type'],
|
38 |
+
format_func = lambda x: x.replace("_", " "))
|
39 |
+
|
40 |
+
# Select source node
|
41 |
+
source_node = st.selectbox("Source Node", kg_nodes[kg_nodes['node_type'] == source_node_type]['node_name'])
|
42 |
+
|
43 |
+
# Select target node type
|
44 |
+
target_node_type = st.selectbox("Target Node Type", edge_types[edge_types.x_type == source_node_type].y_type.unique(),
|
45 |
+
format_func = lambda x: x.replace("_", " "))
|
46 |
+
|
47 |
+
# Select relation
|
48 |
+
relation = st.selectbox("Edge Type", edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique(),
|
49 |
+
format_func = lambda x: x.replace("_", "-"))
|
50 |
+
|
51 |
+
# Button to submit query
|
52 |
+
if st.button("Submit Query"):
|
53 |
+
|
54 |
+
# Save query to session state
|
55 |
+
st.session_state.query = {
|
56 |
+
"source_node_type": source_node_type,
|
57 |
+
"source_node": source_node,
|
58 |
+
"target_node_type": target_node_type,
|
59 |
+
"relation": relation
|
60 |
+
}
|
61 |
+
|
62 |
+
# # Write query to console
|
63 |
+
# st.write("Current Query:")
|
64 |
+
# st.write(st.session_state.query)
|
65 |
+
st.write("Query submitted.")
|
66 |
+
|
67 |
+
st.subheader("Knowledge Graph", divider = "red")
|
68 |
+
display_data = kg_nodes[['node_id', 'node_type', 'node_name', 'node_source']].copy()
|
69 |
+
display_data = display_data.rename(columns = {'node_id': 'ID', 'node_type': 'Type', 'node_name': 'Name', 'node_source': 'Database'})
|
70 |
+
st.dataframe(display_data, use_container_width = True)
|
pages/validate.py
CHANGED
@@ -1,8 +1,16 @@
|
|
1 |
import streamlit as st
|
2 |
from menu import menu_with_redirect
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
# Path manipulation
|
5 |
from pathlib import Path
|
|
|
6 |
|
7 |
# Custom and other imports
|
8 |
import project_config
|
@@ -14,4 +22,93 @@ menu_with_redirect()
|
|
14 |
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)
|
15 |
|
16 |
# Main content
|
17 |
-
st.markdown(f"Hello, {st.session_state.name}!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from menu import menu_with_redirect
|
3 |
|
4 |
+
# Standard imports
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
# Path manipulation
|
12 |
from pathlib import Path
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
|
15 |
# Custom and other imports
|
16 |
import project_config
|
|
|
22 |
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)
|
23 |
|
24 |
# Main content
|
25 |
+
# st.markdown(f"Hello, {st.session_state.name}!")
|
26 |
+
|
27 |
+
st.subheader("Model Predictions", divider = "green")
|
28 |
+
|
29 |
+
# Print current query
|
30 |
+
st.markdown(f"**Query:** {st.session_state.query['source_node']} ➡️ {st.session_state.query['relation']} ➡️ {st.session_state.query['target_node_type']}")
|
31 |
+
|
32 |
+
with st.spinner('Loading knowledge graph...'):
|
33 |
+
kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
|
34 |
+
|
35 |
+
# Get paths to embeddings, relation weights, and edge types
|
36 |
+
with st.spinner('Downloading AI model...'):
|
37 |
+
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
38 |
+
filename="2024_03_29_04_12_52_epoch=3-step=54291_embeddings.pt",
|
39 |
+
token=st.secrets["HF_TOKEN"])
|
40 |
+
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
41 |
+
filename="2024_03_29_04_12_52_epoch=3-step=54291_relation_weights.pt",
|
42 |
+
token=st.secrets["HF_TOKEN"])
|
43 |
+
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
44 |
+
filename="2024_03_29_04_12_52_epoch=3-step=54291_edge_types.pt",
|
45 |
+
token=st.secrets["HF_TOKEN"])
|
46 |
+
|
47 |
+
# Load embeddings, relation weights, and edge types
|
48 |
+
with st.spinner('Loading AI model...'):
|
49 |
+
embeddings = torch.load(embed_path)
|
50 |
+
relation_weights = torch.load(relation_weights_path)
|
51 |
+
edge_types = torch.load(edge_types_path)
|
52 |
+
|
53 |
+
# # Print source node type
|
54 |
+
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
|
55 |
+
|
56 |
+
# # Print source node
|
57 |
+
# st.write(f"Source Node: {st.session_state.query['source_node']}")
|
58 |
+
|
59 |
+
# # Print relation
|
60 |
+
# st.write(f"Edge Type: {st.session_state.query['relation']}")
|
61 |
+
|
62 |
+
# # Print target node type
|
63 |
+
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
|
64 |
+
|
65 |
+
# Compute predictions
|
66 |
+
with st.spinner('Computing predictions...'):
|
67 |
+
|
68 |
+
source_node_type = st.session_state.query['source_node_type']
|
69 |
+
source_node = st.session_state.query['source_node']
|
70 |
+
relation = st.session_state.query['relation']
|
71 |
+
target_node_type = st.session_state.query['target_node_type']
|
72 |
+
|
73 |
+
# Get source node index
|
74 |
+
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
|
75 |
+
|
76 |
+
# Get relation index
|
77 |
+
edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
|
78 |
+
|
79 |
+
# Get target nodes indices
|
80 |
+
target_nodes = kg_nodes[kg_nodes.node_type == target_node_type]
|
81 |
+
dst_indices = target_nodes.node_index.values
|
82 |
+
src_indices = np.repeat(src_index, len(dst_indices))
|
83 |
+
|
84 |
+
# Retrieve cached embeddings
|
85 |
+
src_embeddings = embeddings[src_indices]
|
86 |
+
dst_embeddings = embeddings[dst_indices]
|
87 |
+
|
88 |
+
# Apply activation function
|
89 |
+
src_embeddings = F.leaky_relu(src_embeddings)
|
90 |
+
dst_embeddings = F.leaky_relu(dst_embeddings)
|
91 |
+
|
92 |
+
# Get relation weights
|
93 |
+
rel_weights = relation_weights[edge_type_index]
|
94 |
+
|
95 |
+
# Compute weighted dot product
|
96 |
+
scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
|
97 |
+
scores = torch.sigmoid(scores)
|
98 |
+
|
99 |
+
# Add scores to dataframe
|
100 |
+
target_nodes['score'] = scores.detach().numpy()
|
101 |
+
|
102 |
+
# Rank target nodes by score
|
103 |
+
target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
|
104 |
+
|
105 |
+
# Add rank to dataframe
|
106 |
+
target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
|
107 |
+
|
108 |
+
# Show top ranked nodes
|
109 |
+
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], 50)
|
110 |
+
|
111 |
+
# Rename columns
|
112 |
+
display_data = target_nodes[['rank', 'node_id', 'node_name', 'node_source', 'score']].iloc[:top_k].copy()
|
113 |
+
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'node_source': 'Database', 'score': 'Score'})
|
114 |
+
st.dataframe(display_data, use_container_width = True)
|
requirements.txt
CHANGED
@@ -7,4 +7,5 @@ pathlib
|
|
7 |
torch
|
8 |
altair<5
|
9 |
gspread
|
10 |
-
oauth2client
|
|
|
|
7 |
torch
|
8 |
altair<5
|
9 |
gspread
|
10 |
+
oauth2client
|
11 |
+
huggingface_hub
|