How to Implement the Train() Function in R
-
Syntax of the
train
Function -
Install and Load Packages to Use the
train
Function -
Implement the
train()
Function in R for Basic Model Training for Regression -
Implement the
train()
Function in R for Classification With Random Forests -
Implement the
train()
Function in R for Model Tuning With Cross-Validation - Conclusion
Machine learning has become an indispensable tool in data analysis, enabling data scientists to uncover patterns, make predictions, and derive insights from complex datasets. In machine learning libraries in R, the caret
package stands out with its comprehensive set of tools for model training, tuning, and evaluation.
In this package lies the train
function, a versatile tool that simplifies the process of implementing and using various machine learning algorithms. In this article, we’ll delve into the intricacies of the train
function, exploring its capabilities and providing examples to illustrate its usage.
Syntax of the train
Function
The train
function is a central component of the caret
package, short for Classification And REgression Training. The primary objective of train
is to streamline the process of building and evaluating predictive models.
It is particularly useful when dealing with a variety of machine learning algorithms, as it provides a consistent interface for model training across different methods.
The basic syntax is as follows:
train(formula, data, method, ...)
Here’s a breakdown of the main parameters:
formula
: A symbolic description of the model, specifying the relationship between the predictors and the response variable.data
: The dataset containing the variables specified in the formula.method
: The modeling method or algorithm to be used for training the model. This can be any algorithm supported by thecaret
package, such aslm
for linear regression,rf
for random forests, etc....
: Additional arguments that depend on the chosen modeling method. These additional arguments allow you to customize the model training process, such as specifying hyperparameters or control parameters.
It’s important to note that the train
function is quite flexible and can be customized based on the specific needs of your modeling task. The exact arguments and their interpretation may vary depending on the modeling method chosen.
Install and Load Packages to Use the train
Function
Before diving into the train
function, ensure you have the required packages installed. Open your R console and execute the following commands:
install.packages("caret", dependencies = c("Depends", "Suggests"))
Once installed, load the packages into your workspace:
library(caret)
Implement the train()
Function in R for Basic Model Training for Regression
Let’s start with a basic example of using the train
function for linear regression.
To start, we load the BostonHousing
dataset into our R environment using the data
function. This dataset contains housing-related information for various neighborhoods, making it suitable for regression analysis.
data(BostonHousing)
With the dataset in place, we define our regression formula using the ~
operator. Here, we want to predict the median value of owner-occupied homes (medv
) based on all other available features in the dataset.
This is expressed as medv ~ .
, where the dot (.
) indicates the use of all other variables for prediction.
formula <- medv ~ .
Now, we are ready to use the train
function to train our linear regression model. We pass in the formula the dataset (BostonHousing
) and specify the modeling method as lm
(linear regression).
lm_model <- train(formula, data = BostonHousing, method = "lm")
The result is stored in the lm_model
object, and we can print it to the console for a comprehensive overview of the trained model.
print(lm_model)
Complete Code Example:
library(caret)
data(BostonHousing)
formula <- medv ~ .
lm_model <- train(formula, data = BostonHousing, method = "lm")
print(lm_model)
When you run this code, you should see output similar to the following:
Linear Regression
506 samples
13 predictor
No pre-processing
Resampling: Bootstrapped (25 reps)
Summary of sample sizes: 506, 506, 506, 506, 506, 506, ...
Resampling results:
RMSE Rsquared MAE
4.843917 0.7243338 3.45056
Tuning parameter 'intercept' was held constant at a value of TRUE
This output provides valuable information about the trained linear regression model, including the Root Mean Squared Error (RMSE), R-squared, and Mean Absolute Error (MAE). These metrics offer insights into how well the model is performing on the training data.
Implement the train()
Function in R for Classification With Random Forests
Now, let’s explore classification using the famous iris
dataset and a random forest model.
Our first step is to load the iris
dataset into our R environment. This dataset contains measurements of sepal and petal dimensions for three different species of iris flowers.
data(iris)
Next, we define our classification formula using the ~
operator. Here, our goal is to predict the species of iris flowers (Species
) based on their sepal and petal dimensions.
We express this as Species ~ .
, indicating the utilization of all other available variables for prediction.
formula <- Species ~ .
Now, we employ the train
function to train our random forest classification model. We provide the formula for the dataset (iris
) and specify the modeling method as rf
(random forest).
rf_model <- train(formula, data = iris, method = "rf")
Once the model is trained, we can inspect the details by printing the rf_model
object. This provides us with valuable information about the random forest model, including the number of trees, variable importance, and other relevant statistics.
print(rf_model)
Complete Code Example:
library(caret)
data(iris)
formula <- Species ~ .
rf_model <- train(formula, data = iris, method = "rf")
print(rf_model)
You should observe an output similar to the following:
Random Forest
150 samples
4 predictor
3 classes: 'setosa', 'versicolor', 'virginica'
No pre-processing
Resampling: Bootstrapped (25 reps)
Summary of sample sizes: 150, 150, 150, 150, 150, 150, ...
Resampling results across tuning parameters:
mtry Accuracy Kappa
2 0.9508815 0.9254858
3 0.9513779 0.9262463
4 0.9505901 0.9251102
Accuracy was used to select the optimal model using the largest value.
The final value used for the model was mtry = 3.
This output offers valuable information about the trained Random Forest classification model. It includes accuracy, kappa, and the optimal tuning parameter values determined during the training process.
Understanding these metrics is essential for assessing the model’s performance on the classification task.
Implement the train()
Function in R for Model Tuning With Cross-Validation
One of the strengths of the train
function is its ability to perform hyperparameter tuning with cross-validation. Let’s demonstrate this using the mtcars
dataset and a support vector machine (SVM).
Firstly, we load the mtcars
dataset into our R environment. This dataset contains information about various car models, making it suitable for regression analysis.
data(mtcars)
We define our regression formula, specifying that we want to predict miles per gallon (mpg
) based on all other available variables in the dataset.
formula <- mpg ~ .
To perform model tuning with cross-validation, we need to set up control parameters using the trainControl
function. In this example, we choose a repeated k-fold cross-validation method with 4 folds and save predictions for further analysis.
ctrl <- trainControl(
method = "repeatedcv", number = 4, savePredictions = TRUE, verboseIter = TRUE,
returnResamp = "all"
)
Now, we use the train
function to train our SVM model with a radial kernel. We provide the regression formula, the dataset (mtcars
), the modeling method as svmRadial
, and the previously defined control parameters.
svm_model <- train(formula, data = mtcars, method = "svmRadial", trControl = ctrl)
Upon completion, we print the trained SVM model to the console. This output includes information about the tuning parameters, such as the cost and gamma values, as well as performance metrics across different folds and repetitions.
print(svm_model)
Complete Code Example:
library(caret)
data(mtcars)
formula <- mpg ~ .
ctrl <- trainControl(method = "cv", number = 5)
svm_model <- train(formula, data = mtcars, method = "svmRadial", trControl = ctrl)
print(svm_model)
When you run this code, you should observe output similar to the following:
Support Vector Machines with Radial Basis Function Kernel
32 samples
10 predictors
No pre-processing
Resampling: Cross-Validated (5 fold)
Summary of sample sizes: 25, 25, 24, 28, 26
Resampling results across tuning parameters:
C RMSE Rsquared MAE
0.25 4.053680 0.6800637 3.107744
0.50 3.502690 0.7234818 2.817316
1.00 3.233459 0.7529081 2.742852
Tuning parameter 'sigma' was held constant at a value of 0.1737787
RMSE was used to select the optimal model using the smallest value.
The final values used for the model were sigma = 0.1737787 and C = 1.
Conclusion
The implementation of the train
function in R, part of the caret
package, empowers data scientists and analysts in machine learning and predictive modeling. This function streamlines the model training process, offering a unified interface for various algorithms and facilitating seamless customization.
Through the examples provided, we’ve seen how the train
function can be employed for both regression and classification tasks, accommodating different modeling methods such as linear regression, random forests, and support vector machines.
The train
function excels not only in model training but also in hyperparameter tuning, cross-validation, and performance evaluation. Its integration with the trainControl
function allows users to fine-tune the training process, making informed decisions about model complexity and generalization.
The ability to specify resampling methods, control parameters, and pre-processing steps adds a layer of flexibility, making it adaptable to diverse datasets and modeling scenarios.
Whether you are a beginner or an experienced practitioner, the train
function proves invaluable in simplifying complex workflows. It facilitates model comparison, ensemble learning, and the exploration of variable importance. The diagnostic outputs, such as performance metrics and tuning parameter summaries, aid in model interpretation and selection.
Sheeraz is a Doctorate fellow in Computer Science at Northwestern Polytechnical University, Xian, China. He has 7 years of Software Development experience in AI, Web, Database, and Desktop technologies. He writes tutorials in Java, PHP, Python, GoLang, R, etc., to help beginners learn the field of Computer Science.
LinkedIn Facebook