LF: ad-hoc data struct to store LF Struct like data (#7220)

A `Struct[+X]` is a list of pair `(FieldName, X)` ordered by first
component.

We use this data-structure to represent TStruct LF type.

It will be used in upcomming PRs to sorted fields in Struct values.

CHANGELOG_BEGIN
CHANGELOG_END
This commit is contained in:
Remy 2020-08-26 12:18:18 +02:00 committed by GitHub
parent 48e05e25e2
commit ada8ba033b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 145 additions and 52 deletions

View File

@ -8,7 +8,7 @@ import java.util
import com.daml.lf.archive.Decode.ParseError import com.daml.lf.archive.Decode.ParseError
import com.daml.lf.data.Ref._ import com.daml.lf.data.Ref._
import com.daml.lf.data.{Decimal, ImmArray, Numeric, Time} import com.daml.lf.data.{Decimal, ImmArray, Numeric, Struct, Time}
import ImmArray.ImmArraySeq import ImmArray.ImmArraySeq
import com.daml.lf.language.Ast._ import com.daml.lf.language.Ast._
import com.daml.lf.language.Util._ import com.daml.lf.language.Util._
@ -639,8 +639,7 @@ private[archive] class DecodeV1(minor: LV.Minor) extends Decode.OfPackage[PLF.Pa
val struct = lfType.getStruct val struct = lfType.getStruct
val fields = struct.getFieldsList.asScala val fields = struct.getFieldsList.asScala
assertNonEmpty(fields, "fields") assertNonEmpty(fields, "fields")
TStruct(fields.map(decodeFieldWithType)(breakOut)) TStruct(Struct(fields.map(decodeFieldWithType): _*))
case PLF.Type.SumCase.SUM_NOT_SET => case PLF.Type.SumCase.SUM_NOT_SET =>
throw ParseError("Type.SUM_NOT_SET") throw ParseError("Type.SUM_NOT_SET")
} }

View File

@ -0,0 +1,48 @@
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf
package data
import com.daml.lf.data.Ref.Name
final case class Struct[+X] private (sortedFields: ImmArray[(Ref.Name, X)]) extends NoCopy {
def lookup(name: Ref.Name): Option[X] =
sortedFields.find(_._1 == name).map(_._2)
def mapValue[Y](f: X => Y) = new Struct(sortedFields.map { case (k, v) => k -> f(v) })
def toImmArray: ImmArray[(Ref.Name, X)] = sortedFields
def names: Iterator[Ref.Name] = iterator.map(_._1)
def values: Iterator[X] = iterator.map(_._2)
def iterator: Iterator[(Ref.Name, X)] = sortedFields.iterator
def foreach(f: ((Ref.Name, X)) => Unit): Unit = sortedFields.foreach(f)
}
object Struct {
def apply[X](fields: (Name, X)*): Struct[X] =
new Struct(fields.sortBy(_._1: String).to[ImmArray])
def apply[X](fields: ImmArray[(Name, X)]): Struct[X] = apply(fields.toSeq: _*)
def fromSortedImmArray[X](fields: ImmArray[(Ref.Name, X)]): Either[String, Struct[X]] = {
val struct = new Struct(fields)
Either.cond(
(struct.names zip struct.names.drop(1)).forall { case (x, y) => (x compare y) <= 0 },
struct,
s"the list $fields is not sorted by name"
)
}
private[this] val Emtpy = new Struct(ImmArray.empty)
def empty[X]: Struct[X] = Emtpy.asInstanceOf[Struct[X]]
}

View File

