LF: prevent duplicate fields in Struct (#7299)

CHANGELOG_BEGIN
CHANGELOG_END
This commit is contained in:
Remy 2020-09-03 16:07:38 +02:00 committed by GitHub
parent 0bfb4ba1d2
commit d08de60d9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 198 additions and 78 deletions

View File

@ -639,7 +639,13 @@ private[archive] class DecodeV1(minor: LV.Minor) extends Decode.OfPackage[PLF.Pa
val struct = lfType.getStruct
val fields = struct.getFieldsList.asScala
assertNonEmpty(fields, "fields")
TStruct(Struct(fields.map(decodeFieldWithType): _*))
TStruct(
Struct
.fromSeq(fields.map(decodeFieldWithType))
.fold(
name => throw ParseError(s"TStruct: duplicate field $name"),
identity
))
case PLF.Type.SumCase.SUM_NOT_SET =>
throw ParseError("Type.SUM_NOT_SET")
}

View File

@ -119,6 +119,18 @@ class DecodeV1Spec
List(1, 4, 6, 8).map(i => LV.Minor.Stable(i.toString)): _*
)
private val preInterningVersions = Table(
"minVersion",
LV.Minor.Stable("6"),
)
private val postInterningVersions = Table(
"minVersion",
LV.Minor.Stable("7"),
LV.Minor.Stable("8"),
LV.Minor.Dev,
)
private val postContractIdTextConversionVersions = Table(
"minVersion",
// FIXME: https://github.com/digital-asset/daml/issues/7139
@ -261,6 +273,67 @@ class DecodeV1Spec
decoder.decodeType(buildPrimType(ANY)) shouldBe TAny
}
}
"reject Struct with duplicate field names" in {
val negativeTestCases =
Table("field names", List("a", "b", "c"))
val positiveTestCases =
Table("field names", List("a", "a"), List("a", "b", "c", "a"), List("a", "b", "c", "b"))
val unit = DamlLf1.Type
.newBuilder()
.setPrim(DamlLf1.Type.Prim.newBuilder().setPrim(DamlLf1.PrimType.UNIT))
.build
def fieldWithUnitWithoutInterning(s: String) =
DamlLf1.FieldWithType.newBuilder().setFieldStr(s).setType(unit)
def buildTStructWithoutInterning(fields: Seq[String]) =
DamlLf1.Type
.newBuilder()
.setStruct(
fields.foldLeft(DamlLf1.Type.Struct.newBuilder())((builder, name) =>
builder.addFields(fieldWithUnitWithoutInterning(name)))
)
.build()
val stringTable = ImmArraySeq("a", "b", "c")
val stringIdx = stringTable.zipWithIndex.toMap
def fieldWithUnitWithInterning(s: String) =
DamlLf1.FieldWithType.newBuilder().setFieldInternedStr(stringIdx(s)).setType(unit)
def buildTStructWithInterning(fields: Seq[String]) =
DamlLf1.Type
.newBuilder()
.setStruct(
fields.foldLeft(DamlLf1.Type.Struct.newBuilder())((builder, name) =>
builder.addFields(fieldWithUnitWithInterning(name)))
)
.build()
forEvery(preInterningVersions) { minVersion =>
val decoder = moduleDecoder(minVersion)
forEvery(negativeTestCases) { fieldNames =>
decoder.decodeType(buildTStructWithoutInterning(fieldNames))
}
forEvery(positiveTestCases) { fieldNames =>
a[ParseError] shouldBe thrownBy(
decoder.decodeType(buildTStructWithoutInterning(fieldNames)))
}
}
forEvery(postInterningVersions) { minVersion =>
val decoder = moduleDecoder(minVersion, stringTable)
forEvery(negativeTestCases) { fieldNames =>
decoder.decodeType(buildTStructWithInterning(fieldNames))
}
forEvery(positiveTestCases) { fieldNames =>
a[ParseError] shouldBe thrownBy(decoder.decodeType(buildTStructWithInterning(fieldNames)))
}
}
}
}
"decodeExpr" should {

View File

@ -6,43 +6,75 @@ package data
import com.daml.lf.data.Ref.Name
final case class Struct[+X] private (sortedFields: ImmArray[(Ref.Name, X)]) extends NoCopy {
/** We use this container to describe structural record as sorted flat list in various parts of the codebase.
`entries` are sorted by their first component without duplicate.
*/
final case class Struct[+X] private (private val sortedFields: ImmArray[(Ref.Name, X)])
extends NoCopy {
def lookup(name: Ref.Name): Option[X] =
sortedFields.find(_._1 == name).map(_._2)
/** O(n) */
@throws[IndexOutOfBoundsException]
def apply(name: Ref.Name): X = sortedFields(indexOf(name))._2
def mapValue[Y](f: X => Y) = new Struct(sortedFields.map { case (k, v) => k -> f(v) })
/** O(n) */
def indexOf(name: Ref.Name): Int = sortedFields.indexWhere(_._1 == name)
/** O(n) */
def lookup(name: Ref.Name): Option[X] = sortedFields.find(_._1 == name).map(_._2)
/** O(n) */
def mapValues[Y](f: X => Y) = new Struct(sortedFields.map { case (k, v) => k -> f(v) })
/** O(1) */
def toImmArray: ImmArray[(Ref.Name, X)] = sortedFields
/** O(1) */
def names: Iterator[Ref.Name] = iterator.map(_._1)
/** O(1) */
def values: Iterator[X] = iterator.map(_._2)
/** O(1) */
def iterator: Iterator[(Ref.Name, X)] = sortedFields.iterator
def foreach(f: ((Ref.Name, X)) => Unit): Unit = sortedFields.foreach(f)
/** O(1) */
def size: Int = sortedFields.length
/** O(n) */
override def toString: String = iterator.mkString("Struct(", ",", ")")
}
object Struct {
def apply[X](fields: (Name, X)*): Struct[X] =
new Struct(fields.sortBy(_._1: String).to[ImmArray])
/** Constructs a Struct.
* In case one of the field name is duplicated, return it as Left.
* O(n log n)
*/
def fromSeq[X](fields: Seq[(Name, X)]): Either[Name, Struct[X]] =
if (fields.isEmpty) rightEmpty
else {
val struct = Struct(ImmArray(fields.sortBy(_._1: String)))
val names = struct.names
var previous = names.next()
names
.find { name =>
val found = name == previous
previous = name
found
}
.toLeft(struct)
}
def apply[X](fields: ImmArray[(Name, X)]): Struct[X] = apply(fields.toSeq: _*)
def fromSortedImmArray[X](fields: ImmArray[(Ref.Name, X)]): Either[String, Struct[X]] = {
val struct = new Struct(fields)
Either.cond(
(struct.names zip struct.names.drop(1)).forall { case (x, y) => (x compare y) <= 0 },
struct,
s"the list $fields is not sorted by name"
def assertFromSeq[X](fields: Seq[(Name, X)]): Struct[X] =
fromSeq(fields).fold(
name =>
throw new IllegalArgumentException(s"name $name duplicated when trying to build Struct"),
identity,
)
}
private[this] val Emtpy = new Struct(ImmArray.empty)
val Empty: Struct[Nothing] = new Struct(ImmArray.empty)
def empty[X]: Struct[X] = Emtpy.asInstanceOf[Struct[X]]
private[this] val rightEmpty = Right(Empty)
}

View File

@ -10,51 +10,56 @@ class StructSpec extends WordSpec with Matchers with PropertyChecks {
private[this] val List(f1, f2, f3) = List("f1", "f2", "f3").map(Ref.Name.assertFromString)
"SortMap.fromSortedImmArray" should {
"SortMap.toSeq" should {
"fail if the input list is not sorted" in {
val negativeTestCases =
Table(
"list",
ImmArray.empty,
ImmArray(f1 -> 1),
ImmArray(f1 -> 1, f1 -> 2),
ImmArray(f1 -> 1, f2 -> 2, f3 -> 3),
)
val positiveTestCases = Table(
"list",
ImmArray(f1 -> 1, f2 -> 2, f3 -> 3, f1 -> 2),
ImmArray(f2 -> 2, f3 -> 3, f1 -> 1),
)
forEvery(negativeTestCases)(l => Struct.fromSortedImmArray(l) shouldBe 'right)
forEvery(positiveTestCases)(l => Struct.fromSortedImmArray(l) shouldBe 'left)
}
}
"SortMap.apply" should {
"sorted fields" in {
"sort fields" in {
val testCases =
Table(
"list",
Struct(),
Struct(f1 -> 1),
Struct(f1 -> 1, f1 -> 2),
Struct(f1 -> 1, f2 -> 2, f3 -> 3),
Struct(f1 -> 1, f2 -> 2, f3 -> 3, f1 -> 2),
Struct(f2 -> 2, f3 -> 3, f1 -> 1),
List(),
List(f1 -> 1),
List(f1 -> 1, f2 -> 2, f3 -> 3),
List(f2 -> 1, f3 -> 2, f1 -> 3),
List(f3 -> 2, f2 -> 3, f1 -> 1),
)
forEvery(testCases) { s =>
s.names.toSeq shouldBe s.names.toSeq.sorted[String]
forEvery(testCases) { list =>
val struct = Struct.assertFromSeq(list)
(struct.names zip struct.names.drop(1)).foreach {
case (x, y) => (x: String) shouldBe <(y: String)
}
}
}
"""reject struct with duplicate name.""" in {
val testCases =
Table(
"list",
List(f1 -> 1, f1 -> 1),
List(f1 -> 1, f1 -> 2),
List(f1 -> 1, f2 -> 2, f3 -> 3, f1 -> 2),
List(f2 -> 2, f3 -> 3, f2 -> 1, f3 -> 4, f3 -> 0),
)
forEvery(testCases) { list =>
Struct.fromSeq(list) shouldBe 'left
}
}
}
"Struct" should {
"be equal if built in different order" in {
Struct.fromSeq(List(f1 -> 1, f2 -> 2)) shouldBe
Struct.fromSeq(List(f2 -> 2, f1 -> 1))
Struct.fromSeq(List(f1 -> 1, f2 -> 2, f3 -> 3)) shouldBe
Struct.fromSeq(List(f2 -> 2, f1 -> 1, f3 -> 3))
Struct.fromSeq(List(f1 -> 1, f2 -> 2, f3 -> 3)) shouldBe
Struct.fromSeq(List(f3 -> 3, f1 -> 1, f2 -> 2))
}
}
}

View File

@ -167,24 +167,26 @@ class OrderingSpec
} yield STypeRep(Ast.TTyCon(Ref.Identifier(pkgId, Ref.QualifiedName(mod, name))))
}
private def struct[X](fields: (Ref.Name, X)*) = Struct.assertFromSeq(fields)
private val typeStructReps = List(
Struct.empty,
Struct(Ref.Name.assertFromString("field0") -> AstUtil.TUnit),
Struct(Ref.Name.assertFromString("field0") -> AstUtil.TInt64),
Struct(Ref.Name.assertFromString("field1") -> AstUtil.TUnit),
Struct(
Struct.Empty,
struct(Ref.Name.assertFromString("field0") -> AstUtil.TUnit),
struct(Ref.Name.assertFromString("field0") -> AstUtil.TInt64),
struct(Ref.Name.assertFromString("field1") -> AstUtil.TUnit),
struct(
Ref.Name.assertFromString("field1") -> AstUtil.TUnit,
Ref.Name.assertFromString("field2") -> AstUtil.TUnit,
),
Struct(
struct(
Ref.Name.assertFromString("field1") -> AstUtil.TUnit,
Ref.Name.assertFromString("field2") -> AstUtil.TInt64,
),
Struct(
struct(
Ref.Name.assertFromString("field1") -> AstUtil.TInt64,
Ref.Name.assertFromString("field2") -> AstUtil.TUnit,
),
Struct(
struct(
Ref.Name.assertFromString("field1") -> AstUtil.TUnit,
Ref.Name.assertFromString("field3") -> AstUtil.TUnit,
),
@ -206,8 +208,8 @@ class OrderingSpec
Ast.TTyCon(VariantTypeCon),
Ast.TNat(Numeric.Scale.MinValue),
Ast.TNat(Numeric.Scale.MaxValue),
Ast.TStruct(Struct.empty),
Ast.TStruct(Struct(Ref.Name.assertFromString("field") -> AstUtil.TUnit)),
Ast.TStruct(Struct.Empty),
Ast.TStruct(struct(Ref.Name.assertFromString("field") -> AstUtil.TUnit)),
Ast.TApp(Ast.TBuiltin(Ast.BTArrow), Ast.TBuiltin(Ast.BTUnit)),
Ast.TApp(
Ast.TApp(Ast.TBuiltin(Ast.BTArrow), Ast.TBuiltin(Ast.BTUnit)),

View File

@ -36,6 +36,5 @@ da_scala_test(
":parser",
"//daml-lf/data",
"//daml-lf/language",
"@maven//:org_scalaz_scalaz_core_2_12",
],
)

View File

@ -42,7 +42,7 @@ private[daml] class AstRewriter(
case TForall(binder, body) =>
TForall(binder, apply(body))
case TStruct(fields) =>
TStruct(fields.mapValue(apply))
TStruct(fields.mapValues(apply))
}
def apply(nameWithType: (Name, Type)): (Name, Type) = nameWithType match {

View File

@ -57,7 +57,7 @@ private[parser] class TypeParser[P](parameters: ParserParameters[P]) {
id ~ `:` ~ typ ^^ { case name ~ _ ~ t => name -> t }
private lazy val tStruct: Parser[Type] =
`<` ~>! rep1sep(fieldType, `,`) <~ `>` ^^ (fs => TStruct(Struct(fs: _*)))
`<` ~>! rep1sep(fieldType, `,`) <~ `>` ^^ (fs => TStruct(Struct.assertFromSeq(fs)))
private lazy val tTypeSynApp: Parser[Type] =
`|` ~> fullIdentifier ~ rep(typ0) <~ `|` ^^ { case id ~ tys => TSynApp(id, ImmArray(tys)) }

View File

@ -98,7 +98,7 @@ class ParsersSpec extends WordSpec with TableDrivenPropertyChecks with Matchers
"a -> b -> a" -> TApp(TApp(TBuiltin(BTArrow), α), TApp(TApp(TBuiltin(BTArrow), β), α)),
"forall (a: *). Mod:T a" -> TForall((α.name, KStar), TApp(T, α)),
"<f1: a, f2: Bool, f3:Mod:T>" ->
TStruct(Struct(n"f1" -> α, n"f2" -> TBuiltin(BTBool), n"f3" -> T))
TStruct(Struct.assertFromSeq(List(n"f1" -> α, n"f2" -> TBuiltin(BTBool), n"f3" -> T)))
)
forEvery(testCases)((stringToParse, expectedType) =>

View File

@ -28,7 +28,7 @@ private[validation] object TypeSubst {
} else
TForall(v0 -> k, go(fv0 + v0, subst0 - v0, t))
case TStruct(ts) =>
TStruct(ts.mapValue(go(fv0, subst0, _)))
TStruct(ts.mapValues(go(fv0, subst0, _)))
}
private def freshTypeVarName(fv: Set[TypeVarName]): TypeVarName =

View File

@ -119,7 +119,8 @@ private[validation] object Typing {
BTextMapToList ->
TForall(
alpha.name -> KStar,
TTextMap(alpha) ->: TList(TStruct(Struct(keyFieldName -> TText, valueFieldName -> alpha)))
TTextMap(alpha) ->: TList(
TStruct(Struct.assertFromSeq(List(keyFieldName -> TText, valueFieldName -> alpha))))
),
BTextMapSize ->
TForall(
@ -452,7 +453,7 @@ private[validation] object Typing {
case TForall((v, k), b) =>
TForall((v, k), introTypeVar(v, k).expandTypeSynonyms(b))
case TStruct(recordType) =>
TStruct(recordType.mapValue(expandTypeSynonyms(_)))
TStruct(recordType.mapValues(expandTypeSynonyms(_)))
}
private def expandSynApp(syn: TypeSynName, tArgs: ImmArray[Type]): Type = {
@ -513,10 +514,10 @@ private[validation] object Typing {
throw EExpectedRecordType(ctx, typ0)
}
private def typeOfStructCon(fields: ImmArray[(FieldName, Expr)]): Type = {
checkUniq[FieldName](fields.keys, EDuplicateField(ctx, _))
TStruct(Struct(fields.iterator.map { case (f, x) => f -> typeOf(x) }.toSeq: _*))
}
private def typeOfStructCon(fields: ImmArray[(FieldName, Expr)]): Type =
Struct
.fromSeq(fields.iterator.map { case (f, x) => f -> typeOf(x) }.toSeq)
.fold(name => throw EDuplicateField(ctx, name), TStruct)
private def typeOfStructProj(field: FieldName, expr: Expr): Type = typeOf(expr) match {
case TStruct(structType) =>
@ -758,9 +759,11 @@ private[validation] object Typing {
// fetches return the contract id and the contract itself
TUpdate(
TStruct(
Struct(
(contractIdFieldName, TContractId(TTyCon(retrieveByKey.templateId))),
(contractFieldName, TTyCon(retrieveByKey.templateId)))))
Struct.assertFromSeq(
List(
contractIdFieldName -> TContractId(TTyCon(retrieveByKey.templateId)),
contractFieldName -> TTyCon(retrieveByKey.templateId)
))))
case UpdateLookupByKey(retrieveByKey) =>
checkRetrieveByKey(retrieveByKey)
TUpdate(TOptional(TContractId(TTyCon(retrieveByKey.templateId))))