import streamlit as st import pandas as pd from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier, export_graphviz from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, confusion_matrix import graphviz def main(): st.title("Decision Tree Classifier Demo") st.write("An example using the Iris dataset with a Streamlit interface.") # Load the Iris dataset iris = load_iris() X = iris.data y = iris.target feature_names = iris.feature_names target_names = iris.target_names # Sidebar: Hyperparameters st.sidebar.header("Model Hyperparameters") criterion = st.sidebar.selectbox("Criterion", ("gini", "entropy", "log_loss")) max_depth = st.sidebar.slider("Max Depth", 1, 10, 3) min_samples_split = st.sidebar.slider("Min Samples Split", 2, 10, 2) # Split data into train and test sets test_size = st.sidebar.slider("Test Size (fraction)", 0.1, 0.9, 0.3, 0.1) random_state = st.sidebar.number_input("Random State", 0, 100, 42, 1) X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=test_size, random_state=random_state ) # Train the Decision Tree model clf = DecisionTreeClassifier( criterion=criterion, max_depth=max_depth, min_samples_split=min_samples_split, random_state=random_state ) clf.fit(X_train, y_train) # Make predictions y_pred = clf.predict(X_test) # Show metrics accuracy = accuracy_score(y_test, y_pred) cm = confusion_matrix(y_test, y_pred) st.subheader("Model Performance") st.write(f"**Accuracy**: {accuracy:.2f}") st.write("**Confusion Matrix**:") st.write(cm) # Visualize the decision tree st.subheader("Decision Tree Visualization") dot_data = export_graphviz( clf, out_file=None, feature_names=feature_names, class_names=target_names, filled=True, rounded=True, special_characters=True ) st.graphviz_chart(dot_data) # Predict on user input st.subheader("Make a Prediction") user_inputs = [] for i, feature in enumerate(feature_names): val = st.number_input(f"{feature}", float(X[:, i].min()), float(X[:, i].max()), float(X[:, i].mean())) user_inputs.append(val) if st.button("Predict Class"): user_pred = clf.predict([user_inputs]) st.write(f"**Predicted Class**: {target_names[user_pred][0]}") if __name__ == "__main__": main()