The stacks Package

Data Science
R
Modeling
Author

Robert Lankford

Published

August 22, 2023

This post covers the stacks package, which focuses on combining the outputs of many models into a single model, a process called ensembling.

Setup

Packages

The following packages are required:

Data

I utilized the penguins data set from the modeldata package for these examples.

Note that the outcome variable is forced to be binary by assigning all the odd-numbered rows as “Chinstrap” and the even numbered rows as “Adelie”. This is done to (1) ensure that the outcome is binary so certain functions and techniques can be used and (2) make the outcome a little more difficult to predict to better demonstrate the difference in the construction of the two modeling algorithms used. In addition, rows with missing values are removed.

Code
data("penguins", package = "modeldata")

penguins_tbl <- penguins %>% 
  mutate(
    #> Need a binary outcome
    species = ifelse(row_number() %% 2 == 0, "Adelie", "Chinstrap"),
    species = factor(species, levels = c("Adelie", "Chinstrap"))
  ) %>% 
  filter(!if_any(everything(), \(x) is.na(x)))

penguins_tbl %>% 
  count(species)
# A tibble: 2 × 2
  species       n
  <fct>     <int>
1 Adelie      167
2 Chinstrap   166

Model Building Process

Before demonstrating the stacks package, a few models will be built. Following the previous posts in this series, these model objects are built following the tidymodels process.

Data Splitting

Referencing a previous post on the rsample package, a 70/30 train/test initial_split() on both data sets is taken. It is then passed to the training() and testing() functions to extract the training and testing data sets, respectively.

set.seed(1914)
penguins_split_obj <- initial_split(penguins_tbl, prop = 0.7)
penguins_train_tbl <- training(penguins_split_obj)

penguins_train_tbl
# A tibble: 233 × 7
   species   island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
   <fct>     <fct>           <dbl>         <dbl>             <int>       <int>
 1 Chinstrap Biscoe           36.4          17.1               184        2850
 2 Adelie    Biscoe           52.1          17                 230        5550
 3 Adelie    Dream            50.8          18.5               201        4450
 4 Adelie    Dream            37.2          18.1               178        3900
 5 Chinstrap Biscoe           45.5          14.5               212        4750
 6 Chinstrap Biscoe           39            17.5               186        3550
 7 Chinstrap Biscoe           48.2          15.6               221        5100
 8 Adelie    Dream            52            19                 197        4150
 9 Chinstrap Biscoe           45.4          14.6               211        4800
10 Chinstrap Biscoe           38.1          17                 181        3175
# ℹ 223 more rows
# ℹ 1 more variable: sex <fct>
penguins_test_tbl <- testing(penguins_split_obj)

penguins_test_tbl
# A tibble: 100 × 7
   species   island   bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
   <fct>     <fct>             <dbl>         <dbl>             <int>       <int>
 1 Chinstrap Torgers…           40.3          18                 195        3250
 2 Adelie    Torgers…           39.2          19.6               195        4675
 3 Chinstrap Torgers…           38.7          19                 195        3450
 4 Adelie    Torgers…           46            21.5               194        4200
 5 Chinstrap Biscoe             37.8          18.3               174        3400
 6 Adelie    Biscoe             37.7          18.7               180        3600
 7 Adelie    Biscoe             38.2          18.1               185        3950
 8 Chinstrap Biscoe             40.6          18.6               183        3550
 9 Chinstrap Dream              36.4          17                 195        3325
10 Adelie    Dream              42.2          18.5               180        3550
# ℹ 90 more rows
# ℹ 1 more variable: sex <fct>

Data Processing

Referencing another previous post on the recipes package, pre-processing steps are performed on the training data set. The resulting recipe() object is then passed to the prep() function to prepare it for use. The result is then passed to the juice() function to extract the transformed training data set.

Note that, in preparation for ensembling the models, a different recipe() is created for each modeling algorithm that will be used in the next section:

  • Logistic Regression
  • Random Forest
  • K-Nearest Neighbors

