PySpark - Read from JDBC/Database

• PySpark's JDBC connector enables distributed reading from relational databases with automatic partitioning across executors, but requires careful configuration of partition columns and bounds to...

Key Insights

• PySpark’s JDBC connector enables distributed reading from relational databases with automatic partitioning across executors, but requires careful configuration of partition columns and bounds to avoid bottlenecks • Connection pooling and predicate pushdown are critical for performance—poorly configured JDBC reads can create single-threaded bottlenecks that negate Spark’s distributed processing advantages • The numPartitions, partitionColumn, lowerBound, and upperBound parameters control parallelism; without them, Spark reads the entire table through a single connection

Basic JDBC Configuration

Reading from databases in PySpark requires the appropriate JDBC driver JAR and connection parameters. The driver must be available on the classpath before creating your SparkSession.

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("JDBC Read Example") \
    .config("spark.jars", "/path/to/postgresql-42.6.0.jar") \
    .getOrCreate()

# Basic connection properties
jdbc_url = "jdbc:postgresql://localhost:5432/mydb"
connection_properties = {
    "user": "dbuser",
    "password": "dbpassword",
    "driver": "org.postgresql.Driver"
}

df = spark.read.jdbc(
    url=jdbc_url,
    table="customers",
    properties=connection_properties
)

df.show()

For production environments, avoid hardcoding credentials. Use environment variables or secret management systems:

import os

connection_properties = {
    "user": os.getenv("DB_USER"),
    "password": os.getenv("DB_PASSWORD"),
    "driver": "org.postgresql.Driver"
}

Parallel Reading with Partitioning

Without partitioning configuration, Spark reads the entire table through a single JDBC connection, creating a performance bottleneck. Proper partitioning distributes the read across multiple executors.

# Partitioned read using numeric column
df = spark.read.jdbc(
    url=jdbc_url,
    table="orders",
    column="order_id",  # Must be numeric
    lowerBound=1,
    upperBound=1000000,
    numPartitions=10,
    properties=connection_properties
)

# Spark generates queries like:
# SELECT * FROM orders WHERE order_id >= 1 AND order_id < 100000
# SELECT * FROM orders WHERE order_id >= 100000 AND order_id < 200000
# ... and so on

The lowerBound and upperBound don’t filter data—they define the range for partition calculation. Data outside these bounds is still read (in the first or last partition). Choose values that represent the actual data distribution:

# Get actual bounds from the database
bounds_df = spark.read.jdbc(
    url=jdbc_url,
    table="(SELECT MIN(order_id) as min_id, MAX(order_id) as max_id FROM orders) as bounds",
    properties=connection_properties
)

bounds = bounds_df.collect()[0]

df = spark.read.jdbc(
    url=jdbc_url,
    table="orders",
    column="order_id",
    lowerBound=bounds["min_id"],
    upperBound=bounds["max_id"],
    numPartitions=20,
    properties=connection_properties
)

Custom Partition Predicates

For non-numeric partition columns or complex partitioning logic, use custom predicates:

# Partition by date ranges
predicates = [
    "order_date >= '2024-01-01' AND order_date < '2024-02-01'",
    "order_date >= '2024-02-01' AND order_date < '2024-03-01'",
    "order_date >= '2024-03-01' AND order_date < '2024-04-01'",
    "order_date >= '2024-04-01' AND order_date < '2024-05-01'"
]

df = spark.read.jdbc(
    url=jdbc_url,
    table="orders",
    predicates=predicates,
    properties=connection_properties
)

# Partition by categorical values
category_predicates = [
    "category = 'Electronics'",
    "category = 'Clothing'",
    "category = 'Books'",
    "category NOT IN ('Electronics', 'Clothing', 'Books')"
]

df = spark.read.jdbc(
    url=jdbc_url,
    table="products",
    predicates=category_predicates,
    properties=connection_properties
)

Using Subqueries and Filtering

Push filtering logic to the database level using SQL subqueries to reduce data transfer:

# Instead of reading entire table and filtering in Spark
query = """
(SELECT o.*, c.customer_name, c.region
 FROM orders o
 JOIN customers c ON o.customer_id = c.customer_id
 WHERE o.order_date >= '2024-01-01'
   AND o.status = 'completed'
   AND c.region = 'APAC') as filtered_orders
"""

