How to Create a Correlation Matrix Heatmap in Seaborn
Correlation matrices are your first line of defense against redundant features and hidden relationships in datasets. Before building any predictive model, you need to understand how your variables...
Key Insights
- Correlation heatmaps visualize relationships between multiple variables simultaneously, making them essential for feature engineering and identifying multicollinearity in machine learning pipelines
- Seaborn’s
heatmap()function combined with pandas’corr()method creates publication-ready visualizations in just two lines of code, but customization dramatically improves readability - Masking half the matrix eliminates redundant information since correlation matrices are symmetrical, while centering the colorbar at zero makes positive and negative correlations instantly distinguishable
Introduction & Setup
Correlation matrices are your first line of defense against redundant features and hidden relationships in datasets. Before building any predictive model, you need to understand how your variables interact. A correlation heatmap transforms a grid of numbers into an intuitive visual representation where patterns jump out immediately.
Seaborn makes this process straightforward. Install the required libraries if you haven’t already:
pip install seaborn pandas numpy matplotlib
Import the necessary packages:
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# Set style for better-looking plots
sns.set_theme(style="white")
The sns.set_theme() call isn’t mandatory, but it removes the default gray background and makes your heatmaps cleaner.
Preparing Your Data
Correlation matrices only work with numerical data. You can’t calculate correlations for categorical variables without encoding them first. Let’s start with a real dataset that ships with seaborn:
# Load the tips dataset
df = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/tips.csv')
# Select only numerical columns
numerical_df = df.select_dtypes(include=[np.number])
print(numerical_df.head())
print(f"\nShape: {numerical_df.shape}")
This gives you a DataFrame with columns like total_bill, tip, and size. For custom data, structure it similarly:
# Example with custom data
data = {
'feature_1': np.random.randn(100),
'feature_2': np.random.randn(100),
'feature_3': np.random.randn(100),
'target': np.random.randn(100)
}
custom_df = pd.DataFrame(data)
Calculate the correlation matrix using pandas’ built-in method:
# Generate correlation matrix
correlation_matrix = numerical_df.corr()
print(correlation_matrix)
The .corr() method defaults to Pearson correlation coefficients, which measure linear relationships. You can specify method='spearman' or method='kendall' for non-linear relationships, but Pearson suffices for most exploratory analysis.
Creating a Basic Heatmap
With your correlation matrix ready, creating a heatmap requires one function call:
# Create basic heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(correlation_matrix)
plt.title('Correlation Matrix Heatmap')
plt.tight_layout()
plt.show()
This produces a functional heatmap, but it’s not particularly useful. The default color scheme makes it hard to distinguish correlation strengths, and without annotations, you’re just looking at colored squares.
Customizing the Heatmap
Here’s where seaborn shines. A few parameters transform that basic heatmap into something actually readable:
# Enhanced heatmap with annotations
plt.figure(figsize=(10, 8))
sns.heatmap(
correlation_matrix,
annot=True, # Show correlation values
cmap='coolwarm', # Color palette
center=0, # Center colorbar at 0
square=True, # Make cells square-shaped
linewidths=1, # Add gridlines
cbar_kws={"shrink": 0.8} # Shrink colorbar slightly
)
plt.title('Enhanced Correlation Heatmap', fontsize=16, pad=20)
plt.tight_layout()
plt.show()
Let’s break down these parameters:
annot=Truedisplays the actual correlation coefficients in each cellcmap='coolwarm'uses a diverging color palette where blue represents negative correlations and red represents positive onescenter=0ensures zero correlation is white/neutral, making patterns more obvioussquare=Trueprevents distorted cells when you have different numbers of featureslinewidths=1adds separation between cells for clarity
The color palette matters more than you’d think. Avoid sequential colormaps like ‘viridis’ for correlation matrices—they make negative correlations harder to spot. Stick with diverging palettes: ‘coolwarm’, ‘RdBu_r’, or ‘seismic’.
Adjust figure size based on your feature count. For 5-10 features, use figsize=(10, 8). For 20+ features, scale up to (16, 14) or annotations become unreadable.
Advanced Techniques
The correlation matrix is symmetrical—the correlation between A and B equals the correlation between B and A. Displaying both halves wastes space. Mask the upper triangle:
# Create mask for upper triangle
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
# Apply mask to heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(
correlation_matrix,
mask=mask,
annot=True,
fmt='.2f',
cmap='coolwarm',
center=0,
square=True,
linewidths=1,
cbar_kws={"shrink": 0.8}
)
plt.title('Correlation Heatmap (Lower Triangle)', fontsize=16, pad=20)
plt.tight_layout()
plt.show()
The np.triu() function creates a boolean array matching your correlation matrix shape, with True values in the upper triangle. Seaborn hides those cells.
The fmt='.2f' parameter controls annotation formatting. Two decimal places balance precision with readability. For very small correlations, consider fmt='.3f'.
Handle extreme correlations by setting vmin and vmax:
# Control color scale limits
sns.heatmap(
correlation_matrix,
mask=mask,
annot=True,
fmt='.2f',
cmap='coolwarm',
vmin=-1,
vmax=1,
center=0,
square=True,
linewidths=1
)
This ensures the color scale always spans from -1 to 1, even if your actual correlations range from -0.3 to 0.5. It makes comparisons across different datasets easier.
Practical Use Cases & Interpretation
Here’s a complete example using the housing dataset to detect multicollinearity before building a regression model:
# Load housing data
from sklearn.datasets import fetch_california_housing
housing = fetch_california_housing()
housing_df = pd.DataFrame(housing.data, columns=housing.feature_names)
housing_df['MedHouseVal'] = housing.target
# Calculate correlations
corr = housing_df.corr()
# Create mask
mask = np.triu(np.ones_like(corr, dtype=bool))
# Create comprehensive heatmap
fig, ax = plt.subplots(figsize=(12, 10))
sns.heatmap(
corr,
mask=mask,
annot=True,
fmt='.2f',
cmap='RdBu_r',
center=0,
square=True,
linewidths=0.5,
cbar_kws={"shrink": 0.75, "label": "Correlation Coefficient"},
ax=ax
)
plt.title('California Housing Dataset - Correlation Analysis',
fontsize=16, pad=20, weight='bold')
plt.xlabel('')
plt.ylabel('')
plt.tight_layout()
plt.show()
# Identify strong correlations with target
target_corr = corr['MedHouseVal'].sort_values(ascending=False)
print("\nCorrelations with Median House Value:")
print(target_corr)
# Find multicollinearity issues (high correlation between features)
threshold = 0.8
high_corr = []
for i in range(len(corr.columns)):
for j in range(i+1, len(corr.columns)):
if abs(corr.iloc[i, j]) > threshold:
high_corr.append((corr.columns[i], corr.columns[j], corr.iloc[i, j]))
if high_corr:
print(f"\nFeature pairs with correlation > {threshold}:")
for feat1, feat2, value in high_corr:
print(f"{feat1} <-> {feat2}: {value:.3f}")
else:
print(f"\nNo feature pairs with correlation > {threshold}")
When interpreting your heatmap, look for:
Strong correlations with your target variable (> 0.5 or < -0.5): These features likely have predictive power. In the housing example, median income shows strong positive correlation with house values.
High inter-feature correlations (> 0.8 or < -0.8): This indicates multicollinearity. Consider removing one feature from each pair, using PCA, or applying regularization in your model.
Unexpected relationships: Sometimes you’ll spot correlations that don’t make domain sense. Investigate these—they might reveal data quality issues or spurious correlations.
Near-zero correlations: Features uncorrelated with your target probably won’t help your model. Consider dropping them during feature selection.
The heatmap doesn’t tell you about causation or non-linear relationships. A correlation of 0.1 doesn’t mean the relationship is weak—it might be strongly non-linear. For those cases, scatter plots and domain knowledge remain essential.
Use correlation heatmaps early in your analysis pipeline, right after data cleaning and before feature engineering. They’ll save you from building models with redundant features and help you understand your data’s structure at a glance.