Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
from sklearn.datasets import load_iris | |
from sklearn.tree import DecisionTreeClassifier, export_graphviz | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import accuracy_score, confusion_matrix | |
import graphviz | |
def main(): | |
st.title("Decision Tree Classifier Demo") | |
st.write("An example using the Iris dataset with a Streamlit interface.") | |
# Load the Iris dataset | |
iris = load_iris() | |
X = iris.data | |
y = iris.target | |
feature_names = iris.feature_names | |
target_names = iris.target_names | |
# Sidebar: Hyperparameters | |
st.sidebar.header("Model Hyperparameters") | |
criterion = st.sidebar.selectbox("Criterion", ("gini", "entropy", "log_loss")) | |
max_depth = st.sidebar.slider("Max Depth", 1, 10, 3) | |
min_samples_split = st.sidebar.slider("Min Samples Split", 2, 10, 2) | |
# Split data into train and test sets | |
test_size = st.sidebar.slider("Test Size (fraction)", 0.1, 0.9, 0.3, 0.1) | |
random_state = st.sidebar.number_input("Random State", 0, 100, 42, 1) | |
X_train, X_test, y_train, y_test = train_test_split( | |
X, y, test_size=test_size, random_state=random_state | |
) | |
# Train the Decision Tree model | |
clf = DecisionTreeClassifier( | |
criterion=criterion, | |
max_depth=max_depth, | |
min_samples_split=min_samples_split, | |
random_state=random_state | |
) | |
clf.fit(X_train, y_train) | |
# Make predictions | |
y_pred = clf.predict(X_test) | |
# Show metrics | |
accuracy = accuracy_score(y_test, y_pred) | |
cm = confusion_matrix(y_test, y_pred) | |
st.subheader("Model Performance") | |
st.write(f"**Accuracy**: {accuracy:.2f}") | |
st.write("**Confusion Matrix**:") | |
st.write(cm) | |
# Visualize the decision tree | |
st.subheader("Decision Tree Visualization") | |
dot_data = export_graphviz( | |
clf, | |
out_file=None, | |
feature_names=feature_names, | |
class_names=target_names, | |
filled=True, | |
rounded=True, | |
special_characters=True | |
) | |
st.graphviz_chart(dot_data) | |
# Predict on user input | |
st.subheader("Make a Prediction") | |
user_inputs = [] | |
for i, feature in enumerate(feature_names): | |
val = st.number_input(f"{feature}", float(X[:, i].min()), float(X[:, i].max()), float(X[:, i].mean())) | |
user_inputs.append(val) | |
if st.button("Predict Class"): | |
user_pred = clf.predict([user_inputs]) | |
st.write(f"**Predicted Class**: {target_names[user_pred][0]}") | |
if __name__ == "__main__": | |
main() | |