How to Create a 3D Scatter Plot in Matplotlib

3D scatter plots are essential tools for visualizing relationships between three continuous variables simultaneously. Unlike 2D plots that force you to choose which dimensions to display, 3D...

Key Insights

  • 3D scatter plots excel at revealing patterns in multi-dimensional data that would be invisible in 2D projections, making them invaluable for clustering analysis and scientific visualization
  • Matplotlib’s mpl_toolkits.mplot3d module provides straightforward 3D plotting capabilities, but requires understanding the Axes3D object and projection system
  • Color mapping and point sizing can effectively represent fourth and fifth dimensions in 3D space, though overuse of visual encodings can reduce plot clarity

Introduction & Setup

3D scatter plots are essential tools for visualizing relationships between three continuous variables simultaneously. Unlike 2D plots that force you to choose which dimensions to display, 3D visualizations let you examine all three dimensions at once, revealing clusters, outliers, and patterns that might otherwise remain hidden.

Common use cases include visualizing clustering algorithm results (like K-means or DBSCAN), exploring scientific datasets with three measured variables, analyzing principal component analysis results, and displaying spatial data with an additional measurement dimension.

Before creating 3D plots, ensure you have matplotlib installed. If not, install it via pip:

pip install matplotlib numpy

Here are the necessary imports:

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

The Axes3D import enables 3D projection capabilities within matplotlib. While newer matplotlib versions auto-register 3D projections, explicitly importing Axes3D ensures compatibility across versions.

Creating a Basic 3D Scatter Plot

Creating a 3D scatter plot requires setting up a figure with a 3D projection. The process differs slightly from standard 2D plotting because you must explicitly specify the projection type.

Here’s a minimal working example:

# Generate synthetic data
np.random.seed(42)
n_points = 100
x = np.random.randn(n_points)
y = np.random.randn(n_points)
z = np.random.randn(n_points)

# Create figure and 3D axes
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Create scatter plot
ax.scatter(x, y, z)

# Display
plt.show()

The key line is fig.add_subplot(111, projection='3d'), which creates an Axes3D object instead of a standard 2D axes. The ax.scatter() method then accepts three positional arguments (x, y, z) corresponding to the three spatial dimensions.

Alternatively, you can use the more modern syntax:

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(projection='3d')
ax.scatter(x, y, z)
plt.show()

This approach is cleaner when you only need a single subplot.

Customizing Plot Appearance

Basic scatter plots are functional but often lack visual appeal and clarity. Matplotlib provides extensive customization options to make your plots more informative and visually distinctive.

# Generate data with varying characteristics
np.random.seed(42)
n_points = 200
x = np.random.randn(n_points)
y = np.random.randn(n_points)
z = np.random.randn(n_points)

# Create a fourth dimension for color mapping
colors = x**2 + y**2  # Distance from origin in xy-plane

# Create varying sizes based on z-values
sizes = 100 * (z - z.min()) / (z.max() - z.min()) + 20

# Create the plot
fig = plt.figure(figsize=(12, 9))
ax = fig.add_subplot(projection='3d')

scatter = ax.scatter(x, y, z, 
                     c=colors,           # Color by fourth dimension
                     s=sizes,            # Vary point sizes
                     cmap='viridis',     # Color map
                     alpha=0.6,          # Transparency
                     marker='o',         # Marker style
                     edgecolors='black', # Edge color
                     linewidth=0.5)      # Edge width

plt.show()

Key customization parameters:

  • c: Controls point colors (can be single color, array of colors, or values mapped to a colormap)
  • s: Point sizes (scalar or array)
  • cmap: Colormap name (viridis, plasma, coolwarm, etc.)
  • alpha: Transparency (0=transparent, 1=opaque)
  • marker: Point style (‘o’, ‘^’, ’s’, ‘D’, etc.)
  • edgecolors: Border color around points
  • linewidth: Border thickness

Using color and size to represent additional dimensions effectively creates a 5D visualization (x, y, z, color, size), though you should use this sparingly to avoid overwhelming viewers.

Adding Labels, Titles, and Legends

Professional visualizations require clear labeling. 3D plots need special attention since the viewing angle can obscure labels.

# Create sample data
np.random.seed(42)
n_points = 150
x = np.random.randn(n_points)
y = np.random.randn(n_points)
z = np.random.randn(n_points)
colors = x**2 + y**2

# Create plot
fig = plt.figure(figsize=(12, 9))
ax = fig.add_subplot(projection='3d')

scatter = ax.scatter(x, y, z, c=colors, cmap='plasma', 
                     s=50, alpha=0.6, edgecolors='black', linewidth=0.5)

# Add labels and title
ax.set_xlabel('X Axis Label', fontsize=12, labelpad=10)
ax.set_ylabel('Y Axis Label', fontsize=12, labelpad=10)
ax.set_zlabel('Z Axis Label', fontsize=12, labelpad=10)
ax.set_title('3D Scatter Plot with Complete Labeling', fontsize=14, pad=20)

