search term:

Scala 3 マクロ入門

はじめに

マクロは楽しくかつ強力なツールだが、使いすぎは害もある。責任を持って適度にマクロを楽しんでほしい。

マクロとは何だろうか? よくある説明はマクロはコードを入力として受け取り、コードを出力するプログラムだとされる。それ自体は正しいが、map {...} のような高階関数や名前渡しパラメータのように一見コードのブロックを渡して回っている機能に親しんでいる Scala プログラマには「コードを入力として受け取る」の意味が一見分かりづらいかもしれない。

以下は、僕が Scala 3 にも移植した Expecty という assersion マクロの用例だ:

scala> import com.eed3si9n.expecty.Expecty.assert
import com.eed3si9n.expecty.Expecty.assert

scala> assert(person.say(word1, word2) == "pong pong")
java.lang.AssertionError: assertion failed

assert(person.say(word1, word2) == "pong pong")
       |      |   |      |      |
       |      |   ping   pong   false
       |      ping pong
       Person(Fred,42)

  at com.eed3si9n.expecty.Expecty$ExpectyListener.expressionRecorded(Expecty.scala:35)
  at com.eed3si9n.expecty.RecorderRuntime.recordExpression(RecorderRuntime.scala:39)
  ... 36 elided

例えば assert(...) で名前渡しの引数を使ったとしたら、その値を得るタイミングは制御できるが false しか得ることができない。一方マクロでは、person.say(word1, word2) == "pong pong" というソースコードの形そのものを受け取り、全ての式の評価値を含んだエラーメッセージを自動生成するということができる。頑張って書こうと思えば Predef.assert(...) を使っても手でこのようなエラーメッセージを書くことができるが、非常に退屈な作業となる。マクロの全貌はこれだけでは無い。

よくありがちな考え方としてコンパイラはソースコードをマシンコードへと翻訳するものだとものがある。確かにそういう側面もあるが、コンパイラは他にも多くの事を行っている。型検査 (type checking) はそのうちの一つだ。バイトコード (や JS) を最後に生成する他に、Scala コンパイラはライトウェイトな証明システムとして振る舞い、タイポや引数の型合わせなど様々なエラーを事前にキャッチする。Java の仮想機械は、Scala の型システムが何を行っているかをほとんど知らない。この情報のロスは、何か悪いことかのように型消去とも呼ばれるが、この型とランタイムという二元性によって Scala が JVM、JS、Native 上にゲスト・プログラミング言語として存在することができる。

Scala において、マクロはコンパイル時にアクションを取る方法を提供してくれ、これは Scala の型システムと直接話すことができるホットラインだ。具体例で説明すると、型 A があるとき、ランタイム上からこれが case class であるかを正確に確認する方法は無いと思う。マクロを使うとこれが 5行で書ける:

import scala.quoted.*

inline def isCaseClass[A]: Boolean = ${ isCaseClassImpl[A] }
private def isCaseClassImpl[A: Type](using qctx: Quotes) : Expr[Boolean] =
  import qctx.reflect.*
  val sym = TypeRepr.of[A].typeSymbol
  Expr(sym.isClassDef && sym.flags.is(Flags.Case))

上記の ${ isCaseClassImpl[A] } は Scala 3 マクロの一例で、スプライスと呼ばれる。

クォートとスプライス

公式ドキュメントの Macros の説明では:

Macro はクォートとスプライスという 2つの基礎的な演算から成り立っている。式のクォートは '{...} と書かれ、スプライスは ${...} と書かれる。

クォートは「引用する」、スプライスは「縄などを継ぎ合わせる」というという意味で、「式をクォートする」というふうに動詞として使われる。マクロのエントリーポイントのみでは例外的にトップレベルで ${ isCaseClassImpl[A] } のようにスプライスが出てくる。通常は ${...} はクォート式 '{ ... } の中に現れる。

e が式の場合、'{e} は e の型付けされた抽象構文木を表す。T が型の場合、Type.of[T]T の型構造を表す。「型付けされた抽象構文木」や「型構造」の正確な定義は一旦置いておいて、直感をつかむための用語だと思ってほしい。逆に、${e} は式 e は型付けされた抽象構文木へと評価されることが期待され、その結果は式 (もしくは型) として直近のプログラムへと埋め込まれる。

クォートの中にはスプライスされたパーツを含むことができる。その場合、埋め込まれたスプライスはクォートの形成の一環として評価される。

というわけで、一般的なプロセスとしては、項レベルのパラメータもしくは型を捕獲して、Expr[A] と呼ばれる型付けされた抽象構文木を返す形となる。

Quotes Reflection API

型や項をコードで作ることができる Quotes Reflection API はクォートコンテキストである Quotes trait 以下に公開されている。

注意: 最初は Reflection API が馴染みがあるように見えて、実際に便利なのだが、Scala 3 マクロを学ぶ過程は使わなくても良いときには Reflection を使わずに素のクォートやクォートのパターンマッチなど構文的な (syntactic) な機能を使うことを学ぶことでもある。

Reflection API は一部 Reflection にドキュメント化されているが、僕は Quotes.scala をブラウザで開いてソースを直接読んでいる。

quoted.Exprquoted.Type を用いることでコードを作るだけではなく、AST を検査してコードの分析を行うことができる。マクロは生成されるコードが型安全であることを保証する。Quote Reflection を使うとこれらの保証が無くなるため、マクロ展開時に失敗する可能性があり、追加で明示的なチェックを行う必要がある。

マクロにリフレクション能力を提供するためには、scala.quoted.Quotes 型の givens パラメータを追加して、使用するスコープ内で quotes.reflect.* を import する必要がある。

Reflection API は TypeTypeReprSymbol といった豊富な型ファミリー、そして他にも色々な API を導入する。