@ -0,0 +1,60 @@
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf.data
import org.scalatest.prop.PropertyChecks
import org.scalatest.{Matchers, WordSpec}
class StructSpec extends WordSpec with Matchers with PropertyChecks {
private[this] val List(f1, f2, f3) = List("f1", "f2", "f3").map(Ref.Name.assertFromString)
"SortMap.fromSortedImmArray" should {
"fail if the input list is not sorted" in {
val negativeTestCases =
Table(
"list",
ImmArray.empty,
ImmArray(f1 -> 1),
ImmArray(f1 -> 1, f1 -> 2),
ImmArray(f1 -> 1, f2 -> 2, f3 -> 3),
)
val positiveTestCases = Table(
"list",
ImmArray(f1 -> 1, f2 -> 2, f3 -> 3, f1 -> 2),
ImmArray(f2 -> 2, f3 -> 3, f1 -> 1),
)
forEvery(negativeTestCases)(l => Struct.fromSortedImmArray(l) shouldBe 'right)
forEvery(positiveTestCases)(l => Struct.fromSortedImmArray(l) shouldBe 'left)
}
}
"SortMap.apply" should {
"sorted fields" in {
val testCases =
Table(
"list",
Struct(),
Struct(f1 -> 1),
Struct(f1 -> 1, f1 -> 2),
Struct(f1 -> 1, f2 -> 2, f3 -> 3),
Struct(f1 -> 1, f2 -> 2, f3 -> 3, f1 -> 2),
Struct(f2 -> 2, f3 -> 3, f1 -> 1),
)
forEvery(testCases) { s =>
s.names.toSeq shouldBe s.names.toSeq.sorted[String]
}
}
}
}

View File

@ -262,7 +262,8 @@ private[daml] class EncodeV1(val minor: LV.Minor) {
PLF.Type.Forall.newBuilder().accumulateLeft(binders)(_ addVars _).setBody(body)) PLF.Type.Forall.newBuilder().accumulateLeft(binders)(_ addVars _).setBody(body))
case TStruct(fields) => case TStruct(fields) =>
expect(args.isEmpty) expect(args.isEmpty)
builder.setStruct(PLF.Type.Struct.newBuilder().accumulateLeft(fields)(_ addFields _)) builder.setStruct(
PLF.Type.Struct.newBuilder().accumulateLeft(fields.toImmArray)(_ addFields _))
case TSynApp(name, args) => case TSynApp(name, args) =>
val b = PLF.Type.Syn.newBuilder() val b = PLF.Type.Syn.newBuilder()
b.setTysyn(name) b.setTysyn(name)

View File

@ -6,7 +6,7 @@ package svalue
import java.util import java.util
import com.daml.lf.data.{FrontStack, ImmArray, Numeric, Ref, Time} import com.daml.lf.data.{FrontStack, ImmArray, Numeric, Ref, Struct, Time}
import com.daml.lf.language.{Ast, Util => AstUtil} import com.daml.lf.language.{Ast, Util => AstUtil}
import com.daml.lf.speedy.SResult._ import com.daml.lf.speedy.SResult._
import com.daml.lf.speedy.SValue._ import com.daml.lf.speedy.SValue._
@ -168,23 +168,23 @@ class OrderingSpec
} }
private val typeStructReps = List( private val typeStructReps = List(
ImmArray.empty, Struct.empty,
ImmArray(Ref.Name.assertFromString("field0") -> AstUtil.TUnit), Struct(Ref.Name.assertFromString("field0") -> AstUtil.TUnit),
ImmArray(Ref.Name.assertFromString("field0") -> AstUtil.TInt64), Struct(Ref.Name.assertFromString("field0") -> AstUtil.TInt64),
ImmArray(Ref.Name.assertFromString("field1") -> AstUtil.TUnit), Struct(Ref.Name.assertFromString("field1") -> AstUtil.TUnit),
ImmArray( Struct(
Ref.Name.assertFromString("field1") -> AstUtil.TUnit, Ref.Name.assertFromString("field1") -> AstUtil.TUnit,
Ref.Name.assertFromString("field2") -> AstUtil.TUnit, Ref.Name.assertFromString("field2") -> AstUtil.TUnit,
), ),
ImmArray( Struct(
Ref.Name.assertFromString("field1") -> AstUtil.TUnit, Ref.Name.assertFromString("field1") -> AstUtil.TUnit,
Ref.Name.assertFromString("field2") -> AstUtil.TInt64, Ref.Name.assertFromString("field2") -> AstUtil.TInt64,
), ),
ImmArray( Struct(
Ref.Name.assertFromString("field1") -> AstUtil.TInt64, Ref.Name.assertFromString("field1") -> AstUtil.TInt64,
Ref.Name.assertFromString("field2") -> AstUtil.TUnit, Ref.Name.assertFromString("field2") -> AstUtil.TUnit,
), ),
ImmArray( Struct(
Ref.Name.assertFromString("field1") -> AstUtil.TUnit, Ref.Name.assertFromString("field1") -> AstUtil.TUnit,
Ref.Name.assertFromString("field3") -> AstUtil.TUnit, Ref.Name.assertFromString("field3") -> AstUtil.TUnit,
), ),
@ -206,8 +206,8 @@ class OrderingSpec
Ast.TTyCon(VariantTypeCon), Ast.TTyCon(VariantTypeCon),
Ast.TNat(Numeric.Scale.MinValue), Ast.TNat(Numeric.Scale.MinValue),
Ast.TNat(Numeric.Scale.MaxValue), Ast.TNat(Numeric.Scale.MaxValue),
Ast.TStruct(ImmArray.empty), Ast.TStruct(Struct.empty),
Ast.TStruct(ImmArray(Ref.Name.assertFromString("field") -> AstUtil.TUnit)), Ast.TStruct(Struct(Ref.Name.assertFromString("field") -> AstUtil.TUnit)),
Ast.TApp(Ast.TBuiltin(Ast.BTArrow), Ast.TBuiltin(Ast.BTUnit)), Ast.TApp(Ast.TBuiltin(Ast.BTArrow), Ast.TBuiltin(Ast.BTUnit)),
Ast.TApp( Ast.TApp(
Ast.TApp(Ast.TBuiltin(Ast.BTArrow), Ast.TBuiltin(Ast.BTUnit)), Ast.TApp(Ast.TBuiltin(Ast.BTArrow), Ast.TBuiltin(Ast.BTUnit)),

View File

@ -219,7 +219,7 @@ object Ast {
case TForall((v, _), body) => case TForall((v, _), body) =>
maybeParens(prec > precTForall, "∀" + v + prettyForAll(body)) maybeParens(prec > precTForall, "∀" + v + prettyForAll(body))
case TStruct(fields) => case TStruct(fields) =>
"(" + fields "(" + fields.iterator
.map { case (n, t) => n + ": " + prettyType(t, precTForall) } .map { case (n, t) => n + ": " + prettyType(t, precTForall) }
.toSeq .toSeq
.mkString(", ") + ")" .mkString(", ") + ")"
@ -264,13 +264,7 @@ object Ast {
final case class TForall(binder: (TypeVarName, Kind), body: Type) extends Type final case class TForall(binder: (TypeVarName, Kind), body: Type) extends Type
/** Structs */ /** Structs */
final case class TStruct private (sortedFields: ImmArray[(FieldName, Type)]) extends Type final case class TStruct(fields: Struct[Type]) extends Type
object TStruct extends (ImmArray[(FieldName, Type)] => TStruct) {
// should be dropped once the compiler sort fields.
def apply(fields: ImmArray[(FieldName, Type)]): TStruct =
new TStruct(ImmArray(fields.toSeq.sortBy(_._1: String)))
}
sealed abstract class BuiltinType extends Product with Serializable sealed abstract class BuiltinType extends Product with Serializable

View File

@ -67,9 +67,7 @@ object TypeOrdering extends Ordering[Ast.Type] {
compareType(n1 compareTo n2, stack) compareType(n1 compareTo n2, stack)
case (Ast.TStruct(fields1), Ast.TStruct(fields2)) => case (Ast.TStruct(fields1), Ast.TStruct(fields2)) =>
compareType( compareType(
math.Ordering Ordering.Iterable[String].compare(fields1.names.toSeq, fields2.names.toSeq),
.Iterable[String]
.compare(fields1.toSeq.map(_._1), fields2.toSeq.map(_._1)),
zipAndPush(fields1.iterator.map(_._2), fields2.iterator.map(_._2), stack) zipAndPush(fields1.iterator.map(_._2), fields2.iterator.map(_._2), stack)
) )
case (Ast.TApp(t11, t12), Ast.TApp(t21, t22)) => case (Ast.TApp(t11, t12), Ast.TApp(t21, t22)) =>

View File

@ -36,5 +36,6 @@ da_scala_test(
":parser", ":parser",
"//daml-lf/data", "//daml-lf/data",
"//daml-lf/language", "//daml-lf/language",
"@maven//:org_scalaz_scalaz_core_2_12",
], ],
) )

View File

@ -42,7 +42,7 @@ private[daml] class AstRewriter(
case TForall(binder, body) => case TForall(binder, body) =>
TForall(binder, apply(body)) TForall(binder, apply(body))
case TStruct(fields) => case TStruct(fields) =>
TStruct(fields.map(apply)) TStruct(fields.mapValue(apply))
} }
def apply(nameWithType: (Name, Type)): (Name, Type) = nameWithType match { def apply(nameWithType: (Name, Type)): (Name, Type) = nameWithType match {

View File

@ -4,7 +4,7 @@
package com.daml.lf.testing.parser package com.daml.lf.testing.parser
import com.daml.lf.data import com.daml.lf.data
import com.daml.lf.data.{ImmArray, Ref} import com.daml.lf.data.{ImmArray, Ref, Struct}
import com.daml.lf.language.Ast._ import com.daml.lf.language.Ast._
import com.daml.lf.language.Util._ import com.daml.lf.language.Util._
import com.daml.lf.testing.parser.Parsers._ import com.daml.lf.testing.parser.Parsers._
@ -57,7 +57,7 @@ private[parser] class TypeParser[P](parameters: ParserParameters[P]) {
id ~ `:` ~ typ ^^ { case name ~ _ ~ t => name -> t } id ~ `:` ~ typ ^^ { case name ~ _ ~ t => name -> t }
private lazy val tStruct: Parser[Type] = private lazy val tStruct: Parser[Type] =
`<` ~>! rep1sep(fieldType, `,`) <~ `>` ^^ (fs => TStruct(ImmArray(fs))) `<` ~>! rep1sep(fieldType, `,`) <~ `>` ^^ (fs => TStruct(Struct(fs: _*)))
private lazy val tTypeSynApp: Parser[Type] = private lazy val tTypeSynApp: Parser[Type] =
`|` ~> fullIdentifier ~ rep(typ0) <~ `|` ^^ { case id ~ tys => TSynApp(id, ImmArray(tys)) } `|` ~> fullIdentifier ~ rep(typ0) <~ `|` ^^ { case id ~ tys => TSynApp(id, ImmArray(tys)) }

View File

@ -6,7 +6,7 @@ package com.daml.lf.testing.parser
import java.math.BigDecimal import java.math.BigDecimal
import com.daml.lf.data.Ref._ import com.daml.lf.data.Ref._
import com.daml.lf.data.{ImmArray, Numeric, Time} import com.daml.lf.data.{ImmArray, Numeric, Struct, Time}
import com.daml.lf.language.Ast._ import com.daml.lf.language.Ast._
import com.daml.lf.language.Util._ import com.daml.lf.language.Util._
import com.daml.lf.testing.parser.Implicits._ import com.daml.lf.testing.parser.Implicits._
@ -97,8 +97,8 @@ class ParsersSpec extends WordSpec with TableDrivenPropertyChecks with Matchers
"a -> b" -> TApp(TApp(TBuiltin(BTArrow), α), β), "a -> b" -> TApp(TApp(TBuiltin(BTArrow), α), β),
"a -> b -> a" -> TApp(TApp(TBuiltin(BTArrow), α), TApp(TApp(TBuiltin(BTArrow), β), α)), "a -> b -> a" -> TApp(TApp(TBuiltin(BTArrow), α), TApp(TApp(TBuiltin(BTArrow), β), α)),
"forall (a: *). Mod:T a" -> TForall((α.name, KStar), TApp(T, α)), "forall (a: *). Mod:T a" -> TForall((α.name, KStar), TApp(T, α)),
"<f1: a, f2: Bool, f3:Mod:T>" -> TStruct( "<f1: a, f2: Bool, f3:Mod:T>" ->
ImmArray[(FieldName, Type)](n"f1" -> α, n"f2" -> TBuiltin(BTBool), n"f3" -> T)) TStruct(Struct(n"f1" -> α, n"f2" -> TBuiltin(BTBool), n"f3" -> T))
) )
forEvery(testCases)((stringToParse, expectedType) => forEvery(testCases)((stringToParse, expectedType) =>

View File

@ -366,7 +366,7 @@ object Repl {
case TForall((v, _), body) => case TForall((v, _), body) =>
maybeParens(prec > precTForall, "∀" + v + prettyForAll(body)) maybeParens(prec > precTForall, "∀" + v + prettyForAll(body))
case TStruct(fields) => case TStruct(fields) =>
"(" + fields "(" + fields.iterator
.map { case (n, t) => n + ": " + prettyType(t, precTForall) } .map { case (n, t) => n + ": " + prettyType(t, precTForall) }
.toSeq .toSeq
.mkString(", ") + ")" .mkString(", ") + ")"

View File

@ -4,7 +4,6 @@
package com.daml.lf.validation package com.daml.lf.validation
import com.daml.lf.language.Ast._ import com.daml.lf.language.Ast._
import com.daml.lf.validation.Util._
private[validation] object AlphaEquiv { private[validation] object AlphaEquiv {
@ -31,7 +30,7 @@ private[validation] object AlphaEquiv {
binderDepthRhs + (varName2 -> currentDepth) binderDepthRhs + (varName2 -> currentDepth)
).alphaEquiv(b1, b2) ).alphaEquiv(b1, b2)
case (TStruct(fs1), TStruct(fs2)) => case (TStruct(fs1), TStruct(fs2)) =>
(fs1.keys sameElements fs1.keys) && (fs1.names sameElements fs2.names) &&
(fs1.values zip fs2.values).forall((alphaEquiv _).tupled) (fs1.values zip fs2.values).forall((alphaEquiv _).tupled)
case _ => false case _ => false
} }

View File

@ -28,9 +28,7 @@ private[validation] object TypeSubst {
} else } else
TForall(v0 -> k, go(fv0 + v0, subst0 - v0, t)) TForall(v0 -> k, go(fv0 + v0, subst0 - v0, t))
case TStruct(ts) => case TStruct(ts) =>
TStruct(ts.transform { (_, x) => TStruct(ts.mapValue(go(fv0, subst0, _)))
go(fv0, subst0, x)
})
} }
private def freshTypeVarName(fv: Set[TypeVarName]): TypeVarName = private def freshTypeVarName(fv: Set[TypeVarName]): TypeVarName =

View File

@ -3,7 +3,7 @@
package com.daml.lf.validation package com.daml.lf.validation
import com.daml.lf.data.{ImmArray, Numeric} import com.daml.lf.data.{ImmArray, Numeric, Struct}
import com.daml.lf.data.Ref._ import com.daml.lf.data.Ref._
import com.daml.lf.language.Ast._ import com.daml.lf.language.Ast._
import com.daml.lf.language.Util._ import com.daml.lf.language.Util._
@ -119,8 +119,7 @@ private[validation] object Typing {
BTextMapToList -> BTextMapToList ->
TForall( TForall(
alpha.name -> KStar, alpha.name -> KStar,
TTextMap(alpha) ->: TList( TTextMap(alpha) ->: TList(TStruct(Struct(keyFieldName -> TText, valueFieldName -> alpha)))
TStruct(ImmArray(keyFieldName -> TText, valueFieldName -> alpha)))
), ),
BTextMapSize -> BTextMapSize ->
TForall( TForall(
@ -431,8 +430,8 @@ private[validation] object Typing {
case TForall((v, k), b) => case TForall((v, k), b) =>
introTypeVar(v, k).checkType(b, KStar) introTypeVar(v, k).checkType(b, KStar)
KStar KStar
case TStruct(recordType) => case TStruct(fields) =>
checkRecordType(recordType) checkRecordType(fields.toImmArray)
KStar KStar
} }
@ -453,9 +452,7 @@ private[validation] object Typing {
case TForall((v, k), b) => case TForall((v, k), b) =>
TForall((v, k), introTypeVar(v, k).expandTypeSynonyms(b)) TForall((v, k), introTypeVar(v, k).expandTypeSynonyms(b))
case TStruct(recordType) => case TStruct(recordType) =>
TStruct(recordType.transform { (_, x) => TStruct(recordType.mapValue(expandTypeSynonyms(_)))
expandTypeSynonyms(x)
})
} }
private def expandSynApp(syn: TypeSynName, tArgs: ImmArray[Type]): Type = { private def expandSynApp(syn: TypeSynName, tArgs: ImmArray[Type]): Type = {
@ -518,14 +515,12 @@ private[validation] object Typing {
private def typeOfStructCon(fields: ImmArray[(FieldName, Expr)]): Type = { private def typeOfStructCon(fields: ImmArray[(FieldName, Expr)]): Type = {
checkUniq[FieldName](fields.keys, EDuplicateField(ctx, _)) checkUniq[FieldName](fields.keys, EDuplicateField(ctx, _))
TStruct(fields.transform { (_, x) => TStruct(Struct(fields.iterator.map { case (f, x) => f -> typeOf(x) }.toSeq: _*))
typeOf(x)
})
} }
private def typeOfStructProj(field: FieldName, expr: Expr): Type = typeOf(expr) match { private def typeOfStructProj(field: FieldName, expr: Expr): Type = typeOf(expr) match {
case TStruct(structType) => case TStruct(structType) =>
structType.lookup(field, EUnknownField(ctx, field)) structType.lookup(field).getOrElse(throw EUnknownField(ctx, field))
case typ => case typ =>
throw EExpectedStructType(ctx, typ) throw EExpectedStructType(ctx, typ)
} }
@ -533,7 +528,7 @@ private[validation] object Typing {
private def typeOfStructUpd(field: FieldName, struct: Expr, update: Expr): Type = private def typeOfStructUpd(field: FieldName, struct: Expr, update: Expr): Type =
typeOf(struct) match { typeOf(struct) match {
case typ @ TStruct(structType) => case typ @ TStruct(structType) =>
checkExpr(update, structType.lookup(field, EUnknownField(ctx, field))) checkExpr(update, structType.lookup(field).getOrElse(throw EUnknownField(ctx, field)))
typ typ
case typ => case typ =>
throw EExpectedStructType(ctx, typ) throw EExpectedStructType(ctx, typ)
@ -763,7 +758,7 @@ private[validation] object Typing {
// fetches return the contract id and the contract itself // fetches return the contract id and the contract itself
TUpdate( TUpdate(
TStruct( TStruct(
ImmArray( Struct(
(contractIdFieldName, TContractId(TTyCon(retrieveByKey.templateId))), (contractIdFieldName, TContractId(TTyCon(retrieveByKey.templateId))),
(contractFieldName, TTyCon(retrieveByKey.templateId))))) (contractFieldName, TTyCon(retrieveByKey.templateId)))))
case UpdateLookupByKey(retrieveByKey) => case UpdateLookupByKey(retrieveByKey) =>