How to Partition Data in PySpark

Partitioning is how Spark divides your data into chunks that can be processed in parallel across your cluster. Each partition is a unit of work that gets assigned to a single task, which runs on a...

Key Insights

  • Partitioning controls how data is distributed across your cluster—get it wrong, and you’ll waste resources on idle executors or overwhelm individual nodes with skewed data.
  • Use repartition() when you need to increase partitions or distribute by specific columns; use coalesce() when reducing partitions to avoid expensive shuffles.
  • Write-time partitioning with partitionBy() is fundamentally different from in-memory partitioning—it creates directory structures that enable partition pruning during reads.

Introduction to Data Partitioning

Partitioning is how Spark divides your data into chunks that can be processed in parallel across your cluster. Each partition is a unit of work that gets assigned to a single task, which runs on a single core. If you have 100 partitions and 10 executor cores, Spark processes 10 partitions simultaneously, cycling through until all 100 are complete.

This relationship between partitions and parallelism is the foundation of Spark performance tuning. Too few partitions means cores sit idle while others do all the work. Too many partitions creates overhead from task scheduling and small file problems. Skewed partitions—where one partition has significantly more data than others—cause stragglers that delay your entire job.

Understanding partitioning isn’t optional if you’re working with PySpark at scale. It’s the difference between a job that runs in minutes and one that runs for hours.

Understanding Default Partitioning Behavior

Spark makes partitioning decisions automatically, but those defaults often aren’t optimal for your specific workload.

When reading files, Spark typically creates one partition per file block. For HDFS, that’s usually 128MB blocks. A 1GB file becomes roughly 8 partitions. When reading from sources like JDBC, the default might be a single partition unless you configure parallelism explicitly.

Two configuration parameters control default behavior:

  • spark.default.parallelism: Controls the default number of partitions for RDD transformations and parallelize operations. Defaults to the total number of cores across all executors.
  • spark.sql.shuffle.partitions: Controls the number of partitions after shuffle operations like groupBy(), join(), or distinct(). The default is 200, regardless of your cluster size or data volume.

That 200-partition default is notoriously problematic. Processing 10GB of data? 200 partitions might be reasonable. Processing 10TB? You’ll have 50GB partitions that cause memory issues. Processing 100MB? You’ll have 200 tiny partitions with more scheduling overhead than actual computation.

Check your current partition count to understand what you’re working with:

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("PartitionDemo").getOrCreate()

df = spark.read.parquet("/data/events/")

# Check partition count
num_partitions = df.rdd.getNumPartitions()
print(f"Current partitions: {num_partitions}")

# Check approximate partition sizes
total_rows = df.count()
print(f"Approximate rows per partition: {total_rows // num_partitions}")

For more detailed partition analysis:

from pyspark.sql.functions import spark_partition_id, count

# See distribution across partitions
df.withColumn("partition_id", spark_partition_id()) \
  .groupBy("partition_id") \
  .agg(count("*").alias("row_count")) \
  .orderBy("partition_id") \
  .show()

This reveals skew immediately. If partition 0 has 10 million rows and partition 1 has 100, you have a problem.

Repartitioning with repartition()

The repartition() method performs a full shuffle to redistribute data across a specified number of partitions. It’s expensive but sometimes necessary.

Use repartition() when you need to:

  • Increase the number of partitions (you can’t do this with coalesce())
  • Distribute data evenly across partitions
  • Partition by specific columns for downstream operations

Basic numeric repartitioning:

# Increase partitions for more parallelism
df_repartitioned = df.repartition(100)

# Verify the change
print(f"New partition count: {df_repartitioned.rdd.getNumPartitions()}")

Column-based repartitioning co-locates rows with the same key values in the same partition:

# Repartition by columns - all rows with same region/date land together
df_by_region = df.repartition("region", "date")

# Combine column-based with specific partition count
df_controlled = df.repartition(50, "region", "date")

Column-based repartitioning is powerful for optimizing joins and aggregations. When you join two DataFrames repartitioned on the same key, Spark can perform the join locally within each partition without additional shuffling.

# Optimize a join by pre-partitioning both sides
orders = orders.repartition(100, "customer_id")
customers = customers.repartition(100, "customer_id")

# This join now avoids an extra shuffle
result = orders.join(customers, "customer_id")

The cost of repartition() is a full shuffle—every record potentially moves across the network. Don’t call it repeatedly or unnecessarily. Plan your partition strategy upfront.

Optimizing with coalesce()

When you need fewer partitions, coalesce() is almost always the better choice. It reduces partitions without a full shuffle by combining existing partitions on the same executor.