+- Tree -+- PackageClause
         |
         +- Statement -+- Import
         |             +- Export
         |             +- Definition --+- ClassDef
         |             |               +- TypeDef
         |             |               +- DefDef
         |             |               +- ValDef
         |             |
         |             +- Term --------+- Ref -+- Ident -+- Wildcard
         |                             |       +- Select
         |                             +- Apply
         |                             +- Block
....
         +- TypeTree ----+- Inferred
....
+- ParamClause -+- TypeParamClause
                +- TermParamClause
+- TypeRepr -+- NamedType -+- TermRef
             |             +- TypeRef
             +- ConstantType
....
+- Selector -+- SimpleSelector
....
+- Signature
+- Position
+- SourceFile
+- Constant -+- BooleanConstant
             +- ByteConstant
....
+- Symbol
+- Flags

マクロと Scala 3 コンパイラ実装を隔離させるために API は抽象型、その抽象型への拡張メソッド、コンパニオンオブジェクトを表す val、そしてコンパニオンオブジェクトの API を記述する trait の集合というパターンとなっている。

Tree

Tree は、Scala コンパイラが理解した形でのソースコードの形を表し、これは抽象構文木と呼ばれる。これは val ... といった定義そして関数呼び出しといった (Term) を含む。マクロでは、Term を扱うことが多いが、Tree のサブ型全般に提供される拡張メソッドの中にも有用なものがあるので、それを見ていく。以下が Quotes.scala からの API だ。拡張メソッドが定義されているのは TreeMethods なのでそこまで読み飛ばす。

/** Tree representing code written in the source */
type Tree <: AnyRef

/** Module object of `type Tree`  */
val Tree: TreeModule

/** Methods of the module object `val Tree` */
trait TreeModule { this: Tree.type => }

/** Makes extension methods on `Tree` available without any imports */
given TreeMethods: TreeMethods

/** Extension methods of `Tree` */
trait TreeMethods {

  extension (self: Tree)
    /** Position in the source code */
    def pos: Position

    /** Symbol of defined or referred by this tree */
    def symbol: Symbol

    /** Shows the tree as String */
    def show(using Printer[Tree]): String

    /** Does this tree represent a valid expression? */
    def isExpr: Boolean

    /** Convert this tree to an `quoted.Expr[Any]` if the tree is a valid expression or throws */
    def asExpr: Expr[Any]
  end extension

  /** Convert this tree to an `quoted.Expr[T]` if the tree is a valid expression or throws */
  extension (self: Tree)
    def asExprOf[T](using Type[T]): Expr[T]

  extension [ThisTree <: Tree](self: ThisTree)
    /** Changes the owner of the symbols in the tree */
    def changeOwner(newOwner: Symbol): ThisTree
  end extension

}

以下は show の使い方だ:

package com.eed3si9n.macroexample

import scala.quoted.*

