How to Inner Join in PySpark
Joins are the backbone of relational data processing. Whether you're building ETL pipelines, preparing features for machine learning, or generating reports, you'll spend a significant portion of your...
Key Insights
- PySpark’s
join()method supports both simple column-name syntax (join(df2, "key")) and explicit condition syntax (join(df2, df1.key == df2.key, "inner")), but choosing the right one affects whether you get duplicate columns in your result. - Inner joins are the default join type in PySpark, so
how="inner"is optional—but being explicit improves code readability and prevents confusion during maintenance. - For joins involving a small DataFrame (under ~10MB), always use
broadcast()to avoid expensive shuffle operations and dramatically improve query performance.
Introduction
Joins are the backbone of relational data processing. Whether you’re building ETL pipelines, preparing features for machine learning, or generating reports, you’ll spend a significant portion of your time combining datasets. In distributed computing with PySpark, understanding how joins work isn’t just about getting correct results—it’s about getting them efficiently across a cluster of machines.
Inner joins are the most frequently used join type because they answer the most common question: “Give me records that exist in both datasets.” Unlike outer joins that preserve unmatched rows, inner joins return only the intersection—rows where the join key exists in both DataFrames. This makes them predictable, performant, and the right default choice for most data integration tasks.
This article covers everything you need to perform inner joins effectively in PySpark, from basic syntax to performance optimization techniques that matter in production.
Inner Join Fundamentals
An inner join combines two DataFrames based on a matching condition and returns only the rows where that condition is satisfied in both DataFrames. If a row in the left DataFrame has no matching row in the right DataFrame (or vice versa), it’s excluded from the result.
Conceptually, think of it as the intersection in a Venn diagram:
DataFrame A DataFrame B
┌─────────────┐ ┌─────────────┐
│ │ │ │
│ ┌──────┼─────┼──────┐ │
│ │ INNER JOIN RESULT │ │
│ └──────┼─────┼──────┘ │
│ │ │ │
└─────────────┘ └─────────────┘
Let’s create two sample DataFrames that we’ll use throughout this article:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
spark = SparkSession.builder.appName("InnerJoinDemo").getOrCreate()
# Employee data
employees_data = [
(1, "Alice", 101, "Engineering"),
(2, "Bob", 102, "Marketing"),
(3, "Charlie", 101, "Engineering"),
(4, "Diana", 103, "Sales"),
(5, "Eve", 999, "Unknown"), # Department doesn't exist
]
employees = spark.createDataFrame(
employees_data,
["emp_id", "name", "dept_id", "dept_name"]
)
# Department data
departments_data = [
(101, "Engineering", "Building A"),
(102, "Marketing", "Building B"),
(103, "Sales", "Building C"),
(104, "HR", "Building A"), # No employees
]
departments = spark.createDataFrame(
departments_data,
["dept_id", "dept_name", "location"]
)
Notice that employee Eve has dept_id=999 which doesn’t exist in departments, and department HR (104) has no employees. An inner join will exclude both of these from the result.
Basic Inner Join Syntax
PySpark provides two primary syntaxes for inner joins. The first uses a column name directly:
# Simple syntax - when join column has the same name in both DataFrames
result = employees.join(departments, "dept_id", "inner")
result.show()
Output:
+-------+------+-------+-----------+-----------+----------+
|dept_id|emp_id| name| dept_name| dept_name| location|
+-------+------+-------+-----------+-----------+----------+
| 101| 1| Alice|Engineering|Engineering|Building A|
| 101| 3|Charlie|Engineering|Engineering|Building A|
| 102| 2| Bob| Marketing| Marketing|Building B|
| 103| 4| Diana| Sales| Sales|Building C|
+-------+------+-------+-----------+-----------+----------+
Eve (dept_id=999) and the HR department (dept_id=104) are excluded because they have no matches.
The second syntax uses an explicit condition:
# Explicit condition syntax
result = employees.join(
departments,
employees.dept_id == departments.dept_id,
"inner"
)
result.show()
The key difference: the simple column-name syntax automatically deduplicates the join column, while the explicit condition syntax keeps both columns. This matters when you need to reference the join column later in your pipeline.
Since inner is the default join type, you can omit it:
# These are equivalent
result = employees.join(departments, "dept_id")
result = employees.join(departments, "dept_id", "inner")
However, I recommend always specifying the join type explicitly. When someone reads your code six months later, they shouldn’t have to remember PySpark’s defaults.
Joining on Multiple Columns
Real-world data often requires matching on composite keys. PySpark handles this with a list of column names or compound conditions.
# Sample data with composite keys
orders_data = [
("2024-01", "US", 1001, 500),
("2024-01", "UK", 1002, 300),
("2024-02", "US", 1003, 450),
]
order_details_data = [
("2024-01", "US", "Widget A", 100),
("2024-01", "US", "Widget B", 400),
("2024-01", "UK", "Gadget X", 300),
("2024-02", "CA", "Widget C", 200), # No matching order
]
orders = spark.createDataFrame(
orders_data,
["month", "region", "order_id", "total"]
)
order_details = spark.createDataFrame(
order_details_data,
["month", "region", "product", "amount"]
)
# Join on multiple columns using a list
result = orders.join(order_details, ["month", "region"], "inner")
result.show()
For more complex conditions, use the & operator:
# Explicit multi-column condition
result = orders.join(
order_details,
(orders.month == order_details.month) & (orders.region == order_details.region),
"inner"
)
When using & for multiple conditions, wrap each condition in parentheses. Python’s operator precedence will cause unexpected behavior otherwise.
Handling Column Name Conflicts
The most common frustration with PySpark joins is duplicate column names. When both DataFrames have columns with the same name (beyond the join key), you’ll get ambiguous references.
# Both DataFrames have 'dept_name' column
result = employees.join(departments, "dept_id", "inner")
# result.select("dept_name") # This would fail - ambiguous reference
Strategy 1: Use aliases
from pyspark.sql.functions import col
emp_aliased = employees.alias("emp")
dept_aliased = departments.alias("dept")
result = emp_aliased.join(
dept_aliased,
col("emp.dept_id") == col("dept.dept_id"),
"inner"
)
# Now you can reference columns unambiguously
result.select("emp.name", "dept.location", "dept.dept_name").show()
Strategy 2: Rename columns before joining
departments_renamed = departments.withColumnRenamed("dept_name", "official_dept_name")
result = employees.join(departments_renamed, "dept_id", "inner")
result.show()
Strategy 3: Drop duplicate columns after joining
result = employees.join(
departments,
employees.dept_id == departments.dept_id,
"inner"
).drop(departments.dept_id).drop(departments.dept_name)
result.show()
My preference is Strategy 2—renaming before the join. It makes downstream code clearer and avoids the gotchas of working with ambiguous DataFrame references.
Performance Considerations
Inner joins in PySpark can be expensive because they often require shuffling data across the cluster to co-locate matching keys. Here’s how to minimize that cost.
Broadcast joins for small tables
When one DataFrame is small enough to fit in memory on each executor (typically under 10MB, configurable up to ~8GB), broadcast it:
from pyspark.sql.functions import broadcast
# Departments table is small - broadcast it
result = employees.join(broadcast(departments), "dept_id", "inner")
result.explain()
The explain() output will show BroadcastHashJoin instead of SortMergeJoin, confirming the optimization took effect. This eliminates the shuffle for the small table entirely.
Partition alignment
If you’re joining large tables repeatedly, partition them on the join key:
# Repartition both DataFrames on the join key
employees_partitioned = employees.repartition(100, "dept_id")
departments_partitioned = departments.repartition(100, "dept_id")
# Subsequent joins on dept_id won't require a shuffle
result = employees_partitioned.join(departments_partitioned, "dept_id", "inner")
Filter early
Apply filters before joins to reduce the data volume being shuffled:
# Bad: filter after join
result = employees.join(departments, "dept_id", "inner").filter(col("location") == "Building A")
# Good: filter before join when possible
filtered_depts = departments.filter(col("location") == "Building A")
result = employees.join(filtered_depts, "dept_id", "inner")
Spark’s optimizer often pushes predicates down automatically, but being explicit ensures the optimization happens and makes your intent clear.
Conclusion
Inner joins in PySpark are straightforward once you understand the syntax options and their tradeoffs. Use the simple column-name syntax (join(df, "key")) when you want automatic deduplication of the join column. Use the explicit condition syntax (join(df, df1.key == df2.key)) when you need more control or are joining on expressions.
Always handle column name conflicts proactively—rename columns before joining rather than dealing with ambiguous references later. And when performance matters, reach for broadcast() for small tables and ensure your large tables are partitioned on join keys.
For cases where you need to preserve unmatched rows, explore left joins, right joins, and full outer joins, which follow the same syntax patterns but with different how parameters.