Managing Workflowset Models

Using option-add to Tune Different Model of Workflowset

Machine Learning
Tuning
Published

June 11, 2024

Introduction

The tidymodels package is a game-changer for the R ecosystem, providing a streamlined and intuitive approach to modeling. Built on the tidyverse foundation, it offers a cohesive framework that simplifies the journey from data wrangling to robust models. What makes tidymodels stand out is its consistent workflow, reducing the learning curve for data scientists and ensuring compatibility across different modeling packages【Kuhn and Silge (2022)】.

Workflow

The workflows package is one of the standout components of tidymodels, making the iterative machine learning process in R more manageable. By bundling model fitting and data preprocessing steps into a single coherent object, workflows simplifies the complexities of the machine learning pipeline, ensuring each step is clearly defined and reproducible. This iterative machine learning process, as covered in “Tidy Modeling with R”【Kuhn and Silge (2022)】, is illustrated below:

Source: Tidy Modeling with R

Workflowsets

The focus of this post, the workflowsets package, builds on the workflows package by extending its capabilities to handle multiple machine learning models. Since the best model for any given task is not predetermined, it’s crucial to test multiple models and compare their performances. workflowsets is designed to manage multiple workflows, making it easier to compare different modeling approaches and preprocessing strategies.

This blog post introduces the option_add function of the workflowsets package, which is used to control options for evaluating workflow set functions such as fit_resamples and tune_grid. For more information on this function, refer to the documentation with ?option_add.

We start by loading the packages we will be using for this post

Show the code
library(pacman)
p_load(tidyverse, tidymodels, gt, finetune, bonsai)

For this post we’ll use the heart disease dataset from kaggle.com. A preview of the data is given Table 1

Show the code
heart_disease <- read_csv("heart_disease_dataset.csv")

head(heart_disease) |> 
  gt() |> 
  tab_header(
    title = "Heart Diseases"
  ) |> 
  opt_stylize(
    style = 2, 
    color = "cyan"
  ) |> 
  as_raw_html()
Table 1: Data Preview
Heart Diseases
Age Gender Cholesterol Blood Pressure Heart Rate Smoking Alcohol Intake Exercise Hours Family History Diabetes Obesity Stress Level Blood Sugar Exercise Induced Angina Chest Pain Type Heart Disease
75 Female 228 119 66 Current Heavy 1 No No Yes 8 119 Yes Atypical Angina 1
48 Male 204 165 62 Current None 5 No No No 9 70 Yes Typical Angina 0
53 Male 234 91 67 Never Heavy 3 Yes No Yes 5 196 Yes Atypical Angina 1
69 Female 192 90 72 Current None 4 No Yes No 7 107 Yes Non-anginal Pain 0
62 Female 172 163 93 Never None 6 No Yes No 2 183 Yes Asymptomatic 0
77 Male 309 110 73 Never None 0 No Yes Yes 4 122 Yes Asymptomatic 1

Short EDA

Show the code
skimr::skim_without_charts(heart_disease) |> 
  gt() |> 
  tab_spanner(
    label = "Character",
    columns = character.min:character.whitespace
  ) |> 
  tab_spanner(
    label = "Numeric",
    columns = starts_with("numeric")
  ) |> 
  cols_label(
    skim_type ~ "Type",
    skim_variable ~"Variable",
    n_missing ~ "Missing?",
    complete_rate ~ "Complete?",
    character.min ~ "Min",
    character.max ~ "Max",
    character.empty ~ "Empty",
    character.n_unique ~ "Unique",
    character.whitespace ~ "Gap",
    numeric.mean ~ "Mean",
    numeric.sd ~ "SD",
    numeric.p0 ~ "Min",
    numeric.p25 ~ "25%",
    numeric.p50 ~ "Median",
    numeric.p75 ~ "75%",
    numeric.p100 ~ "Max"
  ) |> 
  cols_width(
    skim_type ~ px(80),
    everything() ~ px(70)
  ) |> 
  opt_stylize(
    style = 2,
    color = "cyan",
  ) |> 
  as_raw_html()
