How to Calculate Summary Statistics in PySpark
When your dataset fits in memory, pandas is the obvious choice. But once you're dealing with billions of rows across distributed storage, you need a tool that can parallelize statistical computations...
Key Insights
- PySpark’s
describe()gives you the basics, butsummary()adds percentiles—know when to use each for quick exploratory analysis versus detailed statistical profiling. - Always use
agg()with multiple aggregation functions in a single pass rather than chaining separate queries; this minimizes Spark job overhead and data shuffling. - Approximate methods like
approxQuantile()are often good enough for big data and can be orders of magnitude faster than exact calculations—embrace the tradeoff.
Why Summary Statistics in PySpark?
When your dataset fits in memory, pandas is the obvious choice. But once you’re dealing with billions of rows across distributed storage, you need a tool that can parallelize statistical computations across a cluster. PySpark fills that gap.
Summary statistics—mean, standard deviation, percentiles, correlations—form the foundation of any data analysis workflow. They help you understand distributions, detect outliers, and validate data quality before building models or dashboards. Getting these calculations right in a distributed environment requires understanding both the available APIs and their performance characteristics.
Setting Up Your PySpark Environment
Before diving into statistics, let’s establish a working environment. For local development, you can run PySpark in standalone mode. For production, you’ll connect to an existing cluster.
from pyspark.sql import SparkSession
# Local development setup
spark = SparkSession.builder \
.appName("SummaryStatistics") \
.master("local[*]") \
.config("spark.driver.memory", "4g") \
.getOrCreate()
# For cluster deployment, omit master() and let cluster manager handle it
# spark = SparkSession.builder \
# .appName("SummaryStatistics") \
# .getOrCreate()
# Load sample data
df = spark.read.csv("sales_data.csv", header=True, inferSchema=True)
# Or from Parquet (preferred for production)
df = spark.read.parquet("s3://your-bucket/sales_data/")
# Quick look at the schema
df.printSchema()
For this article, assume we’re working with a sales dataset containing columns like order_id, product_category, region, quantity, unit_price, total_amount, and order_date.
Basic Descriptive Statistics with describe() and summary()
PySpark provides two built-in methods for quick statistical overviews. Understanding their differences saves you time.
# describe() - the basics
df.describe().show()
Output:
+-------+------------------+------------------+------------------+
|summary| quantity| unit_price| total_amount|
+-------+------------------+------------------+------------------+
| count| 1000000| 1000000| 1000000|
| mean| 15.23| 49.87| 759.52|
| stddev| 8.41| 28.34| 523.18|
| min| 1| 5.00| 5.00|
| max| 50| 199.99| 9999.50|
+-------+------------------+------------------+------------------+
The describe() method gives you count, mean, standard deviation, min, and max. It’s fast and covers the essentials.
# summary() - more comprehensive
df.summary().show()
Output:
+-------+------------------+------------------+------------------+
|summary| quantity| unit_price| total_amount|
+-------+------------------+------------------+------------------+
| count| 1000000| 1000000| 1000000|
| mean| 15.23| 49.87| 759.52|
| stddev| 8.41| 28.34| 523.18|
| min| 1| 5.00| 5.00|
| 25%| 8| 25.00| 312.50|
| 50%| 15| 50.00| 725.00|
| 75%| 23| 75.00| 1150.00|
| max| 50| 199.99| 9999.50|
+-------+------------------+------------------+------------------+
The summary() method adds quartiles (25th, 50th, 75th percentiles). You can also request specific statistics:
# Custom summary with specific statistics
df.summary("count", "min", "max", "50%", "75%", "90%").show()
Use describe() for quick sanity checks. Use summary() when you need distribution insights.
Column-Level Statistics with Built-in Functions
For more control, use pyspark.sql.functions directly. This approach lets you compute exactly what you need in a single query.
from pyspark.sql import functions as F
# Single aggregation call with multiple statistics
stats = df.select(
F.count("total_amount").alias("count"),
F.mean("total_amount").alias("mean"),
F.stddev("total_amount").alias("stddev"),
F.stddev_pop("total_amount").alias("stddev_pop"), # Population stddev
F.variance("total_amount").alias("variance"),
F.min("total_amount").alias("min"),
F.max("total_amount").alias("max"),
F.sum("total_amount").alias("total"),
F.countDistinct("product_category").alias("unique_categories")
)
stats.show()
The agg() method provides equivalent functionality with slightly different syntax:
# Using agg() - same result, different style
stats = df.agg(
F.count("total_amount").alias("count"),
F.mean("total_amount").alias("mean"),
F.stddev("total_amount").alias("stddev"),
F.skewness("total_amount").alias("skewness"),
F.kurtosis("total_amount").alias("kurtosis")
)
stats.show()
Note the distinction between stddev() (sample standard deviation) and stddev_pop() (population standard deviation). Use sample when your data represents a sample of a larger population; use population when you have the complete dataset.
Grouped Summary Statistics
Real-world analysis rarely looks at global statistics alone. You need breakdowns by category, region, time period, or other dimensions.
# Statistics by single group
region_stats = df.groupBy("region").agg(
F.count("order_id").alias("order_count"),
F.sum("total_amount").alias("total_revenue"),
F.mean("total_amount").alias("avg_order_value"),
F.stddev("total_amount").alias("order_value_stddev")
).orderBy(F.desc("total_revenue"))
region_stats.show()
For multi-dimensional analysis, chain multiple columns in groupBy():
# Statistics by region and product category
detailed_stats = df.groupBy("region", "product_category").agg(
F.count("order_id").alias("order_count"),
F.sum("quantity").alias("units_sold"),
F.sum("total_amount").alias("revenue"),
F.mean("unit_price").alias("avg_unit_price"),
F.min("order_date").alias("first_order"),
F.max("order_date").alias("last_order")
).orderBy("region", F.desc("revenue"))
detailed_stats.show(20)
You can also compute multiple statistics for multiple columns simultaneously:
# Multiple aggregations on multiple columns
multi_stats = df.groupBy("region").agg(
F.mean("quantity").alias("avg_quantity"),
F.mean("unit_price").alias("avg_price"),
F.mean("total_amount").alias("avg_total"),
F.stddev("quantity").alias("stddev_quantity"),
F.stddev("unit_price").alias("stddev_price"),
F.stddev("total_amount").alias("stddev_total")
)
multi_stats.show()
Advanced Statistics: Correlation, Covariance, and Percentiles
Beyond basic descriptive statistics, PySpark supports correlation analysis and custom percentile calculations.
# Correlation between two columns
correlation = df.stat.corr("quantity", "total_amount")
print(f"Correlation between quantity and total_amount: {correlation:.4f}")
# Covariance
covariance = df.stat.cov("quantity", "total_amount")
print(f"Covariance: {covariance:.4f}")
For a correlation matrix across multiple columns, you’ll need to iterate:
# Build a correlation matrix
numeric_cols = ["quantity", "unit_price", "total_amount"]
# Create correlation matrix
corr_matrix = []
for col1 in numeric_cols:
row = []
for col2 in numeric_cols:
corr = df.stat.corr(col1, col2)
row.append(round(corr, 4))
corr_matrix.append(row)
# Display as DataFrame
import pandas as pd
corr_df = pd.DataFrame(corr_matrix, index=numeric_cols, columns=numeric_cols)
print(corr_df)
Percentile calculations use approxQuantile(), which provides approximate results for better performance:
# Approximate percentiles (much faster for large datasets)
percentiles = df.stat.approxQuantile(
"total_amount",
[0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99],
0.01 # Relative error - lower = more accurate but slower
)
print("Percentiles for total_amount:")
for p, v in zip([10, 25, 50, 75, 90, 95, 99], percentiles):
print(f" {p}th: {v:.2f}")
The third parameter controls accuracy. A relative error of 0.01 means results are within 1% of the true value. For most analytical purposes, this is sufficient and dramatically faster than exact calculations.
# Percentiles for multiple columns
cols = ["quantity", "unit_price", "total_amount"]
percentile_points = [0.25, 0.5, 0.75]
for col in cols:
percentiles = df.stat.approxQuantile(col, percentile_points, 0.01)
print(f"{col}: Q1={percentiles[0]:.2f}, Median={percentiles[1]:.2f}, Q3={percentiles[2]:.2f}")
Performance Considerations and Best Practices
Calculating summary statistics on large datasets can be expensive. Here’s how to optimize.
Cache when computing multiple statistics:
# Without caching - each aggregation triggers a full scan
stats1 = df.groupBy("region").agg(F.mean("total_amount")).collect()
stats2 = df.groupBy("region").agg(F.stddev("total_amount")).collect()
stats3 = df.groupBy("region").agg(F.sum("total_amount")).collect()
# Three full scans of the data
# With caching - compute once, reuse
df.cache()
df.count() # Trigger caching
stats1 = df.groupBy("region").agg(F.mean("total_amount")).collect()
stats2 = df.groupBy("region").agg(F.stddev("total_amount")).collect()
stats3 = df.groupBy("region").agg(F.sum("total_amount")).collect()
# Data read from cache
df.unpersist() # Release memory when done
Better yet, combine aggregations:
# Best approach - single pass with all aggregations
all_stats = df.groupBy("region").agg(
F.mean("total_amount").alias("mean"),
F.stddev("total_amount").alias("stddev"),
F.sum("total_amount").alias("total")
).collect()
# One scan, all statistics
Avoid collecting large results to the driver:
# Dangerous for large cardinality groupings
# all_data = df.groupBy("customer_id").agg(...).collect() # Don't do this
# Instead, write results to storage
df.groupBy("customer_id").agg(
F.count("order_id").alias("order_count"),
F.sum("total_amount").alias("lifetime_value")
).write.parquet("customer_stats/")
Filter before aggregating when possible:
# Compute statistics only for relevant subset
recent_stats = df.filter(F.col("order_date") >= "2024-01-01") \
.groupBy("region") \
.agg(F.mean("total_amount").alias("avg_recent_order"))
Summary statistics in PySpark follow familiar patterns but require awareness of distributed computing tradeoffs. Use built-in methods for exploration, explicit aggregations for production pipelines, and always consider whether you need exact or approximate results. Your cluster will thank you.