Tail Recursive Monads (FlatMap) 

In 2015 Phil Freeman (@paf31) wrote Stack Safety for Free working with PureScript hosted on JavaScript, a strict language like Scala:

The paper gives a hat tip to Rúnar (@runarorama)’s Stackless Scala With Free Monads, but presents a more drastic solution to the stack safety problem.

The stack problem 

As a background, in Scala the compiler is able to optimize on self-recursive tail calls. For example, here’s an example of a self-recursive tail calls.

import scala.annotation.tailrec

def pow(n: Long, exp: Long): Long =
  {
    @tailrec def go(acc: Long, p: Long): Long =
      (acc, p) match {
        case (acc, 0) => acc
        case (acc, p) => go(acc * n, p - 1)
      }
    go(1, exp)
  }

pow(2, 3)
// res0: Long = 8L

Here’s an example that’s not self-recursive. It’s blowing up the stack.

scala> :paste
object OddEven0 {
  def odd(n: Int): String = even(n - 1)
  def even(n: Int): String = if (n <= 0) "done" else odd(n - 1)
}

// Exiting paste mode, now interpreting.

defined object OddEven0

scala> OddEven0.even(200000)
java.lang.StackOverflowError
  at OddEven0$.even(<console>:15)
  at OddEven0$.odd(<console>:14)
  at OddEven0$.even(<console>:15)
  at OddEven0$.odd(<console>:14)
  at OddEven0$.even(<console>:15)
  ....

Next, we’d try to add Writer datatype to do the pow calculation using LongProduct monoid.

import cats._, cats.data._, cats.syntax.all._

case class LongProduct(value: Long)

implicit val longProdMonoid: Monoid[LongProduct] = new Monoid[LongProduct] {
  def empty: LongProduct = LongProduct(1)
  def combine(x: LongProduct, y: LongProduct): LongProduct = LongProduct(x.value * y.value)
}
// longProdMonoid: Monoid[LongProduct] = repl.MdocSession1@6b0f544

def powWriter(x: Long, exp: Long): Writer[LongProduct, Unit] =
  exp match {
    case 0 => Writer(LongProduct(1L), ())
    case m =>
      Writer(LongProduct(x), ()) >>= { _ => powWriter(x, exp - 1) }
  }

powWriter(2, 3).run
// res1: (LongProduct, Unit) = (LongProduct(value = 8L), ())

This is no longer self-recursive, so it will blow the stack with large exp.

scala> powWriter(1, 10000).run
java.lang.StackOverflowError
  at $anonfun$powWriter$1.apply(<console>:35)
  at $anonfun$powWriter$1.apply(<console>:35)
  at cats.data.WriterT$$anonfun$flatMap$1.apply(WriterT.scala:37)
  at cats.data.WriterT$$anonfun$flatMap$1.apply(WriterT.scala:36)
  at cats.package$$anon$1.flatMap(package.scala:34)
  at cats.data.WriterT.flatMap(WriterT.scala:36)
  at cats.data.WriterTFlatMap1$class.flatMap(WriterT.scala:249)
  at cats.data.WriterTInstances2$$anon$4.flatMap(WriterT.scala:130)
  at cats.data.WriterTInstances2$$anon$4.flatMap(WriterT.scala:130)
  at cats.FlatMap$class.$greater$greater$eq(FlatMap.scala:26)
  at cats.data.WriterTInstances2$$anon$4.$greater$greater$eq(WriterT.scala:130)
  at cats.FlatMap$Ops$class.$greater$greater$eq(FlatMap.scala:20)
  at cats.syntax.FlatMapSyntax1$$anon$1.$greater$greater$eq(flatMap.scala:6)
  at .powWriter1(<console>:35)
  at $anonfun$powWriter$1.apply(<console>:35)

This characteristic of Scala limits the usefulness of monadic composition where flatMap can call monadic function f, which then can call flatMap etc..

FlatMap (MonadRec) 

Our solution is to reduce the candidates for the target monad m from an arbitrary monad, to the class of so-called tail-recursive monads.

class (Monad m) <= MonadRec m where
  tailRecM :: forall a b. (a -> m (Either a b)) -> a -> m b

Here’s the same function in Scala:

  /**
   * Keeps calling `f` until a `scala.util.Right[B]` is returned.
   */
  def tailRecM[A, B](a: A)(f: A => F[Either[A, B]]): F[B]

As it turns out, Oscar Boykin (@posco) brought tailRecM into FlatMap in #1280 (Remove FlatMapRec make all FlatMap implement tailRecM), and it’s now part of Cats 0.7.0. In other words, all FlatMap/Monads in Cats 0.7.0 are tail-recursive.

We can for example, obtain the tailRecM for Writer:

def tailRecM[A, B] = FlatMap[Writer[Vector[String], *]].tailRecM[A, B] _

Here’s how we can make a stack-safe powWriter:

def powWriter2(x: Long, exp: Long): Writer[LongProduct, Unit] =
  FlatMap[Writer[LongProduct, *]].tailRecM(exp) {
    case 0L      => Writer.value[LongProduct, Either[Long, Unit]](Right(()))
    case m: Long => Writer.tell(LongProduct(x)) >>= { _ => Writer.value(Left(m - 1)) }
  }

powWriter2(2, 3).run
// res2: (LongProduct, Unit) = (LongProduct(value = 8L), ())

powWriter2(1, 10000).run
// res3: (LongProduct, Unit) = (LongProduct(value = 1L), ())

This guarantees greater safety on the user of FlatMap typeclass, but it would mean that each the implementers of the instances would need to provide a safe tailRecM.

Here’s the one for Option for example:

@tailrec
def tailRecM[A, B](a: A)(f: A => Option[Either[A, B]]): Option[B] =
  f(a) match {
    case None => None
    case Some(Left(a1)) => tailRecM(a1)(f)
    case Some(Right(b)) => Some(b)
  }

That’s it for today!