Scala - groupBy with Examples

• The `groupBy` method transforms collections into Maps by partitioning elements based on a discriminator function, enabling efficient data categorization and aggregation patterns

Key Insights

• The groupBy method transforms collections into Maps by partitioning elements based on a discriminator function, enabling efficient data categorization and aggregation patterns • Scala’s groupBy works seamlessly across all collection types (List, Seq, Array, Set) and returns immutable Maps by default, with values maintaining the original collection type • Common pitfalls include memory issues with large datasets and inefficient discriminator functions—use groupMap or groupMapReduce for better performance in transformation scenarios

Understanding groupBy Fundamentals

The groupBy method is a higher-order function available on Scala collections that partitions elements into groups based on a discriminator function. The function returns a Map[K, C] where K is the key type returned by the discriminator and C is the collection type containing elements that share the same key.

val numbers = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
val grouped = numbers.groupBy(n => n % 2 == 0)

println(grouped)
// Map(false -> List(1, 3, 5, 7, 9), true -> List(2, 4, 6, 8, 10))

The discriminator function n => n % 2 == 0 evaluates to a Boolean, creating two groups: even numbers (true) and odd numbers (false). The resulting Map preserves the original collection type (List) for its values.

Grouping by Object Properties

A practical use case involves grouping domain objects by their attributes. This pattern appears frequently in data processing and reporting scenarios.

case class Employee(id: Int, name: String, department: String, salary: Double)

val employees = List(
  Employee(1, "Alice", "Engineering", 95000),
  Employee(2, "Bob", "Sales", 75000),
  Employee(3, "Carol", "Engineering", 105000),
  Employee(4, "David", "Sales", 80000),
  Employee(5, "Eve", "Marketing", 70000),
  Employee(6, "Frank", "Engineering", 90000)
)

val byDepartment = employees.groupBy(_.department)

byDepartment.foreach { case (dept, emps) =>
  println(s"$dept: ${emps.map(_.name).mkString(", ")}")
}
// Engineering: Alice, Carol, Frank
// Sales: Bob, David
// Marketing: Eve

You can also group by computed values or multiple criteria by returning tuples or custom objects from the discriminator function.

// Group by salary range
val bySalaryRange = employees.groupBy { emp =>
  emp.salary match {
    case s if s < 80000 => "Junior"
    case s if s < 100000 => "Mid"
    case _ => "Senior"
  }
}

// Group by multiple criteria (department and salary range)
val byDeptAndRange = employees.groupBy(emp => 
  (emp.department, if (emp.salary >= 90000) "High" else "Standard")
)

byDeptAndRange.foreach { case ((dept, range), emps) =>
  println(s"$dept - $range: ${emps.size} employees")
}

Aggregating Grouped Data

After grouping, you typically want to perform aggregations on each group. The mapValues or view.mapValues methods transform the grouped collections.

val departmentStats = employees
  .groupBy(_.department)
  .view
  .mapValues { emps =>
    Map(
      "count" -> emps.size,
      "avgSalary" -> emps.map(_.salary).sum / emps.size,
      "maxSalary" -> emps.map(_.salary).max
    )
  }
  .toMap

departmentStats.foreach { case (dept, stats) =>
  println(s"$dept: ${stats("count")} employees, " +
          f"avg salary: $$${stats("avgSalary")}%.2f")
}

Note: In Scala 2.13+, mapValues creates a view by default. For earlier versions, use .view.mapValues(...).toMap to avoid deprecated warnings.

Using groupMap for Transformations

Scala 2.13 introduced groupMap, which combines grouping and mapping in a single operation, offering better performance than groupBy followed by mapValues.

// Extract just names grouped by department
val namesByDept = employees.groupMap(_.department)(_.name)

println(namesByDept)
// Map(Engineering -> List(Alice, Carol, Frank), 
//     Sales -> List(Bob, David), 
//     Marketing -> List(Eve))

// Group by department, extract salaries
val salariesByDept = employees.groupMap(_.department)(_.salary)

val avgSalariesByDept = salariesByDept.view.mapValues { salaries =>
  salaries.sum / salaries.size
}.toMap

