In this example, we apply a classification tree model to the Carseats dataset, sourced from the
ISLR2 library. This dataset is a simulated collection
of sales
data for child car seats
across 400
distinct stores:
Sales: (Response variable) Unit sales (in thousands) at each location.
CompPrice: Price charged by competitor at each location.
Income: Community income level (in thousands of dollars).
Advertising: Local advertising budget for company at each location (in thousands of dollars).
Population: Population size in region (in thousands).
Price: Price company charges for car seats at each site.
ShelveLoc: A factor with levels Bad, Good and Medium indicating the quality of the shelving location for the car seats at each site.
Age: Average age of the local population.
Education: Education level at each location.
Urban: A factor with levels No and Yes to indicate whether the store is in an urban or rural location.
US: A factor with levels No and Yes to indicate whether the store is in the US or not.
In this data, the response variable Sales is a continuous variable. Thus to use a classification trees, we recode it as a binary variable. We use the ifelse() function to create a variable, called High, which takes on a value of Yes if the Sales variable exceeds 8, and takes on a value of No otherwise.
# Load the packages
library(ISLR2) # Load the package ISLR2
attach(Carseats) # Attach the data set
High <- factor(ifelse(Sales <= 8, "No", "Yes"))
# Merge High variable with the rest of the Carseats data
Carseats <- data.frame(Carseats, High)
In this example, we fit a classification tree on the train data using
the train() function in the
caret library. The package caret call
the rpart() function from the package
rpart and train the model through cross-validation. For
more information about these functions, type ?rpart
and
?train
.
Step 1: Create train/test set
First, we split the observations into a training set (say \(70\%\)) and a test data (say \(30\%\)).
# Create 70% training set and 30% test set (hold-out-validation)
library(caret) #Load the package
set.seed(123)
Index <- createDataPartition(Carseats[,"High"], p = 0.7, list = FALSE)
train.data <- Carseats[Index,] # 281 observations
test.data <- Carseats[-Index,] # 119 observations
Step 2: Build the model
In this stage, we construct the tree using the training set. Following this, we generate predictions and assess its performance on the test data by computing the confusion matrix:
library(rpart)
fit <- rpart(High ~.-Sales,
data = train.data,
method = "class")
pred <- predict(fit, newdata = test.data, type = "class")
# Compute accuracy
table_mat <- table(test.data[, "High"], pred)
table_mat
## pred
## No Yes
## No 59 11
## Yes 21 28
sum(diag(table_mat))/sum(table_mat)
## [1] 0.7310924
The model accurately predicted 28 child car seats with high sales and correctly classified 25 car seats as not high in sales. However, the model also erroneously classified 21 car seats as high sales when they were actually not high. The misclassification rate is calculated as \(\dfrac{21 + 11}{119}\approx 0.27\), meaning that the model’s accuracy stands at approximately \(73\%\).
Step 3: Model visualizing
To visualize the model output, you may use the rpart.plot() function in the rpart.plot library as follows:
fit
## n= 281
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 281 115 No (0.5907473 0.4092527)
## 2) ShelveLoc=Bad,Medium 218 65 No (0.7018349 0.2981651)
## 4) Price>=106.5 139 23 No (0.8345324 0.1654676)
## 8) Advertising< 13.5 113 12 No (0.8938053 0.1061947) *
## 9) Advertising>=13.5 26 11 No (0.5769231 0.4230769)
## 18) Price>=127 14 3 No (0.7857143 0.2142857) *
## 19) Price< 127 12 4 Yes (0.3333333 0.6666667) *
## 5) Price< 106.5 79 37 Yes (0.4683544 0.5316456)
## 10) Advertising< 7.5 40 13 No (0.6750000 0.3250000)
## 20) CompPrice< 125.5 30 6 No (0.8000000 0.2000000) *
## 21) CompPrice>=125.5 10 3 Yes (0.3000000 0.7000000) *
## 11) Advertising>=7.5 39 10 Yes (0.2564103 0.7435897)
## 22) Age>=71 9 3 No (0.6666667 0.3333333) *
## 23) Age< 71 30 4 Yes (0.1333333 0.8666667) *
## 3) ShelveLoc=Good 63 13 Yes (0.2063492 0.7936508) *
library(rpart.plot) #Plot 'rpart' Models
rpart.plot(fit, type = 2, digits = 3, fallen.leaves = FALSE)
Step 4: Tune the hyper-parameters
The rpart() function for decision tree construction offers several parameters that govern different aspects of the model’s fitting process. These parameters can be managed and customized using the rpart.control() function.
The minimum error associated with the optimal cost complexity value is:
(cp <- fit$cptable[which.min(fit$cptable[, "rel error"]), "CP"])
## [1] 0.01
control <- rpart.control(minsplit = 4, #minimum number of observations in the node before the algorithm perform a split
cp = cp)
tune_fit <- rpart(High ~.-Sales,
data = train.data,
method = "class",
control = control)
pred <- predict(tune_fit, newdata = test.data, type = "class")
# Compute accuracy
table_mat <- table(test.data[, "High"], pred)
table_mat
## pred
## No Yes
## No 59 11
## Yes 18 31
sum(diag(table_mat))/sum(table_mat)
## [1] 0.7563025
Using this approach, we achieve a success rate of approximately \(76\%\), surpassing the performance of the previous model.