Scala - collect with Partial Functions

Partial functions in Scala are functions defined only for a subset of possible input values. Unlike total functions that handle all inputs, partial functions explicitly define their domain using the...

Key Insights

  • collect combines filter and map in a single operation using partial functions, eliminating intermediate collections and improving code readability
  • Partial functions defined with case statements allow pattern matching directly in collection transformations, making code more declarative and type-safe
  • Understanding collect versus map/filter chains reveals performance benefits and cleaner handling of Option types and pattern matching scenarios

Understanding Partial Functions in Scala

Partial functions in Scala are functions defined only for a subset of possible input values. Unlike total functions that handle all inputs, partial functions explicitly define their domain using the isDefinedAt method. The PartialFunction[A, B] trait extends Function1[A, B] with this additional capability.

val divide: PartialFunction[(Int, Int), Int] = {
  case (numerator, denominator) if denominator != 0 => 
    numerator / denominator
}

divide.isDefinedAt((10, 2))  // true
divide.isDefinedAt((10, 0))  // false
divide((10, 2))              // 5

The collect method leverages partial functions to transform collections by applying the function only to elements where it’s defined, effectively filtering and mapping in one pass.

Basic collect Usage

The collect method signature looks like this: def collect[B](pf: PartialFunction[A, B]): Collection[B]. It applies the partial function to each element, keeping only the results where the function is defined.

val numbers = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

val evenSquares = numbers.collect {
  case n if n % 2 == 0 => n * n
}
// Result: List(4, 16, 36, 64, 100)

Compare this to the equivalent filter and map chain:

val evenSquares = numbers
  .filter(n => n % 2 == 0)
  .map(n => n * n)

The collect version is more concise and creates only one intermediate collection instead of two.

Pattern Matching with collect

The real power of collect emerges when working with algebraic data types and pattern matching. Consider a domain model with different event types:

sealed trait Event
case class UserLogin(userId: String, timestamp: Long) extends Event
case class UserLogout(userId: String, timestamp: Long) extends Event
case class PageView(userId: String, page: String, timestamp: Long) extends Event
case class Purchase(userId: String, amount: Double, timestamp: Long) extends Event

val events: List[Event] = List(
  UserLogin("user1", 1000),
  PageView("user1", "/home", 1001),
  Purchase("user1", 99.99, 1002),
  UserLogout("user1", 1003),
  PageView("user2", "/products", 1004),
  Purchase("user2", 149.99, 1005)
)

// Extract all purchase amounts
val purchaseAmounts = events.collect {
  case Purchase(_, amount, _) => amount
}
// Result: List(99.99, 149.99)

// Extract user IDs from login events only
val loggedInUsers = events.collect {
  case UserLogin(userId, _) => userId
}
// Result: List("user1")

This approach is significantly cleaner than using filter with type checking and casting:

// Verbose alternative
val purchaseAmounts = events
  .filter(_.isInstanceOf[Purchase])
  .map(_.asInstanceOf[Purchase].amount)

Handling Options with collect

When working with collections of Option values, collect provides an elegant solution to extract and transform defined values:

val maybeNumbers: List[Option[Int]] = List(Some(1), None, Some(3), None, Some(5))

val doubled = maybeNumbers.collect {
  case Some(n) => n * 2
}
// Result: List(2, 6, 10)

This is more idiomatic than flatten and map:

val doubled = maybeNumbers.flatten.map(_ * 2)

You can also transform and filter simultaneously:

val evenDoubled = maybeNumbers.collect {
  case Some(n) if n % 2 == 0 => n * 2
}

Complex Transformations

Combine multiple conditions and transformations in a single collect operation:

case class Transaction(id: String, amount: Double, status: String, category: String)

val transactions = List(
  Transaction("t1", 100.0, "completed", "electronics"),
  Transaction("t2", 50.0, "pending", "books"),
  Transaction("t3", 200.0, "completed", "electronics"),
  Transaction("t4", 75.0, "completed", "books"),
  Transaction("t5", 300.0, "failed", "electronics")
)

// Extract amounts from completed electronics purchases over $100
val significantElectronics = transactions.collect {
  case Transaction(_, amount, "completed", "electronics") if amount > 100 => amount
}
// Result: List(200.0)