For this recipe(), an object is created that only contains the model formula and the training data set. This may seem needless, as passing this recipe() to the juice() function just returns the original data set, but it is required for the later ensembling process.

recipe_logreg_obj <- recipe(species ~ ., data = penguins_train_tbl)

recipe_logreg_obj %>% 
  prep() %>% 
  juice()
# A tibble: 233 × 7
   island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex   
   <fct>           <dbl>         <dbl>             <int>       <int> <fct> 
 1 Biscoe           36.4          17.1               184        2850 female
 2 Biscoe           52.1          17                 230        5550 male  
 3 Dream            50.8          18.5               201        4450 male  
 4 Dream            37.2          18.1               178        3900 male  
 5 Biscoe           45.5          14.5               212        4750 female
 6 Biscoe           39            17.5               186        3550 female
 7 Biscoe           48.2          15.6               221        5100 male  
 8 Dream            52            19                 197        4150 male  
 9 Biscoe           45.4          14.6               211        4800 female
10 Biscoe           38.1          17                 181        3175 female
# ℹ 223 more rows
# ℹ 1 more variable: species <fct>

The logistic regression recipe() object is built upon by adding a step that one-hot encodes all categorical variables.

recipe_rf_obj <- recipe_logreg_obj %>% 
  step_dummy(all_nominal_predictors(), one_hot = TRUE)

recipe_rf_obj %>% 
  prep() %>% 
  juice()
# A tibble: 233 × 10
   bill_length_mm bill_depth_mm flipper_length_mm body_mass_g species  
            <dbl>         <dbl>             <int>       <int> <fct>    
 1           36.4          17.1               184        2850 Chinstrap
 2           52.1          17                 230        5550 Adelie   
 3           50.8          18.5               201        4450 Adelie   
 4           37.2          18.1               178        3900 Adelie   
 5           45.5          14.5               212        4750 Chinstrap
 6           39            17.5               186        3550 Chinstrap
 7           48.2          15.6               221        5100 Chinstrap
 8           52            19                 197        4150 Adelie   
 9           45.4          14.6               211        4800 Chinstrap
10           38.1          17                 181        3175 Chinstrap
# ℹ 223 more rows
# ℹ 5 more variables: island_Biscoe <dbl>, island_Dream <dbl>,
#   island_Torgersen <dbl>, sex_female <dbl>, sex_male <dbl>

The random forest recipe() object is further built upon by adding a step that normalizes all continuous variables.

recipe_knn_obj <- recipe_rf_obj %>% 
  step_normalize(all_numeric_predictors())

recipe_knn_obj %>% 
  prep() %>% 
  juice()
# A tibble: 233 × 10
   bill_length_mm bill_depth_mm flipper_length_mm body_mass_g species  
            <dbl>         <dbl>             <dbl>       <dbl> <fct>    
 1         -1.41       -0.00476           -1.24        -1.72  Chinstrap
 2          1.37       -0.0552             2.02         1.63  Adelie   
 3          1.14        0.701             -0.0360       0.266 Adelie   
 4         -1.26        0.499             -1.67        -0.415 Adelie   
 5          0.204      -1.32               0.745        0.637 Chinstrap
 6         -0.945       0.197             -1.10        -0.849 Chinstrap
 7          0.682      -0.761              1.38         1.07  Chinstrap
 8          1.35        0.953             -0.320       -0.106 Adelie   
 9          0.186      -1.26               0.674        0.699 Chinstrap
10         -1.10       -0.0552            -1.46        -1.31  Chinstrap
# ℹ 223 more rows
# ℹ 5 more variables: island_Biscoe <dbl>, island_Dream <dbl>,
#   island_Torgersen <dbl>, sex_female <dbl>, sex_male <dbl>

Model Specifications

Referencing another previous post on the parsnip package, the three model specifications mentioned in the previous section are constructed. The first is a logistic regression, using the logistic_reg() function; the second is a random forest, using the rand_forest() function; the third is a k-nearest neighbors, using the nearest_neighbor() function.

