PySpark - Describe/Summary Statistics of DataFrame
When working with large-scale datasets in PySpark, understanding your data's statistical properties is the first step toward meaningful analysis. Summary statistics reveal data distributions,...
Key Insights
- PySpark’s
describe()andsummary()methods provide quick statistical overviews, butsummary()includes additional percentiles (25th, 50th, 75th) thatdescribe()omits by default. - For large-scale datasets, use
approxQuantile()instead of exact percentile calculations—the performance difference can be orders of magnitude with negligible accuracy loss at higher relative errors (0.01-0.05). - Combining
groupBy()withagg()and multiple statistical functions in a single pass is far more efficient than running separate queries, especially on partitioned data where Spark can parallelize aggregations.
Introduction to DataFrame Statistics in PySpark
When working with large-scale datasets in PySpark, understanding your data’s statistical properties is the first step toward meaningful analysis. Summary statistics reveal data distributions, identify outliers, expose quality issues, and guide feature engineering decisions. Unlike pandas, where statistics compute on a single machine, PySpark distributes these calculations across a cluster, making it possible to profile datasets that would never fit in memory.
PySpark provides multiple approaches for statistical analysis, from high-level convenience methods like describe() to granular functions for specific metrics. Choosing the right approach depends on your needs: quick exploration versus detailed analysis, exact versus approximate calculations, and single-column versus multi-column operations.
Let’s start with a representative dataset:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
from pyspark.sql.functions import *
spark = SparkSession.builder.appName("StatisticsDemo").getOrCreate()
# Create sample sales data
data = [
("Electronics", "Laptop", 1200.00, 15, "North"),
("Electronics", "Mouse", 25.00, 150, "South"),
("Furniture", "Desk", 350.00, 45, "East"),
("Electronics", "Keyboard", 75.00, 89, "North"),
("Furniture", "Chair", 200.00, 120, "West"),
("Electronics", "Monitor", 300.00, 67, "South"),
("Furniture", "Lamp", 45.00, 200, "East"),
("Electronics", "Tablet", 450.00, 34, "North"),
("Furniture", "Bookshelf", 180.00, 28, "West"),
("Electronics", "Headphones", 85.00, 175, "South")
]
schema = StructType([
StructField("category", StringType(), True),
StructField("product", StringType(), True),
StructField("price", DoubleType(), True),
StructField("units_sold", IntegerType(), True),
StructField("region", StringType(), True)
])
df = spark.createDataFrame(data, schema)
The describe() Method
The describe() method is your first stop for exploratory data analysis. It computes count, mean, standard deviation, minimum, and maximum for all numeric columns in a single operation:
# Basic describe - all numeric columns
df.describe().show()
# Output:
# +-------+------------------+------------------+
# |summary| price| units_sold|
# +-------+------------------+------------------+
# | count| 10| 10|
# | mean| 291.0| 92.3|
# | stddev|352.92...| 65.47...|
# | min| 45.0| 15|
# | max| 1200.0| 200|
# +-------+------------------+------------------+
You can target specific columns to reduce computation and focus your analysis:
# Describe specific columns only
df.describe(['price']).show()
# Include categorical columns for count statistics
df.describe(['category', 'price', 'units_sold']).show()
# Output now includes category counts:
# +-------+-----------+------------------+------------------+
# |summary| category| price| units_sold|
# +-------+-----------+------------------+------------------+
# | count| 10| 10| 10|
# | mean| null| 291.0| 92.3|
# | stddev| null|352.92...| 65.47...|
# | min|Electronics| 45.0| 15|
# | max| Furniture| 1200.0| 200|
# +-------+-----------+------------------+------------------+
Notice that categorical columns show count, min (alphabetically first), and max (alphabetically last), but mean and stddev are null. This behavior makes describe() useful for mixed-type DataFrames.
The summary() Method
The summary() method extends describe() by including quartile information—the 25th, 50th (median), and 75th percentiles:
# summary() provides more percentile details
df.summary().show()
# Output:
# +-------+------------------+------------------+
# |summary| price| units_sold|
# +-------+------------------+------------------+
# | count| 10| 10|
# | mean| 291.0| 92.3|
# | stddev|352.92...| 65.47...|
# | min| 45.0| 15|
# | 25%| 77.5| 32.5|
# | 50%| 240.0| 78.0|
# | 75%| 343.75| 146.25|
# | max| 1200.0| 200|
# +-------+------------------+------------------+
The quartile data is invaluable for understanding distribution shape. A large gap between the 75th percentile and max (like price: 343.75 vs 1200.0) signals potential outliers or a right-skewed distribution.
You can also request custom percentiles:
# Custom percentiles
df.summary("count", "mean", "10%", "90%").show()
Column-Specific Statistical Functions
For more control, use PySpark’s column-level statistical functions with agg(). This approach is more efficient when you need specific metrics rather than the full describe/summary output:
from pyspark.sql.functions import count, mean, stddev, min, max, variance, kurtosis, skewness
# Single aggregation pass with multiple statistics
df.agg(
count("price").alias("count"),
mean("price").alias("avg_price"),
stddev("price").alias("std_price"),
min("price").alias("min_price"),
max("price").alias("max_price"),
variance("price").alias("var_price")
).show()
# Output:
# +-----+------------------+------------------+---------+---------+------------------+
# |count| avg_price| std_price|min_price|max_price| var_price|
# +-----+------------------+------------------+---------+---------+------------------+
# | 10| 291.0|352.92...| 45.0| 1200.0| 124555.55...|
# +-----+------------------+------------------+---------+---------+------------------+
For percentiles, use approxQuantile(), which trades perfect accuracy for massive performance gains on large datasets:
# Calculate approximate quantiles
quantiles = df.approxQuantile("price", [0.25, 0.5, 0.75, 0.95], 0.01)
print(f"25th percentile: {quantiles[0]}")
print(f"Median: {quantiles[1]}")
print(f"75th percentile: {quantiles[2]}")
print(f"95th percentile: {quantiles[3]}")
# The third parameter (0.01) is relative error
# Lower values = more accurate but slower
# 0.0 = exact calculation (avoid on huge datasets)
Advanced Statistics with pyspark.sql.functions
Beyond univariate statistics, PySpark supports correlation and covariance analysis:
# Correlation between price and units_sold
correlation = df.stat.corr("price", "units_sold")
print(f"Price-Units Correlation: {correlation:.4f}")
# Covariance
covariance = df.stat.cov("price", "units_sold")
print(f"Price-Units Covariance: {covariance:.4f}")
Grouped statistics reveal patterns across categories:
# Statistics by category
df.groupBy("category").agg(
count("product").alias("product_count"),
mean("price").alias("avg_price"),
sum("units_sold").alias("total_units"),
stddev("price").alias("price_stddev")
).show()
# Output:
# +-----------+-------------+------------------+-----------+------------------+
# | category|product_count| avg_price|total_units| price_stddev|
# +-----------+-------------+------------------+-----------+------------------+
# | Furniture| 4| 193.75| 393| 122.87...|
# |Electronics| 6| 355.83...| 530| 423.44...|
# +-----------+-------------+------------------+-----------+------------------+
Window functions enable running statistics and rankings:
from pyspark.sql.window import Window
# Running average and rank within category
window_spec = Window.partitionBy("category").orderBy("price")
df.withColumn("running_avg", avg("price").over(window_spec)) \
.withColumn("price_rank", rank().over(window_spec)) \
.select("category", "product", "price", "running_avg", "price_rank") \
.show()
Performance Considerations and Best Practices
Statistical operations in PySpark trigger full dataset scans, making performance critical at scale. Here are key optimization strategies:
Cache when running multiple statistics:
# Cache the DataFrame if you'll compute multiple stats
df.cache()
df.describe().show()
df.summary().show()
df.stat.corr("price", "units_sold")
df.unpersist() # Free memory when done
Use approximate functions for percentiles:
import time
# Exact percentile calculation (slow on large data)
start = time.time()
exact = df.approxQuantile("price", [0.5], 0.0) # relativeError=0.0
exact_time = time.time() - start
# Approximate percentile (much faster)
start = time.time()
approx = df.approxQuantile("price", [0.5], 0.05) # relativeError=0.05
approx_time = time.time() - start
print(f"Exact: {exact[0]:.2f} in {exact_time:.4f}s")
print(f"Approx: {approx[0]:.2f} in {approx_time:.4f}s")
Combine aggregations to minimize passes:
# BAD: Multiple passes over data
mean_price = df.agg(mean("price")).collect()[0][0]
max_price = df.agg(max("price")).collect()[0][0]
min_price = df.agg(min("price")).collect()[0][0]
# GOOD: Single pass
stats = df.agg(
mean("price").alias("mean"),
max("price").alias("max"),
min("price").alias("min")
).collect()[0]
Practical Use Case: Sales Data Profiling
Here’s a complete workflow for profiling a sales dataset:
# Comprehensive statistical profile
def profile_dataframe(df, numeric_cols, categorical_cols):
"""Generate complete statistical profile of DataFrame."""
print("=== BASIC STATISTICS ===")
df.summary().show()
print("\n=== CATEGORICAL DISTRIBUTIONS ===")
for col in categorical_cols:
df.groupBy(col).count().orderBy(desc("count")).show()
print("\n=== CORRELATIONS ===")
for i, col1 in enumerate(numeric_cols):
for col2 in numeric_cols[i+1:]:
corr = df.stat.corr(col1, col2)
print(f"{col1} <-> {col2}: {corr:.4f}")
print("\n=== CATEGORY-WISE STATISTICS ===")
for cat_col in categorical_cols:
df.groupBy(cat_col).agg(
*[mean(num_col).alias(f"avg_{num_col}") for num_col in numeric_cols]
).show()
print("\n=== OUTLIER DETECTION (IQR Method) ===")
for col in numeric_cols:
quantiles = df.approxQuantile(col, [0.25, 0.75], 0.01)
q1, q3 = quantiles[0], quantiles[1]
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
outlier_count = df.filter(
(df[col] < lower_bound) | (df[col] > upper_bound)
).count()
print(f"{col}: {outlier_count} outliers (bounds: {lower_bound:.2f} - {upper_bound:.2f})")
# Execute profiling
profile_dataframe(
df,
numeric_cols=["price", "units_sold"],
categorical_cols=["category", "region"]
)
This function provides a comprehensive statistical overview: basic statistics, categorical distributions, correlations, grouped statistics, and outlier detection—everything you need for initial data assessment.
Summary statistics in PySpark are powerful tools for understanding distributed datasets. Master describe() and summary() for quick exploration, leverage agg() for efficient custom calculations, and always prefer approximate functions when exact precision isn’t required. With these techniques, you can profile terabyte-scale datasets as easily as small samples.