Table 2: Quick description of the data
Type Variable Missing? Complete?
Character
Numeric
Min Max Empty Unique Gap Mean SD Min 25% Median 75% Max
character Gender 0 1 4 6 0 2 0 NA NA NA NA NA NA NA
character Smoking 0 1 5 7 0 3 0 NA NA NA NA NA NA NA
character Alcohol Intake 0 1 4 8 0 3 0 NA NA NA NA NA NA NA
character Family History 0 1 2 3 0 2 0 NA NA NA NA NA NA NA
character Diabetes 0 1 2 3 0 2 0 NA NA NA NA NA NA NA
character Obesity 0 1 2 3 0 2 0 NA NA NA NA NA NA NA
character Exercise Induced Angina 0 1 2 3 0 2 0 NA NA NA NA NA NA NA
character Chest Pain Type 0 1 12 16 0 4 0 NA NA NA NA NA NA NA
numeric Age 0 1 NA NA NA NA NA 52.293 15.727126 25 39.00 52.0 66 79
numeric Cholesterol 0 1 NA NA NA NA NA 249.939 57.914673 150 200.00 248.0 299 349
numeric Blood Pressure 0 1 NA NA NA NA NA 135.281 26.388300 90 112.75 136.0 159 179
numeric Heart Rate 0 1 NA NA NA NA NA 79.204 11.486092 60 70.00 79.0 89 99
numeric Exercise Hours 0 1 NA NA NA NA NA 4.529 2.934241 0 2.00 4.5 7 9
numeric Stress Level 0 1 NA NA NA NA NA 5.646 2.831024 1 3.00 6.0 8 10
numeric Blood Sugar 0 1 NA NA NA NA NA 134.941 36.699624 70 104.00 135.0 167 199
numeric Heart Disease 0 1 NA NA NA NA NA 0.392 0.488441 0 0.00 0.0 1 1

Table 2 shows there are no missing values, so we can proceed with our analysis.

Next, we will convert all character variables to factor data types

Show the code
heart_diseases <- heart_disease |> 
  janitor::clean_names() |> 
  mutate(
    across(where(is.character), factor),
    exercise_hours = factor(exercise_hours),
    stress_level = factor(stress_level),
    heart_disease = factor(
      heart_disease, 
      labels = c("No","Yes"),
      levels = c(0, 1)
    )
  )
Show the code
GGally::ggscatmat(
  data = heart_diseases,
  columns = 1:ncol(heart_diseases),
  color = "heart_disease",
  alpha = .3
)
Figure 1: Scattered Matrix Plots of variables
Show the code
GGally::ggcorr(
  data = heart_diseases,
  columns = 1:ncol(heart_diseases),
  name = expression(rho),
  geom = "circle",
  size = 3,
  min_size = 5,
  max_size = 10,
  angle = -45
) +
  ggtitle("Correlation Plot of Numeric Variables")
Figure 2: Correlation plot of numeric variables
Show the code
heart_diseases |> 
  ggplot(aes(heart_disease, fill = gender)) +
  geom_bar(position = "dodge") +
  labs(
    x = "Heart disease",
    y = "Frequency",
    title = "Heart disease a bit more prevalent in male than females"
  ) +
  ggthemes::scale_fill_fivethirtyeight()
Figure 3: Frequency of Heart Disease Outcome

We won’t spend time on EDA and proceed with our modeling workflow.

Modeling

Data Splitting

we will split our data to 75% for training and 25% for testing, using the outcome variable (heart_disease) as the strata to ensure a balance split. Additionally, We will create validation folds to evaluate the models.

Show the code
set.seed(832)
hd_split <- initial_split(heart_diseases, prop = .75, strata = heart_disease)

hd_train <- training(hd_split)
hd_folds <- vfold_cv(hd_train)

head(hd_train) |> 
  gt() |> 
  opt_stylize(
    style = 2,
    color = "cyan"
  ) |> 
  as_raw_html()
