In this example, we apply a classification tree model to the Carseats dataset, sourced from the
ISLR2 library. This dataset is a simulated collection
of sales
data for child car seats
across 400
distinct stores:
Sales: (Response variable) Unit sales (in thousands) at each location.
CompPrice: Price charged by competitor at each location.
Income: Community income level (in thousands of dollars).
Advertising: Local advertising budget for company at each location (in thousands of dollars).
Population: Population size in region (in thousands).
Price: Price company charges for car seats at each site.
ShelveLoc: A factor with levels Bad, Good and Medium indicating the quality of the shelving location for the car seats at each site.
Age: Average age of the local population.
Education: Education level at each location.
Urban: A factor with levels No and Yes to indicate whether the store is in an urban or rural location.
US: A factor with levels No and Yes to indicate whether the store is in the US or not.
In this data, the response variable Sales is a continuous variable. Thus to use a classification trees, we recode it as a binary variable. We use the ifelse() function to create a variable, called High, which takes on a value of Yes if the Sales variable exceeds 8, and takes on a value of No otherwise.
# Load the packages
library(ISLR2) # Load the package ISLR2
attach(Carseats) # Attach the data set
High <- factor(ifelse(Sales <= 8, "No", "Yes"))
# Merge High variable with the rest of the Carseats data
Carseats <- data.frame(Carseats, High)
Step 1: Create train/test set
First, we split the observations into a training set (say \(70\%\)) and a test
data (say \(30\%\)). Here, we
use the sample() function in the
base library with its default argument
replace = FALSE
. For more information about this function,
type ?sample
.
# Create 70% training set and 30% test set (hold-out-validation)
set.seed(123)
Index <- sample(1:nrow(Carseats), size = 0.7*nrow(Carseats))
train.data <- Carseats[Index,] # 280 observations
test.data <- Carseats[-Index,] # 120 observations
Step 2: Build the model
In this stage, we construct the tree using the training set. We use
the rpart() function in the
rpart library. For more information about these
functions, type ?rpart
. Following this, we generate
predictions and assess its performance on the test data by computing the
confusion matrix:
library(rpart)
## Warning: package 'rpart' was built under R version 4.3.3
fit <- rpart(High ~.-Sales, # Don't include the variable "Sales"
data = train.data,
method = "class") # classification trees
pred <- predict(fit, newdata = test.data, type = "class")
# Compute accuracy
table_mat <- table(test.data[, "High"], pred)
table_mat
## pred
## No Yes
## No 56 15
## Yes 19 30
sum(diag(table_mat))/sum(table_mat)
## [1] 0.7166667
The model accurately predicted 30 child car seats with high sales and correctly classified 56 car seats as not high in sales. However, the model also erroneously classified 19 car seats as not high sales when they were actually high. The misclassification rate is calculated as \(\dfrac{15 + 19}{120}\approx 0.28\), meaning that the model’s accuracy stands at approximately \(72\%\).
Step 3: Model visualizing
To visualize the model output, you may use the rpart.plot() function in the rpart.plot library as follows:
fit
## n= 280
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 280 115 No (0.58928571 0.41071429)
## 2) ShelveLoc=Bad,Medium 222 70 No (0.68468468 0.31531532)
## 4) Price>=105.5 150 32 No (0.78666667 0.21333333)
## 8) Advertising< 13.5 123 18 No (0.85365854 0.14634146)
## 16) CompPrice< 143.5 101 9 No (0.91089109 0.08910891) *
## 17) CompPrice>=143.5 22 9 No (0.59090909 0.40909091)
## 34) ShelveLoc=Bad 7 0 No (1.00000000 0.00000000) *
## 35) ShelveLoc=Medium 15 6 Yes (0.40000000 0.60000000) *
## 9) Advertising>=13.5 27 13 Yes (0.48148148 0.51851852)
## 18) Age>=44.5 16 5 No (0.68750000 0.31250000) *
## 19) Age< 44.5 11 2 Yes (0.18181818 0.81818182) *
## 5) Price< 105.5 72 34 Yes (0.47222222 0.52777778)
## 10) Advertising< 8.5 44 16 No (0.63636364 0.36363636)
## 20) Education< 16.5 30 6 No (0.80000000 0.20000000) *
## 21) Education>=16.5 14 4 Yes (0.28571429 0.71428571) *
## 11) Advertising>=8.5 28 6 Yes (0.21428571 0.78571429)
## 22) Age>=71 7 2 No (0.71428571 0.28571429) *
## 23) Age< 71 21 1 Yes (0.04761905 0.95238095) *
## 3) ShelveLoc=Good 58 13 Yes (0.22413793 0.77586207)
## 6) Price>=144 7 1 No (0.85714286 0.14285714) *
## 7) Price< 144 51 7 Yes (0.13725490 0.86274510) *
library(rpart.plot) #Plot 'rpart' Models
## Warning: package 'rpart.plot' was built under R version 4.3.3
rpart.plot(fit, type = 2, digits = 3, fallen.leaves = FALSE)
Step 4: Tune the hyper-parameters
The rpart() function for decision tree construction offers several parameters that govern different aspects of the model’s fitting process. These parameters can be managed and customized using the rpart.control() function.
The minimum error associated with the optimal cost complexity value is:
(cp <- fit$cptable[which.min(fit$cptable[, "rel error"]), "CP"])
## [1] 0.01
control <- rpart.control(minsplit = 4, #minimum number of observations in the node before the algorithm perform a split
cp = cp)
tune_fit <- rpart(High ~.-Sales,
data = train.data,
method = "class",
control = control)
pred <- predict(tune_fit, newdata = test.data, type = "class")
# Compute accuracy
table_mat <- table(test.data[, "High"], pred)
table_mat
## pred
## No Yes
## No 62 9
## Yes 16 33
sum(diag(table_mat))/sum(table_mat)
## [1] 0.7916667
Using this approach, we achieve a success rate of approximately \(79\%\), surpassing the performance of the previous model.