Apache Spark SQL - Complete Tutorial

Spark SQL requires a SparkSession as the entry point. This unified interface replaced the older SQLContext and HiveContext.

Key Insights

  • Spark SQL bridges the gap between structured data processing and distributed computing, offering a unified interface for batch and streaming analytics with performance optimizations like Catalyst query optimizer and Tungsten execution engine
  • DataFrames and Datasets provide type-safe, optimized abstractions over RDDs with automatic query optimization, while supporting both SQL queries and programmatic transformations interchangeably
  • Understanding partitioning, bucketing, and caching strategies is critical for production deployments—poorly optimized Spark SQL jobs can be orders of magnitude slower than well-tuned ones

Setting Up Spark SQL

Spark SQL requires a SparkSession as the entry point. This unified interface replaced the older SQLContext and HiveContext.

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("SparkSQLTutorial") \
    .config("spark.sql.shuffle.partitions", "200") \
    .config("spark.sql.adaptive.enabled", "true") \
    .enableHiveSupport() \
    .getOrCreate()

# Verify setup
print(spark.version)

For Scala:

import org.apache.spark.sql.SparkSession

val spark = SparkSession.builder()
  .appName("SparkSQLTutorial")
  .config("spark.sql.shuffle.partitions", "200")
  .config("spark.sql.adaptive.enabled", "true")
  .enableHiveSupport()
  .getOrCreate()

import spark.implicits._

The enableHiveSupport() method allows Spark to read from Hive metastore and use Hive SerDes, UDFs, and file formats.

Creating DataFrames

DataFrames can be created from various sources: RDDs, structured files, external databases, or programmatically.

# From JSON file
df_json = spark.read.json("data/users.json")

# From CSV with schema inference
df_csv = spark.read \
    .option("header", "true") \
    .option("inferSchema", "true") \
    .csv("data/sales.csv")

# From Parquet (columnar format)
df_parquet = spark.read.parquet("data/transactions.parquet")

# Programmatically with explicit schema
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("name", StringType(), True),
    StructField("age", IntegerType(), True)
])

data = [(1, "Alice", 29), (2, "Bob", 35), (3, "Charlie", 42)]
df = spark.createDataFrame(data, schema)

df.show()
df.printSchema()

SQL Queries vs DataFrame API

Spark SQL supports both SQL syntax and programmatic DataFrame operations. Both compile to the same optimized execution plan.

# Register DataFrame as temporary view
df.createOrReplaceTempView("users")

# SQL approach
result_sql = spark.sql("""
    SELECT name, age
    FROM users
    WHERE age > 30
    ORDER BY age DESC
""")

# DataFrame API approach
result_df = df.select("name", "age") \
    .filter(df.age > 30) \
    .orderBy(df.age.desc())

# Both produce identical execution plans
result_sql.explain()
result_df.explain()

Complex Transformations

Spark SQL excels at complex analytical queries with window functions, aggregations, and joins.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# Sample sales data
sales_data = [
    ("2024-01-01", "Electronics", "Laptop", 1200, 2),
    ("2024-01-01", "Electronics", "Mouse", 25, 10),
    ("2024-01-02", "Electronics", "Laptop", 1200, 1),
    ("2024-01-02", "Clothing", "Shirt", 40, 5),
    ("2024-01-03", "Clothing", "Pants", 60, 3)
]

sales_df = spark.createDataFrame(
    sales_data, 
    ["date", "category", "product", "price", "quantity"]
)

# Complex aggregation with window functions
window_spec = Window.partitionBy("category").orderBy("date")

analysis = sales_df \
    .withColumn("revenue", F.col("price") * F.col("quantity")) \
    .withColumn("running_total", F.sum("revenue").over(window_spec)) \
    .withColumn("rank", F.rank().over(
        Window.partitionBy("category").orderBy(F.col("revenue").desc())
    )) \
    .groupBy("category") \
    .agg(
        F.sum("revenue").alias("total_revenue"),
        F.avg("price").alias("avg_price"),
        F.count("*").alias("transaction_count")
    )

analysis.show()

Joins and Broadcasting

Join performance depends heavily on data size and distribution. Use broadcast joins for small tables.

# Large transaction table
transactions = spark.read.parquet("data/transactions.parquet")

# Small dimension table (< 10MB)
products = spark.read.parquet("data/products.parquet")

# Explicit broadcast for small table
from pyspark.sql.functions import broadcast

result = transactions.join(
    broadcast(products),
    transactions.product_id == products.id,
    "left"
)

# For large-to-large joins, consider bucketing
transactions.write \
    .bucketBy(100, "product_id") \
    .sortBy("transaction_date") \
    .saveAsTable("bucketed_transactions")

products.write \
    .bucketBy(100, "id") \
    .saveAsTable("bucketed_products")

# Bucketed join avoids shuffle
bucketed_txn = spark.table("bucketed_transactions")
bucketed_prod = spark.table("bucketed_products")

efficient_join = bucketed_txn.join(
    bucketed_prod,
    bucketed_txn.product_id == bucketed_prod.id
)

User-Defined Functions (UDFs)

UDFs extend Spark SQL with custom logic, but use them sparingly—they bypass Catalyst optimizer.

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, IntegerType

# Python UDF (slower due to serialization)
@udf(returnType=StringType())
def categorize_age(age):
    if age < 18:
        return "Minor"
    elif age < 65:
        return "Adult"
    else:
        return "Senior"

df_with_category = df.withColumn("age_category", categorize_age(df.age))

# Pandas UDF (vectorized, much faster)
from pyspark.sql.functions import pandas_udf
import pandas as pd

@pandas_udf(IntegerType())
def complex_calculation(prices: pd.Series, quantities: pd.Series) -> pd.Series:
    return (prices * quantities * 1.15).astype(int)

sales_df = sales_df.withColumn(
    "total_with_tax",
    complex_calculation(F.col("price"), F.col("quantity"))
)

Performance Optimization

Partitioning and Caching

# Repartition before expensive operations
df_repartitioned = df.repartition(100, "category")

# Cache frequently accessed data
df_cached = df.cache()
df_cached.count()  # Materializes cache

# Persist with storage level
from pyspark import StorageLevel
df.persist(StorageLevel.MEMORY_AND_DISK)

# Monitor cached data
spark.catalog.cacheTable("users")
print(spark.catalog.isCached("users"))

Adaptive Query Execution

# Enable AQE for runtime optimization
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")

# Monitor query execution
df.explain("extended")
df.explain("cost")

Working with Structured Streaming

Spark SQL supports streaming queries with the same DataFrame API.

# Read from Kafka stream
streaming_df = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("subscribe", "transactions") \
    .load()

# Parse JSON payload
from pyspark.sql.functions import from_json, col

schema = StructType([
    StructField("transaction_id", StringType()),
    StructField("amount", IntegerType()),
    StructField("timestamp", StringType())
])

parsed_stream = streaming_df \
    .select(from_json(col("value").cast("string"), schema).alias("data")) \
    .select("data.*")

# Windowed aggregation
windowed_agg = parsed_stream \
    .groupBy(
        F.window(F.col("timestamp"), "1 hour", "15 minutes")
    ) \
    .agg(F.sum("amount").alias("total_amount"))

# Write to console
query = windowed_agg.writeStream \
    .outputMode("update") \
    .format("console") \
    .start()

query.awaitTermination()

Production Considerations

Monitor and tune these critical parameters:

# Memory management
spark.conf.set("spark.executor.memory", "4g")
spark.conf.set("spark.driver.memory", "2g")
spark.conf.set("spark.memory.fraction", "0.8")

# Shuffle optimization
spark.conf.set("spark.sql.shuffle.partitions", "200")
spark.conf.set("spark.sql.files.maxPartitionBytes", "134217728")  # 128MB

# Enable statistics collection for better optimization
spark.sql("ANALYZE TABLE users COMPUTE STATISTICS")
spark.sql("ANALYZE TABLE users COMPUTE STATISTICS FOR COLUMNS id, age")

Use the Spark UI (default port 4040) to identify bottlenecks in stages, tasks, and storage. Profile execution plans with explain() and adjust partitioning strategies accordingly. For production workloads, always benchmark with representative data volumes—Spark’s performance characteristics change dramatically at scale.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.