File size: 2,548 Bytes
283ce49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import streamlit as st
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
import graphviz

def main():
    st.title("Decision Tree Classifier Demo")
    st.write("An example using the Iris dataset with a Streamlit interface.")

    # Load the Iris dataset
    iris = load_iris()
    X = iris.data
    y = iris.target
    feature_names = iris.feature_names
    target_names = iris.target_names

    # Sidebar: Hyperparameters
    st.sidebar.header("Model Hyperparameters")
    criterion = st.sidebar.selectbox("Criterion", ("gini", "entropy", "log_loss"))
    max_depth = st.sidebar.slider("Max Depth", 1, 10, 3)
    min_samples_split = st.sidebar.slider("Min Samples Split", 2, 10, 2)

    # Split data into train and test sets
    test_size = st.sidebar.slider("Test Size (fraction)", 0.1, 0.9, 0.3, 0.1)
    random_state = st.sidebar.number_input("Random State", 0, 100, 42, 1)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state
    )

    # Train the Decision Tree model
    clf = DecisionTreeClassifier(
        criterion=criterion,
        max_depth=max_depth,
        min_samples_split=min_samples_split,
        random_state=random_state
    )
    clf.fit(X_train, y_train)

    # Make predictions
    y_pred = clf.predict(X_test)

    # Show metrics
    accuracy = accuracy_score(y_test, y_pred)
    cm = confusion_matrix(y_test, y_pred)

    st.subheader("Model Performance")
    st.write(f"**Accuracy**: {accuracy:.2f}")

    st.write("**Confusion Matrix**:")
    st.write(cm)

    # Visualize the decision tree
    st.subheader("Decision Tree Visualization")
    dot_data = export_graphviz(
        clf,
        out_file=None,
        feature_names=feature_names,
        class_names=target_names,
        filled=True,
        rounded=True,
        special_characters=True
    )
    st.graphviz_chart(dot_data)

    # Predict on user input
    st.subheader("Make a Prediction")
    user_inputs = []
    for i, feature in enumerate(feature_names):
        val = st.number_input(f"{feature}", float(X[:, i].min()), float(X[:, i].max()), float(X[:, i].mean()))
        user_inputs.append(val)

    if st.button("Predict Class"):
        user_pred = clf.predict([user_inputs])
        st.write(f"**Predicted Class**: {target_names[user_pred][0]}")

if __name__ == "__main__":
    main()