PySpark - SQL Subqueries in PySpark
Subqueries are nested SELECT statements embedded within a larger query, allowing you to break complex data transformations into logical steps. In traditional SQL databases, subqueries are common for...
Key Insights
- PySpark supports SQL-style subqueries through both the DataFrame API and Spark SQL, enabling complex nested queries in distributed environments while Spark’s Catalyst optimizer automatically converts many subqueries to efficient joins.
- Correlated subqueries that reference outer query columns can create performance bottlenecks in distributed systems—understanding when to use EXISTS vs IN vs explicit joins is critical for scalable data processing.
- Subqueries in the FROM clause (derived tables) are often the cleanest approach for multi-level aggregations, but materializing intermediate results with caching can dramatically improve performance when the subquery is reused.
Introduction to Subqueries in PySpark
Subqueries are nested SELECT statements embedded within a larger query, allowing you to break complex data transformations into logical steps. In traditional SQL databases, subqueries are common for filtering, aggregation, and creating derived datasets. PySpark brings this capability to distributed computing, supporting subqueries through both native Spark SQL and the DataFrame API.
The beauty of PySpark’s implementation lies in the Catalyst optimizer. When you write a subquery, Spark doesn’t naively execute it as a separate operation—it analyzes the entire query plan and often rewrites subqueries as joins or other optimized operations. This means you can write readable, SQL-like code without sacrificing performance in most cases.
However, distributed computing introduces unique challenges. A subquery that performs well on a single-node database might create data shuffling nightmares across a Spark cluster. Understanding these nuances separates competent PySpark developers from those who write code that grinds to a halt in production.
Setting Up the Environment
Let’s create a realistic dataset to work with throughout this article. We’ll model a company with employees, departments, and sales data.
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, avg, sum, count
# Initialize SparkSession
spark = SparkSession.builder \
.appName("SubqueryExamples") \
.config("spark.sql.adaptive.enabled", "true") \
.getOrCreate()
# Create employees DataFrame
employees_data = [
(1, "Alice", "Engineering", 95000, "Senior"),
(2, "Bob", "Engineering", 75000, "Mid"),
(3, "Charlie", "Sales", 65000, "Mid"),
(4, "Diana", "Sales", 85000, "Senior"),
(5, "Eve", "Marketing", 70000, "Mid"),
(6, "Frank", "Engineering", 120000, "Senior"),
(7, "Grace", "Sales", 55000, "Junior"),
(8, "Henry", "Marketing", 90000, "Senior")
]
employees = spark.createDataFrame(
employees_data,
["emp_id", "name", "department", "salary", "level"]
)
# Create departments DataFrame
departments_data = [
("Engineering", 500000, 10),
("Sales", 300000, 8),
("Marketing", 200000, 5)
]
departments = spark.createDataFrame(
departments_data,
["dept_name", "budget", "headcount_target"]
)
# Create sales DataFrame
sales_data = [
(1, 1, 15000), (2, 1, 22000), (3, 3, 18000),
(4, 3, 25000), (5, 4, 30000), (6, 4, 28000),
(7, 7, 12000), (8, 7, 14000)
]
sales = spark.createDataFrame(sales_data, ["sale_id", "emp_id", "amount"])
# Register as temporary views for SQL queries
employees.createOrReplaceTempView("employees")
departments.createOrReplaceTempView("departments")
sales.createOrReplaceTempView("sales")
Scalar Subqueries
Scalar subqueries return a single value and can appear in SELECT, WHERE, or HAVING clauses. They’re perfect for comparing individual records against aggregate metrics.
# Find employees earning above average salary using SQL
above_avg_sql = spark.sql("""
SELECT name, department, salary
FROM employees
WHERE salary > (SELECT AVG(salary) FROM employees)
ORDER BY salary DESC
""")
above_avg_sql.show()
# Equivalent using DataFrame API
avg_salary = employees.agg(avg("salary")).first()[0]
above_avg_df = employees.filter(col("salary") > avg_salary) \
.select("name", "department", "salary") \
.orderBy(col("salary").desc())
Here’s a more complex example comparing department metrics against company averages:
# Calculate each department's average salary vs company average
dept_comparison = spark.sql("""
SELECT
department,
AVG(salary) as dept_avg_salary,
(SELECT AVG(salary) FROM employees) as company_avg_salary,
AVG(salary) - (SELECT AVG(salary) FROM employees) as difference
FROM employees
GROUP BY department
ORDER BY difference DESC
""")
dept_comparison.show()
Correlated vs Non-Correlated Subqueries
Non-correlated subqueries execute independently of the outer query—they run once and return a result. Correlated subqueries reference columns from the outer query, potentially executing multiple times.
# Non-correlated: Find all employees in the highest-budget department
non_correlated = spark.sql("""
SELECT name, department, salary
FROM employees
WHERE department = (
SELECT dept_name
FROM departments
ORDER BY budget DESC
LIMIT 1
)
""")
# Correlated: Find employees earning above their department's average
correlated = spark.sql("""
SELECT e1.name, e1.department, e1.salary
FROM employees e1
WHERE e1.salary > (
SELECT AVG(e2.salary)
FROM employees e2
WHERE e2.department = e1.department
)
ORDER BY e1.department, e1.salary DESC
""")
correlated.show()
The correlated subquery is conceptually cleaner but can be less efficient. Spark’s optimizer often converts this to a join with aggregation, but you should verify with .explain() for critical queries.
Subqueries with IN, EXISTS, and NOT EXISTS
These predicates enable sophisticated filtering based on subquery results.
# IN: Find employees who have made sales
employees_with_sales = spark.sql("""
SELECT name, department
FROM employees
WHERE emp_id IN (SELECT DISTINCT emp_id FROM sales)
""")
# EXISTS: Find departments with at least one senior employee
depts_with_seniors = spark.sql("""
SELECT DISTINCT d.dept_name, d.budget
FROM departments d
WHERE EXISTS (
SELECT 1
FROM employees e
WHERE e.department = d.dept_name
AND e.level = 'Senior'
)
""")
# NOT EXISTS: Find employees without any sales
employees_no_sales = spark.sql("""
SELECT name, department, level
FROM employees e
WHERE NOT EXISTS (
SELECT 1
FROM sales s
WHERE s.emp_id = e.emp_id
)
""")
employees_no_sales.show()
Use EXISTS instead of IN when checking for existence rather than matching values—it’s often more efficient because EXISTS can short-circuit once a match is found.
Subqueries in FROM Clause (Derived Tables)
Subqueries in the FROM clause create temporary result sets, perfect for multi-level aggregations.
# Calculate department rankings based on total sales per employee
dept_rankings = spark.sql("""
SELECT
department,
avg_sales_per_employee,
RANK() OVER (ORDER BY avg_sales_per_employee DESC) as rank
FROM (
SELECT
e.department,
AVG(sales_total) as avg_sales_per_employee
FROM employees e
LEFT JOIN (
SELECT emp_id, SUM(amount) as sales_total
FROM sales
GROUP BY emp_id
) s ON e.emp_id = s.emp_id
GROUP BY e.department
) dept_sales
""")
dept_rankings.show()
This pattern is cleaner than trying to cram everything into a single-level query. Each subquery has a clear purpose: the innermost aggregates sales by employee, the middle joins and aggregates by department, and the outer ranks departments.
Performance Considerations and Best Practices
Not all subqueries are created equal in distributed systems. Here’s what you need to know:
When to use subqueries vs joins: Subqueries are great for readability, but explicit joins often give you more control over execution strategy. Compare execution plans:
# Subquery approach
subquery_plan = spark.sql("""
SELECT name, salary
FROM employees
WHERE department IN (SELECT dept_name FROM departments WHERE budget > 250000)
""")
# Join approach
join_plan = spark.sql("""
SELECT DISTINCT e.name, e.salary
FROM employees e
INNER JOIN departments d ON e.department = d.dept_name
WHERE d.budget > 250000
""")
print("Subquery Plan:")
subquery_plan.explain(mode="formatted")
print("\nJoin Plan:")
join_plan.explain(mode="formatted")
Caching intermediate results: If a subquery is expensive and used multiple times, cache it:
# Cache the subquery result
high_performers = spark.sql("""
SELECT emp_id, name, department
FROM employees
WHERE salary > (SELECT AVG(salary) * 1.2 FROM employees)
""")
high_performers.cache()
# Use it multiple times
high_performers.count() # Materializes the cache
# Now subsequent operations are fast
by_dept = high_performers.groupBy("department").count()
by_dept.show()
Avoid correlated subqueries in tight loops: If you’re filtering millions of records with a correlated subquery, consider rewriting as a window function or self-join:
# Instead of correlated subquery, use window functions
from pyspark.sql.window import Window
window_spec = Window.partitionBy("department")
employees_above_dept_avg = employees.withColumn(
"dept_avg", avg("salary").over(window_spec)
).filter(col("salary") > col("dept_avg"))
employees_above_dept_avg.select("name", "department", "salary", "dept_avg").show()
Monitor broadcast joins: Small subquery results can be broadcast to all nodes, avoiding shuffles. Spark does this automatically for small tables, but you can force it:
from pyspark.sql.functions import broadcast
# Force broadcast of small lookup table
result = employees.join(
broadcast(departments),
employees.department == departments.dept_name
)
The key takeaway: write subqueries for clarity, then profile and optimize. PySpark’s Catalyst optimizer handles many cases well, but understanding the underlying execution is essential for production workloads. Always check execution plans with .explain() and monitor stage-level metrics in the Spark UI for queries processing significant data volumes.