How to Implement K-Means Clustering in R
K-means clustering partitions data into k distinct groups by iteratively assigning points to the nearest centroid and recalculating centroids based on cluster membership. The algorithm minimizes...
Key Insights
- K-means clustering requires scaled data and careful selection of k using the elbow method or silhouette analysis—choosing the wrong k leads to meaningless segments
- Always run k-means with multiple random starts (nstart=25 or higher) since the algorithm is sensitive to initial centroid placement and can converge to local optima
- Cluster validation metrics like within-cluster sum of squares and silhouette scores are essential for evaluating quality, but domain knowledge should guide final interpretation
Introduction to K-Means Clustering
K-means clustering partitions data into k distinct groups by iteratively assigning points to the nearest centroid and recalculating centroids based on cluster membership. The algorithm minimizes within-cluster variance, making it ideal for discovering natural groupings in unlabeled data.
The process works through these steps: initialize k random centroids, assign each point to its nearest centroid, recalculate centroids as the mean of assigned points, and repeat until convergence. This simplicity makes k-means computationally efficient for large datasets.
Common applications include customer segmentation for targeted marketing, image compression by reducing color palettes, anomaly detection in network traffic, and document clustering for topic modeling. The algorithm works best with spherical clusters of similar sizes and requires numeric features.
Preparing Your Data
Data preparation determines clustering success. K-means uses Euclidean distance, so features with larger scales dominate the distance calculation. Always scale your data.
# Load required libraries
library(tidyverse)
library(cluster)
library(factoextra)
# Load and explore data
data(iris)
df <- iris %>% select(-Species) # Remove labels for unsupervised learning
# Check data structure and missing values
glimpse(df)
sum(is.na(df))
# Scale features to mean=0, sd=1
df_scaled <- scale(df)
head(df_scaled)
# Verify scaling
colMeans(df_scaled) # Should be ~0
apply(df_scaled, 2, sd) # Should be 1
Scaling transforms features to comparable ranges. Without it, a feature measured in thousands would dominate one measured in decimals, producing biased clusters. The scale() function standardizes each column to have zero mean and unit variance.
Check for missing values before scaling. K-means cannot handle NAs, so either impute them or remove affected rows. For categorical variables, convert them to numeric dummy variables or use alternative clustering algorithms like k-modes.
# Handling missing values example
df_clean <- df %>%
drop_na() %>% # Remove rows with any NA
as.data.frame()
# For imputation alternative
# df_imputed <- df %>%
# mutate(across(everything(), ~ifelse(is.na(.), median(., na.rm=TRUE), .)))
Determining the Optimal Number of Clusters
Choosing k requires balancing model complexity against cluster cohesion. The elbow method plots total within-cluster sum of squares (WSS) against k values. Look for the “elbow” where WSS decrease slows significantly.
# Calculate WSS for k=1 to k=10
wss <- sapply(1:10, function(k) {
kmeans(df_scaled, centers=k, nstart=25)$tot.withinss
})
# Create elbow plot
elbow_data <- data.frame(k=1:10, wss=wss)
ggplot(elbow_data, aes(x=k, y=wss)) +
geom_line(linewidth=1) +
geom_point(size=3) +
scale_x_continuous(breaks=1:10) +
labs(title="Elbow Method for Optimal k",
x="Number of Clusters (k)",
y="Total Within-Cluster Sum of Squares") +
theme_minimal()
The elbow often appears subjective. Silhouette analysis provides quantitative validation by measuring how similar each point is to its own cluster compared to other clusters. Scores range from -1 to 1, with higher values indicating better-defined clusters.
# Calculate average silhouette width for k=2 to k=10
silhouette_scores <- sapply(2:10, function(k) {
km <- kmeans(df_scaled, centers=k, nstart=25)
ss <- silhouette(km$cluster, dist(df_scaled))
mean(ss[, 3])
})
# Plot silhouette scores
sil_data <- data.frame(k=2:10, silhouette=silhouette_scores)
ggplot(sil_data, aes(x=k, y=silhouette)) +
geom_line(linewidth=1) +
geom_point(size=3) +
scale_x_continuous(breaks=2:10) +
labs(title="Average Silhouette Width by k",
x="Number of Clusters (k)",
y="Average Silhouette Width") +
theme_minimal()
Choose k where silhouette width is maximized. For iris data, k=2 or k=3 typically performs best. Combine both methods—if the elbow suggests k=3 and silhouette confirms it, you have strong evidence.
Implementing K-Means with Base R
R’s kmeans() function provides efficient implementation with critical parameters. The nstart parameter runs the algorithm multiple times with different random initializations, returning the best result.
# Run k-means with k=3
set.seed(123) # Reproducibility
km_result <- kmeans(df_scaled,
centers=3, # Number of clusters
nstart=25, # Multiple random starts
iter.max=100) # Maximum iterations
# Examine results
print(km_result)
# Extract cluster assignments
clusters <- km_result$cluster
table(clusters)
# Extract cluster centers (in scaled space)
centers <- km_result$centers
print(centers)
# Get cluster sizes
km_result$size
# Total within-cluster sum of squares
km_result$tot.withinss
# Between-cluster sum of squares
km_result$betweenss
Always set nstart to at least 25. K-means is sensitive to initial centroid placement and can converge to local optima. Multiple starts ensure you find the global optimum or close to it.
The iter.max parameter prevents infinite loops if convergence is slow. The default of 10 is often too low for complex data. Set it to 100 or higher for production code.
# Compare different nstart values
set.seed(123)
km_low <- kmeans(df_scaled, centers=3, nstart=1)
km_high <- kmeans(df_scaled, centers=3, nstart=50)
# Lower tot.withinss is better
km_low$tot.withinss
km_high$tot.withinss
Visualizing Clustering Results
Visualization reveals cluster separation and overlap. For datasets with more than two dimensions, use PCA to project into 2D space while preserving variance.
# Add cluster assignments to original data
df_clustered <- iris %>%
select(-Species) %>%
mutate(Cluster = as.factor(km_result$cluster))
# 2D scatter plot
ggplot(df_clustered, aes(x=Sepal.Length, y=Sepal.Width, color=Cluster)) +
geom_point(size=3, alpha=0.6) +
labs(title="K-Means Clustering Results",
x="Sepal Length",
y="Sepal Width") +
theme_minimal() +
scale_color_brewer(palette="Set1")
# Using factoextra for PCA-based visualization
fviz_cluster(km_result, data=df_scaled,
palette="jco",
ggtheme=theme_minimal(),
main="Cluster Plot with PCA")
The fviz_cluster() function automatically performs PCA and plots the first two principal components. The percentages on axes show variance explained by each component.
For three-dimensional visualization, use plotly for interactive exploration:
library(plotly)
# 3D scatter plot
plot_ly(df_clustered,
x=~Sepal.Length,
y=~Sepal.Width,
z=~Petal.Length,
color=~Cluster,
type="scatter3d",
mode="markers") %>%
layout(title="3D Cluster Visualization")
Evaluating Cluster Quality
Quantitative metrics validate clustering quality. Within-cluster variance measures compactness; between-cluster variance measures separation.
# Calculate within-cluster sum of squares for each cluster
within_ss <- km_result$withinss
names(within_ss) <- paste("Cluster", 1:3)
print(within_ss)
# Between-cluster sum of squares (higher is better)
between_ss <- km_result$betweenss
print(between_ss)
# Ratio of between-SS to total-SS (higher is better)
ratio <- between_ss / km_result$totss
print(paste("Between-SS / Total-SS:", round(ratio, 3)))
# Detailed silhouette analysis
sil <- silhouette(km_result$cluster, dist(df_scaled))
summary(sil)
# Visualize silhouette plot
fviz_silhouette(sil)
A good clustering has high between-cluster variance and low within-cluster variance. The ratio of between-SS to total-SS indicates how much variance is explained by clustering. Values above 0.5 suggest meaningful separation.
Silhouette plots show individual point quality. Points with negative silhouette widths are likely misclassified. If many points have negative values, reconsider k or examine outliers.
# Compare multiple k values
comparison <- data.frame(
k = 2:6,
tot_withinss = sapply(2:6, function(k) {
kmeans(df_scaled, centers=k, nstart=25)$tot.withinss
}),
avg_silhouette = sapply(2:6, function(k) {
km <- kmeans(df_scaled, centers=k, nstart=25)
mean(silhouette(km$cluster, dist(df_scaled))[, 3])
})
)
print(comparison)
Practical Example: Customer Segmentation
Here’s a complete workflow for segmenting customers based on purchasing behavior:
# Create synthetic customer data
set.seed(456)
customers <- data.frame(
customer_id = 1:200,
annual_spending = rnorm(200, 5000, 2000),
visit_frequency = rpois(200, 12),
avg_transaction = rnorm(200, 150, 50),
tenure_months = sample(1:60, 200, replace=TRUE)
) %>%
mutate(across(where(is.numeric) & !customer_id, ~pmax(., 0))) # No negative values
# Scale features
customer_features <- customers %>%
select(-customer_id) %>%
scale()
# Determine optimal k
wss_customers <- sapply(1:8, function(k) {
kmeans(customer_features, centers=k, nstart=25)$tot.withinss
})
# Run k-means with k=4
km_customers <- kmeans(customer_features, centers=4, nstart=50, iter.max=100)
# Add segments to original data
customers$segment <- km_customers$cluster
# Analyze segment characteristics
segment_summary <- customers %>%
group_by(segment) %>%
summarise(
count = n(),
avg_spending = mean(annual_spending),
avg_visits = mean(visit_frequency),
avg_transaction = mean(avg_transaction),
avg_tenure = mean(tenure_months)
) %>%
arrange(desc(avg_spending))
print(segment_summary)
# Visualize segments
ggplot(customers, aes(x=annual_spending, y=visit_frequency, color=factor(segment))) +
geom_point(size=3, alpha=0.6) +
labs(title="Customer Segments",
x="Annual Spending ($)",
y="Visit Frequency",
color="Segment") +
theme_minimal() +
scale_color_brewer(palette="Set1")
Interpret segments in business terms. Segment 1 might be “high-value frequent buyers,” Segment 2 “occasional big spenders,” Segment 3 “regular low-value customers,” and Segment 4 “new or inactive customers.” Use these insights for targeted marketing, personalized offers, and retention strategies.
K-means clustering transforms raw data into actionable customer intelligence. Scale your features, validate your k selection, run multiple initializations, and always interpret results through domain expertise. The algorithm provides the structure; your business knowledge provides the meaning.