File size: 5,875 Bytes
5c33331 907fc70 |
|
# type: ignore
import gradio as gr
import random
import matplotlib.pyplot as plt
import pandas as pd
import shap
import xgboost as xgb
from datasets import load_dataset
dataset = load_dataset("scikit-learn/adult-census-income")
X_train = dataset["train"].to_pandas()
_ = X_train.pop("fnlwgt")
_ = X_train.pop("race")
y_train = X_train.pop("income")
y_train = (y_train == ">50K").astype(int)
categorical_columns = [
"workclass",
"education",
"marital.status",
"occupation",
"relationship",
"sex",
"native.country",
]
X_train = X_train.astype({col: "category" for col in categorical_columns})
data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data)
explainer = shap.TreeExplainer(model)
def predict(*args):
df = pd.DataFrame([args], columns=X_train.columns)
df = df.astype({col: "category" for col in categorical_columns})
pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True))
return {">50K": float(pos_pred[0]), "<=50K": 1 - float(pos_pred[0])}
def interpret(*args):
df = pd.DataFrame([args], columns=X_train.columns)
df = df.astype({col: "category" for col in categorical_columns})
shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))
scores_desc = list(zip(shap_values[0], X_train.columns))
scores_desc = sorted(scores_desc)
fig_m = plt.figure(tight_layout=True)
plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
plt.title("Feature Shap Values")
plt.ylabel("Shap Value")
plt.xlabel("Feature")
plt.tight_layout()
return fig_m
unique_class = sorted(X_train["workclass"].unique())
unique_education = sorted(X_train["education"].unique())
unique_marital_status = sorted(X_train["marital.status"].unique())
unique_relationship = sorted(X_train["relationship"].unique())
unique_occupation = sorted(X_train["occupation"].unique())
unique_sex = sorted(X_train["sex"].unique())
unique_country = sorted(X_train["native.country"].unique())
with gr.Blocks() as demo:
gr.Markdown("""
**Income Classification with XGBoost 💰**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).
""")
with gr.Row():
with gr.Column():
age = gr.Slider(label="Age", minimum=17, maximum=90, step=1, randomize=True)
work_class = gr.Dropdown(
label="Workclass",
choices=unique_class,
value=lambda: random.choice(unique_class),
)
education = gr.Dropdown(
label="Education Level",
choices=unique_education,
value=lambda: random.choice(unique_education),
)
years = gr.Slider(
label="Years of schooling",
minimum=1,
maximum=16,
step=1,
randomize=True,
)
marital_status = gr.Dropdown(
label="Marital Status",
choices=unique_marital_status,
value=lambda: random.choice(unique_marital_status),
)
occupation = gr.Dropdown(
label="Occupation",
choices=unique_occupation,
value=lambda: random.choice(unique_occupation),
)
relationship = gr.Dropdown(
label="Relationship Status",
choices=unique_relationship,
value=lambda: random.choice(unique_relationship),
)
sex = gr.Dropdown(
label="Sex", choices=unique_sex, value=lambda: random.choice(unique_sex)
)
capital_gain = gr.Slider(
label="Capital Gain",
minimum=0,
maximum=100000,
step=500,
randomize=True,
)
capital_loss = gr.Slider(
label="Capital Loss", minimum=0, maximum=10000, step=500, randomize=True
)
hours_per_week = gr.Slider(
label="Hours Per Week Worked", minimum=1, maximum=99, step=1
)
country = gr.Dropdown(
label="Native Country",
choices=unique_country,
value=lambda: random.choice(unique_country),
)
with gr.Column():
label = gr.Label()
plot = gr.Plot()
with gr.Row():
predict_btn = gr.Button(value="Predict")
interpret_btn = gr.Button(value="Explain")
predict_btn.click(
predict,
inputs=[
age,
work_class,
education,
years,
marital_status,
occupation,
relationship,
sex,
capital_gain,
capital_loss,
hours_per_week,
country,
],
outputs=[label],
)
interpret_btn.click(
interpret,
inputs=[
age,
work_class,
education,
years,
marital_status,
occupation,
relationship,
sex,
capital_gain,
capital_loss,
hours_per_week,
country,
],
outputs=[plot],
)
demo.launch()
|