Apache Spark - Partitioning Strategies

Partitioning determines how Spark distributes data across the cluster. Each partition represents a logical chunk of data that a single executor core processes independently. Poor partitioning creates...

Key Insights

  • Partition size directly impacts Spark job performance—aim for 128MB-1GB per partition to balance parallelism with task overhead, avoiding the common pitfalls of too many small partitions or too few large ones
  • Hash partitioning works for most distributed operations, but range partitioning prevents data skew in sorted operations while custom partitioners solve domain-specific problems like geographic or temporal data distribution
  • Repartition and coalesce serve different purposes: use repartition() for increasing partitions with full shuffle, coalesce() for reducing partitions without shuffle, and understand when each operation pays for itself in execution time

Understanding Spark Partitioning Fundamentals

Partitioning determines how Spark distributes data across the cluster. Each partition represents a logical chunk of data that a single executor core processes independently. Poor partitioning creates bottlenecks—too few partitions leave cores idle, too many partitions create excessive task scheduling overhead.

Spark creates initial partitions based on data source characteristics. Reading from HDFS uses one partition per block (typically 128MB). JDBC sources use connection parameters to determine partition count. In-memory collections use spark.default.parallelism, which defaults to the total number of cores across all executor nodes.

// Check current partition count
val df = spark.read.parquet("s3://bucket/data")
println(s"Partitions: ${df.rdd.getNumPartitions}")

// View partition distribution
df.rdd.glom().map(_.length).collect().foreach(println)

// Examine data distribution across partitions
df.rdd.mapPartitionsWithIndex { (idx, iter) => 
  Iterator((idx, iter.size))
}.collect().foreach { case (idx, size) => 
  println(s"Partition $idx: $size records")
}

Hash Partitioning for Distributed Operations

Hash partitioning distributes data using a hash function on specified columns. This ensures records with identical keys land in the same partition—critical for joins, groupBy, and aggregations. Spark applies hash partitioning automatically during shuffle operations.

import org.apache.spark.sql.functions._

val transactions = spark.read.parquet("transactions/")

// Explicit hash partitioning by customer_id
val partitioned = transactions.repartition(200, col("customer_id"))

// Multiple column partitioning
val multiKey = transactions.repartition(
  200, 
  col("customer_id"), 
  col("transaction_date")
)

// Hash partitioning before expensive operations
val customerStats = transactions
  .repartition(col("customer_id"))
  .groupBy("customer_id")
  .agg(
    sum("amount").as("total_spent"),
    count("*").as("transaction_count")
  )

Hash partitioning works well when key distribution is relatively uniform. With skewed data where certain keys appear far more frequently, some partitions become disproportionately large, creating stragglers that delay entire jobs.

Range Partitioning for Sorted Data

Range partitioning divides data based on value ranges rather than hash values. This maintains sort order across partitions and prevents skew in sorted operations. Use range partitioning when working with time-series data, sequential IDs, or when downstream operations require sorted input.

val events = spark.read.parquet("events/")

// Range partition by timestamp
val rangePartitioned = events.repartitionByRange(100, col("timestamp"))

// Multiple columns with sort direction
val multiRange = events.repartitionByRange(
  100,
  col("year").asc,
  col("month").asc,
  col("day").asc
)

// Range partitioning eliminates sort in subsequent operations
val sortedOutput = rangePartitioned
  .sortWithinPartitions("timestamp", "event_id")
  .write
  .parquet("output/")

Range partitioning requires sampling data to determine range boundaries, adding upfront cost. However, this pays off when multiple operations benefit from pre-sorted data or when writing sorted output files.

Custom Partitioners for Domain Logic

Custom partitioners implement application-specific distribution logic. Create custom partitioners when data has natural boundaries that hash and range partitioning don’t capture—geographic regions, business hierarchies, or time windows.

import org.apache.spark.Partitioner

class GeoPartitioner(partitions: Int) extends Partitioner {
  override def numPartitions: Int = partitions
  
  override def getPartition(key: Any): Int = {
    val location = key.asInstanceOf[String]
    location match {
      case loc if loc.startsWith("US-") => 0
      case loc if loc.startsWith("EU-") => 1
      case loc if loc.startsWith("ASIA-") => 2
      case _ => 3
    }
  }
}

// Apply custom partitioner (requires RDD API)
val locationData = spark.read.parquet("locations/")
val rdd = locationData.rdd.map(row => 
  (row.getAs[String]("location"), row)
)

