Matrix Chain Multiplication: Optimal Parenthesization

Matrix multiplication is associative: (AB)C = A(BC). This mathematical property might seem like a trivial detail, but it has profound computational implications. While the result is identical...

Key Insights

  • Matrix multiplication is associative, meaning you can parenthesize a chain any way you like—but different parenthesizations can result in orders of magnitude difference in computational cost
  • The optimal parenthesization problem exhibits both optimal substructure and overlapping subproblems, making it a textbook application of dynamic programming with O(n³) time complexity
  • Real-world applications include machine learning frameworks, graphics pipelines, and database query optimizers, where choosing the right multiplication order can dramatically impact performance

Introduction: Why Multiplication Order Matters

Matrix multiplication is associative: (AB)C = A(BC). This mathematical property might seem like a trivial detail, but it has profound computational implications. While the result is identical regardless of how you parenthesize the multiplication, the number of scalar operations required can vary dramatically.

Consider a chain of four matrices with the following dimensions:

  • A₁: 10 × 30
  • A₂: 30 × 5
  • A₃: 5 × 60
  • A₄: 60 × 10

Let’s calculate the cost of different parenthesizations:

def multiplication_cost(p: int, q: int, r: int) -> int:
    """Cost of multiplying a (p×q) matrix with a (q×r) matrix."""
    return p * q * r

# Dimensions: A1(10×30), A2(30×5), A3(5×60), A4(60×10)
dims = [10, 30, 5, 60, 10]

# Parenthesization 1: ((A1 × A2) × A3) × A4
cost1 = (
    multiplication_cost(10, 30, 5) +   # A1 × A2 -> 10×5, cost: 1500
    multiplication_cost(10, 5, 60) +   # result × A3 -> 10×60, cost: 3000
    multiplication_cost(10, 60, 10)    # result × A4 -> 10×10, cost: 6000
)
print(f"((A1 × A2) × A3) × A4: {cost1} operations")  # 10,500

# Parenthesization 2: (A1 × (A2 × A3)) × A4
cost2 = (
    multiplication_cost(30, 5, 60) +   # A2 × A3 -> 30×60, cost: 9000
    multiplication_cost(10, 30, 60) +  # A1 × result -> 10×60, cost: 18000
    multiplication_cost(10, 60, 10)    # result × A4 -> 10×10, cost: 6000
)
print(f"(A1 × (A2 × A3)) × A4: {cost2} operations")  # 33,000

# Parenthesization 3: (A1 × A2) × (A3 × A4)
cost3 = (
    multiplication_cost(10, 30, 5) +   # A1 × A2 -> 10×5, cost: 1500
    multiplication_cost(5, 60, 10) +   # A3 × A4 -> 5×10, cost: 3000
    multiplication_cost(10, 5, 10)     # results together, cost: 500
)
print(f"(A1 × A2) × (A3 × A4): {cost3} operations")  # 5,000

The difference is striking: the worst parenthesization requires 33,000 operations while the best needs only 5,000—a 6.6x difference. For larger chains with bigger matrices, this ratio can become astronomical.

Problem Definition and Cost Analysis

The matrix chain multiplication problem is formally defined as follows: given a chain of n matrices A₁, A₂, …, Aₙ where matrix Aᵢ has dimensions pᵢ₋₁ × pᵢ, determine the parenthesization that minimizes the total number of scalar multiplications.

The cost of multiplying two matrices of dimensions (p × q) and (q × r) is p × q × r scalar multiplications. This formula drives everything that follows.

Why can’t we simply try all possible parenthesizations? The number of ways to parenthesize n matrices is given by the (n-1)th Catalan number:

def catalan(n: int) -> int:
    """Calculate the nth Catalan number."""
    if n <= 1:
        return 1
    result = 0
    for i in range(n):
        result += catalan(i) * catalan(n - 1 - i)
    return result