age gender cholesterol blood_pressure heart_rate smoking alcohol_intake exercise_hours family_history diabetes obesity stress_level blood_sugar exercise_induced_angina chest_pain_type heart_disease
48 Male 204 165 62 Current None 5 No No No 9 70 Yes Typical Angina No
62 Female 172 163 93 Never None 6 No Yes No 2 183 Yes Asymptomatic No
37 Female 317 137 66 Current Heavy 3 No Yes Yes 5 114 No Non-anginal Pain No
43 Male 155 169 82 Current Heavy 8 Yes Yes No 2 163 No Typical Angina No
44 Female 250 111 66 Former None 6 Yes No Yes 3 121 Yes Non-anginal Pain No
43 Female 279 173 81 Current Moderate 9 Yes No No 7 150 No Asymptomatic No

Model Specification

We will use two models for our analysis:

  • K-nearest neighbors (KNN) model

  • Generalized linear model (GLM).

Show the code
knn_spec <- nearest_neighbor(
  neighbors = tune(),
  weight_func = tune(),
  dist_power = tune()
) |> 
  set_engine("kknn") |> 
  set_mode("classification")

glm_spec <- logistic_reg() |> 
  set_engine("glm", family = stats::binomial(link = "logit")) |> 
  set_mode("classification")

Below is the specification we have set for the KNN model:

Show the code
knn_spec |>  translate()
K-Nearest Neighbor Model Specification (classification)

Main Arguments:
  neighbors = tune()
  weight_func = tune()
  dist_power = tune()

Computational engine: kknn 

Model fit template:
kknn::train.kknn(formula = missing_arg(), data = missing_arg(), 
    ks = min_rows(tune(), data, 5), kernel = tune(), distance = tune())

The KNN spec model is having three tuning parameters. For the GLM model we have the following:

Show the code
glm_spec |> translate()
Logistic Regression Model Specification (classification)

Engine-Specific Arguments:
  family = stats::binomial(link = "logit")

Computational engine: glm 

Model fit template:
stats::glm(formula = missing_arg(), data = missing_arg(), weights = missing_arg(), 
    family = stats::binomial(link = "logit"))

The GLM specification is having no tuning parameter.

As seen in all the model specification above, the formula is missing. We’ll determine the formula for all models and the necessary preprocessing/feature engineering options we want to include in the next step using the recipe package

Data Preprocessing

We have three preprocessing specification. The first defines the formula which we will use, the second includes normalizing all numeric predictors, and the final preprocessing step involves creating dummy variables for our categorical variables.

Show the code
formula <- recipe(
  heart_disease ~ .,
  data = hd_train
)

normalize <- formula |> 
  step_normalize(all_numeric_predictors())

dummy <- normalize |> 
  step_dummy(all_factor_predictors())
Show the code
normalize |> 
  prep() |> 
  juice() |> 
  head() |> 
  gt() |> 
  opt_stylize(
    style = 3,
    color = "cyan"
  )
Table 3: Preview of normalized preprocessed data
age gender cholesterol blood_pressure heart_rate smoking alcohol_intake exercise_hours family_history diabetes obesity stress_level blood_sugar exercise_induced_angina chest_pain_type heart_disease
-0.2983420 Male -0.784268085 1.13748976 -1.4962849 Current None 5 No No No 9 -1.7689353 Yes Typical Angina No
0.6020865 Female -1.332244271 1.06281212 1.2243394 Never None 6 No Yes No 2 1.3197224 Yes Asymptomatic No
-1.0058214 Female 1.150772824 0.09200285 -1.1452366 Current Heavy 3 No Yes Yes 5 -0.5662721 No Non-anginal Pain No
-0.6199235 Male -1.623356620 1.28684503 0.2589566 Current Heavy 8 Yes Yes No 2 0.7730573 No Typical Angina No
-0.5556072 Female 0.003447684 -0.87880642 -1.1452366 Former None 6 Yes No Yes 3 -0.3749394 Yes Non-anginal Pain No
-0.6199235 Female 0.500051103 1.43620030 0.1711946 Current Moderate 9 Yes No No 7 0.4177250 No Asymptomatic No

Table 3 previews how the data looks after normalizing, which is the second feature engineering technique. Table 4 shows the data after creating dummy variables categorical variables.

