PySpark - Write to JDBC/Database

• PySpark's JDBC writer supports multiple write modes (append, overwrite, error, ignore) and allows fine-grained control over partitioning and batch size for optimal database performance

Key Insights

• PySpark’s JDBC writer supports multiple write modes (append, overwrite, error, ignore) and allows fine-grained control over partitioning and batch size for optimal database performance • Connection pooling, proper partitioning strategy, and batch size tuning are critical for writing large datasets efficiently—poor configuration can lead to connection exhaustion or slow writes • Using numPartitions, batchsize, and isolationLevel parameters correctly can improve write throughput by 10x or more compared to default settings

Understanding PySpark JDBC Write Mechanics

PySpark writes data to JDBC-compatible databases by distributing the workload across executors. Each executor opens its own database connection and writes a partition of the DataFrame. This parallel approach can dramatically speed up bulk inserts but requires careful configuration to avoid overwhelming the target database.

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("JDBC Write Example") \
    .config("spark.jars", "/path/to/postgresql-42.6.0.jar") \
    .getOrCreate()

# Sample DataFrame
data = [
    (1, "Alice", 28, "Engineering"),
    (2, "Bob", 35, "Marketing"),
    (3, "Charlie", 42, "Sales")
]
columns = ["id", "name", "age", "department"]
df = spark.createDataFrame(data, columns)

# Basic JDBC write
jdbc_url = "jdbc:postgresql://localhost:5432/mydb"
connection_properties = {
    "user": "postgres",
    "password": "secret",
    "driver": "org.postgresql.Driver"
}

df.write \
    .jdbc(url=jdbc_url, table="employees", mode="append", properties=connection_properties)

Write Modes and Their Use Cases

PySpark provides four write modes that determine how to handle existing data in the target table.

# Append: Add new records to existing table
df.write.mode("append").jdbc(jdbc_url, "employees", properties=connection_properties)

# Overwrite: Drop and recreate the table
df.write.mode("overwrite").jdbc(jdbc_url, "employees", properties=connection_properties)

# Error (default): Throw exception if table exists
df.write.mode("error").jdbc(jdbc_url, "employees", properties=connection_properties)

# Ignore: Do nothing if table exists
df.write.mode("ignore").jdbc(jdbc_url, "employees", properties=connection_properties)

The overwrite mode issues a DROP TABLE followed by CREATE TABLE, which means you lose indexes, constraints, and permissions. For production systems, consider using append mode with a manual truncate operation instead:

# Better approach for overwrite scenarios
from pyspark.sql import SparkSession

def safe_overwrite(df, jdbc_url, table_name, properties):
    # Truncate using a separate JDBC connection
    import psycopg2
    conn = psycopg2.connect(
        host="localhost",
        database="mydb",
        user=properties["user"],
        password=properties["password"]
    )
    cursor = conn.cursor()
    cursor.execute(f"TRUNCATE TABLE {table_name}")
    conn.commit()
    cursor.close()
    conn.close()
    
    # Write data
    df.write.mode("append").jdbc(jdbc_url, table_name, properties=properties)

Optimizing Write Performance with Partitioning

The numPartitions parameter controls how many parallel connections PySpark opens to the database. Too few partitions create bottlenecks; too many exhaust database connections.

# Repartition before writing for optimal parallelism
optimal_partitions = 8  # Adjust based on database connection pool size

df.repartition(optimal_partitions) \
    .write \
    .jdbc(jdbc_url, "employees", mode="append", properties=connection_properties)

# For large datasets, use coalesce to reduce shuffling
df.coalesce(optimal_partitions) \
    .write \
    .jdbc(jdbc_url, "employees", mode="append", properties=connection_properties)

Calculate optimal partitions based on your database’s max connections and concurrent Spark jobs:

# If your database allows 100 connections and you run 5 concurrent jobs:
# optimal_partitions = 100 / 5 = 20 partitions per job

Batch Size Tuning for Throughput

The batchsize parameter determines how many rows are inserted in a single batch. Default is typically 1000, but this is often suboptimal.

# Configure batch size in connection properties
connection_properties = {
    "user": "postgres",
    "password": "secret",
    "driver": "org.postgresql.Driver",
    "batchsize": "10000"  # Increase for better throughput
}

df.write \
    .jdbc(jdbc_url, "employees", mode="append", properties=connection_properties)

Test different batch sizes to find the sweet spot for your database:

import time

batch_sizes = [1000, 5000, 10000, 20000, 50000]

