PySpark - Intersect Two DataFrames

Finding common rows between two DataFrames is a fundamental operation in data engineering. In PySpark, intersection operations identify records that exist in both DataFrames, comparing entire rows...

Key Insights

  • PySpark’s intersect() returns distinct common rows between DataFrames, while intersectAll() preserves duplicates—choose based on whether you need to count duplicate occurrences
  • DataFrames must have identical schemas (column count, order, and types) for intersection operations; use select() and cast() to align mismatched schemas before intersecting
  • For large-scale intersections, monitor shuffle operations and consider repartitioning strategies; for small DataFrames, alternative approaches like broadcast joins may perform better

Introduction to DataFrame Intersection

Finding common rows between two DataFrames is a fundamental operation in data engineering. In PySpark, intersection operations identify records that exist in both DataFrames, comparing entire rows for equality. This is particularly valuable for data reconciliation tasks, where you need to verify that two data sources contain the same records, or for data quality checks where you’re validating that expected records appear in your processed datasets.

Common use cases include comparing production and staging environments to find matching records, identifying overlapping customer segments across different marketing campaigns, validating data migrations by ensuring critical records transferred correctly, and performing incremental processing by finding unchanged records between snapshots.

Let’s start with a simple example to visualize what intersection means:

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

spark = SparkSession.builder.appName("IntersectExample").getOrCreate()

# First DataFrame - Customer orders from System A
df1 = spark.createDataFrame([
    (1, "Alice", "Premium"),
    (2, "Bob", "Standard"),
    (3, "Charlie", "Premium"),
    (4, "David", "Basic")
], ["customer_id", "name", "tier"])

# Second DataFrame - Customer orders from System B
df2 = spark.createDataFrame([
    (2, "Bob", "Standard"),
    (3, "Charlie", "Premium"),
    (5, "Eve", "Premium"),
    (6, "Frank", "Basic")
], ["customer_id", "name", "tier"])

df1.show()
df2.show()

In this example, rows for Bob and Charlie appear in both DataFrames, making them candidates for intersection.

The intersect() Method

The intersect() method is PySpark’s primary tool for finding common rows. It performs a row-by-row comparison across all columns and returns only the distinct rows that appear in both DataFrames. The syntax is straightforward:

result = df1.intersect(df2)
result.show()

Output:

+-----------+-------+--------+
|customer_id|   name|    tier|
+-----------+-------+--------+
|          2|    Bob|Standard|
|          3|Charlie| Premium|
+-----------+-------+--------+

The intersect() method compares entire rows, meaning all column values must match exactly. It automatically removes duplicates from the result, similar to SQL’s INTERSECT operator. This behavior is crucial to understand: even if a row appears multiple times in both DataFrames, it will appear only once in the result.

Behind the scenes, PySpark performs this operation as a distributed computation, shuffling data across the cluster to group identical rows together. The method returns a new DataFrame, leaving the original DataFrames unchanged.

intersect() vs intersectAll()

PySpark provides two intersection methods with different handling of duplicate rows. Understanding when to use each is essential for correct results.

# Create DataFrames with duplicate rows
df_with_dupes1 = spark.createDataFrame([
    (1, "Product A"),
    (1, "Product A"),  # duplicate
    (2, "Product B"),
    (2, "Product B"),  # duplicate
    (3, "Product C")
], ["id", "product"])

df_with_dupes2 = spark.createDataFrame([
    (1, "Product A"),
    (2, "Product B"),
    (2, "Product B"),  # duplicate
    (2, "Product B"),  # another duplicate
    (4, "Product D")
], ["id", "product"])

print("Using intersect() - returns distinct rows:")
df_with_dupes1.intersect(df_with_dupes2).show()

print("Using intersectAll() - preserves duplicates:")
df_with_dupes1.intersectAll(df_with_dupes2).show()

Output for intersect():

+---+---------+
| id|  product|
+---+---------+
|  1|Product A|
|  2|Product B|
+---+---------+

Output for intersectAll():

+---+---------+
| id|  product|
+---+---------+
|  1|Product A|
|  2|Product B|
|  2|Product B|
+---+---------+

With intersectAll(), the result contains two rows for Product B because both DataFrames have at least two occurrences of that row. The method takes the minimum count of duplicates from each DataFrame. Use intersect() when you need set-based operations and intersectAll() when duplicate counts matter for your analysis.

Handling Schema Differences

