eaglelandsonce commited on
Commit
283ce49
·
verified ·
1 Parent(s): c710cd9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from sklearn.datasets import load_iris
4
+ from sklearn.tree import DecisionTreeClassifier, export_graphviz
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.metrics import accuracy_score, confusion_matrix
7
+ import graphviz
8
+
9
+ def main():
10
+ st.title("Decision Tree Classifier Demo")
11
+ st.write("An example using the Iris dataset with a Streamlit interface.")
12
+
13
+ # Load the Iris dataset
14
+ iris = load_iris()
15
+ X = iris.data
16
+ y = iris.target
17
+ feature_names = iris.feature_names
18
+ target_names = iris.target_names
19
+
20
+ # Sidebar: Hyperparameters
21
+ st.sidebar.header("Model Hyperparameters")
22
+ criterion = st.sidebar.selectbox("Criterion", ("gini", "entropy", "log_loss"))
23
+ max_depth = st.sidebar.slider("Max Depth", 1, 10, 3)
24
+ min_samples_split = st.sidebar.slider("Min Samples Split", 2, 10, 2)
25
+
26
+ # Split data into train and test sets
27
+ test_size = st.sidebar.slider("Test Size (fraction)", 0.1, 0.9, 0.3, 0.1)
28
+ random_state = st.sidebar.number_input("Random State", 0, 100, 42, 1)
29
+ X_train, X_test, y_train, y_test = train_test_split(
30
+ X, y, test_size=test_size, random_state=random_state
31
+ )
32
+
33
+ # Train the Decision Tree model
34
+ clf = DecisionTreeClassifier(
35
+ criterion=criterion,
36
+ max_depth=max_depth,
37
+ min_samples_split=min_samples_split,
38
+ random_state=random_state
39
+ )
40
+ clf.fit(X_train, y_train)
41
+
42
+ # Make predictions
43
+ y_pred = clf.predict(X_test)
44
+
45
+ # Show metrics
46
+ accuracy = accuracy_score(y_test, y_pred)
47
+ cm = confusion_matrix(y_test, y_pred)
48
+
49
+ st.subheader("Model Performance")
50
+ st.write(f"**Accuracy**: {accuracy:.2f}")
51
+
52
+ st.write("**Confusion Matrix**:")
53
+ st.write(cm)
54
+
55
+ # Visualize the decision tree
56
+ st.subheader("Decision Tree Visualization")
57
+ dot_data = export_graphviz(
58
+ clf,
59
+ out_file=None,
60
+ feature_names=feature_names,
61
+ class_names=target_names,
62
+ filled=True,
63
+ rounded=True,
64
+ special_characters=True
65
+ )
66
+ st.graphviz_chart(dot_data)
67
+
68
+ # Predict on user input
69
+ st.subheader("Make a Prediction")
70
+ user_inputs = []
71
+ for i, feature in enumerate(feature_names):
72
+ val = st.number_input(f"{feature}", float(X[:, i].min()), float(X[:, i].max()), float(X[:, i].mean()))
73
+ user_inputs.append(val)
74
+
75
+ if st.button("Predict Class"):
76
+ user_pred = clf.predict([user_inputs])
77
+ st.write(f"**Predicted Class**: {target_names[user_pred][0]}")
78
+
79
+ if __name__ == "__main__":
80
+ main()