How to Apply Jensen's Inequality
Jensen's inequality is one of those mathematical results that seems abstract until you realize it's everywhere in statistics and machine learning. The inequality states that for a convex function f...
Key Insights
- Jensen’s inequality states that for convex functions, f(E[X]) ≤ E[f(X)], providing a fundamental tool for bounding expectations and proving that transforming then averaging differs from averaging then transforming
- Checking convexity via the second derivative test (f’’(x) ≥ 0) is essential before applying the inequality—applying it to concave functions reverses the inequality direction
- The inequality powers critical algorithms like Expectation-Maximization and variational inference by establishing lower bounds that make intractable optimization problems solvable
Introduction to Jensen’s Inequality
Jensen’s inequality is one of those mathematical results that seems abstract until you realize it’s everywhere in statistics and machine learning. The inequality states that for a convex function f and a random variable X:
f(E[X]) ≤ E[f(X)]
In plain terms: applying a convex function to an average gives you less than or equal to the average of applying the function. This asymmetry is powerful because it lets you establish bounds on expectations that are otherwise difficult to compute directly.
Why does this matter? In machine learning, we constantly deal with expectations we can’t compute analytically. Jensen’s inequality gives us tractable bounds. In optimization, it provides the theoretical foundation for iterative algorithms like EM. In finance and risk analysis, it explains why expected utility differs from utility of expected returns.
Let’s visualize this with the simplest convex function, f(x) = x²:
import numpy as np
import matplotlib.pyplot as plt
# Generate random samples
np.random.seed(42)
X = np.random.normal(loc=2, scale=1.5, size=1000)
# Calculate f(E[X]) and E[f(X)]
E_X = np.mean(X)
f_E_X = E_X**2
E_f_X = np.mean(X**2)
# Visualize
x_range = np.linspace(-2, 6, 100)
plt.figure(figsize=(10, 6))
plt.plot(x_range, x_range**2, 'b-', linewidth=2, label='f(x) = x²')
plt.scatter(X, X**2, alpha=0.3, s=10, label='Sample points')
plt.axvline(E_X, color='red', linestyle='--', label=f'E[X] = {E_X:.2f}')
plt.scatter([E_X], [f_E_X], color='red', s=200, zorder=5,
label=f'f(E[X]) = {f_E_X:.2f}')
plt.axhline(E_f_X, color='green', linestyle='--',
label=f'E[f(X)] = {E_f_X:.2f}')
plt.legend()
plt.xlabel('x')
plt.ylabel('f(x)')
plt.title("Jensen's Inequality: f(E[X]) ≤ E[f(X)]")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('jensens_inequality_basic.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"f(E[X]) = {f_E_X:.4f}")
print(f"E[f(X)] = {E_f_X:.4f}")
print(f"Gap: {E_f_X - f_E_X:.4f}")
The gap between E[f(X)] and f(E[X]) quantifies how much the nonlinearity of f amplifies the variance in X.
Understanding Convexity
Before applying Jensen’s inequality, you must verify that your function is convex. A function f is convex if for any two points, the line segment connecting them lies above the function. Mathematically, f is convex if:
f(λx + (1-λ)y) ≤ λf(x) + (1-λ)f(y) for all λ ∈ [0,1]
The practical test: check if f’’(x) ≥ 0 everywhere in the domain. If f’’(x) ≤ 0, the function is concave, and the inequality reverses: f(E[X]) ≥ E[f(X)].
import numpy as np
import matplotlib.pyplot as plt
def check_convexity(f, f_second_deriv, x_range):
"""Check and visualize convexity of a function."""
x = np.linspace(x_range[0], x_range[1], 1000)
y = f(x)
y_second = f_second_deriv(x)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Plot function
ax1.plot(x, y, 'b-', linewidth=2)
ax1.set_xlabel('x')
ax1.set_ylabel('f(x)')
ax1.set_title('Function')
ax1.grid(True, alpha=0.3)
# Plot second derivative
ax2.plot(x, y_second, 'r-', linewidth=2)
ax2.axhline(y=0, color='k', linestyle='--', alpha=0.5)
ax2.set_xlabel('x')
ax2.set_ylabel("f''(x)")
ax2.set_title('Second Derivative (≥0 for convex)')
ax2.grid(True, alpha=0.3)
is_convex = np.all(y_second >= -1e-10) # Small tolerance for numerical errors
return is_convex
# Test different functions
functions = {
'x²': (lambda x: x**2, lambda x: 2*np.ones_like(x), (-3, 3)),
'exp(x)': (lambda x: np.exp(x), lambda x: np.exp(x), (-2, 2)),
'log(x)': (lambda x: np.log(x), lambda x: -1/x**2, (0.1, 5)),
'-x²': (lambda x: -x**2, lambda x: -2*np.ones_like(x), (-3, 3))
}
for name, (f, f_pp, x_range) in functions.items():
is_convex = check_convexity(f, f_pp, x_range)
convexity_type = "CONVEX" if is_convex else "CONCAVE"
print(f"{name}: {convexity_type}")
Classic Applications in Statistics
Jensen’s inequality proves several fundamental statistical results.
Variance is non-negative: Since f(x) = x² is convex, E[X²] ≥ (E[X])², which rearranges to Var(X) = E[X²] - (E[X])² ≥ 0.
Arithmetic Mean ≥ Geometric Mean: The logarithm is concave, so the inequality reverses. This gives us the AM-GM inequality, a cornerstone of optimization.
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(42)
# Generate positive random samples with varying variance
sample_sizes = [10, 100, 1000]
variances = [0.1, 0.5, 1.0, 2.0]
results = []
for var in variances:
for n in sample_sizes:
# Generate log-normal samples (always positive)
X = np.random.lognormal(mean=1, sigma=np.sqrt(var), size=n)
# Arithmetic mean
AM = np.mean(X)
# Geometric mean
GM = np.exp(np.mean(np.log(X)))
gap = AM - GM
relative_gap = gap / AM
results.append({
'variance': var,
'n': n,
'AM': AM,
'GM': GM,
'gap': gap,
'relative_gap': relative_gap
})
# Analyze how gap depends on variance
import pandas as pd
df = pd.DataFrame(results)
# Group by variance and average over sample sizes
summary = df.groupby('variance').agg({
'AM': 'mean',
'GM': 'mean',
'relative_gap': 'mean'
}).reset_index()
print("\nAM-GM Gap vs Variance:")
print(summary)
print("\nKey observation: The gap increases with variance,")
print("demonstrating Jensen's inequality in action.")
# Verify inequality holds
violations = df[df['AM'] < df['GM']]
print(f"\nInequality violations: {len(violations)} out of {len(df)}")
The AM-GM gap increases with variance because the concave log function “pulls down” the average more when the data is spread out.
Machine Learning Applications
Jensen’s inequality is the workhorse behind several ML algorithms. The Expectation-Maximization (EM) algorithm uses it to create a lower bound on the log-likelihood that’s easier to optimize.
For a Gaussian mixture model, the log-likelihood involves a log of a sum, which is intractable. Jensen’s inequality converts this into a tractable lower bound:
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
class GaussianMixtureEM:
def __init__(self, n_components=2):
self.n_components = n_components
self.means = None
self.stds = None
self.weights = None
def initialize(self, X):
"""Initialize parameters randomly."""
self.means = np.random.choice(X, self.n_components)
self.stds = np.ones(self.n_components)
self.weights = np.ones(self.n_components) / self.n_components
def e_step(self, X):
"""E-step: Compute responsibilities using current parameters."""
n = len(X)
responsibilities = np.zeros((n, self.n_components))
for k in range(self.n_components):
responsibilities[:, k] = self.weights[k] * \
norm.pdf(X, self.means[k], self.stds[k])
# Normalize
responsibilities /= responsibilities.sum(axis=1, keepdims=True)
return responsibilities
def m_step(self, X, responsibilities):
"""M-step: Update parameters to maximize lower bound."""
n = len(X)
for k in range(self.n_components):
resp_k = responsibilities[:, k]
n_k = resp_k.sum()
self.weights[k] = n_k / n
self.means[k] = (resp_k @ X) / n_k
self.stds[k] = np.sqrt((resp_k @ (X - self.means[k])**2) / n_k)
def lower_bound(self, X, responsibilities):
"""Compute the lower bound (ELBO) given by Jensen's inequality."""
log_likelihood = 0
for k in range(self.n_components):
log_likelihood += responsibilities[:, k] @ (
np.log(self.weights[k]) +
norm.logpdf(X, self.means[k], self.stds[k])
)
return log_likelihood
def fit(self, X, max_iters=20):
"""Fit the model using EM algorithm."""
self.initialize(X)
lower_bounds = []
for i in range(max_iters):
# E-step
responsibilities = self.e_step(X)
# Compute lower bound
lb = self.lower_bound(X, responsibilities)
lower_bounds.append(lb)
# M-step
self.m_step(X, responsibilities)
if i > 0 and abs(lower_bounds[-1] - lower_bounds[-2]) < 1e-6:
break
return lower_bounds
# Generate mixture data
np.random.seed(42)
X = np.concatenate([
np.random.normal(-2, 0.5, 300),
np.random.normal(2, 0.8, 200)
])
# Fit model
gmm = GaussianMixtureEM(n_components=2)
lower_bounds = gmm.fit(X)
# Plot lower bound progression
plt.figure(figsize=(10, 5))
plt.plot(lower_bounds, 'o-', linewidth=2, markersize=8)
plt.xlabel('Iteration')
plt.ylabel('Lower Bound (ELBO)')
plt.title('EM Algorithm: Lower Bound Increases Monotonically')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"Final means: {gmm.means}")
print(f"Final weights: {gmm.weights}")
print("Lower bound increases monotonically due to Jensen's inequality")
The EM algorithm is guaranteed to increase the lower bound at each iteration because Jensen’s inequality ensures the bound is tight.
Practical Implementation Patterns
When working with complex expectations, Jensen’s inequality provides bounds that guide algorithm design. Here’s a practical utility for checking convexity and applying the inequality:
import numpy as np
from scipy.misc import derivative
def check_convexity_numerical(f, x_range, n_points=100):
"""Numerically check if function is convex over a range."""
x = np.linspace(x_range[0], x_range[1], n_points)
# Compute second derivative numerically
second_derivs = []
for xi in x:
d2 = derivative(f, xi, n=2, dx=1e-5)
second_derivs.append(d2)
second_derivs = np.array(second_derivs)
is_convex = np.all(second_derivs >= -1e-6)
is_concave = np.all(second_derivs <= 1e-6)
return {
'convex': is_convex,
'concave': is_concave,
'min_second_deriv': second_derivs.min(),
'max_second_deriv': second_derivs.max()
}
def apply_jensens_bound(f, samples, convexity_check=True):
"""
Apply Jensen's inequality to bound E[f(X)].
Returns both sides of the inequality and validates convexity.
"""
if convexity_check:
x_min, x_max = samples.min(), samples.max()
margin = (x_max - x_min) * 0.1
conv_info = check_convexity_numerical(f, (x_min - margin, x_max + margin))
if not (conv_info['convex'] or conv_info['concave']):
print("Warning: Function appears neither convex nor concave!")
print(f"Second derivative range: [{conv_info['min_second_deriv']:.4f}, "
f"{conv_info['max_second_deriv']:.4f}]")
# Compute both sides
f_E_X = f(np.mean(samples))
E_f_X = np.mean([f(x) for x in samples])
result = {
'f(E[X])': f_E_X,
'E[f(X)]': E_f_X,
'gap': abs(E_f_X - f_E_X),
'inequality_satisfied': None
}
if conv_info['convex']:
result['inequality_satisfied'] = f_E_X <= E_f_X + 1e-10
result['inequality_type'] = 'f(E[X]) ≤ E[f(X)]'
elif conv_info['concave']:
result['inequality_satisfied'] = f_E_X >= E_f_X - 1e-10
result['inequality_type'] = 'f(E[X]) ≥ E[f(X)]'
return result
# Example usage
np.random.seed(42)
X = np.random.gamma(2, 2, 1000)
# Convex function: exponential
result_exp = apply_jensens_bound(lambda x: np.exp(x/10), X)
print("Exponential function (convex):")
print(f" f(E[X]) = {result_exp['f(E[X])']:.4f}")
print(f" E[f(X)] = {result_exp['E[f(X)]']:.4f}")
print(f" Inequality satisfied: {result_exp['inequality_satisfied']}")
Real-World Case Study: Portfolio Risk Analysis
In finance, Jensen’s inequality explains why expected utility differs from utility of expected returns. Risk-averse investors have concave utility functions, so the utility of the expected return exceeds the expected utility of returns.
import numpy as np
import matplotlib.pyplot as plt
# Risk-averse utility functions (all concave)
def log_utility(x):
return np.log(x + 1) # Shift to handle x=0
def power_utility(x, gamma=0.5):
return (x + 1)**gamma
def exponential_utility(x, a=0.1):
return -np.exp(-a * x)
# Simulate portfolio returns
np.random.seed(42)
n_simulations = 10000
expected_return = 0.07
volatility = 0.15
returns = np.random.normal(expected_return, volatility, n_simulations)
# Calculate for different utility functions
utilities = {
'Logarithmic': log_utility,
'Power (γ=0.5)': power_utility,
'Exponential': exponential_utility
}
results = []
for name, utility_func in utilities.items():
# E[U(X)] - Expected utility
expected_utility = np.mean([utility_func(r) for r in returns])
# U(E[X]) - Utility of expected return
utility_of_expected = utility_func(expected_return)
# Certainty equivalent: CE such that U(CE) = E[U(X)]
# For log utility: CE = exp(E[U(X)]) - 1
gap = utility_of_expected - expected_utility
results.append({
'Utility': name,
'U(E[X])': utility_of_expected,
'E[U(X)]': expected_utility,
'Gap': gap
})
print(f"\n{name}:")
print(f" U(E[X]) = {utility_of_expected:.6f}")
print(f" E[U(X)] = {expected_utility:.6f}")
print(f" Gap = {gap:.6f} (positive confirms concavity)")
# Risk premium: how much expected return investor would give up for certainty
# This is quantified by Jensen's inequality
print("\n" + "="*50)
print("Jensen's inequality shows that risk-averse investors")
print("value certainty over gambles with the same expected return.")
print("The gap measures the 'risk premium' they demand.")
This demonstrates why rational investors demand higher expected returns for volatile assets—Jensen’s inequality quantifies the utility cost of risk.
Jensen’s inequality isn’t just theoretical mathematics. It’s a practical tool that bounds expectations, powers optimization algorithms, and explains fundamental phenomena in statistics, machine learning, and finance. Master convexity checking, understand when to apply the inequality, and you’ll have a powerful technique for reasoning about nonlinear transformations of random variables.