LF: Imporve safety of the Serialization of proto message. (#12686)

This is a follow up of #12638, applied to LF support for KV.

CHANGELOG_BEGIN
CHANGELOG_END
This commit is contained in:
Remy 2022-02-01 15:45:17 +01:00 committed by GitHub
parent b4ed15bab7
commit 183f936def
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 59 additions and 20 deletions

View File

@ -6,7 +6,10 @@ package archive
import java.io.File
sealed abstract class Error(val msg: String) extends RuntimeException(msg)
sealed abstract class Error(val msg: String)
extends RuntimeException(msg)
with Product
with Serializable
object Error {
@ -42,4 +45,6 @@ object Error {
extends Error(s"Unsupported file extension: ${file.getAbsolutePath}")
final case class Parsing(override val msg: String) extends Error(msg)
final case class Encoding(override val msg: String) extends Error(msg)
}

View File

@ -31,6 +31,7 @@ da_scala_library(
"//daml-lf/archive:daml_lf_archive_reader",
"//daml-lf/data",
"//daml-lf/language",
"//libs-scala/safe-proto",
"@maven//:com_google_protobuf_protobuf_java",
],
)

View File

@ -1,10 +1,12 @@
// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf.archive.testing
package com.daml.lf
package archive.testing
import com.daml.SafeProto
import java.security.MessageDigest
import com.daml.lf.data.Ref.PackageId
import com.daml.lf.language.Ast.Package
import com.daml.lf.language.{LanguageMajorVersion, LanguageVersion}
@ -35,7 +37,7 @@ object Encode {
final def encodeArchive(pkg: (PackageId, Package), version: LanguageVersion): PLF.Archive = {
val payload = encodePayloadOfVersion(pkg, version).toByteString
val payload = data.assertRight(SafeProto.toByteString(encodePayloadOfVersion(pkg, version)))
val hash = PackageId.assertFromString(
MessageDigest.getInstance("SHA-256").digest(payload.toByteArray).map("%02x" format _).mkString
)

View File

@ -26,6 +26,7 @@ da_scala_library(
"//daml-lf/transaction",
"//daml-lf/transaction:transaction_proto_java",
"//daml-lf/transaction:value_proto_java",
"//libs-scala/safe-proto",
"@maven//:com_google_protobuf_protobuf_java",
],
)

View File

@ -13,6 +13,8 @@ object ConversionError {
extends ConversionError(errorMessage)
final case class DecodeError(cause: ValueCoder.DecodeError)
extends ConversionError(cause.errorMessage)
final case class EncodeError(cause: ValueCoder.EncodeError)
extends ConversionError(cause.errorMessage)
final case class InternalError(override val errorMessage: String)
extends ConversionError(errorMessage)
}

View File

@ -3,6 +3,7 @@
package com.daml.lf.kv.archives
import com.daml.SafeProto
import com.daml.lf.archive.{ArchiveParser, Decode, Error => ArchiveError}
import com.daml.lf.data.Ref
import com.daml.lf.data.Ref.PackageId
@ -21,12 +22,15 @@ object ArchiveConversions {
def parsePackageIdsAndRawArchives(
archives: List[com.daml.daml_lf_dev.DamlLf.Archive]
): Either[ArchiveError.Parsing, Map[Ref.PackageId, RawArchive]] =
): Either[ArchiveError, Map[Ref.PackageId, RawArchive]] =
archives.partitionMap { archive =>
Ref.PackageId.fromString(archive.getHash).map(_ -> RawArchive(archive.toByteString))
for {
pkgId <- Ref.PackageId.fromString(archive.getHash).left.map(ArchiveError.Parsing)
bytes <- SafeProto.toByteString(archive).left.map(ArchiveError.Encoding)
} yield pkgId -> RawArchive(bytes)
} match {
case (Nil, hashesAndRawArchives) => Right(hashesAndRawArchives.toMap)
case (errors, _) => Left(ArchiveError.Parsing(errors.head))
case (errors, _) => Left(errors.head)
}
def decodePackages(

View File

@ -3,6 +3,7 @@
package com.daml.lf.kv.contracts
import com.daml.SafeProto
import com.daml.lf.kv.ConversionError
import com.daml.lf.transaction.{TransactionCoder, TransactionOuterClass}
import com.daml.lf.value.{Value, ValueCoder}
@ -14,9 +15,10 @@ object ContractConversions {
def encodeContractInstance(
coinst: Value.VersionedContractInstance
): Either[ValueCoder.EncodeError, RawContractInstance] =
TransactionCoder
.encodeContractInstance(ValueCoder.CidEncoder, coinst)
.map(contractInstance => RawContractInstance(contractInstance.toByteString))
for {
message <- TransactionCoder.encodeContractInstance(ValueCoder.CidEncoder, coinst)
bytes <- SafeProto.toByteString(message).left.map(ValueCoder.EncodeError(_))
} yield RawContractInstance(bytes)
def decodeContractInstance(
rawContractInstance: RawContractInstance

View File

@ -3,6 +3,7 @@
package com.daml.lf.kv.transactions
import com.daml.SafeProto
import com.daml.lf.data.{FrontStack, FrontStackCons, ImmArray}
import com.daml.lf.kv.ConversionError
import com.daml.lf.transaction.TransactionOuterClass.Node.NodeTypeCase
@ -25,9 +26,11 @@ object TransactionConversions {
def encodeTransaction(
tx: VersionedTransaction
): Either[ValueCoder.EncodeError, RawTransaction] =
TransactionCoder
.encodeTransaction(TransactionCoder.NidEncoder, ValueCoder.CidEncoder, tx)
.map(transaction => RawTransaction(transaction.toByteString))
for {
msg <-
TransactionCoder.encodeTransaction(TransactionCoder.NidEncoder, ValueCoder.CidEncoder, tx)
bytes <- SafeProto.toByteString(msg).left.map(ValueCoder.EncodeError(_))
} yield RawTransaction(bytes)
def decodeTransaction(
rawTx: RawTransaction
@ -63,7 +66,7 @@ object TransactionConversions {
def reconstructTransaction(
transactionVersion: String,
nodesWithIds: Seq[TransactionNodeIdWithNode],
): Either[ConversionError.ParseError, RawTransaction] = {
): Either[ConversionError, RawTransaction] = {
import scalaz.std.either._
import scalaz.std.list._
import scalaz.syntax.traverse._
@ -94,7 +97,14 @@ object TransactionConversions {
}
.toList
.sequence_
.map(_ => RawTransaction(transactionBuilder.build.toByteString))
.flatMap(_ =>
SafeProto.toByteString(transactionBuilder.build()) match {
case Right(bytes) =>
Right(RawTransaction(bytes))
case Left(msg) =>
Left(ConversionError.EncodeError(ValueCoder.EncodeError(msg)))
}
)
}
/** Decodes and extracts outputs of a submitted transaction, that is the IDs and keys of contracts created or updated
@ -210,7 +220,7 @@ object TransactionConversions {
}
}
goNodesToKeep(transaction.getRootsList.asScala.to(FrontStack), Set.empty).map {
goNodesToKeep(transaction.getRootsList.asScala.to(FrontStack), Set.empty).flatMap {
nodesToKeep =>
val filteredRoots = transaction.getRootsList.asScala.filter(nodesToKeep)
@ -239,7 +249,14 @@ object TransactionConversions {
.addAllNodes(filteredNodes.asJavaCollection)
.setVersion(transaction.getVersion)
.build()
RawTransaction(newTransaction.toByteString)
SafeProto.toByteString(newTransaction) match {
case Right(bytes) =>
Right(RawTransaction(bytes))
case Left(msg) =>
// Should not happen as removing nodes should results into a smaller transaction.
Left(ConversionError.InternalError(msg))
}
}
}
}

View File

@ -47,6 +47,11 @@ object TransactionTraversal {
case Left(error) => Left(ConversionError.DecodeError(error))
case Right(nodeWitnesses) =>
val witnesses = parentWitnesses union nodeWitnesses
// Here node.toByteString is safe.
// Indeed node is a submessage of the transaction `rawTx` we got serialized
// as input of `traverseTransactionWithWitnesses` and successfully decoded, i.e.
// `rawTx` requires less than 2GB to be serialized, so does <node`.
// See com.daml.SafeProto for more details about issues with the toByteString method.
f(nodeId, RawTransaction.Node(node.toByteString), witnesses)
// Recurse into children (if any).
node.getNodeTypeCase match {
@ -62,7 +67,7 @@ object TransactionTraversal {
}
}
private def informeesOfNode(
private[this] def informeesOfNode(
txVersion: TransactionVersion,
node: TransactionOuterClass.Node,
): Either[ValueCoder.DecodeError, Set[Ref.Party]] =

View File

@ -18,7 +18,7 @@ da_scala_library(
da_scala_test_suite(
name = "safe-protot-test",
srcs = glob(["src/test/scala/**/*.scala"]),
max_heap_size = "4g",
max_heap_size = "3g",
deps = [
":safe-proto",
"@maven//:com_google_protobuf_protobuf_java",

View File

@ -28,7 +28,7 @@ object SafeProto {
case e: RuntimeException
if e.isInstanceOf[NegativeArraySizeException] ||
e.getCause != null && e.getCause.isInstanceOf[CodedOutputStream.OutOfSpaceException] =>
Left(s"the ${message.getClass.getName} message is too big to be serialized")
Left(s"the ${message.getClass.getName} message is too large to be serialized")
}
def toByteString(message: AbstractMessageLite[_, _]): Either[String, ByteString] =