mod_logreg_spec <- logistic_reg() %>% 
  set_engine("glm")

mod_logreg_spec
Logistic Regression Model Specification (classification)

Computational engine: glm 
mod_rf_spec <- rand_forest() %>%
  set_mode("classification") %>% 
  set_engine("randomForest", importance = TRUE) %>% 
  set_args(
    trees = tune(),
    mtry  = tune(),
    min_n = tune()
  )

mod_rf_spec
Random Forest Model Specification (classification)

Main Arguments:
  mtry = tune()
  trees = tune()
  min_n = tune()

Engine-Specific Arguments:
  importance = TRUE

Computational engine: randomForest 
mod_knn_spec <- nearest_neighbor() %>% 
  set_mode("classification") %>% 
  set_engine("kknn") %>% 
  set_args(neighbors = tune())

mod_knn_spec
K-Nearest Neighbor Model Specification (classification)

Main Arguments:
  neighbors = tune()

Computational engine: kknn 

Model Hyperparameters

Notice that two of the model specifications, the random forest and the k-nearest neighbors, had their arguments set to tune(). As such, their hyperparameters need to be tuned. Instructions for doing so can be found in another previous post, this time on the dials and tune packages. The procedure needs to be altered slightly for use with the stacks package.

First, the tuning process needs to be altered using the control_stack_grid() and control_stack_resamples() functions. According to the documentation, these functions slightly alter the usual control_* functions from the tune package such that they “return the appropriate control grid to ensure that assessment set predictions and information on model specifications and preprocessors, is supplied in the resampling results object”. This information is required for ensembling the models.

In addition, 10-fold cross validation is used using the vfold_cv() function from rsample package and the metric to be optimized during the tuning process is the Area Under the ROC Curve, using the roc_auc() function from the yardstick package. More information on the yardstick package can be found in a previous post.

ctrl_grid <- control_stack_grid()
ctrl_res  <- control_stack_resamples()

set.seed(1914)
folds_tbl <- vfold_cv(penguins_train_tbl, v = 10)

metric_obj <- metric_set(roc_auc)

Finally, to keep everything nice and organized, another previous post on the workflows package is referenced to wrap up each modeling algorithm and pre-processing recipe into its own workflow().

The logistic regression model does not have any hyperparameters to tune. As such, its workflow is passed to the fit_resamples() function (with the appropriate resamples, metrics, and control arguments).

The output is the result of the ten-folds, with the predictions appended as required for stacks().

mod_logreg_wflw <- workflow() %>% 
  add_recipe(recipe_logreg_obj) %>% 
  add_model(mod_logreg_spec)

mod_logreg_wflw
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ───────────────────────────────────────────────────────────────────────
Logistic Regression Model Specification (classification)

Computational engine: glm 
set.seed(1915)

mod_logreg_res <- mod_logreg_wflw %>% 
  fit_resamples(
    resamples = folds_tbl,
    metrics   = metric_obj,
    control   = ctrl_res
  )

mod_logreg_res
# Resampling results
# 10-fold cross-validation 
# A tibble: 10 × 5
   splits           id     .metrics         .notes           .predictions     
   <list>           <chr>  <list>           <list>           <list>           
 1 <split [209/24]> Fold01 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [24 × 5]>
 2 <split [209/24]> Fold02 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [24 × 5]>
 3 <split [209/24]> Fold03 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [24 × 5]>
 4 <split [210/23]> Fold04 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>
 5 <split [210/23]> Fold05 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>
 6 <split [210/23]> Fold06 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>
 7 <split [210/23]> Fold07 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>
 8 <split [210/23]> Fold08 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>
 9 <split [210/23]> Fold09 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>
10 <split [210/23]> Fold10 <tibble [1 × 4]> <tibble [0 × 3]> <tibble [23 × 5]>

The random forest has three hyperparameters that need to be tuned. As such, its workflow is passed to the tune_grid() function (again with resamples, metrics, and control appropriately defined) with grid = 10 set to try 10 values of each hyperparameter.

