Usar validación cruzada para un modelo de clasificación KNN en R

Jesse John 21 junio 2023
  1. Diferentes enfoques de validación cruzada
  2. Validación cruzada repetida de K-Fold para un modelo de clasificación de K-vecino más cercano
Usar validación cruzada para un modelo de clasificación KNN en R

La validación cruzada nos permite evaluar el rendimiento de un modelo en datos nuevos aunque solo tengamos el conjunto de datos de entrenamiento. Es una técnica general que puede aplicarse a modelos de regresión y clasificación.

Este artículo discutirá cómo realizar una validación cruzada repetida k-fold para un modelo de clasificación K-Nearest Neighbor (KNN). Emplearemos el paquete de intercalación para este propósito.

La K en KNN se refiere al número de vecinos de observación. Por otro lado, k en k-fold es el número de subconjuntos de los datos de entrenamiento.

Diferentes enfoques de validación cruzada

Hay diferentes enfoques para la validación cruzada.

La versión más básica utiliza un subconjunto de datos de entrenamiento para validar el modelo, denominado enfoque de conjunto de validación. El modelo se ajusta solo una vez y luego se prueba en el subconjunto.

El otro implica ajustar tantos modelos como observaciones y tomar la tasa de error promedio. Cada modelo está equipado con una observación omitida; el modelo luego se prueba en esa única observación.

Esto se denomina Validación cruzada Leave One Out (LOOCV).

El enfoque más útil implica:

  1. Dividir el conjunto de datos de entrenamiento en k pliegues (grupos),
  2. Ajuste del modelo k veces,
  3. Dejando fuera un pliegue, y
  4. Probar el modelo en eso.

Esto se llama la validación cruzada de k-fold. Normalmente, un valor k de 5 o 10 da buenos resultados.

Una mejora de la validación cruzada de k-pliegue implica ajustar el modelo de validación cruzada de k-pliegue varias veces con diferentes divisiones de los pliegues. Esto se llama la validación cruzada repetida de k-fold, que usaremos.

Validación cruzada repetida de K-Fold para un modelo de clasificación de K-vecino más cercano

Ahora veremos cómo ajustar un modelo de clasificación K-Nearest Neighbor (KNN) usando validación cruzada repetida de k-fold. Usaremos el paquete caret.

El paquete caret es muy versátil y se puede utilizar para construir varios tipos de modelos. Consulte su documentación en CRAN para obtener más detalles.

Por lo general, la validación cruzada k-fold solo nos dice qué tan preciso se espera que sea nuestro modelo en datos nuevos.

Por ejemplo, supongamos que ajustamos un modelo KNN K = 5 usando una validación cruzada repetida de k-pliegues usando 10 k-pliegues, repetidos 3 veces. El modelo se ajustará 10 veces para cada una de las 3 divisiones de datos diferentes, y obtendremos métricas de rendimiento para un solo modelo: el que tiene vecinos K = 5.

Además de lo anterior, el paquete caret nos permite ajustar modelos KNN para diferentes valores de K. Luego, la función informa el valor de K que da como resultado el mejor modelo y crea ese modelo para nosotros.

La función createDataPartition() crea una división aleatoria estratificada de un vector de factores. Usaremos esto para separar nuestros datos en subconjuntos de entrenamiento y prueba para verificar la precisión del modelo.

La función train() es la función main para crear un modelo, donde:

  1. x es el marco de datos con los predictores.
  2. y es el marco de datos o vector de resultados.
  3. El argumento método toma el tipo de modelo que queremos construir. Especificaremos knn.
  4. Para preproceso, especificaremos escala y centro.
  5. El argumento trControl nos permite especificar los detalles del procedimiento de validación cruzada.
  6. El argumento tuneGrid ayudará a crear y comparar múltiples modelos. Se necesita una trama de datos con el nombre del parámetro a sintonizar.

Debido a que estamos construyendo un modelo KNN, daremos k en minúsculas como parámetro de ajuste para tuneGrid. Proporcionaremos un vector de valores K del 1 al 12, para el cual queremos que la función cree y pruebe modelos.

Los detalles de la validación cruzada se pasan al argumento trControl usando la función trainControl().

  1. Para el argumento método, especificaremos repeatedcv porque queremos una validación cruzada repetida.
  2. Cuando el método es cv o repeatedcv, el argumento número especifica los pliegues, k. Usaremos 10.
  3. El argumento repeticiones especifica cuántas veces se debe repetir la división k-fold.

Código de ejemplo:

# Create the data vectors for the demonstration.
# We will create two numeric vectors as predictors.
# Each vector will have two distinct groups to suit our model.
# We will create a factor with two levels.
# The factor levels correspond to the groups in the predictor vectors.

set.seed(564)
vX1a = round(rnorm(100, 2,2))+4
set.seed(574)
vX2a = round(rnorm(100, 15,4))