Show the code
dummy |> 
  prep() |> 
  juice() |> 
  head() |> 
  gt() |> 
  opt_stylize(
    style = 2,
    color = "cyan"
  ) |> 
  as_raw_html()
Table 4: Preview of dummy + normalized preprocessed data
age cholesterol blood_pressure heart_rate blood_sugar heart_disease gender_Male smoking_Former smoking_Never alcohol_intake_Moderate alcohol_intake_None exercise_hours_X1 exercise_hours_X2 exercise_hours_X3 exercise_hours_X4 exercise_hours_X5 exercise_hours_X6 exercise_hours_X7 exercise_hours_X8 exercise_hours_X9 family_history_Yes diabetes_Yes obesity_Yes stress_level_X2 stress_level_X3 stress_level_X4 stress_level_X5 stress_level_X6 stress_level_X7 stress_level_X8 stress_level_X9 stress_level_X10 exercise_induced_angina_Yes chest_pain_type_Atypical.Angina chest_pain_type_Non.anginal.Pain chest_pain_type_Typical.Angina
-0.2983420 -0.784268085 1.13748976 -1.4962849 -1.7689353 No 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1
0.6020865 -1.332244271 1.06281212 1.2243394 1.3197224 No 0 0 1 0 1 0 0 0 0 0 1 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 1 0 0 0
-1.0058214 1.150772824 0.09200285 -1.1452366 -0.5662721 No 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 0 0 0 1 0
-0.6199235 -1.623356620 1.28684503 0.2589566 0.7730573 No 1 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 1 0 0 0 0 0 0 0 0 0 0 0 1
-0.5556072 0.003447684 -0.87880642 -1.1452366 -0.3749394 No 0 1 0 0 1 0 0 0 0 0 1 0 0 0 1 0 1 0 1 0 0 0 0 0 0 0 1 0 1 0
-0.6199235 0.500051103 1.43620030 0.1711946 0.4177250 No 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0

Model Workflow Set

Show the code
hd_wf_set<- workflow_set(
  preproc = list(
    form = formula,
    norm = normalize,
    dum = dummy
  ),
  models = list(
    glm = glm_spec,
    knn = knn_spec
  )
)

Tuning Parameter

Using the workflowset function, we’ve tied three recipe objects to the three different models. The K-nearest neighbor model needs tuning as mentioned earlier.

Show the code
set.seed(34443)

knn_grid <- knn_spec |> 
  extract_parameter_set_dials() |> 
  grid_regular(levels = 6)

knn_latin <- knn_spec |> 
  extract_parameter_set_dials() |> 
  grid_latin_hypercube(size = 300)
Warning: `grid_latin_hypercube()` was deprecated in dials 1.3.0.
ℹ Please use `grid_space_filling()` instead.
Show the code
grid_control <- control_race(
  save_pred = TRUE,
  save_workflow = TRUE
)

knn_grid |> 
  ggplot(aes(dist_power, neighbors, col = weight_func)) +
  geom_point() +
  ggthemes::scale_color_colorblind() +
  labs(
    x = "Minkowski distance",
    y = "Number of Neighbors",
    title = "k-NN Regular Grid"
  ) +
  facet_wrap(~weight_func) +
  theme(
    legend.position = "none"
  )
  
knn_latin |>
  ggplot(aes(dist_power, neighbors, col = weight_func)) +
  geom_point() +
  ggthemes::scale_color_tableau() +
  labs(
    x = "Minkowski distance",
    y = "Number of Neighbors",
    title = "k-NN Latin Hypercube Grid"
  ) +
  facet_wrap(~weight_func) +
  theme(
    legend.position = "none"
  )
(a) - knn regular tune grid - knn latin hypercube tune grid
(b) - knn regular tune grid - knn latin hypercube tune grid
Figure 4: Tuning grids to be used for K-nearest neighbor model specification

We set the tuning grid for the model and use the option_add function to specify it. We will test two different grid structures as shown in Figure 4.

Using option_add to Specify Model Grids

We can specify the grid to use for each model using the option_add function. Below is an image of hd_wf_set that we defined recently, and we will interpret its output.

Defined workflowset output

The image above shows that option column is having zero values as well as the results column.

Show the code
hd_tune <- hd_wf_set |> 
  option_add(
    id = "norm_knn",
    grid = knn_grid,
    control = grid_control
  ) |> 
  option_add(
    id = "form_knn",
    grid = knn_grid,
    control = grid_control
  ) |> 
  option_add(
    id = "norm_knn",
    grid = knn_latin,
    control = grid_control
  ) |> 
  option_add(
    id = "form_knn",
    grid = knn_latin,
    control = grid_control
  ) |> 
  option_add(
    id = "dum_knn",
    grid = knn_grid,
    control = grid_control
  ) |> 
  option_add(
    id = "dum_knn",
    grid = knn_latin,
    control = grid_control
  )

Defined workflowset output after options are added

After using the option-add function, we can see that KNN model specification have two options added to it. We can now proceed to tune our model.

Show the code
doParallel::registerDoParallel(cores = 6)

hd_tune_res <- workflow_map(
 hd_tune ,
 fn = "tune_race_anova",
 resamples = hd_folds,
 seed = 3343,
 verbose = TRUE
)
i   No tuning parameters. `fit_resamples()` will be attempted
i 1 of 6 resampling: form_glm
✔ 1 of 6 resampling: form_glm (853ms)
i 2 of 6 tuning:     form_knn
✔ 2 of 6 tuning:     form_knn (1m 40.2s)
i   No tuning parameters. `fit_resamples()` will be attempted
i 3 of 6 resampling: norm_glm
✔ 3 of 6 resampling: norm_glm (888ms)
i 4 of 6 tuning:     norm_knn
✔ 4 of 6 tuning:     norm_knn (1m 43.1s)
i   No tuning parameters. `fit_resamples()` will be attempted
i 5 of 6 resampling: dum_glm
✔ 5 of 6 resampling: dum_glm (1s)
i 6 of 6 tuning:     dum_knn
✔ 6 of 6 tuning:     dum_knn (2m 21.9s)

Tune Result

Show the code
autoplot(hd_tune_res)
Figure 5
Show the code
hd_tune_res |> 
  rank_results(rank_metric = "accuracy") |> 
  filter(.metric == "accuracy") |> 
  select(-c(.metric,  preprocessor, model, n)) |> 
  gt() |> 
  cols_label(
    wflow_id = "Model ID",
    .config = "Model Number"
  ) |> 
  opt_stylize(
    style = 2,
    color = "cyan"
  ) |> 
  as_raw_html()
Model ID Model Number mean std_err rank
form_knn Preprocessor1_Model236 0.8653333 0.013546445 1
norm_knn Preprocessor1_Model236 0.8653333 0.013546445 2
form_glm Preprocessor1_Model1 0.8613333 0.013799266 3
norm_glm Preprocessor1_Model1 0.8613333 0.013799266 4
dum_glm Preprocessor1_Model1 0.8613333 0.013799266 5
norm_knn Preprocessor1_Model134 0.8520000 0.012000000 6
form_knn Preprocessor1_Model134 0.8520000 0.012000000 7
dum_knn Preprocessor1_Model134 0.7426667 0.008899993 8
dum_knn Preprocessor1_Model137 0.7413333 0.009573626 9
dum_knn Preprocessor1_Model129 0.7386667 0.011958777 10
dum_knn Preprocessor1_Model103 0.7360000 0.010850272 11
dum_knn Preprocessor1_Model222 0.7333333 0.015267168 12
dum_knn Preprocessor1_Model106 0.7333333 0.008663817 13
dum_knn Preprocessor1_Model221 0.7226667 0.013303671 14

Based on the results, it appears that the KNN model with no preprocessing is the best performing model.

Conclusion

The success of our KNN model, particularly with preprocessing, underscores the critical role of the option_add function. By utilizing option_add, we efficiently defined and refined our model’s tuning grid, allowing us to systematically explore and optimize hyperparameters. This approach not only enhances model performance but also ensures robustness and reliability in our predictive analytics pipeline.

References

Kuhn, M., and J. Silge. 2022. Tidy Modeling with r. O’Reilly Media. https://books.google.dk/books?id=98J6EAAAQBAJ.