PySpark - SQL Aggregate Functions
PySpark aggregate functions are the workhorses of big data analytics. Unlike Pandas, which loads entire datasets into memory on a single machine, PySpark distributes data across multiple nodes and...
Key Insights
- PySpark aggregate functions distribute computations across clusters, making them fundamentally different from Pandas operations that run on single machines—understanding when to use
groupBy().agg()versus simpleagg()is critical for correct results - Advanced functions like
collect_list()andapprox_count_distinct()can drastically reduce memory overhead and processing time, with approximate functions offering 99%+ accuracy at a fraction of the computational cost - Window functions enable running aggregations and rankings without collapsing your dataset, allowing you to maintain row-level detail while computing partition-level statistics in a single pass
Introduction to PySpark Aggregations
PySpark aggregate functions are the workhorses of big data analytics. Unlike Pandas, which loads entire datasets into memory on a single machine, PySpark distributes data across multiple nodes and performs aggregations in parallel. This fundamental difference means you need to think differently about how aggregations execute.
When you run an aggregation in PySpark, the framework shuffles data across the cluster, grouping related records on the same executor before computing results. This shuffle operation is expensive, so understanding how to minimize it is crucial for performance.
Let’s start with a basic example:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.appName("AggregationDemo").getOrCreate()
# Create sample sales data
data = [
("2024-01-01", "Electronics", "Laptop", 1200, 2),
("2024-01-01", "Electronics", "Mouse", 25, 5),
("2024-01-02", "Clothing", "Shirt", 40, 3),
("2024-01-02", "Electronics", "Keyboard", 75, 4),
("2024-01-03", "Clothing", "Pants", 60, 2)
]
df = spark.createDataFrame(data, ["date", "category", "product", "price", "quantity"])
# Simple aggregations
total_revenue = df.select(F.sum(F.col("price") * F.col("quantity"))).first()[0]
total_items = df.select(F.sum("quantity")).first()[0]
print(f"Total Revenue: ${total_revenue}")
print(f"Total Items Sold: {total_items}")
Basic Aggregate Functions
PySpark provides all standard SQL aggregate functions: count(), sum(), avg(), min(), max(), and mean(). You can use them through the DataFrame API or SQL expressions—both compile to the same execution plan.
Here’s how to apply multiple aggregations simultaneously:
# Using agg() for multiple aggregations at once
summary = df.agg(
F.count("*").alias("total_transactions"),
F.sum(F.col("price") * F.col("quantity")).alias("total_revenue"),
F.avg("price").alias("avg_price"),
F.min("price").alias("min_price"),
F.max("price").alias("max_price")
)
summary.show()
# Alternative: SQL-style approach
df.createOrReplaceTempView("sales")
spark.sql("""
SELECT
COUNT(*) as total_transactions,
SUM(price * quantity) as total_revenue,
AVG(price) as avg_price,
MIN(price) as min_price,
MAX(price) as max_price
FROM sales
""").show()
The key difference between df.agg() and df.groupBy().agg() is that the former aggregates the entire DataFrame into a single row, while the latter creates groups first:
# Without groupBy - single result row
df.agg(F.sum("quantity")).show()
# With groupBy - one row per category
df.groupBy("category").agg(F.sum("quantity").alias("total_quantity")).show()
GroupBy and Multiple Aggregations
Real-world analytics typically require grouping by one or more dimensions. The groupBy() method creates a GroupedData object that you then aggregate:
# Group by single column
category_stats = df.groupBy("category").agg(
F.count("*").alias("num_transactions"),
F.sum(F.col("price") * F.col("quantity")).alias("revenue"),
F.avg("price").alias("avg_price"),
F.max("quantity").alias("max_quantity")
)
category_stats.show()
# Group by multiple columns
daily_category_stats = df.groupBy("date", "category").agg(
F.sum(F.col("price") * F.col("quantity")).alias("daily_revenue"),
F.count("product").alias("products_sold")
).orderBy("date", "category")
daily_category_stats.show()
You can also use dictionary syntax for cleaner code when applying the same aggregation to multiple columns:
# Dictionary syntax - less verbose for simple cases
df.groupBy("category").agg({
"price": "avg",
"quantity": "sum"
}).show()
# But explicit functions give you more control
df.groupBy("category").agg(
F.avg("price").alias("avg_price"),
F.sum("quantity").alias("total_quantity"),
F.expr("sum(price * quantity)").alias("revenue") # Complex expressions
).show()
Advanced Aggregate Functions
Beyond basic statistics, PySpark offers powerful functions for collecting values, counting distinct elements, and computing statistical measures.
# Create dataset with duplicate values
events_data = [
("user1", "2024-01-01", "login"),
("user1", "2024-01-01", "view_page"),
("user1", "2024-01-02", "login"),
("user2", "2024-01-01", "login"),
("user2", "2024-01-01", "purchase"),
]
events_df = spark.createDataFrame(events_data, ["user_id", "date", "event_type"])
# Collect all values vs unique values
user_activity = events_df.groupBy("user_id").agg(
F.collect_list("event_type").alias("all_events"),
F.collect_set("event_type").alias("unique_events"),
F.countDistinct("event_type").alias("distinct_event_count")
)
user_activity.show(truncate=False)
For very large datasets with high cardinality, approx_count_distinct() provides significant performance benefits:
# Generate large dataset
large_df = spark.range(0, 10000000).select(
(F.rand() * 1000000).cast("int").alias("user_id"),
(F.rand() * 100).cast("int").alias("product_id")
)
# Exact count - slower but precise
exact_count = large_df.select(F.countDistinct("user_id")).first()[0]
# Approximate count - much faster, typically 99%+ accurate
approx_count = large_df.select(F.approx_count_distinct("user_id")).first()[0]
print(f"Exact: {exact_count}, Approximate: {approx_count}")
print(f"Error: {abs(exact_count - approx_count) / exact_count * 100:.2f}%")
Statistical aggregations are essential for data quality monitoring:
# Sensor data example
sensor_data = [(i, float(50 + (i % 20) - 10)) for i in range(100)]
sensor_df = spark.createDataFrame(sensor_data, ["reading_id", "temperature"])
stats = sensor_df.agg(
F.mean("temperature").alias("mean_temp"),
F.stddev("temperature").alias("stddev_temp"),
F.variance("temperature").alias("variance_temp"),
F.expr("percentile_approx(temperature, 0.5)").alias("median_temp")
)
stats.show()
Window Functions and Cumulative Aggregations
Window functions are game-changers when you need aggregations without losing row-level detail. They compute values over a “window” of rows related to the current row:
from pyspark.sql.window import Window
# Time series data
ts_data = [
("2024-01-01", "ProductA", 100),
("2024-01-02", "ProductA", 150),
("2024-01-03", "ProductA", 120),
("2024-01-01", "ProductB", 200),
("2024-01-02", "ProductB", 180),
("2024-01-03", "ProductB", 220),
]
ts_df = spark.createDataFrame(ts_data, ["date", "product", "sales"])
# Define window specification
window_spec = Window.partitionBy("product").orderBy("date")
# Running totals and moving averages
ts_with_metrics = ts_df.withColumn(
"running_total", F.sum("sales").over(window_spec)
).withColumn(
"moving_avg_3day", F.avg("sales").over(
window_spec.rowsBetween(-2, 0)
)
).withColumn(
"rank", F.rank().over(Window.partitionBy("product").orderBy(F.desc("sales")))
)
ts_with_metrics.orderBy("product", "date").show()
Lag and lead operations let you compare current values with previous or future rows:
# Compare with previous day
ts_with_comparison = ts_df.withColumn(
"previous_day_sales", F.lag("sales", 1).over(window_spec)
).withColumn(
"sales_change", F.col("sales") - F.lag("sales", 1).over(window_spec)
).withColumn(
"pct_change",
((F.col("sales") - F.lag("sales", 1).over(window_spec)) /
F.lag("sales", 1).over(window_spec) * 100)
)
ts_with_comparison.orderBy("product", "date").show()
Custom Aggregations with UDAFs
When built-in functions don’t meet your needs, create custom aggregations. Pandas UDFs are the modern, performant approach:
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
import pandas as pd
# Weighted average using Pandas UDF
@pandas_udf(DoubleType())
def weighted_average(values: pd.Series, weights: pd.Series) -> float:
return (values * weights).sum() / weights.sum()
# Apply custom aggregation
weighted_data = [
("A", 10, 2),
("A", 20, 3),
("B", 15, 1),
("B", 25, 4),
]
weighted_df = spark.createDataFrame(weighted_data, ["group", "value", "weight"])
result = weighted_df.groupBy("group").agg(
weighted_average(F.col("value"), F.col("weight")).alias("weighted_avg"),
F.avg("value").alias("simple_avg")
)
result.show()
Performance Optimization Tips
Aggregations can be expensive. Here’s how to optimize them:
1. Cache intermediate results when reusing aggregations:
# Cache before multiple aggregations
df.cache()
result1 = df.groupBy("category").agg(F.sum("revenue"))
result2 = df.groupBy("category").agg(F.avg("price"))
df.unpersist() # Clean up when done
2. Use explain() to understand execution plans:
# Check if your aggregation causes unnecessary shuffles
df.groupBy("category").agg(F.sum("revenue")).explain()
# Look for "Exchange" operations - these are shuffles
3. Partition your data appropriately:
# Repartition by group key before aggregation to reduce shuffle
df.repartition("category").groupBy("category").agg(
F.sum("revenue")
).show()
4. Use broadcast joins when joining aggregated results with small tables:
from pyspark.sql.functions import broadcast
# Small lookup table
categories = spark.createDataFrame([("Electronics", "Tech"), ("Clothing", "Fashion")],
["category", "department"])
# Broadcast small table to avoid shuffle
aggregated = df.groupBy("category").agg(F.sum("revenue").alias("total_revenue"))
result = aggregated.join(broadcast(categories), "category")
PySpark aggregations are powerful but require understanding distributed computing principles. Start with simple aggregations, use explain() to verify execution plans, and progressively optimize based on your data size and cluster resources. The functions covered here handle 95% of real-world analytics scenarios—master these before reaching for custom UDAFs.