visit2sachin56 commited on
Commit
5c04edd
1 Parent(s): d3110d3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.svm import SVC
5
+ from sklearn.linear_model import LogisticRegression
6
+ from sklearn.ensemble import RandomForestClassifier
7
+ from sklearn.preprocessing import LabelEncoder
8
+ from sklearn.model_selection import train_test_split
9
+ from sklearn.metrics import plot_confusion_matrix, plot_roc_curve, plot_precision_recall_curve
10
+ from sklearn.metrics import precision_score, recall_score
11
+
12
+ def main():
13
+ st.title("Binary Classification Web App")
14
+ st.sidebar.title("Binary Classification Web App")
15
+ st.markdown("Are your mushroom is editable or poisionous? ")
16
+ st.sidebar.markdown("Are your mushroom is editable or poisionous? ")
17
+
18
+ def load_data():
19
+ data = pd.read_csv('/home/rhyme/Desktop/Project/mushrooms.csv')
20
+ label = LabelEncoder()
21
+ for col in data.columns:
22
+ data[col]= label.fit_transform(data[col])
23
+
24
+ return data
25
+
26
+
27
+ @st.cache(persist=True)
28
+ def split(df):
29
+ y = df.type
30
+ x = df.drop(columns=['type'])
31
+ x_train , x_test, y_train,y_test = train_test_split(x,y,test_size=0.3, random_state=0)
32
+ return x_train,x_test, y_train,y_test
33
+
34
+
35
+ def plot_metrics(metrics_list):
36
+ if 'Confusion Matrix' in metrics_list:
37
+ st.subheader("Confusion Matrix")
38
+ plot_confusion_matrix(model, x_test,y_test,display_labels=class_names)
39
+ st.pyplot()
40
+
41
+ if 'ROC Curve' in metrics_list:
42
+ st.subheader("ROC Curve")
43
+ plot_roc_curve(model,x_test,y_test)
44
+ st.pyplot()
45
+
46
+ if 'Precision-Recall Curve' in metrics_list:
47
+ st.subheader("Precision-Recall Curve")
48
+ plot_precision_recall_curve(model,x_test,y_test)
49
+ st.pyplot()
50
+
51
+
52
+
53
+ df = load_data()
54
+ x_train, x_test, y_train, y_test = split(df)
55
+ class_names = ['edible', 'poisionous']
56
+ st.sidebar.subheader("Chosse Classifiers")
57
+ classifier = st.sidebar.selectbox("Classifier", ("Support Vector Machine(SVM)", "Logostics Regression", "Random Forest"))
58
+
59
+ if classifier == "Support Vector Machine(SVM)":
60
+ st.sidebar.subheader("Model Hyperparameters")
61
+ C = st.sidebar.number_input("C (Regularization parameter)", 0.01,10.0,step=0.01,key='C')
62
+ kernel = st.sidebar.radio("kernel", ("rbf", "linear"), key='kernal')
63
+ gamma = st.sidebar.radio("Gamma (Kernel Coefficient)", ("scale","auto"),key='gamma')
64
+
65
+ metrics = st.sidebar.multiselect("What metrics to plot?", ('Confusion Matrix','ROC Curve','Precision-Recall Curve' ))
66
+
67
+ if st.sidebar.button("Classify", key = 'classify'):
68
+ st.subheader("Support Vector Machine (SVM)")
69
+ model = SVC(C=C,kernel=kernel, gamma=gamma)
70
+ model.fit(x_train,y_train)
71
+ accuracy = model.score(x_test,y_test)
72
+ y_pred = model.predict(x_test)
73
+ st.write("Accuracy: ",accuracy.round(2))
74
+ st.write("Precision : ", precision_score(y_test,y_pred, labels=class_names).round(2))
75
+ st.write("Recall: ", recall_score(y_test, y_pred, labels= class_names).round(2))
76
+ plot_metrics(metrics)
77
+
78
+
79
+ if classifier == "Logostics Regression":
80
+ st.sidebar.subheader("Model Hyperparameters")
81
+ C = st.sidebar.number_input("C (Regularization parameter)", 0.01,10.0,step=0.01,key='C_LR')
82
+ max_iter = st.sidebar.slider("Maximum number of iterations", 100, 500, key='max_iter')
83
+
84
+ # kernel = st.sidebar.radio("kernel", ("rbf", "linear"), key='kernal')
85
+ # gamma = st.sidebar.radio("Gamma (Kernel Coefficient)", ("scale","auto"),key='gamma')
86
+
87
+ metrics = st.sidebar.multiselect("What metrics to plot?", ('Confusion Matrix','ROC Curve','Precision-Recall Curve' ))
88
+
89
+ if st.sidebar.button("Classify", key = 'classify'):
90
+ st.subheader("Logistics Regression")
91
+ model = LogisticRegression(C=C,max_iter =max_iter)
92
+ model.fit(x_train,y_train)
93
+ accuracy = model.score(x_test,y_test)
94
+ y_pred = model.predict(x_test)
95
+ st.write("Accuracy: ",accuracy.round(2))
96
+ st.write("Precision : ", precision_score(y_test,y_pred, labels=class_names).round(2))
97
+ st.write("Recall: ", recall_score(y_test, y_pred, labels= class_names).round(2))
98
+ plot_metrics(metrics)
99
+
100
+ #Random Forest
101
+ if classifier == "Random Forest":
102
+ st.sidebar.subheader("Model Hyperparameters")
103
+ n_estimators = st.sidebar.number_input("The number of trees in the forest", 100, 500, step=10,key='n_estmators')
104
+ max_depth = st.sidebar.number_input("The maximum depth of the tree", 1, 20 , step=1, key='max_depth')
105
+ bootstrap = st.sidebar.radio("Bootstrap samples when builoding trees", ('True','False'),key='bootstrap')
106
+ # C = st.sidebar.number_input("C (Regularization parameter)", 0.01,10.0,step=0.01,key='C_LR')
107
+ # max_iter = st.sidebar.slider("Maximum number of iterations", 100, 500, key='max_iter')
108
+
109
+ # kernel = st.sidebar.radio("kernel", ("rbf", "linear"), key='kernal')
110
+ # gamma = st.sidebar.radio("Gamma (Kernel Coefficient)", ("scale","auto"),key='gamma')
111
+
112
+ metrics = st.sidebar.multiselect("What metrics to plot?", ('Confusion Matrix','ROC Curve','Precision-Recall Curve' ))
113
+
114
+ if st.sidebar.button("Classify", key = 'classify'):
115
+ st.subheader("Random Forest")
116
+ model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth,bootstrap=bootstrap)
117
+ model.fit(x_train,y_train)
118
+ accuracy = model.score(x_test,y_test)
119
+ y_pred = model.predict(x_test)
120
+ st.write("Accuracy: ",accuracy.round(2))
121
+ st.write("Precision : ", precision_score(y_test,y_pred, labels=class_names).round(2))
122
+ st.write("Recall: ", recall_score(y_test, y_pred, labels= class_names).round(2))
123
+ plot_metrics(metrics)
124
+
125
+ if st.sidebar.checkbox("show raw data",False):
126
+ st.subheader("Mushroom data Set (Classifications)")
127
+ st.write(df)
128
+
129
+
130
+ if __name__ == '__main__':
131
+ main()
132
+
133
+