Open In App

How to do nested cross-validation with LASSO in caret or tidymodels?

Nested cross-validation is a robust technique used for hyperparameter tuning and model selection. When working with complex models like LASSO (Least Absolute Shrinkage and Selection Operator), it becomes essential to understand how to implement nested cross-validation efficiently. In this article, we’ll explore the concept of nested cross-validation and how to implement it with LASSO using popular R packages, Caret and Tidymodels.

Understanding Nested Cross-Validation

Nested cross-validation is a technique for evaluating and tuning machine learning models that helps prevent overfitting and provides a more realistic estimate of a model’s performance on unseen data. It consists of two levels of cross-validation:



Nested cross-validation is particularly useful when you have a limited dataset or need to optimize hyperparameters to ensure model generalization.

LASSO Regression

A regularisation method is lasso regression. For a more accurate forecast, it is preferred over regression techniques. Shrinkage is used in this model. When data values shrink towards the mean, this is referred to as shrinkage. Models with fewer parameters are encouraged by the lasso technique since they are straightforward and sparse. When a model exhibits a high degree of multicollinearity or when you wish to automate some steps in the model selection process, such as variable selection and parameter removal, this specific sort of regression is ideally suited.



L1 regularisation technique is used in Lasso Regression (explained more in this article). Because it does feature selection automatically, it is employed when there are more features.

Here’s a step-by-step explanation of how LASSO regression works:

Why Use LASSO?

LASSO is a linear regression technique that adds a penalty term to the linear regression cost function. This penalty encourages the model to shrink some coefficients to exactly zero, effectively performing feature selection. LASSO is valuable when dealing with datasets with many features or when you suspect that some features are irrelevant.

Pre-Requisites

Before diving into nested cross-validation, make sure you have R installed along with the Caret and Tidymodels packages. You can install them using the following commands:




# Load a built-in dataset from the mlbench package
 
data(Sonar)

Loading Libraries




# Define your control parameters for outer CV
ctrl <- trainControl(
  method = "cv",
  number = 5,
  summaryFunction = twoClassSummary,
  classProbs = TRUE,
  search = "grid"
)
 
# Define a hyperparameter grid for LASSO (aplha = 1)
grid <- expand.grid(
  alpha = 1,
  lambda = seq(0.001, 1, length = 10)
)
 
# Perform nested cross-validation
set.seed(123)
model <- train(
  Class ~ .,
  data = Sonar,
  method = "glmnet",
  trControl = ctrl,
  tuneGrid = grid
)
 
# Print the best hyperparameters
print(model$bestTune)

Load the dataset




data(mtcars)

The data report the patterns obtained by bouncing sonar signals at various angles and under various conditions. There are 208 patterns in all, 111 obtained by bouncing sonar signals off a metal cylinder and 97 obtained by bouncing signals off rocks. Each pattern is a set of 60 numbers (variables) taking values between 0 and 1.

Implementing Nested Cross-Validation with Caret

Here’s how you can perform nested cross-validation with LASSO using Caret:




summary(mtcars)

Output:

alpha lambda
1 1 0.001

This code essentially performs nested cross-validation to find the best hyperparameters for a LASSO logistic regression model, using predefined control parameters and a hyperparameter grid. The goal is to identify the hyperparameters that yield the best classification performance.

Nested Cross-Validation with Tidymodels on mtcars Dataset

Tidymodels is another popular package for modeling and machine learning. Here’s how to perform nested cross-validation with LASSO using Tidymodels:

Loading Libraries

We start by loading the necessary libraries, primarily tidymodels, which provides a framework for modeling and machine learning.





Loading the Dataset




# Create a model specification for Lasso regression
lasso <- linear_reg(penalty = tune(), mixture = 1) %>%
  set_engine("glmnet")
 
# Create a workflow that includes the recipe and model specification
lasso_workflow <- workflow() %>%
  add_recipe(preprocess) %>%
  add_model(lasso)

Summary of the Dataset




# Define a grid of hyperparameters to search
hypergrid <- expand.grid(
  penalty = seq(0.001, 1, length.out = 10)
)
 
# Set up a cross-validation resampling method
# 10-fold cross-validation
cv <- vfold_cv(mtcars, v = 10) 
 
# Tune the Lasso model using cross-validation
lasso_t <- tune_grid(
  object = lasso_workflow,
  resamples = cv,
  grid = hypergrid,
  metrics = metric_set(rmse, mae, rsq)
)
 
# Get the best Lasso model
lasso_best <- select_best(lasso_t, "rmse")
 
# Print the best Lasso model
lasso_best

Output:

mpg             cyl             disp             hp             drat             wt             qsec      
Min. :10.40 Min. :4.000 Min. : 71.1 Min. : 52.0 Min. :2.760 Min. :1.513 Min. :14.50
1st Qu.:15.43 1st Qu.:4.000 1st Qu.:120.8 1st Qu.: 96.5 1st Qu.:3.080 1st Qu.:2.581 1st Qu.:16.89
Median :19.20 Median :6.000 Median :196.3 Median :123.0 Median :3.695 Median :3.325 Median :17.71
Mean :20.09 Mean :6.188 Mean :230.7 Mean :146.7 Mean :3.597 Mean :3.217 Mean :17.85
3rd Qu.:22.80 3rd Qu.:8.000 3rd Qu.:326.0 3rd Qu.:180.0 3rd Qu.:3.920 3rd Qu.:3.610 3rd Qu.:18.90
Max. :33.90 Max. :8.000 Max. :472.0 Max. :335.0 Max. :4.930 Max. :5.424 Max. :22.90
vs am gear carb
Min. :0.0000 Min. :0.0000 Min. :3.000 Min. :1.000
1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:3.000 1st Qu.:2.000
Median :0.0000 Median :0.0000 Median :4.000 Median :2.000
Mean :0.4375 Mean :0.4062 Mean :3.688 Mean :2.812
3rd Qu.:1.0000 3rd Qu.:1.0000 3rd Qu.:4.000 3rd Qu.:4.000
Max. :1.0000 Max. :1.0000 Max. :5.000 Max. :8.000

Pre-processing the dataset





Creating Lasso object and workflow




# Create a model specification for Lasso regression
lasso <- linear_reg(penalty = tune(), mixture = 1) %>%
  set_engine("glmnet")
 
# Create a workflow that includes the recipe and model specification
lasso_workflow <- workflow() %>%
  add_recipe(preprocess) %>%
  add_model(lasso)

Implementing Nested Cross-Validation with Tidymodels




# Define a grid of hyperparameters to search
hypergrid <- expand.grid(
  penalty = seq(0.001, 1, length.out = 10)
)
 
# Set up a cross-validation resampling method
# 10-fold cross-validation
cv <- vfold_cv(mtcars, v = 10) 
 
# Tune the Lasso model using cross-validation
lasso_t <- tune_grid(
  object = lasso_workflow,
  resamples = cv,
  grid = hypergrid,
  metrics = metric_set(rmse, mae, rsq)
)
 
# Get the best Lasso model
lasso_best <- select_best(lasso_t, "rmse")
 
# Print the best Lasso model
lasso_best

Output:

A tibble: 1 × 2
penalty .config
<dbl> <chr>
0.667 Preprocessor1_Model07

Conclusion

Implementing nested cross-validation with LASSO in R using Caret or Tidymodels can help you build robust and accurate predictive models. It ensures that your model generalizes well to unseen data and helps you select the best hyperparameters for your LASSO model.

By following the steps outlined in this article, you can confidently apply nested cross-validation to your machine learning projects, ultimately improving the reliability and performance of your models.


Article Tags :