末尾再帰モナド (FlatMap)

2015年に PureScript でのスタック安全性の取り扱いに関して Phil Freeman (@paf31) さんは Stack Safety for Free を書いた。 PureScript は Java 同様に正格 (strict) な JavaScript にホストされている言語だ:

この論文は Rúnar (@runarorama) さんの Stackless Scala With Free Monads にも言及するが、スタック安全性に関してより抜本的な解法を提示している。

スタック問題とは

``````scala> import scala.annotation.tailrec
import scala.annotation.tailrec
scala> :paste
// Entering paste mode (ctrl-D to finish)
def pow(n: Long, exp: Long): Long =
{
@tailrec def go(acc: Long, p: Long): Long =
(acc, p) match {
case (acc, 0) => acc
case (acc, p) => go(acc * n, p - 1)
}
go(1, exp)
}

// Exiting paste mode, now interpreting.
pow: (n: Long, exp: Long)Long
scala> pow(2, 3)
res3: Long = 8``````

``````scala> :paste
object OddEven0 {
def odd(n: Int): String = even(n - 1)
def even(n: Int): String = if (n <= 0) "done" else odd(n - 1)
}

// Exiting paste mode, now interpreting.

defined object OddEven0

scala> OddEven0.even(200000)
java.lang.StackOverflowError
at OddEven0\$.even(<console>:15)
at OddEven0\$.odd(<console>:14)
at OddEven0\$.even(<console>:15)
at OddEven0\$.odd(<console>:14)
at OddEven0\$.even(<console>:15)
....
``````

``````scala> import cats._, cats.data._, cats.implicits._
import cats._
import cats.data._
import cats.implicits._
scala> :paste
// Entering paste mode (ctrl-D to finish)
case class LongProduct(value: Long)
implicit val longProdMonoid: Monoid[LongProduct] = new Monoid[LongProduct] {
def empty: LongProduct = LongProduct(1)
def combine(x: LongProduct, y: LongProduct): LongProduct = LongProduct(x.value * y.value)
}
def powWriter(x: Long, exp: Long): Writer[LongProduct, Unit] =
exp match {
case 0 => Writer(LongProduct(1L), ())
case m =>
Writer(LongProduct(x), ()) >>= { _ => powWriter(x, exp - 1) }
}

// Exiting paste mode, now interpreting.
defined class LongProduct
longProdMonoid: cats.Monoid[LongProduct] = \$anon\$1@188587b7
powWriter: (x: Long, exp: Long)cats.data.Writer[LongProduct,Unit]
scala> powWriter(2, 3).run
res4: cats.Id[(LongProduct, Unit)] = (LongProduct(8),())``````

``````scala> powWriter(1, 10000).run
java.lang.StackOverflowError
at \$anonfun\$powWriter\$1.apply(<console>:35)
at \$anonfun\$powWriter\$1.apply(<console>:35)
at cats.data.WriterT\$\$anonfun\$flatMap\$1.apply(WriterT.scala:37)
at cats.data.WriterT\$\$anonfun\$flatMap\$1.apply(WriterT.scala:36)
at cats.package\$\$anon\$1.flatMap(package.scala:34)
at cats.data.WriterT.flatMap(WriterT.scala:36)
at cats.data.WriterTFlatMap1\$class.flatMap(WriterT.scala:249)
at cats.data.WriterTInstances2\$\$anon\$4.flatMap(WriterT.scala:130)
at cats.data.WriterTInstances2\$\$anon\$4.flatMap(WriterT.scala:130)
at cats.FlatMap\$class.\$greater\$greater\$eq(FlatMap.scala:26)
at cats.data.WriterTInstances2\$\$anon\$4.\$greater\$greater\$eq(WriterT.scala:130)
at cats.FlatMap\$Ops\$class.\$greater\$greater\$eq(FlatMap.scala:20)
at cats.syntax.FlatMapSyntax1\$\$anon\$1.\$greater\$greater\$eq(flatMap.scala:6)
at .powWriter1(<console>:35)
at \$anonfun\$powWriter\$1.apply(<console>:35)
``````

この Scala の特性は `flatMap` がモナディック関数を呼び出して、さらにそれが `flatMap` を呼び出すといった形のモナディック合成の便利さを制限するものだ。

``````class (Monad m) <= MonadRec m where
tailRecM :: forall a b. (a -> m (Either a b)) -> a -> m b
``````

Scala で同じ関数を書くとこうなる:

``````  /**
* Keeps calling `f` until a `scala.util.Right[B]` is returned.
*/
def tailRecM[A, B](a: A)(f: A => F[Either[A, B]]): F[B]
``````

``````scala> def tailRecM[A, B] = FlatMap[Writer[Vector[String], ?]].tailRecM[A, B] _
tailRecM: [A, B]=> A => ((A => cats.data.WriterT[[A]A,scala.collection.immutable.Vector[String],Either[A,B]]) => cats.data.WriterT[[A]A,scala.collection.immutable.Vector[String],B])``````

スタックセーフな `powWriter` はこう書くことができる:

``````scala> :paste
// Entering paste mode (ctrl-D to finish)
def powWriter2(x: Long, exp: Long): Writer[LongProduct, Unit] =
FlatMap[Writer[LongProduct, ?]].tailRecM(exp) {
case 0L      => Writer.value[LongProduct, Either[Long, Unit]](Right(()))
case m: Long => Writer.tell(LongProduct(x)) >>= { _ => Writer.value(Left(m - 1)) }
}

// Exiting paste mode, now interpreting.
powWriter2: (x: Long, exp: Long)cats.data.Writer[LongProduct,Unit]
scala> powWriter2(2, 3).run
res5: cats.Id[(LongProduct, Unit)] = (LongProduct(8),())
scala> powWriter2(1, 10000).run
res6: cats.Id[(LongProduct, Unit)] = (LongProduct(1),())``````

これは `FlatMap` 型クラスのユーザにとってはより大きな安全性を保証するものだが、 インスタンスの実装する者は安全な `tailRecM` を提供しなければいけないことも意味している。

``````@tailrec
def tailRecM[A, B](a: A)(f: A => Option[Either[A, B]]): Option[B] =
f(a) match {
case None => None
case Some(Left(a1)) => tailRecM(a1)(f)
case Some(Right(b)) => Some(b)
}
``````