The output is the result of the ten-folds, but this time with an expanded .metrics column to account for the different hyperparameter combinations. Again, the predictions are appended as required for stacks().

mod_rf_wflw <- workflow() %>%
  add_recipe(recipe_rf_obj) %>% 
  add_model(mod_rf_spec)

mod_rf_wflw
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()

── Preprocessor ────────────────────────────────────────────────────────────────
1 Recipe Step

• step_dummy()

── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)

Main Arguments:
  mtry = tune()
  trees = tune()
  min_n = tune()

Engine-Specific Arguments:
  importance = TRUE

Computational engine: randomForest 
set.seed(1915)

mod_rf_res <- mod_rf_wflw %>% 
  tune_grid(
    grid      = 10,
    resamples = folds_tbl,
    metrics   = metric_obj,
    control   = ctrl_grid
  )

mod_rf_res
# Tuning results
# 10-fold cross-validation 
# A tibble: 10 × 5
   splits           id     .metrics          .notes           .predictions      
   <list>           <chr>  <list>            <list>           <list>            
 1 <split [209/24]> Fold01 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [240 × 8]>
 2 <split [209/24]> Fold02 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [240 × 8]>
 3 <split [209/24]> Fold03 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [240 × 8]>
 4 <split [210/23]> Fold04 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>
 5 <split [210/23]> Fold05 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>
 6 <split [210/23]> Fold06 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>
 7 <split [210/23]> Fold07 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>
 8 <split [210/23]> Fold08 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>
 9 <split [210/23]> Fold09 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>
10 <split [210/23]> Fold10 <tibble [10 × 7]> <tibble [0 × 3]> <tibble [230 × 8]>

Finally, the k-nearest neighbors has one hyperparameter that needs to be tuned. It is also passed to the tune_grid() function with the same arguments as the random forest.

The output is again the the result of the ten-folds, with the .metrics column showing the chosen value of the hyperparameter and the predictions appended as required for stacks().

mod_knn_wflw <- workflow() %>% 
  add_recipe(recipe_knn_obj) %>% 
  add_model(mod_knn_spec)

mod_knn_wflw
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: nearest_neighbor()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_dummy()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
K-Nearest Neighbor Model Specification (classification)

Main Arguments:
  neighbors = tune()

Computational engine: kknn 
set.seed(1915)

mod_knn_res <- mod_knn_wflw %>% 
  tune_grid(
    grid      = 10,
    resamples = folds_tbl,
    metrics   = metric_obj,
    control   = ctrl_grid
  )

mod_knn_res
# Tuning results
# 10-fold cross-validation 
# A tibble: 10 × 5
   splits           id     .metrics          .notes           .predictions      
   <list>           <chr>  <list>            <list>           <list>            
 1 <split [209/24]> Fold01 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [240 × 6]>
 2 <split [209/24]> Fold02 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [240 × 6]>
 3 <split [209/24]> Fold03 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [240 × 6]>
 4 <split [210/23]> Fold04 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>
 5 <split [210/23]> Fold05 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>
 6 <split [210/23]> Fold06 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>
 7 <split [210/23]> Fold07 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>
 8 <split [210/23]> Fold08 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>
 9 <split [210/23]> Fold09 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>
10 <split [210/23]> Fold10 <tibble [10 × 5]> <tibble [0 × 3]> <tibble [230 × 6]>

Model Stacking Process

Now that there are three separate models, they can be stacked together to produce a single output.

Stacking Model Members

To start a model stack, the stacks() function is used to create an empty stack. A stack reports the number of model definitions and candidate members.

# A data stack with 0 model definitions and 0 candidate members.

From that empty stack, models are added one at a time using the add_candidates() function. It can now be seen that the stack contains 3 model definitions (the logistic regression, random forest, and k-nearest neighbor models) and 21 candidate members.

