How to Implement KNN in R
K-Nearest Neighbors (KNN) is one of the simplest yet most effective supervised learning algorithms. Unlike other machine learning methods that build explicit models during training, KNN is a lazy...
Key Insights
- KNN is a distance-based algorithm that classifies data points based on their nearest neighbors, making it ideal for pattern recognition tasks but requiring proper feature scaling
- The choice of K value dramatically affects model performance—too small leads to overfitting, too large causes underfitting, and cross-validation helps find the optimal balance
- Feature normalization is non-negotiable for KNN since the algorithm relies on distance calculations, and variables with larger scales will dominate the classification
Introduction to K-Nearest Neighbors
K-Nearest Neighbors (KNN) is one of the simplest yet most effective supervised learning algorithms. Unlike other machine learning methods that build explicit models during training, KNN is a lazy learner—it stores all training data and makes decisions only when predicting new instances.
The algorithm works by calculating the distance between a new data point and all existing points in the training set. It then identifies the K closest neighbors and assigns the new point to the most common class among those neighbors. For regression tasks, it averages the values instead.
KNN excels in pattern recognition, recommendation systems, and classification problems where decision boundaries are irregular. However, it struggles with high-dimensional data (curse of dimensionality) and requires significant memory for large datasets since it stores all training examples.
Setting Up Your R Environment
Before implementing KNN, install the necessary packages. The class package provides the core KNN functionality, while caret offers advanced tools for model training and evaluation.
# Install required packages (run once)
install.packages(c("class", "caret", "ggplot2"))
# Load libraries
library(class)
library(caret)
library(ggplot2)
# Load the iris dataset
data(iris)
# Examine the structure
str(iris)
head(iris)
The iris dataset contains 150 observations of iris flowers with four features (sepal length, sepal width, petal length, petal width) and three species classifications. This makes it perfect for demonstrating KNN classification.
Data Preparation and Exploration
Proper data preparation is critical for KNN success. Since the algorithm relies on distance calculations, features must be on comparable scales. A feature ranging from 0-1000 will dominate one ranging from 0-1, regardless of actual importance.
# Check for missing values
sum(is.na(iris))
# Separate features and labels
iris_features <- iris[, 1:4]
iris_labels <- iris[, 5]
# Normalize features using scale()
iris_normalized <- as.data.frame(scale(iris_features))
# Add the labels back
iris_normalized$Species <- iris_labels
# Create train-test split (70-30)
set.seed(123)
train_index <- createDataPartition(iris_normalized$Species,
p = 0.7,
list = FALSE)
train_data <- iris_normalized[train_index, ]
test_data <- iris_normalized[-train_index, ]
# Separate features and labels for modeling
train_features <- train_data[, 1:4]
train_labels <- train_data[, 5]
test_features <- test_data[, 1:4]
test_labels <- test_data[, 5]
The scale() function standardizes features to have mean 0 and standard deviation 1. The createDataPartition() function from caret ensures balanced class distribution in both training and test sets, which is crucial for classification problems.
Building the KNN Model
The knn() function from the class package requires training data, test data, training labels, and the K value. Let’s start with K=3, a common default choice.
# Build KNN model with K=3
knn_pred_k3 <- knn(train = train_features,
test = test_features,
cl = train_labels,
k = 3)
# View predictions
table(knn_pred_k3)
# Compare predictions to actual values
comparison <- data.frame(Predicted = knn_pred_k3,
Actual = test_labels)
head(comparison, 10)
Unlike many machine learning algorithms, KNN doesn’t have a separate training phase. The knn() function performs classification immediately, making it fast for small datasets but potentially slow for large ones.
Model Evaluation
Accuracy alone doesn’t tell the full story. Use confusion matrices to understand which classes the model confuses and calculate precision and recall for each class.
# Create confusion matrix
conf_matrix <- confusionMatrix(knn_pred_k3, test_labels)
print(conf_matrix)
# Extract key metrics
accuracy <- conf_matrix$overall['Accuracy']
precision <- conf_matrix$byClass[, 'Precision']
recall <- conf_matrix$byClass[, 'Recall']
f1_score <- conf_matrix$byClass[, 'F1']
# Display results
cat("Overall Accuracy:", round(accuracy, 4), "\n")
print(data.frame(Class = levels(test_labels),
Precision = round(precision, 4),
Recall = round(recall, 4),
F1 = round(f1_score, 4)))
The confusion matrix reveals patterns in misclassification. For iris data, you’ll typically see high accuracy (95%+) with occasional confusion between versicolor and virginica species, which have overlapping features.
Optimizing K Value
Choosing the right K value is crucial. Small K values (1-3) create complex decision boundaries that may overfit. Large K values oversimplify and may underfit. Test multiple K values to find the sweet spot.
# Test K values from 1 to 20
k_values <- 1:20
accuracy_values <- numeric(length(k_values))
for(i in seq_along(k_values)) {
knn_pred <- knn(train = train_features,
test = test_features,
cl = train_labels,
k = k_values[i])
accuracy_values[i] <- sum(knn_pred == test_labels) / length(test_labels)
}
# Create results dataframe
k_results <- data.frame(K = k_values, Accuracy = accuracy_values)
# Find optimal K
optimal_k <- k_results$K[which.max(k_results$Accuracy)]
cat("Optimal K value:", optimal_k, "\n")
cat("Maximum Accuracy:", round(max(k_results$Accuracy), 4), "\n")
# Plot accuracy vs K
ggplot(k_results, aes(x = K, y = Accuracy)) +
geom_line(color = "blue", size = 1) +
geom_point(color = "red", size = 2) +
geom_vline(xintercept = optimal_k, linetype = "dashed", color = "green") +
labs(title = "KNN Accuracy vs K Value",
x = "Number of Neighbors (K)",
y = "Accuracy") +
theme_minimal() +
scale_x_continuous(breaks = k_values)
This visualization helps identify the elbow point where accuracy stabilizes. For iris data, you’ll typically see optimal performance around K=5-7. Odd K values are preferred for binary classification to avoid ties.
Practical Example: Complete Implementation
Here’s a complete workflow implementing KNN for a customer classification scenario. This example demonstrates all steps in a production-ready format.
# Complete KNN Implementation Pipeline
# 1. Load and prepare data
set.seed(42)
data(iris)
# 2. Preprocessing function
preprocess_data <- function(data, target_col) {
features <- data[, -target_col]
labels <- data[, target_col]
# Normalize features
normalized_features <- as.data.frame(scale(features))
normalized_features$target <- labels
return(normalized_features)
}
# 3. Train-test split function
split_data <- function(data, train_ratio = 0.7) {
set.seed(123)
train_idx <- createDataPartition(data$target, p = train_ratio, list = FALSE)
train <- data[train_idx, ]
test <- data[-train_idx, ]
return(list(train = train, test = test))
}
# 4. KNN training and evaluation function
evaluate_knn <- function(train_data, test_data, k_value) {
train_features <- train_data[, -ncol(train_data)]
train_labels <- train_data$target
test_features <- test_data[, -ncol(test_data)]
test_labels <- test_data$target
# Make predictions
predictions <- knn(train = train_features,
test = test_features,
cl = train_labels,
k = k_value)
# Calculate accuracy
accuracy <- sum(predictions == test_labels) / length(test_labels)
# Confusion matrix
conf_mat <- confusionMatrix(predictions, test_labels)
return(list(predictions = predictions,
accuracy = accuracy,
confusion_matrix = conf_mat))
}
# 5. Execute pipeline
processed_data <- preprocess_data(iris, 5)
split <- split_data(processed_data, 0.7)
# Find optimal K
k_range <- 1:15
accuracies <- sapply(k_range, function(k) {
result <- evaluate_knn(split$train, split$test, k)
return(result$accuracy)
})
optimal_k <- k_range[which.max(accuracies)]
# Final model with optimal K
final_model <- evaluate_knn(split$train, split$test, optimal_k)
cat("\n=== Final KNN Model Results ===\n")
cat("Optimal K:", optimal_k, "\n")
cat("Test Accuracy:", round(final_model$accuracy, 4), "\n\n")
print(final_model$confusion_matrix)
This production-ready implementation encapsulates each step in reusable functions, making it easy to apply KNN to different datasets. The modular approach allows you to swap datasets, adjust parameters, and maintain clean, testable code.
KNN’s simplicity makes it an excellent baseline algorithm. While it may not always achieve the highest accuracy compared to ensemble methods or deep learning, its interpretability and ease of implementation make it invaluable for quick prototyping and establishing performance benchmarks. Remember to always normalize your features, validate your K value through cross-validation, and consider computational costs for large datasets.