Data Quality Checks with PySpark

Bad data is expensive. A malformed record in a batch of millions can cascade through your pipeline, corrupt aggregations, and ultimately lead to wrong business decisions. At scale, you can't eyeball...

Key Insights

  • Data quality checks should fail fast and fail loudly—catching schema mismatches, null violations, and duplicate records early in your pipeline prevents costly downstream errors and silent data corruption.
  • Build reusable validation functions that return structured results rather than throwing exceptions, allowing you to aggregate findings into actionable reports and make informed decisions about data acceptance.
  • Treat data quality thresholds as configuration, not code—what’s acceptable for one dataset (5% nulls) might be catastrophic for another, so parameterize your checks accordingly.

Why Data Quality Matters at Scale

Bad data is expensive. A malformed record in a batch of millions can cascade through your pipeline, corrupt aggregations, and ultimately lead to wrong business decisions. At scale, you can’t eyeball a DataFrame to spot problems—you need automated, systematic checks.

PySpark is the natural choice for data quality validation in big data environments. It handles distributed processing, integrates with common data lake formats, and provides the DataFrame API flexibility needed for complex validation logic.

The most common data quality issues fall into predictable categories: schema drift (columns appearing, disappearing, or changing types), completeness problems (nulls and missing values), uniqueness violations (duplicate records), and validity failures (data that doesn’t conform to business rules). A solid DQ framework addresses all four.

Setting Up Your Data Quality Framework

Before diving into specific checks, establish a consistent structure. Create a dedicated module for DQ functions that return standardized results rather than raising exceptions. This approach lets you run all checks and aggregate findings.

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType
from dataclasses import dataclass
from typing import List, Optional
from datetime import datetime

spark = SparkSession.builder \
    .appName("DataQualityChecks") \
    .config("spark.sql.adaptive.enabled", "true") \
    .getOrCreate()

@dataclass
class DQResult:
    check_name: str
    passed: bool
    metric_value: float
    threshold: float
    details: Optional[str] = None
    timestamp: str = None
    
    def __post_init__(self):
        self.timestamp = datetime.now().isoformat()

# Sample dataset for examples
data = [
    ("ORD001", "customer_123", 150.00, "2024-01-15", "COMPLETED"),
    ("ORD002", "customer_456", -50.00, "2024-01-15", "COMPLETED"),  # Invalid negative amount
    ("ORD003", None, 200.00, "2024-01-16", "PENDING"),  # Null customer
    ("ORD001", "customer_123", 150.00, "2024-01-15", "COMPLETED"),  # Duplicate
    ("ORD004", "customer_789", 0.00, "invalid-date", "UNKNOWN"),  # Invalid date, unknown status
]

orders_df = spark.createDataFrame(data, ["order_id", "customer_id", "amount", "order_date", "status"])

Schema Validation Checks

Schema validation is your first line of defense. Catch structural problems before wasting compute on data that’s fundamentally broken.

def validate_schema(
    df: DataFrame, 
    expected_schema: StructType,
    allow_extra_columns: bool = True
) -> DQResult:
    """Validate DataFrame schema against expected structure."""
    actual_fields = {f.name: f.dataType for f in df.schema.fields}
    expected_fields = {f.name: f.dataType for f in expected_schema.fields}
    
    missing_columns = set(expected_fields.keys()) - set(actual_fields.keys())
    type_mismatches = []
    
    for col_name, expected_type in expected_fields.items():
        if col_name in actual_fields:
            actual_type = actual_fields[col_name]
            if actual_type != expected_type:
                type_mismatches.append(
                    f"{col_name}: expected {expected_type}, got {actual_type}"
                )
    
    extra_columns = set(actual_fields.keys()) - set(expected_fields.keys())
    
    issues = []
    if missing_columns:
        issues.append(f"Missing columns: {missing_columns}")
    if type_mismatches:
        issues.append(f"Type mismatches: {type_mismatches}")
    if extra_columns and not allow_extra_columns:
        issues.append(f"Unexpected columns: {extra_columns}")
    
    passed = len(missing_columns) == 0 and len(type_mismatches) == 0
    
    return DQResult(
        check_name="schema_validation",
        passed=passed,
        metric_value=len(issues),
        threshold=0,
        details="; ".join(issues) if issues else "Schema valid"
    )

