Open In App

Cross validation in R without caret package

Last Updated : 30 Jan, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Cross-validation is a technique for evaluating the performance of a machine learning model by training it on a subset of the data and evaluating it on the remaining data. It is a useful method for estimating the performance of a model when you don’t have a separate test set, or when you want to get a better estimate of the model’s performance than a single train/test split would provide.

There are several ways to perform cross-validation in R:

  1. Manually splitting the data into folds: You can use the create folds function from the e1071 library, or write your own code to split the data into the desired number of folds. You can then loop through the folds, using each one as the test set and the remaining folds as the training set, and evaluate the model on each test set.
  2. Using the train function from the caret package: The caret package provides a convenient function, train, that can be used to train and evaluate machine learning models using a variety of cross-validation methods. It also provides a number of functions for generating the folds, such as createFoldDummyVars and createMultiFolds.
  3. Using the cv.glm function from the boot package: The boot package provides a function, cv.glm, that can be used to perform cross-validation for generalized linear models. It allows you to specify the number of folds and the type of cross-validation to use (e.g. stratified or non-stratified).
  4. Using the cv.lm function from the boot package: Similar to cv.glm, the boot package also provides a function, cv.lm, that can be used to perform cross-validation for linear models.
  5. Using the kfoldCV function from the mlr package: The mlr package provides a function, kfoldCV, that can be used to perform k-fold cross-validation for a variety of machine learning models. It allows you to specify the number of folds and the type of cross-validation to use (e.g. stratified or non-stratified).

Cross-validation in R without caret package

There are several ways to perform cross-validation on datasets in the R Programming language. But in this example, we will use tidyverse and e1071 libraries. Load the tidyverse and e1071 libraries. The tidyverse library is not necessary for cross-validation, but it is used in this example to split the data into folds. The e1071 library is used to train the SVM model.

R




# Load the required libraries
library(tidyverse)
library(e1071)


Use the createFolds function from the e1071 library to split the data into 10 folds. The function takes the data as the first argument and the number of folds as the second argument. It returns a list of vectors, where each vector contains the indices of the rows that belong to a fold. Initialize an empty vector to store the results of the cross-validation.

R




# Split the data into 10 folds
folds <- createFolds(iris$Species, k = 10)
  
# Initialize a vector to store the results
results <- c()


Loop through the folds. Split the data into training and test sets using the indices in the current fold. Train the SVM model on the training data using the SVM function from the e1071 library. Make predictions on the test data using the predict function. by calculating the accuracy of the predictions by comparing them to the true labels and taking the mean. Append the accuracy to the results vector.

R




# Loop through the folds
for (i in 1:10)
  # Split the data into training and test sets
  train <- iris[-folds[[i]],]
  test <- iris[folds[[i]],]
  # Train the model on the training data
  model <- svm(Species ~ ., data = train)
  # Make predictions on the test data
  predictions <- predict(model, test)
  # Calculate the accuracy
  accuracy <- mean(predictions == test$Species)
  # Store the result
  results <- c(results, accuracy)


After the loop finishes, calculate the mean of the results using the mean.

R




# Calculate the mean and standard
# deviation of the results
mean(results)


Output:

0.933333333333333 

The first line is the mean accuracy, and the second line is the standard deviation of the accuracy scores. In this case, the mean accuracy is 0.9, which indicates that the model was able to correctly classify 90% of the test samples on average across all folds. The standard deviation of 0.0 indicates that the accuracy was the same on all folds.

This code uses the createFolds function from the e1071 library to split the data into 10 folds. It then trains a support vector machine (SVM) model on each fold, using the remaining folds as the training data, and evaluates the model on the fold being used as the test data. Finally, it calculates the mean and standard deviation of the accuracy scores obtained on each fold to get an idea of the model’s performance.



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads