How to Drop Duplicates in PySpark

Duplicate data is the silent killer of data pipelines. It inflates metrics, breaks joins, and corrupts downstream analytics. In distributed systems like PySpark, duplicates multiply fast—network...

Key Insights

  • Use dropDuplicates() with a column subset for most deduplication tasks—it’s more explicit and handles partial duplicates better than distinct()
  • Window functions with row_number() give you precise control over which record to keep when duplicates have different values in non-key columns
  • Always partition your data by the deduplication key before calling dropDuplicates() to minimize expensive shuffle operations across your cluster

Duplicate data is the silent killer of data pipelines. It inflates metrics, breaks joins, and corrupts downstream analytics. In distributed systems like PySpark, duplicates multiply fast—network retries, failed task re-executions, and overlapping data loads all contribute to the problem. Worse, deduplication at scale isn’t trivial. You’re shuffling potentially billions of rows across a cluster to compare them.

This guide covers every deduplication technique you’ll need in PySpark, from simple one-liners to advanced window function patterns. I’ll show you when to use each approach and how to avoid the performance pitfalls that trip up most engineers.

Understanding Duplicates in DataFrames

Before removing duplicates, you need to understand what kind you’re dealing with. There are two categories:

Exact duplicates are rows where every column value matches. These typically come from accidental double-writes or idempotency failures in your ingestion layer.

Partial duplicates share the same key columns but differ in other fields. These are trickier—you need to decide which version to keep. Common scenarios include multiple updates to the same record arriving out of order, or the same event being recorded with slightly different metadata.

Let’s create a sample DataFrame that demonstrates both scenarios:

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType
from datetime import datetime

spark = SparkSession.builder.appName("DeduplicationDemo").getOrCreate()

data = [
    (1, "alice@example.com", "Alice Smith", datetime(2024, 1, 15, 10, 30)),
    (1, "alice@example.com", "Alice Smith", datetime(2024, 1, 15, 10, 30)),  # Exact duplicate
    (2, "bob@example.com", "Bob Jones", datetime(2024, 1, 14, 9, 0)),
    (2, "bob@example.com", "Robert Jones", datetime(2024, 1, 16, 14, 20)),   # Partial duplicate (name changed)
    (3, "carol@example.com", "Carol White", datetime(2024, 1, 10, 8, 0)),
    (3, "carol@example.com", "Carol White", datetime(2024, 1, 12, 11, 45)),  # Partial duplicate (different timestamp)
    (4, "dave@example.com", "Dave Brown", datetime(2024, 1, 13, 16, 30)),
]

schema = StructType([
    StructField("user_id", IntegerType(), False),
    StructField("email", StringType(), False),
    StructField("name", StringType(), True),
    StructField("updated_at", TimestampType(), True),
])

df = spark.createDataFrame(data, schema)
df.show(truncate=False)
+-------+-----------------+------------+-------------------+
|user_id|email            |name        |updated_at         |
+-------+-----------------+------------+-------------------+
|1      |alice@example.com|Alice Smith |2024-01-15 10:30:00|
|1      |alice@example.com|Alice Smith |2024-01-15 10:30:00|
|2      |bob@example.com  |Bob Jones   |2024-01-14 09:00:00|
|2      |bob@example.com  |Robert Jones|2024-01-16 14:20:00|
|3      |carol@example.com|Carol White |2024-01-10 08:00:00|
|3      |carol@example.com|Carol White |2024-01-12 11:45:00|
|4      |dave@example.com |Dave Brown  |2024-01-13 16:30:00|
+-------+-----------------+------------+-------------------+

User 1 has an exact duplicate. Users 2 and 3 have partial duplicates—same key, different values. This is the messy reality of production data.

Using the dropDuplicates() Method

The dropDuplicates() method is your primary tool for removing exact duplicate rows. When called without arguments, it compares all columns and keeps only unique rows:

print(f"Original count: {df.count()}")

deduped_df = df.dropDuplicates()

print(f"After dropDuplicates(): {deduped_df.count()}")
deduped_df.show(truncate=False)
Original count: 7
After dropDuplicates(): 6
+-------+-----------------+------------+-------------------+
|user_id|email            |name        |updated_at         |
+-------+-----------------+------------+-------------------+
|1      |alice@example.com|Alice Smith |2024-01-15 10:30:00|
|2      |bob@example.com  |Bob Jones   |2024-01-14 09:00:00|
|2      |bob@example.com  |Robert Jones|2024-01-16 14:20:00|
|3      |carol@example.com|Carol White |2024-01-10 08:00:00|
|3      |carol@example.com|Carol White |2024-01-12 11:45:00|
|4      |dave@example.com |Dave Brown  |2024-01-13 16:30:00|
+-------+-----------------+------------+-------------------+

Only Alice’s exact duplicate was removed. Bob and Carol’s partial duplicates remain because they differ in at least one column. This is correct behavior—dropDuplicates() without arguments only removes rows that are identical across every field.

Dropping Duplicates on Specific Columns

Most real-world deduplication requires keeping one record per key, not per unique row combination. Pass a list of columns to dropDuplicates() to deduplicate based on those columns only:

deduped_by_user = df.dropDuplicates(["user_id"])
deduped_by_user.show(truncate=False)
+-------+-----------------+------------+-------------------+
|user_id|email            |name        |updated_at         |
+-------+-----------------+------------+-------------------+
|1      |alice@example.com|Alice Smith |2024-01-15 10:30:00|
|2      |bob@example.com  |Bob Jones   |2024-01-14 09:00:00|
|3      |carol@example.com|Carol White |2024-01-10 08:00:00|
|4      |dave@example.com |Dave Brown  |2024-01-13 16:30:00|
+-------+-----------------+------------+-------------------+

