This post covers the stacks
package, which focuses on combining the outputs of many models into a single model, a process called ensembling.
Setup
Packages
The following packages are required:
Data
I utilized the penguins data set from the modeldata package for these examples.
Note that the outcome variable is forced to be binary by assigning all the odd-numbered rows as “Chinstrap” and the even numbered rows as “Adelie”. This is done to (1) ensure that the outcome is binary so certain functions and techniques can be used and (2) make the outcome a little more difficult to predict to better demonstrate the difference in the construction of the two modeling algorithms used. In addition, rows with missing values are removed.
Code
data("penguins", package = "modeldata")
penguins_tbl <- penguins %>%
mutate(
#> Need a binary outcome
species = ifelse(row_number() %% 2 == 0, "Adelie", "Chinstrap"),
species = factor(species, levels = c("Adelie", "Chinstrap"))
) %>%
filter(!if_any(everything(), \(x) is.na(x)))
penguins_tbl %>%
count(species)
# A tibble: 2 × 2
species n
<fct> <int>
1 Adelie 167
2 Chinstrap 166
Model Building Process
Before demonstrating the stacks
package, a few models will be built. Following the previous posts in this series, these model objects are built following the tidymodels
process.
Data Splitting
Referencing a previous post on the rsample
package, a 70/30 train/test initial_split()
on both data sets is taken. It is then passed to the training()
and testing()
functions to extract the training and testing data sets, respectively.
set.seed(1914)
penguins_split_obj <- initial_split(penguins_tbl, prop = 0.7)
penguins_train_tbl <- training(penguins_split_obj)
penguins_train_tbl
# A tibble: 233 × 7
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
<fct> <fct> <dbl> <dbl> <int> <int>
1 Chinstrap Biscoe 36.4 17.1 184 2850
2 Adelie Biscoe 52.1 17 230 5550
3 Adelie Dream 50.8 18.5 201 4450
4 Adelie Dream 37.2 18.1 178 3900
5 Chinstrap Biscoe 45.5 14.5 212 4750
6 Chinstrap Biscoe 39 17.5 186 3550
7 Chinstrap Biscoe 48.2 15.6 221 5100
8 Adelie Dream 52 19 197 4150
9 Chinstrap Biscoe 45.4 14.6 211 4800
10 Chinstrap Biscoe 38.1 17 181 3175
# ℹ 223 more rows
# ℹ 1 more variable: sex <fct>
penguins_test_tbl <- testing(penguins_split_obj)
penguins_test_tbl
# A tibble: 100 × 7
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
<fct> <fct> <dbl> <dbl> <int> <int>
1 Chinstrap Torgers… 40.3 18 195 3250
2 Adelie Torgers… 39.2 19.6 195 4675
3 Chinstrap Torgers… 38.7 19 195 3450
4 Adelie Torgers… 46 21.5 194 4200
5 Chinstrap Biscoe 37.8 18.3 174 3400
6 Adelie Biscoe 37.7 18.7 180 3600
7 Adelie Biscoe 38.2 18.1 185 3950
8 Chinstrap Biscoe 40.6 18.6 183 3550
9 Chinstrap Dream 36.4 17 195 3325
10 Adelie Dream 42.2 18.5 180 3550
# ℹ 90 more rows
# ℹ 1 more variable: sex <fct>
Data Processing
Referencing another previous post on the recipes
package, pre-processing steps are performed on the training data set. The resulting recipe()
object is then passed to the prep()
function to prepare it for use. The result is then passed to the juice()
function to extract the transformed training data set.
Note that, in preparation for ensembling the models, a different recipe()
is created for each modeling algorithm that will be used in the next section:
- Logistic Regression
- Random Forest
- K-Nearest Neighbors
For this recipe()
, an object is created that only contains the model formula and the training data set. This may seem needless, as passing this recipe()
to the juice()
function just returns the original data set, but it is required for the later ensembling process.
recipe_logreg_obj <- recipe(species ~ ., data = penguins_train_tbl)
recipe_logreg_obj %>%
prep() %>%
juice()
# A tibble: 233 × 7
island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex
<fct> <dbl> <dbl> <int> <int> <fct>
1 Biscoe 36.4 17.1 184 2850 female
2 Biscoe 52.1 17 230 5550 male
3 Dream 50.8 18.5 201 4450 male
4 Dream 37.2 18.1 178 3900 male
5 Biscoe 45.5 14.5 212 4750 female
6 Biscoe 39 17.5 186 3550 female
7 Biscoe 48.2 15.6 221 5100 male
8 Dream 52 19 197 4150 male
9 Biscoe 45.4 14.6 211 4800 female
10 Biscoe 38.1 17 181 3175 female
# ℹ 223 more rows
# ℹ 1 more variable: species <fct>
The logistic regression recipe()
object is built upon by adding a step that one-hot encodes all categorical variables.
recipe_rf_obj <- recipe_logreg_obj %>%
step_dummy(all_nominal_predictors(), one_hot = TRUE)
recipe_rf_obj %>%
prep() %>%
juice()
# A tibble: 233 × 10
bill_length_mm bill_depth_mm flipper_length_mm body_mass_g species
<dbl> <dbl> <int> <int> <fct>
1 36.4 17.1 184 2850 Chinstrap
2 52.1 17 230 5550 Adelie
3 50.8 18.5 201 4450 Adelie
4 37.2 18.1 178 3900 Adelie
5 45.5 14.5 212 4750 Chinstrap
6 39 17.5 186 3550 Chinstrap
7 48.2 15.6 221 5100 Chinstrap
8 52 19 197 4150 Adelie
9 45.4 14.6 211 4800 Chinstrap
10 38.1 17 181 3175 Chinstrap
# ℹ 223 more rows
# ℹ 5 more variables: island_Biscoe <dbl>, island_Dream <dbl>,
# island_Torgersen <dbl>, sex_female <dbl>, sex_male <dbl>
The random forest recipe()
object is further built upon by adding a step that normalizes all continuous variables.
recipe_knn_obj <- recipe_rf_obj %>%
step_normalize(all_numeric_predictors())
recipe_knn_obj %>%
prep() %>%
juice()
# A tibble: 233 × 10
bill_length_mm bill_depth_mm flipper_length_mm body_mass_g species
<dbl> <dbl> <dbl> <dbl> <fct>
1 -1.41 -0.00476 -1.24 -1.72 Chinstrap
2 1.37 -0.0552 2.02 1.63 Adelie
3 1.14 0.701 -0.0360 0.266 Adelie
4 -1.26 0.499 -1.67 -0.415 Adelie
5 0.204 -1.32 0.745 0.637 Chinstrap
6 -0.945 0.197 -1.10 -0.849 Chinstrap
7 0.682 -0.761 1.38 1.07 Chinstrap
8 1.35 0.953 -0.320 -0.106 Adelie
9 0.186 -1.26 0.674 0.699 Chinstrap
10 -1.10 -0.0552 -1.46 -1.31 Chinstrap
# ℹ 223 more rows
# ℹ 5 more variables: island_Biscoe <dbl>, island_Dream <dbl>,
# island_Torgersen <dbl>, sex_female <dbl>, sex_male <dbl>
Model Specifications
Referencing another previous post on the parsnip
package, the three model specifications mentioned in the previous section are constructed. The first is a logistic regression, using the logistic_reg()
function; the second is a random forest, using the rand_forest()
function; the third is a k-nearest neighbors, using the nearest_neighbor()
function.
mod_logreg_spec <- logistic_reg() %>%
set_engine("glm")
mod_logreg_spec
Logistic Regression Model Specification (classification)
Computational engine: glm
mod_rf_spec <- rand_forest() %>%
set_mode("classification") %>%
set_engine("randomForest", importance = TRUE) %>%
set_args(
trees = tune(),
mtry = tune(),
min_n = tune()
)
mod_rf_spec
Random Forest Model Specification (classification)
Main Arguments:
mtry = tune()
trees = tune()
min_n = tune()
Engine-Specific Arguments:
importance = TRUE
Computational engine: randomForest
mod_knn_spec <- nearest_neighbor() %>%
set_mode("classification") %>%
set_engine("kknn") %>%
set_args(neighbors = tune())
mod_knn_spec
K-Nearest Neighbor Model Specification (classification)
Main Arguments:
neighbors = tune()
Computational engine: kknn
Model Hyperparameters
Notice that two of the model specifications, the random forest and the k-nearest neighbors, had their arguments set to tune()
. As such, their hyperparameters need to be tuned. Instructions for doing so can be found in another previous post, this time on the dials
and tune
packages. The procedure needs to be altered slightly for use with the stacks
package.
First, the tuning process needs to be altered using the control_stack_grid()
and control_stack_resamples()
functions. According to the documentation, these functions slightly alter the usual control_*
functions from the tune
package such that they “return the appropriate control grid to ensure that assessment set predictions and information on model specifications and preprocessors, is supplied in the resampling results object”. This information is required for ensembling the models.
In addition, 10-fold cross validation is used using the vfold_cv()
function from rsample
package and the metric to be optimized during the tuning process is the Area Under the ROC Curve, using the roc_auc()
function from the yardstick
package. More information on the yardstick
package can be found in a previous post.
ctrl_grid <- control_stack_grid()
ctrl_res <- control_stack_resamples()
set.seed(1914)
folds_tbl <- vfold_cv(penguins_train_tbl, v = 10)
metric_obj <- metric_set(roc_auc)
Finally, to keep everything nice and organized, another previous post on the workflows
package is referenced to wrap up each modeling algorithm and pre-processing recipe into its own workflow()
.
The logistic regression model does not have any hyperparameters to tune. As such, its workflow is passed to the fit_resamples()
function (with the appropriate resamples
, metrics
, and control
arguments).
The output is the result of the ten-folds, with the predictions appended as required for stacks()
.
mod_logreg_wflw <- workflow() %>%
add_recipe(recipe_logreg_obj) %>%
add_model(mod_logreg_spec)
mod_logreg_wflw
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()
── Preprocessor ────────────────────────────────────────────────────────────────
0 Recipe Steps
── Model ───────────────────────────────────────────────────────────────────────
Logistic Regression Model Specification (classification)
Computational engine: glm
set.seed(1915)
mod_logreg_res <- mod_logreg_wflw %>%
fit_resamples(
resamples = folds_tbl,
metrics = metric_obj,
control = ctrl_res
)
mod_logreg_res
# Resampling results
# 10-fold cross-validation
# A tibble: 10 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [209/24]> Fold01 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [24 × 5]>
2 <split [209/24]> Fold02 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [24 × 5]>
3 <split [209/24]> Fold03 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [24 × 5]>
4 <split [210/23]> Fold04 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>
5 <split [210/23]> Fold05 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>
6 <split [210/23]> Fold06 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>
7 <split [210/23]> Fold07 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>
8 <split [210/23]> Fold08 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>
9 <split [210/23]> Fold09 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>
10 <split [210/23]> Fold10 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>
The random forest has three hyperparameters that need to be tuned. As such, its workflow is passed to the tune_grid()
function (again with resamples
, metrics
, and control
appropriately defined) with grid = 10
set to try 10 values of each hyperparameter.
The output is the result of the ten-folds, but this time with an expanded .metrics
column to account for the different hyperparameter combinations. Again, the predictions are appended as required for stacks()
.
mod_rf_wflw <- workflow() %>%
add_recipe(recipe_rf_obj) %>%
add_model(mod_rf_spec)
mod_rf_wflw
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()
── Preprocessor ────────────────────────────────────────────────────────────────
1 Recipe Step
• step_dummy()
── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)
Main Arguments:
mtry = tune()
trees = tune()
min_n = tune()
Engine-Specific Arguments:
importance = TRUE
Computational engine: randomForest
set.seed(1915)
mod_rf_res <- mod_rf_wflw %>%
tune_grid(
grid = 10,
resamples = folds_tbl,
metrics = metric_obj,
control = ctrl_grid
)
mod_rf_res
# Tuning results
# 10-fold cross-validation
# A tibble: 10 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [209/24]> Fold01 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [240 × 8]>
2 <split [209/24]> Fold02 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [240 × 8]>
3 <split [209/24]> Fold03 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [240 × 8]>
4 <split [210/23]> Fold04 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>
5 <split [210/23]> Fold05 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>
6 <split [210/23]> Fold06 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>
7 <split [210/23]> Fold07 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>
8 <split [210/23]> Fold08 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>
9 <split [210/23]> Fold09 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>
10 <split [210/23]> Fold10 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>
Finally, the k-nearest neighbors has one hyperparameter that needs to be tuned. It is also passed to the tune_grid()
function with the same arguments as the random forest.
The output is again the the result of the ten-folds, with the .metrics
column showing the chosen value of the hyperparameter and the predictions appended as required for stacks()
.
mod_knn_wflw <- workflow() %>%
add_recipe(recipe_knn_obj) %>%
add_model(mod_knn_spec)
mod_knn_wflw
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: nearest_neighbor()
── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps
• step_dummy()
• step_normalize()
── Model ───────────────────────────────────────────────────────────────────────
K-Nearest Neighbor Model Specification (classification)
Main Arguments:
neighbors = tune()
Computational engine: kknn
set.seed(1915)
mod_knn_res <- mod_knn_wflw %>%
tune_grid(
grid = 10,
resamples = folds_tbl,
metrics = metric_obj,
control = ctrl_grid
)
mod_knn_res
# Tuning results
# 10-fold cross-validation
# A tibble: 10 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [209/24]> Fold01 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [240 × 6]>
2 <split [209/24]> Fold02 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [240 × 6]>
3 <split [209/24]> Fold03 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [240 × 6]>
4 <split [210/23]> Fold04 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>
5 <split [210/23]> Fold05 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>
6 <split [210/23]> Fold06 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>
7 <split [210/23]> Fold07 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>
8 <split [210/23]> Fold08 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>
9 <split [210/23]> Fold09 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>
10 <split [210/23]> Fold10 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>
Model Stacking Process
Now that there are three separate models, they can be stacked together to produce a single output.
Stacking Model Members
To start a model stack
, the stacks()
function is used to create an empty stack
. A stack
reports the number of model definitions and candidate members.
stacks()
# A data stack with 0 model definitions and 0 candidate members.
From that empty stack
, models are added one at a time using the add_candidates()
function. It can now be seen that the stack
contains 3 model definitions (the logistic regression, random forest, and k-nearest neighbor models) and 21 candidate members.
mod_spec_stck <- stacks() %>%
add_candidates(mod_logreg_res) %>%
add_candidates(mod_rf_res) %>%
add_candidates(mod_knn_res)
mod_spec_stck
# A data stack with 3 model definitions and 21 candidate members:
# mod_logreg_res: 1 model configuration
# mod_rf_res: 10 model configurations
# mod_knn_res: 10 model configurations
# Outcome: species (factor)
When added to a stack
, models are given a default name in the object. These names can be customized using the name
argument in the add_candidates()
function.
mod_spec_stck <- stacks() %>%
add_candidates(mod_logreg_res, name = "logistic_regression") %>%
add_candidates(mod_rf_res, name = "random_forest") %>%
add_candidates(mod_knn_res, name = "k_nearest_neighbor")
mod_spec_stck
# A data stack with 3 model definitions and 21 candidate members:
# logistic_regression: 1 model configuration
# random_forest: 10 model configurations
# k_nearest_neighbor: 10 model configurations
# Outcome: species (factor)
The 21 candidate members come from the different combinations of hyperparameters that were used during tuning. No final combination of hyperparameters have been chosen yet, so all 21 combinations are included in the stack
, consisting of:
- Logistic Regression: 1 combination (no hyperparameters)
- Random Forest: 10 combinations
- K-Nearest Neighbors: 10 combinations
The collect_parameters()
function can be used to see the specifics of these candidate members.
mod_spec_stck %>%
collect_parameters("logistic_regression")
# A tibble: 1 × 1
member
<chr>
1 logistic_regression_1_1
mod_spec_stck %>%
collect_parameters("random_forest")
# A tibble: 10 × 4
member mtry trees min_n
<chr> <int> <int> <int>
1 random_forest_1_01 5 597 27
2 random_forest_1_02 4 786 16
3 random_forest_1_03 3 32 7
4 random_forest_1_04 1 340 20
5 random_forest_1_05 7 1930 29
6 random_forest_1_06 6 1419 13
7 random_forest_1_07 2 1350 24
8 random_forest_1_08 9 1678 33
9 random_forest_1_09 5 865 4
10 random_forest_1_10 8 1116 37
mod_spec_stck %>%
collect_parameters("k_nearest_neighbor")
# A tibble: 10 × 2
member neighbors
<chr> <int>
1 k_nearest_neighbor_1_01 1
2 k_nearest_neighbor_1_02 3
3 k_nearest_neighbor_1_03 5
4 k_nearest_neighbor_1_04 6
5 k_nearest_neighbor_1_05 7
6 k_nearest_neighbor_1_06 9
7 k_nearest_neighbor_1_07 10
8 k_nearest_neighbor_1_08 11
9 k_nearest_neighbor_1_09 13
10 k_nearest_neighbor_1_10 14
Once all the models have been added to the stack
, the blend_predictions()
function can be used to create the model ensemble. The blending process works by applying an elastic net regression-type procedure to the candidate members. Those that make it through the procedure with a non-zero “stacking coefficient” are kept in the ensemble and the rest are discarded.
set.seed(1916)
mod_blend_stck <- mod_spec_stck %>%
blend_predictions()
mod_blend_stck
# A tibble: 2 × 3
member type weight
<chr> <chr> <dbl>
1 .pred_Chinstrap_random_forest_1_07 rand_forest 2.00
2 .pred_Chinstrap_random_forest_1_05 rand_forest 0.828
By default, the mixture
parameter of the elastic net is 0, which results in a pure LASSO regression. The mixture
, as well as the penalty
parameter, can be altered as shown below. Moreover, the times
parameter (the number of times bootstrapping is used to calculating the stacking coefficients) can also be adjusted in the blend_predictions()
function.
set.seed(1916)
mod_blend_stck <- mod_spec_stck %>%
blend_predictions(
penalty = 10 ^ (-7:-3),
mixture = 0.5,
times = 50
)
mod_blend_stck
# A tibble: 9 × 3
member type weight
<chr> <chr> <dbl>
1 .pred_Chinstrap_random_forest_1_04 rand_forest 1.12
2 .pred_Chinstrap_random_forest_1_07 rand_forest 0.949
3 .pred_Chinstrap_random_forest_1_08 rand_forest 0.619
4 .pred_Chinstrap_random_forest_1_06 rand_forest 0.592
5 .pred_Chinstrap_random_forest_1_05 rand_forest 0.559
6 .pred_Chinstrap_logistic_regression_1_1 logistic_reg 0.334
7 .pred_Chinstrap_k_nearest_neighbor_1_01 nearest_neighbor 0.329
8 .pred_Chinstrap_random_forest_1_10 rand_forest 0.321
9 .pred_Chinstrap_k_nearest_neighbor_1_02 nearest_neighbor 0.0153
With these updated arguments, at least one candidate member from each modeling algorithm was chosen. With these candidate members included in the ensemble, the final model can then be fit to the data using the fit_members()
function.
mod_fit_stck <- mod_blend_stck %>%
fit_members()
mod_fit_stck
# A tibble: 9 × 3
member type weight
<chr> <chr> <dbl>
1 .pred_Chinstrap_random_forest_1_04 rand_forest 1.12
2 .pred_Chinstrap_random_forest_1_07 rand_forest 0.949
3 .pred_Chinstrap_random_forest_1_08 rand_forest 0.619
4 .pred_Chinstrap_random_forest_1_06 rand_forest 0.592
5 .pred_Chinstrap_random_forest_1_05 rand_forest 0.559
6 .pred_Chinstrap_logistic_regression_1_1 logistic_reg 0.334
7 .pred_Chinstrap_k_nearest_neighbor_1_01 nearest_neighbor 0.329
8 .pred_Chinstrap_random_forest_1_10 rand_forest 0.321
9 .pred_Chinstrap_k_nearest_neighbor_1_02 nearest_neighbor 0.0153
Plots
The stacks
package provides a few ggplot2
plots to examine the resulting ensemble via the autoplot()
function.
Passing the fitted stack
object to autoplot()
shows how the performance metrics changed over the values of penalty
. For example, one wants to maximize accuracy and AUC for a classification model. The dotted lines in the plot below show where that is achieved.
autoplot(mod_fit_stck)
Specifying type = "members"
will show the values of performance metrics across different numbers of members in the ensemble model. As before, metrics such as accuracy and AUC are desired to be maximized.
autoplot(mod_fit_stck, type = "members")
Specifying type = "weights"
provides a visual representation of the stacking coefficients (how much each model is weighted in the final ensemble) for each model in the ensemble, color-coded by the model algorithm.
autoplot(mod_fit_stck, type = "weights")
Stacked Model Performance
Now that a final ensemble has been chosen and examined, it can be used for predictions. Since the model was built using a categorical outcome, there are two predictions that can be calculated:
- Class Predictions
- Probability Predictions
The predicted classes for the testing data can be calculated by passing the stack
and the testing data to the predict()
function.
preds_class_tbl <- mod_fit_stck %>%
predict(penguins_test_tbl) %>%
bind_cols(penguins_test_tbl)
preds_class_tbl
# A tibble: 100 × 8
.pred_class species island bill_length_mm bill_depth_mm flipper_length_mm
<fct> <fct> <fct> <dbl> <dbl> <int>
1 Chinstrap Chinstrap Torgers… 40.3 18 195
2 Adelie Adelie Torgers… 39.2 19.6 195
3 Chinstrap Chinstrap Torgers… 38.7 19 195
4 Adelie Adelie Torgers… 46 21.5 194
5 Chinstrap Chinstrap Biscoe 37.8 18.3 174
6 Adelie Adelie Biscoe 37.7 18.7 180
7 Adelie Adelie Biscoe 38.2 18.1 185
8 Adelie Chinstrap Biscoe 40.6 18.6 183
9 Chinstrap Chinstrap Dream 36.4 17 195
10 Chinstrap Adelie Dream 42.2 18.5 180
# ℹ 90 more rows
# ℹ 2 more variables: body_mass_g <int>, sex <fct>
To get the predicted probabilities for each class, specify type = "prob"
in the predict()
function.
preds_prob_tbl <- mod_fit_stck %>%
predict(penguins_test_tbl, type = "prob") %>%
bind_cols(penguins_test_tbl)
preds_prob_tbl
# A tibble: 100 × 9
.pred_Adelie .pred_Chinstrap species island bill_length_mm bill_depth_mm
<dbl> <dbl> <fct> <fct> <dbl> <dbl>
1 0.112 0.888 Chinstrap Torgersen 40.3 18
2 0.901 0.0993 Adelie Torgersen 39.2 19.6
3 0.137 0.863 Chinstrap Torgersen 38.7 19
4 0.900 0.100 Adelie Torgersen 46 21.5
5 0.115 0.885 Chinstrap Biscoe 37.8 18.3
6 0.697 0.303 Adelie Biscoe 37.7 18.7
7 0.827 0.173 Adelie Biscoe 38.2 18.1
8 0.880 0.120 Chinstrap Biscoe 40.6 18.6
9 0.0962 0.904 Chinstrap Dream 36.4 17
10 0.205 0.795 Adelie Dream 42.2 18.5
# ℹ 90 more rows
# ℹ 3 more variables: flipper_length_mm <int>, body_mass_g <int>, sex <fct>
Another method of getting predictions, without having to use bind_cols()
, is to use the augment()
function. More information on the augment()
function can be found in a previous post on the broom
package. Passing the stack
and the training data to augment()
produces the predicted outcome class.
mod_fit_stck %>%
augment(penguins_train_tbl) %>%
relocate(.pred_class, .before = everything())
# A tibble: 233 × 8
.pred_class species island bill_length_mm bill_depth_mm flipper_length_mm
<fct> <fct> <fct> <dbl> <dbl> <int>
1 Chinstrap Chinstrap Biscoe 36.4 17.1 184
2 Adelie Adelie Biscoe 52.1 17 230
3 Adelie Adelie Dream 50.8 18.5 201
4 Adelie Adelie Dream 37.2 18.1 178
5 Chinstrap Chinstrap Biscoe 45.5 14.5 212
6 Chinstrap Chinstrap Biscoe 39 17.5 186
7 Adelie Chinstrap Biscoe 48.2 15.6 221
8 Adelie Adelie Dream 52 19 197
9 Chinstrap Chinstrap Biscoe 45.4 14.6 211
10 Chinstrap Chinstrap Biscoe 38.1 17 181
# ℹ 223 more rows
# ℹ 2 more variables: body_mass_g <int>, sex <fct>
Specifying type = "prob"
in augment produces the predicted probabilities for each outcome class.
mod_fit_stck %>%
augment(penguins_train_tbl, type = "prob") %>%
relocate(contains(".pred"), .before = everything())
# A tibble: 233 × 9
.pred_Adelie .pred_Chinstrap species island bill_length_mm bill_depth_mm
<dbl> <dbl> <fct> <fct> <dbl> <dbl>
1 0.0951 0.905 Chinstrap Biscoe 36.4 17.1
2 0.919 0.0812 Adelie Biscoe 52.1 17
3 0.914 0.0860 Adelie Dream 50.8 18.5
4 0.751 0.249 Adelie Dream 37.2 18.1
5 0.0985 0.901 Chinstrap Biscoe 45.5 14.5
6 0.127 0.873 Chinstrap Biscoe 39 17.5
7 0.837 0.163 Chinstrap Biscoe 48.2 15.6
8 0.916 0.0845 Adelie Dream 52 19
9 0.103 0.897 Chinstrap Biscoe 45.4 14.6
10 0.0970 0.903 Chinstrap Biscoe 38.1 17
# ℹ 223 more rows
# ℹ 3 more variables: flipper_length_mm <int>, body_mass_g <int>, sex <fct>
As with parsnip
models and workflows
, stacks
can be passed to functions from the yardstick
package to calculate performance metrics. For this ensemble, conf_mat()
and roc_auc()
are used below to show the confusion matrix and area under the ROC curve, respectively.
Truth
Prediction Adelie Chinstrap
Adelie 38 11
Chinstrap 8 43
Notes
This post is based on a presentation that was given on the date listed. It may be updated from time to time to fix errors, detail new functions, and/or remove deprecated functions so the packages and R version will likely be newer than what was available at the time.
The R session information used for this post:
R version 4.2.1 (2022-06-23)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS 14.0
Matrix products: default
BLAS: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
attached base packages:
[1] stats graphics grDevices datasets utils methods base
other attached packages:
[1] glmnet_4.1-6 Matrix_1.4-1 broom_1.0.5 workflows_1.1.2
[5] yardstick_1.1.0 tune_1.0.1 dials_1.1.0 scales_1.2.1
[9] parsnip_1.0.3 recipes_1.0.4 dplyr_1.1.3 rsample_1.1.1
[13] stacks_1.0.2
loaded via a namespace (and not attached):
[1] tidyr_1.3.0 kknn_1.3.1 jsonlite_1.8.0
[4] splines_4.2.1 foreach_1.5.2 prodlim_2019.11.13
[7] GPfit_1.0-8 renv_0.16.0 yaml_2.3.5
[10] globals_0.16.2 ipred_0.9-13 pillar_1.9.0
[13] backports_1.4.1 lattice_0.20-45 glue_1.6.2
[16] digest_0.6.29 randomForest_4.7-1.1 hardhat_1.2.0
[19] colorspace_2.0-3 htmltools_0.5.3 timeDate_4022.108
[22] pkgconfig_2.0.3 lhs_1.1.6 DiceDesign_1.9
[25] listenv_0.8.0 purrr_1.0.2 gower_1.0.1
[28] lava_1.7.1 timechange_0.1.1 tibble_3.2.1
[31] farver_2.1.1 generics_0.1.3 ggplot2_3.4.0
[34] ellipsis_0.3.2 withr_2.5.1 furrr_0.3.1
[37] nnet_7.3-17 cli_3.6.1 survival_3.3-1
[40] magrittr_2.0.3 evaluate_0.16 future_1.29.0
[43] fansi_1.0.4 parallelly_1.32.1 MASS_7.3-57
[46] class_7.3-20 tools_4.2.1 lifecycle_1.0.3
[49] stringr_1.5.0 munsell_0.5.0 compiler_4.2.1
[52] rlang_1.1.1 grid_4.2.1 iterators_1.0.14
[55] rstudioapi_0.14 igraph_1.5.1 labeling_0.4.2
[58] rmarkdown_2.16 gtable_0.3.1 codetools_0.2-18
[61] R6_2.5.1 lubridate_1.9.0 knitr_1.40
[64] fastmap_1.1.0 future.apply_1.10.0 utf8_1.2.3
[67] butcher_0.3.3 shape_1.4.6 stringi_1.7.12
[70] parallel_4.2.1 Rcpp_1.0.9 vctrs_0.6.3
[73] rpart_4.1.19 tidyselect_1.2.0 xfun_0.40