How to Apply Jensen's Inequality

Jensen's inequality is one of those mathematical results that seems abstract until you realize it's everywhere in statistics and machine learning. The inequality states that for a convex function f...

Key Insights

  • Jensen’s inequality states that for convex functions, f(E[X]) ≤ E[f(X)], providing a fundamental tool for bounding expectations and proving that transforming then averaging differs from averaging then transforming
  • Checking convexity via the second derivative test (f’’(x) ≥ 0) is essential before applying the inequality—applying it to concave functions reverses the inequality direction
  • The inequality powers critical algorithms like Expectation-Maximization and variational inference by establishing lower bounds that make intractable optimization problems solvable

Introduction to Jensen’s Inequality

Jensen’s inequality is one of those mathematical results that seems abstract until you realize it’s everywhere in statistics and machine learning. The inequality states that for a convex function f and a random variable X:

f(E[X]) ≤ E[f(X)]

In plain terms: applying a convex function to an average gives you less than or equal to the average of applying the function. This asymmetry is powerful because it lets you establish bounds on expectations that are otherwise difficult to compute directly.

Why does this matter? In machine learning, we constantly deal with expectations we can’t compute analytically. Jensen’s inequality gives us tractable bounds. In optimization, it provides the theoretical foundation for iterative algorithms like EM. In finance and risk analysis, it explains why expected utility differs from utility of expected returns.

Let’s visualize this with the simplest convex function, f(x) = x²:

import numpy as np
import matplotlib.pyplot as plt

# Generate random samples
np.random.seed(42)
X = np.random.normal(loc=2, scale=1.5, size=1000)

# Calculate f(E[X]) and E[f(X)]
E_X = np.mean(X)
f_E_X = E_X**2
E_f_X = np.mean(X**2)

# Visualize
x_range = np.linspace(-2, 6, 100)
plt.figure(figsize=(10, 6))
plt.plot(x_range, x_range**2, 'b-', linewidth=2, label='f(x) = x²')
plt.scatter(X, X**2, alpha=0.3, s=10, label='Sample points')
plt.axvline(E_X, color='red', linestyle='--', label=f'E[X] = {E_X:.2f}')
plt.scatter([E_X], [f_E_X], color='red', s=200, zorder=5, 
            label=f'f(E[X]) = {f_E_X:.2f}')
plt.axhline(E_f_X, color='green', linestyle='--', 
            label=f'E[f(X)] = {E_f_X:.2f}')
