Simplify Anf.scala by removing the continuations. (#17242)

The core functions of the ANF transformation in Anf.scala handle 3 different types of continuations:

1. The "transformations" (called tx), that are an inherent aspect of any ANF transformation expressed in pseudo-CPS style for convenience.
2. The "continuations" (called k) that had been introduced to avoid stack oveflows.
3. The trampolines, that have also been introduced to avoid stack overflows.

This commit recognizes that types 2 and 3 are redundant and merges them into only type 3. It relies on the fact that trampolines are monadic, and so we switch from the hand-rolled trampolines to that of scala's std library since they already implement flatMap.
This commit is contained in:
Paul Brauner 2023-08-09 12:19:01 +02:00 committed by GitHub
parent 5c31bc3f21
commit 4e54c69c9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,14 +29,13 @@ package com.daml.lf.speedy
* We use "source." and "t." for lightweight discrimination. * We use "source." and "t." for lightweight discrimination.
*/ */
import com.daml.lf.data.Trampoline.{Bounce, Land, Trampoline}
import com.daml.lf.speedy.{SExpr1 => source} import com.daml.lf.speedy.{SExpr1 => source}
import com.daml.lf.speedy.{SExpr => target} import com.daml.lf.speedy.{SExpr => target}
import com.daml.lf.speedy.Compiler.CompilationError import com.daml.lf.speedy.Compiler.CompilationError
import scala.annotation.nowarn import scala.annotation.nowarn
import scala.annotation.tailrec import scala.annotation.tailrec
import scala.util.control.TailCalls._
import scalaz.{@@, Tag} import scalaz.{@@, Tag}
private[lf] object Anf { private[lf] object Anf {
@ -52,9 +51,7 @@ private[lf] object Anf {
private def flattenToAnfInternal(exp: source.SExpr): target.SExpr = { private def flattenToAnfInternal(exp: source.SExpr): target.SExpr = {
val depth = DepthA(0) val depth = DepthA(0)
val env = initEnv val env = initEnv
flattenExp(depth, env, exp) { exp => flattenExp(depth, env, exp).result
Land(exp)
}.bounce
} }
/** The transformation code is implemented using a what looks like /** The transformation code is implemented using a what looks like
@ -71,32 +68,10 @@ private[lf] object Anf {
* mapping between them. See the types: DepthE, DepthA and Env. * mapping between them. See the types: DepthE, DepthA and Env.
* *
* Using a coding style that looks like CPS is a natural way to express the * Using a coding style that looks like CPS is a natural way to express the
* ANF transformation. However, this means the transformation is very * ANF transformation. The transformation is however not tail recursive: it
* stack-intensive. To address that, we need the code to be in "true" CPS * is only mostly in CPS form, and is actually stack intensive. To address
* form, which is not quite the same as the semantics required by the ANF * this, all functions of the core loop return trampolines
* transform. Having the code in ANF form allows us to use the trampoline * (scala.util.control.TailCalls).
* technique to execute the computation in constant stack space.
*
* This means that, in some sense, the following code has two (interleaved)
* levels of CPS-looking style. For the sake of clarity, in further comments
* as well as in the code, we will use the term "continuation" and the
* variable names "k", strictly for the "true" continuations that have
* been added to achieve constant stack space, and use the term
* "transformation function" and the variable names "transform", "tx" for
* the functions that express the semantics of the ANF transformation.
*
* Things are further muddied by the following:
* 1. A number of top-level functions defined in this object also qualify as
* "transformation functions", even though they themselves receive
* transformation functions as arguments and/or define new ones on the fly
* (flattenExp, transformLet1, flattenAlts, transformExp, atomizeExp,
* atomizeExps).
* 2. To achieve full CPS, transformation functions themselves need to accept
* (and apply) a continuation.
*
* Not all functions in this object are in CPS (only the ones that are part of
* the main recursive loop), but those that do always take the continuation as
* their last argument.
*/ */
/** `DepthE` tracks the stack-depth of the original expression being traversed */ /** `DepthE` tracks the stack-depth of the original expression being traversed */
@ -129,30 +104,14 @@ private[lf] object Anf {
Env(absMap = env.absMap ++ extra, oldDepth = env.oldDepth.incr(n)) Env(absMap = env.absMap ++ extra, oldDepth = env.oldDepth.incr(n))
} }
private[this] type Res = Trampoline[target.SExpr] private[this] type Res = TailRec[target.SExpr]
/** Tx is the type for the stacked transformation functions managed by the ANF /** Tx is the type for the stacked transformation functions managed by the ANF
* transformation, mainly transformExp. * transformation, mainly transformExp.
* *
* All of the transformation functions would, without CPS, return a target.SExpr,
* so that is the input type of the continuation.
*
* Both type parameters are ultimately needed because SCaseAlt does not
* extend source.SExpr. If it did, T would always be source.SExpr and A would always be
* target.SExpr.
*
* Note that Scala does not seem to be able to generate anonymous function of
* a parameterized type, so we use nested `defs` instead.
*
* @tparam T The type of expression this will be applied to. * @tparam T The type of expression this will be applied to.
*/ */
private[this] type Tx[T] = (DepthA, T) => K[target.SExpr] => Res private[this] type Tx[T] = (DepthA, T) => Res
/** K Is the type for continuations.
*
* @tparam T Type the function would have returned had it not been in CPS.
*/
private[this] type K[T] = T => Res
/** During conversion we need to deal with bindings which are made/found at a given /** During conversion we need to deal with bindings which are made/found at a given
* absolute stack depth. These are represented using `AbsBinding`. An absolute stack * absolute stack depth. These are represented using `AbsBinding`. An absolute stack
@ -208,15 +167,8 @@ private[lf] object Anf {
case Right(binding) => target.SELocS(makeRelativeB(depth, binding)) case Right(binding) => target.SELocS(makeRelativeB(depth, binding))
} }
private[this] def flattenExp(depth: DepthA, env: Env, exp: source.SExpr)( private[this] def flattenExp(depth: DepthA, env: Env, exp: source.SExpr): Res = {
k: K[target.SExpr] transformExp(depth, env, exp) { (_, sexpr) => done(sexpr) }
): Res = {
transformExp(depth, env, exp)(k) { (_, sexpr) => k =>
Bounce { () =>
k(sexpr)
}
}
} }
private[this] def transformLet1( private[this] def transformLet1(
@ -224,46 +176,36 @@ private[lf] object Anf {
env: Env, env: Env,
rhs: source.SExpr, rhs: source.SExpr,
body: source.SExpr, body: source.SExpr,
)(k: K[target.SExpr])(transform: Tx[target.SExpr]): Res = { )(transform: Tx[target.SExpr]): Res = {
transformExp(depth, env, rhs) { (depth, rhs) =>
transformExp(depth, env, rhs)(k) { (depth, rhs) => k =>
val depth1 = depth.incr(1) val depth1 = depth.incr(1)
val env1 = trackBindings(depth, env, 1) val env1 = trackBindings(depth, env, 1)
transformExp(depth1, env1, body) { body => transformExp(depth1, env1, body)(transform).map { body =>
Bounce { () => target.SELet1(rhs, body)
k(target.SELet1(rhs, body))
}
}(transform)
}
}
private[this] def flattenAlts(depth: DepthA, env: Env, alts0: List[source.SCaseAlt])(
k: K[List[target.SCaseAlt]]
): Res = {
def loop(acc: List[target.SCaseAlt], alts: List[source.SCaseAlt]): Res = {
alts match {
case alt :: alts =>
flattenAlt(depth, env, alt) { alt =>
loop(alt :: acc, alts)
}
case Nil =>
k(acc.reverse)
} }
} }
loop(Nil, alts0)
} }
private[this] def flattenAlt(depth: DepthA, env: Env, alt: source.SCaseAlt)( private[this] def flattenAlts(
k: K[target.SCaseAlt] depth: DepthA,
): Res = { env: Env,
alts0: List[source.SCaseAlt],
): TailRec[List[SExpr.SCaseAlt]] = {
traverse(alts0, (alt: source.SCaseAlt) => flattenAlt(depth, env, alt))
}
private[this] def flattenAlt(
depth: DepthA,
env: Env,
alt: source.SCaseAlt,
): TailRec[SExpr.SCaseAlt] = {
alt match { alt match {
case source.SCaseAlt(pat, body) => case source.SCaseAlt(pat, body) =>
val n = patternNArgs(pat) val n = patternNArgs(pat)
val env1 = trackBindings(depth, env, n) val env1 = trackBindings(depth, env, n)
flattenExp(depth.incr(n), env1, body) { body => flattenExp(depth.incr(n), env1, body).map { body =>
k(target.SCaseAlt(pat, body)) target.SCaseAlt(pat, body)
} }
} }
} }
@ -290,14 +232,14 @@ private[lf] object Anf {
* achieve constant stack through trampoline. * achieve constant stack through trampoline.
*/ */
private[this] def transformExp(depth: DepthA, env: Env, exp: source.SExpr)( private[this] def transformExp(depth: DepthA, env: Env, exp: source.SExpr)(
k: K[target.SExpr] transform: Tx[target.SExpr]
)(transform: Tx[target.SExpr]): Res = Bounce { () => ): Res = tailcall {
exp match { exp match {
case atom0: source.SExprAtomic => case atom0: source.SExprAtomic =>
val atom = makeRelativeA(depth)(makeAbsoluteA(env, atom0)) val atom = makeRelativeA(depth)(makeAbsoluteA(env, atom0))
transform(depth, atom)(k) transform(depth, atom)
case source.SEVal(x) => transform(depth, target.SEVal(x))(k) case source.SEVal(x) => transform(depth, target.SEVal(x))
case source.SEApp(func, args) => case source.SEApp(func, args) =>
// It's safe to perform ANF if the func-expression has no effects when evaluated. // It's safe to perform ANF if the func-expression has no effects when evaluated.
@ -312,63 +254,63 @@ private[lf] object Anf {
// It's also safe to perform ANF for applications of a single argument. // It's also safe to perform ANF for applications of a single argument.
val singleArg = args.lengthCompare(1) == 0 val singleArg = args.lengthCompare(1) == 0
if (safeFunc || singleArg) { if (safeFunc || singleArg) {
transformMultiApp(depth, env, func, args, k)(transform) transformMultiApp(depth, env, func, args)(transform)
} else { } else {
transformMultiAppSafely(depth, env, func, args, k)(transform) transformMultiAppSafely(depth, env, func, args)(transform)
} }
case source.SEMakeClo(fvs0, arity, body) => case source.SEMakeClo(fvs0, arity, body) =>
val fvs = fvs0.map((loc) => makeRelativeL(depth)(makeAbsoluteL(env, loc))) val fvs = fvs0.map((loc) => makeRelativeL(depth)(makeAbsoluteL(env, loc)))
flattenExp(DepthA(0), initEnv, body) { body => flattenExp(DepthA(0), initEnv, body).flatMap { body =>
transform(depth, target.SEMakeClo(fvs.toArray, arity, body))(k) transform(depth, target.SEMakeClo(fvs.toArray, arity, body))
} }
case source.SECase(scrut, alts0) => case source.SECase(scrut, alts0) =>
atomizeExp(depth, env, scrut, k) { (depth, scrut) => k => atomizeExp(depth, env, scrut) { (depth, scrut) =>
val scrut1 = makeRelativeA(depth)(scrut) val scrut1 = makeRelativeA(depth)(scrut)
flattenAlts(depth, env, alts0) { alts => flattenAlts(depth, env, alts0).flatMap { alts =>
transform(depth, target.SECaseAtomic(scrut1, alts.toArray))(k) transform(depth, target.SECaseAtomic(scrut1, alts.toArray))
} }
} }
case source.SELet(rhss, body) => case source.SELet(rhss, body) =>
val expanded = expandMultiLet(rhss, body) val expanded = expandMultiLet(rhss, body)
transformExp(depth, env, expanded)(k)(transform) transformExp(depth, env, expanded)(transform)
case source.SELet1General(rhs, body) => case source.SELet1General(rhs, body) =>
transformLet1(depth, env, rhs, body)(k)(transform) transformLet1(depth, env, rhs, body)(transform)
case source.SELocation(loc, body) => case source.SELocation(loc, body) =>
transformExp(depth, env, body)(k) { (depth, body) => k => transformExp(depth, env, body) { (depth, body) =>
Bounce { () => tailcall {
transform(depth, target.SELocation(loc, body))(k) transform(depth, target.SELocation(loc, body))
} }
} }
case source.SELabelClosure(label, exp) => case source.SELabelClosure(label, exp) =>
transformExp(depth, env, exp)(k) { (depth, exp) => k => transformExp(depth, env, exp) { (depth, exp) =>
Bounce { () => tailcall {
transform(depth, target.SELabelClosure(label, exp))(k) transform(depth, target.SELabelClosure(label, exp))
} }
} }
case source.SETryCatch(body, handler0) => case source.SETryCatch(body, handler0) =>
// we must not lift applications from either the body or the handler outside of // we must not lift applications from either the body or the handler outside of
// the try-catch block, so we flatten each separately: // the try-catch block, so we flatten each separately:
flattenExp(depth, env, body) { body => flattenExp(depth, env, body).flatMap { body =>
flattenExp(depth.incr(1), trackBindings(depth, env, 1), handler0) { handler => flattenExp(depth.incr(1), trackBindings(depth, env, 1), handler0).flatMap { handler =>
transform(depth, target.SETryCatch(body, handler))(k) transform(depth, target.SETryCatch(body, handler))
} }
} }
case source.SEScopeExercise(body) => case source.SEScopeExercise(body) =>
flattenExp(depth, env, body) { body => flattenExp(depth, env, body).flatMap { body =>
transform(depth, target.SEScopeExercise(body))(k) transform(depth, target.SEScopeExercise(body))
} }
case source.SEPreventCatch(body) => case source.SEPreventCatch(body) =>
flattenExp(depth, env, body) { body => flattenExp(depth, env, body).flatMap { body =>
transform(depth, target.SEPreventCatch(body))(k) transform(depth, target.SEPreventCatch(body))
} }
} }
@ -378,40 +320,36 @@ private[lf] object Anf {
depth: DepthA, depth: DepthA,
env: Env, env: Env,
exps: List[source.SExpr], exps: List[source.SExpr],
k: K[target.SExpr],
)(transform: Tx[List[AbsAtom]]): Res = )(transform: Tx[List[AbsAtom]]): Res =
exps match { exps match {
case Nil => transform(depth, Nil)(k) case Nil => transform(depth, Nil)
case exp :: exps => case exp :: exps =>
atomizeExp(depth, env, exp, k) { (depth, atom) => k => atomizeExp(depth, env, exp) { (depth, atom) =>
Bounce { () => tailcall {
atomizeExps(depth, env, exps, k) { (depth, atoms) => k => atomizeExps(depth, env, exps) { (depth, atoms) =>
Bounce { () => tailcall {
transform(depth, atom :: atoms)(k) transform(depth, atom :: atoms)
} }
} }
} }
} }
} }
private[this] def atomizeExp(depth: DepthA, env: Env, exp: source.SExpr, k: K[target.SExpr])( private[this] def atomizeExp(depth: DepthA, env: Env, exp: source.SExpr)(
transform: Tx[AbsAtom] transform: Tx[AbsAtom]
): Res = { ): Res = {
exp match { exp match {
case ea: source.SExprAtomic => transform(depth, makeAbsoluteA(env, ea))(k) case ea: source.SExprAtomic => transform(depth, makeAbsoluteA(env, ea))
case _ => { case _ =>
transformExp(depth, env, exp)(k) { (depth, exp) => k => transformExp(depth, env, exp) { (depth, exp) =>
val atom = Right(AbsBinding(depth)) val atom = Right(AbsBinding(depth))
Bounce { () => tailcall {
transform(depth.incr(1), atom) { body => transform(depth.incr(1), atom).map { body =>
Bounce { () => target.SELet1(exp, body)
k(target.SELet1(exp, body))
}
} }
} }
} }
}
} }
} }
@ -435,14 +373,13 @@ private[lf] object Anf {
env: Env, env: Env,
func: source.SExpr, func: source.SExpr,
args: List[source.SExpr], args: List[source.SExpr],
k: K[target.SExpr],
)(transform: Tx[target.SExpr]): Res = { )(transform: Tx[target.SExpr]): Res = {
atomizeExp(depth, env, func, k) { (depth, func) => k => atomizeExp(depth, env, func) { (depth, func) =>
atomizeExps(depth, env, args, k) { (depth, args) => k => atomizeExps(depth, env, args) { (depth, args) =>
val func1 = makeRelativeA(depth)(func) val func1 = makeRelativeA(depth)(func)
val args1 = args.map(makeRelativeA(depth)) val args1 = args.map(makeRelativeA(depth))
transform(depth, target.SEAppAtomic(func1, args1.toArray))(k) transform(depth, target.SEAppAtomic(func1, args1.toArray))
} }
} }
} }
@ -457,34 +394,38 @@ private[lf] object Anf {
env: Env, env: Env,
func: source.SExpr, func: source.SExpr,
args: List[source.SExpr], args: List[source.SExpr],
k: K[target.SExpr],
)(transform: Tx[target.SExpr]): Res = { )(transform: Tx[target.SExpr]): Res = {
atomizeExp(depth, env, func, k) { (depth, func) => k => atomizeExp(depth, env, func) { (depth, func) =>
val func1 = makeRelativeA(depth)(func) val func1 = makeRelativeA(depth)(func)
// we dont atomize the args here // we dont atomize the args here
flattenExpList(depth, env, args) { args => flattenExpList(depth, env, args).flatMap { args =>
// we build a non-atomic application here (only the function is atomic) // we build a non-atomic application here (only the function is atomic)
transform(depth, target.SEAppOnlyFunIsAtomic(func1, args.toArray))(k) transform(depth, target.SEAppOnlyFunIsAtomic(func1, args.toArray))
} }
} }
} }
private[this] def flattenExpList(depth: DepthA, env: Env, exps0: List[source.SExpr])( private[this] def flattenExpList(
k: K[List[target.SExpr]] depth: DepthA,
): Res = { env: Env,
exps0: List[source.SExpr],
def loop(acc: List[target.SExpr], exps: List[source.SExpr]): Res = { ): TailRec[List[SExpr.SExpr]] = {
exps match { traverse(exps0, (exp: source.SExpr) => flattenExp(depth, env, exp))
case exp :: exps =>
flattenExp(depth, env, exp) { exp =>
loop(exp :: acc, exps)
}
case Nil =>
k(acc.reverse)
}
}
loop(Nil, exps0)
} }
/** Monadic map for [[TailRec]]. */
private[this] def traverse[A, B](
xs: List[A],
f: A => TailRec[B],
): TailRec[List[B]] = {
xs match {
case Nil => done(Nil)
case x :: xs =>
for {
x <- f(x)
xs <- tailcall { traverse(xs, f) }
} yield (x :: xs)
}
}
} }