In this example, we apply a classification tree model to the Carseats dataset, sourced from the ISLR2 library. This dataset is a simulated collection of sales data for child car seats across 400 distinct stores:

  • Sales: (Response variable) Unit sales (in thousands) at each location.

  • CompPrice: Price charged by competitor at each location.

  • Income: Community income level (in thousands of dollars).

  • Advertising: Local advertising budget for company at each location (in thousands of dollars).

  • Population: Population size in region (in thousands).

  • Price: Price company charges for car seats at each site.

  • ShelveLoc: A factor with levels Bad, Good and Medium indicating the quality of the shelving location for the car seats at each site.

  • Age: Average age of the local population.

  • Education: Education level at each location.

  • Urban: A factor with levels No and Yes to indicate whether the store is in an urban or rural location.

  • US: A factor with levels No and Yes to indicate whether the store is in the US or not.

In this data, the response variable Sales is a continuous variable. Thus to use a classification trees, we recode it as a binary variable. We use the ifelse() function to create a variable, called High, which takes on a value of Yes if the Sales variable exceeds 8, and takes on a value of No otherwise.

# Load the packages
library(ISLR2)   # Load the package ISLR2
attach(Carseats) # Attach the data set
High <- factor(ifelse(Sales <= 8, "No", "Yes"))
# Merge High variable with the rest of the Carseats data
Carseats <- data.frame(Carseats, High) 

Step 1: Create train/test set

First, we split the observations into a training set (say \(70\%\)) and a test data (say \(30\%\)). Here, we use the sample() function in the base library with its default argument replace = FALSE. For more information about this function, type ?sample.

# Create 70% training set and 30% test set (hold-out-validation)
set.seed(123)
Index <- sample(1:nrow(Carseats), size = 0.7*nrow(Carseats))
train.data <- Carseats[Index,]  # 280 observations
test.data <- Carseats[-Index,]  # 120 observations 

Step 2: Build the model

In this stage, we construct the tree using the training set. We use the rpart() function in the rpart library. For more information about these functions, type ?rpart. Following this, we generate predictions and assess its performance on the test data by computing the confusion matrix:

library(rpart) 
## Warning: package 'rpart' was built under R version 4.3.3
fit <- rpart(High ~.-Sales,  # Don't include the variable "Sales"
             data = train.data,
             method = "class") # classification trees
pred <- predict(fit, newdata = test.data, type = "class")
# Compute accuracy
table_mat <- table(test.data[, "High"], pred)
table_mat 
##      pred
##       No Yes
##   No  56  15
##   Yes 19  30
sum(diag(table_mat))/sum(table_mat)
## [1] 0.7166667

The model accurately predicted 30 child car seats with high sales and correctly classified 56 car seats as not high in sales. However, the model also erroneously classified 19 car seats as not high sales when they were actually high. The misclassification rate is calculated as \(\dfrac{15 + 19}{120}\approx 0.28\), meaning that the model’s accuracy stands at approximately \(72\%\).

Step 3: Model visualizing

To visualize the model output, you may use the rpart.plot() function in the rpart.plot library as follows:

fit
## n= 280 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 280 115 No (0.58928571 0.41071429)  
##    2) ShelveLoc=Bad,Medium 222  70 No (0.68468468 0.31531532)  
##      4) Price>=105.5 150  32 No (0.78666667 0.21333333)  
##        8) Advertising< 13.5 123  18 No (0.85365854 0.14634146)  
##         16) CompPrice< 143.5 101   9 No (0.91089109 0.08910891) *
##         17) CompPrice>=143.5 22   9 No (0.59090909 0.40909091)  
##           34) ShelveLoc=Bad 7   0 No (1.00000000 0.00000000) *
##           35) ShelveLoc=Medium 15   6 Yes (0.40000000 0.60000000) *
##        9) Advertising>=13.5 27  13 Yes (0.48148148 0.51851852)  
##         18) Age>=44.5 16   5 No (0.68750000 0.31250000) *
##         19) Age< 44.5 11   2 Yes (0.18181818 0.81818182) *
##      5) Price< 105.5 72  34 Yes (0.47222222 0.52777778)  
##       10) Advertising< 8.5 44  16 No (0.63636364 0.36363636)  
##         20) Education< 16.5 30   6 No (0.80000000 0.20000000) *
##         21) Education>=16.5 14   4 Yes (0.28571429 0.71428571) *
##       11) Advertising>=8.5 28   6 Yes (0.21428571 0.78571429)  
##         22) Age>=71 7   2 No (0.71428571 0.28571429) *
##         23) Age< 71 21   1 Yes (0.04761905 0.95238095) *
##    3) ShelveLoc=Good 58  13 Yes (0.22413793 0.77586207)  
##      6) Price>=144 7   1 No (0.85714286 0.14285714) *
##      7) Price< 144 51   7 Yes (0.13725490 0.86274510) *
library(rpart.plot) #Plot 'rpart' Models
## Warning: package 'rpart.plot' was built under R version 4.3.3
rpart.plot(fit, type = 2, digits = 3, fallen.leaves = FALSE)

Step 4: Tune the hyper-parameters

The rpart() function for decision tree construction offers several parameters that govern different aspects of the model’s fitting process. These parameters can be managed and customized using the rpart.control() function.

The minimum error associated with the optimal cost complexity value is:

(cp <- fit$cptable[which.min(fit$cptable[, "rel error"]), "CP"])
## [1] 0.01
control <- rpart.control(minsplit = 4, #minimum number of observations in the node before the algorithm perform a split
                         cp = cp)

tune_fit <- rpart(High ~.-Sales, 
             data = train.data,
             method = "class",
             control = control)

pred <- predict(tune_fit, newdata = test.data, type = "class")
# Compute accuracy
table_mat <- table(test.data[, "High"], pred)
table_mat 
##      pred
##       No Yes
##   No  62   9
##   Yes 16  33
sum(diag(table_mat))/sum(table_mat)
## [1] 0.7916667

Using this approach, we achieve a success rate of approximately \(79\%\), surpassing the performance of the previous model.