PySpark - GroupBy on DataFrame with Examples

• GroupBy operations in PySpark enable distributed aggregation across massive datasets by partitioning data into groups based on column values, with automatic parallelization across cluster nodes

Key Insights

• GroupBy operations in PySpark enable distributed aggregation across massive datasets by partitioning data into groups based on column values, with automatic parallelization across cluster nodes • Combining groupBy() with agg() allows multiple aggregation functions in a single pass, significantly reducing computation time compared to sequential operations • Data skew in grouped columns can cripple performance—use salting techniques and appropriate partitioning strategies to distribute work evenly across executors

Introduction to GroupBy in PySpark

The groupBy() operation is fundamental to data aggregation in PySpark, enabling you to partition DataFrames based on column values and perform calculations on each group. Unlike pandas where groupBy works on a single machine, PySpark distributes these operations across cluster nodes, making it essential for processing datasets that don’t fit in memory.

Common use cases include calculating departmental statistics from employee records, aggregating sales metrics by region and time period, analyzing user behavior patterns, and generating summary reports from transaction logs. Understanding groupBy mechanics is crucial because improper usage can lead to severe performance bottlenecks, particularly when dealing with skewed data distributions.

Basic GroupBy Syntax and Single Column Grouping

The fundamental syntax follows a simple pattern: select your DataFrame, call groupBy() with the column name(s), then apply an aggregation function. Here’s the basic structure:

from pyspark.sql import SparkSession
from pyspark.sql.functions import count, sum, avg

spark = SparkSession.builder.appName("GroupByExamples").getOrCreate()

# Create sample employee data
employee_data = [
    ("John", "Engineering", 95000),
    ("Sarah", "Engineering", 98000),
    ("Mike", "Sales", 75000),
    ("Emma", "Sales", 72000),
    ("David", "Engineering", 102000),
    ("Lisa", "HR", 68000),
    ("Tom", "HR", 71000),
    ("Anna", "Sales", 78000)
]

df = spark.createDataFrame(employee_data, ["name", "department", "salary"])

# Group by department and count employees
dept_count = df.groupBy("department").count()
dept_count.show()

Output:

+------------+-----+
|  department|count|
+------------+-----+
|Engineering |    3|
|       Sales|    3|
|          HR|    2|
+------------+-----+

For sales aggregation, the pattern remains consistent:

sales_data = [
    ("North", "2024-01", 50000),
    ("North", "2024-02", 52000),
    ("South", "2024-01", 48000),
    ("South", "2024-02", 51000),
    ("East", "2024-01", 45000),
    ("East", "2024-02", 47000)
]

sales_df = spark.createDataFrame(sales_data, ["region", "month", "revenue"])

# Group by region and sum revenue
regional_sales = sales_df.groupBy("region").agg(sum("revenue").alias("total_revenue"))
regional_sales.show()

Output:

+------+-------------+
|region|total_revenue|
+------+-------------+
| North|       102000|
| South|        99000|
|  East|        92000|
+------+-------------+

Multiple Column GroupBy Operations

Grouping by multiple columns creates hierarchical aggregations, useful for drill-down analysis. Pass multiple column names to groupBy() as separate arguments:

from pyspark.sql.functions import avg, round

employee_detailed = [
    ("John", "Engineering", "Senior", 95000),
    ("Sarah", "Engineering", "Senior", 98000),
    ("Mike", "Sales", "Junior", 75000),
    ("Emma", "Sales", "Senior", 85000),
    ("David", "Engineering", "Junior", 72000),
    ("Lisa", "HR", "Senior", 78000),
    ("Tom", "HR", "Junior", 62000),
    ("Anna", "Sales", "Junior", 68000)
]

df_detailed = spark.createDataFrame(
    employee_detailed, 
    ["name", "department", "job_title", "salary"]
)

# Group by department AND job_title
avg_salary_by_dept_title = df_detailed.groupBy("department", "job_title") \
    .agg(round(avg("salary"), 2).alias("avg_salary"))

avg_salary_by_dept_title.orderBy("department", "job_title").show()

Output:

+------------+---------+----------+
|  department|job_title|avg_salary|
+------------+---------+----------+
|Engineering |   Junior|  72000.00|
|Engineering |   Senior|  96500.00|
|          HR|   Junior|  62000.00|
|          HR|   Senior|  78000.00|
|       Sales|   Junior|  71500.00|
|       Sales|   Senior|  85000.00|
+------------+---------+----------+

For time-based grouping, extract date components first:

from pyspark.sql.functions import year, quarter, to_date

time_series_data = [
    ("2023-01-15", 10000),
    ("2023-04-20", 12000),
    ("2023-07-10", 15000),
    ("2023-10-05", 13000),
    ("2024-01-12", 11000),
    ("2024-04-18", 14000)
]

ts_df = spark.createDataFrame(time_series_data, ["date", "amount"])
ts_df = ts_df.withColumn("date", to_date("date"))

# Group by year and quarter
quarterly_summary = ts_df \
    .withColumn("year", year("date")) \
    .withColumn("quarter", quarter("date")) \
    .groupBy("year", "quarter") \
    .agg(sum("amount").alias("total_amount"))

quarterly_summary.orderBy("year", "quarter").show()

