PySpark - Read CSV with Custom Schema
• Defining custom schemas in PySpark eliminates costly schema inference and prevents data type mismatches that cause runtime failures in production pipelines
Key Insights
• Defining custom schemas in PySpark eliminates costly schema inference and prevents data type mismatches that cause runtime failures in production pipelines • StructType and StructField provide fine-grained control over column types, nullability, and metadata, enabling validation at the ingestion layer • Schema enforcement catches data quality issues early and improves performance by 3-5x compared to inferSchema=True on large datasets
Why Custom Schemas Matter
PySpark’s default CSV reader with inferSchema=True scans your entire dataset to guess column types. For a 10GB CSV file, this means reading the data twice—once for inference, once for actual processing. Custom schemas eliminate this overhead while giving you explicit control over data types, null handling, and column constraints.
Schema mismatches are the leading cause of Spark job failures in production. A column containing “N/A” strings will break numeric conversions. Dates in inconsistent formats cause parsing errors. Custom schemas let you handle these edge cases upfront rather than discovering them after hours of processing.
Basic Schema Definition
The simplest approach uses StructType with StructField definitions:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
spark = SparkSession.builder \
.appName("CustomSchemaExample") \
.getOrCreate()
schema = StructType([
StructField("customer_id", IntegerType(), nullable=False),
StructField("name", StringType(), nullable=False),
StructField("email", StringType(), nullable=True),
StructField("age", IntegerType(), nullable=True),
StructField("purchase_amount", DoubleType(), nullable=True)
])
df = spark.read \
.format("csv") \
.option("header", "true") \
.schema(schema) \
.load("s3://bucket/customers.csv")
df.printSchema()
This schema enforces that customer_id and name cannot be null. If the source data contains null values in these columns, PySpark will set them to null anyway, but you’ll see them clearly in your data quality checks.
Handling Complex Data Types
Real-world datasets contain dates, timestamps, decimals, and nested structures. Here’s how to handle them:
from pyspark.sql.types import (
StructType, StructField, StringType, IntegerType,
DateType, TimestampType, DecimalType, BooleanType, ArrayType
)
transaction_schema = StructType([
StructField("transaction_id", StringType(), nullable=False),
StructField("customer_id", IntegerType(), nullable=False),
StructField("transaction_date", DateType(), nullable=False),
StructField("created_at", TimestampType(), nullable=False),
StructField("amount", DecimalType(10, 2), nullable=False),
StructField("is_refund", BooleanType(), nullable=False),
StructField("tags", ArrayType(StringType()), nullable=True)
])
df = spark.read \
.format("csv") \
.option("header", "true") \
.option("dateFormat", "yyyy-MM-dd") \
.option("timestampFormat", "yyyy-MM-dd HH:mm:ss") \
.schema(transaction_schema) \
.load("transactions.csv")
The DecimalType(10, 2) ensures monetary values maintain precision—critical for financial calculations where floating-point errors are unacceptable. DateType and TimestampType require corresponding format options to parse correctly.
Schema with Mode Enforcement
The mode option controls how PySpark handles malformed records:
schema = StructType([
StructField("id", IntegerType(), nullable=False),
StructField("value", DoubleType(), nullable=False),
StructField("category", StringType(), nullable=False)
])
# PERMISSIVE: Sets malformed records to null (default)
df_permissive = spark.read \
.format("csv") \
.option("header", "true") \
.option("mode", "PERMISSIVE") \
.schema(schema) \
.load("data.csv")
# DROPMALFORMED: Silently drops bad records
df_drop = spark.read \
.format("csv") \
.option("header", "true") \
.option("mode", "DROPMALFORMED") \
.schema(schema) \
.load("data.csv")
# FAILFAST: Throws exception on first malformed record
df_failfast = spark.read \
.format("csv") \
.option("header", "true") \
.option("mode", "FAILFAST") \
.schema(schema) \
.load("data.csv")
Use FAILFAST during development to catch schema issues immediately. In production, PERMISSIVE with explicit null checks provides better observability:
from pyspark.sql.functions import col
df = spark.read \
.format("csv") \
.option("header", "true") \
.option("mode", "PERMISSIVE") \
.option("columnNameOfCorruptRecord", "_corrupt_record") \
.schema(schema.add(StructField("_corrupt_record", StringType(), True))) \
.load("data.csv")
# Separate valid and invalid records
valid_df = df.filter(col("_corrupt_record").isNull())
corrupt_df = df.filter(col("_corrupt_record").isNotNull())
print(f"Corrupt records: {corrupt_df.count()}")
corrupt_df.select("_corrupt_record").show(truncate=False)
Dynamic Schema Generation
For datasets with many columns, programmatically generate schemas from metadata:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
# Column definitions from configuration or metadata store
column_config = [
{"name": "user_id", "type": "integer", "nullable": False},
{"name": "username", "type": "string", "nullable": False},
{"name": "score", "type": "double", "nullable": True},
{"name": "level", "type": "integer", "nullable": True}
]
type_mapping = {
"integer": IntegerType(),
"string": StringType(),
"double": DoubleType()
}
def create_schema(config):
fields = []
for col in config:
field = StructField(
col["name"],
type_mapping[col["type"]],
col["nullable"]
)
fields.append(field)
return StructType(fields)
schema = create_schema(column_config)
df = spark.read \
.format("csv") \
.option("header", "true") \
.schema(schema) \
.load("users.csv")
This approach works well when integrating with data catalogs or when schema definitions are maintained in external configuration files.
Handling Missing and Extra Columns
Real CSV files rarely match your schema perfectly. Control this behavior explicitly:
schema = StructType([
StructField("id", IntegerType(), nullable=False),
StructField("name", StringType(), nullable=False),
StructField("value", DoubleType(), nullable=True)
])
# Ignore extra columns in CSV
df = spark.read \
.format("csv") \
.option("header", "true") \
.option("enforceSchema", "false") \
.schema(schema) \
.load("data.csv")
# Handle missing columns by adding them with null values
from pyspark.sql.functions import lit
expected_columns = ["id", "name", "value", "category"]
df_loaded = spark.read \
.format("csv") \
.option("header", "true") \
.schema(schema) \
.load("data.csv")
for col_name in expected_columns:
if col_name not in df_loaded.columns:
df_loaded = df_loaded.withColumn(col_name, lit(None).cast(StringType()))
df_final = df_loaded.select(expected_columns)
Performance Optimization with Schema
Custom schemas enable partition pruning and predicate pushdown when reading from partitioned sources:
schema = StructType([
StructField("event_id", StringType(), nullable=False),
StructField("user_id", IntegerType(), nullable=False),
StructField("event_type", StringType(), nullable=False),
StructField("timestamp", TimestampType(), nullable=False),
StructField("value", DoubleType(), nullable=True)
])
# Read partitioned data with schema
df = spark.read \
.format("csv") \
.option("header", "true") \
.schema(schema) \
.load("s3://bucket/events/year=2024/month=*/day=*/")
# Filter pushdown works efficiently with known schema
filtered_df = df.filter(col("event_type") == "purchase") \
.filter(col("value") > 100)
Schema knowledge allows Spark to skip reading unnecessary partitions and columns entirely, reducing I/O by orders of magnitude on large datasets.
Validation and Testing
Always validate your schema against sample data before running production jobs:
def validate_schema(df, schema):
"""Validate loaded data against expected schema."""
errors = []
# Check column count
if len(df.columns) != len(schema.fields):
errors.append(f"Column count mismatch: expected {len(schema.fields)}, got {len(df.columns)}")
# Check data types
for field in schema.fields:
if field.name in df.columns:
actual_type = df.schema[field.name].dataType
if actual_type != field.dataType:
errors.append(f"Type mismatch for {field.name}: expected {field.dataType}, got {actual_type}")
else:
errors.append(f"Missing column: {field.name}")
return errors
schema = StructType([
StructField("id", IntegerType(), nullable=False),
StructField("value", DoubleType(), nullable=False)
])
df = spark.read.format("csv").option("header", "true").schema(schema).load("test.csv")
validation_errors = validate_schema(df, schema)
if validation_errors:
for error in validation_errors:
print(f"ERROR: {error}")
raise ValueError("Schema validation failed")
Custom schemas transform PySpark CSV reading from a guessing game into a controlled, predictable process. Define your schemas explicitly, test them thoroughly, and your data pipelines will be more reliable and performant.