Using regular recursion, each recursive call pushes another entry onto the call stack. When the recursion is completed, the application has to pop each entry off all the way back down. If there are much recursive function calls it can end up with a huge stack.
Scala automatically removes the recursion in case it finds the recursive call in tail position. The annotation (@tailrec
) can be added to recursive functions to ensure that tail call optimization is performed. The compiler then shows an error message if it can't optimize your recursion.
This example is not tail recursive because when the recursive call is made, the function needs to keep track of the multiplication it needs to do with the result after the call returns.
def fact(i : Int) : Int = {
if(i <= 1) i
else i * fact(i-1)
}
println(fact(5))
The function call with the parameter will result in a stack that looks like this:
(fact 5)
(* 5 (fact 4))
(* 5 (* 4 (fact 3)))
(* 5 (* 4 (* 3 (fact 2))))
(* 5 (* 4 (* 3 (* 2 (fact 1)))))
(* 5 (* 4 (* 3 (* 2 (* 1 (fact 0))))))
(* 5 (* 4 (* 3 (* 2 (* 1 * 1)))))
(* 5 (* 4 (* 3 (* 2))))
(* 5 (* 4 (* 6)))
(* 5 (* 24))
120
If we try to annotate this example with @tailrec
we will get the following error message: could not optimize @tailrec annotated method fact: it contains a recursive call not in tail position
In tail recursion, you perform your calculations first, and then you execute the recursive call, passing the results of your current step to the next recursive step.
def fact_with_tailrec(i : Int) : Long = {
@tailrec
def fact_inside(i : Int, sum: Long) : Long = {
if(i <= 1) sum
else fact_inside(i-1,sum*i)
}
fact_inside(i,1)
}
println(fact_with_tailrec(5))
In contrast, the stack trace for the tail recursive factorial looks like the following:
(fact_with_tailrec 5)
(fact_inside 5 1)
(fact_inside 4 5)
(fact_inside 3 20)
(fact_inside 2 60)
(fact_inside 1 120)
There is only the need to keep track of the same amount of data for every call to fact_inside
because the function is simply returning the value it got right through to the top. This means that even if fact_with_tail 1000000
is called, it needs only the same amount of space as fact_with_tail 3
. This is not the case with the non-tail-recursive call, and as such large values may cause a stack overflow.
It is very common to get a StackOverflowError
error while calling recursive function. Scala standard library offers TailCall to avoid stack overflow by using heap objects and continuations to store the local state of the recursion.
Two examples from the scaladoc of TailCalls
import scala.util.control.TailCalls._
def isEven(xs: List[Int]): TailRec[Boolean] =
if (xs.isEmpty) done(true) else tailcall(isOdd(xs.tail))
def isOdd(xs: List[Int]): TailRec[Boolean] =
if (xs.isEmpty) done(false) else tailcall(isEven(xs.tail))
// Does this List contain an even number of elements?
isEven((1 to 100000).toList).result
def fib(n: Int): TailRec[Int] =
if (n < 2) done(n) else for {
x <- tailcall(fib(n - 1))
y <- tailcall(fib(n - 2))
} yield (x + y)
// What is the 40th entry of the Fibonacci series?
fib(40).result