Scala - Recursion and Tail Recursion
Recursion occurs when a function calls itself to solve a problem by breaking it down into smaller subproblems. In Scala, recursion is the preferred approach over imperative loops for many algorithms,...
Key Insights
- Recursion in Scala is more idiomatic than imperative loops, but naive implementations can cause stack overflow errors for large inputs due to stack frame accumulation
- Tail recursion allows the compiler to optimize recursive calls into iterative loops, eliminating stack growth and enabling constant memory usage through the
@tailrecannotation - Understanding when recursion is tail-recursive versus when it requires accumulator patterns is critical for writing performant functional code in Scala
Understanding Basic Recursion
Recursion occurs when a function calls itself to solve a problem by breaking it down into smaller subproblems. In Scala, recursion is the preferred approach over imperative loops for many algorithms, aligning with functional programming principles.
Here’s a simple recursive factorial implementation:
def factorial(n: Int): Int = {
if (n <= 1) 1
else n * factorial(n - 1)
}
println(factorial(5)) // 120
This works correctly for small inputs, but there’s a critical problem. Each recursive call adds a new frame to the call stack. For factorial(5), the call stack looks like:
factorial(5)
5 * factorial(4)
4 * factorial(3)
3 * factorial(2)
2 * factorial(1)
1
The computation can only complete after unwinding all these frames. Try factorial(10000) and you’ll encounter a StackOverflowError.
The Stack Overflow Problem
Let’s demonstrate the stack limitation with a simple sum function:
def sum(n: Int): Int = {
if (n <= 0) 0
else n + sum(n - 1)
}
// This works fine
println(sum(100)) // 5050
// This will crash with StackOverflowError
// println(sum(100000))
Each recursive call consumes stack space. The JVM’s default stack size (typically 1MB) limits how deep recursion can go. For sum(100000), we need 100,000 stack frames, which exceeds this limit.
Tail Recursion: The Solution
A function is tail-recursive when the recursive call is the last operation performed. No computation happens after the recursive call returns. This allows the Scala compiler to optimize the recursion into a loop, reusing the same stack frame.
Here’s the tail-recursive version of factorial:
import scala.annotation.tailrec
def factorial(n: Int): Long = {
@tailrec
def loop(n: Int, accumulator: Long): Long = {
if (n <= 1) accumulator
else loop(n - 1, n * accumulator)
}
loop(n, 1)
}
println(factorial(20)) // 2432902008176640000
The @tailrec annotation is crucial—it makes the compiler verify that the function is truly tail-recursive. If it’s not, compilation fails with a clear error message.
The key difference: in the tail-recursive version, we pass the accumulated result forward through parameters rather than computing it on the way back up the call stack.
Tail Recursion with Accumulators
The accumulator pattern is the standard technique for converting regular recursion to tail recursion. Instead of computing results during the unwinding phase, we build the result as we go deeper into recursion.
Here’s a tail-recursive sum:
import scala.annotation.tailrec
def sum(n: Int): Int = {
@tailrec
def loop(current: Int, accumulator: Int): Int = {
if (current <= 0) accumulator
else loop(current - 1, accumulator + current)
}
loop(n, 0)
}
println(sum(100000)) // 705082704 (works without stack overflow)
For list operations, accumulators become more sophisticated:
import scala.annotation.tailrec
def reverse[A](list: List[A]): List[A] = {
@tailrec
def loop(remaining: List[A], accumulator: List[A]): List[A] = {
remaining match {
case Nil => accumulator
case head :: tail => loop(tail, head :: accumulator)
}
}
loop(list, Nil)
}
println(reverse(List(1, 2, 3, 4, 5))) // List(5, 4, 3, 2, 1)
When Tail Recursion Isn’t Possible
Not all recursive algorithms can be made tail-recursive without significant restructuring. Tree traversals are a classic example:
case class Tree(value: Int, left: Option[Tree] = None, right: Option[Tree] = None)
// This cannot be tail-recursive
def sumTree(tree: Option[Tree]): Int = tree match {
case None => 0
case Some(node) => node.value + sumTree(node.left) + sumTree(node.right)
}
val tree = Some(Tree(1,
Some(Tree(2, Some(Tree(4)), Some(Tree(5)))),
Some(Tree(3))
))
println(sumTree(tree)) // 15
The problem: after the recursive call to sumTree(node.left), we still need to call sumTree(node.right) and add the results. These operations happen after the recursive calls, preventing tail call optimization.
For such cases, you can use explicit stack-based iteration:
def sumTreeIterative(tree: Option[Tree]): Int = {
var sum = 0
var stack = List(tree)
while (stack.nonEmpty) {
stack.head match {
case None => stack = stack.tail
case Some(node) =>
sum += node.value
stack = node.left :: node.right :: stack.tail
}
}
sum
}
println(sumTreeIterative(tree)) // 15
Mutual Recursion and Trampolining
When two functions call each other recursively, tail call optimization doesn’t work in the JVM. Scala provides TailCalls for this scenario:
import scala.util.control.TailCalls._
def isEven(n: Int): TailRec[Boolean] = {
if (n == 0) done(true)
else tailcall(isOdd(n - 1))
}
def isOdd(n: Int): TailRec[Boolean] = {
if (n == 0) done(false)
else tailcall(isEven(n - 1))
}
println(isEven(100000).result) // true
println(isOdd(100000).result) // false
The TailRec type and tailcall construct create a trampoline that bounces between functions without growing the stack.
Performance Considerations
Tail-recursive functions compile to efficient loops. Here’s what the compiler generates for our tail-recursive factorial:
// Conceptually becomes:
def factorial(n: Int): Long = {
var current = n
var accumulator = 1L
while (current > 1) {
accumulator = current * accumulator
current = current - 1
}
accumulator
}
This has O(1) space complexity instead of O(n) for the naive recursive version. The performance is identical to hand-written imperative loops.
Practical Guidelines
Use the @tailrec annotation religiously. It documents intent and catches mistakes:
import scala.annotation.tailrec
// Compilation error: "could not optimize @tailrec annotated method"
@tailrec
def fibonacci(n: Int): Int = {
if (n <= 1) n
else fibonacci(n - 1) + fibonacci(n - 2)
}
This error message immediately tells you the function isn’t tail-recursive and needs refactoring.
For recursive algorithms on collections, consider Scala’s built-in methods first—they’re already optimized:
// Instead of writing recursive sum
val numbers = (1 to 100000).toList
println(numbers.sum) // Optimized implementation
// Instead of recursive filter
println(numbers.filter(_ % 2 == 0).size)
When writing recursive functions, ask: “Is the recursive call the absolute last thing that happens?” If you need to do anything with the returned value, it’s not tail-recursive and needs an accumulator.