Class_Test / app.py
eaglelandsonce's picture
Create app.py
283ce49 verified
raw
history blame
2.55 kB
import streamlit as st
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
import graphviz
def main():
st.title("Decision Tree Classifier Demo")
st.write("An example using the Iris dataset with a Streamlit interface.")
# Load the Iris dataset
iris = load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names
target_names = iris.target_names
# Sidebar: Hyperparameters
st.sidebar.header("Model Hyperparameters")
criterion = st.sidebar.selectbox("Criterion", ("gini", "entropy", "log_loss"))
max_depth = st.sidebar.slider("Max Depth", 1, 10, 3)
min_samples_split = st.sidebar.slider("Min Samples Split", 2, 10, 2)
# Split data into train and test sets
test_size = st.sidebar.slider("Test Size (fraction)", 0.1, 0.9, 0.3, 0.1)
random_state = st.sidebar.number_input("Random State", 0, 100, 42, 1)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, random_state=random_state
)
# Train the Decision Tree model
clf = DecisionTreeClassifier(
criterion=criterion,
max_depth=max_depth,
min_samples_split=min_samples_split,
random_state=random_state
)
clf.fit(X_train, y_train)
# Make predictions
y_pred = clf.predict(X_test)
# Show metrics
accuracy = accuracy_score(y_test, y_pred)
cm = confusion_matrix(y_test, y_pred)
st.subheader("Model Performance")
st.write(f"**Accuracy**: {accuracy:.2f}")
st.write("**Confusion Matrix**:")
st.write(cm)
# Visualize the decision tree
st.subheader("Decision Tree Visualization")
dot_data = export_graphviz(
clf,
out_file=None,
feature_names=feature_names,
class_names=target_names,
filled=True,
rounded=True,
special_characters=True
)
st.graphviz_chart(dot_data)
# Predict on user input
st.subheader("Make a Prediction")
user_inputs = []
for i, feature in enumerate(feature_names):
val = st.number_input(f"{feature}", float(X[:, i].min()), float(X[:, i].max()), float(X[:, i].mean()))
user_inputs.append(val)
if st.button("Predict Class"):
user_pred = clf.predict([user_inputs])
st.write(f"**Predicted Class**: {target_names[user_pred][0]}")
if __name__ == "__main__":
main()