The signature is groupMap[K, B](key: A => K)(f: A => B): Map[K, CC[B]], where the first parameter function determines the grouping key and the second transforms each element.

Using groupMapReduce for Aggregations

For aggregation scenarios, groupMapReduce provides optimal performance by combining grouping, mapping, and reduction in one pass.

// Sum salaries by department
val totalSalariesByDept = employees.groupMapReduce(_.department)(_.salary)(_ + _)

println(totalSalariesByDept)
// Map(Engineering -> 290000.0, Sales -> 155000.0, Marketing -> 70000.0)

// Count employees by department
val employeeCount = employees.groupMapReduce(_.department)(_ => 1)(_ + _)

// Find highest paid employee per department
val highestPaidByDept = employees.groupMapReduce(_.department)(identity) { (emp1, emp2) =>
  if (emp1.salary > emp2.salary) emp1 else emp2
}

highestPaidByDept.foreach { case (dept, emp) =>
  println(s"$dept highest paid: ${emp.name} ($$${emp.salary})")
}

The signature is groupMapReduce[K, B](key: A => K)(f: A => B)(reduce: (B, B) => B): Map[K, B]. The reduction function combines values with the same key.

Working with Nested Grouping

Complex data often requires multiple levels of grouping. Chain groupBy operations or use nested discriminator functions.

case class Transaction(id: Int, category: String, subcategory: String, amount: Double)

val transactions = List(
  Transaction(1, "Food", "Groceries", 150.0),
  Transaction(2, "Food", "Restaurants", 75.0),
  Transaction(3, "Transport", "Gas", 60.0),
  Transaction(4, "Food", "Groceries", 120.0),
  Transaction(5, "Transport", "Public", 30.0),
  Transaction(6, "Food", "Restaurants", 90.0)
)

// Nested grouping
val nestedGroups = transactions
  .groupBy(_.category)
  .view
  .mapValues(_.groupBy(_.subcategory))
  .toMap

nestedGroups.foreach { case (cat, subgroups) =>
  println(s"\n$cat:")
  subgroups.foreach { case (subcat, txns) =>
    val total = txns.map(_.amount).sum
    println(f"  $subcat: $$${total}%.2f")
  }
}

Performance Considerations

The groupBy method creates intermediate collections and can consume significant memory with large datasets. Consider these alternatives:

// Instead of groupBy + mapValues + map operations
val inefficient = employees
  .groupBy(_.department)
  .view
  .mapValues(_.map(_.salary))
  .mapValues(_.sum)
  .toMap

// Use groupMapReduce for better performance
val efficient = employees.groupMapReduce(_.department)(_.salary)(_ + _)

// For streaming or large datasets, consider using iterators
def groupByIterator[A, K](iter: Iterator[A], f: A => K): Map[K, List[A]] = {
  iter.foldLeft(Map.empty[K, List[A]]) { (acc, item) =>
    val key = f(item)
    acc + (key -> (item :: acc.getOrElse(key, Nil)))
  }
}

Grouping with Custom Keys and Ordering

You can group by custom case classes or use sorted maps for ordered results.

import scala.collection.immutable.TreeMap

case class DateRange(year: Int, quarter: Int) extends Ordered[DateRange] {
  def compare(that: DateRange): Int = {
    val yearComp = this.year.compare(that.year)
    if (yearComp != 0) yearComp else this.quarter.compare(that.quarter)
  }
}

case class Sale(amount: Double, year: Int, quarter: Int)

val sales = List(
  Sale(1000, 2023, 1), Sale(1500, 2023, 2),
  Sale(2000, 2023, 1), Sale(1200, 2024, 1)
)

val salesByPeriod = sales
  .groupBy(s => DateRange(s.year, s.quarter))
  .to(TreeMap)
  .view
  .mapValues(_.map(_.amount).sum)
  .toMap

salesByPeriod.foreach { case (period, total) =>
  println(f"${period.year} Q${period.quarter}: $$${total}%.2f")
}

The groupBy method is fundamental for data analysis in Scala. Choose groupMap when you need to transform elements during grouping, and groupMapReduce when performing aggregations. For large datasets, be mindful of memory usage and consider streaming approaches or specialized libraries like Apache Spark for distributed processing.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.