mod_spec_stck <- stacks() %>% 
  add_candidates(mod_logreg_res) %>% 
  add_candidates(mod_rf_res) %>% 
  add_candidates(mod_knn_res)

mod_spec_stck
# A data stack with 3 model definitions and 21 candidate members:
#   mod_logreg_res: 1 model configuration
#   mod_rf_res: 10 model configurations
#   mod_knn_res: 10 model configurations
# Outcome: species (factor)

When added to a stack, models are given a default name in the object. These names can be customized using the name argument in the add_candidates() function.

mod_spec_stck <- stacks() %>% 
  add_candidates(mod_logreg_res, name = "logistic_regression") %>% 
  add_candidates(mod_rf_res, name = "random_forest") %>% 
  add_candidates(mod_knn_res, name = "k_nearest_neighbor")

mod_spec_stck
# A data stack with 3 model definitions and 21 candidate members:
#   logistic_regression: 1 model configuration
#   random_forest: 10 model configurations
#   k_nearest_neighbor: 10 model configurations
# Outcome: species (factor)

The 21 candidate members come from the different combinations of hyperparameters that were used during tuning. No final combination of hyperparameters have been chosen yet, so all 21 combinations are included in the stack, consisting of:

  • Logistic Regression: 1 combination (no hyperparameters)
  • Random Forest: 10 combinations
  • K-Nearest Neighbors: 10 combinations

The collect_parameters() function can be used to see the specifics of these candidate members.

mod_spec_stck %>% 
  collect_parameters("logistic_regression")
# A tibble: 1 × 1
  member                 
  <chr>                  
1 logistic_regression_1_1
mod_spec_stck %>% 
  collect_parameters("random_forest")
# A tibble: 10 × 4
   member              mtry trees min_n
   <chr>              <int> <int> <int>
 1 random_forest_1_01     5   597    27
 2 random_forest_1_02     4   786    16
 3 random_forest_1_03     3    32     7
 4 random_forest_1_04     1   340    20
 5 random_forest_1_05     7  1930    29
 6 random_forest_1_06     6  1419    13
 7 random_forest_1_07     2  1350    24
 8 random_forest_1_08     9  1678    33
 9 random_forest_1_09     5   865     4
10 random_forest_1_10     8  1116    37
mod_spec_stck %>% 
  collect_parameters("k_nearest_neighbor")
# A tibble: 10 × 2
   member                  neighbors
   <chr>                       <int>
 1 k_nearest_neighbor_1_01         1
 2 k_nearest_neighbor_1_02         3
 3 k_nearest_neighbor_1_03         5
 4 k_nearest_neighbor_1_04         6
 5 k_nearest_neighbor_1_05         7
 6 k_nearest_neighbor_1_06         9
 7 k_nearest_neighbor_1_07        10
 8 k_nearest_neighbor_1_08        11
 9 k_nearest_neighbor_1_09        13
10 k_nearest_neighbor_1_10        14

Once all the models have been added to the stack, the blend_predictions() function can be used to create the model ensemble. The blending process works by applying an elastic net regression-type procedure to the candidate members. Those that make it through the procedure with a non-zero “stacking coefficient” are kept in the ensemble and the rest are discarded.

set.seed(1916)

mod_blend_stck <- mod_spec_stck %>% 
  blend_predictions()

mod_blend_stck
# A tibble: 2 × 3
  member                             type        weight
  <chr>                              <chr>        <dbl>
1 .pred_Chinstrap_random_forest_1_07 rand_forest  2.00 
2 .pred_Chinstrap_random_forest_1_05 rand_forest  0.828

By default, the mixture parameter of the elastic net is 0, which results in a pure LASSO regression. The mixture, as well as the penalty parameter, can be altered as shown below. Moreover, the times parameter (the number of times bootstrapping is used to calculating the stacking coefficients) can also be adjusted in the blend_predictions() function.

set.seed(1916)

mod_blend_stck <- mod_spec_stck %>% 
  blend_predictions(
    penalty = 10 ^ (-7:-3),
    mixture = 0.5,
    times = 50
  )

