Ensure stack-safety during closure-conversion. (#11778)

ClosureConversion -> Suffix with "Old"

CHANGELOG_BEGIN
CHANGELOG_END

ClosureConversion old-vs-new diff check

ClosureConversionNew, first cut. All tests in SBuiltinTest work.

In addition we change some Array --> List in SExpr1 (for human pp).
And we throw away ClosureConversionDup.

adapt AnfTest from Array to List change for SExpr1

all tests pass in daml-lf/interpreter

remove SExpr0.SELet1General

reorder things

testing for stack-safety of closure conversion

file/class renames

improve naming

pass cont as sep arg to commit (move out of Up/Down)

comment stack-safe closure conversion

fix bug: failed to use env1

fix 2x unmoored doc comment

comment stack safety testing

Remove old closure-conversion code & diff-check between old/new.

loose StackSafe suffix on ClosureConversion class/file

rename StackSafetyTest.scala to ClosureConversionTest.scala

prefer "sealed abstract class" to "sealed trait"

fvs.zipWithIndex --> fvs.view.zipWithIndex

(SExpr1) SEAppGeneral -> SEApp; prefer List to Array in SEApp/SECase

prefer xs.toArray to Array(xs: _*)

access SExpr0 via "source."

two more .view

improve comment and fix typo

link to Issue

switch to a continuation stack; avoids nesting in the Cont type
This commit is contained in:
nickchapman-da 2021-11-24 11:47:39 +00:00 committed by GitHub
parent e63c80dddd
commit 970243dd46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 479 additions and 184 deletions

View File

@ -93,16 +93,16 @@ object PlaySpeedy {
// The trailing numeral is the number of args at the scala level
def decrement1(x: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), Array(x, SEValue(SInt64(1))))
def decrement1(x: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), List(x, SEValue(SInt64(1))))
val decrement = SEAbs(1, decrement1(SEVar(1)))
def subtract2(x: SExpr, y: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), Array(x, y))
def subtract2(x: SExpr, y: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), List(x, y))
val subtract = SEAbs(2, subtract2(SEVar(2), SEVar(1)))
def twice2(f: SExpr, x: SExpr): SExpr = SEApp(f, Array(SEApp(f, Array(x))))
def twice2(f: SExpr, x: SExpr): SExpr = SEApp(f, List(SEApp(f, List(x))))
val twice = SEAbs(2, twice2(SEVar(2), SEVar(1)))
def thrice2(f: SExpr, x: SExpr): SExpr = SEApp(f, Array(SEApp(f, Array(SEApp(f, Array(x))))))
def thrice2(f: SExpr, x: SExpr): SExpr = SEApp(f, List(SEApp(f, List(SEApp(f, List(x))))))
val thrice = SEAbs(2, thrice2(SEVar(2), SEVar(1)))
val examples = List(
@ -119,30 +119,30 @@ object PlaySpeedy {
(
"subF", //88-55
33,
SEApp(subtract, Array(num(88), num(55))),
SEApp(subtract, List(num(88), num(55))),
),
(
"thrice", // thrice (\x -> x - 1) 0
-3,
SEApp(thrice, Array(decrement, num(0))),
SEApp(thrice, List(decrement, num(0))),
),
(
"thrice-thrice", //thrice thrice (\x -> x - 1) 0
-27,
SEApp(thrice, Array(thrice, decrement, num(0))),
SEApp(thrice, List(thrice, decrement, num(0))),
),
(
"free", // let (a,b,c) = (30,100,21) in twice (\x -> x - (a-c)) b
82,
SELet1General(
num(30),
SELet1General(
num(100),
SELet1General(
num(21),
SELet(
List(num(30)),
SELet(
List(num(100)),
SELet(
List(num(21)),
SEApp(
twice,
Array(SEAbs(1, subtract2(SEVar(1), subtract2(SEVar(4), SEVar(2)))), SEVar(2)),
List(SEAbs(1, subtract2(SEVar(1), subtract2(SEVar(4), SEVar(2)))), SEVar(2)),
),
), //100
),

View File

@ -301,7 +301,7 @@ private[lf] object Anf {
case source.SEVal(x) => Bounce(() => transform(depth, target.SEVal(x), k))
case source.SEAppGeneral(func, args) =>
case source.SEApp(func, args) =>
// It's safe to perform ANF if the func-expression has no effects when evaluated.
val safeFunc =
func match {
@ -311,22 +311,22 @@ private[lf] object Anf {
}
// It's also safe to perform ANF for applications of a single argument.
if (safeFunc || args.size == 1) {
transformMultiApp[A](depth, env, func, args, k)(transform)
transformMultiApp[A](depth, env, func, args.toArray, k)(transform)
} else {
transformMultiAppSafely[A](depth, env, func, args, k)(transform)
transformMultiAppSafely[A](depth, env, func, args.toArray, k)(transform)
}
case source.SEMakeClo(fvs0, arity, body0) =>
val fvs = fvs0.map((loc) => makeRelativeL(depth)(makeAbsoluteL(env, loc)))
val body = flattenToAnfInternal(body0).wrapped
Bounce(() => transform(depth, target.SEMakeClo(fvs, arity, body), k))
Bounce(() => transform(depth, target.SEMakeClo(fvs.toArray, arity, body), k))
case source.SECase(scrut, alts0) => {
Bounce(() =>
atomizeExp(depth, env, scrut, k) { (depth, scrut, txK) =>
val scrut1 = makeRelativeA(depth)(scrut)
Bounce(() =>
flattenAlts(depth, env, alts0) { alts =>
flattenAlts(depth, env, alts0.toArray) { alts =>
Bounce(() => transform(depth, target.SECaseAtomic(scrut1, alts), txK))
}
)

View File

@ -3,149 +3,323 @@
package com.daml.lf.speedy
/** Closure Conversion (Phase of the speedy compiler pipeline)
/** Closure Conversion (Phase of the speedy compiler pipeline)
*
* This compilation phase transforms from SExpr0 to SExpr1.
* This compilation phase transforms from SExpr0 to SExpr1.
*/
import com.daml.lf.data.Ref
import com.daml.lf.speedy.SExpr.SCasePat
import com.daml.lf.speedy.{SExpr0 => source}
import com.daml.lf.speedy.{SExpr1 => target}
import scala.annotation.tailrec
private[speedy] object ClosureConversion {
case class CompilationError(error: String) extends RuntimeException(error, null, true, false)
private[speedy] def closureConvert(source0: source.SExpr): target.SExpr = {
/** Convert abstractions in a speedy expression into
* explicit closure creations.
* This step computes the free variables in an abstraction
* body, then translates the references in the body into
* references to the immediate top of the argument stack,
* and changes the abstraction into a closure creation node
* describing the free variables that need to be captured.
*
* For example:
* SELet(..two-bindings..) in
* SEAbs(2,
* SEVar(4) .. [reference to first let-bound variable]
* SEVar(2)) [reference to first function-arg]
* =>
* SELet(..two-bindings..) in
* SEMakeClo(
* Array(SELocS(2)), [capture the first let-bound variable, from the stack]
* 2,
* SELocF(0) .. [reference the first let-bound variable via the closure]
* SELocA(0)) [reference the first function arg]
*/
// TODO: Recode the 'Env' management to avoid the polynomial-complexity of 'shift'. Issue #11830
case class Env(mapping: Map[Int, target.SELoc]) {
// TODO: Introduce a new type expression for the result of closure conversion
private[speedy] def closureConvert(expr: source.SExpr): target.SExpr = {
closureConvert(Map.empty, expr)
}
def lookup(i: Int): target.SELoc =
mapping.get(i) match {
case Some(loc) => loc
case None =>
throw sys.error(s"lookup($i),in:$mapping")
}
private def closureConvert(remaps: Map[Int, target.SELoc], expr: source.SExpr): target.SExpr = {
// remaps is a function which maps the relative offset from variables (SEVar) to their runtime location
// The Map must contain a binding for every variable referenced.
// The Map is consulted when translating variable references (SEVar) and free variables of an abstraction (SEAbs)
def remap(i: Int): target.SELoc =
remaps.get(i) match {
case Some(loc) => loc
case None =>
throw CompilationError(s"remap($i),remaps=$remaps")
def shift(n: Int): Env = {
def shiftLoc(loc: target.SELoc, n: Int): target.SELoc = loc match {
case target.SELocS(i) => target.SELocS(i + n)
case target.SELocA(_) | target.SELocF(_) => loc
}
// We must update both the keys of the map (the relative-indexes from the original SEVar)
// And also any values in the map which are stack located (SELocS), which are also indexed relatively
val m1 = mapping.map { case (k, loc) => (n + k, shiftLoc(loc, n)) }
// And create mappings for the `n` new stack items
val m2 = (1 to n).view.map(i => (i, target.SELocS(i)))
Env(m1 ++ m2)
}
expr match {
case source.SEVar(i) => remap(i)
case source.SEVal(ref) => target.SEVal(ref)
case source.SEBuiltin(b) => target.SEBuiltin(b)
case source.SEValue(v) => target.SEValue(v)
case source.SELocation(loc, body) =>
target.SELocation(loc, closureConvert(remaps, body))
}
case source.SEAbs(0, _) =>
throw CompilationError("empty SEAbs")
case source.SEAbs(arity, body) =>
val fvs = freeVars(body, arity).toList.sorted
val newRemapsF: Map[Int, target.SELoc] = fvs.zipWithIndex.map { case (orig, i) =>
object Env {
def apply(): Env = {
Env(Map.empty)
}
def absBody(arity: Int, fvs: List[Int]): Env = {
val newRemapsF: Map[Int, target.SELoc] = fvs.view.zipWithIndex.map { case (orig, i) =>
(orig + arity) -> target.SELocF(i)
}.toMap
val newRemapsA = (1 to arity).map { case i =>
val newRemapsA = (1 to arity).view.map { case i =>
i -> target.SELocA(arity - i)
}
// The keys in newRemapsF and newRemapsA are disjoint
val newBody = closureConvert(newRemapsF ++ newRemapsA, body)
target.SEMakeClo(fvs.map(remap).toArray, arity, newBody)
case source.SEAppGeneral(fun, args) =>
val newFun = closureConvert(remaps, fun)
val newArgs = args.map(closureConvert(remaps, _))
target.SEApp(newFun, newArgs)
case source.SECase(scrut, alts) =>
target.SECase(
closureConvert(remaps, scrut),
alts.map { case source.SCaseAlt(pat, body) =>
val n = pat.numArgs
target.SCaseAlt(
pat,
closureConvert(shift(remaps, n), body),
)
},
)
case source.SELet(bounds, body) =>
target.SELet(
bounds.zipWithIndex.map { case (b, i) =>
closureConvert(shift(remaps, i), b)
},
closureConvert(shift(remaps, bounds.length), body),
)
case source.SETryCatch(body, handler) =>
target.SETryCatch(
closureConvert(remaps, body),
closureConvert(shift(remaps, 1), handler),
)
case source.SEScopeExercise(body) =>
target.SEScopeExercise(closureConvert(remaps, body))
case source.SELabelClosure(label, expr) =>
target.SELabelClosure(label, closureConvert(remaps, expr))
case source.SELet1General(bound, body) =>
target.SELet1General(closureConvert(remaps, bound), closureConvert(shift(remaps, 1), body))
val m1 = newRemapsF ++ newRemapsA
Env(m1)
}
}
/** Closure-conversion, Traversal:
*
* To ensure stack-safety, the input expression is traversed by a single tail-recursive 'loop'.
* During the 'Traversal', we are either:
* - going 'Down' a source expression (subtree), with an 'Env' for context, or
* - coming 'Up' with a target expression (result for a subtree)
*
* In both cases we have a continuation ('List[Cont]') argument which says how to proceed.
*/
sealed abstract class Traversal
object Traversal {
final case class Down(exp: source.SExpr, env: Env) extends Traversal
final case class Up(exp: target.SExpr) extends Traversal
}
import Traversal._
/** Closure Conversion, Cont:
*
* The multiple forms for a continuation describe how the result of transforming a
* sub-expression should be embedded in the continuing traversal. The continuation
* forms correspond to the source expression forms: specifically, the location of
* recursive expression instances (values of type SExpr).
*
* For expression forms with no recursive instance (i.e. SEVar, SEVal), there are
* no corresponding continuation forms.
*
* For expression forms with a single recursive instance (i.e. SELocation), there
* is a single continuation form: (Cont.Location).
*
* For expression forms with two recursive instances (i.e. SETryCatch), there are
* two corresponding continuation forms: (Cont.TryCatch1, Cont.TryCatch2).
*
* For the more complex expression forms containing a list of recursive instances
* (i.e. SEAppGeneral), the corresponding continuation forms are also more complex,
* but will generally have two cases (i.e. Cont.App1, Cont.App2), corresponding to
* the Nil/Cons cases of the list of recursive instances.
*
* And so on. In effect, 'Cont' is a zipper type for expressions.
*
* Another way to understand the continuation forms is by observing the presence of
* an 'env: Env' component indicates more source-expression processing to be done
* (generally with the components following the 'env'). Any components before the
* 'env' (or all components if there is no 'env') represent transform-(sub)-results
* which need combining into the final result.
*/
sealed abstract class Cont
object Cont {
final case class Location(loc: Ref.Location) extends Cont
final case class Abs(arity: Int, fvs: List[target.SELoc]) extends Cont
final case class App1(env: Env, args: List[source.SExpr]) extends Cont
final case class App2(
funDone: target.SExpr,
argsDone: List[target.SExpr],
env: Env,
args: List[source.SExpr],
) extends Cont
final case class Case1(env: Env, alts: List[source.SCaseAlt]) extends Cont
final case class Case2(
scrut: target.SExpr,
altsDone: List[target.SCaseAlt],
pat: SCasePat,
env: Env,
alts: List[source.SCaseAlt],
) extends Cont
final case class Let1(
boundsDone: List[target.SExpr],
env: Env,
bounds: List[source.SExpr],
body: source.SExpr,
) extends Cont
final case class Let2(
boundsDone: List[target.SExpr]
) extends Cont
final case class TryCatch1(
env: Env,
handler: source.SExpr,
) extends Cont
final case class TryCatch2(
body: target.SExpr
) extends Cont
final case object ScopeExercise extends Cont
final case class LabelClosure(label: Profile.Label) extends Cont
}
/* The entire traversal in performed by this single tail recursive 'loop' function.
*
* The 'loop' has two arguments:
* - The traversal item (Down/Up), and a continuation-stack 'conts'.
*
* The traversal is matched to see if we are going 'Down`, or 'Up.
* - When going 'Down', we perform case-analysis on the source-expression being traversed.
* - when going 'Up, we perform case-analysis on the continuation-stack.
* When the continuation-stack is empty, we are finished.
*/
@tailrec
def loop(traversal: Traversal, conts: List[Cont]): target.SExpr = {
traversal match {
// Going Down: match on expression form...
case Down(exp, env) =>
exp match {
case source.SEVar(i) => loop(Up(env.lookup(i)), conts)
case source.SEVal(x) => loop(Up(target.SEVal(x)), conts)
case source.SEBuiltin(x) => loop(Up(target.SEBuiltin(x)), conts)
case source.SEValue(x) => loop(Up(target.SEValue(x)), conts)
case source.SELocation(loc, body) =>
loop(Down(body, env), Cont.Location(loc) :: conts)
case source.SEAbs(arity, body) =>
val fvsAsListInt = freeVars(body, arity).toList.sorted
val fvs = fvsAsListInt.map(i => env.lookup(i))
loop(Down(body, Env.absBody(arity, fvsAsListInt)), Cont.Abs(arity, fvs) :: conts)
case source.SEApp(fun, args) =>
loop(Down(fun, env), Cont.App1(env, args) :: conts)
case source.SECase(scrut, alts) =>
loop(Down(scrut, env), Cont.Case1(env, alts) :: conts)
case source.SELet(bounds, body) =>
bounds match {
case Nil =>
loop(Down(body, env), Cont.Let2(Nil) :: conts)
case bound :: bounds =>
loop(Down(bound, env), Cont.Let1(Nil, env, bounds, body) :: conts)
}
case source.SETryCatch(body, handler) =>
loop(Down(body, env), Cont.TryCatch1(env, handler) :: conts)
case source.SEScopeExercise(body) =>
loop(Down(body, env), Cont.ScopeExercise :: conts)
case source.SELabelClosure(label, expr) =>
loop(Down(expr, env), Cont.LabelClosure(label) :: conts)
}
// Going Up: match on continuation...
case Up(result) =>
conts match {
case Nil => result // The final result of the tail-recursive 'loop'.
case cont :: conts =>
cont match {
// We rebind the current result (i.e. 'val scrut = result') to help
// indicate how it is embedded into the target expression being constructed.
case Cont.Location(loc) =>
val body = result
loop(Up(target.SELocation(loc, body)), conts)
case Cont.Abs(arity, fvs) =>
val body = result
loop(Up(target.SEMakeClo(fvs, arity, body)), conts)
case Cont.App1(env, args) =>
val fun = result
args match {
case Nil =>
loop(Up(target.SEApp(fun, Nil)), conts)
case arg :: args =>
loop(Down(arg, env), Cont.App2(fun, Nil, env, args) :: conts)
}
case Cont.App2(fun, argsDone0, env, args) =>
val argsDone = result :: argsDone0
args match {
case Nil =>
loop(Up(target.SEApp(fun, argsDone.reverse)), conts)
case arg :: args =>
loop(Down(arg, env), Cont.App2(fun, argsDone, env, args) :: conts)
}
case Cont.Case1(env, alts) =>
val scrut = result
alts match {
case Nil =>
loop(Up(target.SECase(scrut, Nil)), conts)
case source.SCaseAlt(pat, rhs) :: alts =>
val n = pat.numArgs
loop(Down(rhs, env.shift(n)), Cont.Case2(scrut, Nil, pat, env, alts) :: conts)
}
case Cont.Case2(scrut, altsDone0, pat, env, alts) =>
val altsDone = target.SCaseAlt(pat, result) :: altsDone0
alts match {
case Nil =>
loop(Up(target.SECase(scrut, altsDone.reverse)), conts)
case source.SCaseAlt(pat, rhs) :: alts =>
val n = pat.numArgs
val env1 = env.shift(n)
loop(Down(rhs, env1), Cont.Case2(scrut, altsDone, pat, env, alts) :: conts)
}
case Cont.Let1(boundsDone0, env, bounds, body) =>
val boundsDone = result :: boundsDone0
val depth = boundsDone.length
val env1 = env.shift(depth)
bounds match {
case Nil =>
loop(Down(body, env1), Cont.Let2(boundsDone) :: conts)
case bound :: bounds =>
loop(Down(bound, env1), Cont.Let1(boundsDone, env, bounds, body) :: conts)
}
case Cont.Let2(boundsDone) =>
val body = result
loop(Up(target.SELet(boundsDone.reverse, body)), conts)
case Cont.TryCatch1(env, handler) =>
val body = result
loop(Down(handler, env.shift(1)), Cont.TryCatch2(body) :: conts)
case Cont.TryCatch2(body) =>
val handler = result
loop(Up(target.SETryCatch(body, handler)), conts)
case Cont.ScopeExercise =>
val body = result
loop(Up(target.SEScopeExercise(body)), conts)
case Cont.LabelClosure(label) =>
val expr = result
loop(Up(target.SELabelClosure(label, expr)), conts)
}
}
}
}
/* The (stack-safe) transformation is started here, passing the original source
* expression (source1), an empty environment, and an empty continuation-stack.
*/
loop(Down(source0, Env()), Nil)
}
// Modify/extend `remaps` to reflect when new values are pushed on the stack. This
// happens as we traverse into SELet and SECase bodies which have bindings which at
// runtime will appear on the stack.
// We must modify `remaps` because it is keyed by indexes relative to the end of the stack.
// And any values in the map which are of the form SELocS must also be _shifted_
// because SELocS indexes are also relative to the end of the stack.
private[this] def shift(remaps: Map[Int, target.SELoc], n: Int): Map[Int, target.SELoc] = {
// We must update both the keys of the map (the relative-indexes from the original SEVar)
// And also any values in the map which are stack located (SELocS), which are also indexed relatively
val m1 = remaps.map { case (k, loc) => (n + k, shiftLoc(loc, n)) }
// And create mappings for the `n` new stack items
val m2 = (1 to n).map(i => (i, target.SELocS(i)))
m1 ++ m2
}
private[this] def shiftLoc(loc: target.SELoc, n: Int): target.SELoc = loc match {
case target.SELocS(i) => target.SELocS(i + n)
case target.SELocA(_) | target.SELocF(_) => loc
}
// TODO: Recode to avoid polynomial-complexity of 'freeVars' computation. Issue #11830
/** Compute the free variables in a speedy expression.
* The returned free variables are de bruijn indices
* adjusted to the stack of the caller.
*/
private[this] def freeVars(expr: source.SExpr, initiallyBound: Int): Set[Int] = {
// @tailrec // TODO: This implementation is not stack-safe. Issue #11830
def go(expr: source.SExpr, bound: Int, free: Set[Int]): Set[Int] =
expr match {
case source.SEVar(i) =>
@ -155,7 +329,7 @@ private[speedy] object ClosureConversion {
case _: source.SEValue => free
case source.SELocation(_, body) =>
go(body, bound, free)
case source.SEAppGeneral(fun, args) =>
case source.SEApp(fun, args) =>
args.foldLeft(go(fun, bound, free))((acc, arg) => go(arg, bound, acc))
case source.SEAbs(n, body) =>
go(body, bound + n, free)
@ -174,9 +348,6 @@ private[speedy] object ClosureConversion {
go(body, bound, go(handler, 1 + bound, free))
case source.SEScopeExercise(body) =>
go(body, bound, free)
case _: source.SELet1General =>
throw CompilationError(s"freeVars: unexpected $expr")
}
go(expr, initiallyBound, Set.empty)

View File

@ -91,7 +91,7 @@ private[lf] object Compiler {
private val SEGetTime = s.SEBuiltin(SBGetTime)
private def SBCompareNumeric(b: SBuiltinPure) =
s.SEAbs(3, s.SEApp(s.SEBuiltin(b), Array(s.SEVar(2), s.SEVar(1))))
s.SEAbs(3, s.SEApp(s.SEBuiltin(b), List(s.SEVar(2), s.SEVar(1))))
private val SBLessNumeric = SBCompareNumeric(SBLess)
private val SBLessEqNumeric = SBCompareNumeric(SBLessEq)
private val SBGreaterNumeric = SBCompareNumeric(SBGreater)
@ -128,14 +128,14 @@ private[lf] object Compiler {
}
// Hand-implemented `map` uses less stack.
private def mapToArray[A, B: ClassTag](input: ImmArray[A])(f: A => B): Array[B] = {
private def mapToArray[A, B: ClassTag](input: ImmArray[A])(f: A => B): List[B] = {
val output = Array.ofDim[B](input.length)
var i = 0
input.foreach { value =>
output(i) = f(value)
i += 1
}
output
output.toList
}
}
@ -263,7 +263,7 @@ private[lf] final class Compiler(
case None => expr
}
private[this] def app(f: s.SExpr, a: s.SExpr) = s.SEApp(f, Array(a))
private[this] def app(f: s.SExpr, a: s.SExpr) = s.SEApp(f, List(a))
private[this] def let(env: Env, bound: s.SExpr)(f: (Position, Env) => s.SExpr): s.SELet =
f(env.nextPosition, env.pushVar) match {
@ -513,7 +513,7 @@ private[lf] final class Compiler(
case ECons(_, front, tail) =>
// TODO(JM): Consider emitting SEValue(SList(...)) for
// constant lists?
val args = (front.iterator.map(compile(env, _)) ++ Seq(compile(env, tail))).toArray
val args = (front.iterator.map(compile(env, _)) ++ Seq(compile(env, tail))).toList
if (front.length == 1) {
s.SEApp(s.SEBuiltin(SBCons), args)
} else {
@ -742,7 +742,7 @@ private[lf] final class Compiler(
else
s.SEApp(
s.SEBuiltin(SBRecCon(tApp.tycon, fields.map(_._1))),
fields.iterator.map(f => compile(env, f._2)).toArray,
fields.iterator.map(f => compile(env, f._2)).toList,
)
private[this] def compileERecUpd(env: Env, erecupd: ERecUpd): s.SExpr = {
@ -899,7 +899,7 @@ private[lf] final class Compiler(
case _ if args.isEmpty =>
compile(env, expr0)
case _ =>
s.SEApp(compile(env, expr0), args.toArray)
s.SEApp(compile(env, expr0), args)
}
private[this] def translateType(env: Env, typ: Type): Option[s.SExpr] =
@ -1305,7 +1305,7 @@ private[lf] final class Compiler(
(Iterator(compile(env2, tmpl.precond)) ++ implementsPrecondsIterator ++ Iterator(
s.SEValue.EmptyList
)).to(ImmArray)
val preconds = s.SEApp(s.SEBuiltin(SBConsMany(precondsArray.length - 1)), precondsArray.toArray)
val preconds = s.SEApp(s.SEBuiltin(SBConsMany(precondsArray.length - 1)), precondsArray.toList)
// We check precondition in a separated builtin to prevent
// further evaluation of agreement, signatories, observers and key
// in case of failed precondition.

View File

@ -52,7 +52,7 @@ private[speedy] sealed abstract class SBuiltin(val arity: Int) {
// TODO: move this into the speedy compiler code
private[lf] def apply(args: compileTime.SExpr*): compileTime.SExpr =
compileTime.SEApp(compileTime.SEBuiltin(this), args.toArray)
compileTime.SEApp(compileTime.SEBuiltin(this), args.toList)
// TODO: avoid constructing application expression at run time
private[lf] def apply(args: runTime.SExpr*): runTime.SExpr =

View File

@ -408,7 +408,7 @@ object SExpr {
def modName: ModuleName = ref.qualifiedName.module
// TODO: move this into the speedy compiler code
private[this] val eval = compileTime.SEVal(this)
def apply(args: compileTime.SExpr*) = compileTime.SEApp(eval, args.toArray)
def apply(args: compileTime.SExpr*) = compileTime.SEApp(eval, args.toList)
}
// references to definitions that come from the archive

View File

@ -30,14 +30,14 @@ package speedy
*
* Summary of which constructors are contained by: SExp0, SExpr1 and SExpr:
*
* - In SExpr{0,1,} (everywhere): SEAppGeneral, SEBuiltin, SELabelClosure, SELet1General,
* - In SExpr{0,1,} (everywhere): SEAppGeneral, SEBuiltin, SELabelClosure,
* SELocation, SEScopeExercise, SETryCatch, SEVal, SEValue,
*
* - In SExpr0: SEAbs, SEVar
*
* - In SExpr{0,1}: SECase, SELet
*
* - In SExpr{1,}: SELocA, SELocF, SELocS, SEMakeClo,
* - In SExpr{1,}: SELocA, SELocF, SELocS, SEMakeClo, SELet1General,
*
* - In SExpr: SEAppAtomicFun, SEAppAtomicGeneral, SEAppAtomicSaturatedBuiltin,
* SECaseAtomic, SELet1Builtin, SELet1BuiltinArithmetic
@ -76,13 +76,7 @@ private[speedy] object SExpr0 {
object SEValue extends SValueContainer[SEValue]
/** Function application */
final case class SEAppGeneral(fun: SExpr, args: Array[SExpr]) extends SExpr with SomeArrayEquals
object SEApp {
def apply(fun: SExpr, args: Array[SExpr]): SExpr = {
SEAppGeneral(fun, args)
}
}
final case class SEApp(fun: SExpr, args: List[SExpr]) extends SExpr
/** Lambda abstraction. Transformed to SEMakeClo during closure conversion */
final case class SEAbs(arity: Int, body: SExpr) extends SExpr
@ -96,13 +90,7 @@ private[speedy] object SExpr0 {
}
/** Pattern match. */
final case class SECase(scrut: SExpr, alts: Array[SCaseAlt]) extends SExpr with SomeArrayEquals
/** A let-expression with a single RHS
* This form only exists *during* the ANF transformation, but not when the ANF
* transformation is finished.
*/
final case class SELet1General(rhs: SExpr, body: SExpr) extends SExpr with SomeArrayEquals
final case class SECase(scrut: SExpr, alts: List[SCaseAlt]) extends SExpr
/** A non-recursive, non-parallel let block.
* It is used as an intermediary data structure by the compiler to

View File

@ -35,20 +35,13 @@ private[speedy] object SExpr1 {
/** Function application:
* General case: 'fun' and 'args' are any kind of expression
*/
final case class SEAppGeneral(fun: SExpr, args: Array[SExpr]) extends SExpr with SomeArrayEquals
object SEApp {
def apply(fun: SExpr, args: Array[SExpr]): SExpr = {
SEAppGeneral(fun, args)
}
}
final case class SEApp(fun: SExpr, args: List[SExpr]) extends SExpr
/** Closure creation. Create a new closure object storing the free variables
* in 'body'.
*/
final case class SEMakeClo(fvs: Array[SELoc], arity: Int, body: SExpr)
extends SExpr
with SomeArrayEquals
final case class SEMakeClo(fvs: List[SELoc], arity: Int, body: SExpr) extends SExpr
/** SELoc -- Reference to the runtime location of a variable.
*
@ -67,7 +60,7 @@ private[speedy] object SExpr1 {
final case class SELocF(n: Int) extends SELoc
/** Pattern match. */
final case class SECase(scrut: SExpr, alts: Array[SCaseAlt]) extends SExpr with SomeArrayEquals
final case class SECase(scrut: SExpr, alts: List[SCaseAlt]) extends SExpr
/** A let-expression with a single RHS
* This form only exists *during* the ANF transformation, but not when the ANF

View File

@ -129,7 +129,7 @@ class AnfTest extends AnyWordSpec with Matchers {
"error applied to 1 arg" should {
"be transformed to ANF as expected" in {
val original = slam(1, source.SEApp(source.SEBuiltin(SBError), Array(sarg0)))
val original = slam(1, source.SEApp(source.SEBuiltin(SBError), List(sarg0)))
val expected = lam(1, target.SEAppAtomicSaturatedBuiltin(SBError, Array(arg0)))
testTransform(original, expected)
}
@ -137,7 +137,7 @@ class AnfTest extends AnyWordSpec with Matchers {
"error (over) applied to 2 arg" should {
"be transformed to ANF as expected" in {
val original = slam(2, source.SEApp(source.SEBuiltin(SBError), Array(sarg0, sarg1)))
val original = slam(2, source.SEApp(source.SEBuiltin(SBError), List(sarg0, sarg1)))
val expected = lam(2, target.SEAppAtomicFun(target.SEBuiltin(SBError), Array(arg0, arg1)))
testTransform(original, expected)
}
@ -270,19 +270,19 @@ class AnfTest extends AnyWordSpec with Matchers {
// We have different expression types before/after the ANF transform, so we different constructors.
// Use "s" (for "source") as a prefix to distinguish.
private def slam(n: Int, body: source.SExpr): source.SExpr = source.SEMakeClo(Array(), n, body)
private def slam(n: Int, body: source.SExpr): source.SExpr = source.SEMakeClo(List(), n, body)
private def sclo1(fv: source.SELoc, n: Int, body: source.SExpr): source.SExpr =
source.SEMakeClo(Array(fv), n, body)
source.SEMakeClo(List(fv), n, body)
private def sapp(func: source.SExpr, arg: source.SExpr): source.SExpr =
source.SEAppGeneral(func, Array(arg))
source.SEApp(func, List(arg))
private def sbinop(op: SBuiltinPure, x: source.SExpr, y: source.SExpr): source.SExpr =
source.SEApp(source.SEBuiltin(op), Array(x, y))
source.SEApp(source.SEBuiltin(op), List(x, y))
private def sbinop(op: SBuiltinArithmetic, x: source.SExpr, y: source.SExpr): source.SExpr =
source.SEApp(source.SEBuiltin(op), Array(x, y))
source.SEApp(source.SEBuiltin(op), List(x, y))
private def sapp2(func: source.SExpr, arg1: source.SExpr, arg2: source.SExpr): source.SExpr =
source.SEAppGeneral(func, Array(arg1, arg2))
source.SEApp(func, List(arg1, arg2))
private def site(i: source.SExpr, t: source.SExpr, e: source.SExpr): source.SExpr =
source.SECase(i, Array(source.SCaseAlt(patTrue, t), source.SCaseAlt(patFalse, e)))
source.SECase(i, List(source.SCaseAlt(patTrue, t), source.SCaseAlt(patFalse, e)))
private def sarg0 = source.SELocA(0)
private def sarg1 = source.SELocA(1)
private def sarg2 = source.SELocA(2)

View File

@ -0,0 +1,143 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf.speedy
import com.daml.lf.data.Ref
import com.daml.lf.speedy.{SExpr0 => source}
import com.daml.lf.speedy.{SExpr1 => target}
import com.daml.lf.speedy.{SExpr => expr}
import com.daml.lf.speedy.{SValue => v}
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.prop.TableDrivenPropertyChecks
import scala.annotation.tailrec
class ClosureConversionTest extends AnyFreeSpec with Matchers with TableDrivenPropertyChecks {
import source._
// Construct one level of source-expression at various 'recursion-points'.
// This list is intended to be exhaustive.
private val location = (x: SExpr) => SELocation(loc, x)
private val abs1 = (x: SExpr) => SEAbs(1, x)
private val appF = (x: SExpr) => SEApp(x, List(leaf, leaf))
private val app1 = (x: SExpr) => SEApp(leaf, List(x, leaf))
private val app2 = (x: SExpr) => SEApp(leaf, List(leaf, x))
private val scrut = (x: SExpr) => SECase(x, List(SCaseAlt(pat, leaf), SCaseAlt(pat, leaf)))
private val alt1 = (x: SExpr) => SECase(leaf, List(SCaseAlt(pat, x), SCaseAlt(pat, leaf)))
private val alt2 = (x: SExpr) => SECase(leaf, List(SCaseAlt(pat, leaf), SCaseAlt(pat, x)))
private val let1 = (x: SExpr) => SELet(List(x, leaf), leaf)
private val let2 = (x: SExpr) => SELet(List(leaf, x), leaf)
private val letBody = (x: SExpr) => SELet(List(leaf, leaf), x)
private val tryCatch1 = (x: SExpr) => SETryCatch(leaf, x)
private val tryCatch2 = (x: SExpr) => SETryCatch(x, leaf)
private val labelClosure = (x: SExpr) => SELabelClosure(label, x)
"closure conversion" - {
// This is the code under test...
def transform(e: SExpr): target.SExpr = {
import com.daml.lf.speedy.ClosureConversion.closureConvert
closureConvert(e)
}
/* We test stack-safety by building deep expressions through each of the different
* recursion points of an expression, using one of the builder functions above, and
* then ensuring we can 'transform' the expression using 'closureConvert'.
*/
def runTest(depth: Int, cons: SExpr => SExpr) = {
// Make an expression by iterating the 'cons' function, 'depth' times
@tailrec def loop(x: SExpr, n: Int): SExpr = if (n == 0) x else loop(cons(x), n - 1)
val exp: SExpr = loop(leaf, depth)
val _: target.SExpr = transform(exp)
true
}
/* The testcases are split into two sets:
*
* For both sets the code under test is stack-safe, but the 2nd set provokes an
* unrelated quadratic-or-worse time-issue in the handling of 'Env' management and the
* free-vars computation, during the closure-conversion transform.
*/
val testCases1 = {
Table[String, SExpr => SExpr](
("name", "recursion-point"),
("Location", location),
("AppF", appF),
("App1", app1),
("App2", app2),
("Scrut", scrut),
("Let1", let1),
("TryCatch2", tryCatch2),
("Labelclosure", labelClosure),
)
}
// These 'quadratic' testcases pertain to recursion-points under a binder.
val testCases2 = {
Table[String, SExpr => SExpr](
("name", "recursion-point"),
("Abs", abs1),
("Alt1", alt1),
("Alt2", alt2),
("Let2", let2),
("LetBody", letBody),
("TryCatch1", tryCatch1),
)
}
{
// All tests. Shallow enough for pre-stack-safe closure-conversion code to pass.
val depth = 100
s"depth = $depth" - {
forEvery(testCases1 ++ testCases2) { (name: String, recursionPoint: SExpr => SExpr) =>
name in {
runTest(depth, recursionPoint)
}
}
}
}
{
// Only first set. At this depth we can be really sure that we are stack-safe.
val depth = 100000
s"depth = $depth" - {
forEvery(testCases1) { (name: String, recursionPoint: SExpr => SExpr) =>
name in {
runTest(depth, recursionPoint)
}
}
}
}
{
// Only 2nd set. This depth is not really deep enough to ensure stack-safety, but
// much deeper and the quadratic-or-worse time-complexity starts to seriously slow
// down the test run.
// TODO: fix quadratic time issue to allow these tests to be run at depth 100000.
val depth = 1000
s"depth = $depth" - {
forEvery(testCases2) { (name: String, recursionPoint: SExpr => SExpr) =>
name in {
runTest(depth, recursionPoint)
}
}
}
}
}
private val leaf = SEValue(v.SText("leaf"))
private val label: Profile.Label = expr.AnonymousClosure
private val pat: expr.SCasePat = expr.SCPCons
private val loc = Ref.Location(
Ref.PackageId.assertFromString("P"),
Ref.ModuleName.assertFromString("M"),
"X",
(1, 2),
(3, 4),
)
}