# Define expected schema
expected_schema = StructType([
    StructField("order_id", StringType(), False),
    StructField("customer_id", StringType(), False),
    StructField("amount", DoubleType(), False),
    StructField("order_date", StringType(), False),
    StructField("status", StringType(), False),
])

schema_result = validate_schema(orders_df, expected_schema)
print(f"Schema check passed: {schema_result.passed}")

For production pipelines dealing with schema evolution, compare against a schema registry or versioned schema definitions. Store schema history to track drift over time.

Completeness Checks

Null values are the silent killers of data pipelines. Quantify them systematically and enforce thresholds.

def check_completeness(
    df: DataFrame,
    columns: List[str] = None,
    null_threshold: float = 0.0
) -> List[DQResult]:
    """Check null/empty percentages for specified columns."""
    if columns is None:
        columns = df.columns
    
    total_count = df.count()
    results = []
    
    # Build aggregation expressions for all columns at once
    null_exprs = [
        F.sum(
            F.when(
                F.col(c).isNull() | (F.col(c) == "") | F.isnan(F.col(c)),
                1
            ).otherwise(0)
        ).alias(f"{c}_nulls")
        for c in columns
    ]
    
    null_counts = df.agg(*null_exprs).collect()[0]
    
    for col in columns:
        null_count = null_counts[f"{col}_nulls"] or 0
        null_pct = (null_count / total_count) * 100 if total_count > 0 else 0
        
        results.append(DQResult(
            check_name=f"completeness_{col}",
            passed=null_pct <= null_threshold,
            metric_value=round(null_pct, 2),
            threshold=null_threshold,
            details=f"{null_count}/{total_count} records null/empty"
        ))
    
    return results

# Run completeness checks with 1% threshold
completeness_results = check_completeness(
    orders_df, 
    columns=["order_id", "customer_id", "amount"],
    null_threshold=1.0
)

for result in completeness_results:
    print(f"{result.check_name}: {result.metric_value}% null (threshold: {result.threshold}%)")

Uniqueness and Duplicate Detection

Duplicate detection at scale requires careful consideration of what constitutes a duplicate. Sometimes it’s a single primary key; sometimes it’s a composite of multiple columns.

def check_uniqueness(
    df: DataFrame,
    key_columns: List[str],
    duplicate_threshold: float = 0.0
) -> DQResult:
    """Check for duplicate records based on key columns."""
    total_count = df.count()
    
    duplicate_df = df.groupBy(key_columns) \
        .count() \
        .filter(F.col("count") > 1)
    
    duplicate_groups = duplicate_df.count()
    duplicate_records = duplicate_df.agg(F.sum("count")).collect()[0][0] or 0
    
    duplicate_pct = (duplicate_records / total_count) * 100 if total_count > 0 else 0
    
    return DQResult(
        check_name=f"uniqueness_{'_'.join(key_columns)}",
        passed=duplicate_pct <= duplicate_threshold,
        metric_value=round(duplicate_pct, 2),
        threshold=duplicate_threshold,
        details=f"{duplicate_groups} duplicate groups, {duplicate_records} total duplicate records"
    )

def deduplicate_with_strategy(
    df: DataFrame,
    key_columns: List[str],
    order_column: str,
    keep: str = "last"
) -> DataFrame:
    """Remove duplicates keeping first or last record by order column."""
    from pyspark.sql.window import Window
    
    order_expr = F.col(order_column).desc() if keep == "last" else F.col(order_column).asc()
    
    window = Window.partitionBy(key_columns).orderBy(order_expr)
    
    return df.withColumn("_row_num", F.row_number().over(window)) \
        .filter(F.col("_row_num") == 1) \
        .drop("_row_num")

# Check uniqueness on order_id
uniqueness_result = check_uniqueness(orders_df, ["order_id"])
print(f"Uniqueness check: {uniqueness_result.details}")

Validity and Business Rule Checks

This is where domain knowledge meets code. Encode your business rules as explicit, testable validations.