mod_blend_stck
# A tibble: 9 × 3
  member                                  type             weight
  <chr>                                   <chr>             <dbl>
1 .pred_Chinstrap_random_forest_1_04      rand_forest      1.12  
2 .pred_Chinstrap_random_forest_1_07      rand_forest      0.949 
3 .pred_Chinstrap_random_forest_1_08      rand_forest      0.619 
4 .pred_Chinstrap_random_forest_1_06      rand_forest      0.592 
5 .pred_Chinstrap_random_forest_1_05      rand_forest      0.559 
6 .pred_Chinstrap_logistic_regression_1_1 logistic_reg     0.334 
7 .pred_Chinstrap_k_nearest_neighbor_1_01 nearest_neighbor 0.329 
8 .pred_Chinstrap_random_forest_1_10      rand_forest      0.321 
9 .pred_Chinstrap_k_nearest_neighbor_1_02 nearest_neighbor 0.0153

With these updated arguments, at least one candidate member from each modeling algorithm was chosen. With these candidate members included in the ensemble, the final model can then be fit to the data using the fit_members() function.

mod_fit_stck <- mod_blend_stck %>% 
  fit_members()

mod_fit_stck
# A tibble: 9 × 3
  member                                  type             weight
  <chr>                                   <chr>             <dbl>
1 .pred_Chinstrap_random_forest_1_04      rand_forest      1.12  
2 .pred_Chinstrap_random_forest_1_07      rand_forest      0.949 
3 .pred_Chinstrap_random_forest_1_08      rand_forest      0.619 
4 .pred_Chinstrap_random_forest_1_06      rand_forest      0.592 
5 .pred_Chinstrap_random_forest_1_05      rand_forest      0.559 
6 .pred_Chinstrap_logistic_regression_1_1 logistic_reg     0.334 
7 .pred_Chinstrap_k_nearest_neighbor_1_01 nearest_neighbor 0.329 
8 .pred_Chinstrap_random_forest_1_10      rand_forest      0.321 
9 .pred_Chinstrap_k_nearest_neighbor_1_02 nearest_neighbor 0.0153

Plots

The stacks package provides a few ggplot2 plots to examine the resulting ensemble via the autoplot() function.

Passing the fitted stack object to autoplot() shows how the performance metrics changed over the values of penalty. For example, one wants to maximize accuracy and AUC for a classification model. The dotted lines in the plot below show where that is achieved.

autoplot(mod_fit_stck)

Specifying type = "members" will show the values of performance metrics across different numbers of members in the ensemble model. As before, metrics such as accuracy and AUC are desired to be maximized.

autoplot(mod_fit_stck, type = "members")

Specifying type = "weights" provides a visual representation of the stacking coefficients (how much each model is weighted in the final ensemble) for each model in the ensemble, color-coded by the model algorithm.

autoplot(mod_fit_stck, type = "weights")

Stacked Model Performance

Now that a final ensemble has been chosen and examined, it can be used for predictions. Since the model was built using a categorical outcome, there are two predictions that can be calculated:

  1. Class Predictions
  2. Probability Predictions

The predicted classes for the testing data can be calculated by passing the stack and the testing data to the predict() function.

preds_class_tbl <- mod_fit_stck %>% 
  predict(penguins_test_tbl) %>% 
  bind_cols(penguins_test_tbl)

preds_class_tbl
# A tibble: 100 × 8
   .pred_class species   island   bill_length_mm bill_depth_mm flipper_length_mm
   <fct>       <fct>     <fct>             <dbl>         <dbl>             <int>
 1 Chinstrap   Chinstrap Torgers…           40.3          18                 195
 2 Adelie      Adelie    Torgers…           39.2          19.6               195
 3 Chinstrap   Chinstrap Torgers…           38.7          19                 195
 4 Adelie      Adelie    Torgers…           46            21.5               194
 5 Chinstrap   Chinstrap Biscoe             37.8          18.3               174
 6 Adelie      Adelie    Biscoe             37.7          18.7               180
 7 Adelie      Adelie    Biscoe             38.2          18.1               185
 8 Adelie      Chinstrap Biscoe             40.6          18.6               183
 9 Chinstrap   Chinstrap Dream              36.4          17                 195
