Simplify Anf.scala by removing the continuations. (#17242)

The core functions of the ANF transformation in Anf.scala handle 3 different types of continuations:

1. The "transformations" (called tx), that are an inherent aspect of any ANF transformation expressed in pseudo-CPS style for convenience.
2. The "continuations" (called k) that had been introduced to avoid stack oveflows.
3. The trampolines, that have also been introduced to avoid stack overflows.

This commit recognizes that types 2 and 3 are redundant and merges them into only type 3. It relies on the fact that trampolines are monadic, and so we switch from the hand-rolled trampolines to that of scala's std library since they already implement flatMap.
This commit is contained in:
Paul Brauner 2023-08-09 12:19:01 +02:00 committed by GitHub
parent 5c31bc3f21
commit 4e54c69c9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,14 +29,13 @@ package com.daml.lf.speedy
* We use "source." and "t." for lightweight discrimination.
*/
import com.daml.lf.data.Trampoline.{Bounce, Land, Trampoline}
import com.daml.lf.speedy.{SExpr1 => source}
import com.daml.lf.speedy.{SExpr => target}
import com.daml.lf.speedy.Compiler.CompilationError
import scala.annotation.nowarn
import scala.annotation.tailrec
import scala.util.control.TailCalls._
import scalaz.{@@, Tag}
private[lf] object Anf {
@ -52,9 +51,7 @@ private[lf] object Anf {
private def flattenToAnfInternal(exp: source.SExpr): target.SExpr = {
val depth = DepthA(0)
val env = initEnv
flattenExp(depth, env, exp) { exp =>
Land(exp)
}.bounce
flattenExp(depth, env, exp).result
}
/** The transformation code is implemented using a what looks like
@ -71,32 +68,10 @@ private[lf] object Anf {
* mapping between them. See the types: DepthE, DepthA and Env.
*
* Using a coding style that looks like CPS is a natural way to express the
* ANF transformation. However, this means the transformation is very
* stack-intensive. To address that, we need the code to be in "true" CPS
* form, which is not quite the same as the semantics required by the ANF
* transform. Having the code in ANF form allows us to use the trampoline
* technique to execute the computation in constant stack space.
*
* This means that, in some sense, the following code has two (interleaved)
* levels of CPS-looking style. For the sake of clarity, in further comments
* as well as in the code, we will use the term "continuation" and the
* variable names "k", strictly for the "true" continuations that have
* been added to achieve constant stack space, and use the term
* "transformation function" and the variable names "transform", "tx" for
* the functions that express the semantics of the ANF transformation.
*
* Things are further muddied by the following:
* 1. A number of top-level functions defined in this object also qualify as
* "transformation functions", even though they themselves receive
* transformation functions as arguments and/or define new ones on the fly
* (flattenExp, transformLet1, flattenAlts, transformExp, atomizeExp,
* atomizeExps).
* 2. To achieve full CPS, transformation functions themselves need to accept
* (and apply) a continuation.
*
* Not all functions in this object are in CPS (only the ones that are part of
* the main recursive loop), but those that do always take the continuation as
* their last argument.
* ANF transformation. The transformation is however not tail recursive: it
* is only mostly in CPS form, and is actually stack intensive. To address
* this, all functions of the core loop return trampolines
* (scala.util.control.TailCalls).
*/
/** `DepthE` tracks the stack-depth of the original expression being traversed */
@ -129,30 +104,14 @@ private[lf] object Anf {
Env(absMap = env.absMap ++ extra, oldDepth = env.oldDepth.incr(n))
}
private[this] type Res = Trampoline[target.SExpr]
private[this] type Res = TailRec[target.SExpr]
/** Tx is the type for the stacked transformation functions managed by the ANF
* transformation, mainly transformExp.
*
* All of the transformation functions would, without CPS, return a target.SExpr,
* so that is the input type of the continuation.
*
* Both type parameters are ultimately needed because SCaseAlt does not
* extend source.SExpr. If it did, T would always be source.SExpr and A would always be
* target.SExpr.
*
* Note that Scala does not seem to be able to generate anonymous function of
* a parameterized type, so we use nested `defs` instead.
*
* @tparam T The type of expression this will be applied to.
*/
private[this] type Tx[T] = (DepthA, T) => K[target.SExpr] => Res
/** K Is the type for continuations.
*
* @tparam T Type the function would have returned had it not been in CPS.
*/
private[this] type K[T] = T => Res
private[this] type Tx[T] = (DepthA, T) => Res
/** During conversion we need to deal with bindings which are made/found at a given
* absolute stack depth. These are represented using `AbsBinding`. An absolute stack
@ -208,15 +167,8 @@ private[lf] object Anf {
case Right(binding) => target.SELocS(makeRelativeB(depth, binding))
}
private[this] def flattenExp(depth: DepthA, env: Env, exp: source.SExpr)(
k: K[target.SExpr]
): Res = {
transformExp(depth, env, exp)(k) { (_, sexpr) => k =>
Bounce { () =>
k(sexpr)
}
}
private[this] def flattenExp(depth: DepthA, env: Env, exp: source.SExpr): Res = {
transformExp(depth, env, exp) { (_, sexpr) => done(sexpr) }
}
private[this] def transformLet1(
@ -224,46 +176,36 @@ private[lf] object Anf {
env: Env,
rhs: source.SExpr,
body: source.SExpr,
)(k: K[target.SExpr])(transform: Tx[target.SExpr]): Res = {
transformExp(depth, env, rhs)(k) { (depth, rhs) => k =>
)(transform: Tx[target.SExpr]): Res = {
transformExp(depth, env, rhs) { (depth, rhs) =>
val depth1 = depth.incr(1)
val env1 = trackBindings(depth, env, 1)
transformExp(depth1, env1, body) { body =>
Bounce { () =>
k(target.SELet1(rhs, body))
}
}(transform)
}
}
private[this] def flattenAlts(depth: DepthA, env: Env, alts0: List[source.SCaseAlt])(
k: K[List[target.SCaseAlt]]
): Res = {
def loop(acc: List[target.SCaseAlt], alts: List[source.SCaseAlt]): Res = {
alts match {
case alt :: alts =>
flattenAlt(depth, env, alt) { alt =>
loop(alt :: acc, alts)
}
case Nil =>
k(acc.reverse)
transformExp(depth1, env1, body)(transform).map { body =>
target.SELet1(rhs, body)
}
}
loop(Nil, alts0)
}
private[this] def flattenAlt(depth: DepthA, env: Env, alt: source.SCaseAlt)(
k: K[target.SCaseAlt]
): Res = {
private[this] def flattenAlts(
depth: DepthA,
env: Env,
alts0: List[source.SCaseAlt],
): TailRec[List[SExpr.SCaseAlt]] = {
traverse(alts0, (alt: source.SCaseAlt) => flattenAlt(depth, env, alt))
}
private[this] def flattenAlt(
depth: DepthA,
env: Env,
alt: source.SCaseAlt,
): TailRec[SExpr.SCaseAlt] = {
alt match {
case source.SCaseAlt(pat, body) =>
val n = patternNArgs(pat)
val env1 = trackBindings(depth, env, n)
flattenExp(depth.incr(n), env1, body) { body =>
k(target.SCaseAlt(pat, body))
flattenExp(depth.incr(n), env1, body).map { body =>
target.SCaseAlt(pat, body)
}
}
}
@ -290,14 +232,14 @@ private[lf] object Anf {
* achieve constant stack through trampoline.
*/
private[this] def transformExp(depth: DepthA, env: Env, exp: source.SExpr)(
k: K[target.SExpr]
)(transform: Tx[target.SExpr]): Res = Bounce { () =>
transform: Tx[target.SExpr]
): Res = tailcall {
exp match {
case atom0: source.SExprAtomic =>
val atom = makeRelativeA(depth)(makeAbsoluteA(env, atom0))
transform(depth, atom)(k)
transform(depth, atom)
case source.SEVal(x) => transform(depth, target.SEVal(x))(k)
case source.SEVal(x) => transform(depth, target.SEVal(x))
case source.SEApp(func, args) =>
// It's safe to perform ANF if the func-expression has no effects when evaluated.
@ -312,63 +254,63 @@ private[lf] object Anf {
// It's also safe to perform ANF for applications of a single argument.
val singleArg = args.lengthCompare(1) == 0
if (safeFunc || singleArg) {
transformMultiApp(depth, env, func, args, k)(transform)
transformMultiApp(depth, env, func, args)(transform)
} else {
transformMultiAppSafely(depth, env, func, args, k)(transform)
transformMultiAppSafely(depth, env, func, args)(transform)
}
case source.SEMakeClo(fvs0, arity, body) =>
val fvs = fvs0.map((loc) => makeRelativeL(depth)(makeAbsoluteL(env, loc)))
flattenExp(DepthA(0), initEnv, body) { body =>
transform(depth, target.SEMakeClo(fvs.toArray, arity, body))(k)
flattenExp(DepthA(0), initEnv, body).flatMap { body =>
transform(depth, target.SEMakeClo(fvs.toArray, arity, body))
}
case source.SECase(scrut, alts0) =>
atomizeExp(depth, env, scrut, k) { (depth, scrut) => k =>
atomizeExp(depth, env, scrut) { (depth, scrut) =>
val scrut1 = makeRelativeA(depth)(scrut)
flattenAlts(depth, env, alts0) { alts =>
transform(depth, target.SECaseAtomic(scrut1, alts.toArray))(k)
flattenAlts(depth, env, alts0).flatMap { alts =>
transform(depth, target.SECaseAtomic(scrut1, alts.toArray))
}
}
case source.SELet(rhss, body) =>
val expanded = expandMultiLet(rhss, body)
transformExp(depth, env, expanded)(k)(transform)
transformExp(depth, env, expanded)(transform)
case source.SELet1General(rhs, body) =>
transformLet1(depth, env, rhs, body)(k)(transform)
transformLet1(depth, env, rhs, body)(transform)
case source.SELocation(loc, body) =>
transformExp(depth, env, body)(k) { (depth, body) => k =>
Bounce { () =>
transform(depth, target.SELocation(loc, body))(k)
transformExp(depth, env, body) { (depth, body) =>
tailcall {
transform(depth, target.SELocation(loc, body))
}
}
case source.SELabelClosure(label, exp) =>
transformExp(depth, env, exp)(k) { (depth, exp) => k =>
Bounce { () =>
transform(depth, target.SELabelClosure(label, exp))(k)
transformExp(depth, env, exp) { (depth, exp) =>
tailcall {
transform(depth, target.SELabelClosure(label, exp))
}
}
case source.SETryCatch(body, handler0) =>
// we must not lift applications from either the body or the handler outside of
// the try-catch block, so we flatten each separately:
flattenExp(depth, env, body) { body =>
flattenExp(depth.incr(1), trackBindings(depth, env, 1), handler0) { handler =>
transform(depth, target.SETryCatch(body, handler))(k)
flattenExp(depth, env, body).flatMap { body =>
flattenExp(depth.incr(1), trackBindings(depth, env, 1), handler0).flatMap { handler =>
transform(depth, target.SETryCatch(body, handler))
}
}
case source.SEScopeExercise(body) =>
flattenExp(depth, env, body) { body =>
transform(depth, target.SEScopeExercise(body))(k)
flattenExp(depth, env, body).flatMap { body =>
transform(depth, target.SEScopeExercise(body))
}
case source.SEPreventCatch(body) =>
flattenExp(depth, env, body) { body =>
transform(depth, target.SEPreventCatch(body))(k)
flattenExp(depth, env, body).flatMap { body =>
transform(depth, target.SEPreventCatch(body))
}
}
@ -378,40 +320,36 @@ private[lf] object Anf {
depth: DepthA,
env: Env,
exps: List[source.SExpr],
k: K[target.SExpr],
)(transform: Tx[List[AbsAtom]]): Res =
exps match {
case Nil => transform(depth, Nil)(k)
case Nil => transform(depth, Nil)
case exp :: exps =>
atomizeExp(depth, env, exp, k) { (depth, atom) => k =>
Bounce { () =>
atomizeExps(depth, env, exps, k) { (depth, atoms) => k =>
Bounce { () =>
transform(depth, atom :: atoms)(k)
atomizeExp(depth, env, exp) { (depth, atom) =>
tailcall {
atomizeExps(depth, env, exps) { (depth, atoms) =>
tailcall {
transform(depth, atom :: atoms)
}
}
}
}
}
private[this] def atomizeExp(depth: DepthA, env: Env, exp: source.SExpr, k: K[target.SExpr])(
private[this] def atomizeExp(depth: DepthA, env: Env, exp: source.SExpr)(
transform: Tx[AbsAtom]
): Res = {
exp match {
case ea: source.SExprAtomic => transform(depth, makeAbsoluteA(env, ea))(k)
case _ => {
transformExp(depth, env, exp)(k) { (depth, exp) => k =>
case ea: source.SExprAtomic => transform(depth, makeAbsoluteA(env, ea))
case _ =>
transformExp(depth, env, exp) { (depth, exp) =>
val atom = Right(AbsBinding(depth))
Bounce { () =>
transform(depth.incr(1), atom) { body =>
Bounce { () =>
k(target.SELet1(exp, body))
}
tailcall {
transform(depth.incr(1), atom).map { body =>
target.SELet1(exp, body)
}
}
}
}
}
}
@ -435,14 +373,13 @@ private[lf] object Anf {
env: Env,
func: source.SExpr,
args: List[source.SExpr],
k: K[target.SExpr],
)(transform: Tx[target.SExpr]): Res = {
atomizeExp(depth, env, func, k) { (depth, func) => k =>
atomizeExps(depth, env, args, k) { (depth, args) => k =>
atomizeExp(depth, env, func) { (depth, func) =>
atomizeExps(depth, env, args) { (depth, args) =>
val func1 = makeRelativeA(depth)(func)
val args1 = args.map(makeRelativeA(depth))
transform(depth, target.SEAppAtomic(func1, args1.toArray))(k)
transform(depth, target.SEAppAtomic(func1, args1.toArray))
}
}
}
@ -457,34 +394,38 @@ private[lf] object Anf {
env: Env,
func: source.SExpr,
args: List[source.SExpr],
k: K[target.SExpr],
)(transform: Tx[target.SExpr]): Res = {
atomizeExp(depth, env, func, k) { (depth, func) => k =>
atomizeExp(depth, env, func) { (depth, func) =>
val func1 = makeRelativeA(depth)(func)
// we dont atomize the args here
flattenExpList(depth, env, args) { args =>
flattenExpList(depth, env, args).flatMap { args =>
// we build a non-atomic application here (only the function is atomic)
transform(depth, target.SEAppOnlyFunIsAtomic(func1, args.toArray))(k)
transform(depth, target.SEAppOnlyFunIsAtomic(func1, args.toArray))
}
}
}
private[this] def flattenExpList(depth: DepthA, env: Env, exps0: List[source.SExpr])(
k: K[List[target.SExpr]]
): Res = {
def loop(acc: List[target.SExpr], exps: List[source.SExpr]): Res = {
exps match {
case exp :: exps =>
flattenExp(depth, env, exp) { exp =>
loop(exp :: acc, exps)
}
case Nil =>
k(acc.reverse)
}
}
loop(Nil, exps0)
private[this] def flattenExpList(
depth: DepthA,
env: Env,
exps0: List[source.SExpr],
): TailRec[List[SExpr.SExpr]] = {
traverse(exps0, (exp: source.SExpr) => flattenExp(depth, env, exp))
}
/** Monadic map for [[TailRec]]. */
private[this] def traverse[A, B](
xs: List[A],
f: A => TailRec[B],
): TailRec[List[B]] = {
xs match {
case Nil => done(Nil)
case x :: xs =>
for {
x <- f(x)
xs <- tailcall { traverse(xs, f) }
} yield (x :: xs)
}
}
}