File size: 5,686 Bytes
5dc2016
 
 
ddff90b
7ce5b82
ddff90b
7ce5b82
ddff90b
3279179
 
 
4a37eb1
ddff90b
0416a61
bda3587
445b26a
8bf2961
f2852e3
44803cb
e7caceb
6d2e57c
e7caceb
 
6d2e57c
e7caceb
4a37eb1
e7caceb
 
4a37eb1
 
e7caceb
6d2e57c
e7caceb
 
 
6d2e57c
f2852e3
38efeba
847adc5
f2852e3
0416a61
f2852e3
ddff90b
cde5ff7
b102419
 
 
 
 
 
2360c00
388fbdd
81c44d2
 
 
 
 
971a385
81c44d2
 
8f768aa
 
4e45f70
31ca6c1
7780086
 
 
 
 
 
 
0d9531e
 
b28ab8e
0d9531e
b28ab8e
 
 
 
f435314
b28ab8e
 
 
 
f435314
 
 
b28ab8e
 
b424a32
0d9531e
8f768aa
b102419
 
80b9099
b102419
 
 
 
a8b6710
33ca54e
a8b6710
 
 
 
 
7ead1f4
6d2e57c
8e409e1
 
 
 
 
c6171a2
8e409e1
 
 
 
 
 
4a37eb1
 
c6171a2
d5e5338
f2852e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import nltk
nltk.download('stopwords')
nltk.download('punkt')
import pandas as pd
import classify_abs
import extract_abs
#pd.set_option('display.max_colwidth', None)
import streamlit as st
import spacy
import tensorflow as tf
import pickle
import plotly.graph_objects as go

########## Title for the Web App ##########
st.markdown('''<img src="https://huggingface.co/spaces/ncats/EpiPipeline4GARD/resolve/main/NCATS_logo.png" alt="National Center for Advancing Translational Sciences Logo" width=550>''',unsafe_allow_html=True)
#st.markdown("![National Center for Advancing Translational Sciences (NCATS) Logo](https://huggingface.co/spaces/ncats/EpiPipeline4GARD/resolve/main/NCATS_logo.png")
#st.markdown('''<img src="https://huggingface.co/spaces/ncats/EpiPipeline4GARD/raw/main/NCATS_logo.svg" alt="National Center for Advancing Translational Sciences Logo" width="800" height="300">''',unsafe_allow_html=True)
st.title("Epidemiology Extraction Pipeline for Rare Diseases")
#st.subheader("National Center for Advancing Translational Sciences (NIH/NCATS)") 

#### CHANGE SIDEBAR WIDTH ###
st.markdown(
    """
    <style>
    [data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
        width: 250px;
    }
    [data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
        width: 250px;
        margin-left: -350px;
    }
    </style>
    """,
    unsafe_allow_html=True,
)

#max_results is Maximum number of PubMed ID's to retrieve BEFORE filtering
max_results = st.sidebar.number_input("Maximum number of articles to find in PubMed", min_value=1, max_value=None, value=50)

filtering = st.sidebar.radio("What type of filtering would you like?",('Strict', 'Lenient', 'None'))

extract_diseases = st.sidebar.checkbox("Extract Rare Diseases", value=False)

@st.experimental_singleton(show_spinner=False)
def load_models_experimental():
    classify_model_vars = classify_abs.init_classify_model()
    NER_pipeline, entity_classes = extract_abs.init_NER_pipeline()
    GARD_dict, max_length = extract_abs.load_GARD_diseases()
    return classify_model_vars, NER_pipeline, entity_classes, GARD_dict, max_length

@st.cache(allow_output_mutation=True)
def load_models():
    # load the tokenizer
    with open('tokenizer.pickle', 'rb') as handle:
        classify_tokenizer = pickle.load(handle)
    
    # load the model
    classify_model = tf.keras.models.load_model("LSTM_RNN_Model") 
    
    #classify_model_vars = classify_abs.init_classify_model()
    NER_pipeline, entity_classes = extract_abs.init_NER_pipeline()
    GARD_dict, max_length = extract_abs.load_GARD_diseases()
    return classify_tokenizer, classify_model, NER_pipeline, entity_classes, GARD_dict, max_length

@st.cache
 def convert_df(df):
     # IMPORTANT: Cache the conversion to prevent computation on every rerun
     return df.to_csv().encode('utf-8')

#@st.experimental_memo
@st.cache(allow_output_mutation=True)
def epi_sankey(sankey_data):
    gathered, relevant, epidemiologic = sankey_data
    
    fig = go.Figure(data=[go.Sankey(
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(color = "black", width = 0.5),
          label = ["PubMed IDs Gathered", "Irrelevant Abstracts","Relevant Abstracts Gathered","Epidemiologic Abstracts","Not Epidemiologic"],
          color = "blue"
        ),
        #label = ["A1", "A2", "B1", "B2", "C1", "C2"]
        link = dict(
          source = [0, 0, 2, 2],
          target = [2, 1, 3, 4],
          value = [relevant, gathered-relevant, epidemiologic, relevant-epidemiologic]
      ))])
    
    return fig

with st.spinner('Loading Epidemiology Models and Dependencies...'):
    classify_model_vars, NER_pipeline, entity_classes, GARD_dict, max_length = load_models_experimental()
    #classify_tokenizer, classify_model, NER_pipeline, entity_classes, GARD_dict, max_length = load_models()
    #Load spaCy models which cannot be cached due to hash function error
    #nlp = spacy.load('en_core_web_lg')
    #nlpSci = spacy.load("en_ner_bc5cdr_md")
    #nlpSci2 = spacy.load('en_ner_bionlp13cg_md')
    #classify_model_vars = (nlp, nlpSci, nlpSci2, classify_model, classify_tokenizer)
loaded = st.success('All Models and Dependencies Loaded!')
disease_or_gard_id = st.text_input("Input a rare disease term or GARD ID.")
loaded.empty()

st.markdown("Examples of rare diseases include [**Fellman syndrome**](https://rarediseases.info.nih.gov/diseases/1/gracile-syndrome), [**Classic Homocystinuria**](https://rarediseases.info.nih.gov/diseases/6667/classic-homocystinuria), [**phenylketonuria**](https://rarediseases.info.nih.gov/diseases/7383/phenylketonuria), and [GARD:0009941](https://rarediseases.info.nih.gov/diseases/9941/fshmd1a).")

st.markdown("A full list of rare diseases tracked by GARD can be found [here](https://rarediseases.info.nih.gov/diseases/browse-by-first-letter).")

if disease_or_gard_id:
    df, sankey_data = extract_abs.streamlit_extraction(disease_or_gard_id, max_results, filtering,
                                NER_pipeline, entity_classes, 
                                extract_diseases,GARD_dict, max_length, 
                                classify_model_vars)
    st.dataframe(df, height=100)
    st.download_button(
        label="Download epidemiology results for "+disease_or_gard_id+" as CSV",
        data=df.to_csv().encode('utf-8'),
        file_name=disease_or_gard_id+'.csv',
        mime='text/csv',
        )
    #st.dataframe(data=None, width=None, height=None)
    fig = epi_sankey(sankey_data)
        
    if st.button('Display Sankey Diagram'):
        st.plotly_chart(fig, use_container_width=True)
# st.code(body, language="python")