# Number of parenthesizations for n matrices
for n in range(1, 12):
    print(f"{n} matrices: {catalan(n - 1)} parenthesizations")

The Catalan numbers grow exponentially: C₁₀ = 16,796 and C₂₀ = 6,564,120,420. Brute force is not an option.

Dynamic Programming Formulation

The key insight is that this problem has optimal substructure. If the optimal solution splits the chain A₁…Aₙ at position k, then the sub-solutions for A₁…Aₖ and Aₖ₊₁…Aₙ must also be optimal.

We define m[i,j] as the minimum cost to compute the product Aᵢ × Aᵢ₊₁ × … × Aⱼ. The recurrence relation is:

  • m[i,i] = 0 (base case: single matrix, no multiplication needed)
  • m[i,j] = min{m[i,k] + m[k+1,j] + pᵢ₋₁ × pₖ × pⱼ} for i ≤ k < j

Here’s the memoized recursive implementation:

from functools import lru_cache
from typing import List, Tuple

def matrix_chain_memoized(dims: List[int]) -> Tuple[int, dict]:
    """
    Find minimum multiplication cost using memoization.
    dims[i-1] × dims[i] gives dimensions of matrix i.
    """
    n = len(dims) - 1  # number of matrices
    split_points = {}
    
    @lru_cache(maxsize=None)
    def dp(i: int, j: int) -> int:
        if i == j:
            return 0
        
        min_cost = float('inf')
        best_k = i
        
        for k in range(i, j):
            cost = (
                dp(i, k) + 
                dp(k + 1, j) + 
                dims[i - 1] * dims[k] * dims[j]
            )
            if cost < min_cost:
                min_cost = cost
                best_k = k
        
        split_points[(i, j)] = best_k
        return min_cost
    
    result = dp(1, n)
    return result, split_points

# Example usage
dims = [10, 30, 5, 60, 10]
cost, splits = matrix_chain_memoized(dims)
print(f"Minimum cost: {cost}")  # 5000

Bottom-Up DP Implementation

The bottom-up approach fills the table systematically by chain length. We start with chains of length 1 (cost 0), then length 2, and so on:

from typing import List, Tuple
import numpy as np

def matrix_chain_dp(dims: List[int]) -> Tuple[int, np.ndarray, np.ndarray]:
    """
    Bottom-up DP solution for matrix chain multiplication.
    
    Args:
        dims: List where dims[i-1] × dims[i] are dimensions of matrix i
        
    Returns:
        Tuple of (minimum cost, cost table, split table)
    """
    n = len(dims) - 1  # number of matrices
    
    # m[i,j] = minimum cost to multiply matrices i through j
    # s[i,j] = optimal split point for matrices i through j
    m = np.zeros((n + 1, n + 1), dtype=np.int64)
    s = np.zeros((n + 1, n + 1), dtype=np.int32)
    
    # Base case: single matrices have zero cost
    # (already initialized to 0)
    
    # Fill table by increasing chain length
    for chain_len in range(2, n + 1):  # l = chain length
        for i in range(1, n - chain_len + 2):  # start index
            j = i + chain_len - 1  # end index
            m[i, j] = float('inf')
            
            # Try all possible split points
            for k in range(i, j):
                cost = (
                    m[i, k] + 
                    m[k + 1, j] + 
                    dims[i - 1] * dims[k] * dims[j]
                )
                if cost < m[i, j]:
                    m[i, j] = cost
                    s[i, j] = k
    
    return m[1, n], m, s

# Example with detailed output
dims = [30, 35, 15, 5, 10, 20, 25]
min_cost, cost_table, split_table = matrix_chain_dp(dims)

print(f"Dimensions: {dims}")
print(f"Number of matrices: {len(dims) - 1}")
print(f"Minimum scalar multiplications: {min_cost}")
print(f"\nCost table (upper triangle):")
print(cost_table[1:, 1:])

The time complexity is O(n³): we have O(n²) subproblems, and each requires O(n) work to find the optimal split point. Space complexity is O(n²) for the two tables.