df = spark.read.jdbc(
    url=jdbc_url,
    table=query,
    column="order_id",
    lowerBound=1,
    upperBound=1000000,
    numPartitions=10,
    properties=connection_properties
)

This approach leverages database indexes and reduces network transfer. The subquery must be aliased for JDBC compatibility.

Connection Pool Configuration

For production workloads, configure connection pooling to manage database connections efficiently:

connection_properties = {
    "user": "dbuser",
    "password": "dbpassword",
    "driver": "org.postgresql.Driver",
    "fetchsize": "10000",  # Rows fetched per round trip
    "batchsize": "10000",  # For writes
    "queryTimeout": "300",  # Seconds
    "connectionTimeout": "60"
}

# PostgreSQL-specific optimizations
connection_properties.update({
    "ssl": "true",
    "sslmode": "require",
    "ApplicationName": "SparkJDBCReader"
})

The fetchsize parameter is critical for memory management. Default values vary by driver and can cause out-of-memory errors with large result sets.

Database-Specific Examples

MySQL

mysql_properties = {
    "user": "root",
    "password": "password",
    "driver": "com.mysql.cj.jdbc.Driver",
    "fetchsize": "10000",
    "zeroDateTimeBehavior": "convertToNull",
    "useSSL": "false"
}

df = spark.read.jdbc(
    url="jdbc:mysql://localhost:3306/mydb",
    table="transactions",
    column="transaction_id",
    lowerBound=1,
    upperBound=5000000,
    numPartitions=16,
    properties=mysql_properties
)

SQL Server

sqlserver_properties = {
    "user": "sa",
    "password": "Password123",
    "driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver",
    "encrypt": "true",
    "trustServerCertificate": "true"
}

df = spark.read.jdbc(
    url="jdbc:sqlserver://localhost:1433;databaseName=mydb",
    table="dbo.sales",
    column="sale_id",
    lowerBound=1,
    upperBound=10000000,
    numPartitions=20,
    properties=sqlserver_properties
)

Oracle

oracle_properties = {
    "user": "system",
    "password": "oracle",
    "driver": "oracle.jdbc.driver.OracleDriver",
    "fetchsize": "5000"
}

df = spark.read.jdbc(
    url="jdbc:oracle:thin:@localhost:1521:ORCL",
    table="HR.EMPLOYEES",
    properties=oracle_properties
)

Handling Large Tables

For tables with billions of rows, implement incremental reads using timestamp columns:

from datetime import datetime, timedelta

# Read last 7 days of data
end_date = datetime.now()
start_date = end_date - timedelta(days=7)

query = f"""
(SELECT * FROM events 
 WHERE event_timestamp >= TIMESTAMP '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
   AND event_timestamp < TIMESTAMP '{end_date.strftime('%Y-%m-%d %H:%M:%S')}') as recent_events
"""

df = spark.read.jdbc(
    url=jdbc_url,
    table=query,
    column="event_id",
    lowerBound=1,
    upperBound=100000000,
    numPartitions=50,
    properties=connection_properties
)

# Cache if reusing
df.cache()
df.count()  # Trigger caching

Performance Monitoring

Monitor partition distribution to identify skew:

from pyspark.sql.functions import spark_partition_id, count

# Check partition distribution
partition_stats = df.groupBy(spark_partition_id().alias("partition_id")) \
    .agg(count("*").alias("row_count")) \
    .orderBy("partition_id")

partition_stats.show()

# Repartition if skewed
df_balanced = df.repartition(20)

Check the Spark UI’s SQL tab to verify predicate pushdown. The physical plan should show PushedFilters for optimal performance.

Error Handling

Implement retry logic and connection validation:

from pyspark.sql.utils import AnalysisException
import time

def read_with_retry(url, table, properties, max_retries=3):
    for attempt in range(max_retries):
        try:
            df = spark.read.jdbc(
                url=url,
                table=table,
                properties=properties
            )
            # Validate connection by triggering action
            df.limit(1).collect()
            return df
        except AnalysisException as e:
            if attempt < max_retries - 1:
                wait_time = 2 ** attempt
                print(f"Attempt {attempt + 1} failed. Retrying in {wait_time}s...")
                time.sleep(wait_time)
            else:
                raise

df = read_with_retry(jdbc_url, "customers", connection_properties)

Proper JDBC configuration transforms PySpark from a slow, single-threaded database reader into a distributed processing powerhouse. The key is understanding partition mechanics and pushing computation to the database where possible.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.