hsandvik00 commited on
Commit
cfbc198
Β·
verified Β·
1 Parent(s): de6f85e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -289
app.py CHANGED
@@ -1,290 +1,154 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "68bddf84-439e-461b-b7f9-4c9212ca81f5",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "%%capture --no-display\n",
11
- "pip install gradio\n"
12
- ]
13
- },
14
- {
15
- "cell_type": "code",
16
- "execution_count": 2,
17
- "id": "d692b9ae-c7d4-40b6-9777-992a4303055b",
18
- "metadata": {},
19
- "outputs": [],
20
- "source": [
21
- "import gradio as gr"
22
- ]
23
- },
24
- {
25
- "cell_type": "code",
26
- "execution_count": 3,
27
- "id": "9742e1cb-7e40-4af9-8c59-6a341c9d9b38",
28
- "metadata": {},
29
- "outputs": [],
30
- "source": [
31
- "import pickle\n",
32
- "import pandas as pd\n",
33
- "import shap\n",
34
- "from shap.plots._force_matplotlib import draw_additive_plot\n",
35
- "import numpy as np\n",
36
- "import matplotlib.pyplot as plt\n",
37
- "from sklearn.preprocessing import LabelEncoder, MinMaxScaler, StandardScaler # Importing function for scaling the data"
38
- ]
39
- },
40
- {
41
- "cell_type": "markdown",
42
- "id": "7ab0387d-3a31-4164-a46e-b1bb57dbf33a",
43
- "metadata": {},
44
- "source": [
45
- "# App Code"
46
- ]
47
- },
48
- {
49
- "cell_type": "code",
50
- "execution_count": 4,
51
- "id": "b168e539-83f7-4bb6-a112-fe8de8fdab52",
52
- "metadata": {},
53
- "outputs": [
54
- {
55
- "data": {
56
- "text/plain": [
57
- "('WellBeing',\n",
58
- " 'SupportiveGM',\n",
59
- " 'Engagement',\n",
60
- " 'Workload',\n",
61
- " 'WorkEnvironment',\n",
62
- " 'Merit')"
63
- ]
64
- },
65
- "execution_count": 4,
66
- "metadata": {},
67
- "output_type": "execute_result"
68
- }
69
- ],
70
- "source": [
71
- "'WellBeing', 'SupportiveGM', 'Engagement', 'Workload', 'WorkEnvironment', 'Merit'"
72
- ]
73
- },
74
- {
75
- "cell_type": "code",
76
- "execution_count": 5,
77
- "id": "fc97600d-3c32-4d66-a60e-b416c170293b",
78
- "metadata": {},
79
- "outputs": [
80
- {
81
- "name": "stdout",
82
- "output_type": "stream",
83
- "text": [
84
- "* Running on local URL: http://127.0.0.1:7860\n",
85
- "* Running on public URL: https://35c8546aedfc1b28e8.gradio.live\n",
86
- "\n",
87
- "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
88
- ]
89
- },
90
- {
91
- "data": {
92
- "text/html": [
93
- "<div><iframe src=\"https://35c8546aedfc1b28e8.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
94
- ],
95
- "text/plain": [
96
- "<IPython.core.display.HTML object>"
97
- ]
98
- },
99
- "metadata": {},
100
- "output_type": "display_data"
101
- },
102
- {
103
- "data": {
104
- "text/plain": []
105
- },
106
- "execution_count": 5,
107
- "metadata": {},
108
- "output_type": "execute_result"
109
- }
110
- ],
111
- "source": [
112
- "import gradio as gr\n",
113
- "import pickle\n",
114
- "import pandas as pd\n",
115
- "import shap\n",
116
- "import matplotlib.pyplot as plt\n",
117
- "\n",
118
- "# Load model\n",
119
- "filename = 'xgb_h_new.pkl'\n",
120
- "with open(filename, 'rb') as f:\n",
121
- " loaded_model = pickle.load(f)\n",
122
- "\n",
123
- "# Setup SHAP\n",
124
- "explainer = shap.Explainer(loaded_model)\n",
125
- "\n",
126
- "# Employee Profiles (Adjusted Top Performer)\n",
127
- "employee_profiles = {\n",
128
- " \"πŸ“ˆ High Potential Employee\": [4, 5, 5, 3, 4, 5],\n",
129
- " \"πŸ† Top Performer\": [5, 5, 5, 3, 5, 5], # Reduced workload\n",
130
- " \"⚠️ At-Risk Employee\": [2, 2, 2, 4, 2, 2],\n",
131
- " \"πŸ”₯ Burnt-Out Employee\": [1, 2, 2, 5, 1, 1]\n",
132
- "}\n",
133
- "\n",
134
- "# Define the prediction function\n",
135
- "def main_func(WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit):\n",
136
- " new_row = pd.DataFrame({\n",
137
- " 'WellBeing': [WellBeing],\n",
138
- " 'SupportiveGM': [SupportiveGM],\n",
139
- " 'Engagement': [Engagement],\n",
140
- " 'Workload': [Workload],\n",
141
- " 'WorkEnvironment': [WorkEnvironment],\n",
142
- " 'Merit': [Merit]\n",
143
- " })\n",
144
- "\n",
145
- " # Predict probability\n",
146
- " prob = loaded_model.predict_proba(new_row)\n",
147
- " shap_values = explainer(new_row)\n",
148
- "\n",
149
- " # Calculate probability values\n",
150
- " stay_prob = round((1 - float(prob[0][0])) * 100, 2)\n",
151
- " leave_prob = round(float(prob[0][0]) * 100, 2)\n",
152
- "\n",
153
- " # Dynamic risk label: Changes color & text based on probability\n",
154
- " risk_label = \"πŸ”΄ High Risk of Turnover\" if leave_prob > 50 else \"🟒 Low Risk of Turnover\"\n",
155
- " risk_color = \"red\" if leave_prob > 50 else \"green\"\n",
156
- "\n",
157
- " risk_html = f\"\"\"\n",
158
- " <div style='padding: 15px; border-radius: 8px;'>\n",
159
- " <span style='color: {risk_color}; font-size: 26px; font-weight: bold;'>{risk_label}</span>\n",
160
- " <ul style='list-style-type: none; padding-left: 0; font-size: 20px; font-weight: bold; color: #0057B8;'>\n",
161
- " <li>🧲 Likelihood of Staying: {stay_prob}%</li>\n",
162
- " <li>πŸšͺ Likelihood of Leaving: {leave_prob}%</li>\n",
163
- " </ul>\n",
164
- " </div>\n",
165
- " \"\"\"\n",
166
- "\n",
167
- " # Key Insights (Updated for 0.1-point increments)\n",
168
- " insights_html = \"<div style='font-size: 18px;'>\"\n",
169
- " for feature, shap_val in dict(zip(new_row.columns, shap_values.values[0])).items():\n",
170
- " impact = round(shap_val * 10, 2) # Scaling impact for 0.1 changes\n",
171
- " icon = \"πŸ“ˆ\" if shap_val > 0 else \"πŸ“‰\"\n",
172
- " effect = \"raises turnover risk\" if shap_val > 0 else \"improves retention\"\n",
173
- " insights_html += f\"<p style='margin: 5px 0;'> {icon} <b>Each 0.1-point increase in {feature} {effect} by {abs(impact)}%.</b></p>\"\n",
174
- " insights_html += \"</div>\"\n",
175
- "\n",
176
- " # Final Layout (Risk + Key Insights)\n",
177
- " final_layout = f\"\"\"\n",
178
- " <table style='width:100%; border-collapse: collapse; margin-top: 10px; background-color: #FFFFFF;'>\n",
179
- " <tr>\n",
180
- " <td style='width: 33%; padding: 15px; background-color: #FFFFFF; border-radius: 8px; vertical-align: top;'>\n",
181
- " {risk_html}\n",
182
- " </td>\n",
183
- " <td style='width: 67%; padding: 15px; background-color: #FFFFFF; border-radius: 8px; vertical-align: top;'>\n",
184
- " <b style='color: #0057B8; font-size: 22px;'>Key Insights:</b>\n",
185
- " {insights_html}\n",
186
- " </td>\n",
187
- " </tr>\n",
188
- " </table>\n",
189
- " \"\"\"\n",
190
- "\n",
191
- " # Retention vs. Turnover Chart\n",
192
- " fig, ax = plt.subplots()\n",
193
- " categories = [\"Stay\", \"Leave\"]\n",
194
- " values = [stay_prob, leave_prob]\n",
195
- " colors = [\"#0057B8\", \"#D43F00\"]\n",
196
- " ax.barh(categories, values, color=colors)\n",
197
- " for i, v in enumerate(values):\n",
198
- " ax.text(v + 2, i, f\"{v:.2f}%\", va='center', fontweight='bold', fontsize=12)\n",
199
- " ax.set_xlabel(\"Probability (%)\")\n",
200
- " ax.set_title(\"Retention vs. Turnover Probability\")\n",
201
- " plt.tight_layout()\n",
202
- " prob_chart_path = \"prob_chart.png\"\n",
203
- " plt.savefig(prob_chart_path, transparent=True)\n",
204
- " plt.close()\n",
205
- "\n",
206
- " # SHAP Chart\n",
207
- " fig, ax = plt.subplots()\n",
208
- " shap.plots.bar(shap_values[0], max_display=6, show=False)\n",
209
- " ax.set_title(\"Key Drivers of Turnover Risk\")\n",
210
- " plt.tight_layout()\n",
211
- " shap_plot_path = \"shap_plot.png\"\n",
212
- " plt.savefig(shap_plot_path, transparent=True)\n",
213
- " plt.close()\n",
214
- "\n",
215
- " return final_layout, prob_chart_path, shap_plot_path\n",
216
- "\n",
217
- "# UI Setup\n",
218
- "with gr.Blocks() as demo:\n",
219
- " gr.Markdown(\"\"\"\n",
220
- " <div style=\"display: flex; justify-content: center; align-items: center;\">\n",
221
- " <img src=\"https://logos-world.net/wp-content/uploads/2021/02/Hilton-Logo.png\" width=\"250px\">\n",
222
- " </div>\n",
223
- " \"\"\")\n",
224
- " gr.Markdown(\"<h1 style='color: #0057B8;'>Hilton Team Member Retention Predictor</h1>\")\n",
225
- " gr.Markdown(\"\"\"\n",
226
- " <div style='font-size: 20px; color: #0057B8;'>\n",
227
- " ✨ <b>Welcome to Hilton’s Employee Retention Predictor</b><br>\n",
228
- " This tool helps <b>HR leaders & managers</b> assess <b>team member engagement</b> \n",
229
- " and predict <b>turnover risk</b> using AI-powered insights.<br> \n",
230
- " πŸ” <b>See what factors drive retention & make data-driven decisions.</b> \n",
231
- " </div>\n",
232
- " \"\"\")\n",
233
- "\n",
234
- " # Dropdown for Employee Profiles\n",
235
- " profile_dropdown = gr.Dropdown(choices=list(employee_profiles.keys()), label=\"Select Employee Profile\")\n",
236
- "\n",
237
- " # Sliders for input features\n",
238
- " with gr.Row():\n",
239
- " WellBeing = gr.Slider(label=\"WellBeing Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
240
- " SupportiveGM = gr.Slider(label=\"Supportive GM Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
241
- " Engagement = gr.Slider(label=\"Engagement Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
242
- " with gr.Row():\n",
243
- " Workload = gr.Slider(label=\"Workload Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
244
- " WorkEnvironment = gr.Slider(label=\"Work Environment Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
245
- " Merit = gr.Slider(label=\"Merit Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
246
- "\n",
247
- " submit_btn = gr.Button(\"πŸ”Ž Click Here to Analyze Retention\")\n",
248
- "\n",
249
- " # Output elements\n",
250
- " prediction = gr.HTML()\n",
251
- " with gr.Row():\n",
252
- " prob_chart = gr.Image(label=\"Retention vs. Turnover Probability\", type=\"filepath\")\n",
253
- " shap_plot = gr.Image(label=\"Key Drivers of Turnover Risk\", type=\"filepath\")\n",
254
- "\n",
255
- " # Allow profile selection to update sliders\n",
256
- " def update_sliders(profile):\n",
257
- " if profile in employee_profiles:\n",
258
- " return employee_profiles[profile]\n",
259
- " return [4, 4, 4, 4, 4, 4]\n",
260
- "\n",
261
- " profile_dropdown.change(update_sliders, inputs=[profile_dropdown], outputs=[WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit])\n",
262
- "\n",
263
- " submit_btn.click(main_func, [WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit], [prediction, prob_chart, shap_plot])\n",
264
- "\n",
265
- "demo.launch(share=True)\n"
266
- ]
267
- }
268
- ],
269
- "metadata": {
270
- "kernelspec": {
271
- "display_name": "Python 3 (ipykernel)",
272
- "language": "python",
273
- "name": "python3"
274
- },
275
- "language_info": {
276
- "codemirror_mode": {
277
- "name": "ipython",
278
- "version": 3
279
- },
280
- "file_extension": ".py",
281
- "mimetype": "text/x-python",
282
- "name": "python",
283
- "nbconvert_exporter": "python",
284
- "pygments_lexer": "ipython3",
285
- "version": "3.11.9"
286
- }
287
- },
288
- "nbformat": 4,
289
- "nbformat_minor": 5
290
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pickle
3
+ import pandas as pd
4
+ import shap
5
+ import matplotlib.pyplot as plt
6
+
7
+ # Load model
8
+ filename = 'xgb_h_new.pkl'
9
+ with open(filename, 'rb') as f:
10
+ loaded_model = pickle.load(f)
11
+
12
+ # Setup SHAP
13
+ explainer = shap.Explainer(loaded_model)
14
+
15
+ # Employee Profiles (Adjusted Top Performer)
16
+ employee_profiles = {
17
+ "πŸ“ˆ High Potential Employee": [4, 5, 5, 3, 4, 5],
18
+ "πŸ† Top Performer": [5, 5, 5, 3, 5, 5], # Reduced workload
19
+ "⚠️ At-Risk Employee": [2, 2, 2, 4, 2, 2],
20
+ "πŸ”₯ Burnt-Out Employee": [1, 2, 2, 5, 1, 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  }
22
+
23
+ # Define the prediction function
24
+ def main_func(WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit):
25
+ new_row = pd.DataFrame({
26
+ 'WellBeing': [WellBeing],
27
+ 'SupportiveGM': [SupportiveGM],
28
+ 'Engagement': [Engagement],
29
+ 'Workload': [Workload],
30
+ 'WorkEnvironment': [WorkEnvironment],
31
+ 'Merit': [Merit]
32
+ })
33
+
34
+ # Predict probability
35
+ prob = loaded_model.predict_proba(new_row)
36
+ shap_values = explainer(new_row)
37
+
38
+ # Calculate probability values
39
+ stay_prob = round((1 - float(prob[0][0])) * 100, 2)
40
+ leave_prob = round(float(prob[0][0]) * 100, 2)
41
+
42
+ # Dynamic risk label: Changes color & text based on probability
43
+ risk_label = "πŸ”΄ High Risk of Turnover" if leave_prob > 50 else "🟒 Low Risk of Turnover"
44
+ risk_color = "red" if leave_prob > 50 else "green"
45
+
46
+ risk_html = f"""
47
+ <div style='padding: 15px; border-radius: 8px;'>
48
+ <span style='color: {risk_color}; font-size: 26px; font-weight: bold;'>{risk_label}</span>
49
+ <ul style='list-style-type: none; padding-left: 0; font-size: 20px; font-weight: bold; color: #0057B8;'>
50
+ <li>🧲 Likelihood of Staying: {stay_prob}%</li>
51
+ <li>πŸšͺ Likelihood of Leaving: {leave_prob}%</li>
52
+ </ul>
53
+ </div>
54
+ """
55
+
56
+ # Key Insights (Updated for 0.1-point increments)
57
+ insights_html = "<div style='font-size: 18px;'>"
58
+ for feature, shap_val in dict(zip(new_row.columns, shap_values.values[0])).items():
59
+ impact = round(shap_val * 10, 2) # Scaling impact for 0.1 changes
60
+ icon = "πŸ“ˆ" if shap_val > 0 else "πŸ“‰"
61
+ effect = "raises turnover risk" if shap_val > 0 else "improves retention"
62
+ insights_html += f"<p style='margin: 5px 0;'> {icon} <b>Each 0.1-point increase in {feature} {effect} by {abs(impact)}%.</b></p>"
63
+ insights_html += "</div>"
64
+
65
+ # Final Layout (Risk + Key Insights)
66
+ final_layout = f"""
67
+ <table style='width:100%; border-collapse: collapse; margin-top: 10px; background-color: #FFFFFF;'>
68
+ <tr>
69
+ <td style='width: 33%; padding: 15px; background-color: #FFFFFF; border-radius: 8px; vertical-align: top;'>
70
+ {risk_html}
71
+ </td>
72
+ <td style='width: 67%; padding: 15px; background-color: #FFFFFF; border-radius: 8px; vertical-align: top;'>
73
+ <b style='color: #0057B8; font-size: 22px;'>Key Insights:</b>
74
+ {insights_html}
75
+ </td>
76
+ </tr>
77
+ </table>
78
+ """
79
+
80
+ # Retention vs. Turnover Chart
81
+ fig, ax = plt.subplots()
82
+ categories = ["Stay", "Leave"]
83
+ values = [stay_prob, leave_prob]
84
+ colors = ["#0057B8", "#D43F00"]
85
+ ax.barh(categories, values, color=colors)
86
+ for i, v in enumerate(values):
87
+ ax.text(v + 2, i, f"{v:.2f}%", va='center', fontweight='bold', fontsize=12)
88
+ ax.set_xlabel("Probability (%)")
89
+ ax.set_title("Retention vs. Turnover Probability")
90
+ plt.tight_layout()
91
+ prob_chart_path = "prob_chart.png"
92
+ plt.savefig(prob_chart_path, transparent=True)
93
+ plt.close()
94
+
95
+ # SHAP Chart
96
+ fig, ax = plt.subplots()
97
+ shap.plots.bar(shap_values[0], max_display=6, show=False)
98
+ ax.set_title("Key Drivers of Turnover Risk")
99
+ plt.tight_layout()
100
+ shap_plot_path = "shap_plot.png"
101
+ plt.savefig(shap_plot_path, transparent=True)
102
+ plt.close()
103
+
104
+ return final_layout, prob_chart_path, shap_plot_path
105
+
106
+ # UI Setup
107
+ with gr.Blocks() as demo:
108
+ gr.Markdown("""
109
+ <div style="display: flex; justify-content: center; align-items: center;">
110
+ <img src="https://logos-world.net/wp-content/uploads/2021/02/Hilton-Logo.png" width="250px">
111
+ </div>
112
+ """)
113
+ gr.Markdown("<h1 style='color: #0057B8;'>Hilton Team Member Retention Predictor</h1>")
114
+ gr.Markdown("""
115
+ <div style='font-size: 20px; color: #0057B8;'>
116
+ ✨ <b>Welcome to Hilton’s Employee Retention Predictor</b><br>
117
+ This tool helps <b>HR leaders & managers</b> assess <b>team member engagement</b>
118
+ and predict <b>turnover risk</b> using AI-powered insights.<br>
119
+ πŸ” <b>See what factors drive retention & make data-driven decisions.</b>
120
+ </div>
121
+ """)
122
+
123
+ # Dropdown for Employee Profiles
124
+ profile_dropdown = gr.Dropdown(choices=list(employee_profiles.keys()), label="Select Employee Profile")
125
+
126
+ # Sliders for input features
127
+ with gr.Row():
128
+ WellBeing = gr.Slider(label="WellBeing Score", minimum=1, maximum=5, value=4, step=0.1)
129
+ SupportiveGM = gr.Slider(label="Supportive GM Score", minimum=1, maximum=5, value=4, step=0.1)
130
+ Engagement = gr.Slider(label="Engagement Score", minimum=1, maximum=5, value=4, step=0.1)
131
+ with gr.Row():
132
+ Workload = gr.Slider(label="Workload Score", minimum=1, maximum=5, value=4, step=0.1)
133
+ WorkEnvironment = gr.Slider(label="Work Environment Score", minimum=1, maximum=5, value=4, step=0.1)
134
+ Merit = gr.Slider(label="Merit Score", minimum=1, maximum=5, value=4, step=0.1)
135
+
136
+ submit_btn = gr.Button("πŸ”Ž Click Here to Analyze Retention")
137
+
138
+ # Output elements
139
+ prediction = gr.HTML()
140
+ with gr.Row():
141
+ prob_chart = gr.Image(label="Retention vs. Turnover Probability", type="filepath")
142
+ shap_plot = gr.Image(label="Key Drivers of Turnover Risk", type="filepath")
143
+
144
+ # Allow profile selection to update sliders
145
+ def update_sliders(profile):
146
+ if profile in employee_profiles:
147
+ return employee_profiles[profile]
148
+ return [4, 4, 4, 4, 4, 4]
149
+
150
+ profile_dropdown.change(update_sliders, inputs=[profile_dropdown], outputs=[WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit])
151
+
152
+ submit_btn.click(main_func, [WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit], [prediction, prob_chart, shap_plot])
153
+
154
+ demo.launch()