Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
from pathlib import Path | |
import ast | |
''' | |
Causal Gene Discovery Model | |
/home/ema30/zaklab/rare_disease_dx/checkpoints/aligner/04_30_22:13:29:55_lr_1e-05_val_simulated_pats.disease_split_val_sim_pats_kg_8.9.21_kg_losstype_gene_multisimilarity/all_udn_patients_kg_8.9.21_kgsolved_manual_baylor_nobgm_distractor_genes_5_candidates_mapped_only_genes | |
Patients-Like-Me Model | |
/home/ema30/zaklab/rare_disease_dx/checkpoints/patient_NCA/04_26_22:17:38:30_lr_5e-05_val_simulated_pats.disease_split_val_sim_pats_kg_8.9.21_kg_losstype_patient_patient_NCA/mygene2_all_sim_all_udn_patients_kg_8.9.21_kgsolved_with_phenotypes | |
Disease Characterization Model | |
/home/ema30/zaklab/rare_disease_dx/checkpoints/patient_NCA/05_13_22:08:00:32_lr_1e-05_val_simulated_pats.disease_split_val_sim_pats_kg_8.9.21_kg_losstype_pd_NCA/mygene2_all_sim_all_udn_patients_kg_8.9.21_kgsolved_with_phenotypes | |
''' | |
gene_scores_df = pd.read_csv('gene_discovery_scores.csv') | |
exomiser_gene_scores_df = pd.read_csv('exomiser_gene_discovery_scores.csv') | |
patient_scores_df = pd.read_csv('patients_like_me_scores.csv') | |
dx_scores_df = pd.read_csv('dx_characterization_scores.csv') | |
plm_attn_df = pd.read_csv('patients_like_me_scores_attn.csv') | |
dx_attn_df = pd.read_csv('dx_characterization_scores_attn.csv') | |
gene_attn_df = pd.read_csv('gene_discovery_scores_attn.csv') | |
exomiser_gene_attn_df = pd.read_csv('exomiser_gene_discovery_scores_attn.csv') | |
diseases_map = {'UDN-P1': 'POLR3-releated leukodystrophy', 'UDN-P2': 'Novel Syndrome', 'UDN-P3':'Coffin-Lowry syndrome' , | |
'UDN-P4': 'automsomal recessive spastic paraplegia type 76', 'UDN-P5': 'atypical presentation of familial cold autoinflammatory syndrome', | |
'UDN-P6': '*GATAD2B*-associated syndrome', 'UDN-P7': 'AR limb-girdle muscular atrophy type 2D', 'UDN-P8': '*ATP5PO*-related Leigh syndrome', 'UDN-P9': 'Spondyloepimetaphyseal dysplasia, Isidor-Toutain type'} | |
genes_map = {'UDN-P3': 'RPS6KA3', 'UDN-P4': 'CAPN1', 'UDN-P5': 'NLRP12, RAPGEFL1', 'UDN-P6': 'GATAD2B', 'UDN-P7': 'SGCA', 'UDN-P8': 'ATP5P0', 'UDN-P9': 'RPL13'} | |
def get_patient(patient_id, attn_df): | |
''' | |
Returns phenotypes, candidate genes, Causal gene, disease | |
''' | |
if patient_id in genes_map: gene = genes_map[patient_id] | |
else: | |
patient_gene_scores_df = gene_scores_df.loc[gene_scores_df['patient_id'] == patient_id] | |
gene = ', '.join(patient_gene_scores_df.loc[patient_gene_scores_df['correct_gene_label'] == 1, 'genes'].tolist()) | |
if patient_id in diseases_map: disease = diseases_map[patient_id] | |
else: | |
patient_dx_scores_df = dx_scores_df.loc[dx_scores_df['patient_id'] == patient_id] | |
disease = ', '.join(patient_dx_scores_df.loc[patient_dx_scores_df['correct_label'] == 1, 'diseases'].tolist()) | |
patient_attn_df = attn_df.loc[attn_df['patient_id'] == patient_id] | |
phenotypes = ', '.join(patient_attn_df['phenotypes'].tolist()) | |
patient_str = f''' | |
**Selected Patient:** {patient_id}<br> | |
**Causal Gene:** *{gene}*<br> | |
**Disease:** {disease}<br> | |
**Phenotypes:** {phenotypes}<br><br> | |
''' | |
return patient_str | |
def read_file(filename): | |
with open(filename, 'r') as file: | |
f = file.read() | |
return f | |
def causal_gene_discovery(patient_id, prioritization_type): | |
if prioritization_type == 'Variant Filtered': | |
scores_df = exomiser_gene_scores_df.loc[exomiser_gene_scores_df['patient_id'] == patient_id] | |
else: | |
scores_df = gene_scores_df.loc[gene_scores_df['patient_id'] == patient_id] | |
# read in gene scores | |
scores_df = scores_df.sort_values("similarities", ascending=False) | |
scores_df['similarities'] = scores_df['similarities'].round(3).astype(str) | |
# add links to gene cards | |
scores_df['genes'] = scores_df['genes'].apply(lambda x: f'<u>[{x}](https://www.genecards.org/cgi-bin/carddisp.pl?gene={x})</u>') | |
# bold/color causal gene | |
scores_df.loc[scores_df['correct_gene_label'] == 1, 'similarities'] = scores_df.loc[scores_df['correct_gene_label'] == 1, 'similarities'].apply(lambda x: f'<span style="color:green">**{x}**</span>') | |
scores_df.loc[scores_df['correct_gene_label'] == 1, 'genes'] = scores_df.loc[scores_df['correct_gene_label'] == 1, 'genes'].apply(lambda x: f'<span style="color:green">**{x}**</span>') | |
#filter df | |
scores_df = scores_df.drop(columns=['patient_id', 'correct_gene_label']).rename(columns={ 'similarities': 'SHEPHERD Score', 'genes': 'Candidate Genes'}) #'correct_gene_label' : 'Is Causal Gene', | |
############# | |
# Attention | |
#read in phenotype attention | |
if prioritization_type == 'Variant Filtered': | |
attn_df = exomiser_gene_attn_df.loc[exomiser_gene_attn_df['patient_id'] == patient_id] | |
else: | |
attn_df = gene_attn_df.loc[gene_attn_df['patient_id'] == patient_id] | |
attn_df = attn_df.sort_values("attention", ascending=False) | |
attn_df['attention'] = attn_df['attention'].round(4) | |
attn_df = attn_df.drop(columns=['patient_id', 'degrees']) | |
############# | |
# KG neighborhood | |
#image_loc = f'images/{patient_id}.png' | |
html_file = f'https://michellemli.github.io/test_html/{patient_id}.html' | |
kg_html = f'''<iframe id="igraph" scrolling="no" style="border:none; width: 100%; height: 600px" seamless="seamless" src="{html_file}"></iframe>''' | |
#patient_info | |
patient = get_patient(patient_id, gene_attn_df) | |
return patient, scores_df, attn_df, kg_html | |
def patients_like_me(patient_id, k=10): | |
scores_df = patient_scores_df.loc[patient_scores_df['patient_id'] == patient_id] | |
scores_df = scores_df.sort_values("similarities", ascending=False) | |
#scores_df['phenotypes'] ='PHEN' | |
# add links to disease pages | |
scores_df['disease_ids'] = scores_df['disease_ids'].apply(lambda x: f'(https://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=en&Expert={x})</u>') | |
scores_df['diseases'] = scores_df['diseases'].apply(lambda x: f'<u>[{x}]') | |
scores_df['diseases'] = scores_df['diseases'] + scores_df['disease_ids'] | |
scores_df['genes'] = scores_df['genes'].apply(lambda x: f'<u>[{x}](https://www.genecards.org/cgi-bin/carddisp.pl?gene={x})</u>') | |
# bold/color patients with same causal gene | |
scores_df.loc[scores_df['correct_label'] == 1, 'candidate_patients'] = scores_df.loc[scores_df['correct_label'] == 1, 'candidate_patients'].apply(lambda x: f'<span style="color:green">**{x}**</span>') | |
scores_df.loc[scores_df['correct_label'] == 1, 'genes'] = scores_df.loc[scores_df['correct_label'] == 1, 'genes'].apply(lambda x: f'<span style="color:green">**{x}**</span>') | |
scores_df.loc[scores_df['correct_label'] == 1, 'diseases'] = scores_df.loc[scores_df['correct_label'] == 1, 'diseases'].apply(lambda x: f'<span style="color:green">**{x}**</span>') | |
scores_df = scores_df.drop(columns=['patient_id', 'similarities', 'correct_label', 'disease_ids']).rename(columns={'candidate_patients': 'Candidate Patient', 'genes': 'Candidate Patient\'s Gene', 'diseases': 'Candidate Patient\'s Disease' }) #'phenotypes': 'Candidate Patient\'s Phenotypes' | |
scores_df = scores_df.head(k) | |
#read in phenotype attention | |
attn_df = plm_attn_df.loc[plm_attn_df['patient_id'] == patient_id] | |
attn_df = attn_df.sort_values("attention", ascending=False) | |
attn_df['attention'] = attn_df['attention'].round(4) | |
attn_df = attn_df.drop(columns=['patient_id', 'degrees']) | |
#patient_info | |
patient = get_patient(patient_id, plm_attn_df) | |
return patient, scores_df, attn_df | |
def disease_characterization(patient_id, k=10): | |
#TODO: limit # of rows | |
scores_df = dx_scores_df.loc[dx_scores_df['patient_id'] == patient_id] | |
scores_df = scores_df.sort_values("similarities", ascending=False) | |
scores_df = scores_df.head(k) | |
scores_df.loc[ scores_df['disease_ids'].str.contains('Coxa vara'), 'disease_ids'] = '2812' | |
scores_df.loc[ scores_df['disease_ids'].str.contains('Multiple epiphyseal dysplasia'), 'disease_ids'] = '2654' | |
scores_df['disease_ids'] = scores_df['disease_ids'].apply(lambda x: ast.literal_eval(x)) | |
scores_df['type_disease_ids'] = scores_df['disease_ids'].apply(lambda x: type(x)) | |
scores_df.loc[scores_df['type_disease_ids'] == list, 'disease_ids'] = scores_df.loc[scores_df['type_disease_ids'] == list, 'disease_ids'].apply(lambda x: x[0]) | |
# add links to disease pages | |
scores_df['disease_ids'] = scores_df['disease_ids'].apply(lambda x: f'(https://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=en&Expert={x})</u>') | |
scores_df['diseases'] = scores_df['diseases'].apply(lambda x: f'<u>[{x}]') | |
scores_df['diseases'] = scores_df['diseases'] + scores_df['disease_ids'] | |
# one disease couldn't map to orphanet | |
scores_df.loc[ scores_df['disease_ids'].str.contains('33657'), 'diseases'] = '<u>[leukodystrophy, hypomyelinating, 20](https://www.omim.org/entry/619071)</u>' | |
scores_df.loc[ scores_df['disease_ids'].str.contains('2654'), 'diseases'] = '<u>[Multiple epiphyseal dysplasia](https://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=EN&Expert=251)</u>' | |
scores_df.loc[ scores_df['disease_ids'].str.contains('2812'), 'diseases'] = '<u>[Coxa vara](https://omim.org/entry/122750)</u>' | |
scores_df = scores_df.drop(columns=['patient_id', 'similarities', 'correct_label', 'disease_ids','type_disease_ids']).rename(columns={'diseases' : 'Disease'}) | |
#read in phenotype attention | |
attn_df = dx_attn_df.loc[dx_attn_df['patient_id'] == patient_id] | |
attn_df = attn_df.sort_values("attention", ascending=False) | |
attn_df['attention'] = attn_df['attention'].round(4) | |
attn_df = attn_df.drop(columns=['patient_id', 'degrees']) | |
#patient_info | |
patient = get_patient(patient_id, dx_attn_df) | |
return patient, scores_df, attn_df | |
def get_umap(umap_type): | |
# get UMAP | |
if umap_type == 'disease': | |
html_file = 'https://michellemli.github.io/test_html/shepherd_disease_characterization_umap.html' | |
#html_file = read_file('images/udn_orphafit_patient_umap_nneigh=50_mindist=0.9_spread=1.0colored_by_disease_category.html') | |
elif umap_type == 'patient': | |
html_file = 'https://michellemli.github.io/test_html/shepherd_patient_umap.html' | |
else: | |
raise NotImplementedError | |
# return f"""<iframe style="width: 100%; height: 480px" name="result" allow="midi; | |
# display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
# allow-scripts allow-same-origin allow-popups | |
# allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
# allowpaymentrequest="" frameborder="0" srcdoc='{html_file}'></iframe>""" | |
return f'''<embed style="border: none;" src="{html_file}" dpi="300" width="100%" height="750px" />''' | |
#return f'''<iframe id="igraph" scrolling="no" style="border:none; width: 100%; height: 750px" seamless="seamless" src="{html_file}"></iframe>''' | |
with gr.Blocks() as demo: #css="#gene_attn_accordion {text-align: center}" css="kg_neigh {width: 70%}" | |
gr.Markdown('<center><h1>AI-assisted Rare Disease Diagnosis with SHEPHERD</h1></center>') | |
#gr.Markdown('<center><h2>A few SHot Explainable Predictor for Hard-to-diagnosE Rare Diseases</h2></center>') | |
with gr.Tabs(): | |
with gr.TabItem("Causal Gene Discovery"): | |
with gr.Column(): | |
gr.Markdown('<center><h2>Select a patient to view SHEPHERD\'s predictions</h2></center>') | |
gene_dropdown = gr.Dropdown(choices=['UDN-P1', 'UDN-P2'], label='Rare Disease Patients', type='value') #value='UDN-P1', | |
gene_radio = gr.Radio(choices=['Expert Curated', 'Variant Filtered'], value='Expert Curated', label='Type of Gene List') | |
patient_info = gr.Markdown() #get_patient('UDN-P1') | |
with gr.Accordion(label=f'SHEPHERD\'s Ranking of Patient\'s Candidate Genes', open=True, elem_id='gene_accordion'): | |
#gr.Markdown(f'<center><h3>SHEPHERD\'s Ranking of Patient\'s Candidate Genes</h3></center>') | |
gr.Markdown('The patient\'s causal gene (i.e. gene harboring a variant that explains the patient\'s symptoms) is colored in green.') | |
gene_dataframe = gr.Dataframe(max_rows=5, elem_id="gene_df", datatype = 'markdown', headers=['Candidate Genes', 'SHEPHERD Score' ], overflow_row_behaviour='paginate') # label='Candidate Genes', show_label=False, | |
with gr.Accordion(label=f'SHEPHERD\'s Attention to Patient\'s Phenotypes', open=False, elem_id='gene_attn_accordion'): | |
#gr.Markdown(f'<center><h3>SHEPHERD\'s Attention to Patient\'s Phenotypes</h3></center>') | |
gene_attn_dataframe = gr.Dataframe(max_rows=5, elem_id="gene_attn_df", headers=['Phenotypes', 'Attention' ], overflow_row_behaviour='paginate') # label='Candidate Genes', show_label=False, | |
with gr.Accordion(label=f'Visualization of Patient\'s Neighborhood in the Knowledge Graph', open=False, elem_id='kg_neigh_accordion'): | |
#kg_neighborhood_image = gr.Image(elem_id='kg_neigh')#.style(height=200, width=200) | |
kg_neighborhood_image = gr.HTML(elem_id = 'kg_patient_neighborhood') | |
#gene_button = gr.Button("Go") | |
with gr.TabItem("Patients Like Me"): | |
gr.HTML(get_umap('patient')) | |
gr.Markdown('<center><h2>Select a patient to view SHEPHERD\'s predictions</h2></center>') | |
patient_dropdown = gr.Dropdown(choices=['UDN-P3','UDN-P4','UDN-P5','UDN-P6'], label='Rare Disease Patients', type='value') | |
p_patient_info = gr.Markdown() | |
with gr.Accordion(label=f'Top 10 Most Similar Patients according to SHEPHERD', open=True, elem_id='pt_accordion'): # | |
patient_dataframe = gr.Dataframe(max_rows=10, datatype = 'markdown', show_label=False, elem_id="pat_df", headers=['Candidate Patient', 'Candidate Patient\'s Gene', 'Candidate Patient\'s Disease' ]) #'Candidate Patient\'s Phenotypes' | |
#patient_button = gr.Button("Go") | |
with gr.Accordion(label='SHEPHERD\'s Attention to Patient\'s Phenotypes', open=False, elem_id='pt_attn_accordion'): | |
pt_attn_dataframe = gr.Dataframe(max_rows=5, elem_id="pt_attn_df", headers=['Phenotypes', 'Attention' ], overflow_row_behaviour='paginate') | |
with gr.TabItem("Disease Characterization"): | |
gr.HTML(get_umap('disease')) | |
gr.Markdown('<center><h2>Select a patient to view SHEPHERD\'s predictions</h2></center>') | |
dx_dropdown = gr.Dropdown(choices=['UDN-P7','UDN-P8','UDN-P9','UDN-P2'], label='Rare Disease Patients', type='value') | |
dx_patient_info = gr.Markdown() | |
with gr.Accordion(label='Top 10 Most Similar Diseases according to SHEPHERD', open=True, elem_id='pt_accordion'): # | |
dx_dataframe = gr.Dataframe(max_rows=10, datatype = 'markdown', show_label=False, elem_id="dx_df", headers=['Diseases']) | |
with gr.Accordion(label='SHEPHERD\'s Attention to Patient\'s Phenotypes', open=False, elem_id='dx_attn_accordion'): | |
dx_attn_dataframe = gr.Dataframe(max_rows=5, elem_id="dx_attn_df", headers=['Phenotypes', 'Attention' ], overflow_row_behaviour='paginate') | |
#dx_button = gr.Button("Go") | |
gene_dropdown.change(causal_gene_discovery, inputs=[gene_dropdown,gene_radio], outputs=[patient_info, gene_dataframe, gene_attn_dataframe, kg_neighborhood_image]) | |
gene_radio.change(causal_gene_discovery, inputs=[gene_dropdown,gene_radio], outputs=[patient_info, gene_dataframe, gene_attn_dataframe, kg_neighborhood_image]) | |
patient_dropdown.change(patients_like_me, inputs=patient_dropdown, outputs=[p_patient_info, patient_dataframe, pt_attn_dataframe]) | |
dx_dropdown.change(disease_characterization, inputs=dx_dropdown, outputs=[dx_patient_info, dx_dataframe, dx_attn_dataframe]) | |
#gene_dropdown.change(get_patient, inputs=gene_dropdown, outputs=patient_info) | |
#gene_button.click(causal_gene_discovery, inputs=gene_dropdown, outputs=[gene_dataframe,gene_attn_dataframe, kg_neighborhood_image]) | |
#patient_button.click(patients_like_me, inputs=patient_dropdown, outputs=patient_dataframe) | |
#dx_button.click(disease_characterization, inputs=dx_dropdown, outputs=dx_dataframe) | |
demo.launch( ) #server_port=50018, share=True | |