Beta
HR Analytics: Predicting Employee Churn with a Decision Tree
# Modules needed
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# Seaborn parameters for visualization
sns.set(rc={"figure.figsize":(10, 5)})
sns.set_context("notebook")
sns.set_style("white")
# Import data
url = "https://assets.datacamp.com/production/repositories/1765/datasets/ae888d00f9b36dd7d50a4afbc112761e2db766d2/turnover.csv"
data = pd.read_csv(url)
display(data.head())
display(data.shape)
Data Validation
# Check data types
data.dtypes
# Check missing values
data.isnull().sum()
# Identify columns with binary values
for x in data.columns:
print(x, len(data[x].unique()))
Exploratory Data Analysis
# Summary statistics for non-binary numeric variables
data.drop(["work_accident", "churn", "promotion"], axis=1).describe()
# Churn proportion
data.churn.value_counts(normalize=True).round(2)
# Plot churn rate for department
data.groupby("department").churn.mean().sort_values(ascending=False).plot(kind="bar", title="Churn Rate For Department")
plt.xticks(rotation=45);
# Get department dimension
dep_dimension = data.department.value_counts()
# Get churn rate for department
churn_mean = data.groupby("department").churn.mean()
# Merge them
merged = pd.merge(dep_dimension, churn_mean, left_on = dep_dimension.index, right_on=churn_mean.index)
display(merged.head())
# Plot churn rate and dimension of the department
fig, ax = plt.subplots()
ax.bar(x = merged.key_0, height = merged.department)
ax.set_ylabel("Dep. Dimension")
ax.set_xticklabels(merged.key_0, rotation = 45)
ax2 = ax.twinx()
ax2.plot(merged.key_0, merged.churn, marker="o", linestyle="--", color="red")
ax2.set_ylabel("Churn rate")
plt.legend(["Churn rate"], loc="upper right", prop={'size': 8})
plt.title("Dep. Dimension and Churn Rate");
# Plot churn rate for salary category
data.groupby("salary").churn.mean().sort_values(ascending=False).plot(kind="bar", title="Churn Rate For Salary")
plt.xticks(rotation=360);
# Create a pivot table to show chrun rate for department and salary category
pivot = data.pivot_table(values="churn", index="department", columns="salary", aggfunc="mean").round(2)
pivot = pivot.reindex(columns=["high", "medium", "low"])
# Palette for heatmap
pal = sns.light_palette("#8B0001", as_cmap = True)
# Pivot table heatmap
sns.heatmap(pivot, cmap=pal, annot=True, cbar=False)
plt.title("Churn Rate for Department and Salary");