hsandvik00 sjalbright commited on
Commit
84f4e93
Β·
verified Β·
1 Parent(s): df4ed14

Update app.py (#10)

Browse files

- Update app.py (8432b4cb6b02bc3be3b379814d92657c0886a931)


Co-authored-by: Sarah Albright <[email protected]>

Files changed (1) hide show
  1. app.py +73 -62
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # with cluster profiles
2
-
3
  import gradio as gr
4
  import pickle
5
  import pandas as pd
@@ -7,23 +5,37 @@ import shap
7
  import matplotlib.pyplot as plt
8
 
9
  # Load model
10
- filename = 'xgb_h_new.pkl'
11
  with open(filename, 'rb') as f:
12
  loaded_model = pickle.load(f)
13
 
14
  # Setup SHAP
15
  explainer = shap.Explainer(loaded_model)
16
 
17
- # Employee Profiles (From SPSS 3-Cluster Solution)
 
 
 
 
 
 
 
 
 
 
18
  employee_profiles = {
19
- "πŸ† Leslie Knope": [4.716, 4.792, 4.864, 4.588, 4.849, 4.601], # Cluster group 1 averages - high engagement, strong support, high workload
20
- "⚠️ Kevin Malone": [3.045, 3.122, 3.129, 2.886, 3.113, 2.197], # Cluster group 2 averages - disengaged, low recognition, weak support
21
- "🌱 Jim Halpert": [3.885, 3.992, 4.119, 3.704, 4.090, 3.377] # Cluster group 3 averages - Moderately engaged, could be more recognized - room to grow
 
22
  }
23
 
24
  # Define the prediction function
25
- def main_func(WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit):
 
 
26
  new_row = pd.DataFrame({
 
27
  'WellBeing': [WellBeing],
28
  'SupportiveGM': [SupportiveGM],
29
  'Engagement': [Engagement],
@@ -40,43 +52,34 @@ def main_func(WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Me
40
  stay_prob = round((1 - float(prob[0][0])) * 100, 2)
41
  leave_prob = round(float(prob[0][0]) * 100, 2)
42
 
43
- # Dynamic risk label: Changes color & text based on probability
44
  risk_label = "πŸ”΄ High Risk of Turnover" if leave_prob > 50 else "🟒 Low Risk of Turnover"
45
  risk_color = "red" if leave_prob > 50 else "green"
46
 
47
  risk_html = f"""
48
- <div style='padding: 15px; border-radius: 8px;'>
49
- <span style='color: {risk_color}; font-size: 26px; font-weight: bold;'>{risk_label}</span>
50
- <ul style='list-style-type: none; padding-left: 0; font-size: 20px; font-weight: bold; color: #0057B8;'>
51
- <li>🧲 Likelihood of Staying: {stay_prob}%</li>
52
- <li>πŸšͺ Likelihood of Leaving: {leave_prob}%</li>
53
- </ul>
54
- </div>
 
 
 
55
  """
56
 
57
- # Key Insights (Updated for 0.1-point increments)
58
- insights_html = "<div style='font-size: 18px;'>"
59
- for feature, shap_val in dict(zip(new_row.columns, shap_values.values[0])).items():
60
- impact = round(shap_val * 10, 2) # Scaling impact for 0.1 changes
 
 
61
  icon = "πŸ“ˆ" if shap_val > 0 else "πŸ“‰"
62
  effect = "raises turnover risk" if shap_val > 0 else "improves retention"
63
- insights_html += f"<p style='margin: 5px 0;'> {icon} <b>Each 0.1-point increase in {feature} {effect} by {abs(impact)}%.</b></p>"
64
- insights_html += "</div>"
65
-
66
- # Final Layout (Risk + Key Insights)
67
- final_layout = f"""
68
- <table style='width:100%; border-collapse: collapse; margin-top: 10px; background-color: #FFFFFF;'>
69
- <tr>
70
- <td style='width: 33%; padding: 15px; background-color: #FFFFFF; border-radius: 8px; vertical-align: top;'>
71
- {risk_html}
72
- </td>
73
- <td style='width: 67%; padding: 15px; background-color: #FFFFFF; border-radius: 8px; vertical-align: top;'>
74
- <b style='color: #0057B8; font-size: 22px;'>Key Insights:</b>
75
- {insights_html}
76
- </td>
77
- </tr>
78
- </table>
79
- """
80
 
81
  # Retention vs. Turnover Chart
82
  fig, ax = plt.subplots()
@@ -93,16 +96,23 @@ def main_func(WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Me
93
  plt.savefig(prob_chart_path, transparent=True)
94
  plt.close()
95
 
96
- # SHAP Chart
97
  fig, ax = plt.subplots()
98
- shap.plots.bar(shap_values[0], max_display=6, show=False)
 
99
  ax.set_title("Key Drivers of Turnover Risk")
100
  plt.tight_layout()
101
  shap_plot_path = "shap_plot.png"
102
  plt.savefig(shap_plot_path, transparent=True)
103
  plt.close()
104
 
105
- return final_layout, prob_chart_path, shap_plot_path
 
 
 
 
 
 
106
 
107
  # UI Setup
108
  with gr.Blocks() as demo:
@@ -117,41 +127,42 @@ with gr.Blocks() as demo:
117
  gr.Markdown("""
118
  <div style='font-size: 20px; color: #0057B8;'>
119
  ✨ <b>Welcome to Hilton’s Employee Retention Predictor</b><br>
120
- This tool helps <b>Sales & Marketing leaders and Front Office Operations teams</b>β€”The Face of Hiltonβ€”
121
- assess <b>team member engagement</b> and predict <b>turnover risk</b> using AI-powered insights.<br>
 
 
122
  πŸ” <b>Understand what drives retention and make data-driven decisions to keep top talent.</b>
123
  </div>
124
- """)
125
-
126
- # Dropdown for Employee Profiles
127
- profile_dropdown = gr.Dropdown(choices=list(employee_profiles.keys()), label="Select Employee Profile")
128
 
129
- # Sliders for input features
 
 
 
130
  with gr.Row():
131
- WellBeing = gr.Slider(label="WellBeing Score", minimum=1, maximum=5, value=4, step=0.1)
132
- SupportiveGM = gr.Slider(label="Supportive GM Score", minimum=1, maximum=5, value=4, step=0.1)
133
- Engagement = gr.Slider(label="Engagement Score", minimum=1, maximum=5, value=4, step=0.1)
134
  with gr.Row():
135
- Workload = gr.Slider(label="Workload Score", minimum=1, maximum=5, value=4, step=0.1)
136
- WorkEnvironment = gr.Slider(label="Work Environment Score", minimum=1, maximum=5, value=4, step=0.1)
137
- Merit = gr.Slider(label="Merit Score", minimum=1, maximum=5, value=4, step=0.1)
138
 
139
  submit_btn = gr.Button("πŸ”Ž Click Here to Analyze Retention")
140
 
141
- # Output elements
142
  prediction = gr.HTML()
 
 
143
  with gr.Row():
144
  prob_chart = gr.Image(label="Retention vs. Turnover Probability", type="filepath")
145
  shap_plot = gr.Image(label="Key Drivers of Turnover Risk", type="filepath")
146
 
147
- # Allow profile selection to update sliders
148
- def update_sliders(profile):
149
- if profile in employee_profiles:
150
- return employee_profiles[profile]
151
- return [4, 4, 4, 4, 4, 4]
152
-
153
  profile_dropdown.change(update_sliders, inputs=[profile_dropdown], outputs=[WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit])
154
 
155
- submit_btn.click(main_func, [WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit], [prediction, prob_chart, shap_plot])
 
156
 
157
- demo.launch()
 
 
 
1
  import gradio as gr
2
  import pickle
3
  import pandas as pd
 
5
  import matplotlib.pyplot as plt
6
 
7
  # Load model
8
+ filename = 'xgb_h_generation.pkl'
9
  with open(filename, 'rb') as f:
10
  loaded_model = pickle.load(f)
11
 
12
  # Setup SHAP
13
  explainer = shap.Explainer(loaded_model)
14
 
15
+ # Generation Mapping (Radio Button Labels β†’ Numeric Values)
16
+ generation_mapping = {
17
+ "Before 1927": 1,
18
+ "Silent Generation": 2,
19
+ "Baby Boomers": 3,
20
+ "Generation X": 4,
21
+ "Millennials": 5,
22
+ "Generation Z": 6
23
+ }
24
+
25
+ # Employee Profiles (Updated Default Dream Employee Values)
26
  employee_profiles = {
27
+ "πŸ₯‡ Default Dream Employee": [5.0, 5.0, 5.0, 4.8, 4.8, 4.9],
28
+ "πŸ† Leslie Knope": [4.716, 4.792, 4.864, 4.588, 4.849, 4.601],
29
+ "⚠️ Kevin Malone": [3.045, 3.122, 3.129, 2.886, 3.113, 2.197],
30
+ "🌱 Jim Halpert": [3.885, 3.992, 4.119, 3.704, 4.090, 3.377]
31
  }
32
 
33
  # Define the prediction function
34
+ def main_func(Generation_Label, WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit):
35
+ Generation = generation_mapping.get(Generation_Label, 5) # Convert label to numeric value
36
+
37
  new_row = pd.DataFrame({
38
+ 'Generation': [Generation],
39
  'WellBeing': [WellBeing],
40
  'SupportiveGM': [SupportiveGM],
41
  'Engagement': [Engagement],
 
52
  stay_prob = round((1 - float(prob[0][0])) * 100, 2)
53
  leave_prob = round(float(prob[0][0]) * 100, 2)
54
 
55
+ # Dynamic risk label
56
  risk_label = "πŸ”΄ High Risk of Turnover" if leave_prob > 50 else "🟒 Low Risk of Turnover"
57
  risk_color = "red" if leave_prob > 50 else "green"
58
 
59
  risk_html = f"""
60
+ <div style='border: 1px solid black; padding: 15px; border-radius: 8px; display: flex;'>
61
+ <div style='width: 50%; padding-right: 15px;'>
62
+ <span style='color: {risk_color}; font-size: 26px; font-weight: bold;'>{risk_label}</span>
63
+ <ul style='list-style-type: none; padding-left: 0; font-size: 20px; font-weight: bold; color: #0057B8;'>
64
+ <li>🧲 Likelihood of Staying: {stay_prob}%</li>
65
+ <li>πŸšͺ Likelihood of Leaving: {leave_prob}%</li>
66
+ </ul>
67
+ </div>
68
+ <div style='width: 50%; border-left: 1px solid black; padding-left: 15px;'>
69
+ <b style='color: #0057B8; font-size: 22px;'>Key Insights:</b>
70
  """
71
 
72
+ # Key Insights (excluding Generation)
73
+ shap_values_df = pd.DataFrame(shap_values.values, columns=new_row.columns)
74
+ shap_values_df = shap_values_df.drop(columns=["Generation"]) # Drop Generation
75
+ for feature in shap_values_df.columns:
76
+ shap_val = shap_values_df[feature].values[0]
77
+ impact = round(shap_val * 10, 2)
78
  icon = "πŸ“ˆ" if shap_val > 0 else "πŸ“‰"
79
  effect = "raises turnover risk" if shap_val > 0 else "improves retention"
80
+ risk_html += f"<p style='margin: 5px 0;'> {icon} <b>Each 0.1-point increase in {feature} {effect} by {abs(impact)}%.</b></p>"
81
+
82
+ risk_html += "</div></div>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Retention vs. Turnover Chart
85
  fig, ax = plt.subplots()
 
96
  plt.savefig(prob_chart_path, transparent=True)
97
  plt.close()
98
 
99
+ # SHAP Chart (excluding Generation)
100
  fig, ax = plt.subplots()
101
+ shap_values_filtered = shap_values[:, 1:] # Remove Generation from SHAP values
102
+ shap.plots.bar(shap_values_filtered[0], max_display=6, show=False) # Adjust max_display if needed
103
  ax.set_title("Key Drivers of Turnover Risk")
104
  plt.tight_layout()
105
  shap_plot_path = "shap_plot.png"
106
  plt.savefig(shap_plot_path, transparent=True)
107
  plt.close()
108
 
109
+ return risk_html, prob_chart_path, shap_plot_path
110
+
111
+ # Function to update sliders based on selected profile
112
+ def update_sliders(profile):
113
+ if profile in employee_profiles:
114
+ return employee_profiles[profile]
115
+ return [5.0, 5.0, 5.0, 4.8, 4.8, 4.9]
116
 
117
  # UI Setup
118
  with gr.Blocks() as demo:
 
127
  gr.Markdown("""
128
  <div style='font-size: 20px; color: #0057B8;'>
129
  ✨ <b>Welcome to Hilton’s Employee Retention Predictor</b><br>
130
+ This tool helps <b>HR and People Analytics professionals</b> assess
131
+ <b>Sales, Marketing, and Front Office Operations teams</b>β€”<span style='color: #0057B8;'>The Face of Hilton</span>β€”
132
+ by analyzing <b>team member engagement</b> and predicting <b>turnover risk</b> using
133
+ <span style='color: #0057B8;'>AI-powered insights</span>.<br>
134
  πŸ” <b>Understand what drives retention and make data-driven decisions to keep top talent.</b>
135
  </div>
136
+ """)
137
+
138
+ # Generation Filter (Radio Button - Independent)
139
+ generation_filter = gr.Radio(choices=list(generation_mapping.keys()), label="Select Generation", value="Millennials")
140
 
141
+ # Dropdown for Employee Profiles (Updates Sliders)
142
+ profile_dropdown = gr.Dropdown(choices=list(employee_profiles.keys()), label="Select Employee Profile", value="πŸ₯‡ Default Dream Employee")
143
+
144
+ # Sliders for input features (Updated Order)
145
  with gr.Row():
146
+ WellBeing = gr.Slider(label="WellBeing Score", minimum=1, maximum=5, value=5.0, step=0.1)
147
+ SupportiveGM = gr.Slider(label="Supportive GM Score", minimum=1, maximum=5, value=5.0, step=0.1)
148
+ Engagement = gr.Slider(label="Engagement Score", minimum=1, maximum=5, value=5.0, step=0.1)
149
  with gr.Row():
150
+ Workload = gr.Slider(label="Workload Score", minimum=1, maximum=5, value=4.8, step=0.1)
151
+ WorkEnvironment = gr.Slider(label="Work Environment Score", minimum=1, maximum=5, value=4.8, step=0.1)
152
+ Merit = gr.Slider(label="Merit Score", minimum=1, maximum=5, value=4.9, step=0.1)
153
 
154
  submit_btn = gr.Button("πŸ”Ž Click Here to Analyze Retention")
155
 
 
156
  prediction = gr.HTML()
157
+
158
+ # Charts Side by Side
159
  with gr.Row():
160
  prob_chart = gr.Image(label="Retention vs. Turnover Probability", type="filepath")
161
  shap_plot = gr.Image(label="Key Drivers of Turnover Risk", type="filepath")
162
 
 
 
 
 
 
 
163
  profile_dropdown.change(update_sliders, inputs=[profile_dropdown], outputs=[WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit])
164
 
165
+ submit_btn.click(main_func, [generation_filter, WellBeing, SupportiveGM, Engagement, Workload, WorkEnvironment, Merit],
166
+ [prediction, prob_chart, shap_plot])
167
 
168
+ demo.launch()