set.seed(584)
vX1b = round(rnorm(100, 10,3))+5
set.seed(594)
vX2b = round(rnorm(100, 5,4))

vYa = rep("Blue", 100)
vYb = rep("Red", 100)

vX1 = c(vX1a, vX1b)
vX2 = c(vX2a, vX2b)
vY = c(vYa, vYb)

# Dummy column for ordering rows.
set.seed(528)
R = sample(1:200,200)

# Temporary data frame.
temp_df = data.frame(X1 = vX1, X2 = vX2, Y = as.factor(vY), R)

# Packages that we will use.
library(ggplot2)
library(dplyr)

# See the sample data.
temp_df %>% ggplot(aes(x=X1, y = X2, colour = Y)) + geom_point()

# Re-order the rows, just to see that the KNN model works with the rows jumbled up.
# Final data frame.
# Notice that the outputs are a factor vector.
fin_df = temp_df %>% arrange(R) %>% select(X1, X2, Y)
head(fin_df)
str(fin_df)

# Install the caret package if it is not already installed.
# To install, uncomment the next line and run it.
# install.packages("caret")

# Load the caret package.
library(caret)

# Split the data frame into a training set and test set.
# Create a list of row numbers in the training set.
# This function creates a stratified random sample of all the outcome classes.
set.seed(365)
training_row_index = createDataPartition(fin_df[,3], p=0.75, list=FALSE)

# Create training sets of the predictors and the corresponding outcomes.
trg_data = fin_df[training_row_index,1:2]
trg_class = fin_df[training_row_index,3]

# Create the test set of predictors and the outcomes that we will later use.
tst_data = fin_df[-training_row_index,1:2]
tst_class = fin_df[-training_row_index,3]

# Let us check if the sample is stratified:
table(tst_class)
# Obviously, the training sample will complement these numbers of the totals.

# We will build a K-Nearest neighbors model using repeated k-fold cross-validation.
# The arguments are described in the article.
mod_knn = train(x = trg_data,
                y = trg_class,
                method = "knn",
                preProcess = c("center", "scale"),
                tuneGrid = data.frame(k = c(1:12)),
                trControl = trainControl(method = "repeatedcv",
                                         number = 10,
                                         repeats = 3)
                )

# View the fitted model.
mod_knn

Producción :

> head(fin_df)
  X1 X2    Y
1 15  6  Red
2 15  2  Red
3 14  3  Red
4  4 22 Blue
5 20 -3  Red
6  4 22 Blue

> str(fin_df)
'data.frame':	200 obs. of  3 variables:
 $ X1: num  15 15 14 4 20 4 15 2 20 13 ...
 $ X2: num  6 2 3 22 -3 22 7 16 9 -6 ...
 $ Y : Factor w/ 2 levels "Blue","Red": 2 2 2 1 2 1 2 1 2 2 ...

> # Let us check if the sample is really stratified:
> table(tst_class)
tst_class
Blue  Red
  25   25

> # View the fitted model.
> mod_knn
k-Nearest Neighbors

150 samples
  2 predictor
  2 classes: 'Blue', 'Red'

Pre-processing: centered (2), scaled (2)
Resampling: Cross-Validated (10 fold, repeated 3 times)
Summary of sample sizes: 135, 136, 135, 135, 135, 134, ...
Resampling results across tuning parameters:

  k   Accuracy   Kappa
   1  0.9710317  0.9420024
   2  0.9736905  0.9473602
   3  0.9753373  0.9505141
   4  0.9842460  0.9683719
   5  0.9864683  0.9728764
   6  0.9843849  0.9687098
   7  0.9843849  0.9687098
   8  0.9800794  0.9600386
   9  0.9800794  0.9600386
  10  0.9800794  0.9600386
  11  0.9800794  0.9600386
  12  0.9800794  0.9600386

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was k = 5.

Encontramos que el mejor modelo usa K = 5.

Usemos ahora la función base R predict() y una matriz de confusión creada usando la función table() para verificar la precisión del modelo en los datos de prueba.

Código de ejemplo:

# Use model to predict classes for the test set.
pred_cls = predict(mod_knn, tst_data)

# Check the accuracy of the predictions by computing the confusion matrix.
table(Actl = tst_class, Pred = pred_cls)

Producción :

> table(Actl = tst_class, Pred = pred_cls)
      Pred
Actl   Blue Red
  Blue   25   0
  Red     0  25

Encontramos que el modelo predijo la clase de datos de prueba con total precisión. Esto fue posible porque los datos estaban bien separados en el marco de datos de muestra.

En la práctica, la precisión será menor. Sin embargo, para cada modelo, el procedimiento repetido de validación cruzada de k-fold nos da una buena idea de la precisión que podemos esperar en nuevos datos similares a los datos de entrenamiento.

Autor: Jesse John
Jesse John avatar Jesse John avatar

Jesse is passionate about data analysis and visualization. He uses the R statistical programming language for all aspects of his work.