How to Repartition a DataFrame in PySpark
Partitions are the fundamental unit of parallelism in Spark. When you create a DataFrame, Spark splits the data across multiple partitions, and each partition gets processed independently by a...
Key Insights
- Use
repartition()when you need to increase partitions or redistribute data evenly (triggers full shuffle), andcoalesce()when decreasing partitions (avoids shuffle but can create uneven distribution) - Column-based repartitioning with
df.repartition(n, "column")co-locates related data on the same partitions, dramatically improving join and aggregation performance - Target partition sizes of 128-256MB and aim for 2-4x your cluster’s total cores; too few partitions underutilize parallelism, too many create excessive scheduling overhead
Introduction to DataFrame Partitioning
Partitions are the fundamental unit of parallelism in Spark. When you create a DataFrame, Spark splits the data across multiple partitions, and each partition gets processed independently by a separate task on your cluster. This is how Spark achieves distributed processing—more partitions mean more potential parallelism.
The problem is that Spark’s default partitioning often doesn’t match your workload. You might read a small file that creates 200 partitions (Spark’s default for shuffles), wasting resources on task scheduling overhead. Or you might filter a large dataset down to a fraction of its size, leaving you with mostly empty partitions. Repartitioning gives you control over this distribution.
You’ll need to repartition when: preparing data for writes (to control output file count), optimizing joins between large tables, recovering from skewed data after transformations, or right-sizing partitions after significant filtering.
Understanding Your Current Partition State
Before changing partitions, understand what you’re working with. The simplest check is getNumPartitions():
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("PartitionDemo").getOrCreate()
# Create a sample DataFrame
df = spark.range(0, 1000000)
# Check partition count
print(f"Number of partitions: {df.rdd.getNumPartitions()}")
This tells you the count, but not the distribution. Uneven partitions cause stragglers—tasks that take much longer than others because they’re processing more data. To inspect partition sizes:
from pyspark.sql.functions import spark_partition_id, count
# See row distribution across partitions
df.withColumn("partition_id", spark_partition_id()) \
.groupBy("partition_id") \
.agg(count("*").alias("row_count")) \
.orderBy("partition_id") \
.show()
This query adds a partition ID column, then aggregates to show rows per partition. If you see wildly different counts—say, one partition with 900,000 rows and others with 10,000—you have skew that needs addressing.
For a quick size estimate, you can also use:
# Approximate size per partition (requires caching first for accuracy)
df.cache()
df.count() # Force materialization
# Check storage info in Spark UI or use:
print(f"Estimated size: {df.rdd.map(lambda x: len(str(x))).sum()} bytes")
repartition() vs coalesce(): Key Differences
Spark provides two methods for changing partition counts, and choosing wrong costs you either performance or correctness.
repartition(n) performs a full shuffle. Every record gets redistributed across all new partitions using a round-robin or hash-based approach. This is expensive but guarantees even distribution.
coalesce(n) uses a narrow dependency—it combines existing partitions without shuffling data across the network. This is cheaper but only works for reducing partitions, and it can create uneven results.
# Starting point: 100 partitions
df = spark.range(0, 10000000).repartition(100)
print(f"Starting partitions: {df.rdd.getNumPartitions()}") # 100
# Reduce to 10 partitions - two approaches
df_repartitioned = df.repartition(10)
df_coalesced = df.coalesce(10)
print(f"After repartition(10): {df_repartitioned.rdd.getNumPartitions()}") # 10
print(f"After coalesce(10): {df_coalesced.rdd.getNumPartitions()}") # 10
# Check distribution difference
def show_distribution(dataframe, name):
print(f"\n{name} distribution:")
dataframe.withColumn("partition_id", spark_partition_id()) \
.groupBy("partition_id") \
.count() \
.orderBy("partition_id") \
.show()
show_distribution(df_repartitioned, "repartition()")
show_distribution(df_coalesced, "coalesce()")
The repartition() output shows roughly equal counts per partition. The coalesce() output often shows variation because it’s merging adjacent partitions from the original distribution.
Use repartition() when:
- Increasing partition count (coalesce can’t do this)
- You need guaranteed even distribution
- Preparing for operations sensitive to skew
Use coalesce() when:
- Decreasing partitions and distribution doesn’t matter much
- Writing output files where slight size variation is acceptable
- You want to avoid shuffle costs
Repartitioning by Column
Column-based repartitioning is where the real optimization happens. Instead of random distribution, Spark hashes the specified column values and assigns rows with the same hash to the same partition.
from pyspark.sql.functions import col
# Sample data with user transactions
transactions = spark.createDataFrame([
(1, "purchase", 100.0),
(2, "purchase", 50.0),
(1, "refund", 25.0),
(3, "purchase", 200.0),
(2, "purchase", 75.0),
(1, "purchase", 150.0),
], ["user_id", "type", "amount"])
# Repartition by user_id
df_by_user = transactions.repartition(4, "user_id")
# All records for the same user_id are now on the same partition
df_by_user.withColumn("partition_id", spark_partition_id()) \
.orderBy("user_id", "type") \
.show()
This co-location is powerful. When you join two DataFrames repartitioned on the same key, Spark can perform a more efficient join because matching keys are already on the same executor—no shuffle needed during the join itself.
# Two DataFrames to join
users = spark.createDataFrame([
(1, "Alice"), (2, "Bob"), (3, "Charlie")
], ["user_id", "name"])
# Repartition both on join key
users_partitioned = users.repartition(8, "user_id")
transactions_partitioned = transactions.repartition(8, "user_id")
# Join is now optimized - data is co-located
result = transactions_partitioned.join(users_partitioned, "user_id")
result.explain() # Shows optimized join plan
You can also repartition by multiple columns:
# Partition by composite key
df.repartition(16, "year", "month", "day")
Choosing the Right Number of Partitions
The magic number depends on your data size and cluster resources. Here are the guidelines:
Target partition size: 128-256MB. Smaller partitions increase scheduling overhead; larger ones reduce parallelism and can cause memory pressure.
Partition count: 2-4x total cores. If you have 10 executors with 4 cores each (40 cores), aim for 80-160 partitions. This ensures all cores stay busy even if some tasks finish faster.
def calculate_optimal_partitions(data_size_gb, target_partition_mb=200, cluster_cores=40):
"""Calculate recommended partition count."""
data_size_mb = data_size_gb * 1024
# Based on data size
size_based = int(data_size_mb / target_partition_mb)
# Based on parallelism (2-4x cores)
parallelism_min = cluster_cores * 2
parallelism_max = cluster_cores * 4
# Take the value that satisfies both constraints
recommended = max(size_based, parallelism_min)
recommended = min(recommended, parallelism_max) if size_based < parallelism_max else size_based
print(f"Data size: {data_size_gb}GB ({data_size_mb}MB)")
print(f"Size-based partitions: {size_based}")
print(f"Parallelism range: {parallelism_min}-{parallelism_max}")
print(f"Recommended: {recommended}")
return recommended
# Example: 50GB dataset, 40 cores
optimal = calculate_optimal_partitions(50, cluster_cores=40)
Common pitfalls:
- Too few partitions: Underutilized cluster, potential OOM errors, long task times
- Too many partitions: Excessive scheduling overhead, small files on write, driver memory pressure from tracking tasks
Performance Considerations and Best Practices
Every repartition() triggers a shuffle—data serialization, network transfer, and deserialization. This is expensive. Use explain() to understand the impact:
df = spark.range(0, 1000000)
# Without repartition
df.groupBy(spark_partition_id()).count().explain()
# With repartition
df.repartition(10).groupBy(spark_partition_id()).count().explain()
The second plan shows an additional Exchange (shuffle) stage. On large datasets, this can add minutes to your job.
Best practices:
-
Repartition early, not often. If you need specific partitioning, do it once near the start of your pipeline.
-
Cache after expensive repartitions. If you’ll reuse the repartitioned DataFrame, cache it to avoid repeating the shuffle.
df_optimized = df.repartition(100, "key_column").cache()
df_optimized.count() # Materialize cache
# Now multiple operations benefit from the partitioning
result1 = df_optimized.groupBy("key_column").sum()
result2 = df_optimized.filter(col("value") > 100)
-
Let Spark handle it when possible. Operations like
groupByandjointrigger their own shuffles. Adding a manual repartition beforehand often just adds an extra shuffle. -
Check AQE settings. Spark 3.0+ has Adaptive Query Execution that can automatically coalesce shuffle partitions. Enable with
spark.sql.adaptive.enabled=true.
Common Use Cases and Patterns
Before writing files: Control output file count to avoid small files problem.
# Write exactly 10 parquet files
df.repartition(10).write.mode("overwrite").parquet("/output/path")
# Write partitioned by date with controlled files per partition
df.repartition(4, "date") \
.write \
.partitionBy("date") \
.mode("overwrite") \
.parquet("/output/partitioned")
After heavy filtering: Reclaim empty partitions.
# Original: 200 partitions with 10M rows
df_filtered = df.filter(col("status") == "active") # Now only 100K rows
# Coalesce to appropriate size
df_filtered.coalesce(10).write.parquet("/output/filtered")
Pre-join optimization: Co-locate data for large-large joins.
# Both tables repartitioned on join key
table_a = large_df_a.repartition(200, "join_key")
table_b = large_df_b.repartition(200, "join_key")
# Join benefits from co-location
result = table_a.join(table_b, "join_key")
Repartitioning is a powerful lever for Spark performance, but it’s not free. Measure before and after, understand your data distribution, and let the guidelines—not guesswork—drive your partition counts.