arittrabag commited on
Commit
d04b4c3
·
verified ·
1 Parent(s): 3bb58c3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shap
2
+ import gradio as gr
3
+ import numpy as np
4
+ import shap
5
+ import joblib
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+
9
+ model = joblib.load('gdsc_xgboost_model.pkl')
10
+
11
+ def predict_ic50(AUC, Z_SCORE, DRUG_ID, TARGET, TARGET_PATHWAY, Growth_Properties_Suspension):
12
+ # Create DataFrame for input
13
+ input_data = pd.DataFrame([{
14
+ 'AUC': AUC,
15
+ 'Z_SCORE': Z_SCORE,
16
+ 'DRUG_ID': DRUG_ID,
17
+ 'TARGET': TARGET,
18
+ 'TARGET_PATHWAY': TARGET_PATHWAY,
19
+ 'Growth Properties_Suspension': Growth_Properties_Suspension
20
+ }])
21
+
22
+ # One-hot encode categorical features if necessary
23
+ input_data = pd.get_dummies(input_data)
24
+
25
+ # Align input with model features
26
+ model_features = model.get_booster().feature_names
27
+ for feature in model_features:
28
+ if feature not in input_data.columns:
29
+ input_data[feature] = 0 # Add missing features with 0
30
+
31
+ input_data = input_data[model_features]
32
+
33
+ # Predict IC50
34
+ ic50_pred = model.predict(input_data)[0]
35
+
36
+ # SHAP Explanation
37
+ explainer = shap.Explainer(model)
38
+ shap_values = explainer(input_data)
39
+
40
+ # Plot SHAP explanation
41
+ plt.figure(figsize=(10, 6))
42
+ shap.plots.waterfall(shap_values[0], max_display=10)
43
+ plt.title("SHAP Explanation for Prediction")
44
+ plt.savefig("shap_plot.png")
45
+ plt.close()
46
+
47
+ return f"Predicted LN_IC50: {ic50_pred:.3f}", "shap_plot.png"
48
+
49
+ inputs = [
50
+ gr.Number(label="AUC (0.5 - 1.5)", value=0.85, info="Area Under Curve - Typically 0.5 to 1.5"),
51
+ gr.Number(label="Z_SCORE (-2 to 2)", value=0.45, info="Z-Score for dose-response curve"),
52
+ gr.Number(label="DRUG_ID (Numeric Code)", value=1003, info="Unique identifier for the drug"),
53
+ gr.Textbox(label="TARGET", value="MTORC1", placeholder="e.g., MTORC1", info="Gene or protein targeted by the drug"),
54
+ gr.Textbox(label="TARGET_PATHWAY", value="PI3K/MTOR signaling", placeholder="e.g., PI3K/MTOR signaling", info="Biological pathway affected"),
55
+ gr.Checkbox(label="Growth Properties - Suspension", value=False, info="Check if cells grow in suspension")
56
+ ]
57
+
58
+ outputs = [
59
+ gr.Textbox(label="Predicted LN_IC50"),
60
+ gr.Image(label="SHAP Explanation")
61
+ ]
62
+
63
+ gr.Interface(
64
+ fn=predict_ic50,
65
+ inputs=inputs,
66
+ outputs=outputs,
67
+ title="GDSC Drug Sensitivity Predictor",
68
+ description="Predict LN_IC50 for cancer drug response and visualize feature impact using SHAP. Please follow the input guidelines for accurate predictions.",
69
+ theme="default"
70
+ ).launch()