Karthik001291546 commited on
Commit
75d0fc8
1 Parent(s): 0cd0026

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from joblib import load
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+ # Load your saved models
6
+ rf = load('best_random_forest_model.joblib')
7
+ dt = load('best_decision_tree_model.joblib')
8
+ mlp = load('best_MLP_classifier_model.joblib')
9
+ knn = load('best_knn_model.joblib')
10
+
11
+ # Class names
12
+ class_names = ["Low Therapeutic Dose of Warfarin Required", "High Therapeutic Dose of Warfarin Required"]
13
+
14
+ # Load training data for expected feature names
15
+ training_data = pd.read_csv('/content/dataset_train.csv')
16
+
17
+ # Drop the 'Unnamed: 0' column if it exists
18
+ if 'Unnamed: 0' in training_data.columns:
19
+ training_data = training_data.drop(columns=['Unnamed: 0'])
20
+
21
+ expected_feature_names = training_data.columns.tolist()
22
+
23
+ # Define the prediction function
24
+ def predict_warfarin_dose(gender, race, age, height, weight, diabetes, simvastatin, amiodarone, genotype, inr, algorithm):
25
+ # Decode the encoded values
26
+ gender = "Male" if gender == 1 else "Female"
27
+ race = race_dict_inverse[race]
28
+ age = age_dict_inverse[age]
29
+ genotype = genotype_dict_inverse[genotype]
30
+
31
+ # Convert input data to DataFrame for one-hot encoding
32
+ input_data = pd.DataFrame([[gender, race, age, height, weight, diabetes, simvastatin, amiodarone, genotype, inr]],
33
+ columns=['gender', 'race', 'age', 'height', 'weight', 'diabetes', 'simvastatin',
34
+ 'amiodarone', 'genotype', 'inr'])
35
+
36
+ # One-hot encode categorical features
37
+ input_data_encoded = pd.get_dummies(input_data, columns=['gender', 'race', 'diabetes', 'simvastatin', 'amiodarone', 'genotype'])
38
+
39
+ # Reindex the DataFrame to match expected feature names
40
+ input_data_encoded = input_data_encoded.reindex(columns=expected_feature_names, fill_value=0)
41
+
42
+ # Predict using the selected algorithm
43
+ if algorithm == 'Random Forest':
44
+ model = rf
45
+ elif algorithm == 'Decision Tree':
46
+ model = dt
47
+ elif algorithm == 'MLP':
48
+ model = mlp
49
+ elif algorithm == 'KNN':
50
+ model = knn
51
+ else:
52
+ raise ValueError("Invalid algorithm selected.")
53
+
54
+ y_prob = model.predict_proba(input_data_encoded)
55
+ class_idx = np.argmax(y_prob)
56
+
57
+ preds_dict = {class_names[i]: float(y_prob[0, i]) for i in range(len(class_names))}
58
+ name = class_names[class_idx]
59
+
60
+ return name, preds_dict
61
+
62
+ race_dict = {
63
+ "African-American":0,
64
+ "Asian":1,
65
+ "Black":2,
66
+ "Black African":3,"Black Caribbean":4,"Black or African American":5,"Black other":6 ,
67
+ "Caucasian":7,"Chinese":8,"Han Chinese":9,"Hispanic":10,"Indian":11,"Intermediate":12,
68
+ "Japanese":13,"Korean":14, "Malay":15, "Other":16, "Other (Black British)":17, "Other (Hungarian)":18, "Other Mixed Race":19, "White":20}
69
+
70
+
71
+ age_dict = {
72
+ "10-19":0,
73
+ "20-29":1,
74
+ "30-39":2,
75
+ "40-49":3,"50-59":4,"60-69":5,"70-79":6,
76
+ "80-89":7,"90+":8}
77
+
78
+ genotype_dict = {"A/A":0, "A/G":1, "G/G":2}
79
+ # Invert dictionaries for decoding
80
+ genotype_dict_inverse = {v: k for k, v in genotype_dict.items()}
81
+ race_dict_inverse = {v: k for k, v in race_dict.items()}
82
+ age_dict_inverse = {v: k for k, v in age_dict.items()}
83
+
84
+ # Create Gradio interface
85
+ gender_choices = [("Male", 1), ("Female", 0)]
86
+ gender_module = gr.Dropdown(choices=gender_choices, label="Gender")
87
+
88
+ # Assuming race_choices, age_choices, genotype_choices are already defined
89
+ race_module = gr.Dropdown(choices=list(race_dict.items()), label="Race")
90
+ age_module = gr.Dropdown(choices=list(age_dict.items()), label="Age Group")
91
+ genotype_module = gr.Dropdown(choices=list(genotype_dict.items()), label="Genotype")
92
+
93
+ height_module = gr.Number(label="Height")
94
+ weight_module = gr.Number(label="Weight")
95
+ diabetes_module = gr.Number(label="Diabetes")
96
+ simvastatin_module = gr.Radio(choices=[0, 1], label="Simvastatin")
97
+ amiodarone_module = gr.Radio(choices=[0, 1], label="Amiodarone")
98
+ inr_module = gr.Number(label="INR Reported")
99
+ algorithm_module = gr.Dropdown(choices=["Random Forest", "Decision Tree", "MLP", "KNN"], label="Algorithm")
100
+ output_module1 = gr.Textbox(label="Predicted Class")
101
+ output_module2 = gr.Label(label="Predicted Probability")
102
+
103
+ iface = gr.Interface(fn=predict_warfarin_dose,
104
+ inputs=[gender_module, race_module, age_module, height_module, weight_module, diabetes_module,
105
+ simvastatin_module, amiodarone_module, genotype_module, inr_module, algorithm_module],
106
+ outputs=[output_module1, output_module2])
107
+
108
+ iface.launch(debug=True)