Now we have exactly one row per user_id. But there’s a critical caveat: which row gets kept is non-deterministic. PySpark doesn’t guarantee it will keep the first row it encounters, the last row, or any particular row. The result depends on data partitioning and execution order.

For Bob, we got the older record. For Carol, we also got the older one. But run this on a different cluster configuration, and you might get different results. If you need deterministic behavior—like always keeping the most recent record—you need window functions.

You can also deduplicate on multiple columns for composite keys:

# Deduplicate on email domain and name combination
from pyspark.sql.functions import split

df_with_domain = df.withColumn("domain", split("email", "@")[1])
deduped_composite = df_with_domain.dropDuplicates(["domain", "name"])
deduped_composite.show(truncate=False)

Using distinct() vs. dropDuplicates()

PySpark offers two methods that seem similar: distinct() and dropDuplicates(). Here’s the difference:

# These produce identical results for full-row deduplication
distinct_df = df.distinct()
drop_dup_df = df.dropDuplicates()

print(f"distinct() count: {distinct_df.count()}")
print(f"dropDuplicates() count: {drop_dup_df.count()}")

# Verify they're equivalent
print(f"Results match: {distinct_df.subtract(drop_dup_df).count() == 0}")
distinct() count: 6
dropDuplicates() count: 6
Results match: True

Functionally, distinct() is equivalent to dropDuplicates() with no arguments. Both remove exact duplicate rows across all columns. The key difference is that dropDuplicates() accepts a column subset, while distinct() does not.

My recommendation: Always use dropDuplicates(). It’s more explicit about intent, and when requirements change to deduplicate on specific columns, you just add the subset parameter instead of rewriting the logic.

Performance-wise, both methods trigger a shuffle operation to group identical rows. There’s no meaningful difference in execution cost.

Advanced Deduplication with Window Functions

When you need to control exactly which duplicate to keep, window functions are the answer. The pattern uses row_number() to assign a rank within each group, then filters to keep only rank 1:

from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, col, desc

# Define window: partition by user_id, order by updated_at descending
window_spec = Window.partitionBy("user_id").orderBy(desc("updated_at"))

# Add row number within each partition
df_ranked = df.withColumn("row_num", row_number().over(window_spec))
df_ranked.show(truncate=False)
+-------+-----------------+------------+-------------------+-------+
|user_id|email            |name        |updated_at         |row_num|
+-------+-----------------+------------+-------------------+-------+
|1      |alice@example.com|Alice Smith |2024-01-15 10:30:00|1      |
|1      |alice@example.com|Alice Smith |2024-01-15 10:30:00|2      |
|2      |bob@example.com  |Robert Jones|2024-01-16 14:20:00|1      |
|2      |bob@example.com  |Bob Jones   |2024-01-14 09:00:00|2      |
|3      |carol@example.com|Carol White |2024-01-12 11:45:00|1      |
|3      |carol@example.com|Carol White |2024-01-10 08:00:00|2      |
|4      |dave@example.com |Dave Brown  |2024-01-13 16:30:00|1      |
+-------+-----------------+------------+-------------------+-------+
# Keep only the most recent record per user
latest_records = df_ranked.filter(col("row_num") == 1).drop("row_num")
latest_records.show(truncate=False)
+-------+-----------------+------------+-------------------+
|user_id|email            |name        |updated_at         |
+-------+-----------------+------------+-------------------+
|1      |alice@example.com|Alice Smith |2024-01-15 10:30:00|
|2      |bob@example.com  |Robert Jones|2024-01-16 14:20:00|
|3      |carol@example.com|Carol White |2024-01-12 11:45:00|
|4      |dave@example.com |Dave Brown  |2024-01-13 16:30:00|
+-------+-----------------+------------+-------------------+

Now Bob correctly shows “Robert Jones” (the newer name), and Carol shows the later timestamp. This is deterministic and reproducible.

You can make the ordering as complex as needed. Want to keep the record with the most complete data? Order by a completeness score. Want to prefer certain data sources? Add source priority to the ordering:

# Complex ordering example
window_complex = Window.partitionBy("user_id").orderBy(
    desc("updated_at"),
    desc("data_quality_score"),
    "source_priority"
)

Performance Considerations

Deduplication is expensive because it requires comparing rows across partitions. Here’s how to minimize the cost:

Partition by your deduplication key before processing. If your data is already partitioned by user_id and you’re deduplicating by user_id, the operation stays local to each partition:

# Repartition by the deduplication key first
df_partitioned = df.repartition(200, "user_id")
deduped = df_partitioned.dropDuplicates(["user_id"])

Filter early. Remove obviously invalid or irrelevant rows before deduplication to reduce the data volume:

# Filter before deduplicating
df_filtered = df.filter(col("updated_at").isNotNull())
deduped = df_filtered.dropDuplicates(["user_id"])

Cache intermediate results if reused. If you’re running multiple operations on deduplicated data, cache it to avoid recomputation:

deduped = df.dropDuplicates(["user_id"]).cache()
deduped.count()  # Trigger caching

# Now use deduped multiple times without recomputation

Consider approximate deduplication for very large datasets. If exact deduplication is too expensive and you can tolerate some duplicates, techniques like bloom filters or probabilistic data structures can help. But this is rarely necessary—PySpark handles billions of rows fine with proper partitioning.

The right deduplication strategy depends on your data characteristics and business requirements. Start with dropDuplicates() for simple cases, graduate to window functions when you need control, and always verify your results with row counts and spot checks.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.