ayushnoori commited on
Commit
ca764d6
·
1 Parent(s): ba1c7a0

Add real-time inference

Browse files
.gitignore CHANGED
@@ -5,6 +5,10 @@
5
  # Ignore python cache files
6
  __pycache__/
7
 
 
 
 
8
  # Ignore secrets
9
  .streamlit/secrets.toml
10
- .streamlit/gravity-user-db.json
 
 
5
  # Ignore python cache files
6
  __pycache__/
7
 
8
+ # Ignore model files
9
+ data/*.pt
10
+
11
  # Ignore secrets
12
  .streamlit/secrets.toml
13
+ .streamlit/gravity-user-db.json
14
+ test.ipynb
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Gravity
3
  emoji: 💻
4
  colorFrom: red
5
  colorTo: purple
 
1
  ---
2
+ title: GRAVITY
3
  emoji: 💻
4
  colorFrom: red
5
  colorTo: purple
data/kg_edge_types.csv ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x_type,relation,display_relation,y_type,direction,N
2
+ anatomy,anatomy_protein_present,expression present,gene/protein,forward,3831782
3
+ gene/protein,rev_anatomy_protein_present,expression present,anatomy,reverse,3831782
4
+ drug,drug_drug,synergistic interaction,drug,forward,1433261
5
+ drug,rev_drug_drug,synergistic interaction,drug,reverse,1433261
6
+ anatomy,anatomy_protein_absent,expression absent,gene/protein,forward,324186
7
+ gene/protein,rev_anatomy_protein_absent,expression absent,anatomy,reverse,324186
8
+ gene/protein,protein_protein,ppi,gene/protein,forward,321090
9
+ gene/protein,rev_protein_protein,ppi,gene/protein,reverse,321090
10
+ disease,disease_phenotype_positive,phenotype present,effect/phenotype,forward,200354
11
+ effect/phenotype,rev_disease_phenotype_positive,phenotype present,disease,reverse,200354
12
+ disease,disease_protein,associated with,gene/protein,forward,147984
13
+ gene/protein,rev_disease_protein,associated with,disease,reverse,147984
14
+ biological_process,bioprocess_protein,interacts with,gene/protein,forward,138297
15
+ gene/protein,rev_bioprocess_protein,interacts with,biological_process,reverse,138297
16
+ cellular_component,cellcomp_protein,interacts with,gene/protein,forward,83089
17
+ gene/protein,rev_cellcomp_protein,interacts with,cellular_component,reverse,83089
18
+ disease,disease_protein_negative,expression downregulated,gene/protein,forward,71135
19
+ gene/protein,rev_disease_protein_negative,expression downregulated,disease,reverse,71135
20
+ gene/protein,molfunc_protein,interacts with,molecular_function,forward,70291
21
+ molecular_function,rev_molfunc_protein,interacts with,gene/protein,reverse,70291
22
+ disease,disease_protein_positive,expression upregulated,gene/protein,forward,69488
23
+ gene/protein,rev_disease_protein_positive,expression upregulated,disease,reverse,69488
24
+ drug,drug_effect,side effect,effect/phenotype,forward,64249
25
+ effect/phenotype,rev_drug_effect,side effect,drug,reverse,64249
26
+ biological_process,bioprocess_bioprocess,parent-child,biological_process,forward,50232
27
+ biological_process,rev_bioprocess_bioprocess,parent-child,biological_process,reverse,50232
28
+ gene/protein,pathway_protein,interacts with,pathway,forward,44116
29
+ pathway,rev_pathway_protein,interacts with,gene/protein,reverse,44116
30
+ disease,disease_disease,parent-child,disease,forward,37808
31
+ disease,rev_disease_disease,parent-child,disease,reverse,37808
32
+ disease,contraindication,contraindication,drug,forward,26899
33
+ drug,rev_contraindication,contraindication,disease,reverse,26899
34
+ effect/phenotype,phenotype_phenotype,parent-child,effect/phenotype,forward,20183
35
+ effect/phenotype,rev_phenotype_phenotype,parent-child,effect/phenotype,reverse,20183
36
+ drug,drug_protein,target,gene/protein,forward,18513
37
+ gene/protein,rev_drug_protein,target,drug,reverse,18513
38
+ disease,weak_clinical_evidence,clinical candidate,drug,forward,16111
39
+ drug,rev_weak_clinical_evidence,clinical candidate,disease,reverse,16111
40
+ anatomy,anatomy_anatomy,parent-child,anatomy,forward,14383
41
+ anatomy,rev_anatomy_anatomy,parent-child,anatomy,reverse,14383
42
+ molecular_function,molfunc_molfunc,parent-child,molecular_function,forward,13735
43
+ molecular_function,rev_molfunc_molfunc,parent-child,molecular_function,reverse,13735
44
+ disease,indication,indication,drug,forward,12608
45
+ drug,rev_indication,indication,disease,reverse,12608
46
+ drug,drug_protein,enzyme,gene/protein,forward,5919
47
+ gene/protein,rev_drug_protein,enzyme,drug,reverse,5919
48
+ disease,strong_clinical_evidence,clinical candidate,drug,forward,5352
49
+ drug,rev_strong_clinical_evidence,clinical candidate,disease,reverse,5352
50
+ cellular_component,cellcomp_cellcomp,parent-child,cellular_component,forward,4683
51
+ cellular_component,rev_cellcomp_cellcomp,parent-child,cellular_component,reverse,4683
52
+ effect/phenotype,phenotype_protein,associated with,gene/protein,forward,4437
53
+ gene/protein,rev_phenotype_protein,associated with,effect/phenotype,reverse,4437
54
+ drug,drug_protein,transporter,gene/protein,forward,3349
55
+ gene/protein,rev_drug_protein,transporter,drug,reverse,3349
56
+ pathway,pathway_pathway,parent-child,pathway,forward,2647
57
+ pathway,rev_pathway_pathway,parent-child,pathway,reverse,2647
58
+ disease,exposure_disease,linked to,exposure,forward,2421
59
+ exposure,rev_exposure_disease,linked to,disease,reverse,2421
60
+ disease,off_label_use,off-label use,drug,forward,2370
61
+ drug,rev_off_label_use,off-label use,disease,reverse,2370
62
+ exposure,exposure_exposure,parent-child,exposure,forward,2263
63
+ exposure,rev_exposure_exposure,parent-child,exposure,reverse,2263
64
+ exposure,exposure_protein,interacts with,gene/protein,forward,2012
65
+ gene/protein,rev_exposure_protein,interacts with,exposure,reverse,2012
66
+ biological_process,exposure_bioprocess,interacts with,exposure,forward,1990
67
+ exposure,rev_exposure_bioprocess,interacts with,biological_process,reverse,1990
68
+ drug,drug_protein,carrier,gene/protein,forward,993
69
+ gene/protein,rev_drug_protein,carrier,drug,reverse,993
70
+ disease,disease_phenotype_negative,phenotype absent,effect/phenotype,forward,508
71
+ effect/phenotype,rev_disease_phenotype_negative,phenotype absent,disease,reverse,508
72
+ exposure,exposure_molfunc,interacts with,molecular_function,forward,45
73
+ molecular_function,rev_exposure_molfunc,interacts with,exposure,reverse,45
74
+ cellular_component,exposure_cellcomp,interacts with,exposure,forward,12
75
+ exposure,rev_exposure_cellcomp,interacts with,cellular_component,reverse,12
data/kg_node_types.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ node_type,N
2
+ gene/protein,35198
3
+ biological_process,27668
4
+ disease,22201
5
+ effect/phenotype,16711
6
+ anatomy,14384
7
+ molecular_function,11228
8
+ drug,8160
9
+ cellular_component,4054
10
+ pathway,2629
11
+ exposure,860
pages/about.py CHANGED
@@ -14,4 +14,7 @@ menu_with_redirect()
14
  st.image(str(project_config.MEDIA_DIR / 'about_header.svg'), use_column_width=True)
15
 
16
  # Main content
17
- st.markdown(f"Hello, {st.session_state.name}! Welcome to GRAVITY, a GRaph AI VISualization Tool to query and visualize knowledge graph-grounded biomedical AI models.")
 
 
 
 
14
  st.image(str(project_config.MEDIA_DIR / 'about_header.svg'), use_column_width=True)
15
 
16
  # Main content
17
+ st.markdown(f"Hello, {st.session_state.name}! Welcome to GRAVITY, a **GR**aph **A**I **VI**sualization **T**ool to query and visualize knowledge graph-grounded biomedical AI models.")
18
+
19
+ # Subheader
20
+ st.subheader("About GRAVITY", divider = "grey")
pages/explore.py CHANGED
@@ -14,4 +14,7 @@ menu_with_redirect()
14
  st.image(str(project_config.MEDIA_DIR / 'explore_header.svg'), use_column_width=True)
15
 
16
  # Main content
17
- st.markdown(f"Hello, {st.session_state.name}!")
 
 
 
 
14
  st.image(str(project_config.MEDIA_DIR / 'explore_header.svg'), use_column_width=True)
15
 
16
  # Main content
17
+ # st.markdown(f"Hello, {st.session_state.name}!")
18
+
19
+ # Coming soon
20
+ st.write("Coming soon...")
pages/input.py CHANGED
@@ -1,6 +1,10 @@
1
  import streamlit as st
2
  from menu import menu_with_redirect
3
 
 
 
 
 
4
  # Path manipulation
5
  from pathlib import Path
6
 
@@ -14,4 +18,53 @@ menu_with_redirect()
14
  st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
15
 
16
  # Main content
17
- st.markdown(f"Hello, {st.session_state.name}!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from menu import menu_with_redirect
3
 
4
+ # Standard imports
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
  # Path manipulation
9
  from pathlib import Path
10
 
 
18
  st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
19
 
20
  # Main content
21
+ # st.markdown(f"Hello, {st.session_state.name}!")
22
+
23
+ st.subheader("Construct Query", divider = "red")
24
+
25
+ # Checkbox to allow reverse edges
26
+ allow_reverse_edges = st.checkbox("Reverse Edges", value = False)
27
+
28
+ with st.spinner('Loading knowledge graph...'):
29
+ kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
30
+ node_types = pd.read_csv(project_config.DATA_DIR / 'kg_node_types.csv')
31
+ edge_types = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')
32
+
33
+ if not allow_reverse_edges:
34
+ edge_types = edge_types[edge_types.direction == 'forward']
35
+
36
+ # Select source node type
37
+ source_node_type = st.selectbox("Source Node Type", node_types['node_type'],
38
+ format_func = lambda x: x.replace("_", " "))
39
+
40
+ # Select source node
41
+ source_node = st.selectbox("Source Node", kg_nodes[kg_nodes['node_type'] == source_node_type]['node_name'])
42
+
43
+ # Select target node type
44
+ target_node_type = st.selectbox("Target Node Type", edge_types[edge_types.x_type == source_node_type].y_type.unique(),
45
+ format_func = lambda x: x.replace("_", " "))
46
+
47
+ # Select relation
48
+ relation = st.selectbox("Edge Type", edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique(),
49
+ format_func = lambda x: x.replace("_", "-"))
50
+
51
+ # Button to submit query
52
+ if st.button("Submit Query"):
53
+
54
+ # Save query to session state
55
+ st.session_state.query = {
56
+ "source_node_type": source_node_type,
57
+ "source_node": source_node,
58
+ "target_node_type": target_node_type,
59
+ "relation": relation
60
+ }
61
+
62
+ # # Write query to console
63
+ # st.write("Current Query:")
64
+ # st.write(st.session_state.query)
65
+ st.write("Query submitted.")
66
+
67
+ st.subheader("Knowledge Graph", divider = "red")
68
+ display_data = kg_nodes[['node_id', 'node_type', 'node_name', 'node_source']].copy()
69
+ display_data = display_data.rename(columns = {'node_id': 'ID', 'node_type': 'Type', 'node_name': 'Name', 'node_source': 'Database'})
70
+ st.dataframe(display_data, use_container_width = True)
pages/validate.py CHANGED
@@ -1,8 +1,16 @@
1
  import streamlit as st
2
  from menu import menu_with_redirect
3
 
 
 
 
 
 
 
 
4
  # Path manipulation
5
  from pathlib import Path
 
6
 
7
  # Custom and other imports
8
  import project_config
@@ -14,4 +22,93 @@ menu_with_redirect()
14
  st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)
15
 
16
  # Main content
17
- st.markdown(f"Hello, {st.session_state.name}!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from menu import menu_with_redirect
3
 
4
+ # Standard imports
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
  # Path manipulation
12
  from pathlib import Path
13
+ from huggingface_hub import hf_hub_download
14
 
15
  # Custom and other imports
16
  import project_config
 
22
  st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)
23
 
24
  # Main content
25
+ # st.markdown(f"Hello, {st.session_state.name}!")
26
+
27
+ st.subheader("Model Predictions", divider = "green")
28
+
29
+ # Print current query
30
+ st.markdown(f"**Query:** {st.session_state.query['source_node']} ➡️ {st.session_state.query['relation']} ➡️ {st.session_state.query['target_node_type']}")
31
+
32
+ with st.spinner('Loading knowledge graph...'):
33
+ kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
34
+
35
+ # Get paths to embeddings, relation weights, and edge types
36
+ with st.spinner('Downloading AI model...'):
37
+ embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
38
+ filename="2024_03_29_04_12_52_epoch=3-step=54291_embeddings.pt",
39
+ token=st.secrets["HF_TOKEN"])
40
+ relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
41
+ filename="2024_03_29_04_12_52_epoch=3-step=54291_relation_weights.pt",
42
+ token=st.secrets["HF_TOKEN"])
43
+ edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
44
+ filename="2024_03_29_04_12_52_epoch=3-step=54291_edge_types.pt",
45
+ token=st.secrets["HF_TOKEN"])
46
+
47
+ # Load embeddings, relation weights, and edge types
48
+ with st.spinner('Loading AI model...'):
49
+ embeddings = torch.load(embed_path)
50
+ relation_weights = torch.load(relation_weights_path)
51
+ edge_types = torch.load(edge_types_path)
52
+
53
+ # # Print source node type
54
+ # st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
55
+
56
+ # # Print source node
57
+ # st.write(f"Source Node: {st.session_state.query['source_node']}")
58
+
59
+ # # Print relation
60
+ # st.write(f"Edge Type: {st.session_state.query['relation']}")
61
+
62
+ # # Print target node type
63
+ # st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
64
+
65
+ # Compute predictions
66
+ with st.spinner('Computing predictions...'):
67
+
68
+ source_node_type = st.session_state.query['source_node_type']
69
+ source_node = st.session_state.query['source_node']
70
+ relation = st.session_state.query['relation']
71
+ target_node_type = st.session_state.query['target_node_type']
72
+
73
+ # Get source node index
74
+ src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
75
+
76
+ # Get relation index
77
+ edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
78
+
79
+ # Get target nodes indices
80
+ target_nodes = kg_nodes[kg_nodes.node_type == target_node_type]
81
+ dst_indices = target_nodes.node_index.values
82
+ src_indices = np.repeat(src_index, len(dst_indices))
83
+
84
+ # Retrieve cached embeddings
85
+ src_embeddings = embeddings[src_indices]
86
+ dst_embeddings = embeddings[dst_indices]
87
+
88
+ # Apply activation function
89
+ src_embeddings = F.leaky_relu(src_embeddings)
90
+ dst_embeddings = F.leaky_relu(dst_embeddings)
91
+
92
+ # Get relation weights
93
+ rel_weights = relation_weights[edge_type_index]
94
+
95
+ # Compute weighted dot product
96
+ scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
97
+ scores = torch.sigmoid(scores)
98
+
99
+ # Add scores to dataframe
100
+ target_nodes['score'] = scores.detach().numpy()
101
+
102
+ # Rank target nodes by score
103
+ target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
104
+
105
+ # Add rank to dataframe
106
+ target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
107
+
108
+ # Show top ranked nodes
109
+ top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], 50)
110
+
111
+ # Rename columns
112
+ display_data = target_nodes[['rank', 'node_id', 'node_name', 'node_source', 'score']].iloc[:top_k].copy()
113
+ display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'node_source': 'Database', 'score': 'Score'})
114
+ st.dataframe(display_data, use_container_width = True)
requirements.txt CHANGED
@@ -7,4 +7,5 @@ pathlib
7
  torch
8
  altair<5
9
  gspread
10
- oauth2client
 
 
7
  torch
8
  altair<5
9
  gspread
10
+ oauth2client
11
+ huggingface_hub