hsandvik00 commited on
Commit
1c2b61a
Β·
verified Β·
1 Parent(s): 518b8bd

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +290 -0
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "68bddf84-439e-461b-b7f9-4c9212ca81f5",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "%%capture --no-display\n",
11
+ "pip install gradio\n"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "id": "d692b9ae-c7d4-40b6-9777-992a4303055b",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "import gradio as gr"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 3,
27
+ "id": "9742e1cb-7e40-4af9-8c59-6a341c9d9b38",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "import pickle\n",
32
+ "import pandas as pd\n",
33
+ "import shap\n",
34
+ "from shap.plots._force_matplotlib import draw_additive_plot\n",
35
+ "import numpy as np\n",
36
+ "import matplotlib.pyplot as plt\n",
37
+ "from sklearn.preprocessing import LabelEncoder, MinMaxScaler, StandardScaler # Importing function for scaling the data"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "id": "7ab0387d-3a31-4164-a46e-b1bb57dbf33a",
43
+ "metadata": {},
44
+ "source": [
45
+ "# App Code"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 4,
51
+ "id": "b168e539-83f7-4bb6-a112-fe8de8fdab52",
52
+ "metadata": {},
53
+ "outputs": [
54
+ {
55
+ "data": {
56
+ "text/plain": [
57
+ "('WellBeing',\n",
58
+ " 'SupportiveGM',\n",
59
+ " 'Engagement',\n",
60
+ " 'Workload',\n",
61
+ " 'WorkEnvironment',\n",
62
+ " 'Merit')"
63
+ ]
64
+ },
65
+ "execution_count": 4,
66
+ "metadata": {},
67
+ "output_type": "execute_result"
68
+ }
69
+ ],
70
+ "source": [
71
+ "'WellBeing', 'SupportiveGM', 'Engagement', 'Workload', 'WorkEnvironment', 'Merit'"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": 5,
77
+ "id": "fc97600d-3c32-4d66-a60e-b416c170293b",
78
+ "metadata": {},
79
+ "outputs": [
80
+ {
81
+ "name": "stdout",
82
+ "output_type": "stream",
83
+ "text": [
84
+ "* Running on local URL: http://127.0.0.1:7860\n",
85
+ "* Running on public URL: https://35c8546aedfc1b28e8.gradio.live\n",
86
+ "\n",
87
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
88
+ ]
89
+ },
90
+ {
91
+ "data": {
92
+ "text/html": [
93
+ "<div><iframe src=\"https://35c8546aedfc1b28e8.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
94
+ ],
95
+ "text/plain": [
96
+ "<IPython.core.display.HTML object>"
97
+ ]
98
+ },
99
+ "metadata": {},
100
+ "output_type": "display_data"
101
+ },
102
+ {
103
+ "data": {
104
+ "text/plain": []
105
+ },
106
+ "execution_count": 5,
107
+ "metadata": {},
108
+ "output_type": "execute_result"
109
+ }
110
+ ],
111
+ "source": [
112
+ "import gradio as gr\n",
113
+ "import pickle\n",
114
+ "import pandas as pd\n",
115
+ "import shap\n",
116
+ "import matplotlib.pyplot as plt\n",
117
+ "\n",
118
+ "# Load model\n",
119
+ "filename = 'xgb_h_new.pkl'\n",
120
+ "with open(filename, 'rb') as f:\n",
121
+ " loaded_model = pickle.load(f)\n",
122
+ "\n",
123
+ "# Setup SHAP\n",
124
+ "explainer = shap.Explainer(loaded_model)\n",
125
+ "\n",
126
+ "# Employee Profiles (Adjusted Top Performer)\n",
127
+ "employee_profiles = {\n",
128
+ " \"πŸ“ˆ High Potential Employee\": [4, 5, 5, 3, 4, 5],\n",
129
+ " \"πŸ† Top Performer\": [5, 5, 5, 3, 5, 5], # Reduced workload\n",
130
+ " \"⚠️ At-Risk Employee\": [2, 2, 2, 4, 2, 2],\n",
131
+ " \"πŸ”₯ Burnt-Out Employee\": [1, 2, 2, 5, 1, 1]\n",
132
+ "}\n",
133
+ "\n",
134
+ "# Define the prediction function\n",
135
+ "def main_func(WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit):\n",
136
+ " new_row = pd.DataFrame({\n",
137
+ " 'WellBeing': [WellBeing],\n",
138
+ " 'SupportiveGM': [SupportiveGM],\n",
139
+ " 'Engagement': [Engagement],\n",
140
+ " 'Workload': [Workload],\n",
141
+ " 'WorkEnvironment': [WorkEnvironment],\n",
142
+ " 'Merit': [Merit]\n",
143
+ " })\n",
144
+ "\n",
145
+ " # Predict probability\n",
146
+ " prob = loaded_model.predict_proba(new_row)\n",
147
+ " shap_values = explainer(new_row)\n",
148
+ "\n",
149
+ " # Calculate probability values\n",
150
+ " stay_prob = round((1 - float(prob[0][0])) * 100, 2)\n",
151
+ " leave_prob = round(float(prob[0][0]) * 100, 2)\n",
152
+ "\n",
153
+ " # Dynamic risk label: Changes color & text based on probability\n",
154
+ " risk_label = \"πŸ”΄ High Risk of Turnover\" if leave_prob > 50 else \"🟒 Low Risk of Turnover\"\n",
155
+ " risk_color = \"red\" if leave_prob > 50 else \"green\"\n",
156
+ "\n",
157
+ " risk_html = f\"\"\"\n",
158
+ " <div style='padding: 15px; border-radius: 8px;'>\n",
159
+ " <span style='color: {risk_color}; font-size: 26px; font-weight: bold;'>{risk_label}</span>\n",
160
+ " <ul style='list-style-type: none; padding-left: 0; font-size: 20px; font-weight: bold; color: #0057B8;'>\n",
161
+ " <li>🧲 Likelihood of Staying: {stay_prob}%</li>\n",
162
+ " <li>πŸšͺ Likelihood of Leaving: {leave_prob}%</li>\n",
163
+ " </ul>\n",
164
+ " </div>\n",
165
+ " \"\"\"\n",
166
+ "\n",
167
+ " # Key Insights (Updated for 0.1-point increments)\n",
168
+ " insights_html = \"<div style='font-size: 18px;'>\"\n",
169
+ " for feature, shap_val in dict(zip(new_row.columns, shap_values.values[0])).items():\n",
170
+ " impact = round(shap_val * 10, 2) # Scaling impact for 0.1 changes\n",
171
+ " icon = \"πŸ“ˆ\" if shap_val > 0 else \"πŸ“‰\"\n",
172
+ " effect = \"raises turnover risk\" if shap_val > 0 else \"improves retention\"\n",
173
+ " insights_html += f\"<p style='margin: 5px 0;'> {icon} <b>Each 0.1-point increase in {feature} {effect} by {abs(impact)}%.</b></p>\"\n",
174
+ " insights_html += \"</div>\"\n",
175
+ "\n",
176
+ " # Final Layout (Risk + Key Insights)\n",
177
+ " final_layout = f\"\"\"\n",
178
+ " <table style='width:100%; border-collapse: collapse; margin-top: 10px; background-color: #FFFFFF;'>\n",
179
+ " <tr>\n",
180
+ " <td style='width: 33%; padding: 15px; background-color: #FFFFFF; border-radius: 8px; vertical-align: top;'>\n",
181
+ " {risk_html}\n",
182
+ " </td>\n",
183
+ " <td style='width: 67%; padding: 15px; background-color: #FFFFFF; border-radius: 8px; vertical-align: top;'>\n",
184
+ " <b style='color: #0057B8; font-size: 22px;'>Key Insights:</b>\n",
185
+ " {insights_html}\n",
186
+ " </td>\n",
187
+ " </tr>\n",
188
+ " </table>\n",
189
+ " \"\"\"\n",
190
+ "\n",
191
+ " # Retention vs. Turnover Chart\n",
192
+ " fig, ax = plt.subplots()\n",
193
+ " categories = [\"Stay\", \"Leave\"]\n",
194
+ " values = [stay_prob, leave_prob]\n",
195
+ " colors = [\"#0057B8\", \"#D43F00\"]\n",
196
+ " ax.barh(categories, values, color=colors)\n",
197
+ " for i, v in enumerate(values):\n",
198
+ " ax.text(v + 2, i, f\"{v:.2f}%\", va='center', fontweight='bold', fontsize=12)\n",
199
+ " ax.set_xlabel(\"Probability (%)\")\n",
200
+ " ax.set_title(\"Retention vs. Turnover Probability\")\n",
201
+ " plt.tight_layout()\n",
202
+ " prob_chart_path = \"prob_chart.png\"\n",
203
+ " plt.savefig(prob_chart_path, transparent=True)\n",
204
+ " plt.close()\n",
205
+ "\n",
206
+ " # SHAP Chart\n",
207
+ " fig, ax = plt.subplots()\n",
208
+ " shap.plots.bar(shap_values[0], max_display=6, show=False)\n",
209
+ " ax.set_title(\"Key Drivers of Turnover Risk\")\n",
210
+ " plt.tight_layout()\n",
211
+ " shap_plot_path = \"shap_plot.png\"\n",
212
+ " plt.savefig(shap_plot_path, transparent=True)\n",
213
+ " plt.close()\n",
214
+ "\n",
215
+ " return final_layout, prob_chart_path, shap_plot_path\n",
216
+ "\n",
217
+ "# UI Setup\n",
218
+ "with gr.Blocks() as demo:\n",
219
+ " gr.Markdown(\"\"\"\n",
220
+ " <div style=\"display: flex; justify-content: center; align-items: center;\">\n",
221
+ " <img src=\"https://logos-world.net/wp-content/uploads/2021/02/Hilton-Logo.png\" width=\"250px\">\n",
222
+ " </div>\n",
223
+ " \"\"\")\n",
224
+ " gr.Markdown(\"<h1 style='color: #0057B8;'>Hilton Team Member Retention Predictor</h1>\")\n",
225
+ " gr.Markdown(\"\"\"\n",
226
+ " <div style='font-size: 20px; color: #0057B8;'>\n",
227
+ " ✨ <b>Welcome to Hilton’s Employee Retention Predictor</b><br>\n",
228
+ " This tool helps <b>HR leaders & managers</b> assess <b>team member engagement</b> \n",
229
+ " and predict <b>turnover risk</b> using AI-powered insights.<br> \n",
230
+ " πŸ” <b>See what factors drive retention & make data-driven decisions.</b> \n",
231
+ " </div>\n",
232
+ " \"\"\")\n",
233
+ "\n",
234
+ " # Dropdown for Employee Profiles\n",
235
+ " profile_dropdown = gr.Dropdown(choices=list(employee_profiles.keys()), label=\"Select Employee Profile\")\n",
236
+ "\n",
237
+ " # Sliders for input features\n",
238
+ " with gr.Row():\n",
239
+ " WellBeing = gr.Slider(label=\"WellBeing Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
240
+ " SupportiveGM = gr.Slider(label=\"Supportive GM Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
241
+ " Engagement = gr.Slider(label=\"Engagement Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
242
+ " with gr.Row():\n",
243
+ " Workload = gr.Slider(label=\"Workload Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
244
+ " WorkEnvironment = gr.Slider(label=\"Work Environment Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
245
+ " Merit = gr.Slider(label=\"Merit Score\", minimum=1, maximum=5, value=4, step=0.1)\n",
246
+ "\n",
247
+ " submit_btn = gr.Button(\"πŸ”Ž Click Here to Analyze Retention\")\n",
248
+ "\n",
249
+ " # Output elements\n",
250
+ " prediction = gr.HTML()\n",
251
+ " with gr.Row():\n",
252
+ " prob_chart = gr.Image(label=\"Retention vs. Turnover Probability\", type=\"filepath\")\n",
253
+ " shap_plot = gr.Image(label=\"Key Drivers of Turnover Risk\", type=\"filepath\")\n",
254
+ "\n",
255
+ " # Allow profile selection to update sliders\n",
256
+ " def update_sliders(profile):\n",
257
+ " if profile in employee_profiles:\n",
258
+ " return employee_profiles[profile]\n",
259
+ " return [4, 4, 4, 4, 4, 4]\n",
260
+ "\n",
261
+ " profile_dropdown.change(update_sliders, inputs=[profile_dropdown], outputs=[WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit])\n",
262
+ "\n",
263
+ " submit_btn.click(main_func, [WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit], [prediction, prob_chart, shap_plot])\n",
264
+ "\n",
265
+ "demo.launch(share=True)\n"
266
+ ]
267
+ }
268
+ ],
269
+ "metadata": {
270
+ "kernelspec": {
271
+ "display_name": "Python 3 (ipykernel)",
272
+ "language": "python",
273
+ "name": "python3"
274
+ },
275
+ "language_info": {
276
+ "codemirror_mode": {
277
+ "name": "ipython",
278
+ "version": 3
279
+ },
280
+ "file_extension": ".py",
281
+ "mimetype": "text/x-python",
282
+ "name": "python",
283
+ "nbconvert_exporter": "python",
284
+ "pygments_lexer": "ipython3",
285
+ "version": "3.11.9"
286
+ }
287
+ },
288
+ "nbformat": 4,
289
+ "nbformat_minor": 5
290
+ }