val customPartitioned = rdd.partitionBy(new GeoPartitioner(4))

// Convert back to DataFrame if needed
val df = spark.createDataFrame(
  customPartitioned.map(_._2),
  locationData.schema
)

Custom partitioners work exclusively with RDDs using key-value pairs. When working with DataFrames, implement partitioning logic through UDFs and repartition operations.

Repartition vs Coalesce

Understanding when to use repartition() versus coalesce() prevents unnecessary shuffles. Repartition performs a full shuffle to create the target partition count—expensive but necessary when increasing partitions or changing partition distribution. Coalesce reduces partitions by combining adjacent partitions without shuffle.

val largeDataset = spark.read.parquet("large_data/")
  // 5000 partitions from source

// Reduce partitions without shuffle (coalesce)
val reduced = largeDataset.coalesce(500)

// Increase partitions requires shuffle (repartition)
val increased = largeDataset.repartition(10000)

// Coalesce after filtering to reduce partition count
val filtered = largeDataset
  .filter(col("status") === "active")  // Reduces data by 90%
  .coalesce(500)  // Reduce partitions proportionally

// Write with optimal file count
filtered.write
  .option("maxRecordsPerFile", 1000000)
  .parquet("output/")

Coalesce creates uneven partitions when reducing partition count significantly. Some partitions may contain data from multiple source partitions while others remain untouched. For even distribution after coalesce, use repartition despite the shuffle cost.

Adaptive Query Execution

Spark 3.0+ includes Adaptive Query Execution (AQE) that dynamically optimizes partition counts during execution. AQE coalesces small partitions, splits large partitions, and optimizes skewed joins automatically.

// Enable AQE (enabled by default in Spark 3.2+)
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.initialPartitionNum", "200")
spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128MB")

// AQE handles partition optimization automatically
val result = spark.read.parquet("input/")
  .filter(col("date") >= "2024-01-01")
  .groupBy("category")
  .agg(sum("amount"))
  .write
  .parquet("output/")

AQE reduces the need for manual partition tuning but doesn’t eliminate it. Initial partition counts still matter for early stages, and AQE can’t fix fundamentally poor partitioning strategies.

Partition Size Guidelines

Target 128MB-1GB per partition for optimal performance. Smaller partitions increase task scheduling overhead—launching thousands of tiny tasks wastes more time than processing data. Larger partitions reduce parallelism and increase memory pressure.

// Calculate optimal partition count
val dataSize = 500 * 1024 * 1024 * 1024L  // 500GB
val targetPartitionSize = 256 * 1024 * 1024L  // 256MB
val optimalPartitions = (dataSize / targetPartitionSize).toInt

val optimized = spark.read.parquet("data/")
  .repartition(optimalPartitions)

// Monitor partition sizes
optimized.rdd.mapPartitions { iter =>
  val size = iter.map(_.toString.getBytes.length.toLong).sum
  Iterator(size)
}.collect().foreach(bytes => println(s"${bytes / (1024*1024)} MB"))

Adjust partition size based on available memory and data characteristics. Wide transformations with many columns need larger executor memory and benefit from fewer, larger partitions. Narrow transformations with selective columns handle more, smaller partitions efficiently.

Handling Data Skew

Data skew occurs when certain partition keys contain disproportionate data volumes. One oversized partition becomes a bottleneck while other executors sit idle. Detect skew through partition size analysis and Spark UI stage metrics.

// Detect skew
val skewedData = spark.read.parquet("skewed/")
  .groupBy("key")
  .count()
  .orderBy(col("count").desc)
  .show()

// Salting technique for skewed joins
import org.apache.spark.sql.functions._

val saltCount = 10
val leftDF = skewedData
  .withColumn("salt", (rand() * saltCount).cast("int"))
  .withColumn("salted_key", concat(col("key"), lit("_"), col("salt")))

val rightDF = otherData
  .withColumn("salt", explode(array((0 until saltCount).map(lit): _*)))
  .withColumn("salted_key", concat(col("key"), lit("_"), col("salt")))

val joined = leftDF.join(rightDF, Seq("salted_key"))
  .drop("salt", "salted_key")

Salting artificially increases cardinality of skewed keys by appending random values, distributing hot keys across multiple partitions. This technique adds complexity but solves severe skew problems that would otherwise make jobs impossible to complete.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.