10 Chinstrap   Adelie    Dream              42.2          18.5               180
# ℹ 90 more rows
# ℹ 2 more variables: body_mass_g <int>, sex <fct>

To get the predicted probabilities for each class, specify type = "prob" in the predict() function.

preds_prob_tbl <- mod_fit_stck %>% 
  predict(penguins_test_tbl, type = "prob") %>% 
  bind_cols(penguins_test_tbl)

preds_prob_tbl
# A tibble: 100 × 9
   .pred_Adelie .pred_Chinstrap species   island    bill_length_mm bill_depth_mm
          <dbl>           <dbl> <fct>     <fct>              <dbl>         <dbl>
 1       0.112           0.888  Chinstrap Torgersen           40.3          18  
 2       0.901           0.0993 Adelie    Torgersen           39.2          19.6
 3       0.137           0.863  Chinstrap Torgersen           38.7          19  
 4       0.900           0.100  Adelie    Torgersen           46            21.5
 5       0.115           0.885  Chinstrap Biscoe              37.8          18.3
 6       0.697           0.303  Adelie    Biscoe              37.7          18.7
 7       0.827           0.173  Adelie    Biscoe              38.2          18.1
 8       0.880           0.120  Chinstrap Biscoe              40.6          18.6
 9       0.0962          0.904  Chinstrap Dream               36.4          17  
10       0.205           0.795  Adelie    Dream               42.2          18.5
# ℹ 90 more rows
# ℹ 3 more variables: flipper_length_mm <int>, body_mass_g <int>, sex <fct>

Another method of getting predictions, without having to use bind_cols(), is to use the augment() function. More information on the augment() function can be found in a previous post on the broom package. Passing the stack and the training data to augment() produces the predicted outcome class.

mod_fit_stck %>% 
  augment(penguins_train_tbl) %>% 
  relocate(.pred_class, .before = everything())
# A tibble: 233 × 8
   .pred_class species   island bill_length_mm bill_depth_mm flipper_length_mm
   <fct>       <fct>     <fct>           <dbl>         <dbl>             <int>
 1 Chinstrap   Chinstrap Biscoe           36.4          17.1               184
 2 Adelie      Adelie    Biscoe           52.1          17                 230
 3 Adelie      Adelie    Dream            50.8          18.5               201
 4 Adelie      Adelie    Dream            37.2          18.1               178
 5 Chinstrap   Chinstrap Biscoe           45.5          14.5               212
 6 Chinstrap   Chinstrap Biscoe           39            17.5               186
 7 Adelie      Chinstrap Biscoe           48.2          15.6               221
 8 Adelie      Adelie    Dream            52            19                 197
 9 Chinstrap   Chinstrap Biscoe           45.4          14.6               211
10 Chinstrap   Chinstrap Biscoe           38.1          17                 181
# ℹ 223 more rows
# ℹ 2 more variables: body_mass_g <int>, sex <fct>

Specifying type = "prob" in augment produces the predicted probabilities for each outcome class.

mod_fit_stck %>% 
  augment(penguins_train_tbl, type = "prob") %>% 
  relocate(contains(".pred"), .before = everything())
# A tibble: 233 × 9
   .pred_Adelie .pred_Chinstrap species   island bill_length_mm bill_depth_mm
          <dbl>           <dbl> <fct>     <fct>           <dbl>         <dbl>
 1       0.0951          0.905  Chinstrap Biscoe           36.4          17.1
 2       0.919           0.0812 Adelie    Biscoe           52.1          17  
 3       0.914           0.0860 Adelie    Dream            50.8          18.5
 4       0.751           0.249  Adelie    Dream            37.2          18.1
 5       0.0985          0.901  Chinstrap Biscoe           45.5          14.5
 6       0.127           0.873  Chinstrap Biscoe           39            17.5
 7       0.837           0.163  Chinstrap Biscoe           48.2          15.6
 8       0.916           0.0845 Adelie    Dream            52            19  
 9       0.103           0.897  Chinstrap Biscoe           45.4          14.6
