jaleesahmed's picture
app
72ac741
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import LabelEncoder
import seaborn as sns
import gradio as gr
plt.switch_backend('Agg')
pd.options.display.max_columns = 25
pd.options.display.max_rows = 300
def outbreak(plot_type):
df = pd.read_csv('emp_experience_data.csv')
data_encoded = df.copy(deep=True)
categorical_column = ['Attrition', 'Gender', 'BusinessTravel', 'Education', 'EmployeeExperience', 'EmployeeFeedbackSentiments', 'Designation',
'SalarySatisfaction', 'HealthBenefitsSatisfaction', 'UHGDiscountProgramUsage', 'HealthConscious', 'CareerPathSatisfaction', 'Region']
label_encoding = LabelEncoder()
for col in categorical_column:
data_encoded[col] = label_encoding.fit_transform(data_encoded[col])
if plot_type == "Find Data Correlation":
fig = plt.figure()
data_correlation = data_encoded.corr()
sns.heatmap(data_correlation, xticklabels = data_correlation.columns, yticklabels = data_correlation.columns)
return fig
if plot_type == "Filter Correlation Data":
fig = plt.figure()
filtered_df = df[['EmployeeExperience', 'EmployeeFeedbackSentiments', 'Age', 'SalarySatisfaction', 'BusinessTravel', 'HealthBenefitsSatisfaction']]
correlation_filter_data = filtered_df.corr()
sns.heatmap(correlation_filter_data, xticklabels = filtered_df.columns, yticklabels = filtered_df.columns)
return fig
if plot_type == "Age vs Attrition":
fig = plt.figure()
plt.hist(data_encoded['Age'], bins=np.arange(0,80,10), alpha=0.8, rwidth=0.9, color='red')
plt.xlabel("Age")
plt.ylabel("Count")
plt.title("Age vs Attrition")
return fig
if plot_type == "Business Travel vs Attrition":
fig = plt.figure()
ax = sns.countplot(x="BusinessTravel", hue="Attrition", data=data_encoded)
for p in ax.patches:
ax.annotate('{}'.format(p.get_height()), (p.get_x(), p.get_height()+1))
return fig
if plot_type == "Employee Experience vs Attrition":
fig = plt.figure()
ax = sns.countplot(x="EmployeeExperience", hue="Attrition", data=data_encoded)
for p in ax.patches:
ax.annotate('{}'.format(p.get_height()), (p.get_x(), p.get_height()+1))
return fig
inputs = [
gr.Dropdown(["Find Data Correlation", "Filter Correlation Data", "Business Travel vs Attrition", "Employee Experience vs Attrition", "Age vs Attrition",], label="Data Correlation and Visualization")
]
outputs = gr.Plot()
demo = gr.Interface(
fn = outbreak,
inputs = inputs,
outputs = outputs,
title="Employee-Experience: Data Correlation and Pattern Visualization",
allow_flagging=False
)
if __name__ == "__main__":
demo.launch()