Spaces:
Running
Running
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "68bddf84-439e-461b-b7f9-4c9212ca81f5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%capture --no-display\n", | |
"pip install gradio\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "d692b9ae-c7d4-40b6-9777-992a4303055b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import gradio as gr" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "9742e1cb-7e40-4af9-8c59-6a341c9d9b38", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pickle\n", | |
"import pandas as pd\n", | |
"import shap\n", | |
"from shap.plots._force_matplotlib import draw_additive_plot\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"from sklearn.preprocessing import LabelEncoder, MinMaxScaler, StandardScaler # Importing function for scaling the data" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7ab0387d-3a31-4164-a46e-b1bb57dbf33a", | |
"metadata": {}, | |
"source": [ | |
"# App Code" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "b168e539-83f7-4bb6-a112-fe8de8fdab52", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"('WellBeing',\n", | |
" 'SupportiveGM',\n", | |
" 'Engagement',\n", | |
" 'Workload',\n", | |
" 'WorkEnvironment',\n", | |
" 'Merit')" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"'WellBeing', 'SupportiveGM', 'Engagement', 'Workload', 'WorkEnvironment', 'Merit'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "fc97600d-3c32-4d66-a60e-b416c170293b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"* Running on local URL: http://127.0.0.1:7860\n", | |
"* Running on public URL: https://35c8546aedfc1b28e8.gradio.live\n", | |
"\n", | |
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<div><iframe src=\"https://35c8546aedfc1b28e8.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import gradio as gr\n", | |
"import pickle\n", | |
"import pandas as pd\n", | |
"import shap\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"# Load model\n", | |
"filename = 'xgb_h_new.pkl'\n", | |
"with open(filename, 'rb') as f:\n", | |
" loaded_model = pickle.load(f)\n", | |
"\n", | |
"# Setup SHAP\n", | |
"explainer = shap.Explainer(loaded_model)\n", | |
"\n", | |
"# Employee Profiles (Adjusted Top Performer)\n", | |
"employee_profiles = {\n", | |
" \"π High Potential Employee\": [4, 5, 5, 3, 4, 5],\n", | |
" \"π Top Performer\": [5, 5, 5, 3, 5, 5], # Reduced workload\n", | |
" \"β οΈ At-Risk Employee\": [2, 2, 2, 4, 2, 2],\n", | |
" \"π₯ Burnt-Out Employee\": [1, 2, 2, 5, 1, 1]\n", | |
"}\n", | |
"\n", | |
"# Define the prediction function\n", | |
"def main_func(WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit):\n", | |
" new_row = pd.DataFrame({\n", | |
" 'WellBeing': [WellBeing],\n", | |
" 'SupportiveGM': [SupportiveGM],\n", | |
" 'Engagement': [Engagement],\n", | |
" 'Workload': [Workload],\n", | |
" 'WorkEnvironment': [WorkEnvironment],\n", | |
" 'Merit': [Merit]\n", | |
" })\n", | |
"\n", | |
" # Predict probability\n", | |
" prob = loaded_model.predict_proba(new_row)\n", | |
" shap_values = explainer(new_row)\n", | |
"\n", | |
" # Calculate probability values\n", | |
" stay_prob = round((1 - float(prob[0][0])) * 100, 2)\n", | |
" leave_prob = round(float(prob[0][0]) * 100, 2)\n", | |
"\n", | |
" # Dynamic risk label: Changes color & text based on probability\n", | |
" risk_label = \"π΄ High Risk of Turnover\" if leave_prob > 50 else \"π’ Low Risk of Turnover\"\n", | |
" risk_color = \"red\" if leave_prob > 50 else \"green\"\n", | |
"\n", | |
" risk_html = f\"\"\"\n", | |
" <div style='padding: 15px; border-radius: 8px;'>\n", | |
" <span style='color: {risk_color}; font-size: 26px; font-weight: bold;'>{risk_label}</span>\n", | |
" <ul style='list-style-type: none; padding-left: 0; font-size: 20px; font-weight: bold; color: #0057B8;'>\n", | |
" <li>π§² Likelihood of Staying: {stay_prob}%</li>\n", | |
" <li>πͺ Likelihood of Leaving: {leave_prob}%</li>\n", | |
" </ul>\n", | |
" </div>\n", | |
" \"\"\"\n", | |
"\n", | |
" # Key Insights (Updated for 0.1-point increments)\n", | |
" insights_html = \"<div style='font-size: 18px;'>\"\n", | |
" for feature, shap_val in dict(zip(new_row.columns, shap_values.values[0])).items():\n", | |
" impact = round(shap_val * 10, 2) # Scaling impact for 0.1 changes\n", | |
" icon = \"π\" if shap_val > 0 else \"π\"\n", | |
" effect = \"raises turnover risk\" if shap_val > 0 else \"improves retention\"\n", | |
" insights_html += f\"<p style='margin: 5px 0;'> {icon} <b>Each 0.1-point increase in {feature} {effect} by {abs(impact)}%.</b></p>\"\n", | |
" insights_html += \"</div>\"\n", | |
"\n", | |
" # Final Layout (Risk + Key Insights)\n", | |
" final_layout = f\"\"\"\n", | |
" <table style='width:100%; border-collapse: collapse; margin-top: 10px; background-color: #FFFFFF;'>\n", | |
" <tr>\n", | |
" <td style='width: 33%; padding: 15px; background-color: #FFFFFF; border-radius: 8px; vertical-align: top;'>\n", | |
" {risk_html}\n", | |
" </td>\n", | |
" <td style='width: 67%; padding: 15px; background-color: #FFFFFF; border-radius: 8px; vertical-align: top;'>\n", | |
" <b style='color: #0057B8; font-size: 22px;'>Key Insights:</b>\n", | |
" {insights_html}\n", | |
" </td>\n", | |
" </tr>\n", | |
" </table>\n", | |
" \"\"\"\n", | |
"\n", | |
" # Retention vs. Turnover Chart\n", | |
" fig, ax = plt.subplots()\n", | |
" categories = [\"Stay\", \"Leave\"]\n", | |
" values = [stay_prob, leave_prob]\n", | |
" colors = [\"#0057B8\", \"#D43F00\"]\n", | |
" ax.barh(categories, values, color=colors)\n", | |
" for i, v in enumerate(values):\n", | |
" ax.text(v + 2, i, f\"{v:.2f}%\", va='center', fontweight='bold', fontsize=12)\n", | |
" ax.set_xlabel(\"Probability (%)\")\n", | |
" ax.set_title(\"Retention vs. Turnover Probability\")\n", | |
" plt.tight_layout()\n", | |
" prob_chart_path = \"prob_chart.png\"\n", | |
" plt.savefig(prob_chart_path, transparent=True)\n", | |
" plt.close()\n", | |
"\n", | |
" # SHAP Chart\n", | |
" fig, ax = plt.subplots()\n", | |
" shap.plots.bar(shap_values[0], max_display=6, show=False)\n", | |
" ax.set_title(\"Key Drivers of Turnover Risk\")\n", | |
" plt.tight_layout()\n", | |
" shap_plot_path = \"shap_plot.png\"\n", | |
" plt.savefig(shap_plot_path, transparent=True)\n", | |
" plt.close()\n", | |
"\n", | |
" return final_layout, prob_chart_path, shap_plot_path\n", | |
"\n", | |
"# UI Setup\n", | |
"with gr.Blocks() as demo:\n", | |
" gr.Markdown(\"\"\"\n", | |
" <div style=\"display: flex; justify-content: center; align-items: center;\">\n", | |
" <img src=\"https://logos-world.net/wp-content/uploads/2021/02/Hilton-Logo.png\" width=\"250px\">\n", | |
" </div>\n", | |
" \"\"\")\n", | |
" gr.Markdown(\"<h1 style='color: #0057B8;'>Hilton Team Member Retention Predictor</h1>\")\n", | |
" gr.Markdown(\"\"\"\n", | |
" <div style='font-size: 20px; color: #0057B8;'>\n", | |
" β¨ <b>Welcome to Hiltonβs Employee Retention Predictor</b><br>\n", | |
" This tool helps <b>HR leaders & managers</b> assess <b>team member engagement</b> \n", | |
" and predict <b>turnover risk</b> using AI-powered insights.<br> \n", | |
" π <b>See what factors drive retention & make data-driven decisions.</b> \n", | |
" </div>\n", | |
" \"\"\")\n", | |
"\n", | |
" # Dropdown for Employee Profiles\n", | |
" profile_dropdown = gr.Dropdown(choices=list(employee_profiles.keys()), label=\"Select Employee Profile\")\n", | |
"\n", | |
" # Sliders for input features\n", | |
" with gr.Row():\n", | |
" WellBeing = gr.Slider(label=\"WellBeing Score\", minimum=1, maximum=5, value=4, step=0.1)\n", | |
" SupportiveGM = gr.Slider(label=\"Supportive GM Score\", minimum=1, maximum=5, value=4, step=0.1)\n", | |
" Engagement = gr.Slider(label=\"Engagement Score\", minimum=1, maximum=5, value=4, step=0.1)\n", | |
" with gr.Row():\n", | |
" Workload = gr.Slider(label=\"Workload Score\", minimum=1, maximum=5, value=4, step=0.1)\n", | |
" WorkEnvironment = gr.Slider(label=\"Work Environment Score\", minimum=1, maximum=5, value=4, step=0.1)\n", | |
" Merit = gr.Slider(label=\"Merit Score\", minimum=1, maximum=5, value=4, step=0.1)\n", | |
"\n", | |
" submit_btn = gr.Button(\"π Click Here to Analyze Retention\")\n", | |
"\n", | |
" # Output elements\n", | |
" prediction = gr.HTML()\n", | |
" with gr.Row():\n", | |
" prob_chart = gr.Image(label=\"Retention vs. Turnover Probability\", type=\"filepath\")\n", | |
" shap_plot = gr.Image(label=\"Key Drivers of Turnover Risk\", type=\"filepath\")\n", | |
"\n", | |
" # Allow profile selection to update sliders\n", | |
" def update_sliders(profile):\n", | |
" if profile in employee_profiles:\n", | |
" return employee_profiles[profile]\n", | |
" return [4, 4, 4, 4, 4, 4]\n", | |
"\n", | |
" profile_dropdown.change(update_sliders, inputs=[profile_dropdown], outputs=[WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit])\n", | |
"\n", | |
" submit_btn.click(main_func, [WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit], [prediction, prob_chart, shap_plot])\n", | |
"\n", | |
"demo.launch(share=True)\n" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} | |