Tail Recursive Monads (FlatMap) 

In 2015 Phil Freeman (@paf31) wrote Stack Safety for Free working with PureScript hosted on JavaScript, a strict language like Java:

The paper gives a hat tip to Rúnar (@runarorama)’s Stackless Scala With Free Monads, but presents a more drastic solution to the stack safety problem.

The stack problem 

As a background, in Scala the compiler is able to optimize on self-recursive tail calls. For example, here’s an example of a self-recursive tail calls.

scala> import scala.annotation.tailrec
import scala.annotation.tailrec

scala> :paste
// Entering paste mode (ctrl-D to finish)
def pow(n: Long, exp: Long): Long =
  {
    @tailrec def go(acc: Long, p: Long): Long =
      (acc, p) match {
        case (acc, 0) => acc
        case (acc, p) => go(acc * n, p - 1)
      }
    go(1, exp)
  }

// Exiting paste mode, now interpreting.

pow: (n: Long, exp: Long)Long

scala> pow(2, 3)
res3: Long = 8

Here’s an example that’s not self-recursive. It’s blowing up the stack.

scala> :paste
object OddEven0 {
  def odd(n: Int): String = even(n - 1)
  def even(n: Int): String = if (n <= 0) "done" else odd(n - 1)
}

// Exiting paste mode, now interpreting.

defined object OddEven0

scala> OddEven0.even(200000)
java.lang.StackOverflowError
  at OddEven0$.even(<console>:15)
  at OddEven0$.odd(<console>:14)
  at OddEven0$.even(<console>:15)
  at OddEven0$.odd(<console>:14)
  at OddEven0$.even(<console>:15)
  ....

Next, we’d try to add Writer datatype to do the pow calculation using LongProduct monoid.

scala> import cats._, cats.instances.all._, cats.data.Writer
import cats._
import cats.instances.all._
import cats.data.Writer

scala> import cats.syntax.flatMap._
import cats.syntax.flatMap._

scala> :paste
// Entering paste mode (ctrl-D to finish)
case class LongProduct(value: Long)
implicit val longProdMonoid: Monoid[LongProduct] = new Monoid[LongProduct] {
  def empty: LongProduct = LongProduct(1)
  def combine(x: LongProduct, y: LongProduct): LongProduct = LongProduct(x.value * y.value)
}
def powWriter(x: Long, exp: Long): Writer[LongProduct, Unit] =
  exp match {
    case 0 => Writer(LongProduct(1L), ())
    case m =>
      Writer(LongProduct(x), ()) >>= { _ => powWriter(x, exp - 1) }
  }

// Exiting paste mode, now interpreting.

defined class LongProduct
longProdMonoid: cats.Monoid[LongProduct] = $anon$1@538028a
powWriter: (x: Long, exp: Long)cats.data.Writer[LongProduct,Unit]

scala> powWriter(2, 3).run
res4: cats.Id[(LongProduct, Unit)] = (LongProduct(8),())

This is no longer self-recursive, so it will blow the stack with large exp.

scala> powWriter(1, 10000).run
java.lang.StackOverflowError
  at $anonfun$powWriter$1.apply(<console>:35)
  at $anonfun$powWriter$1.apply(<console>:35)
  at cats.data.WriterT$$anonfun$flatMap$1.apply(WriterT.scala:37)
  at cats.data.WriterT$$anonfun$flatMap$1.apply(WriterT.scala:36)
  at cats.package$$anon$1.flatMap(package.scala:34)
  at cats.data.WriterT.flatMap(WriterT.scala:36)
  at cats.data.WriterTFlatMap1$class.flatMap(WriterT.scala:249)
  at cats.data.WriterTInstances2$$anon$4.flatMap(WriterT.scala:130)
  at cats.data.WriterTInstances2$$anon$4.flatMap(WriterT.scala:130)
  at cats.FlatMap$class.$greater$greater$eq(FlatMap.scala:26)
  at cats.data.WriterTInstances2$$anon$4.$greater$greater$eq(WriterT.scala:130)
  at cats.FlatMap$Ops$class.$greater$greater$eq(FlatMap.scala:20)
  at cats.syntax.FlatMapSyntax1$$anon$1.$greater$greater$eq(flatMap.scala:6)
  at .powWriter1(<console>:35)
  at $anonfun$powWriter$1.apply(<console>:35)

This characteristic of Scala limits the usefulness of monadic composition where flatMap can call monadic function f, which then can call flatMap etc..

FlatMap (MonadRec) 

Our solution is to reduce the candidates for the target monad m from an arbitrary monad, to the class of so-called tail-recursive monads.

class (Monad m) <= MonadRec m where
  tailRecM :: forall a b. (a -> m (Either a b)) -> a -> m b

Here’s the same function in Scala:

  /**
   * Keeps calling `f` until a `scala.util.Right[B]` is returned.
   */
  def tailRecM[A, B](a: A)(f: A => F[Either[A, B]]): F[B]

As it turns out, Oscar Boykin (@posco) brought tailRecM into FlatMap in #1280 (Remove FlatMapRec make all FlatMap implement tailRecM), and it’s now part of Cats 0.7.0. In other words, all FlatMap/Monads in Cats 0.7.0 are tail-recursive.

We can for example, obtain the tailRecM for Writer:

scala> def tailRecM[A, B] = FlatMap[Writer[Vector[String], ?]].tailRecM[A, B] _
tailRecM: [A, B]=> A => ((A => cats.data.WriterT[[A]A,scala.collection.immutable.Vector[String],Either[A,B]]) => cats.data.WriterT[[A]A,scala.collection.immutable.Vector[String],B])

Here’s how we can make a stack-safe powWriter:

scala> :paste
// Entering paste mode (ctrl-D to finish)
def powWriter2(x: Long, exp: Long): Writer[LongProduct, Unit] =
  FlatMap[Writer[LongProduct, ?]].tailRecM(exp) {
    case 0L      => Writer.value[LongProduct, Either[Long, Unit]](Right(()))
    case m: Long => Writer.tell(LongProduct(x)) >>= { _ => Writer.value(Left(m - 1)) }
  }

// Exiting paste mode, now interpreting.

powWriter2: (x: Long, exp: Long)cats.data.Writer[LongProduct,Unit]

scala> powWriter2(2, 3).run
res5: cats.Id[(LongProduct, Unit)] = (LongProduct(8),())

scala> powWriter2(1, 10000).run
res6: cats.Id[(LongProduct, Unit)] = (LongProduct(1),())

This guarantees greater safety on the user of FlatMap typeclass, but it would mean that each the implementers of the instances would need to provide a safe tailRecM.

Here’s the one for Option for example:

@tailrec
def tailRecM[A, B](a: A)(f: A => Option[Either[A, B]]): Option[B] =
  f(a) match {
    case None => None
    case Some(Left(a1)) => tailRecM(a1)(f)
    case Some(Right(b)) => Some(b)
  }

That’s it for today!