for batch_size in batch_sizes:
    props = connection_properties.copy()
    props["batchsize"] = str(batch_size)
    
    start = time.time()
    df.write.jdbc(jdbc_url, f"test_batch_{batch_size}", mode="overwrite", properties=props)
    duration = time.time() - start
    
    print(f"Batch size {batch_size}: {duration:.2f} seconds")

Handling Transactions and Isolation Levels

Control transaction behavior using the isolationLevel property. This affects how concurrent reads see uncommitted data during writes.

connection_properties = {
    "user": "postgres",
    "password": "secret",
    "driver": "org.postgresql.Driver",
    "isolationLevel": "READ_COMMITTED"  # Options: NONE, READ_UNCOMMITTED, READ_COMMITTED, REPEATABLE_READ, SERIALIZABLE
}

df.write \
    .jdbc(jdbc_url, "employees", mode="append", properties=connection_properties)

For databases supporting it, use NONE for maximum write performance when isolation isn’t critical:

# Maximum performance for bulk loads (use with caution)
connection_properties["isolationLevel"] = "NONE"

Database-Specific Optimizations

Different databases require different JDBC drivers and optimization strategies.

PostgreSQL:

postgres_properties = {
    "user": "postgres",
    "password": "secret",
    "driver": "org.postgresql.Driver",
    "batchsize": "10000",
    "reWriteBatchedInserts": "true",  # PostgreSQL-specific optimization
    "stringtype": "unspecified"  # Helps with VARCHAR vs TEXT
}

df.write.jdbc(
    "jdbc:postgresql://localhost:5432/mydb",
    "employees",
    mode="append",
    properties=postgres_properties
)

MySQL:

mysql_properties = {
    "user": "root",
    "password": "secret",
    "driver": "com.mysql.cj.jdbc.Driver",
    "batchsize": "10000",
    "rewriteBatchedStatements": "true",  # MySQL batch optimization
    "useServerPrepStmts": "false"  # Faster for bulk inserts
}

df.write.jdbc(
    "jdbc:mysql://localhost:3306/mydb",
    "employees",
    mode="append",
    properties=mysql_properties
)

SQL Server:

sqlserver_properties = {
    "user": "sa",
    "password": "secret",
    "driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver",
    "batchsize": "10000",
    "useBulkCopyForBatchInsert": "true"  # Use BULK INSERT API
}

df.write.jdbc(
    "jdbc:sqlserver://localhost:1433;databaseName=mydb",
    "employees",
    mode="append",
    properties=sqlserver_properties
)

Creating Tables with Custom DDL

Control table creation by specifying column types and constraints using the createTableOptions and createTableColumnTypes properties.

connection_properties = {
    "user": "postgres",
    "password": "secret",
    "driver": "org.postgresql.Driver",
    "createTableColumnTypes": "id INTEGER PRIMARY KEY, name VARCHAR(100), age INTEGER, department VARCHAR(50)"
}

df.write.jdbc(jdbc_url, "employees", mode="overwrite", properties=connection_properties)

For partitioned tables or advanced constraints, create the table manually first:

import psycopg2

conn = psycopg2.connect(host="localhost", database="mydb", user="postgres", password="secret")
cursor = conn.cursor()

cursor.execute("""
    CREATE TABLE IF NOT EXISTS employees (
        id INTEGER PRIMARY KEY,
        name VARCHAR(100) NOT NULL,
        age INTEGER CHECK (age > 0),
        department VARCHAR(50),
        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
    )
""")

conn.commit()
cursor.close()
conn.close()

# Write to existing table
df.write.mode("append").jdbc(jdbc_url, "employees", properties=connection_properties)

Error Handling and Monitoring

Implement robust error handling to catch connection failures and constraint violations:

from pyspark.sql.utils import AnalysisException
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

try:
    df.write \
        .jdbc(jdbc_url, "employees", mode="append", properties=connection_properties)
    logger.info(f"Successfully wrote {df.count()} records")
except AnalysisException as e:
    logger.error(f"Analysis error: {str(e)}")
    # Handle schema mismatches, missing tables, etc.
except Exception as e:
    logger.error(f"Write failed: {str(e)}")
    # Handle connection errors, constraint violations, etc.

Monitor write progress for large datasets by adding intermediate checkpoints:

total_records = df.count()
partition_size = total_records // optimal_partitions

for i in range(optimal_partitions):
    partition_df = df.filter(f"id % {optimal_partitions} = {i}")
    partition_df.write.jdbc(jdbc_url, "employees", mode="append", properties=connection_properties)
    logger.info(f"Wrote partition {i+1}/{optimal_partitions}")

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.