This post covers the broom
package, which focuses accessing and print key information about a model object in a “tidy” fashion.
Setup
Packages
The following packages are required:
Data
I utilized the penguins data set from the modeldata package for these examples.
Note that the outcome variable is forced to be binary by assigning all the odd-numbered rows as “Chinstrap” and the even numbered rows as “Adelie”. This is done to (1) ensure that the outcome is binary so certain functions and techniques can be used and (2) make the outcome a little more difficult to predict to better demonstrate the difference in the construction of the two modeling algorithms used.
Model Building Process
Before demonstrating the broom
package, a few models will be built. Following the previous posts in this series, these model objects are built following the tidymodels
process.
Splitting & Processing Data
Referencing a previous post on the rsample
package, a 70/30 train/test initial_split()
on both data sets is taken. It is then passed to the training()
and testing()
functions to extract the training and testing data sets, respectively.
set.seed(1914)
penguins_split_obj <- initial_split(penguins_tbl, prop = 0.7)
penguins_train_tbl <- training(penguins_split_obj)
penguins_train_tbl
# A tibble: 240 × 7
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
<fct> <fct> <dbl> <dbl> <int> <int>
1 Chinstrap Biscoe 36.5 16.6 181 2850
2 Adelie Biscoe 52.5 15.6 221 5450
3 Chinstrap Dream 49.7 18.6 195 3600
4 Chinstrap Biscoe 40.6 18.6 183 3550
5 Chinstrap Biscoe 44.9 13.8 212 4750
6 Chinstrap Biscoe 39.6 17.7 186 3500
7 Chinstrap Biscoe 45.8 14.2 219 4700
8 Chinstrap Dream 45.9 17.1 190 3575
9 Chinstrap Biscoe 46.1 13.2 211 4500
10 Chinstrap Biscoe 37.7 16 183 3075
# ℹ 230 more rows
# ℹ 1 more variable: sex <fct>
penguins_test_tbl <- testing(penguins_split_obj)
penguins_test_tbl
# A tibble: 104 × 7
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
<fct> <fct> <dbl> <dbl> <int> <int>
1 Chinstrap Torgers… 40.3 18 195 3250
2 Chinstrap Torgers… 38.9 17.8 181 3625
3 Adelie Torgers… 37.8 17.3 180 3700
4 Chinstrap Torgers… 34.6 21.1 198 4400
5 Adelie Torgers… 36.6 17.8 185 3700
6 Chinstrap Torgers… 38.7 19 195 3450
7 Chinstrap Torgers… 34.4 18.4 184 3325
8 Adelie Biscoe 37.7 18.7 180 3600
9 Adelie Biscoe 40.5 18.9 180 3950
10 Chinstrap Dream 39.5 17.8 188 3300
# ℹ 94 more rows
# ℹ 1 more variable: sex <fct>
Referencing another previous post on the recipes
package, pre-processing steps are performed on the training data set. The resulting object from the recipe()
function is then passed to the prep()
function to prepare it for use. The result is then passed to the juice()
function to extract the transformed training data set and the bake()
function, along with the raw testing data set, to get the transformed testing data set.
recipe_obj <- recipe(species ~ ., data = penguins_train_tbl) %>%
step_naomit(all_predictors(), skip = FALSE) %>%
step_normalize(all_numeric_predictors())
# A tibble: 231 × 7
island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex
<fct> <dbl> <dbl> <dbl> <dbl> <fct>
1 Biscoe -1.35 -0.225 -1.41 -1.67 female
2 Biscoe 1.54 -0.726 1.36 1.52 male
3 Dream 1.03 0.779 -0.440 -0.753 male
4 Biscoe -0.612 0.779 -1.27 -0.814 male
5 Biscoe 0.166 -1.63 0.740 0.658 female
6 Biscoe -0.793 0.327 -1.06 -0.876 female
7 Biscoe 0.328 -1.43 1.23 0.597 female
8 Dream 0.347 0.0263 -0.787 -0.784 female
9 Biscoe 0.383 -1.93 0.670 0.351 female
10 Biscoe -1.14 -0.526 -1.27 -1.40 female
# ℹ 221 more rows
# ℹ 1 more variable: species <fct>
penguins_test_prep_tbl <- recipe_obj %>%
prep() %>%
bake(new_data = penguins_test_tbl)
penguins_test_prep_tbl
# A tibble: 102 × 7
island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex
<fct> <dbl> <dbl> <dbl> <dbl> <fct>
1 Torgersen -0.667 0.478 -0.440 -1.18 female
2 Torgersen -0.920 0.377 -1.41 -0.722 female
3 Torgersen -1.70 2.03 -0.232 0.229 male
4 Torgersen -1.34 0.377 -1.13 -0.630 female
5 Torgersen -0.956 0.979 -0.440 -0.937 female
6 Torgersen -1.73 0.678 -1.20 -1.09 female
7 Biscoe -1.14 0.829 -1.48 -0.753 male
8 Biscoe -0.630 0.929 -1.48 -0.324 male
9 Dream -0.811 0.377 -0.926 -1.12 female
10 Dream -0.558 0.929 -1.20 -0.385 male
# ℹ 92 more rows
# ℹ 1 more variable: species <fct>
Tuning Model Hyperparameters
Referencing another previous post on the parsnip
package, two model specifications are constructed. The first is a logistic regression, using the logistic_reg()
function, and the second is a random forest, using the rand_forest()
function.
mod_logistic_spec <- logistic_reg() %>%
set_engine("glm")
mod_logistic_fit <- mod_logistic_spec %>%
fit(species ~ ., data = penguins_train_prep_tbl)
mod_logistic_fit
parsnip model object
Call: stats::glm(formula = species ~ ., family = stats::binomial, data = data)
Coefficients:
(Intercept) islandDream islandTorgersen bill_length_mm
1.3929 -0.5952 -0.8689 -0.1673
bill_depth_mm flipper_length_mm body_mass_g sexmale
-0.9881 0.1984 -1.6473 -2.0012
Degrees of Freedom: 230 Total (i.e. Null); 223 Residual
Null Deviance: 320
Residual Deviance: 169.4 AIC: 185.4
mod_rf_spec <- rand_forest() %>%
set_mode("classification") %>%
set_engine("randomForest", importance = TRUE) %>%
set_args(
mtry = tune(),
min_n = tune(),
trees = tune()
)
mod_rf_spec
Random Forest Model Specification (classification)
Main Arguments:
mtry = tune()
trees = tune()
min_n = tune()
Engine-Specific Arguments:
importance = TRUE
Computational engine: randomForest
Another previous post on the workflows
package is referenced to create two versions of each model specification: one base parsnip
object (as created previously) and one where a workflow()
is created with the model specification. The recipe
is also added to the workflow
.
mod_logistic_fit_wflw <- workflow() %>%
add_recipe(recipe_obj) %>%
add_model(mod_logistic_spec) %>%
fit(penguins_train_tbl)
mod_logistic_fit_wflw
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()
── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps
• step_naomit()
• step_normalize()
── Model ───────────────────────────────────────────────────────────────────────
Call: stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)
Coefficients:
(Intercept) islandDream islandTorgersen bill_length_mm
1.3929 -0.5952 -0.8689 -0.1673
bill_depth_mm flipper_length_mm body_mass_g sexmale
-0.9881 0.1984 -1.6473 -2.0012
Degrees of Freedom: 230 Total (i.e. Null); 223 Residual
Null Deviance: 320
Residual Deviance: 169.4 AIC: 185.4
mod_rf_spec_wflw <- workflow() %>%
add_recipe(recipe_obj) %>%
add_model(mod_rf_spec)
mod_rf_spec_wflw
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()
── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps
• step_naomit()
• step_normalize()
── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)
Main Arguments:
mtry = tune()
trees = tune()
min_n = tune()
Engine-Specific Arguments:
importance = TRUE
Computational engine: randomForest
Again referencing a previous post, this time on the dials
and tune
packages, the random forest, with hyperparameters marked as tune()
, will undergo hyperparameter tuning using the tune_grid()
function on a random table of hyperparameter values for mtry()
, min_n()
, and trees()
created using the grid_random()
function.
set.seed(1914)
folds_tbl <- vfold_cv(penguins_train_prep_tbl, v = 10)
params <- parameters(
finalize(mtry(), penguins_train_prep_tbl[ ,-7]),
trees(),
min_n()
)
set.seed(1915)
grid_random_tbl <- grid_random(params, size = 10)
set.seed(1916)
mod_rf_tuned_tbl <- tune_grid(
mod_rf_spec, species ~ .,
resamples = folds_tbl,
grid = grid_random_tbl,
metrics = metric_set(roc_auc)
)
mod_rf_best_tbl <- select_best(mod_rf_tuned_tbl)
mod_rf_fit <- mod_rf_spec %>%
finalize_model(mod_rf_best_tbl) %>%
fit(species ~ ., data = penguins_train_prep_tbl)
mod_rf_fit
parsnip model object
Call:
randomForest(x = maybe_data_frame(x), y = y, ntree = ~705L, mtry = min_cols(~4L, x), nodesize = min_rows(~40L, x), importance = ~TRUE)
Type of random forest: classification
Number of trees: 705
No. of variables tried at each split: 4
OOB estimate of error rate: 14.72%
Confusion matrix:
Adelie Chinstrap class.error
Adelie 94 18 0.1607143
Chinstrap 16 103 0.1344538
set.seed(1917)
mod_rf_spec_vars_wflw <- mod_rf_spec_wflw %>%
remove_recipe() %>%
add_variables(outcomes = species, predictors = everything())
mod_rf_tuned_wflw <- tune_grid(
mod_rf_spec_vars_wflw,
resamples = folds_tbl,
grid = grid_random_tbl,
metrics = metric_set(roc_auc)
)
mod_rf_best_tbl <- select_best(mod_rf_tuned_wflw)
mod_rf_fit_wflw <- mod_rf_spec_wflw %>%
finalize_workflow(mod_rf_best_tbl) %>%
fit(penguins_train_tbl)
mod_rf_fit_wflw
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()
── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps
• step_naomit()
• step_normalize()
── Model ───────────────────────────────────────────────────────────────────────
Call:
randomForest(x = maybe_data_frame(x), y = y, ntree = ~705L, mtry = min_cols(~4L, x), nodesize = min_rows(~40L, x), importance = ~TRUE)
Type of random forest: classification
Number of trees: 705
No. of variables tried at each split: 4
OOB estimate of error rate: 14.72%
Confusion matrix:
Adelie Chinstrap class.error
Adelie 94 18 0.1607143
Chinstrap 16 103 0.1344538
Validating Final Model
Now that a few final model objects are built, information from another previous post can be used to validate them with the yardstick
package. To show that each model algorithm produces the same result, regardless of it is a base parsnip
object of a workflow
, each type is shown below. For each type, the in-sample performance on the training data and the out-of-sample performance on the testing data is shown. The performance metrics shown are a confusion matrix and the Area Under the ROC Curve (AUC).
Logistic Regression
As shown below, the logistic regression model produces the exact same in-sample and out-of-sample metrics regardless of it is a parsnip
or workflows
object.
Code
#> Use prepared data from `recipe`
mod_logistic_pred_class_train_tbl <- mod_logistic_fit %>%
predict(penguins_train_prep_tbl) %>%
bind_cols(penguins_train_prep_tbl)
mod_logistic_pred_prob_train_tbl <- mod_logistic_fit %>%
predict(penguins_train_prep_tbl, type = "prob") %>%
bind_cols(penguins_train_prep_tbl)
Truth
Prediction Adelie Chinstrap
Adelie 95 17
Chinstrap 17 102
Code
#> Use prepared data from `recipe`
mod_logistic_pred_class_test_tbl <- mod_logistic_fit %>%
predict(penguins_test_prep_tbl) %>%
bind_cols(penguins_test_prep_tbl)
mod_logistic_pred_prob_test_tbl <- mod_logistic_fit %>%
predict(penguins_test_prep_tbl, type = "prob") %>%
bind_cols(penguins_test_prep_tbl)
Truth
Prediction Adelie Chinstrap
Adelie 47 8
Chinstrap 8 39
Code
#> Use original data since `recipe` is in the `workflow`
#> The `recipe` will remove rows with missing values, so bind to prepared data
mod_logistic_pred_class_train_tbl <- mod_logistic_fit_wflw %>%
predict(penguins_train_tbl) %>%
bind_cols(penguins_train_prep_tbl)
mod_logistic_pred_prob_train_tbl <- mod_logistic_fit_wflw %>%
predict(penguins_train_tbl, type = "prob") %>%
bind_cols(penguins_train_prep_tbl)
Truth
Prediction Adelie Chinstrap
Adelie 95 17
Chinstrap 17 102
Code
#> Use original data since `recipe` is in the `workflow`
#> The `recipe` will remove rows with missing values, so bind to prepared data
mod_logistic_pred_class_test_tbl <- mod_logistic_fit_wflw %>%
predict(penguins_test_tbl) %>%
bind_cols(penguins_test_prep_tbl)
mod_logistic_pred_prob_test_tbl <- mod_logistic_fit_wflw %>%
predict(penguins_test_tbl, type = "prob") %>%
bind_cols(penguins_test_prep_tbl)
Truth
Prediction Adelie Chinstrap
Adelie 47 8
Chinstrap 8 39
Random Forest
The random forest model also produces the exact same in-sample and out-of-sample metrics regardless of it is a parsnip
or workflows
object, as shown below.
Truth
Prediction Adelie Chinstrap
Adelie 96 12
Chinstrap 16 107
Truth
Prediction Adelie Chinstrap
Adelie 45 8
Chinstrap 10 39
Code
#> Use original data since `recipe` is in the `workflow`
#> The `recipe` will remove rows with missing values, so bind to prepared data
mod_rf_pred_class_train_tbl <- mod_rf_fit_wflw %>%
predict(penguins_train_tbl) %>%
bind_cols(penguins_train_prep_tbl)
mod_rf_pred_prob_train_tbl <- mod_rf_fit_wflw %>%
predict(penguins_train_tbl, type = "prob") %>%
bind_cols(penguins_train_prep_tbl)
Truth
Prediction Adelie Chinstrap
Adelie 95 12
Chinstrap 17 107
Code
#> Use original data since `recipe` is in the `workflow`
#> The `recipe` will remove rows with missing values, so bind to prepared data
mod_rf_pred_class_test_tbl <- mod_rf_fit_wflw %>%
predict(penguins_test_tbl) %>%
bind_cols(penguins_test_prep_tbl)
mod_rf_pred_prob_test_tbl <- mod_rf_fit_wflw %>%
predict(penguins_test_tbl, type = "prob") %>%
bind_cols(penguins_test_prep_tbl)
Truth
Prediction Adelie Chinstrap
Adelie 45 8
Chinstrap 10 39
Examining Model Objects
Now that there are various model objects, the broom
package can be used to examine them. There are three main functions that the broom
package provides:
-
tidy()
: summarizes information about the components of a model in a tidytibble
-
glance()
: constructs a single row summary of a model -
augment()
: adds information from a model to a data set
Logistic Regression
When printing out a fitted model object, especially a workflow
, the amount of information displayed on the console can quickly get overwhelming, as shown below.
mod_logistic_fit
parsnip model object
Call: stats::glm(formula = species ~ ., family = stats::binomial, data = data)
Coefficients:
(Intercept) islandDream islandTorgersen bill_length_mm
1.3929 -0.5952 -0.8689 -0.1673
bill_depth_mm flipper_length_mm body_mass_g sexmale
-0.9881 0.1984 -1.6473 -2.0012
Degrees of Freedom: 230 Total (i.e. Null); 223 Residual
Null Deviance: 320
Residual Deviance: 169.4 AIC: 185.4
mod_logistic_fit_wflw
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()
── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps
• step_naomit()
• step_normalize()
── Model ───────────────────────────────────────────────────────────────────────
Call: stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)
Coefficients:
(Intercept) islandDream islandTorgersen bill_length_mm
1.3929 -0.5952 -0.8689 -0.1673
bill_depth_mm flipper_length_mm body_mass_g sexmale
-0.9881 0.1984 -1.6473 -2.0012
Degrees of Freedom: 230 Total (i.e. Null); 223 Residual
Null Deviance: 320
Residual Deviance: 169.4 AIC: 185.4
This is where tidy()
and the other functions from broom
come into play. They allow for quick summaries of model information to be neatly printed to the console as a tidy tibble
.
For a logistic regression parsnip
object, tidy()
prints out a tibble
of the variables, the estimated coefficients, and other statistics related to the coefficients.
tidy(mod_logistic_fit)
# A tibble: 8 × 5
term estimate std.error statistic p.value
<chr> <dbl> <dbl> <dbl> <dbl>
1 (Intercept) 1.39 0.424 3.29 0.00102
2 islandDream -0.595 0.638 -0.933 0.351
3 islandTorgersen -0.869 0.736 -1.18 0.238
4 bill_length_mm -0.167 0.301 -0.556 0.578
5 bill_depth_mm -0.988 0.361 -2.73 0.00627
6 flipper_length_mm 0.198 0.501 0.396 0.692
7 body_mass_g -1.65 0.556 -2.96 0.00307
8 sexmale -2.00 0.496 -4.04 0.0000542
Additionally, for a logistic regression parsnip
object, glance()
prints out some common fit statistics for linear models.
glance(mod_logistic_fit)
# A tibble: 1 × 8
null.deviance df.null logLik AIC BIC deviance df.residual nobs
<dbl> <int> <dbl> <dbl> <dbl> <dbl> <int> <int>
1 320. 230 -84.7 185. 213. 169. 223 231
Finally, for a logistic regression parsnip
object, augment()
adds the predicted class and the predicted probabilities for each class to a data set that is supplied to the function. Note that these columns are appended to the end of the data set, so the data set below was reduced down real outcome and the augment
ed columns.
mod_logistic_fit %>%
augment(penguins_test_prep_tbl) %>%
select(species, .pred_class, .pred_Adelie, .pred_Chinstrap)
# A tibble: 102 × 4
species .pred_class .pred_Adelie .pred_Chinstrap
<fct> <fct> <dbl> <dbl>
1 Chinstrap Chinstrap 0.117 0.883
2 Chinstrap Chinstrap 0.229 0.771
3 Chinstrap Adelie 0.974 0.0260
4 Adelie Chinstrap 0.234 0.766
5 Chinstrap Chinstrap 0.236 0.764
6 Chinstrap Chinstrap 0.154 0.846
7 Adelie Adelie 0.572 0.428
8 Adelie Adelie 0.765 0.235
9 Chinstrap Chinstrap 0.0977 0.902
10 Adelie Adelie 0.837 0.163
# ℹ 92 more rows
As with parsnip
objects, workflow
objects can also be passed to broom
’s functions. For a logistic regression workflow
object, the same information is printed using tidy()
as with a parsnip
object.
tidy(mod_logistic_fit_wflw)
# A tibble: 8 × 5
term estimate std.error statistic p.value
<chr> <dbl> <dbl> <dbl> <dbl>
1 (Intercept) 1.39 0.424 3.29 0.00102
2 islandDream -0.595 0.638 -0.933 0.351
3 islandTorgersen -0.869 0.736 -1.18 0.238
4 bill_length_mm -0.167 0.301 -0.556 0.578
5 bill_depth_mm -0.988 0.361 -2.73 0.00627
6 flipper_length_mm 0.198 0.501 0.396 0.692
7 body_mass_g -1.65 0.556 -2.96 0.00307
8 sexmale -2.00 0.496 -4.04 0.0000542
Similarly, for a logistic regression workflow
object, glance()
prints the same information as for a parsnip
object.
glance(mod_logistic_fit_wflw)
# A tibble: 1 × 8
null.deviance df.null logLik AIC BIC deviance df.residual nobs
<dbl> <int> <dbl> <dbl> <dbl> <dbl> <int> <int>
1 320. 230 -84.7 185. 213. 169. 223 231
Where things go wrong is with augment()
. The recipe
attached to the workflow
filters out rows with missing values. However, the original test data has missing values. This data does not appear to be passed to the recipe
before predictions are attempted, leading to the following error.
mod_logistic_fit_wflw %>%
augment(penguins_test_tbl) %>%
select(species, .pred_class, .pred_Adelie, .pred_Chinstrap)
Error in `vctrs::vec_cbind()`:
! Can't recycle `..1` (size 104) to match `..2` (size 102).
Manually removing the rows with missing values solves this problem, and augment()
works just as it did with the parsnip
model object.
penguins_test_no_missing_tbl <- tidyr::drop_na(penguins_test_tbl)
mod_logistic_fit_wflw %>%
augment(penguins_test_no_missing_tbl) %>%
select(species, .pred_class, .pred_Adelie, .pred_Chinstrap)
# A tibble: 102 × 4
species .pred_class .pred_Adelie .pred_Chinstrap
<fct> <fct> <dbl> <dbl>
1 Chinstrap Chinstrap 0.117 0.883
2 Chinstrap Chinstrap 0.229 0.771
3 Chinstrap Adelie 0.974 0.0260
4 Adelie Chinstrap 0.234 0.766
5 Chinstrap Chinstrap 0.236 0.764
6 Chinstrap Chinstrap 0.154 0.846
7 Adelie Adelie 0.572 0.428
8 Adelie Adelie 0.765 0.235
9 Chinstrap Chinstrap 0.0977 0.902
10 Adelie Adelie 0.837 0.163
# ℹ 92 more rows
Random Forest
As with the logistic regression, printing a random forest fitted model object, especially a workflow
, creates a large amount of information displayed on the console.
mod_rf_fit
parsnip model object
Call:
randomForest(x = maybe_data_frame(x), y = y, ntree = ~705L, mtry = min_cols(~4L, x), nodesize = min_rows(~40L, x), importance = ~TRUE)
Type of random forest: classification
Number of trees: 705
No. of variables tried at each split: 4
OOB estimate of error rate: 14.72%
Confusion matrix:
Adelie Chinstrap class.error
Adelie 94 18 0.1607143
Chinstrap 16 103 0.1344538
mod_rf_fit_wflw
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()
── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps
• step_naomit()
• step_normalize()
── Model ───────────────────────────────────────────────────────────────────────
Call:
randomForest(x = maybe_data_frame(x), y = y, ntree = ~705L, mtry = min_cols(~4L, x), nodesize = min_rows(~40L, x), importance = ~TRUE)
Type of random forest: classification
Number of trees: 705
No. of variables tried at each split: 4
OOB estimate of error rate: 14.72%
Confusion matrix:
Adelie Chinstrap class.error
Adelie 94 18 0.1607143
Chinstrap 16 103 0.1344538
However, unlike with the logistic regression, when the random forest object is passed to the tidy()
function, an error is thrown.
tidy(mod_rf_fit)
Error: No tidy method for objects of class randomForest
This is because of how the tidy()
function (and many other functions in R) works. It is a generic function with different methods written for it based on the class of the object that is passed to it. A discussion of object oriented programming (OOO) in R is outside the scope of the this post, but a good summary can be found here.
The short version is that someone has to write a different implementation of the tidy()
function for each type of object. The logistic regression has a version written for it in the broom
package (specifically broom:::tidy.glm), but the random forest does not.
An implementation of the tidy()
function does not have to come from the broom
package. Any package can include a tidy()
method for objects it creates or objects from other packages. For example, the recipes
package contains a tidy
method for a recipe
object and the rsample
package contains a tidy
method for an rsplit
object.
tidy(recipe_obj)
# A tibble: 2 × 6
number operation type trained skip id
<int> <chr> <chr> <lgl> <lgl> <chr>
1 1 step naomit FALSE FALSE naomit_6l6I0
2 2 step normalize FALSE FALSE normalize_gNX4l
tidy(penguins_split_obj)
# A tibble: 344 × 2
Row Data
<int> <chr>
1 1 Analysis
2 2 Analysis
3 4 Analysis
4 5 Analysis
5 6 Analysis
6 8 Analysis
7 9 Analysis
8 10 Analysis
9 11 Analysis
10 13 Analysis
# ℹ 334 more rows
There is a package on GitHub called broomstick
that provides a tidy()
method for a random forest object. It can be installed using the install_github()
function from the remotes
package.
remotes::install_github("njtierney/broomstick")
After loading the package, passing the random forest object to tidy()
will now provide a summary of information about the model in a tibble
.
# A tibble: 6 × 5
term MeanDecreaseAccuracy MeanDecreaseGini MeanDecreaseAccuracy…¹
<chr> <dbl> <dbl> <dbl>
1 island 0.00137 0.771 0.000374
2 bill_length_mm 0.0239 4.58 0.00132
3 bill_depth_mm 0.0164 5.09 0.00126
4 flipper_length_mm 0.0259 4.91 0.00124
5 body_mass_g 0.0479 9.62 0.00188
6 sex 0.218 49.3 0.00335
# ℹ abbreviated name: ¹MeanDecreaseAccuracy_sd
# ℹ 1 more variable: classwise_importance <list>
For a random forest parsnip
object, tidy()
prints out a tibble
of the the variables and performance metrics such as Mean Decrease in Accuracy and Gini.
tidy(mod_rf_fit)
# A tibble: 6 × 5
term MeanDecreaseAccuracy MeanDecreaseGini MeanDecreaseAccuracy…¹
<chr> <dbl> <dbl> <dbl>
1 island 0.00137 0.771 0.000374
2 bill_length_mm 0.0239 4.58 0.00132
3 bill_depth_mm 0.0164 5.09 0.00126
4 flipper_length_mm 0.0259 4.91 0.00124
5 body_mass_g 0.0479 9.62 0.00188
6 sex 0.218 49.3 0.00335
# ℹ abbreviated name: ¹MeanDecreaseAccuracy_sd
# ℹ 1 more variable: classwise_importance <list>
Additionally, for a random forest parsnip
object, glance()
prints out some common fit statistics such as precision, recall, and accuracy for each outcome level.
glance(mod_rf_fit)
# A tibble: 2 × 5
class precision recall accuracy f_measure
<chr> <dbl> <dbl> <dbl> <dbl>
1 Adelie 0.855 0.839 0.853 0.847
2 Chinstrap 0.851 0.866 0.853 0.858
Finally, for a random forest parsnip
object, augment()
adds the predicted class and the predicted probabilities for each class to a data set that is supplied to the function. Note that these columns are appended to the end of the data set, so the data set below was reduced down real outcome and the augment
ed columns.
mod_rf_fit %>%
augment(penguins_test_prep_tbl) %>%
select(species, .pred_class, .pred_Adelie, .pred_Chinstrap)
# A tibble: 102 × 4
species .pred_class .pred_Adelie .pred_Chinstrap
<fct> <fct> <dbl> <dbl>
1 Chinstrap Chinstrap 0.0128 0.987
2 Chinstrap Chinstrap 0.0695 0.930
3 Chinstrap Adelie 0.721 0.279
4 Adelie Chinstrap 0.0723 0.928
5 Chinstrap Chinstrap 0.0936 0.906
6 Chinstrap Chinstrap 0.0227 0.977
7 Adelie Chinstrap 0.418 0.582
8 Adelie Chinstrap 0.494 0.506
9 Chinstrap Chinstrap 0.0128 0.987
10 Adelie Adelie 0.877 0.123
# ℹ 92 more rows
For a random forest workflow
object, the same information is printed using tidy()
as with a parsnip
object.
tidy(mod_rf_fit_wflw)
# A tibble: 6 × 5
term MeanDecreaseAccuracy MeanDecreaseGini MeanDecreaseAccuracy…¹
<chr> <dbl> <dbl> <dbl>
1 island 0.00132 0.926 0.000377
2 bill_length_mm 0.0216 3.97 0.00124
3 bill_depth_mm 0.0141 5.07 0.00108
4 flipper_length_mm 0.0265 5.02 0.00126
5 body_mass_g 0.0481 9.56 0.00186
6 sex 0.221 49.4 0.00313
# ℹ abbreviated name: ¹MeanDecreaseAccuracy_sd
# ℹ 1 more variable: classwise_importance <list>
Similarly, for a random forest workflow
object, glance()
prints the same information as for a parsnip
object.
glance(mod_rf_fit_wflw)
# A tibble: 2 × 5
class precision recall accuracy f_measure
<chr> <dbl> <dbl> <dbl> <dbl>
1 Adelie 0.855 0.839 0.853 0.847
2 Chinstrap 0.851 0.866 0.853 0.858
Once again where things go wrong is with augment()
. The recipe
attached to the workflow
filters out rows with missing values. However, the original test data has missing values. This data does not appear to be passed to the recipe
before predictions are attempted, leading to the following error.
mod_rf_fit_wflw %>%
augment(penguins_test_tbl) %>%
select(species, .pred_class, .pred_Adelie, .pred_Chinstrap)
Error in `vctrs::vec_cbind()`:
! Can't recycle `..1` (size 104) to match `..2` (size 102).
Manually removing the rows with missing values solves this problem, and augment()
works just as it did with the parsnip
model object.
penguins_test_no_missing_tbl <- tidyr::drop_na(penguins_test_tbl)
mod_rf_fit_wflw %>%
augment(penguins_test_no_missing_tbl) %>%
select(species, .pred_class, .pred_Adelie, .pred_Chinstrap)
# A tibble: 102 × 4
species .pred_class .pred_Adelie .pred_Chinstrap
<fct> <fct> <dbl> <dbl>
1 Chinstrap Chinstrap 0.0128 0.987
2 Chinstrap Chinstrap 0.0695 0.930
3 Chinstrap Adelie 0.748 0.252
4 Adelie Chinstrap 0.0567 0.943
5 Chinstrap Chinstrap 0.0922 0.908
6 Chinstrap Chinstrap 0.0170 0.983
7 Adelie Chinstrap 0.417 0.583
8 Adelie Chinstrap 0.461 0.539
9 Chinstrap Chinstrap 0.0113 0.989
10 Adelie Adelie 0.847 0.153
# ℹ 92 more rows
Notes
This post is based on a presentation that was given on the date listed. It may be updated from time to time to fix errors, detail new functions, and/or remove deprecated functions so the packages and R version will likely be newer than what was available at the time.
The R session information used for this post:
R version 4.2.1 (2022-06-23)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS 14.1.1
Matrix products: default
BLAS: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
attached base packages:
[1] stats graphics grDevices datasets utils methods base
other attached packages:
[1] broomstick_0.1.2.9200 workflows_1.1.2 yardstick_1.1.0
[4] tune_1.0.1 dials_1.1.0 scales_1.2.1
[7] parsnip_1.0.3 recipes_1.0.4 dplyr_1.1.3
[10] rsample_1.1.1 broom_1.0.5
loaded via a namespace (and not attached):
[1] Rcpp_1.0.9 lubridate_1.9.0 lattice_0.20-45
[4] tidyr_1.3.0 listenv_0.8.0 class_7.3-20
[7] foreach_1.5.2 digest_0.6.29 ipred_0.9-13
[10] utf8_1.2.3 parallelly_1.32.1 R6_2.5.1
[13] backports_1.4.1 hardhat_1.2.0 evaluate_0.16
[16] ggplot2_3.4.0 pillar_1.9.0 rlang_1.1.1
[19] rstudioapi_0.14 DiceDesign_1.9 furrr_0.3.1
[22] rpart_4.1.19 Matrix_1.4-1 rmarkdown_2.16
[25] splines_4.2.1 gower_1.0.1 stringr_1.5.0
[28] munsell_0.5.0 compiler_4.2.1 xfun_0.40
[31] pkgconfig_2.0.3 globals_0.16.2 htmltools_0.5.3
[34] nnet_7.3-17 tidyselect_1.2.0 tibble_3.2.1
[37] prodlim_2019.11.13 codetools_0.2-18 randomForest_4.7-1.1
[40] GPfit_1.0-8 fansi_1.0.4 future_1.29.0
[43] withr_2.5.1 MASS_7.3-57 grid_4.2.1
[46] jsonlite_1.8.0 gtable_0.3.1 lifecycle_1.0.3
[49] magrittr_2.0.3 future.apply_1.10.0 cli_3.6.1
[52] stringi_1.7.12 renv_0.16.0 timeDate_4022.108
[55] ellipsis_0.3.2 lhs_1.1.6 generics_0.1.3
[58] vctrs_0.6.3 lava_1.7.1 iterators_1.0.14
[61] tools_4.2.1 glue_1.6.2 purrr_1.0.2
[64] parallel_4.2.1 fastmap_1.1.0 survival_3.3-1
[67] yaml_2.3.5 timechange_0.1.1 colorspace_2.0-3
[70] knitr_1.40