PySpark - SQL GROUP BY with Examples
• PySpark GROUP BY operations trigger shuffle operations across your cluster—understanding partition distribution and data skew is critical for performance at scale, unlike pandas where everything...
Key Insights
• PySpark GROUP BY operations trigger shuffle operations across your cluster—understanding partition distribution and data skew is critical for performance at scale, unlike pandas where everything fits in memory.
• The DataFrame API (groupBy().agg()) and Spark SQL syntax produce identical execution plans, so choose based on your team’s preferences and existing codebase patterns rather than performance concerns.
• Advanced aggregation functions like collect_list() and collect_set() are powerful but dangerous—collecting too many values into arrays can cause executor memory issues and should be used with explicit limits or filters.
Introduction to GROUP BY in PySpark
GROUP BY operations in PySpark differ fundamentally from pandas because they execute across distributed data partitions. When you group data in PySpark, Spark shuffles records with matching keys to the same executor, which can become a bottleneck with skewed data or poor partitioning strategies.
You should use PySpark GROUP BY when working with datasets that don’t fit in memory, when you need horizontal scalability, or when your data already lives in a distributed system like HDFS or S3. For datasets under a few gigabytes that fit comfortably in RAM, pandas is simpler and often faster.
PySpark offers two equivalent approaches: the DataFrame API and Spark SQL. Both compile to the same execution plan, so your choice depends on readability and team familiarity.
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, sum, avg, min, max
# Initialize SparkSession
spark = SparkSession.builder \
.appName("GroupByExamples") \
.getOrCreate()
# Create sample sales data
data = [
("2024-01-15", "Electronics", "Laptop", "North", 1200, 2),
("2024-01-16", "Electronics", "Mouse", "North", 25, 5),
("2024-01-16", "Clothing", "Shirt", "South", 45, 3),
("2024-01-17", "Electronics", "Laptop", "South", 1200, 1),
("2024-01-17", "Clothing", "Pants", "North", 60, 2),
("2024-01-18", "Electronics", "Keyboard", "North", 75, 4),
("2024-01-18", "Clothing", "Shirt", "South", 45, 6),
]
columns = ["date", "category", "product", "region", "price", "quantity"]
df = spark.createDataFrame(data, columns)
df.show()
Basic GROUP BY Operations
The groupBy() method creates a GroupedData object that you then aggregate using agg(). This two-step process gives you flexibility to apply multiple aggregation functions simultaneously.
# Single column grouping with count
category_counts = df.groupBy("category").count()
category_counts.show()
# Output:
# +-----------+-----+
# | category|count|
# +-----------+-----+
# |Electronics| 4|
# | Clothing| 3|
# +-----------+-----+
# Multiple aggregations with aliases
from pyspark.sql.functions import sum, avg
category_stats = df.groupBy("category").agg(
count("*").alias("total_transactions"),
sum(col("price") * col("quantity")).alias("total_revenue"),
avg("price").alias("avg_price"),
sum("quantity").alias("total_quantity")
)
category_stats.show()
# Output:
# +-----------+------------------+-------------+---------+--------------+
# | category|total_transactions|total_revenue|avg_price|total_quantity|
# +-----------+------------------+-------------+---------+--------------+
# |Electronics| 4| 2725| 625.0| 12|
# | Clothing| 3| 525| 50.0| 11|
# +-----------+------------------+-------------+---------+--------------+
Always use alias() for aggregated columns to create meaningful names. Without aliases, Spark generates names like sum(price) which become cumbersome in downstream operations.
Multi-Column Grouping
Grouping by multiple columns creates hierarchical aggregations. Each unique combination of the grouped columns becomes a separate group.
# Group by region and category
regional_category_sales = df.groupBy("region", "category").agg(
count("*").alias("transactions"),
sum(col("price") * col("quantity")).alias("revenue")
).orderBy("region", "category")
regional_category_sales.show()
# Output:
# +------+-----------+------------+-------+
# |region| category|transactions|revenue|
# +------+-----------+------------+-------+
# | North| Clothing| 1| 120|
# | North|Electronics| 3| 2525|
# | South| Clothing| 2| 405|
# | South|Electronics| 1| 1200|
# +------+-----------+------------+-------+
# Three-column grouping for time-series analysis
from pyspark.sql.functions import to_date, year, month
df_with_date = df.withColumn("date_parsed", to_date(col("date")))
monthly_product_sales = df_with_date.groupBy(
"category",
"product",
month("date_parsed").alias("month")
).agg(
sum("quantity").alias("units_sold"),
sum(col("price") * col("quantity")).alias("revenue")
).orderBy("category", "product", "month")
monthly_product_sales.show()
The orderBy() clause doesn’t affect the grouping logic but makes results more readable. For large datasets, consider whether you need sorted output—sorting requires additional shuffle operations.
Advanced Aggregation Functions
Beyond basic aggregations, PySpark provides specialized functions for complex data collection and analysis.
from pyspark.sql.functions import collect_list, collect_set, countDistinct, first, last
# Collect all products sold in each region
regional_products = df.groupBy("region").agg(
collect_list("product").alias("all_products"),
collect_set("product").alias("unique_products"),
countDistinct("product").alias("distinct_product_count")
)
regional_products.show(truncate=False)
# Output:
# +------+--------------------------------+------------------------+-----------------------+
# |region|all_products |unique_products |distinct_product_count|
# +------+--------------------------------+------------------------+-----------------------+
# |North |[Laptop, Mouse, Pants, Keyboard]|[Keyboard, Pants, ...] |4 |
# |South |[Shirt, Laptop, Shirt] |[Laptop, Shirt] |2 |
# +------+--------------------------------+------------------------+-----------------------+
# Using first() and last() - order matters!
category_price_range = df.orderBy("price").groupBy("category").agg(
first("product").alias("cheapest_product"),
first("price").alias("min_price"),
last("product").alias("most_expensive_product"),
last("price").alias("max_price")
)
category_price_range.show()
Warning: collect_list() and collect_set() gather all values into memory on a single executor. If a group contains millions of items, you’ll encounter out-of-memory errors. Always filter or limit data before using collection functions, or consider alternative approaches like window functions with row limits.
Using Spark SQL Syntax
If your team prefers SQL or you’re migrating from traditional databases, Spark SQL provides identical functionality with familiar syntax.
# Register DataFrame as temporary view
df.createOrReplaceTempView("sales")
# Basic GROUP BY with SQL
sql_category_stats = spark.sql("""
SELECT
category,
COUNT(*) as total_transactions,
SUM(price * quantity) as total_revenue,
AVG(price) as avg_price,
SUM(quantity) as total_quantity
FROM sales
GROUP BY category
""")
sql_category_stats.show()
# Multi-column grouping with HAVING clause
high_revenue_regions = spark.sql("""
SELECT
region,
category,
SUM(price * quantity) as revenue,
COUNT(*) as transactions
FROM sales
GROUP BY region, category
HAVING SUM(price * quantity) > 500
ORDER BY revenue DESC
""")
high_revenue_regions.show()
# Output:
# +------+-----------+-------+------------+
# |region| category|revenue|transactions|
# +------+-----------+-------+------------+
# | North|Electronics| 2525| 3|
# | South|Electronics| 1200| 1|
# +------+-----------+-------+------------+
The HAVING clause filters groups after aggregation, while WHERE filters rows before grouping. This distinction is crucial for query performance—filter as much data as possible with WHERE before the expensive shuffle operation.
Filtering Grouped Data and Window Functions
Post-aggregation filtering with the DataFrame API uses filter() or where() after the aggregation step.
# Filter aggregated results
high_volume_categories = df.groupBy("category").agg(
sum("quantity").alias("total_units")
).filter(col("total_units") > 10)
high_volume_categories.show()
# Window functions for running totals (alternative to GROUP BY)
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, rank
# Running total by category (preserves individual rows)
window_spec = Window.partitionBy("category").orderBy("date")
df_with_running_total = df.withColumn(
"running_quantity",
sum("quantity").over(window_spec)
)
df_with_running_total.select("date", "category", "product", "quantity", "running_quantity").show()
Window functions shine when you need both detail rows and aggregations, or when calculating rankings and running totals. Use GROUP BY when you only need summary statistics and want to reduce data volume.
Performance Considerations
GROUP BY operations trigger shuffle operations—data redistribution across executors. Understanding and optimizing shuffles is critical for production performance.
# Check execution plan
df.groupBy("category").agg(sum("quantity")).explain()
# Output shows Exchange (shuffle) operations:
# == Physical Plan ==
# *(2) HashAggregate(keys=[category#1], functions=[sum(quantity#5)])
# +- Exchange hashpartitioning(category#1, 200)
# +- *(1) HashAggregate(keys=[category#1], functions=[partial_sum(quantity#5)])
# Optimize with repartitioning before GROUP BY
optimized_df = df.repartition(10, "category")
optimized_df.groupBy("category").agg(sum("quantity")).explain()
# For small dimension tables, use broadcast joins
from pyspark.sql.functions import broadcast
# If joining grouped results with small lookup table
# grouped_df.join(broadcast(small_lookup_df), "category")
Key optimization strategies:
- Partition before grouping: If you know your group key distribution, repartition on that key to minimize shuffle
- Increase parallelism: Default shuffle partitions (200) may be too low for large datasets—configure
spark.sql.shuffle.partitions - Handle skew: If one category has 90% of records, use salting techniques to distribute load
- Broadcast small tables: When joining aggregated results with dimension tables under 10MB, broadcast to avoid shuffles
Monitor your Spark UI to identify bottlenecks. Look for stages with long durations and uneven task completion times—these indicate skew or suboptimal partitioning.
PySpark GROUP BY operations are powerful but require understanding distributed computing fundamentals. Start with the DataFrame API for better IDE support and type safety, use SQL when integrating with BI tools or migrating existing queries, and always profile with explain() before deploying to production. The key to performance is minimizing shuffle operations through smart partitioning and filtering data as early as possible in your transformation pipeline.