10       0.0970          0.903  Chinstrap Biscoe           38.1          17  
# ℹ 223 more rows
# ℹ 3 more variables: flipper_length_mm <int>, body_mass_g <int>, sex <fct>

As with parsnip models and workflows, stacks can be passed to functions from the yardstick package to calculate performance metrics. For this ensemble, conf_mat() and roc_auc() are used below to show the confusion matrix and area under the ROC curve, respectively.

preds_class_tbl %>% 
  conf_mat(truth = species, estimate = .pred_class)
           Truth
Prediction  Adelie Chinstrap
  Adelie        38        11
  Chinstrap      8        43
preds_prob_tbl %>% 
  roc_auc(truth = species, estimate = .pred_Adelie)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 roc_auc binary         0.891

Notes

This post is based on a presentation that was given on the date listed. It may be updated from time to time to fix errors, detail new functions, and/or remove deprecated functions so the packages and R version will likely be newer than what was available at the time.

The R session information used for this post:

R version 4.2.1 (2022-06-23)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS 14.0

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices datasets  utils     methods   base     

other attached packages:
 [1] glmnet_4.1-6    Matrix_1.4-1    broom_1.0.5     workflows_1.1.2
 [5] yardstick_1.1.0 tune_1.0.1      dials_1.1.0     scales_1.2.1   
 [9] parsnip_1.0.3   recipes_1.0.4   dplyr_1.1.3     rsample_1.1.1  
[13] stacks_1.0.2   

loaded via a namespace (and not attached):
 [1] tidyr_1.3.0          kknn_1.3.1           jsonlite_1.8.0      
 [4] splines_4.2.1        foreach_1.5.2        prodlim_2019.11.13  
 [7] GPfit_1.0-8          renv_0.16.0          yaml_2.3.5          
[10] globals_0.16.2       ipred_0.9-13         pillar_1.9.0        
[13] backports_1.4.1      lattice_0.20-45      glue_1.6.2          
[16] digest_0.6.29        randomForest_4.7-1.1 hardhat_1.2.0       
[19] colorspace_2.0-3     htmltools_0.5.3      timeDate_4022.108   
[22] pkgconfig_2.0.3      lhs_1.1.6            DiceDesign_1.9      
[25] listenv_0.8.0        purrr_1.0.2          gower_1.0.1         
[28] lava_1.7.1           timechange_0.1.1     tibble_3.2.1        
[31] farver_2.1.1         generics_0.1.3       ggplot2_3.4.0       
[34] ellipsis_0.3.2       withr_2.5.1          furrr_0.3.1         
[37] nnet_7.3-17          cli_3.6.1            survival_3.3-1      
[40] magrittr_2.0.3       evaluate_0.16        future_1.29.0       
[43] fansi_1.0.4          parallelly_1.32.1    MASS_7.3-57         
[46] class_7.3-20         tools_4.2.1          lifecycle_1.0.3     
[49] stringr_1.5.0        munsell_0.5.0        compiler_4.2.1      
[52] rlang_1.1.1          grid_4.2.1           iterators_1.0.14    
[55] rstudioapi_0.14      igraph_1.5.1         labeling_0.4.2      
[58] rmarkdown_2.16       gtable_0.3.1         codetools_0.2-18    
[61] R6_2.5.1             lubridate_1.9.0      knitr_1.40          
[64] fastmap_1.1.0        future.apply_1.10.0  utf8_1.2.3          
[67] butcher_0.3.3        shape_1.4.6          stringi_1.7.12      
[70] parallel_4.2.1       Rcpp_1.0.9           vctrs_0.6.3         
[73] rpart_4.1.19         tidyselect_1.2.0     xfun_0.40