inline def showTree[A](inline a: A): String = ${showTreeImpl[A]('{ a })}

def showTreeImpl[A: Type](a: Expr[A])(using Quotes): Expr[String] =
  import quotes.reflect.*
  Expr(a.asTerm.show)

これは以下のように使える:

scala> import com.eed3si9n.macroexample.*

scala> showTree(List(1).map(x => x + 1))
val res0: String = scala.List.apply[scala.Int](1).map[scala.Int](((x: scala.Int) => x.+(1)))

型推論の結果を見たりするのに多少役立つかもしれないが、僕が見たかったのは任意のコードの木構造だ。

Printer

AST の構造を見るには Printer.TreeStructure.show(...) を使う:

package com.eed3si9n.macroexample

import scala.quoted.*

inline def showTree[A](inline a: A): String = ${showTreeImpl[A]('{ a })}

def showTreeImpl[A: Type](a: Expr[A])(using Quotes): Expr[String] =
  import quotes.reflect.*
  Expr(Printer.TreeStructure.show(a.asTerm))

仕切り直し:

scala> import com.eed3si9n.macroexample.*

scala> showTree(List(1).map(x => x + 1))
val res0: String = Inlined(None, Nil, Apply(TypeApply(Select(Apply(TypeApply(Select(Ident("List"), "apply"), List(Inferred())), List(Typed(Repeated(List(Literal(IntConstant(1))), Inferred()), Inferred()))), "map"), List(Inferred())), List(Block(List(DefDef("$anonfun", List(TermParamClause(List(ValDef("x", Inferred(), None)))), Inferred(), Some(Apply(Select(Ident("x"), "+"), List(Literal(IntConstant(1))))))), Closure(Ident("$anonfun"), None)))))

求めていたのは、これ。注意としては、この木のエンコードは Scala 3.x を通じて安定してるか分からないので、詳細にべったり依存するのは安全では無い可能性があるので、unapply 抽出子を使ったほうがいいと思う (これに関して互換性が保証するのかしないのかは僕は知らない)。しかし、コンパイラが構築したものと自分が人工的に構築したものを比べるツールとしてこれは役立つと思う。

Literal

通常は Literal(...) の木をこのように作る必要はあんまり無いが、基礎となる木なので、単独で説明を始めやすい:

/** `TypeTest` that allows testing at runtime in a pattern match if a `Tree` is a `Literal` */
given LiteralTypeTest: TypeTest[Tree, Literal]

/** Tree representing a literal value in the source code */
type Literal <: Term

/** Module object of `type Literal`  */
val Literal: LiteralModule

/** Methods of the module object `val Literal` */
trait LiteralModule { this: Literal.type =>

  /** Create a literal constant */
  def apply(constant: Constant): Literal

  def copy(original: Tree)(constant: Constant): Literal

  /** Matches a literal constant */
  def unapply(x: Literal): Some[Constant]
}

/** Makes extension methods on `Literal` available without any imports */
given LiteralMethods: LiteralMethods

/** Extension methods of `Literal` */
trait LiteralMethods:
  extension (self: Literal)
    /** Value of this literal */
    def constant: Constant
  end extension
end LiteralMethods

抽象型の type LiteralLiteral 木を表し、LiteralModule は、コンパニオンオブジェクト Literal を記述する。ここでは、apply(...)copy(...)unapply(...) を提供しているのが分かる。

これを使って、Int リテラルを受け取ってコンパイル時に 1を加算する addOne(...) マクロを実装できるはずだ。これは単に n + 1 を返すのとは違うことに注意してほしい。n + 1 は実行時に計算する。僕たちがやりたいのは、1 を渡すと *.class が計算無しで 2 を含んでいることだ。

package com.eed3si9n.macroexample

import scala.quoted.*

inline def addOne_bad(inline x: Int): Int = ${addOne_badImpl('{x})}

def addOne_badImpl(x: Expr[Int])(using Quotes): Expr[Int] =
  import quotes.reflect.*
  x.asTerm match
    case Inlined(_, _, Literal(IntConstant(n))) =>
      Literal(IntConstant(n + 1)).asExprOf[Int]

これは意味無く冗長な書き方になっている。

FromExpr 型クラス

Int を含む、FromExpr 型クラスのインスタンスを形成する型の場合は、Expr の拡張メソッドである .value を使った方が簡単だ。value は以下のように定義される:

def value(using FromExpr[T]): Option[T] =
  given Quotes = Quotes.this
  summon[FromExpr[T]].unapply(self)

同様に、ExprExpr.apply(...) を使って構築できる ToExpr 型クラスがある。

そのため、これらと .value の兄弟である .valueOrError を使うことで addOne(...) は 1行マクロとして書き換える事ができる:

package com.eed3si9n.macroexample

import scala.quoted.*

inline def addOne(inline x: Int): Int = ${addOneImpl('{x})}

def addOneImpl(x: Expr[Int])(using Quotes): Expr[Int] =
  Expr(x.valueOrError + 1)

こっちの方がシンプルであるだけじゃなく、Reflection API を使っていないのでより型安全だというのもポイントだ。

Position

マクロ機能のデモとして、Position も見ていこう。Position はソースコード内での位置を表し、ファイル名や行数などを保持する。

以下は Source.line 関数の実装だ。

package com.eed3si9n.macroexample

import scala.quoted.*

object Source:
  inline def line: Int = ${lineImpl()}
  def lineImpl()(using Quotes): Expr[Int] =
    import quotes.reflect.*
    val pos = Position.ofMacroExpansion
    Expr(pos.startLine + 1)
end Source

これは以下のように使うことができる:

package com.eed3si9n.macroexample

object PositionTest extends verify.BasicTestSuite:
  test("testLine") {
    assert(Source.line == 5)
  }
end PositionTest

Apply

実践的なマクロのほとんどはメソッドの呼び出しに関わると思うので Apply も見ていこう。addOne の結果を List で返すマクロの例だ。

package com.eed3si9n.macroexample

import scala.quoted.*

inline def addOneList(inline x: Int): List[Int] = ${addOneListImpl('{x})}

def addOneListImpl(x: Expr[Int])(using Quotes): Expr[List[Int]] =
  val inner = Expr(x.valueOrError + 1)
  '{ List($inner) }

手でゴリゴリ Apply(...) 木を作るのでは無く、普通の Scala を使って List(...) 呼び出しを書いて、中に式をスプライスして、それを丸っと '{ ... } でクォートすることができた。List(...) メソッドと言っても実際には _root_.scala.collection.immutable.List.apply[Int](...) みたいな形になることを考慮すると、それを正確に記述するだけで面倒な作業となるので、これは非常に便利だ。

しかしながら、メソッド呼び出しは頻出なので Term 全般に対して専用の拡張メソッドが提供されている。

/** A unary apply node with given argument: `tree(arg)` */
def appliedTo(arg: Term): Term

/** An apply node with given arguments: `tree(arg, args0, ..., argsN)` */
def appliedTo(arg: Term, args: Term*): Term

/** An apply node with given argument list `tree(args(0), ..., args(args.length - 1))` */
def appliedToArgs(args: List[Term]): Apply

/** The current tree applied to given argument lists:
*  `tree (argss(0)) ... (argss(argss.length -1))`
*/
def appliedToArgss(argss: List[List[Term]]): Term

/** The current tree applied to (): `tree()` */
def appliedToNone: Apply

1 を加算して、toString を呼び出すというおかしなマクロを書いてみよう:

package com.eed3si9n.macroexample

import scala.quoted.*

inline def addOneToString(inline x: Int): String = ${addOneToStringImpl('{x})}

def addOneToStringImpl(x: Expr[Int])(using Quotes): Expr[String] =
  import quotes.reflect.*
  val inner = Literal(IntConstant(x.valueOrError + 1))
  Select.unique(inner, "toString").appliedToNone.asExprOf[String]

Select

Select もメジャーだ。上記では、Select.unique(term, <method name>) として登場した。

Select はオーバーロードされたメソッドを区別するための関数が色々あったりする。

ValDef

ValDefval 定義を表す。

クォートを使って val x を定義して、その参照を返すマクロは以下のように書ける:

package com.eed3si9n.macroexample

import scala.quoted.*

inline def addOneX(inline x: Int): Int = ${addOneXImpl('{x})}

def addOneXImpl(x: Expr[Int])(using Quotes): Expr[Int] =
  val rhs = Expr(x.valueOrError + 1)
  '{
    val x = $rhs
    x
  }

何らかの理由でこれをコードを使ってやりたいとする。まずは新しい val のためのシンボルを作る必要がある。そのためには、TypoeReprFlags も必要になる。

inline def addOneXv2(inline x: Int): Int = ${addOneXv2Impl('{x})}

def addOneXv2Impl(x: Expr[Int])(using Quotes): Expr[Int] =
  import quotes.reflect.*
  val rhs = Expr(x.valueOrError + 1)
  val sym = Symbol.newVal(
    Symbol.spliceOwner,
    "x",
    TypeRepr.of[Int],
    Flags.EmptyFlags,
    Symbol.noSymbol,
  )
  val vd = ValDef(sym, Some(rhs.asTerm))
  Block(
    List(vd),
    Ref(sym)
  ).asExprOf[Int]

Symbol

便宜的にシンボルはクラス、val、型といったものへの正確な名前だと考えることができる。 シンボルは val などの実体を定義するときに作られ、後で val を参照したいときに使うことができる。

以下が Symbol API だ。

type Symbol <: AnyRef

/** Module object of `type Symbol`  */
val Symbol: SymbolModule

/** Methods of the module object `val Symbol` */
trait SymbolModule { this: Symbol.type =>

  /** Symbol of the definition that encloses the current splicing context.
   *
   *  For example, the following call to `spliceOwner` would return the symbol `x`.
   *  ```scala sc:nocompile
   *  val x = ${ ... Symbol.spliceOwner ... }
   *  ```
   *
   *  For a macro splice, it is the symbol of the definition where the macro expansion happens.
   *  @syntax markdown
   */
  def spliceOwner: Symbol

  /** Get package symbol if package is either defined in current compilation run or present on classpath. */
  def requiredPackage(path: String): Symbol

  /** Get class symbol if class is either defined in current compilation run or present on classpath. */
  def requiredClass(path: String): Symbol

  /** Get module symbol if module is either defined in current compilation run or present on classpath. */
  def requiredModule(path: String): Symbol

  /** Get method symbol if method is either defined in current compilation run or present on classpath. Throws if the method has an overload. */
  def requiredMethod(path: String): Symbol

  def classSymbol(fullName: String): Symbol

  def newMethod(parent: Symbol, name: String, tpe: TypeRepr): Symbol

  def newMethod(parent: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol

  def newVal(parent: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol

  def newBind(parent: Symbol, name: String, flags: Flags, tpe: TypeRepr): Symbol

  def noSymbol: Symbol
}

/** Extension methods of `Symbol` */
trait SymbolMethods {
  extension (self: Symbol)

    /** Owner of this symbol. The owner is the symbol in which this symbol is defined. Throws if this symbol does not have an owner. */
    def owner: Symbol

    /** Owner of this symbol. The owner is the symbol in which this symbol is defined. Returns `NoSymbol` if this symbol does not have an owner. */
    def maybeOwner: Symbol

    /** Flags of this symbol */
    def flags: Flags

    /** This symbol is private within the resulting type */
    def privateWithin: Option[TypeRepr]

    /** This symbol is protected within the resulting type */
    def protectedWithin: Option[TypeRepr]

    /** The name of this symbol */
    def name: String

    /** The full name of this symbol up to the root package */
    def fullName: String

    /** The position of this symbol */
    def pos: Option[Position]

    /** The documentation for this symbol, if any */
    def docstring: Option[String]

    /** Tree of this definition
     *
     *  If this symbol `isClassDef` it will return `a `ClassDef`,
     *  if this symbol `isTypeDef` it will return `a `TypeDef`,
     *  if this symbol `isValDef` it will return `a `ValDef`,
     *  if this symbol `isDefDef` it will return `a `DefDef`
     *  if this symbol `isBind` it will return `a `Bind`,
     *  else will throw
     *
     *  **Warning**: avoid using this method in macros.
     *
     *  **Caveat**: The tree is not guaranteed to exist unless the compiler
     *  option `-Yretain-trees` is enabled.
     *
     *  **Anti-pattern**: The following code is an anti-pattern:
     *
     *      symbol.tree.tpe
     *
     *  It should be replaced by the following code:
     *
     *      tp.memberType(symbol)
     *
     */
    def tree: Tree

    /** Is the annotation defined with `annotSym` attached to this symbol? */
    def hasAnnotation(annotSym: Symbol): Boolean

    /** Get the annotation defined with `annotSym` attached to this symbol */
    def getAnnotation(annotSym: Symbol): Option[Term]

    /** Annotations attached to this symbol */
    def annotations: List[Term]

    /** Does this symbol come from a currently compiled source file? */
    def isDefinedInCurrentRun: Boolean

    /** Dummy val symbol that owns all statements within the initialization of the class.
    *  This may also contain local definitions such as classes defined in a `locally` block in the class.
    */
    def isLocalDummy: Boolean

    /** Is this symbol a class representing a refinement? */
    def isRefinementClass: Boolean

    /** Is this symbol an alias type? */
    def isAliasType: Boolean

    /** Is this symbol an anonymous class? */
    def isAnonymousClass: Boolean

    /** Is this symbol an anonymous function? */
    def isAnonymousFunction: Boolean

    /** Is this symbol an abstract type? */
    def isAbstractType: Boolean

    /** Is this the constructor of a class? */
    def isClassConstructor: Boolean

    /** Is this the definition of a type? */
    def isType: Boolean

    /** Is this the definition of a term? */
    def isTerm: Boolean

    /** Is this the definition of a PackageDef tree? */
    def isPackageDef: Boolean

    /** Is this the definition of a ClassDef tree? */
    def isClassDef: Boolean

    /** Is this the definition of a TypeDef tree */
    def isTypeDef: Boolean

    /** Is this the definition of a ValDef tree? */
    def isValDef: Boolean

    /** Is this the definition of a DefDef tree? */
    def isDefDef: Boolean

    /** Is this the definition of a Bind pattern? */
    def isBind: Boolean

    /** Does this symbol represent a no definition? */
    def isNoSymbol: Boolean

    /** Does this symbol represent a definition? */
    def exists: Boolean

    /** Field with the given name directly declared in the class */
    def declaredField(name: String): Symbol

    /** Fields directly declared in the class */
    def declaredFields: List[Symbol]

    /** Get named non-private fields declared or inherited */
    def fieldMember(name: String): Symbol

    /** Get all non-private fields declared or inherited */
    def fieldMembers: List[Symbol]

    /** Get non-private named methods defined directly inside the class */
    def declaredMethod(name: String): List[Symbol]

    /** Get all non-private methods defined directly inside the class, excluding constructors */
    def declaredMethods: List[Symbol]

    /** Get named non-private methods declared or inherited */
    def methodMember(name: String): List[Symbol]

    /** Get all non-private methods declared or inherited */
    def methodMembers: List[Symbol]

    /** Get non-private named methods defined directly inside the class */
    def declaredType(name: String): List[Symbol]

    /** Get all non-private methods defined directly inside the class, excluding constructors */
    def declaredTypes: List[Symbol]

    /** Type member with the given name directly declared in the class */
    def typeMember(name: String): Symbol

    /** Type member directly declared in the class */
    def typeMembers: List[Symbol]

    /** All members directly declared in the class */
    def declarations: List[Symbol]

    /** The symbols of each type parameter list and value parameter list of this
      *  method, or Nil if this isn't a method.
      */
    def paramSymss: List[List[Symbol]]

    /** Returns all symbols overridden by this symbol. */
    def allOverriddenSymbols: Iterator[Symbol]

    /** The symbol overriding this symbol in given subclass `ofclazz`.
     *
     *  @param ofclazz is a subclass of this symbol's owner
     */
    def overridingSymbol(ofclazz: Symbol): Symbol

    /** The primary constructor of a class or trait, `noSymbol` if not applicable. */
    def primaryConstructor: Symbol

    /** Fields of a case class type -- only the ones declared in primary constructor */
    def caseFields: List[Symbol]

    def isTypeParam: Boolean

    /** Signature of this definition */
    def signature: Signature

    /** The class symbol of the companion module class */
    def moduleClass: Symbol

    /** The symbol of the companion class */
    def companionClass: Symbol

    /** The symbol of the companion module */
    def companionModule: Symbol

    /** Case class or case object children of a sealed trait or cases of an `enum`. */
    def children: List[Symbol]
  end extension
}

包囲項 (enclosing term)

豊かな Symbol API の簡単なデモとして、マクロ適用の包囲項を探すのに使ってみる。 例えば、sbt ではこれを使ってコンフィギュレーションの名前を val から抜き出している。

lazy val Compile = config

// 上を以下のように展開したい
lazy val Compile = Config("Compile")

“Compile” という名前を抜き出す config マクロは以下のように実装できる:

package com.eed3si9n.macroexample

import scala.quoted.*

case class Config(name: String)

inline def config: Config = ${configImpl}

def configImpl(using Quotes): Expr[Config] =
  import quotes.reflect.*
  def enclosingTerm(sym: Symbol): Symbol =
    sym match
      case sym if sym.flags is Flags.Macro => enclosingTerm(sym.owner)
      case sym if !sym.isTerm              => enclosingTerm(sym.owner)
      case _                               => sym
  val n = enclosingTerm(Symbol.spliceOwner).name
  val nExpr = Expr(n)
  '{ Config($nExpr) }

config の用法は以下のようになっている:

scala> import com.eed3si9n.macroexample._

scala> lazy val Compile = config
lazy val Compile: com.eed3si9n.macroexample.Config

scala> Compile.name
val res0: String = Compile

この例は Symbol API の複数の側面を使っている。まず、Symbol.spliceOwner がある。 マクロでは、これはマクロ展開が発生した定義への参照を返す。 しかし、Scala 3.1.1 は macro という名前の人工的変数を作っているのでは、これは直接役に立たない。

次にやってみる事として、flags という拡張メソッドがある。 Scala コンパイラにおける全てのシンボルは様々なフラグが与えられており、そのシンボルが項か型なのか、人工的なのか否か、val なのか def なのか等をチェックすることができる。 この場合、sym.flags is Flags.Macro という検査を行う。

シンボル間ではグラフ構造を形成しており、Symbol#owner 拡張メソッドを使って 1つ上のレベルに上がることができる。 項に当たるまで enclosingTerm(sym.owner) を再帰的に呼び出す。 このテクニックを応用して包囲クラスなどを探すこともできる。 一般的に、シンボルはリッチな情報を保持しているので、構文木や型を見に行かなくてもシンボルだけで十分に情報が得られることがある。

補足すると、Symbol#tree 拡張メソッドもあって、マクロ開発中には

sys.error(Printer.TreeStructure.show(sym.tree))

を実行して構文木構造を見るのが役に立つこともある:

scala> lazy val Compile = config
-- Error: ----------------------------------------------------------------------
1 |lazy val Compile = config
  |                   ^^^^^^
  | Exception occurred while executing macro expansion.
  | java.lang.RuntimeException: ValDef("macro", Inferred(), None)
  |   at scala.sys.package$.error(package.scala:27)
  |   at com.eed3si9n.macroexample.Config$package$.configImpl(Config.scala:16)

しかし、-Yretain-trees フラグを立てないとシンボルは構文木を保持することは保証されていないので、一般的にマクロから Symbol#tree を呼ぶのは安全では無い。 これは Best Practices ガイドでも Avoid Symbol.tree として注意されている。

Ref

本物のコンパイラは import や入れ子になったブロックなども考慮して名前を正しいシンボルに解決するが、僕たちは既にシンボルを持っているので Ref(sym) と書くことができる。

TypeRepr

TypeRepr はマクロ時における型と型関連の演算を表す。実行時には型情報は消去されるため、マクロを使うことで Scala の型情報を直接取り扱うことができる。

A が case class かどうかを検査するコードは TypeRepr がどう取得されるかを見れる良い例だ。

import scala.quoted.*

inline def isCaseClass[A]: Boolean = ${ isCaseClassImpl[A] }

private def isCaseClassImpl[A: Type](using qctx: Quotes) : Expr[Boolean] =
  import qctx.reflect.*
  val sym = TypeRepr.of[A].typeSymbol
  Expr(sym.isClassDef && (sym.flags is Flags.Case))

以下が TypeRepr API だ。

/** A type, type constructors, type bounds or NoPrefix */
type TypeRepr

/** Module object of `type TypeRepr`  */
val TypeRepr: TypeReprModule

/** Methods of the module object `val TypeRepr` */
trait TypeReprModule { this: TypeRepr.type =>
  /** Returns the type or kind (TypeRepr) of T */
  def of[T <: AnyKind](using Type[T]): TypeRepr

  /** Returns the type constructor of the runtime (erased) class */
  def typeConstructorOf(clazz: Class[?]): TypeRepr
}

/** Makes extension methods on `TypeRepr` available without any imports */
given TypeReprMethods: TypeReprMethods

/** Extension methods of `TypeRepr` */
trait TypeReprMethods {
  extension (self: TypeRepr)

    /** Shows the type as a String */
    def show(using Printer[TypeRepr]): String

    /** Convert this `TypeRepr` to an `Type[?]` */
    def asType: Type[?]

    /** Is `self` type the same as `that` type?
    *  This is the case iff `self <:< that` and `that <:< self`.
    */
    def =:=(that: TypeRepr): Boolean

    /** Is this type a subtype of that type? */
    def <:<(that: TypeRepr): Boolean

    /** Widen from singleton type to its underlying non-singleton
     *  base type by applying one or more `underlying` dereferences,
     *  Also go from => T to T.
     *  Identity for all other types. Example:
     *
     *  class Outer { class C ; val x: C }
     *  def o: Outer
     *  <o.x.type>.widen = o.C
     */
    def widen: TypeRepr

    /** Widen from TermRef to its underlying non-termref
     *  base type, while also skipping ByName types.
     */
    def widenTermRefByName: TypeRepr

    /** Widen from ByName type to its result type. */
    def widenByName: TypeRepr

    /** Follow aliases, annotated types until type is no longer alias type, annotated type. */
    def dealias: TypeRepr

    /** A simplified version of this type which is equivalent wrt =:= to this type.
    *  Reduces typerefs, applied match types, and and or types.
    */
    def simplified: TypeRepr

    def classSymbol: Option[Symbol]
    def typeSymbol: Symbol
    def termSymbol: Symbol
    def isSingleton: Boolean
    def memberType(member: Symbol): TypeRepr

    /** The base classes of this type with the class itself as first element. */
    def baseClasses: List[Symbol]

    /** The least type instance of given class which is a super-type
    *  of this type.  Example:
    *  {{{
    *    class D[T]
    *    class C extends p.D[Int]
    *    ThisType(C).baseType(D) = p.D[Int]
    * }}}
    */
    def baseType(cls: Symbol): TypeRepr

    /** Is this type an instance of a non-bottom subclass of the given class `cls`? */
    def derivesFrom(cls: Symbol): Boolean

    /** Is this type a function type?
    *
    *  @return true if the dealiased type of `self` without refinement is `FunctionN[T1, T2, ..., Tn]`
    *
    *  @note The function
    *
    *     - returns true for `given Int => Int` and `erased Int => Int`
    *     - returns false for `List[Int]`, despite that `List[Int] <:< Int => Int`.
    */
    def isFunctionType: Boolean

    /** Is this type an context function type?
    *
    *  @see `isFunctionType`
    */
    def isContextFunctionType: Boolean

    /** Is this type an erased function type?
    *
    *  @see `isFunctionType`
    */
    def isErasedFunctionType: Boolean

    /** Is this type a dependent function type?
    *
    *  @see `isFunctionType`
    */
    def isDependentFunctionType: Boolean

    /** The type <this . sym>, reduced if possible */
    def select(sym: Symbol): TypeRepr

    /** The current type applied to given type arguments: `this[targ]` */
    def appliedTo(targ: TypeRepr): TypeRepr

    /** The current type applied to given type arguments: `this[targ0, ..., targN]` */
    def appliedTo(targs: List[TypeRepr]): TypeRepr

  end extension
}

TypeRepr の拡張メソッドを使ってみよう。以下は 2つの型が等しいかを比べるマクロだ:

package com.eed3si9n.macroexample

import scala.quoted.*

inline def typeEq[A1, A2]: Boolean = ${ typeEqImpl[A1, A2] }

def typeEqImpl[A1: Type, A2: Type](using Quotes): Expr[Boolean] =
  import quotes.reflect.*
  Expr(TypeRepr.of[A1] =:= TypeRepr.of[A2])

typeEq は以下のように使うことができる:

scala> import com.eed3si9n.macroexample.*

scala> typeEq[scala.Predef.String, java.lang.String]
val res0: Boolean = true

scala> typeEq[Int, java.lang.Integer]
val res1: Boolean = false

AppliedType

型消去で無くなる情報の 1つに List[Int] といったパラメータ化された型の型パラメータがある。TypeRepr の情報を型適用の部分に分解するのは少しトリッキーだ。

TypeTest[TypeRepr, AppliedType] を使うことも可能だが、コンパイラがマジックを使って通常のパターンマッチと同じように書けるようになっている。型パラメータの名前を返すマクロは以下のように書ける。

package com.eed3si9n.macroexample

import scala.quoted.*
import scala.reflect.*

inline def paramInfo[A]: List[String] = ${paramInfoImpl[A]}

def paramInfoImpl[A: Type](using Quotes): Expr[List[String]] =
  import quotes.reflect.*
  val tpe = TypeRepr.of[A]
  val targs = tpe.widenTermRefByName.dealias match
    case AppliedType(_, args) => args
    case _                    => Nil
  Expr(targs.map(_.show))

これはこのように使える:

scala> import com.eed3si9n.macroexample.*

scala> paramInfo[List[Int]]
val res0: List[String] = List(scala.Int)

scala> paramInfo[Int]
val res1: List[String] = List()

抽出子としての Select

これまでの所マクロには 1 みたいな素の値を渡して来た。マクロに関数の呼び出しを渡して、関数呼び出しを操作することで少しひねったマクロを書くことができる。

具体例で説明すると、まずは echo というダミー関数を作る:

import scala.annotation.compileTimeOnly

object Dummy:
  @compileTimeOnly("echo can only be used in lines macro")
  def echo(line: String): String = ???
end Dummy

次に、Dummy.echo(...) を入力された値と行番号を前置したものに置換する Source.lines(...) マクロを実装できる。

package com.eed3si9n.macroexample

import scala.annotation.compileTimeOnly
import scala.quoted.*

object Source:
  inline def lines_bad(inline xs: List[String]): List[String] = ${lines_badImpl('{ xs })}

  def lines_badImpl(xs: Expr[List[String]])(using Quotes): Expr[List[String]] =
    import quotes.reflect.*
    val dummySym = Symbol.requiredModule("com.eed3si9n.macroexample.Dummy")
    xs match
      case ListApply(args) =>
        val args2 = args map { arg =>
          arg.asTerm match
            case a @ Apply(Select(qual, "echo"), List(Literal(StringConstant(str)))) if qual.symbol == dummySym =>
              val pos = a.pos
              Expr(s"${pos.startLine + 1}: $str")
            case _ => arg
        }
        '{ List(${ Varargs[String](args2.toList) }: _*) }

  // bad example. see below for quoted pattern.
  object ListApply:
    def unapply(expr: Expr[List[String]])(using Quotes): Option[Seq[Expr[String]]] =
      import quotes.reflect.*
      def rec(tree: Term): Option[Seq[Expr[String]]] =
        tree match
          case Inlined(_, _, e) => rec(e)
          case Block(Nil, e)    => rec(e)
          case Typed(e, _)      => rec(e)
          case Apply(TypeApply(Select(obj, "apply"), _), List(e)) if obj.symbol.name == "List" => rec(e)
          case Repeated(elems, _) => Some(elems.map(_.asExprOf[String]))
      rec(expr.asTerm)
  end ListApply

end Source

object Dummy:
  @compileTimeOnly("echo can only be used in lines macro")
  def echo(line: String): String = ???
end Dummy

これは以下のようにテストできる:

package com.eed3si9n.macroexample

object LinesTest extends verify.BasicTestSuite:
  test("lines") {
    assert(Source.lines_bad(List(
      "foo",
      Dummy.echo("bar"),
    )) == List(
      "foo",
      "7: bar"
    ))
  }
end LinesTest

抽出子としてのクォート

上の例では List(...) 適用式の引数を抽出するのにかなり頑張っている。これはクォートを抽出子として用いることで改善できる。これは quoted patterns として公式ドキュメントに書いてある。

Scala がパターンを期待する位置に '{ ... } パターンを置くことができる。

Dummy.echo(...) を置換する lines(...) マクロの改善版は以下のようになる。

package com.eed3si9n.macroexample

import scala.annotation.compileTimeOnly
import scala.quoted.*

object Source:
  inline def lines(inline xs: List[String]): List[String] = ${linesImpl('{ xs })}

  def linesImpl(xs: Expr[List[String]])(using Quotes): Expr[List[String]] =
    import quotes.reflect.*
    xs match
      case '{ List[String]($vargs*) } =>
        vargs match
          case Varargs(args) =>
            val args2 = args map { arg =>
              arg match
                case '{ Dummy.echo($str) } =>
                  val pos = arg.asTerm.pos
                  Expr(s"${pos.startLine + 1}: ${ str.valueOrError }")
                case _ => arg
            }
            '{ List(${ Varargs[String](args2.toList) }: _*) }
end Source

object Dummy:
  @compileTimeOnly("echo can only be used in lines macro")
  def echo(line: String): String = ???
end Dummy

Dummy.echo メソッドの面倒なシンボル照会も無くすことができた。

型のスプライス

一旦 TypeRepr に戻る。TypeRepr を使って型を構築して、それを生成されるコードにスプライスするというパターンが出てくる。

a: AString の 2つのパラメータを受け取って、2つ目のパラメータが "String" ならば Either[String, A] を宣言して、もしも "List[String]" ならば Either[List[String], A] を作るマクロを作ってみよう。その Either を使うためには flatMap してゼロじゃないかをチェックする。

package com.eed3si9n.macroexample

import scala.quoted.*

inline def right[A](inline a: A, inline which: String): String =
  ${ rightImpl[A]('{ a }, '{ which }) }

def rightImpl[A: Type](a: Expr[A], which: Expr[String])(using Quotes): Expr[String] =
  import quotes.reflect.*
  val w = which.valueOrError
  val leftTpe = w match
    case "String"       => TypeRepr.of[String]
    case "List[String]" => TypeRepr.of[List[String]]
  val msg = w match
    case "String"       => Expr("empty not allowed")
    case "List[String]" => Expr(List("empty not allowed"))
  leftTpe.asType match
    case '[l] =>
      '{
        val e0: Either[l, A] = Right[l, A]($a)
        val e1 = e0 flatMap { x =>
          if x == null.asInstanceOf[A] then Left[l, A]($msg.asInstanceOf[l])
          else Right(x)
        }
        e1.toString
      }

つまり、マクロ内で型情報を扱うときは TypeRepr[_] を召喚 (summon) するが、Scala コードにスプライスし直すときは Type[_] を作る必要がある。使ってみよう:

scala> import com.eed3si9n.macroexample.*

scala> right(1, "String")
val res0: String = Right(1)

scala> right(0, "String")
val res1: String = Left(empty not allowed)

scala> right[String](null, "List[String]")
val res2: String = Left(List(empty not allowed))

あと、これは入力と出力は関数のシグネチャによって定義済みだが、入力によって内部実装で別の型を作っている例だ。

Lambda

ラムダ式 (匿名関数) を作るのはよくある作業なので、Reflection API は Lambda というヘルパーを提供する。これは以下のようにして使うことができる:

import scala.quoted.*

inline def mkLambda[A](inline a: A): A = ${mkLambdaImpl[A]('{ a })}

def mkLambdaImpl[A: Type](a: Expr[A])(using Quotes): Expr[A] =
  import quotes.reflect.*

  val lambdaTpe =
    MethodType(List("p0"))(_ => List(TypeRepr.of[Int] ), _ => TypeRepr.of[A])
  val lambda = Lambda(
    owner = Symbol.spliceOwner,
    tpe = lambdaTpe,
    rhsFn = (sym, params) => {
      val p0 = params.head.asInstanceOf[Term]
      a.asTerm.changeOwner(sym)
    }
  )
  '{
    val f: Int => A = ${ lambda.asExprOf[Int => A] }
    f(0)
  }

これは以下のようなラムダ式を作る。

val f: Int => A = (p0: Int) => {
  ....
}

ただし、マクロのに渡されたコードはラムダ式の中に移され、f(0) として呼び出される。用法は以下のようになる:

scala> import com.eed3si9n.macroexample.*

scala> mkLambda({
     |   val x = 1
     |   x + 2
     | })
val res0: Int = 3

引数 a.asTerm がラムダ式の中に移動されるとき、val x などのシンボルのオーナーをラムダ式へと譲渡するために changeOwner(sym) を呼ぶ必要があることに注意してほしい。そうしないと

[error] (run-main-1) java.util.NoSuchElementException: val x
[error] java.util.NoSuchElementException: val x

[error] java.lang.IllegalArgumentException: Could not find proxy for p0: Tuple2 in List(....)

といった変なエラーが発生する。

Restligeist マクロ

Restligeist マクロ、つまり地縛霊マクロは直ちに失敗するマクロだ。API を廃止した後でマイグレーションのためのメッセージを表示させるというユースケースがある。Scala 3 だとこのようなユーザランドでのコンパイルエラーが一行で書ける。

package com.eed3si9n.macroexample

object SomeDSL:
  inline def <<=[A](inline a: A): Option[A] =
    compiletime.error("<<= is removed; migrated to := instead")
end SomeDSL

使う側だとこのような感じに見える:

scala> import com.eed3si9n.macroexample.*

scala> SomeDSL.<<=((1, "foo"))
-- Error:
1 |SomeDSL.<<=((1, "foo"))
  |^^^^^^^^^^^^^^^^^^^^^^^
  |<<= is removed; migrated to := instead

compiletime.error(...) は便利だが、文字列リテラルと codeOf() しか扱えないという制限がある。もしエラーメッセージを再利用したい場合はマクロを作って report モジュールを使うという手がある:

/** Module containing error and warning reporting. */
val report: reportModule

/** Methods of the module object `val report` */
trait reportModule { self: report.type =>

  /** Report an error at the position of the macro expansion */
  def error(msg: String): Unit

  /** Report an error at the position of `expr` */
  def error(msg: String, expr: Expr[Any]): Unit

  /** Report an error message at the given position */
  def error(msg: String, pos: Position): Unit

  /** Report an error at the position of the macro expansion and throw a StopMacroExpansion */
  def errorAndAbort(msg: String): Nothing

  /** Report an error at the position of `expr` and throw a StopMacroExpansion */
  def errorAndAbort(msg: String, expr: Expr[Any]): Nothing

  /** Report an error message at the given position and throw a StopMacroExpansion */
  def errorAndAbort(msg: String, pos: Position): Nothing

  ....
}

より本格的な Restligeist マクロは以下のようになる:

package com.eed3si9n.macroexample

import scala.quoted.*

object SomeDSL:
  final val assignMigration = """<<= is removed; migrated to := instead
                                |go to link to documentation""".stripMargin

  inline def <<=[A](a: A): Option[A] = ${ assignImpl('a) }

  def assignImpl[A: Type](a: Expr[A])(using qctx: Quotes): Expr[Option[A]] =
    import qctx.reflect.*
    report.errorAndAbort(assignMigration)

end SomeDSL

使う側だとこのような感じに見える:

scala> import com.eed3si9n.macroexample.*

scala> SomeDSL.<<=((1, "foo"))
-- Error: ----------------------------------------------------------------------
1 |SomeDSL.<<=((1, "foo"))
  |^^^^^^^^^^^^^^^^^^^^^^^
  |<<= is removed; migrated to := instead
  |go to link to documentation

まとめ

Scala 3 のマクロは、Scala 構文そのものを使ってソースコードの形を操作したり、型システムと直接対話できるなど、今までと異なるレベルのプログラミング能力を引き出すことができる。可能な場合は、プログラムを使って AST を構築する (Quote) Reflection API を避け、Scala 構文を使ってクォートされるコードを構築する事が推奨される。

プログラム的な柔軟性を必要とする場合は、Reflection API が TreeSymbolTypeRepr といった豊富な型ファミリーを提供する。これは一部 Reflection としてドキュメント化されているが、現時点では Quotes.scala を読むのが最も便利な情報源だ。

クォートをパターンマッチで使う方が全般的に型安全であり、マクロが現行 Scala バージョンの実装に特定の Tree 実装に決め打ちになってしまうことを回避できる可能性もある。