def check_validity_rules(df: DataFrame) -> tuple[DataFrame, List[DQResult]]:
    """Apply business rule validations and flag invalid records."""
    
    validated_df = df.withColumn(
        "valid_amount", 
        F.col("amount") > 0
    ).withColumn(
        "valid_status",
        F.col("status").isin(["PENDING", "COMPLETED", "CANCELLED", "REFUNDED"])
    ).withColumn(
        "valid_date",
        F.to_date(F.col("order_date"), "yyyy-MM-dd").isNotNull()
    ).withColumn(
        "valid_customer_id",
        F.col("customer_id").rlike("^customer_[0-9]+$")
    ).withColumn(
        "is_valid_record",
        F.col("valid_amount") & F.col("valid_status") & 
        F.col("valid_date") & F.col("valid_customer_id")
    )
    
    total_count = validated_df.count()
    
    # Aggregate validation results
    validation_stats = validated_df.agg(
        F.sum(F.when(~F.col("valid_amount"), 1).otherwise(0)).alias("invalid_amount"),
        F.sum(F.when(~F.col("valid_status"), 1).otherwise(0)).alias("invalid_status"),
        F.sum(F.when(~F.col("valid_date"), 1).otherwise(0)).alias("invalid_date"),
        F.sum(F.when(~F.col("valid_customer_id"), 1).otherwise(0)).alias("invalid_customer"),
    ).collect()[0]
    
    results = [
        DQResult("validity_amount_positive", validation_stats["invalid_amount"] == 0,
                 validation_stats["invalid_amount"], 0, "Amount must be positive"),
        DQResult("validity_status_enum", validation_stats["invalid_status"] == 0,
                 validation_stats["invalid_status"], 0, "Status must be valid enum"),
        DQResult("validity_date_format", validation_stats["invalid_date"] == 0,
                 validation_stats["invalid_date"], 0, "Date must be yyyy-MM-dd"),
        DQResult("validity_customer_pattern", validation_stats["invalid_customer"] == 0,
                 validation_stats["invalid_customer"], 0, "Customer ID must match pattern"),
    ]
    
    return validated_df, results

validated_df, validity_results = check_validity_rules(orders_df)
invalid_records = validated_df.filter(~F.col("is_valid_record"))
print(f"Found {invalid_records.count()} invalid records")

Building a Reusable DQ Report

Aggregate all checks into a single report that can drive alerting and pipeline decisions.

def run_full_dq_suite(
    df: DataFrame,
    dataset_name: str,
    expected_schema: StructType,
    key_columns: List[str],
    completeness_threshold: float = 1.0
) -> DataFrame:
    """Run complete DQ suite and return results as DataFrame."""
    all_results = []
    
    # Schema check
    all_results.append(validate_schema(df, expected_schema))
    
    # Completeness checks
    all_results.extend(check_completeness(df, null_threshold=completeness_threshold))
    
    # Uniqueness check
    all_results.append(check_uniqueness(df, key_columns))
    
    # Validity checks
    _, validity_results = check_validity_rules(df)
    all_results.extend(validity_results)
    
    # Convert to DataFrame
    results_data = [
        (dataset_name, r.check_name, r.passed, r.metric_value, 
         r.threshold, r.details, r.timestamp)
        for r in all_results
    ]
    
    results_df = spark.createDataFrame(
        results_data,
        ["dataset", "check_name", "passed", "metric_value", 
         "threshold", "details", "timestamp"]
    )
    
    return results_df

# Generate full report
dq_report = run_full_dq_suite(
    orders_df,
    dataset_name="orders",
    expected_schema=expected_schema,
    key_columns=["order_id"],
    completeness_threshold=5.0
)

dq_report.show(truncate=False)

# Write to Delta for historical tracking
# dq_report.write.format("delta").mode("append").save("/dq/reports/orders")

# Check if pipeline should proceed
failed_checks = dq_report.filter(~F.col("passed")).count()
if failed_checks > 0:
    print(f"WARNING: {failed_checks} DQ checks failed")

The key to effective data quality is treating it as a first-class concern in your pipeline architecture. Run checks early, fail explicitly, and maintain historical records of data quality metrics. Your future self—debugging a production issue at 2 AM—will thank you.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.