Aggregation Functions with GroupBy

The agg() method allows multiple aggregations simultaneously, which is far more efficient than separate groupBy calls:

from pyspark.sql.functions import min, max, stddev, countDistinct

# Multiple aggregations in one operation
comprehensive_stats = df_detailed.groupBy("department").agg(
    count("name").alias("employee_count"),
    avg("salary").alias("avg_salary"),
    min("salary").alias("min_salary"),
    max("salary").alias("max_salary"),
    round(stddev("salary"), 2).alias("salary_stddev")
)

comprehensive_stats.show()

Output:

+------------+--------------+----------+----------+----------+-------------+
|  department|employee_count|avg_salary|min_salary|max_salary|salary_stddev|
+------------+--------------+----------+----------+----------+-------------+
|Engineering |             3|   88333.3|     72000|     98000|     13650.29|
|          HR|             2|   70000.0|     62000|     78000|     11313.71|
|       Sales|             3|   76000.0|     68000|     85000|      8888.19|
+------------+--------------+----------+----------+----------+-------------+

Use collect_list() and collect_set() to gather grouped values into arrays:

from pyspark.sql.functions import collect_list, collect_set, concat_ws

# Collect employee names per department
dept_employees = df_detailed.groupBy("department").agg(
    collect_list("name").alias("all_employees"),
    collect_set("job_title").alias("unique_titles"),
    concat_ws(", ", collect_list("name")).alias("employee_names")
)

dept_employees.show(truncate=False)

Advanced GroupBy Techniques

Pivot operations transform grouped data from long to wide format, useful for creating cross-tabulation reports:

# Create pivot table: departments as rows, job titles as columns
pivot_salary = df_detailed.groupBy("department").pivot("job_title").agg(avg("salary"))
pivot_salary.show()

Output:

+------------+-------+-------+
|  department| Junior| Senior|
+------------+-------+-------+
|Engineering |72000.0|96500.0|
|          HR|62000.0|78000.0|
|       Sales|71500.0|85000.0|
+------------+-------+-------+

For custom aggregations, define user-defined aggregate functions (UDAFs):

from pyspark.sql.functions import expr

# Calculate salary range (max - min) per department
salary_range = df_detailed.groupBy("department").agg(
    (max("salary") - min("salary")).alias("salary_range")
)

salary_range.show()

For more complex custom logic, use pandas UDAFs (available in PySpark 2.3+):

from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
import pandas as pd

@pandas_udf(DoubleType())
def coefficient_of_variation(salary: pd.Series) -> float:
    return (salary.std() / salary.mean()) * 100

cv_by_dept = df_detailed.groupBy("department").agg(
    coefficient_of_variation("salary").alias("salary_cv")
)

Performance Optimization Tips

Repartitioning before groupBy can dramatically improve performance by ensuring even data distribution:

# Repartition by grouping key before aggregation
optimized_groupby = df_detailed \
    .repartition("department") \
    .groupBy("department") \
    .agg(avg("salary"))

# Check partition distribution
print(f"Partitions: {optimized_groupby.rdd.getNumPartitions()}")

For skewed data where one group dominates, use salting:

from pyspark.sql.functions import rand, floor, concat, lit

# Add salt to skewed key
salted_df = df_detailed.withColumn(
    "salted_dept", 
    concat("department", lit("_"), floor(rand() * 10).cast("string"))
)

# Group by salted key, then aggregate again
result = salted_df.groupBy("salted_dept").agg(sum("salary").alias("partial_sum")) \
    .withColumn("department", expr("substring(salted_dept, 1, length(salted_dept)-2)")) \
    .groupBy("department").agg(sum("partial_sum").alias("total_salary"))

Enable adaptive query execution for automatic optimization:

spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")

Common Pitfalls and Troubleshooting

Null values in grouping columns create a separate group. Handle them explicitly:

from pyspark.sql.functions import coalesce

# Data with nulls
data_with_nulls = [
    ("John", "Engineering", 95000),
    ("Sarah", None, 98000),
    ("Mike", "Sales", 75000),
    (None, "Engineering", 72000)
]

df_nulls = spark.createDataFrame(data_with_nulls, ["name", "department", "salary"])

# Replace nulls before grouping
cleaned_groupby = df_nulls \
    .withColumn("department", coalesce("department", lit("Unknown"))) \
    .groupBy("department") \
    .agg(avg("salary"))

cleaned_groupby.show()

Avoid collecting large grouped results to the driver:

# BAD: Don't do this with large groups
# large_groups = df.groupBy("key").agg(collect_list("value")).collect()

# GOOD: Write to storage or use take() with limit
df.groupBy("key").agg(collect_list("value")) \
    .write.parquet("output/grouped_data")

Monitor for OOM errors with large groups by setting appropriate executor memory and using spillable aggregations. PySpark automatically spills to disk, but configure appropriately:

spark.conf.set("spark.sql.shuffle.partitions", "200")  # Adjust based on data size
spark.conf.set("spark.executor.memory", "4g")

GroupBy operations are powerful but require understanding of distributed computing principles. Always profile your queries using the Spark UI to identify bottlenecks, and remember that the best optimization is often reducing data volume before grouping through filtering or sampling.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.