# Reduce from 200 partitions to 10 without full shuffle
df_reduced = df.coalesce(10)

The key limitation: coalesce() can only reduce partitions, not increase them. Calling df.coalesce(1000) when you have 100 partitions does nothing.

Common use cases for coalesce():

Before writing to reduce file count:

# After filtering, you might have many small partitions
filtered_df = large_df.filter(col("status") == "active")

# Consolidate before writing to avoid small file problem
filtered_df.coalesce(10).write.parquet("/output/active_users/")

After wide transformations that inflate partition count:

# Aggregations use spark.sql.shuffle.partitions (default 200)
aggregated = df.groupBy("category").agg(sum("amount"))

# If you know the result is small, coalesce before continuing
aggregated.coalesce(4).write.parquet("/output/category_totals/")

Choose coalesce() over repartition() when:

  • You’re reducing partition count
  • You don’t need even data distribution (some skew is acceptable)
  • You want to minimize shuffle overhead

Choose repartition() when:

  • You’re increasing partition count
  • You need even distribution to address skew
  • You’re partitioning by specific columns for join optimization

Partition-Aware Writing with partitionBy()

The partitionBy() method on DataFrameWriter is conceptually different from in-memory partitioning. It creates a directory structure on disk that enables partition pruning during reads.

# Write with directory-based partitioning
df.write \
  .partitionBy("year", "month") \
  .parquet("/data/events/")

This creates a structure like:

/data/events/
  year=2024/
    month=01/
      part-00000.parquet
      part-00001.parquet
    month=02/
      ...
  year=2025/
    ...

When you later read with a filter, Spark skips irrelevant directories entirely:

# Only reads from year=2024/month=06/ directory
spark.read.parquet("/data/events/") \
  .filter((col("year") == 2024) & (col("month") == 6))

Best practices for choosing partition columns:

  1. Choose low-cardinality columns. Partitioning by user_id with millions of users creates millions of directories. Partition by date, region, or category instead.

  2. Order columns by query patterns. Put the most frequently filtered column first. If you usually query by year then month, use .partitionBy("year", "month").

  3. Avoid over-partitioning. Each partition should contain meaningful data—ideally hundreds of megabytes, not kilobytes.

Combine in-memory and write-time partitioning for optimal output:

# Control both file count AND directory structure
df.repartition("year", "month") \
  .write \
  .partitionBy("year", "month") \
  .parquet("/data/events/")

This ensures one file per directory partition, avoiding the small file problem.

Common Partitioning Strategies and Pitfalls

The 128MB rule of thumb: Aim for partitions around 128MB of data. This balances parallelism with overhead. Calculate your target partition count:

# Estimate data size and calculate partitions
data_size_mb = 50000  # 50GB
target_partition_size_mb = 128
optimal_partitions = data_size_mb // target_partition_size_mb  # ~390 partitions

Handling data skew: When certain keys have disproportionate data, add a salt to distribute the load:

from pyspark.sql.functions import concat, lit, floor, rand

# Add random salt to spread hot keys across partitions
salted_df = df.withColumn(
    "salted_key", 
    concat(col("region"), lit("_"), floor(rand() * 10))
)

# Repartition on salted key
salted_df = salted_df.repartition(100, "salted_key")

Avoiding the small file problem: After filters or aggregations, always check if you need to consolidate:

# Bad: creates potentially thousands of tiny files
df.filter(col("rare_condition") == True).write.parquet("/output/")

# Good: consolidate first
df.filter(col("rare_condition") == True).coalesce(1).write.parquet("/output/")

Setting shuffle partitions dynamically:

# Adjust based on data size, not hardcoded
spark.conf.set("spark.sql.shuffle.partitions", "auto")  # Spark 3.0+

# Or calculate manually
estimated_shuffle_size_mb = 10000
spark.conf.set("spark.sql.shuffle.partitions", str(estimated_shuffle_size_mb // 128))

Summary and Best Practices

Use this checklist when making partitioning decisions:

Scenario Action
Need more parallelism repartition(n) with higher n
Reducing partitions before write coalesce(n)
Optimizing joins on same key repartition(n, "key_col") on both DataFrames
Enabling partition pruning on reads write.partitionBy("col")
Fixing data skew Salt keys before repartitioning
Default shuffle partitions too high/low Set spark.sql.shuffle.partitions

Always verify your partitioning with df.rdd.getNumPartitions() and check for skew with spark_partition_id(). The few minutes spent understanding your data distribution will save hours of debugging slow jobs.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.