FusionGDA / gda_api.py
ZhaohanM
FusionGDA
7c46397
raw
history blame
3.09 kB
# -*- coding: utf-8 -*-
import gradio as gr
import pandas as pd
import os
import subprocess
def predict_top_100_genes(disease_id):
# Initialize paths
input_csv_path = '/data/downstream/{}_disease.csv'.format(disease_id)
output_csv_path = '/data/downstream/{}_top100.csv'.format(disease_id)
# Check if the output CSV already exists
if not os.path.exists(output_csv_path):
# Proceed with your existing code if the output file doesn't exist
df = pd.read_csv('/data/pretrain/disgenet_latest.csv')
df = df[df['proteinSeq'].notna()]
desired_diseaseDes = df[df['diseaseId'] == disease_id]['diseaseDes'].iloc[0]
related_proteins = df[df['diseaseDes'] == desired_diseaseDes]['proteinSeq'].unique()
df['score'] = df['proteinSeq'].isin(related_proteins).astype(int)
new_df = pd.DataFrame({
'diseaseId': disease_id,
'diseaseDes': desired_diseaseDes,
'geneSymbol': df['geneSymbol'],
'proteinSeq': df['proteinSeq'],
'score': df['score']
}).drop_duplicates().reset_index(drop=True)
new_df.to_csv(input_csv_path, index=False)
# Call the model script only if the output CSV does not exist
script_path = 'model.sh'
subprocess.run(['bash', script_path, input_csv_path, output_csv_path], check=True)
# Read the model output file or the existing file to get the top 100 genes
output_df = pd.read_csv(output_csv_path)
# Update here to select only the required columns and rename them
result_df = output_df[['geneSymbol', 'Prediction_score']].rename(columns={'geneSymbol': 'Gene', 'Prediction_score': 'Score'}).head(100)
return result_df
iface = gr.Interface(
fn=predict_top_100_genes,
inputs=gr.Textbox(lines=1, placeholder="Enter Disease ID Here...", label="Disease ID"),
outputs=gr.Dataframe(label="Predicted Top 100 Related Genes"),
title="Gene Disease Association Prediction",
description = (
"This AI model predicts the top 100 genes associated with a given disease based on 16,733 genes."
" To get started, you need a Disease ID (UMLS CUI), which can be obtained from the DisGeNET database. "
"\n\n**Steps to Obtain a Disease ID from DisGeNET:**\n"
"1. Visit the DisGeNET website: [https://www.disgenet.org/search](https://www.disgenet.org/search).\n"
"2. Use the search bar to enter your disease of interest. For instance, if you're interested in 'Alzheimer's Disease', type 'Alzheimer's Disease' into the search bar.\n"
"3. From the search results, identify the disease you're researching. The Disease ID (UMLS CUI) is listed alongside each disease name, e.g. C0002395.\n"
"4. Enter the Disease ID into the input box below and submit.\n\n"
"The DisGeNET database contains all known gene-disease associations and associated evidence. In addition, it is able to find the corresponding diseases based on a gene.\n"
"\n**The model will take about 18 minutes to inference a new disease.**\n"
)
)
iface.launch(share=True)