Different types before after closure conversion (#11661)

* avoid single letter (s./t.) import prefixes

* Split type SExpr0 -> SExpr{0,1} for before/after closure conversion

CHANGELOG_BEGIN
CHANGELOG_END

* remove unnecessary constructors in SExpr{0,1}

* remove SExpr0.SExprAtomic
This commit is contained in:
nickchapman-da 2021-11-12 08:54:43 +00:00 committed by GitHub
parent 2f4476c12a
commit 3192d5eb74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 310 additions and 197 deletions

View File

@ -30,7 +30,7 @@ package com.daml.lf.speedy
*/
import com.daml.lf.data.Trampoline.{Bounce, Land, Trampoline}
import com.daml.lf.speedy.{SExpr0 => source}
import com.daml.lf.speedy.{SExpr1 => source}
import com.daml.lf.speedy.{SExpr => target}
import com.daml.lf.speedy.Compiler.CompilationError
@ -190,7 +190,6 @@ private[lf] object Anf {
case source.SEValue(x) => target.SEValue(x)
case source.SEBuiltin(x) => target.SEBuiltin(x)
case source.SEBuiltinRecursiveDefinition(x) => target.SEBuiltinRecursiveDefinition(x)
case _: source.SEVar => sys.error(s"Anf1.convertAtom, unexpected: $x")
}
}
@ -375,7 +374,7 @@ private[lf] object Anf {
val body: target.SExpr = flattenExp(depth, env, body0)(anf => Land(anf.wrapped)).bounce
Bounce(() => transform(depth, target.SEScopeExercise(body), k))
case _: source.SEAbs | _: source.SEDamlException =>
case _: source.SEDamlException =>
throw CompilationError(s"flatten: unexpected: $exp")
}

View File

