Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn.datasets import load_iris
|
4 |
+
from sklearn.tree import DecisionTreeClassifier, export_graphviz
|
5 |
+
from sklearn.model_selection import train_test_split
|
6 |
+
from sklearn.metrics import accuracy_score, confusion_matrix
|
7 |
+
import graphviz
|
8 |
+
|
9 |
+
def main():
|
10 |
+
st.title("Decision Tree Classifier Demo")
|
11 |
+
st.write("An example using the Iris dataset with a Streamlit interface.")
|
12 |
+
|
13 |
+
# Load the Iris dataset
|
14 |
+
iris = load_iris()
|
15 |
+
X = iris.data
|
16 |
+
y = iris.target
|
17 |
+
feature_names = iris.feature_names
|
18 |
+
target_names = iris.target_names
|
19 |
+
|
20 |
+
# Sidebar: Hyperparameters
|
21 |
+
st.sidebar.header("Model Hyperparameters")
|
22 |
+
criterion = st.sidebar.selectbox("Criterion", ("gini", "entropy", "log_loss"))
|
23 |
+
max_depth = st.sidebar.slider("Max Depth", 1, 10, 3)
|
24 |
+
min_samples_split = st.sidebar.slider("Min Samples Split", 2, 10, 2)
|
25 |
+
|
26 |
+
# Split data into train and test sets
|
27 |
+
test_size = st.sidebar.slider("Test Size (fraction)", 0.1, 0.9, 0.3, 0.1)
|
28 |
+
random_state = st.sidebar.number_input("Random State", 0, 100, 42, 1)
|
29 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
30 |
+
X, y, test_size=test_size, random_state=random_state
|
31 |
+
)
|
32 |
+
|
33 |
+
# Train the Decision Tree model
|
34 |
+
clf = DecisionTreeClassifier(
|
35 |
+
criterion=criterion,
|
36 |
+
max_depth=max_depth,
|
37 |
+
min_samples_split=min_samples_split,
|
38 |
+
random_state=random_state
|
39 |
+
)
|
40 |
+
clf.fit(X_train, y_train)
|
41 |
+
|
42 |
+
# Make predictions
|
43 |
+
y_pred = clf.predict(X_test)
|
44 |
+
|
45 |
+
# Show metrics
|
46 |
+
accuracy = accuracy_score(y_test, y_pred)
|
47 |
+
cm = confusion_matrix(y_test, y_pred)
|
48 |
+
|
49 |
+
st.subheader("Model Performance")
|
50 |
+
st.write(f"**Accuracy**: {accuracy:.2f}")
|
51 |
+
|
52 |
+
st.write("**Confusion Matrix**:")
|
53 |
+
st.write(cm)
|
54 |
+
|
55 |
+
# Visualize the decision tree
|
56 |
+
st.subheader("Decision Tree Visualization")
|
57 |
+
dot_data = export_graphviz(
|
58 |
+
clf,
|
59 |
+
out_file=None,
|
60 |
+
feature_names=feature_names,
|
61 |
+
class_names=target_names,
|
62 |
+
filled=True,
|
63 |
+
rounded=True,
|
64 |
+
special_characters=True
|
65 |
+
)
|
66 |
+
st.graphviz_chart(dot_data)
|
67 |
+
|
68 |
+
# Predict on user input
|
69 |
+
st.subheader("Make a Prediction")
|
70 |
+
user_inputs = []
|
71 |
+
for i, feature in enumerate(feature_names):
|
72 |
+
val = st.number_input(f"{feature}", float(X[:, i].min()), float(X[:, i].max()), float(X[:, i].mean()))
|
73 |
+
user_inputs.append(val)
|
74 |
+
|
75 |
+
if st.button("Predict Class"):
|
76 |
+
user_pred = clf.predict([user_inputs])
|
77 |
+
st.write(f"**Predicted Class**: {target_names[user_pred][0]}")
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
main()
|