Spark Scala - Unit Testing Spark Applications
Testing Spark applications feels different from testing typical Scala code. You're dealing with a distributed computing framework that expects cluster resources, manages its own memory, and requires...
Key Insights
- Spark’s distributed nature makes testing challenging, but isolating transformation logic into pure functions that accept and return DataFrames enables comprehensive unit testing without complex infrastructure.
- A shared SparkSession trait across test suites dramatically reduces test execution time—creating a SparkSession is expensive, so reuse it wherever possible.
- Dependency injection through Scala traits allows you to swap production data sources with test fixtures, making your Spark jobs both testable and maintainable.
Introduction to Spark Testing Challenges
Testing Spark applications feels different from testing typical Scala code. You’re dealing with a distributed computing framework that expects cluster resources, manages its own memory, and requires a SparkSession for virtually every operation. Many teams skip comprehensive testing entirely, relying on manual validation against sample datasets. This approach fails spectacularly when edge cases appear in production.
The core challenges are real: SparkSession creation is slow (often 5-10 seconds), DataFrames aren’t easily comparable with simple equality checks, and your transformation logic is often tightly coupled to data source specifics. But none of these challenges are insurmountable.
Comprehensive unit testing for Spark applications catches schema mismatches before deployment, validates business logic against known inputs, and gives you confidence when refactoring complex pipelines. The investment pays dividends when you’re maintaining data pipelines that process millions of records daily.
Setting Up the Test Environment
Start with the right dependencies. You need ScalaTest for the testing framework and a few Spark test utilities. Add these to your build.sbt:
libraryDependencies ++= Seq(
"org.scalatest" %% "scalatest" % "3.2.17" % Test,
"org.apache.spark" %% "spark-sql" % "3.5.0" % "provided",
"org.apache.spark" %% "spark-sql" % "3.5.0" % Test classifier "tests",
"org.apache.spark" %% "spark-catalyst" % "3.5.0" % Test classifier "tests"
)
The critical piece is a shared SparkSession that persists across tests. Creating a new session for each test method will make your test suite unbearably slow. Here’s a trait that handles setup and teardown properly:
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, Suite}
trait SparkTestBase extends BeforeAndAfterAll { self: Suite =>
@transient lazy val spark: SparkSession = {
SparkSession.builder()
.appName("unit-tests")
.master("local[2]")
.config("spark.ui.enabled", "false")
.config("spark.driver.memory", "2g")
.config("spark.sql.shuffle.partitions", "4")
.getOrCreate()
}
override def afterAll(): Unit = {
try {
spark.stop()
} finally {
super.afterAll()
}
}
}
Key configuration choices here: local[2] uses two threads to catch parallelism bugs, disabled UI reduces overhead, and reduced shuffle partitions speed up aggregation tests. The @transient lazy val pattern ensures the session initializes once and survives serialization issues.
Testing DataFrame Transformations
The secret to testable Spark code is isolation. Extract your transformation logic into pure functions that take DataFrames as input and return DataFrames as output. Avoid functions that read from files or databases directly.
import org.apache.spark.sql.{DataFrame, functions => F}
object SalesTransformations {
def filterActiveSales(df: DataFrame): DataFrame = {
df.filter(F.col("status") === "active")
.filter(F.col("amount") > 0)
}
def aggregateSalesByRegion(df: DataFrame): DataFrame = {
df.groupBy("region")
.agg(
F.sum("amount").alias("total_sales"),
F.count("*").alias("transaction_count"),
F.avg("amount").alias("avg_sale")
)
}
}
Now test these transformations with controlled inputs:
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
class SalesTransformationsTest extends AnyFunSuite
with SparkTestBase with Matchers {
import spark.implicits._
test("filterActiveSales removes inactive and zero-amount records") {
val input = Seq(
("sale1", "active", 100.0),
("sale2", "inactive", 50.0),
("sale3", "active", 0.0),
("sale4", "active", 200.0)
).toDF("id", "status", "amount")
val result = SalesTransformations.filterActiveSales(input)
result.count() shouldBe 2
result.select("id").as[String].collect() should contain theSameElementsAs
Seq("sale1", "sale4")
}
test("aggregateSalesByRegion calculates correct metrics") {
val input = Seq(
("North", 100.0),
("North", 200.0),
("South", 150.0)
).toDF("region", "amount")
val result = SalesTransformations.aggregateSalesByRegion(input)
.collect()
.map(r => (r.getString(0), r.getDouble(1), r.getLong(2)))
.toMap
result("North") shouldBe (300.0, 2)
result("South")._1 shouldBe 150.0
}
}
Testing with Sample Data
For complex schemas, explicit type definitions prevent subtle bugs. The toDF() method infers types, which can cause issues when your production data has specific nullability or decimal precision requirements.
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
trait SalesTestFixtures { self: SparkTestBase =>
val salesSchema: StructType = StructType(Seq(
StructField("transaction_id", StringType, nullable = false),
StructField("region", StringType, nullable = false),
StructField("amount", DecimalType(10, 2), nullable = false),
StructField("status", StringType, nullable = false),
StructField("timestamp", TimestampType, nullable = true)
))
def createSalesDataFrame(data: Seq[Row]): DataFrame = {
spark.createDataFrame(
spark.sparkContext.parallelize(data),
salesSchema
)
}
val sampleSalesData: Seq[Row] = Seq(
Row("TXN001", "North", BigDecimal("150.00").bigDecimal, "active", null),
Row("TXN002", "South", BigDecimal("200.50").bigDecimal, "active", null),
Row("TXN003", "North", BigDecimal("75.25").bigDecimal, "inactive", null)
)
}
Use these fixtures in your tests for consistency:
class SchemaAwareTest extends AnyFunSuite
with SparkTestBase with SalesTestFixtures with Matchers {
test("transformation preserves schema precision") {
val input = createSalesDataFrame(sampleSalesData)
val result = SalesTransformations.filterActiveSales(input)
result.schema("amount").dataType shouldBe DecimalType(10, 2)
}
}
Mocking External Dependencies
Production Spark jobs read from databases, S3 buckets, and Kafka topics. Testing against these systems is slow and brittle. Instead, abstract your data access behind traits:
trait DataReader {
def readSales(spark: SparkSession): DataFrame
def readRegions(spark: SparkSession): DataFrame
}
class ProductionDataReader(basePath: String) extends DataReader {
override def readSales(spark: SparkSession): DataFrame = {
spark.read.parquet(s"$basePath/sales")
}
override def readRegions(spark: SparkSession): DataFrame = {
spark.read.jdbc(
"jdbc:postgresql://prod-db:5432/warehouse",
"regions",
new java.util.Properties()
)
}
}
class TestDataReader(salesData: DataFrame, regionsData: DataFrame)
extends DataReader {
override def readSales(spark: SparkSession): DataFrame = salesData
override def readRegions(spark: SparkSession): DataFrame = regionsData
}
Your job class accepts a reader through constructor injection:
class SalesPipeline(reader: DataReader) {
def run(spark: SparkSession): DataFrame = {
val sales = reader.readSales(spark)
val regions = reader.readRegions(spark)
val activeSales = SalesTransformations.filterActiveSales(sales)
activeSales.join(regions, Seq("region"), "left")
}
}
Testing becomes straightforward:
test("pipeline joins sales with region metadata") {
import spark.implicits._
val testSales = Seq(("North", 100.0, "active")).toDF("region", "amount", "status")
val testRegions = Seq(("North", "US")).toDF("region", "country")
val reader = new TestDataReader(testSales, testRegions)
val pipeline = new SalesPipeline(reader)
val result = pipeline.run(spark)
result.columns should contain("country")
}
Integration Testing Strategies
Some tests need to verify actual file I/O. Use temporary directories that clean up automatically:
import java.nio.file.Files
import org.scalatest.BeforeAndAfterEach
class ParquetIntegrationTest extends AnyFunSuite
with SparkTestBase with BeforeAndAfterEach with Matchers {
var tempDir: java.nio.file.Path = _
override def beforeEach(): Unit = {
tempDir = Files.createTempDirectory("spark-test")
}
override def afterEach(): Unit = {
import scala.reflect.io.Directory
new Directory(tempDir.toFile).deleteRecursively()
}
test("pipeline writes partitioned parquet correctly") {
import spark.implicits._
val input = Seq(
("2024-01", "North", 100.0),
("2024-01", "South", 150.0),
("2024-02", "North", 200.0)
).toDF("month", "region", "amount")
val outputPath = tempDir.resolve("output").toString
input.write
.partitionBy("month")
.parquet(outputPath)
val reloaded = spark.read.parquet(outputPath)
reloaded.count() shouldBe 3
// Verify partition structure
val partitions = new java.io.File(outputPath).listFiles()
.filter(_.isDirectory)
.map(_.getName)
partitions should contain allOf ("month=2024-01", "month=2024-02")
}
}
Best Practices and CI/CD Considerations
Keep your test suite fast by minimizing SparkSession creation. One session per test class is ideal—one per test method is a performance killer. Structure your tests so related transformations share a suite.
Configure your CI runner with adequate memory. Spark tests need at least 2GB of driver memory. In your CI configuration:
# GitHub Actions example
- name: Run tests
run: sbt test
env:
SBT_OPTS: "-Xmx4G -XX:+UseG1GC"
SPARK_LOCAL_DIRS: /tmp/spark
Organize tests for parallel execution at the suite level, not the method level. ScalaTest’s parallel execution works well when each suite manages its own SparkSession lifecycle:
// build.sbt
Test / testOptions += Tests.Argument("-oDF")
Test / parallelExecution := true
Test / fork := true
Finally, separate fast unit tests from slower integration tests using tags or separate source directories. Run unit tests on every commit; run integration tests on merge to main. This keeps developer feedback loops tight while maintaining comprehensive coverage.
Testing Spark applications requires upfront investment in architecture and test infrastructure. The patterns shown here—isolated transformations, dependency injection, shared sessions—create a foundation for reliable, maintainable data pipelines. Start with transformation unit tests, add integration tests for critical paths, and expand coverage as your pipeline complexity grows.