plt.legend()
plt.xlabel('x')
plt.ylabel('f(x)')
plt.title("Jensen's Inequality: f(E[X]) ≤ E[f(X)]")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('jensens_inequality_basic.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"f(E[X]) = {f_E_X:.4f}")
print(f"E[f(X)] = {E_f_X:.4f}")
print(f"Gap: {E_f_X - f_E_X:.4f}")

The gap between E[f(X)] and f(E[X]) quantifies how much the nonlinearity of f amplifies the variance in X.

Understanding Convexity

Before applying Jensen’s inequality, you must verify that your function is convex. A function f is convex if for any two points, the line segment connecting them lies above the function. Mathematically, f is convex if:

f(λx + (1-λ)y) ≤ λf(x) + (1-λ)f(y) for all λ ∈ [0,1]

The practical test: check if f’’(x) ≥ 0 everywhere in the domain. If f’’(x) ≤ 0, the function is concave, and the inequality reverses: f(E[X]) ≥ E[f(X)].

import numpy as np
import matplotlib.pyplot as plt

def check_convexity(f, f_second_deriv, x_range):
    """Check and visualize convexity of a function."""
    x = np.linspace(x_range[0], x_range[1], 1000)
    y = f(x)
    y_second = f_second_deriv(x)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot function
    ax1.plot(x, y, 'b-', linewidth=2)
    ax1.set_xlabel('x')
    ax1.set_ylabel('f(x)')
    ax1.set_title('Function')
    ax1.grid(True, alpha=0.3)
    
    # Plot second derivative
    ax2.plot(x, y_second, 'r-', linewidth=2)
    ax2.axhline(y=0, color='k', linestyle='--', alpha=0.5)
    ax2.set_xlabel('x')
    ax2.set_ylabel("f''(x)")
    ax2.set_title('Second Derivative (≥0 for convex)')
    ax2.grid(True, alpha=0.3)
    
    is_convex = np.all(y_second >= -1e-10)  # Small tolerance for numerical errors
    return is_convex

# Test different functions
functions = {
    'x²': (lambda x: x**2, lambda x: 2*np.ones_like(x), (-3, 3)),
    'exp(x)': (lambda x: np.exp(x), lambda x: np.exp(x), (-2, 2)),
    'log(x)': (lambda x: np.log(x), lambda x: -1/x**2, (0.1, 5)),
    '-x²': (lambda x: -x**2, lambda x: -2*np.ones_like(x), (-3, 3))
}

for name, (f, f_pp, x_range) in functions.items():
    is_convex = check_convexity(f, f_pp, x_range)
    convexity_type = "CONVEX" if is_convex else "CONCAVE"
    print(f"{name}: {convexity_type}")

Classic Applications in Statistics

Jensen’s inequality proves several fundamental statistical results.

Variance is non-negative: Since f(x) = x² is convex, E[X²] ≥ (E[X])², which rearranges to Var(X) = E[X²] - (E[X])² ≥ 0.

Arithmetic Mean ≥ Geometric Mean: The logarithm is concave, so the inequality reverses. This gives us the AM-GM inequality, a cornerstone of optimization.

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

# Generate positive random samples with varying variance
sample_sizes = [10, 100, 1000]
variances = [0.1, 0.5, 1.0, 2.0]

results = []

for var in variances:
    for n in sample_sizes:
        # Generate log-normal samples (always positive)
        X = np.random.lognormal(mean=1, sigma=np.sqrt(var), size=n)
        
        # Arithmetic mean
        AM = np.mean(X)
        
        # Geometric mean
        GM = np.exp(np.mean(np.log(X)))
        
        gap = AM - GM
        relative_gap = gap / AM
        
        results.append({
            'variance': var,
            'n': n,
            'AM': AM,
            'GM': GM,
            'gap': gap,
            'relative_gap': relative_gap
        })

# Analyze how gap depends on variance
import pandas as pd
df = pd.DataFrame(results)

# Group by variance and average over sample sizes
summary = df.groupby('variance').agg({
    'AM': 'mean',
    'GM': 'mean',
    'relative_gap': 'mean'
}).reset_index()

print("\nAM-GM Gap vs Variance:")
print(summary)
print("\nKey observation: The gap increases with variance,")
print("demonstrating Jensen's inequality in action.")

# Verify inequality holds
violations = df[df['AM'] < df['GM']]
print(f"\nInequality violations: {len(violations)} out of {len(df)}")

The AM-GM gap increases with variance because the concave log function “pulls down” the average more when the data is spread out.

Machine Learning Applications

Jensen’s inequality is the workhorse behind several ML algorithms. The Expectation-Maximization (EM) algorithm uses it to create a lower bound on the log-likelihood that’s easier to optimize.

For a Gaussian mixture model, the log-likelihood involves a log of a sum, which is intractable. Jensen’s inequality converts this into a tractable lower bound:

import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt

class GaussianMixtureEM:
    def __init__(self, n_components=2):
        self.n_components = n_components
        self.means = None
        self.stds = None
        self.weights = None
        
    def initialize(self, X):
        """Initialize parameters randomly."""
        self.means = np.random.choice(X, self.n_components)
        self.stds = np.ones(self.n_components)
        self.weights = np.ones(self.n_components) / self.n_components
        
    def e_step(self, X):
        """E-step: Compute responsibilities using current parameters."""
        n = len(X)
        responsibilities = np.zeros((n, self.n_components))
        
        for k in range(self.n_components):
            responsibilities[:, k] = self.weights[k] * \
                norm.pdf(X, self.means[k], self.stds[k])
        
        # Normalize
        responsibilities /= responsibilities.sum(axis=1, keepdims=True)
        return responsibilities
    
    def m_step(self, X, responsibilities):
        """M-step: Update parameters to maximize lower bound."""
        n = len(X)
        
        for k in range(self.n_components):
            resp_k = responsibilities[:, k]
            n_k = resp_k.sum()
            
            self.weights[k] = n_k / n
            self.means[k] = (resp_k @ X) / n_k
            self.stds[k] = np.sqrt((resp_k @ (X - self.means[k])**2) / n_k)
    
    def lower_bound(self, X, responsibilities):
        """Compute the lower bound (ELBO) given by Jensen's inequality."""
        log_likelihood = 0
        for k in range(self.n_components):
            log_likelihood += responsibilities[:, k] @ (
                np.log(self.weights[k]) + 
                norm.logpdf(X, self.means[k], self.stds[k])
            )
        return log_likelihood
    
    def fit(self, X, max_iters=20):
        """Fit the model using EM algorithm."""
        self.initialize(X)
        lower_bounds = []
        
        for i in range(max_iters):
            # E-step
            responsibilities = self.e_step(X)
            
            # Compute lower bound
            lb = self.lower_bound(X, responsibilities)
            lower_bounds.append(lb)
            
            # M-step
            self.m_step(X, responsibilities)
            
            if i > 0 and abs(lower_bounds[-1] - lower_bounds[-2]) < 1e-6:
                break
                
        return lower_bounds

# Generate mixture data
np.random.seed(42)
X = np.concatenate([
    np.random.normal(-2, 0.5, 300),
    np.random.normal(2, 0.8, 200)
])

# Fit model
gmm = GaussianMixtureEM(n_components=2)
lower_bounds = gmm.fit(X)

# Plot lower bound progression
plt.figure(figsize=(10, 5))
plt.plot(lower_bounds, 'o-', linewidth=2, markersize=8)
plt.xlabel('Iteration')
plt.ylabel('Lower Bound (ELBO)')
plt.title('EM Algorithm: Lower Bound Increases Monotonically')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Final means: {gmm.means}")
print(f"Final weights: {gmm.weights}")
print("Lower bound increases monotonically due to Jensen's inequality")

The EM algorithm is guaranteed to increase the lower bound at each iteration because Jensen’s inequality ensures the bound is tight.

Practical Implementation Patterns

When working with complex expectations, Jensen’s inequality provides bounds that guide algorithm design. Here’s a practical utility for checking convexity and applying the inequality:

import numpy as np
from scipy.misc import derivative

def check_convexity_numerical(f, x_range, n_points=100):
    """Numerically check if function is convex over a range."""
    x = np.linspace(x_range[0], x_range[1], n_points)
    
    # Compute second derivative numerically
    second_derivs = []
    for xi in x:
        d2 = derivative(f, xi, n=2, dx=1e-5)
        second_derivs.append(d2)
    
    second_derivs = np.array(second_derivs)
    
    is_convex = np.all(second_derivs >= -1e-6)
    is_concave = np.all(second_derivs <= 1e-6)
    
    return {
        'convex': is_convex,
        'concave': is_concave,
        'min_second_deriv': second_derivs.min(),
        'max_second_deriv': second_derivs.max()
    }

def apply_jensens_bound(f, samples, convexity_check=True):
    """
    Apply Jensen's inequality to bound E[f(X)].
    
    Returns both sides of the inequality and validates convexity.
    """
    if convexity_check:
        x_min, x_max = samples.min(), samples.max()
        margin = (x_max - x_min) * 0.1
        conv_info = check_convexity_numerical(f, (x_min - margin, x_max + margin))
        
        if not (conv_info['convex'] or conv_info['concave']):
            print("Warning: Function appears neither convex nor concave!")
            print(f"Second derivative range: [{conv_info['min_second_deriv']:.4f}, "
                  f"{conv_info['max_second_deriv']:.4f}]")
    
    # Compute both sides
    f_E_X = f(np.mean(samples))
    E_f_X = np.mean([f(x) for x in samples])
    
    result = {
        'f(E[X])': f_E_X,
        'E[f(X)]': E_f_X,
        'gap': abs(E_f_X - f_E_X),
        'inequality_satisfied': None
    }
    
    if conv_info['convex']:
        result['inequality_satisfied'] = f_E_X <= E_f_X + 1e-10
        result['inequality_type'] = 'f(E[X]) ≤ E[f(X)]'
    elif conv_info['concave']:
        result['inequality_satisfied'] = f_E_X >= E_f_X - 1e-10
        result['inequality_type'] = 'f(E[X]) ≥ E[f(X)]'
    
    return result

# Example usage
np.random.seed(42)
X = np.random.gamma(2, 2, 1000)

# Convex function: exponential
result_exp = apply_jensens_bound(lambda x: np.exp(x/10), X)
print("Exponential function (convex):")
print(f"  f(E[X]) = {result_exp['f(E[X])']:.4f}")
print(f"  E[f(X)] = {result_exp['E[f(X)]']:.4f}")
print(f"  Inequality satisfied: {result_exp['inequality_satisfied']}")

Real-World Case Study: Portfolio Risk Analysis

In finance, Jensen’s inequality explains why expected utility differs from utility of expected returns. Risk-averse investors have concave utility functions, so the utility of the expected return exceeds the expected utility of returns.

import numpy as np
import matplotlib.pyplot as plt

# Risk-averse utility functions (all concave)
def log_utility(x):
    return np.log(x + 1)  # Shift to handle x=0

def power_utility(x, gamma=0.5):
    return (x + 1)**gamma

def exponential_utility(x, a=0.1):
    return -np.exp(-a * x)

# Simulate portfolio returns
np.random.seed(42)
n_simulations = 10000
expected_return = 0.07
volatility = 0.15
returns = np.random.normal(expected_return, volatility, n_simulations)

# Calculate for different utility functions
utilities = {
    'Logarithmic': log_utility,
    'Power (γ=0.5)': power_utility,
    'Exponential': exponential_utility
}

results = []

for name, utility_func in utilities.items():
    # E[U(X)] - Expected utility
    expected_utility = np.mean([utility_func(r) for r in returns])
    
    # U(E[X]) - Utility of expected return
    utility_of_expected = utility_func(expected_return)
    
    # Certainty equivalent: CE such that U(CE) = E[U(X)]
    # For log utility: CE = exp(E[U(X)]) - 1
    
    gap = utility_of_expected - expected_utility
    
    results.append({
        'Utility': name,
        'U(E[X])': utility_of_expected,
        'E[U(X)]': expected_utility,
        'Gap': gap
    })
    
    print(f"\n{name}:")
    print(f"  U(E[X]) = {utility_of_expected:.6f}")
    print(f"  E[U(X)] = {expected_utility:.6f}")
    print(f"  Gap = {gap:.6f} (positive confirms concavity)")

# Risk premium: how much expected return investor would give up for certainty
# This is quantified by Jensen's inequality
print("\n" + "="*50)
print("Jensen's inequality shows that risk-averse investors")
print("value certainty over gambles with the same expected return.")
print("The gap measures the 'risk premium' they demand.")

This demonstrates why rational investors demand higher expected returns for volatile assets—Jensen’s inequality quantifies the utility cost of risk.

Jensen’s inequality isn’t just theoretical mathematics. It’s a practical tool that bounds expectations, powers optimization algorithms, and explains fundamental phenomena in statistics, machine learning, and finance. Master convexity checking, understand when to apply the inequality, and you’ll have a powerful technique for reasoning about nonlinear transformations of random variables.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.