// Transform to different type
case class Summary(id: String, total: Double)

val completedSummaries = transactions.collect {
  case Transaction(id, amount, "completed", _) => 
    Summary(id, amount * 1.1) // Add 10% tax
}
// Result: List(Summary("t1", 110.0), Summary("t3", 220.0), Summary("t4", 82.5))

Working with Maps

The collect method works on maps by pattern matching on key-value tuples:

val userScores = Map(
  "alice" -> 95,
  "bob" -> 67,
  "charlie" -> 82,
  "diana" -> 45,
  "eve" -> 91
)

// Extract names of users who passed (score >= 70)
val passedUsers = userScores.collect {
  case (name, score) if score >= 70 => name.capitalize
}
// Result: Iterable("Alice", "Charlie", "Eve")

// Transform passing scores to grades
val grades = userScores.collect {
  case (name, score) if score >= 90 => (name, "A")
  case (name, score) if score >= 80 => (name, "B")
  case (name, score) if score >= 70 => (name, "C")
}
// Result: Map("alice" -> "A", "charlie" -> "B", "eve" -> "A")

Nested Pattern Matching

Handle nested structures elegantly with deep pattern matching:

case class Address(city: String, country: String)
case class User(name: String, address: Option[Address], age: Int)

val users = List(
  User("Alice", Some(Address("NYC", "USA")), 30),
  User("Bob", None, 25),
  User("Charlie", Some(Address("London", "UK")), 35),
  User("Diana", Some(Address("Paris", "France")), 28)
)

// Extract cities for users over 30 with addresses
val cities = users.collect {
  case User(_, Some(Address(city, _)), age) if age > 30 => city
}
// Result: List("London")

// Complex nested extraction
case class Company(name: String, users: List[User])

val companies = List(
  Company("TechCorp", users.take(2)),
  Company("GlobalInc", users.drop(2))
)

val usaCities = companies.collect {
  case Company(_, userList) => 
    userList.collect {
      case User(_, Some(Address(city, "USA")), _) => city
    }
}.flatten
// Result: List("NYC")

Performance Considerations

The collect method creates a single pass through the collection and builds one result collection. This is more efficient than chaining filter and map:

// Single pass - more efficient
val result1 = (1 to 1000000).collect {
  case n if n % 2 == 0 => n * n
}

// Two passes - less efficient
val result2 = (1 to 1000000)
  .filter(n => n % 2 == 0)
  .map(n => n * n)

However, for simple transformations without filtering, map is more appropriate. Use collect when you need both filtering and transformation, especially with pattern matching.

Chaining collect Operations

Multiple collect operations can be chained for complex multi-stage transformations:

sealed trait Result[+A]
case class Success[A](value: A) extends Result[A]
case class Failure(error: String) extends Result[Nothing]

val results: List[Result[Int]] = List(
  Success(10),
  Failure("error1"),
  Success(20),
  Success(30),
  Failure("error2")
)

val processedValues = results
  .collect { case Success(n) => n }
  .collect { case n if n > 15 => n * 2 }
// Result: List(40, 60)

For better performance with multiple filtering conditions, combine them in a single collect:

val processedValues = results.collect {
  case Success(n) if n > 15 => n * 2
}

Custom Partial Functions

Define reusable partial functions for common patterns:

def onlyPositive[A]: PartialFunction[A, A] = {
  case n: Int if n > 0 => n.asInstanceOf[A]
  case d: Double if d > 0 => d.asInstanceOf[A]
}

def doubled: PartialFunction[Int, Int] = {
  case n => n * 2
}

val numbers = List(-5, -2, 0, 3, 7, -1, 9)
val result = numbers.collect(onlyPositive andThen doubled)
// Result: List(6, 14, 18)

The orElse combinator allows fallback behavior:

val handleInts: PartialFunction[Any, String] = {
  case n: Int => s"Integer: $n"
}

val handleStrings: PartialFunction[Any, String] = {
  case s: String => s"String: $s"
}

val mixed: List[Any] = List(1, "hello", 2, "world", 3.14)
val described = mixed.collect(handleInts orElse handleStrings)
// Result: List("Integer: 1", "String: hello", "Integer: 2", "String: world")

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.