|
|
|
import gradio as gr |
|
import random |
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
import shap |
|
import xgboost as xgb |
|
from datasets import load_dataset |
|
|
|
dataset = load_dataset("scikit-learn/adult-census-income") |
|
X_train = dataset["train"].to_pandas() |
|
_ = X_train.pop("fnlwgt") |
|
_ = X_train.pop("race") |
|
y_train = X_train.pop("income") |
|
y_train = (y_train == ">50K").astype(int) |
|
categorical_columns = [ |
|
"workclass", |
|
"education", |
|
"marital.status", |
|
"occupation", |
|
"relationship", |
|
"sex", |
|
"native.country", |
|
] |
|
X_train = X_train.astype({col: "category" for col in categorical_columns}) |
|
data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True) |
|
model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data) |
|
explainer = shap.TreeExplainer(model) |
|
|
|
def predict(*args): |
|
df = pd.DataFrame([args], columns=X_train.columns) |
|
df = df.astype({col: "category" for col in categorical_columns}) |
|
pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True)) |
|
return {">50K": float(pos_pred[0]), "<=50K": 1 - float(pos_pred[0])} |
|
|
|
def interpret(*args): |
|
df = pd.DataFrame([args], columns=X_train.columns) |
|
df = df.astype({col: "category" for col in categorical_columns}) |
|
shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True)) |
|
scores_desc = list(zip(shap_values[0], X_train.columns)) |
|
scores_desc = sorted(scores_desc) |
|
fig_m = plt.figure(tight_layout=True) |
|
plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc]) |
|
plt.title("Feature Shap Values") |
|
plt.ylabel("Shap Value") |
|
plt.xlabel("Feature") |
|
plt.tight_layout() |
|
return fig_m |
|
|
|
unique_class = sorted(X_train["workclass"].unique()) |
|
unique_education = sorted(X_train["education"].unique()) |
|
unique_marital_status = sorted(X_train["marital.status"].unique()) |
|
unique_relationship = sorted(X_train["relationship"].unique()) |
|
unique_occupation = sorted(X_train["occupation"].unique()) |
|
unique_sex = sorted(X_train["sex"].unique()) |
|
unique_country = sorted(X_train["native.country"].unique()) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
**Income Classification with XGBoost 💰**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py). |
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
age = gr.Slider(label="Age", minimum=17, maximum=90, step=1, randomize=True) |
|
work_class = gr.Dropdown( |
|
label="Workclass", |
|
choices=unique_class, |
|
value=lambda: random.choice(unique_class), |
|
) |
|
education = gr.Dropdown( |
|
label="Education Level", |
|
choices=unique_education, |
|
value=lambda: random.choice(unique_education), |
|
) |
|
years = gr.Slider( |
|
label="Years of schooling", |
|
minimum=1, |
|
maximum=16, |
|
step=1, |
|
randomize=True, |
|
) |
|
marital_status = gr.Dropdown( |
|
label="Marital Status", |
|
choices=unique_marital_status, |
|
value=lambda: random.choice(unique_marital_status), |
|
) |
|
occupation = gr.Dropdown( |
|
label="Occupation", |
|
choices=unique_occupation, |
|
value=lambda: random.choice(unique_occupation), |
|
) |
|
relationship = gr.Dropdown( |
|
label="Relationship Status", |
|
choices=unique_relationship, |
|
value=lambda: random.choice(unique_relationship), |
|
) |
|
sex = gr.Dropdown( |
|
label="Sex", choices=unique_sex, value=lambda: random.choice(unique_sex) |
|
) |
|
capital_gain = gr.Slider( |
|
label="Capital Gain", |
|
minimum=0, |
|
maximum=100000, |
|
step=500, |
|
randomize=True, |
|
) |
|
capital_loss = gr.Slider( |
|
label="Capital Loss", minimum=0, maximum=10000, step=500, randomize=True |
|
) |
|
hours_per_week = gr.Slider( |
|
label="Hours Per Week Worked", minimum=1, maximum=99, step=1 |
|
) |
|
country = gr.Dropdown( |
|
label="Native Country", |
|
choices=unique_country, |
|
value=lambda: random.choice(unique_country), |
|
) |
|
with gr.Column(): |
|
label = gr.Label() |
|
plot = gr.Plot() |
|
with gr.Row(): |
|
predict_btn = gr.Button(value="Predict") |
|
interpret_btn = gr.Button(value="Explain") |
|
predict_btn.click( |
|
predict, |
|
inputs=[ |
|
age, |
|
work_class, |
|
education, |
|
years, |
|
marital_status, |
|
occupation, |
|
relationship, |
|
sex, |
|
capital_gain, |
|
capital_loss, |
|
hours_per_week, |
|
country, |
|
], |
|
outputs=[label], |
|
) |
|
interpret_btn.click( |
|
interpret, |
|
inputs=[ |
|
age, |
|
work_class, |
|
education, |
|
years, |
|
marital_status, |
|
occupation, |
|
relationship, |
|
sex, |
|
capital_gain, |
|
capital_loss, |
|
hours_per_week, |
|
country, |
|
], |
|
outputs=[plot], |
|
) |
|
|
|
demo.launch() |
|
|