# Add colorbar
cbar = plt.colorbar(scatter, ax=ax, pad=0.1, shrink=0.8)
cbar.set_label('Distance from Origin (XY)', fontsize=10)

# Customize grid
ax.grid(True, alpha=0.3)

# Adjust viewing angle for better label visibility
ax.view_init(elev=20, azim=45)

plt.tight_layout()
plt.show()

The labelpad parameter adds spacing between axis labels and tick labels, preventing overlap. The pad parameter in set_title() creates space between the title and the plot. For colorbars, shrink reduces the height to match the plot better, and pad controls horizontal spacing.

Interactive Features & Rotation

One advantage of 3D plots is the ability to rotate them and view from different angles. The view_init() method controls the viewing angle programmatically.

# Create sample data
np.random.seed(42)
theta = np.linspace(0, 4*np.pi, 100)
x = np.sin(theta)
y = np.cos(theta)
z = theta

# Create multiple views
fig = plt.figure(figsize=(15, 5))

# View 1: Default angle
ax1 = fig.add_subplot(131, projection='3d')
ax1.scatter(x, y, z, c=z, cmap='viridis', s=30)
ax1.set_title('Default View')
ax1.view_init(elev=30, azim=45)

# View 2: Top-down view
ax2 = fig.add_subplot(132, projection='3d')
ax2.scatter(x, y, z, c=z, cmap='viridis', s=30)
ax2.set_title('Top-Down View')
ax2.view_init(elev=90, azim=0)

# View 3: Side view
ax3 = fig.add_subplot(133, projection='3d')
ax3.scatter(x, y, z, c=z, cmap='viridis', s=30)
ax3.set_title('Side View')
ax3.view_init(elev=0, azim=0)

plt.tight_layout()
plt.show()

The elev parameter controls elevation angle (vertical rotation, -90 to 90 degrees), while azim controls azimuth angle (horizontal rotation, 0 to 360 degrees).

In Jupyter notebooks, plots are interactive by default when using %matplotlib notebook or %matplotlib widget. For standalone scripts, the standard plt.show() creates an interactive window where you can click and drag to rotate.

Real-World Application

Let’s apply these techniques to the classic Iris dataset, visualizing three flower measurements and species classification:

from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

# Load iris dataset
iris = load_iris()
X = iris.data
y = iris.target

# Select three features: sepal length, sepal width, petal length
x_data = X[:, 0]  # Sepal length
y_data = X[:, 1]  # Sepal width
z_data = X[:, 2]  # Petal length

# Create plot
fig = plt.figure(figsize=(12, 9))
ax = fig.add_subplot(projection='3d')

# Create scatter plot with species as colors
scatter = ax.scatter(x_data, y_data, z_data, 
                     c=y, 
                     cmap='viridis', 
                     s=100, 
                     alpha=0.7,
                     edgecolors='black',
                     linewidth=0.5)

# Labels
ax.set_xlabel('Sepal Length (cm)', fontsize=11, labelpad=10)
ax.set_ylabel('Sepal Width (cm)', fontsize=11, labelpad=10)
ax.set_zlabel('Petal Length (cm)', fontsize=11, labelpad=10)
ax.set_title('Iris Dataset: 3D Feature Visualization', fontsize=14, pad=20)

# Create custom legend
species_names = iris.target_names
legend_elements = [plt.Line2D([0], [0], marker='o', color='w', 
                              markerfacecolor=plt.cm.viridis(i/2), 
                              markersize=10, label=species_names[i])
                   for i in range(3)]
ax.legend(handles=legend_elements, loc='upper left', fontsize=10)

# Optimize viewing angle
ax.view_init(elev=20, azim=135)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

This visualization clearly shows how the three Iris species cluster in 3D space based on their physical measurements. The setosa species (typically purple in viridis) separates distinctly, while versicolor and virginica show some overlap.

Conclusion & Best Practices

3D scatter plots are powerful visualization tools when used appropriately. Follow these best practices:

When to use 3D plots: Use them when all three dimensions are essential to understanding the data. If two dimensions suffice, stick with 2D plots—they’re easier to read and interpret.

Performance considerations: For datasets exceeding 10,000 points, consider downsampling or using alpha transparency to prevent visual clutter. Matplotlib can handle hundreds of thousands of points, but rendering becomes sluggish and the plot becomes difficult to interpret.

Color and size encoding: Use color to represent a fourth dimension sparingly. Ensure color scales are intuitive and include a colorbar. Avoid using both color and size for different dimensions unless absolutely necessary.

Viewing angles: Always test multiple viewing angles to ensure your chosen perspective doesn’t hide important patterns. Consider providing multiple views or interactive plots for presentations.

Export considerations: When saving 3D plots for publications, save multiple views at different angles. Static images lose the interactive rotation capability, so comprehensive documentation requires multiple perspectives.

3D scatter plots shine when exploring multi-dimensional relationships, but they require thoughtful design to communicate effectively. Master the basics first, then progressively add complexity only when it serves your analytical goals.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.