mirror of
https://github.com/digital-asset/daml.git
synced 2024-09-19 16:57:40 +03:00
Ensure stack-safety during closure-conversion. (#11778)
ClosureConversion -> Suffix with "Old" CHANGELOG_BEGIN CHANGELOG_END ClosureConversion old-vs-new diff check ClosureConversionNew, first cut. All tests in SBuiltinTest work. In addition we change some Array --> List in SExpr1 (for human pp). And we throw away ClosureConversionDup. adapt AnfTest from Array to List change for SExpr1 all tests pass in daml-lf/interpreter remove SExpr0.SELet1General reorder things testing for stack-safety of closure conversion file/class renames improve naming pass cont as sep arg to commit (move out of Up/Down) comment stack-safe closure conversion fix bug: failed to use env1 fix 2x unmoored doc comment comment stack safety testing Remove old closure-conversion code & diff-check between old/new. loose StackSafe suffix on ClosureConversion class/file rename StackSafetyTest.scala to ClosureConversionTest.scala prefer "sealed abstract class" to "sealed trait" fvs.zipWithIndex --> fvs.view.zipWithIndex (SExpr1) SEAppGeneral -> SEApp; prefer List to Array in SEApp/SECase prefer xs.toArray to Array(xs: _*) access SExpr0 via "source." two more .view improve comment and fix typo link to Issue switch to a continuation stack; avoids nesting in the Cont type
This commit is contained in:
parent
e63c80dddd
commit
970243dd46
@ -93,16 +93,16 @@ object PlaySpeedy {
|
||||
|
||||
// The trailing numeral is the number of args at the scala level
|
||||
|
||||
def decrement1(x: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), Array(x, SEValue(SInt64(1))))
|
||||
def decrement1(x: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), List(x, SEValue(SInt64(1))))
|
||||
val decrement = SEAbs(1, decrement1(SEVar(1)))
|
||||
|
||||
def subtract2(x: SExpr, y: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), Array(x, y))
|
||||
def subtract2(x: SExpr, y: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), List(x, y))
|
||||
val subtract = SEAbs(2, subtract2(SEVar(2), SEVar(1)))
|
||||
|
||||
def twice2(f: SExpr, x: SExpr): SExpr = SEApp(f, Array(SEApp(f, Array(x))))
|
||||
def twice2(f: SExpr, x: SExpr): SExpr = SEApp(f, List(SEApp(f, List(x))))
|
||||
val twice = SEAbs(2, twice2(SEVar(2), SEVar(1)))
|
||||
|
||||
def thrice2(f: SExpr, x: SExpr): SExpr = SEApp(f, Array(SEApp(f, Array(SEApp(f, Array(x))))))
|
||||
def thrice2(f: SExpr, x: SExpr): SExpr = SEApp(f, List(SEApp(f, List(SEApp(f, List(x))))))
|
||||
val thrice = SEAbs(2, thrice2(SEVar(2), SEVar(1)))
|
||||
|
||||
val examples = List(
|
||||
@ -119,30 +119,30 @@ object PlaySpeedy {
|
||||
(
|
||||
"subF", //88-55
|
||||
33,
|
||||
SEApp(subtract, Array(num(88), num(55))),
|
||||
SEApp(subtract, List(num(88), num(55))),
|
||||
),
|
||||
(
|
||||
"thrice", // thrice (\x -> x - 1) 0
|
||||
-3,
|
||||
SEApp(thrice, Array(decrement, num(0))),
|
||||
SEApp(thrice, List(decrement, num(0))),
|
||||
),
|
||||
(
|
||||
"thrice-thrice", //thrice thrice (\x -> x - 1) 0
|
||||
-27,
|
||||
SEApp(thrice, Array(thrice, decrement, num(0))),
|
||||
SEApp(thrice, List(thrice, decrement, num(0))),
|
||||
),
|
||||
(
|
||||
"free", // let (a,b,c) = (30,100,21) in twice (\x -> x - (a-c)) b
|
||||
82,
|
||||
SELet1General(
|
||||
num(30),
|
||||
SELet1General(
|
||||
num(100),
|
||||
SELet1General(
|
||||
num(21),
|
||||
SELet(
|
||||
List(num(30)),
|
||||
SELet(
|
||||
List(num(100)),
|
||||
SELet(
|
||||
List(num(21)),
|
||||
SEApp(
|
||||
twice,
|
||||
Array(SEAbs(1, subtract2(SEVar(1), subtract2(SEVar(4), SEVar(2)))), SEVar(2)),
|
||||
List(SEAbs(1, subtract2(SEVar(1), subtract2(SEVar(4), SEVar(2)))), SEVar(2)),
|
||||
),
|
||||
), //100
|
||||
),
|
||||
|
@ -301,7 +301,7 @@ private[lf] object Anf {
|
||||
|
||||
case source.SEVal(x) => Bounce(() => transform(depth, target.SEVal(x), k))
|
||||
|
||||
case source.SEAppGeneral(func, args) =>
|
||||
case source.SEApp(func, args) =>
|
||||
// It's safe to perform ANF if the func-expression has no effects when evaluated.
|
||||
val safeFunc =
|
||||
func match {
|
||||
@ -311,22 +311,22 @@ private[lf] object Anf {
|
||||
}
|
||||
// It's also safe to perform ANF for applications of a single argument.
|
||||
if (safeFunc || args.size == 1) {
|
||||
transformMultiApp[A](depth, env, func, args, k)(transform)
|
||||
transformMultiApp[A](depth, env, func, args.toArray, k)(transform)
|
||||
} else {
|
||||
transformMultiAppSafely[A](depth, env, func, args, k)(transform)
|
||||
transformMultiAppSafely[A](depth, env, func, args.toArray, k)(transform)
|
||||
}
|
||||
|
||||
case source.SEMakeClo(fvs0, arity, body0) =>
|
||||
val fvs = fvs0.map((loc) => makeRelativeL(depth)(makeAbsoluteL(env, loc)))
|
||||
val body = flattenToAnfInternal(body0).wrapped
|
||||
Bounce(() => transform(depth, target.SEMakeClo(fvs, arity, body), k))
|
||||
Bounce(() => transform(depth, target.SEMakeClo(fvs.toArray, arity, body), k))
|
||||
|
||||
case source.SECase(scrut, alts0) => {
|
||||
Bounce(() =>
|
||||
atomizeExp(depth, env, scrut, k) { (depth, scrut, txK) =>
|
||||
val scrut1 = makeRelativeA(depth)(scrut)
|
||||
Bounce(() =>
|
||||
flattenAlts(depth, env, alts0) { alts =>
|
||||
flattenAlts(depth, env, alts0.toArray) { alts =>
|
||||
Bounce(() => transform(depth, target.SECaseAtomic(scrut1, alts), txK))
|
||||
}
|
||||
)
|
||||
|
@ -3,149 +3,323 @@
|
||||
|
||||
package com.daml.lf.speedy
|
||||
|
||||
/** Closure Conversion (Phase of the speedy compiler pipeline)
|
||||
/** Closure Conversion (Phase of the speedy compiler pipeline)
|
||||
*
|
||||
* This compilation phase transforms from SExpr0 to SExpr1.
|
||||
* This compilation phase transforms from SExpr0 to SExpr1.
|
||||
*/
|
||||
|
||||
import com.daml.lf.data.Ref
|
||||
|
||||
import com.daml.lf.speedy.SExpr.SCasePat
|
||||
import com.daml.lf.speedy.{SExpr0 => source}
|
||||
import com.daml.lf.speedy.{SExpr1 => target}
|
||||
|
||||
import scala.annotation.tailrec
|
||||
|
||||
private[speedy] object ClosureConversion {
|
||||
|
||||
case class CompilationError(error: String) extends RuntimeException(error, null, true, false)
|
||||
private[speedy] def closureConvert(source0: source.SExpr): target.SExpr = {
|
||||
|
||||
/** Convert abstractions in a speedy expression into
|
||||
* explicit closure creations.
|
||||
* This step computes the free variables in an abstraction
|
||||
* body, then translates the references in the body into
|
||||
* references to the immediate top of the argument stack,
|
||||
* and changes the abstraction into a closure creation node
|
||||
* describing the free variables that need to be captured.
|
||||
*
|
||||
* For example:
|
||||
* SELet(..two-bindings..) in
|
||||
* SEAbs(2,
|
||||
* SEVar(4) .. [reference to first let-bound variable]
|
||||
* SEVar(2)) [reference to first function-arg]
|
||||
* =>
|
||||
* SELet(..two-bindings..) in
|
||||
* SEMakeClo(
|
||||
* Array(SELocS(2)), [capture the first let-bound variable, from the stack]
|
||||
* 2,
|
||||
* SELocF(0) .. [reference the first let-bound variable via the closure]
|
||||
* SELocA(0)) [reference the first function arg]
|
||||
*/
|
||||
// TODO: Recode the 'Env' management to avoid the polynomial-complexity of 'shift'. Issue #11830
|
||||
case class Env(mapping: Map[Int, target.SELoc]) {
|
||||
|
||||
// TODO: Introduce a new type expression for the result of closure conversion
|
||||
private[speedy] def closureConvert(expr: source.SExpr): target.SExpr = {
|
||||
closureConvert(Map.empty, expr)
|
||||
}
|
||||
def lookup(i: Int): target.SELoc =
|
||||
mapping.get(i) match {
|
||||
case Some(loc) => loc
|
||||
case None =>
|
||||
throw sys.error(s"lookup($i),in:$mapping")
|
||||
}
|
||||
|
||||
private def closureConvert(remaps: Map[Int, target.SELoc], expr: source.SExpr): target.SExpr = {
|
||||
|
||||
// remaps is a function which maps the relative offset from variables (SEVar) to their runtime location
|
||||
// The Map must contain a binding for every variable referenced.
|
||||
// The Map is consulted when translating variable references (SEVar) and free variables of an abstraction (SEAbs)
|
||||
def remap(i: Int): target.SELoc =
|
||||
remaps.get(i) match {
|
||||
case Some(loc) => loc
|
||||
case None =>
|
||||
throw CompilationError(s"remap($i),remaps=$remaps")
|
||||
def shift(n: Int): Env = {
|
||||
def shiftLoc(loc: target.SELoc, n: Int): target.SELoc = loc match {
|
||||
case target.SELocS(i) => target.SELocS(i + n)
|
||||
case target.SELocA(_) | target.SELocF(_) => loc
|
||||
}
|
||||
// We must update both the keys of the map (the relative-indexes from the original SEVar)
|
||||
// And also any values in the map which are stack located (SELocS), which are also indexed relatively
|
||||
val m1 = mapping.map { case (k, loc) => (n + k, shiftLoc(loc, n)) }
|
||||
// And create mappings for the `n` new stack items
|
||||
val m2 = (1 to n).view.map(i => (i, target.SELocS(i)))
|
||||
Env(m1 ++ m2)
|
||||
}
|
||||
expr match {
|
||||
case source.SEVar(i) => remap(i)
|
||||
case source.SEVal(ref) => target.SEVal(ref)
|
||||
case source.SEBuiltin(b) => target.SEBuiltin(b)
|
||||
case source.SEValue(v) => target.SEValue(v)
|
||||
case source.SELocation(loc, body) =>
|
||||
target.SELocation(loc, closureConvert(remaps, body))
|
||||
}
|
||||
|
||||
case source.SEAbs(0, _) =>
|
||||
throw CompilationError("empty SEAbs")
|
||||
|
||||
case source.SEAbs(arity, body) =>
|
||||
val fvs = freeVars(body, arity).toList.sorted
|
||||
val newRemapsF: Map[Int, target.SELoc] = fvs.zipWithIndex.map { case (orig, i) =>
|
||||
object Env {
|
||||
def apply(): Env = {
|
||||
Env(Map.empty)
|
||||
}
|
||||
def absBody(arity: Int, fvs: List[Int]): Env = {
|
||||
val newRemapsF: Map[Int, target.SELoc] = fvs.view.zipWithIndex.map { case (orig, i) =>
|
||||
(orig + arity) -> target.SELocF(i)
|
||||
}.toMap
|
||||
val newRemapsA = (1 to arity).map { case i =>
|
||||
val newRemapsA = (1 to arity).view.map { case i =>
|
||||
i -> target.SELocA(arity - i)
|
||||
}
|
||||
// The keys in newRemapsF and newRemapsA are disjoint
|
||||
val newBody = closureConvert(newRemapsF ++ newRemapsA, body)
|
||||
target.SEMakeClo(fvs.map(remap).toArray, arity, newBody)
|
||||
|
||||
case source.SEAppGeneral(fun, args) =>
|
||||
val newFun = closureConvert(remaps, fun)
|
||||
val newArgs = args.map(closureConvert(remaps, _))
|
||||
target.SEApp(newFun, newArgs)
|
||||
|
||||
case source.SECase(scrut, alts) =>
|
||||
target.SECase(
|
||||
closureConvert(remaps, scrut),
|
||||
alts.map { case source.SCaseAlt(pat, body) =>
|
||||
val n = pat.numArgs
|
||||
target.SCaseAlt(
|
||||
pat,
|
||||
closureConvert(shift(remaps, n), body),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
case source.SELet(bounds, body) =>
|
||||
target.SELet(
|
||||
bounds.zipWithIndex.map { case (b, i) =>
|
||||
closureConvert(shift(remaps, i), b)
|
||||
},
|
||||
closureConvert(shift(remaps, bounds.length), body),
|
||||
)
|
||||
|
||||
case source.SETryCatch(body, handler) =>
|
||||
target.SETryCatch(
|
||||
closureConvert(remaps, body),
|
||||
closureConvert(shift(remaps, 1), handler),
|
||||
)
|
||||
|
||||
case source.SEScopeExercise(body) =>
|
||||
target.SEScopeExercise(closureConvert(remaps, body))
|
||||
|
||||
case source.SELabelClosure(label, expr) =>
|
||||
target.SELabelClosure(label, closureConvert(remaps, expr))
|
||||
|
||||
case source.SELet1General(bound, body) =>
|
||||
target.SELet1General(closureConvert(remaps, bound), closureConvert(shift(remaps, 1), body))
|
||||
val m1 = newRemapsF ++ newRemapsA
|
||||
Env(m1)
|
||||
}
|
||||
}
|
||||
|
||||
/** Closure-conversion, Traversal:
|
||||
*
|
||||
* To ensure stack-safety, the input expression is traversed by a single tail-recursive 'loop'.
|
||||
* During the 'Traversal', we are either:
|
||||
* - going 'Down' a source expression (subtree), with an 'Env' for context, or
|
||||
* - coming 'Up' with a target expression (result for a subtree)
|
||||
*
|
||||
* In both cases we have a continuation ('List[Cont]') argument which says how to proceed.
|
||||
*/
|
||||
sealed abstract class Traversal
|
||||
object Traversal {
|
||||
final case class Down(exp: source.SExpr, env: Env) extends Traversal
|
||||
final case class Up(exp: target.SExpr) extends Traversal
|
||||
}
|
||||
import Traversal._
|
||||
|
||||
/** Closure Conversion, Cont:
|
||||
*
|
||||
* The multiple forms for a continuation describe how the result of transforming a
|
||||
* sub-expression should be embedded in the continuing traversal. The continuation
|
||||
* forms correspond to the source expression forms: specifically, the location of
|
||||
* recursive expression instances (values of type SExpr).
|
||||
*
|
||||
* For expression forms with no recursive instance (i.e. SEVar, SEVal), there are
|
||||
* no corresponding continuation forms.
|
||||
*
|
||||
* For expression forms with a single recursive instance (i.e. SELocation), there
|
||||
* is a single continuation form: (Cont.Location).
|
||||
*
|
||||
* For expression forms with two recursive instances (i.e. SETryCatch), there are
|
||||
* two corresponding continuation forms: (Cont.TryCatch1, Cont.TryCatch2).
|
||||
*
|
||||
* For the more complex expression forms containing a list of recursive instances
|
||||
* (i.e. SEAppGeneral), the corresponding continuation forms are also more complex,
|
||||
* but will generally have two cases (i.e. Cont.App1, Cont.App2), corresponding to
|
||||
* the Nil/Cons cases of the list of recursive instances.
|
||||
*
|
||||
* And so on. In effect, 'Cont' is a zipper type for expressions.
|
||||
*
|
||||
* Another way to understand the continuation forms is by observing the presence of
|
||||
* an 'env: Env' component indicates more source-expression processing to be done
|
||||
* (generally with the components following the 'env'). Any components before the
|
||||
* 'env' (or all components if there is no 'env') represent transform-(sub)-results
|
||||
* which need combining into the final result.
|
||||
*/
|
||||
sealed abstract class Cont
|
||||
object Cont {
|
||||
|
||||
final case class Location(loc: Ref.Location) extends Cont
|
||||
|
||||
final case class Abs(arity: Int, fvs: List[target.SELoc]) extends Cont
|
||||
|
||||
final case class App1(env: Env, args: List[source.SExpr]) extends Cont
|
||||
|
||||
final case class App2(
|
||||
funDone: target.SExpr,
|
||||
argsDone: List[target.SExpr],
|
||||
env: Env,
|
||||
args: List[source.SExpr],
|
||||
) extends Cont
|
||||
|
||||
final case class Case1(env: Env, alts: List[source.SCaseAlt]) extends Cont
|
||||
|
||||
final case class Case2(
|
||||
scrut: target.SExpr,
|
||||
altsDone: List[target.SCaseAlt],
|
||||
pat: SCasePat,
|
||||
env: Env,
|
||||
alts: List[source.SCaseAlt],
|
||||
) extends Cont
|
||||
|
||||
final case class Let1(
|
||||
boundsDone: List[target.SExpr],
|
||||
env: Env,
|
||||
bounds: List[source.SExpr],
|
||||
body: source.SExpr,
|
||||
) extends Cont
|
||||
|
||||
final case class Let2(
|
||||
boundsDone: List[target.SExpr]
|
||||
) extends Cont
|
||||
|
||||
final case class TryCatch1(
|
||||
env: Env,
|
||||
handler: source.SExpr,
|
||||
) extends Cont
|
||||
|
||||
final case class TryCatch2(
|
||||
body: target.SExpr
|
||||
) extends Cont
|
||||
|
||||
final case object ScopeExercise extends Cont
|
||||
|
||||
final case class LabelClosure(label: Profile.Label) extends Cont
|
||||
}
|
||||
|
||||
/* The entire traversal in performed by this single tail recursive 'loop' function.
|
||||
*
|
||||
* The 'loop' has two arguments:
|
||||
* - The traversal item (Down/Up), and a continuation-stack 'conts'.
|
||||
*
|
||||
* The traversal is matched to see if we are going 'Down`, or 'Up.
|
||||
* - When going 'Down', we perform case-analysis on the source-expression being traversed.
|
||||
* - when going 'Up, we perform case-analysis on the continuation-stack.
|
||||
* When the continuation-stack is empty, we are finished.
|
||||
*/
|
||||
@tailrec
|
||||
def loop(traversal: Traversal, conts: List[Cont]): target.SExpr = {
|
||||
|
||||
traversal match {
|
||||
|
||||
// Going Down: match on expression form...
|
||||
case Down(exp, env) =>
|
||||
exp match {
|
||||
case source.SEVar(i) => loop(Up(env.lookup(i)), conts)
|
||||
case source.SEVal(x) => loop(Up(target.SEVal(x)), conts)
|
||||
case source.SEBuiltin(x) => loop(Up(target.SEBuiltin(x)), conts)
|
||||
case source.SEValue(x) => loop(Up(target.SEValue(x)), conts)
|
||||
|
||||
case source.SELocation(loc, body) =>
|
||||
loop(Down(body, env), Cont.Location(loc) :: conts)
|
||||
|
||||
case source.SEAbs(arity, body) =>
|
||||
val fvsAsListInt = freeVars(body, arity).toList.sorted
|
||||
val fvs = fvsAsListInt.map(i => env.lookup(i))
|
||||
loop(Down(body, Env.absBody(arity, fvsAsListInt)), Cont.Abs(arity, fvs) :: conts)
|
||||
|
||||
case source.SEApp(fun, args) =>
|
||||
loop(Down(fun, env), Cont.App1(env, args) :: conts)
|
||||
|
||||
case source.SECase(scrut, alts) =>
|
||||
loop(Down(scrut, env), Cont.Case1(env, alts) :: conts)
|
||||
|
||||
case source.SELet(bounds, body) =>
|
||||
bounds match {
|
||||
case Nil =>
|
||||
loop(Down(body, env), Cont.Let2(Nil) :: conts)
|
||||
case bound :: bounds =>
|
||||
loop(Down(bound, env), Cont.Let1(Nil, env, bounds, body) :: conts)
|
||||
}
|
||||
|
||||
case source.SETryCatch(body, handler) =>
|
||||
loop(Down(body, env), Cont.TryCatch1(env, handler) :: conts)
|
||||
|
||||
case source.SEScopeExercise(body) =>
|
||||
loop(Down(body, env), Cont.ScopeExercise :: conts)
|
||||
|
||||
case source.SELabelClosure(label, expr) =>
|
||||
loop(Down(expr, env), Cont.LabelClosure(label) :: conts)
|
||||
}
|
||||
|
||||
// Going Up: match on continuation...
|
||||
case Up(result) =>
|
||||
conts match {
|
||||
|
||||
case Nil => result // The final result of the tail-recursive 'loop'.
|
||||
|
||||
case cont :: conts =>
|
||||
cont match {
|
||||
|
||||
// We rebind the current result (i.e. 'val scrut = result') to help
|
||||
// indicate how it is embedded into the target expression being constructed.
|
||||
|
||||
case Cont.Location(loc) =>
|
||||
val body = result
|
||||
loop(Up(target.SELocation(loc, body)), conts)
|
||||
|
||||
case Cont.Abs(arity, fvs) =>
|
||||
val body = result
|
||||
loop(Up(target.SEMakeClo(fvs, arity, body)), conts)
|
||||
|
||||
case Cont.App1(env, args) =>
|
||||
val fun = result
|
||||
args match {
|
||||
case Nil =>
|
||||
loop(Up(target.SEApp(fun, Nil)), conts)
|
||||
case arg :: args =>
|
||||
loop(Down(arg, env), Cont.App2(fun, Nil, env, args) :: conts)
|
||||
}
|
||||
|
||||
case Cont.App2(fun, argsDone0, env, args) =>
|
||||
val argsDone = result :: argsDone0
|
||||
args match {
|
||||
case Nil =>
|
||||
loop(Up(target.SEApp(fun, argsDone.reverse)), conts)
|
||||
case arg :: args =>
|
||||
loop(Down(arg, env), Cont.App2(fun, argsDone, env, args) :: conts)
|
||||
}
|
||||
|
||||
case Cont.Case1(env, alts) =>
|
||||
val scrut = result
|
||||
alts match {
|
||||
case Nil =>
|
||||
loop(Up(target.SECase(scrut, Nil)), conts)
|
||||
case source.SCaseAlt(pat, rhs) :: alts =>
|
||||
val n = pat.numArgs
|
||||
loop(Down(rhs, env.shift(n)), Cont.Case2(scrut, Nil, pat, env, alts) :: conts)
|
||||
}
|
||||
|
||||
case Cont.Case2(scrut, altsDone0, pat, env, alts) =>
|
||||
val altsDone = target.SCaseAlt(pat, result) :: altsDone0
|
||||
alts match {
|
||||
case Nil =>
|
||||
loop(Up(target.SECase(scrut, altsDone.reverse)), conts)
|
||||
case source.SCaseAlt(pat, rhs) :: alts =>
|
||||
val n = pat.numArgs
|
||||
val env1 = env.shift(n)
|
||||
loop(Down(rhs, env1), Cont.Case2(scrut, altsDone, pat, env, alts) :: conts)
|
||||
}
|
||||
|
||||
case Cont.Let1(boundsDone0, env, bounds, body) =>
|
||||
val boundsDone = result :: boundsDone0
|
||||
val depth = boundsDone.length
|
||||
val env1 = env.shift(depth)
|
||||
bounds match {
|
||||
case Nil =>
|
||||
loop(Down(body, env1), Cont.Let2(boundsDone) :: conts)
|
||||
case bound :: bounds =>
|
||||
loop(Down(bound, env1), Cont.Let1(boundsDone, env, bounds, body) :: conts)
|
||||
}
|
||||
|
||||
case Cont.Let2(boundsDone) =>
|
||||
val body = result
|
||||
loop(Up(target.SELet(boundsDone.reverse, body)), conts)
|
||||
|
||||
case Cont.TryCatch1(env, handler) =>
|
||||
val body = result
|
||||
loop(Down(handler, env.shift(1)), Cont.TryCatch2(body) :: conts)
|
||||
|
||||
case Cont.TryCatch2(body) =>
|
||||
val handler = result
|
||||
loop(Up(target.SETryCatch(body, handler)), conts)
|
||||
|
||||
case Cont.ScopeExercise =>
|
||||
val body = result
|
||||
loop(Up(target.SEScopeExercise(body)), conts)
|
||||
|
||||
case Cont.LabelClosure(label) =>
|
||||
val expr = result
|
||||
loop(Up(target.SELabelClosure(label, expr)), conts)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* The (stack-safe) transformation is started here, passing the original source
|
||||
* expression (source1), an empty environment, and an empty continuation-stack.
|
||||
*/
|
||||
loop(Down(source0, Env()), Nil)
|
||||
}
|
||||
|
||||
// Modify/extend `remaps` to reflect when new values are pushed on the stack. This
|
||||
// happens as we traverse into SELet and SECase bodies which have bindings which at
|
||||
// runtime will appear on the stack.
|
||||
// We must modify `remaps` because it is keyed by indexes relative to the end of the stack.
|
||||
// And any values in the map which are of the form SELocS must also be _shifted_
|
||||
// because SELocS indexes are also relative to the end of the stack.
|
||||
private[this] def shift(remaps: Map[Int, target.SELoc], n: Int): Map[Int, target.SELoc] = {
|
||||
|
||||
// We must update both the keys of the map (the relative-indexes from the original SEVar)
|
||||
// And also any values in the map which are stack located (SELocS), which are also indexed relatively
|
||||
val m1 = remaps.map { case (k, loc) => (n + k, shiftLoc(loc, n)) }
|
||||
|
||||
// And create mappings for the `n` new stack items
|
||||
val m2 = (1 to n).map(i => (i, target.SELocS(i)))
|
||||
|
||||
m1 ++ m2
|
||||
}
|
||||
|
||||
private[this] def shiftLoc(loc: target.SELoc, n: Int): target.SELoc = loc match {
|
||||
case target.SELocS(i) => target.SELocS(i + n)
|
||||
case target.SELocA(_) | target.SELocF(_) => loc
|
||||
}
|
||||
// TODO: Recode to avoid polynomial-complexity of 'freeVars' computation. Issue #11830
|
||||
|
||||
/** Compute the free variables in a speedy expression.
|
||||
* The returned free variables are de bruijn indices
|
||||
* adjusted to the stack of the caller.
|
||||
*/
|
||||
private[this] def freeVars(expr: source.SExpr, initiallyBound: Int): Set[Int] = {
|
||||
// @tailrec // TODO: This implementation is not stack-safe. Issue #11830
|
||||
def go(expr: source.SExpr, bound: Int, free: Set[Int]): Set[Int] =
|
||||
expr match {
|
||||
case source.SEVar(i) =>
|
||||
@ -155,7 +329,7 @@ private[speedy] object ClosureConversion {
|
||||
case _: source.SEValue => free
|
||||
case source.SELocation(_, body) =>
|
||||
go(body, bound, free)
|
||||
case source.SEAppGeneral(fun, args) =>
|
||||
case source.SEApp(fun, args) =>
|
||||
args.foldLeft(go(fun, bound, free))((acc, arg) => go(arg, bound, acc))
|
||||
case source.SEAbs(n, body) =>
|
||||
go(body, bound + n, free)
|
||||
@ -174,9 +348,6 @@ private[speedy] object ClosureConversion {
|
||||
go(body, bound, go(handler, 1 + bound, free))
|
||||
case source.SEScopeExercise(body) =>
|
||||
go(body, bound, free)
|
||||
|
||||
case _: source.SELet1General =>
|
||||
throw CompilationError(s"freeVars: unexpected $expr")
|
||||
}
|
||||
|
||||
go(expr, initiallyBound, Set.empty)
|
||||
|
@ -91,7 +91,7 @@ private[lf] object Compiler {
|
||||
private val SEGetTime = s.SEBuiltin(SBGetTime)
|
||||
|
||||
private def SBCompareNumeric(b: SBuiltinPure) =
|
||||
s.SEAbs(3, s.SEApp(s.SEBuiltin(b), Array(s.SEVar(2), s.SEVar(1))))
|
||||
s.SEAbs(3, s.SEApp(s.SEBuiltin(b), List(s.SEVar(2), s.SEVar(1))))
|
||||
private val SBLessNumeric = SBCompareNumeric(SBLess)
|
||||
private val SBLessEqNumeric = SBCompareNumeric(SBLessEq)
|
||||
private val SBGreaterNumeric = SBCompareNumeric(SBGreater)
|
||||
@ -128,14 +128,14 @@ private[lf] object Compiler {
|
||||
}
|
||||
|
||||
// Hand-implemented `map` uses less stack.
|
||||
private def mapToArray[A, B: ClassTag](input: ImmArray[A])(f: A => B): Array[B] = {
|
||||
private def mapToArray[A, B: ClassTag](input: ImmArray[A])(f: A => B): List[B] = {
|
||||
val output = Array.ofDim[B](input.length)
|
||||
var i = 0
|
||||
input.foreach { value =>
|
||||
output(i) = f(value)
|
||||
i += 1
|
||||
}
|
||||
output
|
||||
output.toList
|
||||
}
|
||||
|
||||
}
|
||||
@ -263,7 +263,7 @@ private[lf] final class Compiler(
|
||||
case None => expr
|
||||
}
|
||||
|
||||
private[this] def app(f: s.SExpr, a: s.SExpr) = s.SEApp(f, Array(a))
|
||||
private[this] def app(f: s.SExpr, a: s.SExpr) = s.SEApp(f, List(a))
|
||||
|
||||
private[this] def let(env: Env, bound: s.SExpr)(f: (Position, Env) => s.SExpr): s.SELet =
|
||||
f(env.nextPosition, env.pushVar) match {
|
||||
@ -513,7 +513,7 @@ private[lf] final class Compiler(
|
||||
case ECons(_, front, tail) =>
|
||||
// TODO(JM): Consider emitting SEValue(SList(...)) for
|
||||
// constant lists?
|
||||
val args = (front.iterator.map(compile(env, _)) ++ Seq(compile(env, tail))).toArray
|
||||
val args = (front.iterator.map(compile(env, _)) ++ Seq(compile(env, tail))).toList
|
||||
if (front.length == 1) {
|
||||
s.SEApp(s.SEBuiltin(SBCons), args)
|
||||
} else {
|
||||
@ -742,7 +742,7 @@ private[lf] final class Compiler(
|
||||
else
|
||||
s.SEApp(
|
||||
s.SEBuiltin(SBRecCon(tApp.tycon, fields.map(_._1))),
|
||||
fields.iterator.map(f => compile(env, f._2)).toArray,
|
||||
fields.iterator.map(f => compile(env, f._2)).toList,
|
||||
)
|
||||
|
||||
private[this] def compileERecUpd(env: Env, erecupd: ERecUpd): s.SExpr = {
|
||||
@ -899,7 +899,7 @@ private[lf] final class Compiler(
|
||||
case _ if args.isEmpty =>
|
||||
compile(env, expr0)
|
||||
case _ =>
|
||||
s.SEApp(compile(env, expr0), args.toArray)
|
||||
s.SEApp(compile(env, expr0), args)
|
||||
}
|
||||
|
||||
private[this] def translateType(env: Env, typ: Type): Option[s.SExpr] =
|
||||
@ -1305,7 +1305,7 @@ private[lf] final class Compiler(
|
||||
(Iterator(compile(env2, tmpl.precond)) ++ implementsPrecondsIterator ++ Iterator(
|
||||
s.SEValue.EmptyList
|
||||
)).to(ImmArray)
|
||||
val preconds = s.SEApp(s.SEBuiltin(SBConsMany(precondsArray.length - 1)), precondsArray.toArray)
|
||||
val preconds = s.SEApp(s.SEBuiltin(SBConsMany(precondsArray.length - 1)), precondsArray.toList)
|
||||
// We check precondition in a separated builtin to prevent
|
||||
// further evaluation of agreement, signatories, observers and key
|
||||
// in case of failed precondition.
|
||||
|
@ -52,7 +52,7 @@ private[speedy] sealed abstract class SBuiltin(val arity: Int) {
|
||||
|
||||
// TODO: move this into the speedy compiler code
|
||||
private[lf] def apply(args: compileTime.SExpr*): compileTime.SExpr =
|
||||
compileTime.SEApp(compileTime.SEBuiltin(this), args.toArray)
|
||||
compileTime.SEApp(compileTime.SEBuiltin(this), args.toList)
|
||||
|
||||
// TODO: avoid constructing application expression at run time
|
||||
private[lf] def apply(args: runTime.SExpr*): runTime.SExpr =
|
||||
|
@ -408,7 +408,7 @@ object SExpr {
|
||||
def modName: ModuleName = ref.qualifiedName.module
|
||||
// TODO: move this into the speedy compiler code
|
||||
private[this] val eval = compileTime.SEVal(this)
|
||||
def apply(args: compileTime.SExpr*) = compileTime.SEApp(eval, args.toArray)
|
||||
def apply(args: compileTime.SExpr*) = compileTime.SEApp(eval, args.toList)
|
||||
}
|
||||
|
||||
// references to definitions that come from the archive
|
||||
|
@ -30,14 +30,14 @@ package speedy
|
||||
*
|
||||
* Summary of which constructors are contained by: SExp0, SExpr1 and SExpr:
|
||||
*
|
||||
* - In SExpr{0,1,} (everywhere): SEAppGeneral, SEBuiltin, SELabelClosure, SELet1General,
|
||||
* - In SExpr{0,1,} (everywhere): SEAppGeneral, SEBuiltin, SELabelClosure,
|
||||
* SELocation, SEScopeExercise, SETryCatch, SEVal, SEValue,
|
||||
*
|
||||
* - In SExpr0: SEAbs, SEVar
|
||||
*
|
||||
* - In SExpr{0,1}: SECase, SELet
|
||||
*
|
||||
* - In SExpr{1,}: SELocA, SELocF, SELocS, SEMakeClo,
|
||||
* - In SExpr{1,}: SELocA, SELocF, SELocS, SEMakeClo, SELet1General,
|
||||
*
|
||||
* - In SExpr: SEAppAtomicFun, SEAppAtomicGeneral, SEAppAtomicSaturatedBuiltin,
|
||||
* SECaseAtomic, SELet1Builtin, SELet1BuiltinArithmetic
|
||||
@ -76,13 +76,7 @@ private[speedy] object SExpr0 {
|
||||
object SEValue extends SValueContainer[SEValue]
|
||||
|
||||
/** Function application */
|
||||
final case class SEAppGeneral(fun: SExpr, args: Array[SExpr]) extends SExpr with SomeArrayEquals
|
||||
|
||||
object SEApp {
|
||||
def apply(fun: SExpr, args: Array[SExpr]): SExpr = {
|
||||
SEAppGeneral(fun, args)
|
||||
}
|
||||
}
|
||||
final case class SEApp(fun: SExpr, args: List[SExpr]) extends SExpr
|
||||
|
||||
/** Lambda abstraction. Transformed to SEMakeClo during closure conversion */
|
||||
final case class SEAbs(arity: Int, body: SExpr) extends SExpr
|
||||
@ -96,13 +90,7 @@ private[speedy] object SExpr0 {
|
||||
}
|
||||
|
||||
/** Pattern match. */
|
||||
final case class SECase(scrut: SExpr, alts: Array[SCaseAlt]) extends SExpr with SomeArrayEquals
|
||||
|
||||
/** A let-expression with a single RHS
|
||||
* This form only exists *during* the ANF transformation, but not when the ANF
|
||||
* transformation is finished.
|
||||
*/
|
||||
final case class SELet1General(rhs: SExpr, body: SExpr) extends SExpr with SomeArrayEquals
|
||||
final case class SECase(scrut: SExpr, alts: List[SCaseAlt]) extends SExpr
|
||||
|
||||
/** A non-recursive, non-parallel let block.
|
||||
* It is used as an intermediary data structure by the compiler to
|
||||
|
@ -35,20 +35,13 @@ private[speedy] object SExpr1 {
|
||||
/** Function application:
|
||||
* General case: 'fun' and 'args' are any kind of expression
|
||||
*/
|
||||
final case class SEAppGeneral(fun: SExpr, args: Array[SExpr]) extends SExpr with SomeArrayEquals
|
||||
|
||||
object SEApp {
|
||||
def apply(fun: SExpr, args: Array[SExpr]): SExpr = {
|
||||
SEAppGeneral(fun, args)
|
||||
}
|
||||
}
|
||||
final case class SEApp(fun: SExpr, args: List[SExpr]) extends SExpr
|
||||
|
||||
/** Closure creation. Create a new closure object storing the free variables
|
||||
* in 'body'.
|
||||
*/
|
||||
final case class SEMakeClo(fvs: Array[SELoc], arity: Int, body: SExpr)
|
||||
extends SExpr
|
||||
with SomeArrayEquals
|
||||
final case class SEMakeClo(fvs: List[SELoc], arity: Int, body: SExpr) extends SExpr
|
||||
|
||||
/** SELoc -- Reference to the runtime location of a variable.
|
||||
*
|
||||
@ -67,7 +60,7 @@ private[speedy] object SExpr1 {
|
||||
final case class SELocF(n: Int) extends SELoc
|
||||
|
||||
/** Pattern match. */
|
||||
final case class SECase(scrut: SExpr, alts: Array[SCaseAlt]) extends SExpr with SomeArrayEquals
|
||||
final case class SECase(scrut: SExpr, alts: List[SCaseAlt]) extends SExpr
|
||||
|
||||
/** A let-expression with a single RHS
|
||||
* This form only exists *during* the ANF transformation, but not when the ANF
|
||||
|
@ -129,7 +129,7 @@ class AnfTest extends AnyWordSpec with Matchers {
|
||||
|
||||
"error applied to 1 arg" should {
|
||||
"be transformed to ANF as expected" in {
|
||||
val original = slam(1, source.SEApp(source.SEBuiltin(SBError), Array(sarg0)))
|
||||
val original = slam(1, source.SEApp(source.SEBuiltin(SBError), List(sarg0)))
|
||||
val expected = lam(1, target.SEAppAtomicSaturatedBuiltin(SBError, Array(arg0)))
|
||||
testTransform(original, expected)
|
||||
}
|
||||
@ -137,7 +137,7 @@ class AnfTest extends AnyWordSpec with Matchers {
|
||||
|
||||
"error (over) applied to 2 arg" should {
|
||||
"be transformed to ANF as expected" in {
|
||||
val original = slam(2, source.SEApp(source.SEBuiltin(SBError), Array(sarg0, sarg1)))
|
||||
val original = slam(2, source.SEApp(source.SEBuiltin(SBError), List(sarg0, sarg1)))
|
||||
val expected = lam(2, target.SEAppAtomicFun(target.SEBuiltin(SBError), Array(arg0, arg1)))
|
||||
testTransform(original, expected)
|
||||
}
|
||||
@ -270,19 +270,19 @@ class AnfTest extends AnyWordSpec with Matchers {
|
||||
|
||||
// We have different expression types before/after the ANF transform, so we different constructors.
|
||||
// Use "s" (for "source") as a prefix to distinguish.
|
||||
private def slam(n: Int, body: source.SExpr): source.SExpr = source.SEMakeClo(Array(), n, body)
|
||||
private def slam(n: Int, body: source.SExpr): source.SExpr = source.SEMakeClo(List(), n, body)
|
||||
private def sclo1(fv: source.SELoc, n: Int, body: source.SExpr): source.SExpr =
|
||||
source.SEMakeClo(Array(fv), n, body)
|
||||
source.SEMakeClo(List(fv), n, body)
|
||||
private def sapp(func: source.SExpr, arg: source.SExpr): source.SExpr =
|
||||
source.SEAppGeneral(func, Array(arg))
|
||||
source.SEApp(func, List(arg))
|
||||
private def sbinop(op: SBuiltinPure, x: source.SExpr, y: source.SExpr): source.SExpr =
|
||||
source.SEApp(source.SEBuiltin(op), Array(x, y))
|
||||
source.SEApp(source.SEBuiltin(op), List(x, y))
|
||||
private def sbinop(op: SBuiltinArithmetic, x: source.SExpr, y: source.SExpr): source.SExpr =
|
||||
source.SEApp(source.SEBuiltin(op), Array(x, y))
|
||||
source.SEApp(source.SEBuiltin(op), List(x, y))
|
||||
private def sapp2(func: source.SExpr, arg1: source.SExpr, arg2: source.SExpr): source.SExpr =
|
||||
source.SEAppGeneral(func, Array(arg1, arg2))
|
||||
source.SEApp(func, List(arg1, arg2))
|
||||
private def site(i: source.SExpr, t: source.SExpr, e: source.SExpr): source.SExpr =
|
||||
source.SECase(i, Array(source.SCaseAlt(patTrue, t), source.SCaseAlt(patFalse, e)))
|
||||
source.SECase(i, List(source.SCaseAlt(patTrue, t), source.SCaseAlt(patFalse, e)))
|
||||
private def sarg0 = source.SELocA(0)
|
||||
private def sarg1 = source.SELocA(1)
|
||||
private def sarg2 = source.SELocA(2)
|
||||
|
@ -0,0 +1,143 @@
|
||||
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.lf.speedy
|
||||
|
||||
import com.daml.lf.data.Ref
|
||||
|
||||
import com.daml.lf.speedy.{SExpr0 => source}
|
||||
import com.daml.lf.speedy.{SExpr1 => target}
|
||||
import com.daml.lf.speedy.{SExpr => expr}
|
||||
import com.daml.lf.speedy.{SValue => v}
|
||||
|
||||
import org.scalatest.freespec.AnyFreeSpec
|
||||
import org.scalatest.matchers.should.Matchers
|
||||
import org.scalatest.prop.TableDrivenPropertyChecks
|
||||
|
||||
import scala.annotation.tailrec
|
||||
|
||||
class ClosureConversionTest extends AnyFreeSpec with Matchers with TableDrivenPropertyChecks {
|
||||
|
||||
import source._
|
||||
|
||||
// Construct one level of source-expression at various 'recursion-points'.
|
||||
// This list is intended to be exhaustive.
|
||||
private val location = (x: SExpr) => SELocation(loc, x)
|
||||
private val abs1 = (x: SExpr) => SEAbs(1, x)
|
||||
private val appF = (x: SExpr) => SEApp(x, List(leaf, leaf))
|
||||
private val app1 = (x: SExpr) => SEApp(leaf, List(x, leaf))
|
||||
private val app2 = (x: SExpr) => SEApp(leaf, List(leaf, x))
|
||||
private val scrut = (x: SExpr) => SECase(x, List(SCaseAlt(pat, leaf), SCaseAlt(pat, leaf)))
|
||||
private val alt1 = (x: SExpr) => SECase(leaf, List(SCaseAlt(pat, x), SCaseAlt(pat, leaf)))
|
||||
private val alt2 = (x: SExpr) => SECase(leaf, List(SCaseAlt(pat, leaf), SCaseAlt(pat, x)))
|
||||
private val let1 = (x: SExpr) => SELet(List(x, leaf), leaf)
|
||||
private val let2 = (x: SExpr) => SELet(List(leaf, x), leaf)
|
||||
private val letBody = (x: SExpr) => SELet(List(leaf, leaf), x)
|
||||
private val tryCatch1 = (x: SExpr) => SETryCatch(leaf, x)
|
||||
private val tryCatch2 = (x: SExpr) => SETryCatch(x, leaf)
|
||||
private val labelClosure = (x: SExpr) => SELabelClosure(label, x)
|
||||
|
||||
"closure conversion" - {
|
||||
|
||||
// This is the code under test...
|
||||
def transform(e: SExpr): target.SExpr = {
|
||||
import com.daml.lf.speedy.ClosureConversion.closureConvert
|
||||
closureConvert(e)
|
||||
}
|
||||
|
||||
/* We test stack-safety by building deep expressions through each of the different
|
||||
* recursion points of an expression, using one of the builder functions above, and
|
||||
* then ensuring we can 'transform' the expression using 'closureConvert'.
|
||||
*/
|
||||
def runTest(depth: Int, cons: SExpr => SExpr) = {
|
||||
// Make an expression by iterating the 'cons' function, 'depth' times
|
||||
@tailrec def loop(x: SExpr, n: Int): SExpr = if (n == 0) x else loop(cons(x), n - 1)
|
||||
val exp: SExpr = loop(leaf, depth)
|
||||
val _: target.SExpr = transform(exp)
|
||||
true
|
||||
}
|
||||
|
||||
/* The testcases are split into two sets:
|
||||
*
|
||||
* For both sets the code under test is stack-safe, but the 2nd set provokes an
|
||||
* unrelated quadratic-or-worse time-issue in the handling of 'Env' management and the
|
||||
* free-vars computation, during the closure-conversion transform.
|
||||
*/
|
||||
val testCases1 = {
|
||||
Table[String, SExpr => SExpr](
|
||||
("name", "recursion-point"),
|
||||
("Location", location),
|
||||
("AppF", appF),
|
||||
("App1", app1),
|
||||
("App2", app2),
|
||||
("Scrut", scrut),
|
||||
("Let1", let1),
|
||||
("TryCatch2", tryCatch2),
|
||||
("Labelclosure", labelClosure),
|
||||
)
|
||||
}
|
||||
|
||||
// These 'quadratic' testcases pertain to recursion-points under a binder.
|
||||
val testCases2 = {
|
||||
Table[String, SExpr => SExpr](
|
||||
("name", "recursion-point"),
|
||||
("Abs", abs1),
|
||||
("Alt1", alt1),
|
||||
("Alt2", alt2),
|
||||
("Let2", let2),
|
||||
("LetBody", letBody),
|
||||
("TryCatch1", tryCatch1),
|
||||
)
|
||||
}
|
||||
|
||||
{
|
||||
// All tests. Shallow enough for pre-stack-safe closure-conversion code to pass.
|
||||
val depth = 100
|
||||
s"depth = $depth" - {
|
||||
forEvery(testCases1 ++ testCases2) { (name: String, recursionPoint: SExpr => SExpr) =>
|
||||
name in {
|
||||
runTest(depth, recursionPoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Only first set. At this depth we can be really sure that we are stack-safe.
|
||||
val depth = 100000
|
||||
s"depth = $depth" - {
|
||||
forEvery(testCases1) { (name: String, recursionPoint: SExpr => SExpr) =>
|
||||
name in {
|
||||
runTest(depth, recursionPoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Only 2nd set. This depth is not really deep enough to ensure stack-safety, but
|
||||
// much deeper and the quadratic-or-worse time-complexity starts to seriously slow
|
||||
// down the test run.
|
||||
// TODO: fix quadratic time issue to allow these tests to be run at depth 100000.
|
||||
val depth = 1000
|
||||
s"depth = $depth" - {
|
||||
forEvery(testCases2) { (name: String, recursionPoint: SExpr => SExpr) =>
|
||||
name in {
|
||||
runTest(depth, recursionPoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val leaf = SEValue(v.SText("leaf"))
|
||||
private val label: Profile.Label = expr.AnonymousClosure
|
||||
private val pat: expr.SCasePat = expr.SCPCons
|
||||
private val loc = Ref.Location(
|
||||
Ref.PackageId.assertFromString("P"),
|
||||
Ref.ModuleName.assertFromString("M"),
|
||||
"X",
|
||||
(1, 2),
|
||||
(3, 4),
|
||||
)
|
||||
}
|
Loading…
Reference in New Issue
Block a user