In this example, we apply a classification tree model to the Carseats dataset, sourced from the ISLR2 library. This dataset is a simulated collection of sales data for child car seats across 400 distinct stores:

  • Sales: (Response variable) Unit sales (in thousands) at each location.

  • CompPrice: Price charged by competitor at each location.

  • Income: Community income level (in thousands of dollars).

  • Advertising: Local advertising budget for company at each location (in thousands of dollars).

  • Population: Population size in region (in thousands).

  • Price: Price company charges for car seats at each site.

  • ShelveLoc: A factor with levels Bad, Good and Medium indicating the quality of the shelving location for the car seats at each site.

  • Age: Average age of the local population.

  • Education: Education level at each location.

  • Urban: A factor with levels No and Yes to indicate whether the store is in an urban or rural location.

  • US: A factor with levels No and Yes to indicate whether the store is in the US or not.

In this data, the response variable Sales is a continuous variable. Thus to use a classification trees, we recode it as a binary variable. We use the ifelse() function to create a variable, called High, which takes on a value of Yes if the Sales variable exceeds 8, and takes on a value of No otherwise.

# Load the packages
library(ISLR2)   # Load the package ISLR2
attach(Carseats) # Attach the data set
High <- factor(ifelse(Sales <= 8, "No", "Yes"))
# Merge High variable with the rest of the Carseats data
Carseats <- data.frame(Carseats, High) 

In this example, we fit a classification tree on the train data using the train() function in the caret library. The package caret call the rpart() function from the package rpart and train the model through cross-validation. For more information about these functions, type ?rpart and ?train.

Step 1: Create train/test set

First, we split the observations into a training set (say \(70\%\)) and a test data (say \(30\%\)).

# Create 70% training set and 30% test set (hold-out-validation)
library(caret) #Load the package
set.seed(123)
Index <- createDataPartition(Carseats[,"High"], p = 0.7, list = FALSE)
train.data <- Carseats[Index,]  # 281 observations
test.data <- Carseats[-Index,]  # 119 observations 

Step 2: Build the model

In this stage, we construct the tree using the training set. Following this, we generate predictions and assess its performance on the test data by computing the confusion matrix:

library(rpart) 
fit <- rpart(High ~.-Sales, 
             data = train.data,
             method = "class")
pred <- predict(fit, newdata = test.data, type = "class")
# Compute accuracy
table_mat <- table(test.data[, "High"], pred)
table_mat 
##      pred
##       No Yes
##   No  59  11
##   Yes 21  28
sum(diag(table_mat))/sum(table_mat)
## [1] 0.7310924

The model accurately predicted 28 child car seats with high sales and correctly classified 25 car seats as not high in sales. However, the model also erroneously classified 21 car seats as high sales when they were actually not high. The misclassification rate is calculated as \(\dfrac{21 + 11}{119}\approx 0.27\), meaning that the model’s accuracy stands at approximately \(73\%\).

Step 3: Model visualizing

To visualize the model output, you may use the rpart.plot() function in the rpart.plot library as follows:

fit
## n= 281 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 281 115 No (0.5907473 0.4092527)  
##    2) ShelveLoc=Bad,Medium 218  65 No (0.7018349 0.2981651)  
##      4) Price>=106.5 139  23 No (0.8345324 0.1654676)  
##        8) Advertising< 13.5 113  12 No (0.8938053 0.1061947) *
##        9) Advertising>=13.5 26  11 No (0.5769231 0.4230769)  
##         18) Price>=127 14   3 No (0.7857143 0.2142857) *
##         19) Price< 127 12   4 Yes (0.3333333 0.6666667) *
##      5) Price< 106.5 79  37 Yes (0.4683544 0.5316456)  
##       10) Advertising< 7.5 40  13 No (0.6750000 0.3250000)  
##         20) CompPrice< 125.5 30   6 No (0.8000000 0.2000000) *
##         21) CompPrice>=125.5 10   3 Yes (0.3000000 0.7000000) *
##       11) Advertising>=7.5 39  10 Yes (0.2564103 0.7435897)  
##         22) Age>=71 9   3 No (0.6666667 0.3333333) *
##         23) Age< 71 30   4 Yes (0.1333333 0.8666667) *
##    3) ShelveLoc=Good 63  13 Yes (0.2063492 0.7936508) *
library(rpart.plot) #Plot 'rpart' Models
rpart.plot(fit, type = 2, digits = 3, fallen.leaves = FALSE)

Step 4: Tune the hyper-parameters

The rpart() function for decision tree construction offers several parameters that govern different aspects of the model’s fitting process. These parameters can be managed and customized using the rpart.control() function.

The minimum error associated with the optimal cost complexity value is:

(cp <- fit$cptable[which.min(fit$cptable[, "rel error"]), "CP"])
## [1] 0.01
control <- rpart.control(minsplit = 4, #minimum number of observations in the node before the algorithm perform a split
                         cp = cp)

tune_fit <- rpart(High ~.-Sales, 
             data = train.data,
             method = "class",
             control = control)

pred <- predict(tune_fit, newdata = test.data, type = "class")
# Compute accuracy
table_mat <- table(test.data[, "High"], pred)
table_mat 
##      pred
##       No Yes
##   No  59  11
##   Yes 18  31
sum(diag(table_mat))/sum(table_mat)
## [1] 0.7563025

Using this approach, we achieve a success rate of approximately \(76\%\), surpassing the performance of the previous model.