Spaces:
Running
Running
import gradio as gr | |
import pickle | |
import pandas as pd | |
import shap | |
import matplotlib.pyplot as plt | |
# Load model | |
filename = 'xgb_h_generation.pkl' | |
with open(filename, 'rb') as f: | |
loaded_model = pickle.load(f) | |
# Setup SHAP | |
explainer = shap.Explainer(loaded_model) | |
# Generation Mapping (Radio Button Labels β Numeric Values) | |
generation_mapping = { | |
"Before 1927": 1, | |
"Silent Generation": 2, | |
"Baby Boomers": 3, | |
"Generation X": 4, | |
"Millennials": 5, | |
"Generation Z": 6 | |
} | |
# Employee Profiles (Updated Default Dream Employee Values) | |
employee_profiles = { | |
"π₯ Default Dream Employee": [5.0, 5.0, 5.0, 4.8, 4.8, 4.9], | |
"π Leslie Knope": [4.716, 4.792, 4.864, 4.588, 4.849, 4.601], | |
"β οΈ Kevin Malone": [3.045, 3.122, 3.129, 2.886, 3.113, 2.197], | |
"π± Jim Halpert": [3.885, 3.992, 4.119, 3.704, 4.090, 3.377] | |
} | |
# Define the prediction function | |
def main_func(Generation_Label, WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit): | |
Generation = generation_mapping.get(Generation_Label, 5) # Convert label to numeric value | |
new_row = pd.DataFrame({ | |
'Generation': [Generation], | |
'WellBeing': [WellBeing], | |
'SupportiveGM': [SupportiveGM], | |
'Engagement': [Engagement], | |
'Workload': [Workload], | |
'WorkEnvironment': [WorkEnvironment], | |
'Merit': [Merit] | |
}) | |
# Predict probability | |
prob = loaded_model.predict_proba(new_row) | |
shap_values = explainer(new_row) | |
# Calculate probability values | |
stay_prob = round((1 - float(prob[0][0])) * 100, 2) | |
leave_prob = round(float(prob[0][0]) * 100, 2) | |
# Dynamic risk label | |
risk_label = "π΄ High Risk of Turnover" if leave_prob > 50 else "π’ Low Risk of Turnover" | |
risk_color = "red" if leave_prob > 50 else "green" | |
risk_html = f""" | |
<div style='border: 1px solid black; padding: 15px; border-radius: 8px; display: flex;'> | |
<div style='width: 50%; padding-right: 15px;'> | |
<span style='color: {risk_color}; font-size: 26px; font-weight: bold;'>{risk_label}</span> | |
<ul style='list-style-type: none; padding-left: 0; font-size: 20px; font-weight: bold; color: #0057B8;'> | |
<li>𧲠Likelihood of Staying: {stay_prob}%</li> | |
<li>πͺ Likelihood of Leaving: {leave_prob}%</li> | |
</ul> | |
</div> | |
<div style='width: 50%; border-left: 1px solid black; padding-left: 15px;'> | |
<b style='color: #0057B8; font-size: 22px;'>Key Insights:</b> | |
""" | |
# Key Insights (excluding Generation) | |
shap_values_df = pd.DataFrame(shap_values.values, columns=new_row.columns) | |
shap_values_df = shap_values_df.drop(columns=["Generation"]) # Drop Generation | |
for feature in shap_values_df.columns: | |
shap_val = shap_values_df[feature].values[0] | |
impact = f"{abs(shap_val * 1):.2f}" # FIXED: Correct decimal place | |
icon = "π" if shap_val > 0 else "π" | |
effect = "raises turnover risk" if shap_val > 0 else "improves retention" | |
risk_html += f"<p style='margin: 5px 0;'> {icon} <b>Each 1-point increase in {feature} {effect} by {impact}%.</b></p>" | |
risk_html += "</div></div>" | |
# Retention vs. Turnover Chart | |
fig, ax = plt.subplots() | |
categories = ["Stay", "Leave"] | |
values = [stay_prob, leave_prob] | |
colors = ["#0057B8", "#D43F00"] | |
ax.barh(categories, values, color=colors) | |
for i, v in enumerate(values): | |
ax.text(v + 2, i, f"{v:.2f}%", va='center', fontweight='bold', fontsize=12) | |
ax.set_xlabel("Probability (%)") | |
ax.set_title("Retention vs. Turnover Probability") | |
plt.tight_layout() | |
prob_chart_path = "prob_chart.png" | |
plt.savefig(prob_chart_path, transparent=True) | |
plt.close() | |
# SHAP Chart (excluding Generation) | |
fig, ax = plt.subplots() | |
shap_values_filtered = shap_values[:, 1:] # Remove Generation from SHAP values | |
shap.plots.bar(shap_values_filtered[0], max_display=6, show=False) # Adjust max_display if needed | |
ax.set_title("Key Drivers of Turnover Risk") | |
plt.tight_layout() | |
shap_plot_path = "shap_plot.png" | |
plt.savefig(shap_plot_path, transparent=True) | |
plt.close() | |
return risk_html, prob_chart_path, shap_plot_path | |
# Function to update sliders based on selected profile | |
def update_sliders(profile): | |
if profile in employee_profiles: | |
return employee_profiles[profile] | |
return [5.0, 5.0, 5.0, 4.8, 4.8, 4.9] | |
# UI Setup | |
with gr.Blocks() as demo: | |
gr.Image("HiltonLogoSmall.jpg") | |
gr.Markdown(""" | |
<div style="display: flex; justify-content: center; align-items: center;"> | |
<img src="file=assets/HiltonLogoSmall.jpg" alt="Hilton Logo" width="250px"> | |
</div> | |
""") | |
gr.Markdown("<h1 style='color: #0057B8;'>Hilton Team Member Retention Predictor</h1>") | |
gr.Markdown(""" | |
<div style='font-size: 20px; color: #0057B8;'> | |
β¨ <b>Welcome to Hiltonβs Employee Retention Predictor</b><br> | |
This tool helps <b>HR and People Analytics professionals</b> assess | |
<b>Sales, Marketing, and Front Office Operations teams</b>β<span style='color: #0057B8;'>The Face of Hilton</span>β | |
by analyzing <b>team member engagement</b> and predicting <b>turnover risk</b> using | |
<span style='color: #0057B8;'>AI-powered insights</span>.<br> | |
π <b>Understand what drives retention and make data-driven decisions to keep top talent.</b> | |
</div> | |
""") | |
# Generation Filter (Radio Button - Independent) | |
generation_filter = gr.Radio(choices=list(generation_mapping.keys()), label="Select Generation", value="Millennials") | |
# Dropdown for Employee Profiles (Updates Sliders) | |
profile_dropdown = gr.Dropdown(choices=list(employee_profiles.keys()), label="Select Employee Profile", value="π₯ Default Dream Employee") | |
# Sliders for input features (Updated Order) | |
with gr.Row(): | |
WellBeing = gr.Slider(label="WellBeing Score", minimum=1, maximum=5, value=5.0, step=0.1) | |
SupportiveGM = gr.Slider(label="Supportive GM Score", minimum=1, maximum=5, value=5.0, step=0.1) | |
Engagement = gr.Slider(label="Engagement Score", minimum=1, maximum=5, value=5.0, step=0.1) | |
with gr.Row(): | |
Workload = gr.Slider(label="Workload Score", minimum=1, maximum=5, value=4.8, step=0.1) | |
WorkEnvironment = gr.Slider(label="Work Environment Score", minimum=1, maximum=5, value=4.8, step=0.1) | |
Merit = gr.Slider(label="Merit Score", minimum=1, maximum=5, value=4.9, step=0.1) | |
submit_btn = gr.Button("π Click Here to Analyze Retention") | |
prediction = gr.HTML() | |
# Charts Side by Side | |
with gr.Row(): | |
prob_chart = gr.Image(label="Retention vs. Turnover Probability", type="filepath") | |
shap_plot = gr.Image(label="Key Drivers of Turnover Risk", type="filepath") | |
profile_dropdown.change(update_sliders, inputs=[profile_dropdown], outputs=[WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit]) | |
submit_btn.click(main_func, [generation_filter, WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit], | |
[prediction, prob_chart, shap_plot]) | |
demo.launch() |