Reconstructing the Optimal Parenthesization

The split table tells us where to divide each subchain. We reconstruct the solution recursively:

def print_optimal_parens(s: np.ndarray, i: int, j: int, names: List[str] = None) -> str:
    """
    Recursively construct the optimal parenthesization string.
    
    Args:
        s: Split point table from DP solution
        i: Start index of chain
        j: End index of chain
        names: Optional list of matrix names (defaults to A1, A2, ...)
    """
    if names is None:
        names = [f"A{k}" for k in range(1, s.shape[0])]
    
    if i == j:
        return names[i - 1]
    
    k = s[i, j]
    left = print_optimal_parens(s, i, k, names)
    right = print_optimal_parens(s, k + 1, j, names)
    
    return f"({left} × {right})"

# Complete example
dims = [30, 35, 15, 5, 10, 20, 25]
min_cost, _, split_table = matrix_chain_dp(dims)

n = len(dims) - 1
result = print_optimal_parens(split_table, 1, n)
print(f"Optimal parenthesization: {result}")
print(f"Minimum cost: {min_cost}")

# Output: ((A1 × (A2 × A3)) × ((A4 × A5) × A6))
# Cost: 15125

Practical Considerations and Optimizations

In production systems, you’ll rarely implement matrix chain multiplication from scratch. Modern ML frameworks handle this automatically, but understanding the algorithm helps you structure computations efficiently.

Here’s how you might integrate optimal ordering with NumPy:

import numpy as np
from typing import List
from functools import reduce

def optimal_multi_dot(matrices: List[np.ndarray]) -> np.ndarray:
    """
    Multiply matrices in optimal order.
    Uses our DP solution to determine order, then numpy.linalg.multi_dot.
    """
    if len(matrices) <= 2:
        return reduce(np.matmul, matrices)
    
    # Extract dimensions
    dims = [matrices[0].shape[0]]
    dims.extend(m.shape[1] for m in matrices)
    
    # Get optimal ordering
    _, _, split_table = matrix_chain_dp(dims)
    
    def multiply_optimal(i: int, j: int) -> np.ndarray:
        if i == j:
            return matrices[i - 1]
        k = split_table[i, j]
        left = multiply_optimal(i, k)
        right = multiply_optimal(k + 1, j)
        return np.matmul(left, right)
    
    return multiply_optimal(1, len(matrices))

# Demonstration
np.random.seed(42)
matrices = [
    np.random.randn(100, 2),
    np.random.randn(2, 500),
    np.random.randn(500, 3),
    np.random.randn(3, 200),
]

# Compare with numpy's built-in (which also optimizes order)
result_optimal = optimal_multi_dot(matrices)
result_numpy = np.linalg.multi_dot(matrices)

print(f"Results match: {np.allclose(result_optimal, result_numpy)}")

For very long chains, the Hu-Shing algorithm achieves O(n log n) time complexity by exploiting geometric properties of the problem. However, the O(n³) solution is sufficient for most practical cases where n < 1000.

Space optimization is possible by noting that we only need the previous diagonal of the table to compute the current one. This reduces space to O(n), though the implementation becomes more complex.

Matrix chain multiplication is a canonical dynamic programming problem that demonstrates several key principles: identifying optimal substructure, defining the right subproblems, and reconstructing solutions from auxiliary data structures.

The same DP pattern appears in related problems. Optimal binary search tree construction uses nearly identical recurrence relations. Polygon triangulation—finding the minimum cost way to divide a polygon into triangles—is mathematically equivalent to matrix chain multiplication. Database query optimizers use similar techniques to determine join ordering.

When you encounter a problem involving sequential decisions where the cost depends on how you partition or combine elements, consider whether matrix chain multiplication’s structure applies. The O(n³) pattern of trying all split points within nested loops over subproblem sizes is a powerful template for many optimization problems.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.