@ -12,7 +12,8 @@ package com.daml.lf.speedy
* following ANF transformation phase.
*/
import com.daml.lf.speedy.{SExpr0 => s}
import com.daml.lf.speedy.{SExpr0 => source}
import com.daml.lf.speedy.{SExpr1 => target}
private[speedy] object ClosureConversion {
@ -41,86 +42,86 @@ private[speedy] object ClosureConversion {
*/
// TODO: Introduce a new type expression for the result of closure conversion
private[speedy] def closureConvert(expr: s.SExpr): s.SExpr = {
private[speedy] def closureConvert(expr: source.SExpr): target.SExpr = {
closureConvert(Map.empty, expr)
}
private def closureConvert(remaps: Map[Int, s.SELoc], expr: s.SExpr): s.SExpr = {
private def closureConvert(remaps: Map[Int, target.SELoc], expr: source.SExpr): target.SExpr = {
// remaps is a function which maps the relative offset from variables (SEVar) to their runtime location
// The Map must contain a binding for every variable referenced.
// The Map is consulted when translating variable references (SEVar) and free variables of an abstraction (SEAbs)
def remap(i: Int): s.SELoc =
def remap(i: Int): target.SELoc =
remaps.get(i) match {
case Some(loc) => loc
case None =>
throw CompilationError(s"remap($i),remaps=$remaps")
}
expr match {
case s.SEVar(i) => remap(i)
case v: s.SEVal => v
case be: s.SEBuiltin => be
case pl: s.SEValue => pl
case f: s.SEBuiltinRecursiveDefinition => f
case s.SELocation(loc, body) =>
s.SELocation(loc, closureConvert(remaps, body))
case source.SEVar(i) => remap(i)
case source.SEVal(ref) => target.SEVal(ref)
case source.SEBuiltin(b) => target.SEBuiltin(b)
case source.SEValue(v) => target.SEValue(v)
case source.SEBuiltinRecursiveDefinition(f) => target.SEBuiltinRecursiveDefinition(f)
case source.SELocation(loc, body) =>
target.SELocation(loc, closureConvert(remaps, body))
case s.SEAbs(0, _) =>
case source.SEAbs(0, _) =>
throw CompilationError("empty SEAbs")
case s.SEAbs(arity, body) =>
case source.SEAbs(arity, body) =>
val fvs = freeVars(body, arity).toList.sorted
val newRemapsF: Map[Int, s.SELoc] = fvs.zipWithIndex.map { case (orig, i) =>
(orig + arity) -> s.SELocF(i)
val newRemapsF: Map[Int, target.SELoc] = fvs.zipWithIndex.map { case (orig, i) =>
(orig + arity) -> target.SELocF(i)
}.toMap
val newRemapsA = (1 to arity).map { case i =>
i -> s.SELocA(arity - i)
i -> target.SELocA(arity - i)
}
// The keys in newRemapsF and newRemapsA are disjoint
val newBody = closureConvert(newRemapsF ++ newRemapsA, body)
s.SEMakeClo(fvs.map(remap).toArray, arity, newBody)
target.SEMakeClo(fvs.map(remap).toArray, arity, newBody)
case s.SEAppGeneral(fun, args) =>
case source.SEAppGeneral(fun, args) =>
val newFun = closureConvert(remaps, fun)
val newArgs = args.map(closureConvert(remaps, _))
s.SEApp(newFun, newArgs)
target.SEApp(newFun, newArgs)
case s.SECase(scrut, alts) =>
s.SECase(
case source.SECase(scrut, alts) =>
target.SECase(
closureConvert(remaps, scrut),
alts.map { case s.SCaseAlt(pat, body) =>
alts.map { case source.SCaseAlt(pat, body) =>
val n = pat.numArgs
s.SCaseAlt(
target.SCaseAlt(
pat,
closureConvert(shift(remaps, n), body),
)
},
)
case s.SELet(bounds, body) =>
s.SELet(
case source.SELet(bounds, body) =>
target.SELet(
bounds.zipWithIndex.map { case (b, i) =>
closureConvert(shift(remaps, i), b)
},
closureConvert(shift(remaps, bounds.length), body),
)
case s.SETryCatch(body, handler) =>
s.SETryCatch(
case source.SETryCatch(body, handler) =>
target.SETryCatch(
closureConvert(remaps, body),
closureConvert(shift(remaps, 1), handler),
)
case s.SEScopeExercise(body) =>
s.SEScopeExercise(closureConvert(remaps, body))
case source.SEScopeExercise(body) =>
target.SEScopeExercise(closureConvert(remaps, body))
case s.SELabelClosure(label, expr) =>
s.SELabelClosure(label, closureConvert(remaps, expr))
case source.SELabelClosure(label, expr) =>
target.SELabelClosure(label, closureConvert(remaps, expr))
case s.SELet1General(bound, body) =>
s.SELet1General(closureConvert(remaps, bound), closureConvert(shift(remaps, 1), body))
case source.SELet1General(bound, body) =>
target.SELet1General(closureConvert(remaps, bound), closureConvert(shift(remaps, 1), body))
case _: s.SELoc | _: s.SEMakeClo | _: s.SEDamlException | _: s.SEImportValue =>
case _: source.SEDamlException | _: source.SEImportValue =>
throw CompilationError(s"closureConvert: unexpected $expr")
}
}
@ -131,60 +132,59 @@ private[speedy] object ClosureConversion {
// We must modify `remaps` because it is keyed by indexes relative to the end of the stack.
// And any values in the map which are of the form SELocS must also be _shifted_
// because SELocS indexes are also relative to the end of the stack.
private[this] def shift(remaps: Map[Int, s.SELoc], n: Int): Map[Int, s.SELoc] = {
private[this] def shift(remaps: Map[Int, target.SELoc], n: Int): Map[Int, target.SELoc] = {
// We must update both the keys of the map (the relative-indexes from the original SEVar)
// And also any values in the map which are stack located (SELocS), which are also indexed relatively
val m1 = remaps.map { case (k, loc) => (n + k, shiftLoc(loc, n)) }
// And create mappings for the `n` new stack items
val m2 = (1 to n).map(i => (i, s.SELocS(i)))
val m2 = (1 to n).map(i => (i, target.SELocS(i)))
m1 ++ m2
}
private[this] def shiftLoc(loc: s.SELoc, n: Int): s.SELoc = loc match {
case s.SELocS(i) => s.SELocS(i + n)
case s.SELocA(_) | s.SELocF(_) => loc
private[this] def shiftLoc(loc: target.SELoc, n: Int): target.SELoc = loc match {
case target.SELocS(i) => target.SELocS(i + n)
case target.SELocA(_) | target.SELocF(_) => loc
}
/** Compute the free variables in a speedy expression.
* The returned free variables are de bruijn indices
* adjusted to the stack of the caller.
*/
private[this] def freeVars(expr: s.SExpr, initiallyBound: Int): Set[Int] = {
def go(expr: s.SExpr, bound: Int, free: Set[Int]): Set[Int] =
private[this] def freeVars(expr: source.SExpr, initiallyBound: Int): Set[Int] = {
def go(expr: source.SExpr, bound: Int, free: Set[Int]): Set[Int] =
expr match {
case s.SEVar(i) =>
case source.SEVar(i) =>
if (i > bound) free + (i - bound) else free /* adjust to caller's environment */
case _: s.SEVal => free
case _: s.SEBuiltin => free
case _: s.SEValue => free
case _: s.SEBuiltinRecursiveDefinition => free
case s.SELocation(_, body) =>
case _: source.SEVal => free
case _: source.SEBuiltin => free
case _: source.SEValue => free
case _: source.SEBuiltinRecursiveDefinition => free
case source.SELocation(_, body) =>
go(body, bound, free)
case s.SEAppGeneral(fun, args) =>
case source.SEAppGeneral(fun, args) =>
args.foldLeft(go(fun, bound, free))((acc, arg) => go(arg, bound, acc))
case s.SEAbs(n, body) =>
case source.SEAbs(n, body) =>
go(body, bound + n, free)
case s.SECase(scrut, alts) =>
alts.foldLeft(go(scrut, bound, free)) { case (acc, s.SCaseAlt(pat, body)) =>
case source.SECase(scrut, alts) =>
alts.foldLeft(go(scrut, bound, free)) { case (acc, source.SCaseAlt(pat, body)) =>
val n = pat.numArgs
go(body, bound + n, acc)
}
case s.SELet(bounds, body) =>
case source.SELet(bounds, body) =>
bounds.zipWithIndex.foldLeft(go(body, bound + bounds.length, free)) {
case (acc, (expr, idx)) => go(expr, bound + idx, acc)
}
case s.SELabelClosure(_, expr) =>
case source.SELabelClosure(_, expr) =>
go(expr, bound, free)
case s.SETryCatch(body, handler) =>
case source.SETryCatch(body, handler) =>
go(body, bound, go(handler, 1 + bound, free))
case s.SEScopeExercise(body) =>
case source.SEScopeExercise(body) =>
go(body, bound, free)
case _: s.SELoc | _: s.SEMakeClo | _: s.SEDamlException | _: s.SEImportValue |
_: s.SELet1General =>
case _: source.SEDamlException | _: source.SEImportValue | _: source.SELet1General =>
throw CompilationError(s"freeVars: unexpected $expr")
}

View File

@ -25,18 +25,22 @@ package speedy
* Stage 3 is in Anf.scala
* Stage 4 is in ValidateCompilation.scala
*
* During Stage3 (ANF transformation), we move from this type (SExpr0) to SExpr,
* and so have the expression form suitable for execution on a speedy machine.
* During Stage2 (Closure Conversion), we move from SExpr0 to SExpr1,
* During Stage3 (ANF transformation), we move from SExpr1 to SExpr.
*
* Here is a summary of the differences between SExp0 and SExpr:
* Summary of which constructors are contained by: SExp0, SExpr1 and SExpr:
*
* - Constructors in both: SEAppGeneral, SEBuiltin, SEBuiltinRecursiveDefinition,
* SEDamlException, SEImportValue, SELabelClosure, SELet1General, SELocA, SELocF,
* SELocS, SELocation, SEMakeClo, SEScopeExercise, SETryCatch, SEVal, SEValue,
* - In SExpr{0,1,} (everywhere): SEAppGeneral, SEBuiltin, SEBuiltinRecursiveDefinition,
* SEDamlException, SEImportValue, SELabelClosure, SELet1General, SELocation,
* SEScopeExercise, SETryCatch, SEVal, SEValue,
*
* - Only in SExpr0: SEAbs, SECase, SELet, SEVar
* - In SExpr0: SEAbs, SEVar
*
* - Only in SExpr: SEAppAtomicFun, SEAppAtomicGeneral, SEAppAtomicSaturatedBuiltin,
* - In SExpr{0,1}: SECase, SELet
*
* - In SExpr{1,}: SELocA, SELocF, SELocS, SEMakeClo,
*
* - In SExpr: SEAppAtomicFun, SEAppAtomicGeneral, SEAppAtomicSaturatedBuiltin,
* SECaseAtomic, SELet1Builtin, SELet1BuiltinArithmetic
*/
@ -52,15 +56,13 @@ private[speedy] object SExpr0 {
sealed abstract class SExpr extends Product with Serializable
sealed abstract class SExprAtomic extends SExpr
/** Reference to a variable. 'index' is the 1-based de Bruijn index,
* that is, SEVar(1) points to the nearest enclosing variable binder.
* which could be an SELam, SELet, or a binding variant of SECasePat.
* https://en.wikipedia.org/wiki/De_Bruijn_index
* This expression form is only allowed prior to closure conversion
*/
final case class SEVar(index: Int) extends SExprAtomic
final case class SEVar(index: Int) extends SExpr
/** Reference to a value. On first lookup the evaluated expression is
* stored in 'cached'.
@ -68,16 +70,14 @@ private[speedy] object SExpr0 {
final case class SEVal(ref: SDefinitionRef) extends SExpr
/** Reference to a builtin function */
final case class SEBuiltin(b: SBuiltin) extends SExprAtomic
final case class SEBuiltin(b: SBuiltin) extends SExpr
/** A pre-computed value, usually primitive literal, e.g. integer, text, boolean etc. */
final case class SEValue(v: SValue) extends SExprAtomic
final case class SEValue(v: SValue) extends SExpr
object SEValue extends SValueContainer[SEValue]
/** Function application:
* General case: 'fun' and 'args' are any kind of expression
*/
/** Function application */
final case class SEAppGeneral(fun: SExpr, args: Array[SExpr]) extends SExpr with SomeArrayEquals
object SEApp {
@ -86,10 +86,7 @@ private[speedy] object SExpr0 {
}
}
/** Lambda abstraction. Transformed into SEMakeClo in lambda lifting.
* NOTE(JM): Compilation done in two passes so that closure conversion
* can be written against this simplified expression type.
*/
/** Lambda abstraction. Transformed to SEMakeClo during closure conversion */
final case class SEAbs(arity: Int, body: SExpr) extends SExpr
object SEAbs {
@ -100,29 +97,6 @@ private[speedy] object SExpr0 {
val identity: SEAbs = SEAbs(1, SEVar(1))
}
/** Closure creation. Create a new closure object storing the free variables
* in 'body'.
*/
final case class SEMakeClo(fvs: Array[SELoc], arity: Int, body: SExpr)
extends SExpr
with SomeArrayEquals
/** SELoc -- Reference to the runtime location of a variable.
*
* This is the closure-converted form of SEVar. There are three sub-forms, with sufffix:
* S/A/F, indicating [S]tack, [A]argument, or [F]ree variable captured by a closure.
*/
sealed abstract class SELoc extends SExprAtomic
// SELocS -- variable is located on the stack (SELet & binding forms of SECasePat)
final case class SELocS(n: Int) extends SELoc
// SELocS -- variable is located in the args array of the application
final case class SELocA(n: Int) extends SELoc
// SELocF -- variable is located in the free-vars array of the closure being applied
final case class SELocF(n: Int) extends SELoc
/** Pattern match. */
final case class SECase(scrut: SExpr, alts: Array[SCaseAlt]) extends SExpr with SomeArrayEquals
@ -180,6 +154,6 @@ private[speedy] object SExpr0 {
// TODO: simplify here: There is only kind of SEBuiltinRecursiveDefinition! - EqualList
final case class SEBuiltinRecursiveDefinition(ref: runTime.SEBuiltinRecursiveDefinition.Reference)
extends SExprAtomic
extends SExpr
}

View File

@ -0,0 +1,125 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf
package speedy
/** SExpr1 -- AST for the speedy compiler pipeline. (after closure conversion phase)
*
* These are *not* the expression forms which run on the speedy machine. See SExpr.
*/
import com.daml.lf.data.Ref._
import com.daml.lf.language.Ast
import com.daml.lf.value.{Value => V}
import com.daml.lf.speedy.SValue._
import com.daml.lf.speedy.SExpr.{SDefinitionRef, SCasePat}
import com.daml.lf.speedy.{SExpr => runTime}
private[speedy] object SExpr1 {
sealed abstract class SExpr extends Product with Serializable
sealed abstract class SExprAtomic extends SExpr
/** Reference to a value. On first lookup the evaluated expression is
* stored in 'cached'.
*/
final case class SEVal(ref: SDefinitionRef) extends SExpr
/** Reference to a builtin function */
final case class SEBuiltin(b: SBuiltin) extends SExprAtomic
/** A pre-computed value, usually primitive literal, e.g. integer, text, boolean etc. */
final case class SEValue(v: SValue) extends SExprAtomic
object SEValue extends SValueContainer[SEValue]
/** Function application:
* General case: 'fun' and 'args' are any kind of expression
*/
final case class SEAppGeneral(fun: SExpr, args: Array[SExpr]) extends SExpr with SomeArrayEquals
object SEApp {
def apply(fun: SExpr, args: Array[SExpr]): SExpr = {
SEAppGeneral(fun, args)
}
}
/** Closure creation. Create a new closure object storing the free variables
* in 'body'.
*/
final case class SEMakeClo(fvs: Array[SELoc], arity: Int, body: SExpr)
extends SExpr
with SomeArrayEquals
/** SELoc -- Reference to the runtime location of a variable.
*
* This is the closure-converted form of SEVar. There are three sub-forms, with sufffix:
* S/A/F, indicating [S]tack, [A]argument, or [F]ree variable captured by a closure.
*/
sealed abstract class SELoc extends SExprAtomic
// SELocS -- variable is located on the stack (SELet & binding forms of SECasePat)
final case class SELocS(n: Int) extends SELoc
// SELocS -- variable is located in the args array of the application
final case class SELocA(n: Int) extends SELoc
// SELocF -- variable is located in the free-vars array of the closure being applied
final case class SELocF(n: Int) extends SELoc
/** Pattern match. */
final case class SECase(scrut: SExpr, alts: Array[SCaseAlt]) extends SExpr with SomeArrayEquals
/** A let-expression with a single RHS
* This form only exists *during* the ANF transformation, but not when the ANF
* transformation is finished.
*/
final case class SELet1General(rhs: SExpr, body: SExpr) extends SExpr with SomeArrayEquals
/** A non-recursive, non-parallel let block.
* It is used as an intermediary data structure by the compiler to
* mitigate stack overflow issues, but are later exploded into
* [[SELet1General]] and [[SELet1Builtin]] by the ANF transformation.
*/
final case class SELet(bounds: List[SExpr], body: SExpr) extends SExpr
/** Location annotation. When encountered the location is stored in the 'lastLocation'
* variable of the machine. When commit is begun the location is stored in 'commitLocation'.
*/
final case class SELocation(loc: Location, expr: SExpr) extends SExpr
/** This is used only during profiling. When a package is compiled with
* profiling enabled, the right hand sides of top-level and let bindings,
* lambdas and some builtins are wrapped into [[SELabelClosure]]. During
* runtime, if the value resulting from evaluating [[expr]] is a
* (partially applied) closure, the label of the closure is set to the
* [[label]] given here.
* See [[com.daml.lf.speedy.Profile]] for an explanation why we use
* [[AnyRef]] for the label.
*/
final case class SELabelClosure(label: Profile.Label, expr: SExpr) extends SExpr
/** We cannot crash in the engine call back.
* Rather, we set the control to this expression and then crash when executing.
*/
final case class SEDamlException(error: interpretation.Error) extends SExpr
final case class SEImportValue(typ: Ast.Type, value: V) extends SExpr
/** Exception handler */
final case class SETryCatch(body: SExpr, handler: SExpr) extends SExpr
/** Exercise scope (begin..end) */
final case class SEScopeExercise(body: SExpr) extends SExpr
/** Case alternative. If the 'pattern' matches, then the environment is accordingly
* extended and 'body' is evaluated.
*/
final case class SCaseAlt(pattern: SCasePat, body: SExpr)
final case class SEBuiltinRecursiveDefinition(ref: runTime.SEBuiltinRecursiveDefinition.Reference)
extends SExprAtomic
}

View File

@ -9,7 +9,7 @@ package com.daml.lf.speedy
*/
import com.daml.lf.speedy.SValue._
import com.daml.lf.speedy.{SExpr => t}
import com.daml.lf.speedy.SExpr._
import scala.annotation.tailrec
@ -17,7 +17,7 @@ private[lf] object ValidateCompilation {
case class CompilationError(error: String) extends RuntimeException(error, null, true, false)
private[speedy] def validateCompilation(expr0: t.SExpr): t.SExpr = {
private[speedy] def validateCompilation(expr0: SExpr): SExpr = {
def goV(v: SValue): Unit =
v match {
@ -37,73 +37,73 @@ private[lf] object ValidateCompilation {
throw CompilationError("validate: unexpected s.SEValue")
}
def goBody(maxS: Int, maxA: Int, maxF: Int): t.SExpr => Unit = {
def goBody(maxS: Int, maxA: Int, maxF: Int): SExpr => Unit = {
def goLoc(loc: t.SELoc) = loc match {
case t.SELocS(i) =>
def goLoc(loc: SELoc) = loc match {
case SELocS(i) =>
if (i < 1 || i > maxS)
throw CompilationError(s"validate: SELocS: index $i out of range ($maxS..1)")
case t.SELocA(i) =>
case SELocA(i) =>
if (i < 0 || i >= maxA)
throw CompilationError(s"validate: SELocA: index $i out of range (0..$maxA-1)")
case t.SELocF(i) =>
case SELocF(i) =>
if (i < 0 || i >= maxF)
throw CompilationError(s"validate: SELocF: index $i out of range (0..$maxF-1)")
}
def go(expr: t.SExpr): Unit = expr match {
case loc: t.SELoc => goLoc(loc)
case _: t.SEVal => ()
case _: t.SEBuiltin => ()
case _: t.SEBuiltinRecursiveDefinition => ()
case t.SEValue(v) => goV(v)
case t.SEAppAtomicGeneral(fun, args) =>
def go(expr: SExpr): Unit = expr match {
case loc: SELoc => goLoc(loc)
case _: SEVal => ()
case _: SEBuiltin => ()
case _: SEBuiltinRecursiveDefinition => ()
case SEValue(v) => goV(v)
case SEAppAtomicGeneral(fun, args) =>
go(fun)
args.foreach(go)
case t.SEAppAtomicSaturatedBuiltin(_, args) =>
case SEAppAtomicSaturatedBuiltin(_, args) =>
args.foreach(go)
case t.SEAppGeneral(fun, args) =>
case SEAppGeneral(fun, args) =>
go(fun)
args.foreach(go)
case t.SEAppAtomicFun(fun, args) =>
case SEAppAtomicFun(fun, args) =>
go(fun)
args.foreach(go)
case t.SEMakeClo(fvs, n, body) =>
case SEMakeClo(fvs, n, body) =>
fvs.foreach(goLoc)
goBody(0, n, fvs.length)(body)
case t.SECaseAtomic(scrut, alts) =>
case SECaseAtomic(scrut, alts) =>
go(scrut)
alts.foreach { case t.SCaseAlt(pat, body) =>
alts.foreach { case SCaseAlt(pat, body) =>
val n = pat.numArgs
goBody(maxS + n, maxA, maxF)(body)
}
case _: t.SELet1General => goLets(maxS)(expr)
case _: t.SELet1Builtin => goLets(maxS)(expr)
case _: t.SELet1BuiltinArithmetic => goLets(maxS)(expr)
case t.SELocation(_, body) =>
case _: SELet1General => goLets(maxS)(expr)
case _: SELet1Builtin => goLets(maxS)(expr)
case _: SELet1BuiltinArithmetic => goLets(maxS)(expr)
case SELocation(_, body) =>
go(body)
case t.SELabelClosure(_, expr) =>
case SELabelClosure(_, expr) =>
go(expr)
case t.SETryCatch(body, handler) =>
case SETryCatch(body, handler) =>
go(body)
goBody(maxS + 1, maxA, maxF)(handler)
case t.SEScopeExercise(body) =>
case SEScopeExercise(body) =>
go(body)
case _: t.SEDamlException | _: t.SEImportValue =>
case _: SEDamlException | _: SEImportValue =>
throw CompilationError(s"validate: unexpected $expr")
}
@tailrec
def goLets(maxS: Int)(expr: t.SExpr): Unit = {
def goLets(maxS: Int)(expr: SExpr): Unit = {
def go = goBody(maxS, maxA, maxF)
expr match {
case t.SELet1General(rhs, body) =>
case SELet1General(rhs, body) =>
go(rhs)
goLets(maxS + 1)(body)
case t.SELet1Builtin(_, args, body) =>
case SELet1Builtin(_, args, body) =>
args.foreach(go)
goLets(maxS + 1)(body)
case t.SELet1BuiltinArithmetic(_, args, body) =>
case SELet1BuiltinArithmetic(_, args, body) =>
args.foreach(go)
goLets(maxS + 1)(body)
case expr =>

View File

@ -7,8 +7,8 @@ import org.scalatest.Assertion
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import com.daml.lf.speedy.{SExpr0 => s}
import com.daml.lf.speedy.{SExpr => t}
import com.daml.lf.speedy.{SExpr1 => source}
import com.daml.lf.speedy.{SExpr => target}
import com.daml.lf.speedy.SValue._
import com.daml.lf.speedy.SBuiltin._
import com.daml.lf.speedy.Anf.flattenToAnf
@ -129,16 +129,16 @@ class AnfTest extends AnyWordSpec with Matchers {
"error applied to 1 arg" should {
"be transformed to ANF as expected" in {
val original = slam(1, s.SEApp(s.SEBuiltin(SBError), Array(sarg0)))
val expected = lam(1, t.SEAppAtomicSaturatedBuiltin(SBError, Array(arg0)))
val original = slam(1, source.SEApp(source.SEBuiltin(SBError), Array(sarg0)))
val expected = lam(1, target.SEAppAtomicSaturatedBuiltin(SBError, Array(arg0)))
testTransform(original, expected)
}
}
"error (over) applied to 2 arg" should {
"be transformed to ANF as expected" in {
val original = slam(2, s.SEApp(s.SEBuiltin(SBError), Array(sarg0, sarg1)))
val expected = lam(2, t.SEAppAtomicFun(t.SEBuiltin(SBError), Array(arg0, arg1)))
val original = slam(2, source.SEApp(source.SEBuiltin(SBError), Array(sarg0, sarg1)))
val expected = lam(2, target.SEAppAtomicFun(target.SEBuiltin(SBError), Array(arg0, arg1)))
testTransform(original, expected)
}
}
@ -194,47 +194,60 @@ class AnfTest extends AnyWordSpec with Matchers {
}
// expression builders
private def lam(n: Int, body: t.SExpr): t.SExpr = t.SEMakeClo(Array(), n, body)
private def clo1(fv: t.SELoc, n: Int, body: t.SExpr): t.SExpr = t.SEMakeClo(Array(fv), n, body)
private def lam(n: Int, body: target.SExpr): target.SExpr = target.SEMakeClo(Array(), n, body)
private def clo1(fv: target.SELoc, n: Int, body: target.SExpr): target.SExpr =
target.SEMakeClo(Array(fv), n, body)
private def app2n(func: t.SExprAtomic, arg1: t.SExpr, arg2: t.SExpr): t.SExpr =
t.SEAppAtomicFun(func, Array(arg1, arg2))
private def app2n(
func: target.SExprAtomic,
arg1: target.SExpr,
arg2: target.SExpr,
): target.SExpr =
target.SEAppAtomicFun(func, Array(arg1, arg2))
// anf builders
private def let1(rhs: t.SExpr, body: t.SExpr): t.SExpr =
t.SELet1General(rhs, body)
private def let1(rhs: target.SExpr, body: target.SExpr): target.SExpr =
target.SELet1General(rhs, body)
private def let1b2(
op: SBuiltinPure,
arg1: t.SExprAtomic,
arg2: t.SExprAtomic,
body: t.SExpr,
): t.SExpr =
t.SELet1Builtin(op, Array(arg1, arg2), body)
arg1: target.SExprAtomic,
arg2: target.SExprAtomic,
body: target.SExpr,
): target.SExpr =
target.SELet1Builtin(op, Array(arg1, arg2), body)
private def let1b2(
op: SBuiltinArithmetic,
arg1: t.SExprAtomic,
arg2: t.SExprAtomic,
body: t.SExpr,
): t.SExpr =
t.SELet1BuiltinArithmetic(op, Array(arg1, arg2), body)
arg1: target.SExprAtomic,
arg2: target.SExprAtomic,
body: target.SExpr,
): target.SExpr =
target.SELet1BuiltinArithmetic(op, Array(arg1, arg2), body)
private def appa(func: t.SExprAtomic, arg: t.SExprAtomic): t.SExpr =
t.SEAppAtomicGeneral(func, Array(arg))
private def appa(func: target.SExprAtomic, arg: target.SExprAtomic): target.SExpr =
target.SEAppAtomicGeneral(func, Array(arg))
private def binopa(op: SBuiltinArithmetic, x: t.SExprAtomic, y: t.SExprAtomic): t.SExpr =
t.SEAppAtomicSaturatedBuiltin(op, Array(x, y))
private def binopa(
op: SBuiltinArithmetic,
x: target.SExprAtomic,
y: target.SExprAtomic,
): target.SExpr =
target.SEAppAtomicSaturatedBuiltin(op, Array(x, y))
private def itea(i: t.SExprAtomic, th: t.SExpr, e: t.SExpr): t.SExpr =
t.SECaseAtomic(i, Array(t.SCaseAlt(patTrue, th), t.SCaseAlt(patFalse, e)))
private def itea(i: target.SExprAtomic, th: target.SExpr, e: target.SExpr): target.SExpr =
target.SECaseAtomic(i, Array(target.SCaseAlt(patTrue, th), target.SCaseAlt(patFalse, e)))
// true/false case-patterns
private def patTrue: t.SCasePat =
t.SCPVariant(Identifier.assertFromString("P:M:bool"), IdString.Name.assertFromString("True"), 1)
private def patTrue: target.SCasePat =
target.SCPVariant(
Identifier.assertFromString("P:M:bool"),
IdString.Name.assertFromString("True"),
1,
)
private def patFalse: t.SCasePat =
t.SCPVariant(
private def patFalse: target.SCasePat =
target.SCPVariant(
Identifier.assertFromString("P:M:bool"),
IdString.Name.assertFromString("False"),
2,
@ -242,46 +255,48 @@ class AnfTest extends AnyWordSpec with Matchers {
// atoms
private def arg0 = t.SELocA(0)
private def arg1 = t.SELocA(1)
private def arg2 = t.SELocA(2)
private def arg3 = t.SELocA(3)
private def free0 = t.SELocF(0)
private def stack1 = t.SELocS(1)
private def stack2 = t.SELocS(2)
private def stack3 = t.SELocS(3)
private def arg0 = target.SELocA(0)
private def arg1 = target.SELocA(1)
private def arg2 = target.SELocA(2)
private def arg3 = target.SELocA(3)
private def free0 = target.SELocF(0)
private def stack1 = target.SELocS(1)
private def stack2 = target.SELocS(2)
private def stack3 = target.SELocS(3)
private def num0 = num(0)
private def num1 = num(1)
private def num2 = num(2)
private def num(n: Long): t.SExprAtomic = t.SEValue(SInt64(n))
private def num(n: Long): target.SExprAtomic = target.SEValue(SInt64(n))
// We have different expression types before/after the ANF transform, so we different constructors.
// Use "s" (for "source") as a prefix to distinguish.
private def slam(n: Int, body: s.SExpr): s.SExpr = s.SEMakeClo(Array(), n, body)
private def sclo1(fv: s.SELoc, n: Int, body: s.SExpr): s.SExpr = s.SEMakeClo(Array(fv), n, body)
private def sapp(func: s.SExpr, arg: s.SExpr): s.SExpr = s.SEAppGeneral(func, Array(arg))
private def sbinop(op: SBuiltinPure, x: s.SExpr, y: s.SExpr): s.SExpr =
s.SEApp(s.SEBuiltin(op), Array(x, y))
private def sbinop(op: SBuiltinArithmetic, x: s.SExpr, y: s.SExpr): s.SExpr =
s.SEApp(s.SEBuiltin(op), Array(x, y))
private def sapp2(func: s.SExpr, arg1: s.SExpr, arg2: s.SExpr): s.SExpr =
s.SEAppGeneral(func, Array(arg1, arg2))
private def site(i: s.SExpr, t: s.SExpr, e: s.SExpr): s.SExpr =
s.SECase(i, Array(s.SCaseAlt(patTrue, t), s.SCaseAlt(patFalse, e)))
private def sarg0 = s.SELocA(0)
private def sarg1 = s.SELocA(1)
private def sarg2 = s.SELocA(2)
private def sarg3 = s.SELocA(3)
private def sfree0 = s.SELocF(0)
private def slam(n: Int, body: source.SExpr): source.SExpr = source.SEMakeClo(Array(), n, body)
private def sclo1(fv: source.SELoc, n: Int, body: source.SExpr): source.SExpr =
source.SEMakeClo(Array(fv), n, body)
private def sapp(func: source.SExpr, arg: source.SExpr): source.SExpr =
source.SEAppGeneral(func, Array(arg))
private def sbinop(op: SBuiltinPure, x: source.SExpr, y: source.SExpr): source.SExpr =
source.SEApp(source.SEBuiltin(op), Array(x, y))
private def sbinop(op: SBuiltinArithmetic, x: source.SExpr, y: source.SExpr): source.SExpr =
source.SEApp(source.SEBuiltin(op), Array(x, y))
private def sapp2(func: source.SExpr, arg1: source.SExpr, arg2: source.SExpr): source.SExpr =
source.SEAppGeneral(func, Array(arg1, arg2))
private def site(i: source.SExpr, t: source.SExpr, e: source.SExpr): source.SExpr =
source.SECase(i, Array(source.SCaseAlt(patTrue, t), source.SCaseAlt(patFalse, e)))
private def sarg0 = source.SELocA(0)
private def sarg1 = source.SELocA(1)
private def sarg2 = source.SELocA(2)
private def sarg3 = source.SELocA(3)
private def sfree0 = source.SELocF(0)
private def snum0 = snum(0)
private def snum1 = snum(1)
private def snum2 = snum(2)
private def snum(n: Long): s.SExprAtomic = s.SEValue(SInt64(n))
private def snum(n: Long): source.SExprAtomic = source.SEValue(SInt64(n))
// run a test...
private def testTransform(
original: s.SExpr,
expected: t.SExpr,
original: source.SExpr,
expected: target.SExpr,
show: Boolean = false,
): Assertion = {
val transformed = flattenToAnf(original)
@ -293,7 +308,7 @@ class AnfTest extends AnyWordSpec with Matchers {
transformed shouldBe (expected)
}
private def pp(e: t.SExpr): String = {
private def pp(e: target.SExpr): String = {
prettySExpr(0)(e).render(80)
}