In this example, we apply bagging and random forests to the Boston dataset, sourced from the ISLR2 library. Data has 506 rows and 13 variables:
crim: per capita crime rate by town.
zn: proportion of residential land zoned for lots over 25,000 sq.ft.
indus: proportion of non-retail business acres per town.
chas: Charles River dummy variable (= 1 if tract bounds river; 0 otherwise).
nox: nitrogen oxides concentration (parts per 10 million).
rm: average number of rooms per dwelling.
age: proportion of owner-occupied units built prior to 1940.
dis: weighted mean of distances to five Boston employment centres.
rad: index of accessibility to radial highways.
tax: full-value property-tax rate per $10,000.
ptratio: pupil-teacher ratio by town.
lstat: lower status of the population (percent).
medv: (Response variable) median value of owner-occupied homes in $1000s.
Load the required libraries and create training/test sets.
# Load the required packages
library(ISLR2) # Introduction to Statistical Learning, Second Edition
library(caret) # Classification And REgression Training
library(rpart.plot) # Plot 'rpart' Models
# Create 50% training set and 50% test set (hold-out-validation)
set.seed(1)
Index <- sample(1:nrow(Boston), nrow(Boston)/2)
# Alternatively, we may use the createDataPartition() function as below:
# Index <- createDataPartition(Boston[,"medv"], p = 0.5, list = FALSE)
train.data <- Boston[Index,] # 253 observations
test.data <- Boston[-Index,] # 253 observations
Fit the tree to the training data using the train() function in the caret
library. The package caret call the rpart() function from the package
rpart and train the model through cross-validation. For
more information about these functions, type ?rpart
and
?train
. The important arguments of the train() function are given below:
form: a formula that links the response (target) variable to the independent (features) variables.
data: the data to be used for modeling.
trControl: specify the control
parameters for the resampling method and other settings when training a
machine learning model. It is important when performing cross-validation
or bootstrapping to assess the model’s performance. Here are some of the
sub-arguments we can set within trControl
:
"cv"
for k-fold
cross-validation, "repeatedcv"
for repeated k-fold
cross-validation, "boot"
for bootstrap resampling, and
more.method: defines the algorithm
specifying which classification or regression model to use. In the case
of decision trees we use the default method as "rpart2"
where the model complexity is determined using the one-standard error
method.
preProcess: an optional argument
is used to perform mean centering and standardize predictor variables by
typing preProcess = c("center", "scale")
.
tuneLength: specifies the number
of values to consider during the hyperparameter tuning process. It
typically applies to the cp
parameter, which controls the
complexity of the tree in regression tree models.
tuneGrid: define a grid of
values for the cp
parameter.
ctrl <- trainControl(method = "cv", number = 5)
fit1 <- train(medv ~.,
data = train.data,
method ="rpart2",
trControl = ctrl,
preProcess = c("center", "scale"),
#tuneGrid = expand.grid(cp = seq(0.01,0.5, by = 0.01)),
tuneLength = 10 # consider 10 different values of the cp (maxdepth = 10)
)
fit1
## CART
##
## 253 samples
## 12 predictor
##
## Pre-processing: centered (12), scaled (12)
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 201, 203, 202, 203, 203
## Resampling results across tuning parameters:
##
## maxdepth RMSE Rsquared MAE
## 1 6.075240 0.5164832 4.741275
## 2 4.865617 0.7042105 3.727730
## 3 4.544070 0.7259526 3.457168
## 4 4.155026 0.7778254 3.106068
## 5 3.864853 0.8089666 2.904045
## 6 3.863746 0.8086280 2.876111
## 7 3.836177 0.8112678 2.830364
## 8 3.836177 0.8112678 2.830364
## 9 3.836177 0.8112678 2.830364
## 10 3.836177 0.8112678 2.830364
##
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was maxdepth = 7.
plot(fit1)
After examining the output and the fitted model plot, it becomes evident that the root mean squared errors (RMSE) exhibit minimal variation when the maximum tree depth exceeds 7. Consequently, we establish the maximum depth to be 7 and proceed to retrain the model using the rpart() function from the rpart library as shown below:
fit2 <- rpart(medv ~.,
train.data,
method = "anova", #default method
maxdepth = 7)
fit2
## n= 253
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 253 19447.8700 21.78656
## 2) rm< 6.9595 222 6794.2920 19.35360
## 4) lstat>=14.405 87 1553.7870 14.46092
## 8) crim>=11.48635 26 302.7138 10.31538 *
## 9) crim< 11.48635 61 613.8026 16.22787
## 18) age>=93.95 31 164.1239 14.42903 *
## 19) age< 93.95 30 245.7147 18.08667 *
## 5) lstat< 14.405 135 1815.7240 22.50667
## 10) rm< 6.543 111 763.1337 21.37748 *
## 11) rm>=6.543 24 256.4696 27.72917 *
## 3) rm>=6.9595 31 1928.9870 39.20968
## 6) rm< 7.553 16 505.4900 33.42500 *
## 7) rm>=7.553 15 317.0040 45.38000 *
Note that the method argument is
used to define the algorithm that we use to fit the model. It can be one
of "anova"
, "poisson"
, "class"
,
"exp"
. In the case of regression trees where the target
variable is numeric, we use the default method as
"anova"
.
Notice that the output of the fitted model indicates that only four
of the variables have been used in constructing the tree
(rm
, lstat
, age
, and
crime
). In the context of a regression tree, the deviance is simply the sum of squared errors
(SSE) for the tree.
The output of the fitted model shows the steps of the trees splits (root, branch, leaf). For example, we start with 253 observations at the root node and we split the data first on the rm variable (first branch). That is, out of all other features, rm is the most predictive variable that optimizes a reduction in SSE.
We see that 222 observations with rm
less than 6.9595 will go to the left hand side of the second branch
(denoted by 2)) and 31 observations with
rm greater than or equal to 6.9595 will
go to the right hand side of the second branch (denoted by 3)). The asterisks
(*) indicate the leaf nodes associated with prediction
values.
For example, out of the 31 observations where \(\text{rm} \geq 6.9595\), we see that 16
observations are following the terminal node (leaf) \(6.9595\leq \text{rm} < 7.553\) with a
predicted median value of owner-occupied homes (on average) of \(\$33,425\). The SSE of these observations
is 505.49. On the other hand, we see that 15 observations are following
the leaf \(\text{rm} \geq 7.553\) with
a predicted median value of owner-occupied homes (on average) of \(\$45,380\). The SSE of these observations
is 317.00.
For the case where we split \(\text{rm} < 6.9595\), we have another split. We split on the lstat so that those that are greater than or equal to 14.405 (87 observations) will go to the left hand side of a sub branch and others (135 observations) will go to the other side. We continue in this manner so that we split again on the crim. We predict (on average) the median value of owner-occupied homes of those with \(\text{lstat}\geq 14.405\) and \(\text{crime}\geq 11.48635\) to be \(\$10,315\). On the other hand, we predict the median value of owner-occupied homes of those with \(\text{lstat}\geq 14.405\), \(\text{crime}<11.48635\), and \(\text{age}\geq 93.95\) to be \(\$14,429\) and those with \(\text{lstat}\geq 14.405\), \(\text{crime}<11.48635\), and \(\text{age}<93.95\) to be \(\$18,087\).
Similarly, we predict (on average) the median value of owner-occupied homes of those with \(\text{lstat} < 14.405\) and \(\text{rm}<6.543\) to be \(\$21,378\), whereas the predicted median value of owner-occupied homes (on average) of those with \(\text{lstat} < 14.405\) and \(\text{rm}\geq 6.543\) is \(\$27,730\).
We can easily visualize the model output by plotting the tree using the rpart.plot() function in the rpart.plot library using the following code.
rpart.plot(fit2, type = 2, digits = 3, fallen.leaves = FALSE)
Note that the rpart() function is
automatically applying a range of
cost complexity values and tuning parameter
(\(\alpha\)) to prune the tree that has the
lowest error. This function performs a 10-fold cross validation,
comparing the error that is associated with each \(\alpha\) on the hold-out validation data.
This process helps identify the optimal \(\alpha\) and, consequently, the most
suitable subtree to prevent overfitting the data. To gain insight into
the internal process, we can utilize the printcp() function as demonstrated below. For
instance, with no split, the
cross-validation error (xerror)
is 1.01828, whereas with 6
splits, this error decreases to 0.23986.
# Cost complexity pruning
printcp(fit2)
##
## Regression tree:
## rpart(formula = medv ~ ., data = train.data, method = "anova",
## maxdepth = 7)
##
## Variables actually used in tree construction:
## [1] age crim lstat rm
##
## Root node error: 19448/253 = 76.869
##
## n= 253
##
## CP nsplit rel error xerror xstd
## 1 0.551453 0 1.00000 1.01828 0.128244
## 2 0.176101 1 0.44855 0.51236 0.059912
## 3 0.056895 2 0.27245 0.33807 0.052980
## 4 0.040936 3 0.21555 0.29998 0.041969
## 5 0.032768 4 0.17461 0.26613 0.040627
## 6 0.010488 5 0.14185 0.23932 0.039977
## 7 0.010000 6 0.13136 0.23986 0.040027
# Extract minimum error associated with the optimal cost complexity value for each model
cp <- fit2$cptable
We may also use the plotcp() function
to create a graph displaying the cross-validation errors (y-axis) in
relation to the cost complexity (CP) values (x-axis). In this example,
we observe diminishing values of the cost complexity (CP)
after reaching 7 terminal nodes (tree size \(=
|T|\)). It’s worth noting that the dashed line intersects the
point where \(|T| = 5\).” Therefore, we
can explore the potential for improved prediction accuracy by employing
a pruned tree with 5 terminal nodes.
plotcp(fit2)
prune.fit <- prune(fit2, cp[5])
rpart.plot(prune.fit, type = 2, digits = 3, fallen.leaves = FALSE)
Recall that a smaller Mean Squared Error (MSE)
indicates
a better model.
Therefore, we can assess the model’s performance by computing the
MSE
using the test data:
# predictions on the test set using the unpruned tree
yhat <- predict(fit2, newdata = test.data)
# Mean square errors
mean((yhat - test.data$medv)^2)
## [1] 35.28688
In this example, the mean square errors is 35.29. Note that we don’t
know if this value is a small or large! We require an additional model
to facilitate the comparison of Mean Squared Error (MSE)
values across all models.
To assess whether pruning the tree can enhance the performance of
this model, we compute the Mean Squared Error (MSE)
for the
pruned tree
model using the following approach:
# predictions on the test set using the pruned tree
yhat <- predict(prune.fit, newdata = test.data)
# Mean square errors
mean((yhat - test.data$medv)^2)
## [1] 35.90102
The square root of the MSE associated with the pruned tree is 35.90 which is larger than that one obtained from the unpruned tree. Thus, we may conclude that the pruned tree model has lower prediction accuracy.
Here we apply bagging and random forests, using the randomForest() function in the randomForest package.
The argument mtry = 12
in the randomForest() function indicates that all 12
predictors should be considered for each split of the tree (i.e, bagging is used). The argument
importance = TRUE
indicates that the importance of
predictors should be assessed.
library(randomForest)
set.seed (1)
bag.boston <- randomForest(medv ~ .,
data = Boston,
subset = Index,
mtry = 12, #all 12 predictors should be considered for bagging
importance = TRUE)
bag.boston
##
## Call:
## randomForest(formula = medv ~ ., data = Boston, mtry = 12, importance = TRUE, subset = Index)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 12
##
## Mean of squared residuals: 11.40162
## % Var explained: 85.17
To measure the performance of the bagged regression tree on the test set, we plot the predicted values versus the actual values. The plot helps us assess how well the model predicts the actual values. If the points on the plot are close to the 45-degree line (the line where predicted values equal true values), it suggests that the model is a good fit. If the points deviate significantly from this line, it indicates prediction errors.
yhat.bag <- predict(bag.boston, newdata = test.data)
plot(yhat.bag, test.data$medv)
abline(0, 1)
mean((yhat.bag - test.data$medv)^2)
## [1] 23.41916
The test set MSE associated with the bagged regression tree is 23.42
which is less than that one obtained from the unpruned tree (35.29).
This indicates that bagging algorithm for regression tree yielded an
improvement over unpruned tree. We could change the number of trees
grown by randomForest() using the
ntree
argument:
bag.boston <- randomForest(medv ~ .,
data = Boston,
subset = Index,
mtry = 12, #all 12 predictors should be considered for bagging
ntree = 25)
yhat.bag <- predict(bag.boston, newdata = test.data)
mean((yhat.bag - test.data$medv)^2)
## [1] 25.75055
Growing a random forest proceeds in exactly the same way that we use for applying the bagged model, except that we use a smaller value of the mtry argument.
By default, randomForest() uses \(p/3\) variables when building a random
forest of regression trees, and \(\sqrt{p}\) variables when building a random
forest of classification trees. Here we use mtry = 6
.
set.seed (1)
rf.boston <- randomForest(medv ~ .,
data = Boston,
subset = Index,
mtry = 6, #subset of 6 predictors - random forest algorithm
importance = TRUE)
yhat.rf <- predict(rf.boston, newdata = test.data)
mean((yhat.rf - test.data$medv)^2)
## [1] 20.06644
The test set MSE is 20.07; this indicates that random forests yielded an improvement over bagging in this case.
Using the importance() function, we can view the importance of each variable.
importance(rf.boston)
## %IncMSE IncNodePurity
## crim 19.435587 1070.42307
## zn 3.091630 82.19257
## indus 6.140529 590.09536
## chas 1.370310 36.70356
## nox 13.263466 859.97091
## rm 35.094741 8270.33906
## age 15.144821 634.31220
## dis 9.163776 684.87953
## rad 4.793720 83.18719
## tax 4.410714 292.20949
## ptratio 8.612780 902.20190
## lstat 28.725343 5813.04833
Two measures of variable importance are reported:
In the case of regression trees, the node impurity is measured by the
training RSS
. Plots of these importance measures can be
produced using the varImpPlot() function
as follows:
varImpPlot(rf.boston, main = "Variable Importance as Measured by a Random Forest")
The results indicate that across all of the trees considered in the
random forest, the wealth of the community (lstat
) and the
house size (rm
) are by far the two most important
variables.