HR Analytics: Predicting Employee Churn with a Decision Tree
  • AI Chat
  • Code
  • Report
  • Beta
    Spinner

    HR Analytics: Predicting Employee Churn with a Decision Tree

    # Modules needed
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # Seaborn parameters for visualization
    sns.set(rc={"figure.figsize":(10, 5)})
    sns.set_context("notebook")
    sns.set_style("white")
    # Import data
    url = "https://assets.datacamp.com/production/repositories/1765/datasets/ae888d00f9b36dd7d50a4afbc112761e2db766d2/turnover.csv"
    
    data = pd.read_csv(url)
    
    display(data.head())
    display(data.shape)

    Data Validation

    # Check data types
    data.dtypes
    # Check missing values
    data.isnull().sum()
    # Identify columns with binary values
    for x in data.columns:
        print(x, len(data[x].unique()))

    Exploratory Data Analysis

    # Summary statistics for non-binary numeric variables
    data.drop(["work_accident", "churn", "promotion"], axis=1).describe()
    # Churn proportion
    data.churn.value_counts(normalize=True).round(2)
    # Plot churn rate for department
    data.groupby("department").churn.mean().sort_values(ascending=False).plot(kind="bar", title="Churn Rate For Department")
    plt.xticks(rotation=45);
    # Get department dimension
    dep_dimension = data.department.value_counts()
    
    # Get churn rate for department
    churn_mean = data.groupby("department").churn.mean()
    
    # Merge them
    merged = pd.merge(dep_dimension, churn_mean, left_on = dep_dimension.index, right_on=churn_mean.index)
    display(merged.head())
    # Plot churn rate and dimension of the department
    fig, ax = plt.subplots()
    ax.bar(x = merged.key_0, height = merged.department)
    ax.set_ylabel("Dep. Dimension")
    ax.set_xticklabels(merged.key_0, rotation = 45)
    ax2 = ax.twinx()
    ax2.plot(merged.key_0, merged.churn, marker="o", linestyle="--", color="red")
    ax2.set_ylabel("Churn rate")
    plt.legend(["Churn rate"], loc="upper right", prop={'size': 8})
    plt.title("Dep. Dimension and Churn Rate");
    # Plot churn rate for salary category
    data.groupby("salary").churn.mean().sort_values(ascending=False).plot(kind="bar", title="Churn Rate For Salary")
    plt.xticks(rotation=360);
    # Create a pivot table to show chrun rate for department and salary category
    pivot = data.pivot_table(values="churn", index="department", columns="salary", aggfunc="mean").round(2)
    pivot = pivot.reindex(columns=["high", "medium", "low"])
    
    # Palette for heatmap
    pal = sns.light_palette("#8B0001", as_cmap = True)
    
    # Pivot table heatmap
    sns.heatmap(pivot, cmap=pal, annot=True, cbar=False)
    plt.title("Churn Rate for Department and Salary");