Mod3_Team1_2025 / app.py
hsandvik00's picture
Upload app.py
1c2b61a verified
raw
history blame
11.1 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "68bddf84-439e-461b-b7f9-4c9212ca81f5",
"metadata": {},
"outputs": [],
"source": [
"%%capture --no-display\n",
"pip install gradio\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d692b9ae-c7d4-40b6-9777-992a4303055b",
"metadata": {},
"outputs": [],
"source": [
"import gradio as gr"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9742e1cb-7e40-4af9-8c59-6a341c9d9b38",
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import pandas as pd\n",
"import shap\n",
"from shap.plots._force_matplotlib import draw_additive_plot\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.preprocessing import LabelEncoder, MinMaxScaler, StandardScaler # Importing function for scaling the data"
]
},
{
"cell_type": "markdown",
"id": "7ab0387d-3a31-4164-a46e-b1bb57dbf33a",
"metadata": {},
"source": [
"# App Code"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b168e539-83f7-4bb6-a112-fe8de8fdab52",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('WellBeing',\n",
" 'SupportiveGM',\n",
" 'Engagement',\n",
" 'Workload',\n",
" 'WorkEnvironment',\n",
" 'Merit')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"'WellBeing', 'SupportiveGM', 'Engagement', 'Workload', 'WorkEnvironment', 'Merit'"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fc97600d-3c32-4d66-a60e-b416c170293b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"* Running on local URL: http://127.0.0.1:7860\n",
"* Running on public URL: https://35c8546aedfc1b28e8.gradio.live\n",
"\n",
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
]
},
{
"data": {
"text/html": [
"<div><iframe src=\"https://35c8546aedfc1b28e8.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": []
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import gradio as gr\n",
"import pickle\n",
"import pandas as pd\n",
"import shap\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Load model\n",
"filename = 'xgb_h_new.pkl'\n",
"with open(filename, 'rb') as f:\n",
" loaded_model = pickle.load(f)\n",
"\n",
"# Setup SHAP\n",
"explainer = shap.Explainer(loaded_model)\n",
"\n",
"# Employee Profiles (Adjusted Top Performer)\n",
"employee_profiles = {\n",
" \"πŸ“ˆ High Potential Employee\": [4, 5, 5, 3, 4, 5],\n",
" \"πŸ† Top Performer\": [5, 5, 5, 3, 5, 5], # Reduced workload\n",
" \"⚠️ At-Risk Employee\": [2, 2, 2, 4, 2, 2],\n",
" \"πŸ”₯ Burnt-Out Employee\": [1, 2, 2, 5, 1, 1]\n",
"}\n",
"\n",
"# Define the prediction function\n",
"def main_func(WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit):\n",
" new_row = pd.DataFrame({\n",
" 'WellBeing': [WellBeing],\n",
" 'SupportiveGM': [SupportiveGM],\n",
" 'Engagement': [Engagement],\n",
" 'Workload': [Workload],\n",
" 'WorkEnvironment': [WorkEnvironment],\n",
" 'Merit': [Merit]\n",
" })\n",
"\n",
" # Predict probability\n",
" prob = loaded_model.predict_proba(new_row)\n",
" shap_values = explainer(new_row)\n",
"\n",
" # Calculate probability values\n",
" stay_prob = round((1 - float(prob[0][0])) * 100, 2)\n",
" leave_prob = round(float(prob[0][0]) * 100, 2)\n",
"\n",
" # Dynamic risk label: Changes color & text based on probability\n",
" risk_label = \"πŸ”΄ High Risk of Turnover\" if leave_prob > 50 else \"🟒 Low Risk of Turnover\"\n",
" risk_color = \"red\" if leave_prob > 50 else \"green\"\n",
"\n",
" risk_html = f\"\"\"\n",
" <div style='padding: 15px; border-radius: 8px;'>\n",
" <span style='color: {risk_color}; font-size: 26px; font-weight: bold;'>{risk_label}</span>\n",
" <ul style='list-style-type: none; padding-left: 0; font-size: 20px; font-weight: bold; color: #0057B8;'>\n",
" <li>🧲 Likelihood of Staying: {stay_prob}%</li>\n",
" <li>πŸšͺ Likelihood of Leaving: {leave_prob}%</li>\n",
" </ul>\n",
" </div>\n",
" \"\"\"\n",
"\n",
" # Key Insights (Updated for 0.1-point increments)\n",
" insights_html = \"<div style='font-size: 18px;'>\"\n",
" for feature, shap_val in dict(zip(new_row.columns, shap_values.values[0])).items():\n",
" impact = round(shap_val * 10, 2) # Scaling impact for 0.1 changes\n",
" icon = \"πŸ“ˆ\" if shap_val > 0 else \"πŸ“‰\"\n",
" effect = \"raises turnover risk\" if shap_val > 0 else \"improves retention\"\n",
" insights_html += f\"<p style='margin: 5px 0;'> {icon} <b>Each 0.1-point increase in {feature} {effect} by {abs(impact)}%.</b></p>\"\n",
" insights_html += \"</div>\"\n",
"\n",
" # Final Layout (Risk + Key Insights)\n",
" final_layout = f\"\"\"\n",
" <table style='width:100%; border-collapse: collapse; margin-top: 10px; background-color: #FFFFFF;'>\n",
" <tr>\n",
" <td style='width: 33%; padding: 15px; background-color: #FFFFFF; border-radius: 8px; vertical-align: top;'>\n",
" {risk_html}\n",
" </td>\n",
" <td style='width: 67%; padding: 15px; background-color: #FFFFFF; border-radius: 8px; vertical-align: top;'>\n",
" <b style='color: #0057B8; font-size: 22px;'>Key Insights:</b>\n",
" {insights_html}\n",
" </td>\n",
" </tr>\n",
" </table>\n",
" \"\"\"\n",
"\n",
" # Retention vs. Turnover Chart\n",
" fig, ax = plt.subplots()\n",
" categories = [\"Stay\", \"Leave\"]\n",
" values = [stay_prob, leave_prob]\n",
" colors = [\"#0057B8\", \"#D43F00\"]\n",
" ax.barh(categories, values, color=colors)\n",
" for i, v in enumerate(values):\n",
" ax.text(v + 2, i, f\"{v:.2f}%\", va='center', fontweight='bold', fontsize=12)\n",
" ax.set_xlabel(\"Probability (%)\")\n",
" ax.set_title(\"Retention vs. Turnover Probability\")\n",
" plt.tight_layout()\n",
" prob_chart_path = \"prob_chart.png\"\n",
" plt.savefig(prob_chart_path, transparent=True)\n",
" plt.close()\n",
"\n",
" # SHAP Chart\n",
" fig, ax = plt.subplots()\n",
" shap.plots.bar(shap_values[0], max_display=6, show=False)\n",
" ax.set_title(\"Key Drivers of Turnover Risk\")\n",
" plt.tight_layout()\n",
" shap_plot_path = \"shap_plot.png\"\n",
" plt.savefig(shap_plot_path, transparent=True)\n",
" plt.close()\n",
"\n",
" return final_layout, prob_chart_path, shap_plot_path\n",
"\n",
"# UI Setup\n",
"with gr.Blocks() as demo:\n",
" gr.Markdown(\"\"\"\n",
" <div style=\"display: flex; justify-content: center; align-items: center;\">\n",
" <img src=\"https://logos-world.net/wp-content/uploads/2021/02/Hilton-Logo.png\" width=\"250px\">\n",
" </div>\n",
" \"\"\")\n",
" gr.Markdown(\"<h1 style='color: #0057B8;'>Hilton Team Member Retention Predictor</h1>\")\n",
" gr.Markdown(\"\"\"\n",
" <div style='font-size: 20px; color: #0057B8;'>\n",
" ✨ <b>Welcome to Hilton’s Employee Retention Predictor</b><br>\n",
" This tool helps <b>HR leaders & managers</b> assess <b>team member engagement</b> \n",
" and predict <b>turnover risk</b> using AI-powered insights.<br> \n",
" πŸ” <b>See what factors drive retention & make data-driven decisions.</b> \n",
" </div>\n",
" \"\"\")\n",
"\n",
" # Dropdown for Employee Profiles\n",
" profile_dropdown = gr.Dropdown(choices=list(employee_profiles.keys()), label=\"Select Employee Profile\")\n",
"\n",
" # Sliders for input features\n",
" with gr.Row():\n",
" WellBeing = gr.Slider(label=\"WellBeing Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
" SupportiveGM = gr.Slider(label=\"Supportive GM Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
" Engagement = gr.Slider(label=\"Engagement Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
" with gr.Row():\n",
" Workload = gr.Slider(label=\"Workload Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
" WorkEnvironment = gr.Slider(label=\"Work Environment Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
" Merit = gr.Slider(label=\"Merit Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
"\n",
" submit_btn = gr.Button(\"πŸ”Ž Click Here to Analyze Retention\")\n",
"\n",
" # Output elements\n",
" prediction = gr.HTML()\n",
" with gr.Row():\n",
" prob_chart = gr.Image(label=\"Retention vs. Turnover Probability\", type=\"filepath\")\n",
" shap_plot = gr.Image(label=\"Key Drivers of Turnover Risk\", type=\"filepath\")\n",
"\n",
" # Allow profile selection to update sliders\n",
" def update_sliders(profile):\n",
" if profile in employee_profiles:\n",
" return employee_profiles[profile]\n",
" return [4, 4, 4, 4, 4, 4]\n",
"\n",
" profile_dropdown.change(update_sliders, inputs=[profile_dropdown], outputs=[WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit])\n",
"\n",
" submit_btn.click(main_func, [WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit], [prediction, prob_chart, shap_plot])\n",
"\n",
"demo.launch(share=True)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}