PySpark - Write to JDBC/Database
• PySpark's JDBC writer supports multiple write modes (append, overwrite, error, ignore) and allows fine-grained control over partitioning and batch size for optimal database performance
Key Insights
• PySpark’s JDBC writer supports multiple write modes (append, overwrite, error, ignore) and allows fine-grained control over partitioning and batch size for optimal database performance
• Connection pooling, proper partitioning strategy, and batch size tuning are critical for writing large datasets efficiently—poor configuration can lead to connection exhaustion or slow writes
• Using numPartitions, batchsize, and isolationLevel parameters correctly can improve write throughput by 10x or more compared to default settings
Understanding PySpark JDBC Write Mechanics
PySpark writes data to JDBC-compatible databases by distributing the workload across executors. Each executor opens its own database connection and writes a partition of the DataFrame. This parallel approach can dramatically speed up bulk inserts but requires careful configuration to avoid overwhelming the target database.
from pyspark.sql import SparkSession
spark = SparkSession.builder \
.appName("JDBC Write Example") \
.config("spark.jars", "/path/to/postgresql-42.6.0.jar") \
.getOrCreate()
# Sample DataFrame
data = [
(1, "Alice", 28, "Engineering"),
(2, "Bob", 35, "Marketing"),
(3, "Charlie", 42, "Sales")
]
columns = ["id", "name", "age", "department"]
df = spark.createDataFrame(data, columns)
# Basic JDBC write
jdbc_url = "jdbc:postgresql://localhost:5432/mydb"
connection_properties = {
"user": "postgres",
"password": "secret",
"driver": "org.postgresql.Driver"
}
df.write \
.jdbc(url=jdbc_url, table="employees", mode="append", properties=connection_properties)
Write Modes and Their Use Cases
PySpark provides four write modes that determine how to handle existing data in the target table.
# Append: Add new records to existing table
df.write.mode("append").jdbc(jdbc_url, "employees", properties=connection_properties)
# Overwrite: Drop and recreate the table
df.write.mode("overwrite").jdbc(jdbc_url, "employees", properties=connection_properties)
# Error (default): Throw exception if table exists
df.write.mode("error").jdbc(jdbc_url, "employees", properties=connection_properties)
# Ignore: Do nothing if table exists
df.write.mode("ignore").jdbc(jdbc_url, "employees", properties=connection_properties)
The overwrite mode issues a DROP TABLE followed by CREATE TABLE, which means you lose indexes, constraints, and permissions. For production systems, consider using append mode with a manual truncate operation instead:
# Better approach for overwrite scenarios
from pyspark.sql import SparkSession
def safe_overwrite(df, jdbc_url, table_name, properties):
# Truncate using a separate JDBC connection
import psycopg2
conn = psycopg2.connect(
host="localhost",
database="mydb",
user=properties["user"],
password=properties["password"]
)
cursor = conn.cursor()
cursor.execute(f"TRUNCATE TABLE {table_name}")
conn.commit()
cursor.close()
conn.close()
# Write data
df.write.mode("append").jdbc(jdbc_url, table_name, properties=properties)
Optimizing Write Performance with Partitioning
The numPartitions parameter controls how many parallel connections PySpark opens to the database. Too few partitions create bottlenecks; too many exhaust database connections.
# Repartition before writing for optimal parallelism
optimal_partitions = 8 # Adjust based on database connection pool size
df.repartition(optimal_partitions) \
.write \
.jdbc(jdbc_url, "employees", mode="append", properties=connection_properties)
# For large datasets, use coalesce to reduce shuffling
df.coalesce(optimal_partitions) \
.write \
.jdbc(jdbc_url, "employees", mode="append", properties=connection_properties)
Calculate optimal partitions based on your database’s max connections and concurrent Spark jobs:
# If your database allows 100 connections and you run 5 concurrent jobs:
# optimal_partitions = 100 / 5 = 20 partitions per job
Batch Size Tuning for Throughput
The batchsize parameter determines how many rows are inserted in a single batch. Default is typically 1000, but this is often suboptimal.
# Configure batch size in connection properties
connection_properties = {
"user": "postgres",
"password": "secret",
"driver": "org.postgresql.Driver",
"batchsize": "10000" # Increase for better throughput
}
df.write \
.jdbc(jdbc_url, "employees", mode="append", properties=connection_properties)
Test different batch sizes to find the sweet spot for your database:
import time
batch_sizes = [1000, 5000, 10000, 20000, 50000]
for batch_size in batch_sizes:
props = connection_properties.copy()
props["batchsize"] = str(batch_size)
start = time.time()
df.write.jdbc(jdbc_url, f"test_batch_{batch_size}", mode="overwrite", properties=props)
duration = time.time() - start
print(f"Batch size {batch_size}: {duration:.2f} seconds")
Handling Transactions and Isolation Levels
Control transaction behavior using the isolationLevel property. This affects how concurrent reads see uncommitted data during writes.
connection_properties = {
"user": "postgres",
"password": "secret",
"driver": "org.postgresql.Driver",
"isolationLevel": "READ_COMMITTED" # Options: NONE, READ_UNCOMMITTED, READ_COMMITTED, REPEATABLE_READ, SERIALIZABLE
}
df.write \
.jdbc(jdbc_url, "employees", mode="append", properties=connection_properties)
For databases supporting it, use NONE for maximum write performance when isolation isn’t critical:
# Maximum performance for bulk loads (use with caution)
connection_properties["isolationLevel"] = "NONE"
Database-Specific Optimizations
Different databases require different JDBC drivers and optimization strategies.
PostgreSQL:
postgres_properties = {
"user": "postgres",
"password": "secret",
"driver": "org.postgresql.Driver",
"batchsize": "10000",
"reWriteBatchedInserts": "true", # PostgreSQL-specific optimization
"stringtype": "unspecified" # Helps with VARCHAR vs TEXT
}
df.write.jdbc(
"jdbc:postgresql://localhost:5432/mydb",
"employees",
mode="append",
properties=postgres_properties
)
MySQL:
mysql_properties = {
"user": "root",
"password": "secret",
"driver": "com.mysql.cj.jdbc.Driver",
"batchsize": "10000",
"rewriteBatchedStatements": "true", # MySQL batch optimization
"useServerPrepStmts": "false" # Faster for bulk inserts
}
df.write.jdbc(
"jdbc:mysql://localhost:3306/mydb",
"employees",
mode="append",
properties=mysql_properties
)
SQL Server:
sqlserver_properties = {
"user": "sa",
"password": "secret",
"driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver",
"batchsize": "10000",
"useBulkCopyForBatchInsert": "true" # Use BULK INSERT API
}
df.write.jdbc(
"jdbc:sqlserver://localhost:1433;databaseName=mydb",
"employees",
mode="append",
properties=sqlserver_properties
)
Creating Tables with Custom DDL
Control table creation by specifying column types and constraints using the createTableOptions and createTableColumnTypes properties.
connection_properties = {
"user": "postgres",
"password": "secret",
"driver": "org.postgresql.Driver",
"createTableColumnTypes": "id INTEGER PRIMARY KEY, name VARCHAR(100), age INTEGER, department VARCHAR(50)"
}
df.write.jdbc(jdbc_url, "employees", mode="overwrite", properties=connection_properties)
For partitioned tables or advanced constraints, create the table manually first:
import psycopg2
conn = psycopg2.connect(host="localhost", database="mydb", user="postgres", password="secret")
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS employees (
id INTEGER PRIMARY KEY,
name VARCHAR(100) NOT NULL,
age INTEGER CHECK (age > 0),
department VARCHAR(50),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
conn.commit()
cursor.close()
conn.close()
# Write to existing table
df.write.mode("append").jdbc(jdbc_url, "employees", properties=connection_properties)
Error Handling and Monitoring
Implement robust error handling to catch connection failures and constraint violations:
from pyspark.sql.utils import AnalysisException
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
df.write \
.jdbc(jdbc_url, "employees", mode="append", properties=connection_properties)
logger.info(f"Successfully wrote {df.count()} records")
except AnalysisException as e:
logger.error(f"Analysis error: {str(e)}")
# Handle schema mismatches, missing tables, etc.
except Exception as e:
logger.error(f"Write failed: {str(e)}")
# Handle connection errors, constraint violations, etc.
Monitor write progress for large datasets by adding intermediate checkpoints:
total_records = df.count()
partition_size = total_records // optimal_partitions
for i in range(optimal_partitions):
partition_df = df.filter(f"id % {optimal_partitions} = {i}")
partition_df.write.jdbc(jdbc_url, "employees", mode="append", properties=connection_properties)
logger.info(f"Wrote partition {i+1}/{optimal_partitions}")