A critical requirement for intersection operations is schema compatibility. Both DataFrames must have the same number of columns, in the same order, with matching data types. Attempting to intersect incompatible DataFrames will raise an error.

# DataFrames with different column orders
df_ordered1 = spark.createDataFrame([
    (1, "Alice", 25),
    (2, "Bob", 30)
], ["id", "name", "age"])

df_ordered2 = spark.createDataFrame([
    ("Alice", 25, 1),
    ("Bob", 30, 2)
], ["name", "age", "id"])

# This would fail - different column order
# df_ordered1.intersect(df_ordered2)  # Error!

# Fix by reordering columns
df_ordered2_fixed = df_ordered2.select("id", "name", "age")
result = df_ordered1.intersect(df_ordered2_fixed)
result.show()

When dealing with type mismatches, use cast() to align data types:

from pyspark.sql.functions import col

# DataFrame with age as string
df_string_type = spark.createDataFrame([
    (1, "Alice", "25"),
    (2, "Bob", "30")
], ["id", "name", "age"])

# DataFrame with age as integer
df_int_type = spark.createDataFrame([
    (1, "Alice", 25),
    (2, "Bob", 30)
], ["id", "name", "age"])

# Cast to match types before intersection
df_aligned = df_string_type.select(
    col("id"),
    col("name"),
    col("age").cast(IntegerType())
)

result = df_aligned.intersect(df_int_type)
result.show()

Always verify schemas using printSchema() before performing intersections on unfamiliar DataFrames.

Performance Considerations

Intersection operations can be expensive, particularly with large DataFrames. The operation requires a full shuffle of data across the cluster to identify matching rows.

# Examine the execution plan
df1.intersect(df2).explain()

The explain output reveals the shuffle operations and stages involved. For better performance with large datasets, consider these strategies:

# Repartition on key columns before intersection
df1_partitioned = df1.repartition(10, "customer_id")
df2_partitioned = df2.repartition(10, "customer_id")

result = df1_partitioned.intersect(df2_partitioned)

# Cache if you'll reuse the result
result.cache()
result.count()  # Trigger caching

When one DataFrame is significantly smaller (typically under 10MB), consider alternative approaches using broadcast joins, which can be more efficient by avoiding large shuffles.

Practical Use Cases

Here’s a complete real-world example: comparing production and staging databases to find records that successfully migrated:

# Production data
prod_customers = spark.createDataFrame([
    (101, "Acme Corp", "active", "2024-01-15"),
    (102, "TechStart", "active", "2024-01-20"),
    (103, "DataCo", "inactive", "2023-12-01"),
    (104, "CloudSys", "active", "2024-02-01")
], ["customer_id", "company_name", "status", "signup_date"])

# Staging data after migration
staging_customers = spark.createDataFrame([
    (101, "Acme Corp", "active", "2024-01-15"),
    (102, "TechStart", "active", "2024-01-20"),
    (105, "NewClient", "active", "2024-03-01")
], ["customer_id", "company_name", "status", "signup_date"])

# Find successfully migrated records
migrated_records = prod_customers.intersect(staging_customers)

print("Successfully migrated records:")
migrated_records.show()

# Find records that failed to migrate
failed_migration = prod_customers.subtract(staging_customers)
print(f"Failed migrations: {failed_migration.count()}")

This pattern is invaluable for data validation, allowing you to quickly identify discrepancies between environments.

Alternative Approaches

While intersect() is convenient, you can achieve the same result using inner joins. This approach offers more flexibility for complex scenarios:

from pyspark.sql.functions import col

# Using inner join as an alternative to intersect
join_condition = [df1[col_name] == df2[col_name] for col_name in df1.columns]
result_via_join = df1.join(df2, join_condition, "inner").select(df1.columns).distinct()

result_via_join.show()

You can also use SQL syntax if you prefer:

df1.createOrReplaceTempView("table1")
df2.createOrReplaceTempView("table2")

result_sql = spark.sql("""
    SELECT * FROM table1
    INTERSECT
    SELECT * FROM table2
""")

result_sql.show()

The SQL approach can be more readable for teams familiar with SQL, though it performs identically to the DataFrame API under the hood.

Choose intersect() for simplicity and clarity when finding common rows. Use join-based approaches when you need additional logic, such as comparing only specific columns or performing transformations during the comparison. For production systems processing large volumes, always profile your approach using explain() and monitor Spark UI metrics to ensure optimal performance.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.