From 62c9864514d736e1987015877735658ff95a4660 Mon Sep 17 00:00:00 2001 From: Carl Pulley <106966370+carlpulley-da@users.noreply.github.com> Date: Tue, 29 Aug 2023 20:46:01 +0100 Subject: [PATCH] Verification of the Contract State Machine in UCK mode (#17293) (#17312) PR of Andrea Gilot's work from #17293 --------- Co-authored-by: Andrea Gilot <126685183+andreagilot-da@users.noreply.github.com> --- canton_dep.bzl | 4 +- .../lf/transaction/PartialTransaction.scala | 6 +- .../lf/transaction/ContractStateMachine.scala | 121 +- .../daml/lf/transaction/Transaction.scala | 5 +- .../ContractStateMachineSpec.scala | 40 +- daml-lf/verification/README.md | 62 + daml-lf/verification/latex/proof.tex | 899 +++++++ .../scripts/stainless_imports.txt | 27 + .../scripts/verification_script.sh | 85 + daml-lf/verification/stainless.nix | 17 + .../verification/transaction/CSMAdvance.scala | 307 +++ .../verification/transaction/CSMEither.scala | 673 ++++++ .../verification/transaction/CSMHelpers.scala | 174 ++ .../transaction/CSMInconsistency.scala | 698 ++++++ .../transaction/CSMInvariant.scala | 825 +++++++ .../transaction/CSMKeysProperties.scala | 1096 +++++++++ .../CSMLocallyCreatedProperties.scala | 198 ++ .../transaction/ContractStateMachineAlt.scala | 334 +++ .../translation/CSMConversion.scala | 626 +++++ .../verification/tree/TransactionTree.scala | 532 +++++ .../tree/TransactionTreeAdvance.scala | 480 ++++ .../tree/TransactionTreeChecks.scala | 1175 +++++++++ .../tree/TransactionTreeFull.scala | 1004 ++++++++ .../tree/TransactionTreeInconsistency.scala | 1086 +++++++++ .../tree/TransactionTreeInvariant.scala | 270 +++ .../tree/TransactionTreeKeys.scala | 1109 +++++++++ daml-lf/verification/utils/AxiomaticMap.scala | 257 ++ daml-lf/verification/utils/AxiomaticSet.scala | 287 +++ daml-lf/verification/utils/GlobalKey.scala | 14 + daml-lf/verification/utils/Helpers.scala | 45 + .../utils/InvListProperties.scala | 184 ++ .../verification/utils/MapProperties.scala | 1104 +++++++++ daml-lf/verification/utils/Node.scala | 57 + .../verification/utils/SetProperties.scala | 2106 +++++++++++++++++ daml-lf/verification/utils/Transaction.scala | 57 + daml-lf/verification/utils/Tree.scala | 655 +++++ daml-lf/verification/utils/Value.scala | 14 + 37 files changed, 16548 insertions(+), 85 deletions(-) create mode 100644 daml-lf/verification/README.md create mode 100644 daml-lf/verification/latex/proof.tex create mode 100644 daml-lf/verification/scripts/stainless_imports.txt create mode 100755 daml-lf/verification/scripts/verification_script.sh create mode 100644 daml-lf/verification/stainless.nix create mode 100644 daml-lf/verification/transaction/CSMAdvance.scala create mode 100644 daml-lf/verification/transaction/CSMEither.scala create mode 100644 daml-lf/verification/transaction/CSMHelpers.scala create mode 100644 daml-lf/verification/transaction/CSMInconsistency.scala create mode 100644 daml-lf/verification/transaction/CSMInvariant.scala create mode 100644 daml-lf/verification/transaction/CSMKeysProperties.scala create mode 100644 daml-lf/verification/transaction/CSMLocallyCreatedProperties.scala create mode 100644 daml-lf/verification/transaction/ContractStateMachineAlt.scala create mode 100644 daml-lf/verification/translation/CSMConversion.scala create mode 100644 daml-lf/verification/tree/TransactionTree.scala create mode 100644 daml-lf/verification/tree/TransactionTreeAdvance.scala create mode 100644 daml-lf/verification/tree/TransactionTreeChecks.scala create mode 100644 daml-lf/verification/tree/TransactionTreeFull.scala create mode 100644 daml-lf/verification/tree/TransactionTreeInconsistency.scala create mode 100644 daml-lf/verification/tree/TransactionTreeInvariant.scala create mode 100644 daml-lf/verification/tree/TransactionTreeKeys.scala create mode 100644 daml-lf/verification/utils/AxiomaticMap.scala create mode 100644 daml-lf/verification/utils/AxiomaticSet.scala create mode 100644 daml-lf/verification/utils/GlobalKey.scala create mode 100644 daml-lf/verification/utils/Helpers.scala create mode 100644 daml-lf/verification/utils/InvListProperties.scala create mode 100644 daml-lf/verification/utils/MapProperties.scala create mode 100644 daml-lf/verification/utils/Node.scala create mode 100644 daml-lf/verification/utils/SetProperties.scala create mode 100644 daml-lf/verification/utils/Transaction.scala create mode 100644 daml-lf/verification/utils/Tree.scala create mode 100644 daml-lf/verification/utils/Value.scala diff --git a/canton_dep.bzl b/canton_dep.bzl index 7182eebf30..ad475a794e 100644 --- a/canton_dep.bzl +++ b/canton_dep.bzl @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 canton = { - "sha": "bc2fd1611ffbc3b990909743f3d6dd9e98af071e5de00f6f54719b76c852fdbd", - "url": "https://www.canton.io/releases/canton-open-source-2.8.0-snapshot.20230828.11069.0.v0819eb02.tar.gz", + "sha": "9cf983138e0a579b9cc713c6b0cd971c76be9c34031c2218f0fcecf4678d7580", + "url": "https://www.canton.io/releases/canton-open-source-2.8.0-snapshot.20230829.11082.0.v69754d9b.tar.gz", "local": False, } diff --git a/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/transaction/PartialTransaction.scala b/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/transaction/PartialTransaction.scala index 30bdf35b0e..9efd4b5b9e 100644 --- a/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/transaction/PartialTransaction.scala +++ b/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/transaction/PartialTransaction.scala @@ -187,7 +187,7 @@ private[lf] object PartialTransaction { nodes = HashMap.empty, actionNodeSeeds = BackStack.empty, context = Context(initialSeeds, committers), - contractState = new ContractStateMachine[NodeId](contractKeyUniqueness).initial, + contractState = ContractStateMachine.initial[NodeId](contractKeyUniqueness), actionNodeLocations = BackStack.empty, authorizationChecker = authorizationChecker, ) @@ -221,7 +221,7 @@ private[speedy] case class PartialTransaction( nodes: HashMap[NodeId, Node], actionNodeSeeds: BackStack[crypto.Hash], context: PartialTransaction.Context, - contractState: ContractStateMachine[NodeId]#State, + contractState: ContractStateMachine.State[NodeId], actionNodeLocations: BackStack[Option[Location]], authorizationChecker: AuthorizationChecker, ) { @@ -652,7 +652,7 @@ private[speedy] case class PartialTransaction( node: Node.LeafOnlyAction, version: TxVersion, optLocation: Option[Location], - newContractState: ContractStateMachine[NodeId]#State, + newContractState: ContractStateMachine.State[NodeId], ): PartialTransaction = { val _ = version val nid = NodeId(nextNodeIdx) diff --git a/daml-lf/transaction/src/main/scala/com/digitalasset/daml/lf/transaction/ContractStateMachine.scala b/daml-lf/transaction/src/main/scala/com/digitalasset/daml/lf/transaction/ContractStateMachine.scala index a2abde2ed9..e845b8c0ad 100644 --- a/daml-lf/transaction/src/main/scala/com/digitalasset/daml/lf/transaction/ContractStateMachine.scala +++ b/daml-lf/transaction/src/main/scala/com/digitalasset/daml/lf/transaction/ContractStateMachine.scala @@ -31,10 +31,7 @@ import com.daml.lf.value.Value.ContractId * [[com.daml.lf.transaction.ContractKeyUniquenessMode.Strict]] and * @see ContractStateMachineSpec.visitSubtree for iteration in all modes */ -class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { - import ContractStateMachine._ - - def initial: State = State.empty +object ContractStateMachine { /** @param locallyCreated * Tracks all contracts created by a node processed so far (including nodes under a rollback). @@ -82,11 +79,12 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { * keyset are in [[activeState]].[[ActiveLedgerState.keys]], * and similarly for all [[ActiveLedgerState]]s in [[rollbackStack]]. */ - case class State private ( + case class State[Nid] private[lf] ( locallyCreated: Set[ContractId], globalKeyInputs: Map[GlobalKey, KeyInput], - activeState: ActiveLedgerState[Nid], - rollbackStack: List[ActiveLedgerState[Nid]], + activeState: ContractStateMachine.ActiveLedgerState[Nid], + rollbackStack: List[ContractStateMachine.ActiveLedgerState[Nid]], + mode: ContractKeyUniquenessMode, ) { /** The return value indicates if the given contract is either consumed, inactive, or otherwise @@ -106,8 +104,6 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { } } - def mode: ContractKeyUniquenessMode = ContractStateMachine.this.mode - /** Lookup the given key k. Returns * - Some(KeyActive(cid)) if k maps to cid and cid is active. * - Some(KeyInactive) if there is no active contract with the given key. @@ -124,13 +120,13 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { } /** Visit a create node */ - def handleCreate(node: Node.Create): Either[KeyInputError, State] = + def handleCreate(node: Node.Create): Either[KeyInputError, State[Nid]] = visitCreate(node.coid, node.gkeyOpt).left.map(Right(_)) private[lf] def visitCreate( contractId: ContractId, mbKey: Option[GlobalKey], - ): Either[DuplicateContractKey, State] = { + ): Either[DuplicateContractKey, State[Nid]] = { val me = this.copy( locallyCreated = locallyCreated + contractId, @@ -160,7 +156,7 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { } } - def handleExercise(nid: Nid, exe: Node.Exercise): Either[KeyInputError, State] = + def handleExercise(nid: Nid, exe: Node.Exercise): Either[KeyInputError, State[Nid]] = visitExercise( nid, exe.targetCoid, @@ -180,7 +176,7 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { mbKey: Option[GlobalKey], byKey: Boolean, consuming: Boolean, - ): Either[InconsistentContractKey, State] = { + ): Either[InconsistentContractKey, State[Nid]] = { for { state <- if (byKey || mode == ContractKeyUniquenessMode.Strict) @@ -197,7 +193,7 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { /** Must be used to handle lookups iff in [[com.daml.lf.transaction.ContractKeyUniquenessMode.Strict]] mode */ - def handleLookup(lookup: Node.LookupByKey): Either[KeyInputError, State] = { + def handleLookup(lookup: Node.LookupByKey): Either[KeyInputError, State[Nid]] = { // If the key has not yet been resolved, we use the resolution from the lookup node, // but this only makes sense if `activeState.keys` is updated by every node and not only by by-key nodes. if (mode != ContractKeyUniquenessMode.Strict) @@ -221,7 +217,7 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { def handleLookupWith( lookup: Node.LookupByKey, keyInput: Option[ContractId], - ): Either[KeyInputError, State] = { + ): Either[KeyInputError, State[Nid]] = { if (mode != ContractKeyUniquenessMode.Off) throw new UnsupportedOperationException( "handleLookupWith can only be used if only by-key nodes are considered" @@ -233,7 +229,7 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { gk: GlobalKey, keyInput: Option[ContractId], keyResolution: Option[ContractId], - ): Either[InconsistentContractKey, State] = { + ): Either[InconsistentContractKey, State[Nid]] = { val (keyMapping, next) = resolveKey(gk) match { case Right(result) => result case Left(handle) => handle(keyInput) @@ -247,13 +243,13 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { private[lf] def resolveKey( gkey: GlobalKey - ): Either[Option[ContractId] => (KeyMapping, State), (KeyMapping, State)] = { + ): Either[Option[ContractId] => (KeyMapping, State[Nid]), (KeyMapping, State[Nid])] = { lookupActiveKey(gkey) match { case Some(keyMapping) => Right(keyMapping -> this) case None => // if we cannot find it here, send help, and make sure to update keys after // that. - def handleResult(result: Option[ContractId]): (KeyMapping, State) = { + def handleResult(result: Option[ContractId]): (KeyMapping, State[Nid]) = { // Update key inputs. Create nodes never call this method, // so NegativeKeyLookup is the right choice for the global key input. val keyInput = result match { @@ -273,23 +269,23 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { } } - def handleFetch(node: Node.Fetch): Either[KeyInputError, State] = + def handleFetch(node: Node.Fetch): Either[KeyInputError, State[Nid]] = visitFetch(node.coid, node.gkeyOpt, node.byKey).left.map(Left(_)) private[lf] def visitFetch( contractId: ContractId, mbKey: Option[GlobalKey], byKey: Boolean, - ): Either[InconsistentContractKey, State] = + ): Either[InconsistentContractKey, State[Nid]] = if (byKey || mode == ContractKeyUniquenessMode.Strict) assertKeyMapping(contractId, mbKey) else Right(this) - private[this] def assertKeyMapping( + private[lf] def assertKeyMapping( cid: Value.ContractId, mbKey: Option[GlobalKey], - ): Either[InconsistentContractKey, State] = + ): Either[InconsistentContractKey, State[Nid]] = mbKey match { case None => Right(this) case Some(gk) => @@ -309,7 +305,7 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { id: Nid, node: Node.Action, keyInput: => Option[ContractId], - ): Either[KeyInputError, State] = node match { + ): Either[KeyInputError, State[Nid]] = node match { case create: Node.Create => handleCreate(create) case fetch: Node.Fetch => handleFetch(fetch) case lookup: Node.LookupByKey => @@ -324,13 +320,13 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { /** To be called when interpretation enters a try block or iteration enters a Rollback node * Must be matched by [[endRollback]] or [[dropRollback]]. */ - def beginRollback(): State = - this.copy(rollbackStack = activeState +: rollbackStack) + def beginRollback(): State[Nid] = + this.copy(rollbackStack = activeState :: rollbackStack) /** To be called when interpretation does insert a Rollback node or iteration leaves a Rollback node. * Must be matched by a [[beginRollback]]. */ - def endRollback(): State = rollbackStack match { + def endRollback(): State[Nid] = rollbackStack match { case Nil => throw new IllegalStateException("Not inside a rollback scope") case headState :: tailStack => this.copy(activeState = headState, rollbackStack = tailStack) } @@ -338,12 +334,12 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { /** To be called if interpretation notices that a try block did not lead to a Rollback node * Must be matched by a [[beginRollback]]. */ - def dropRollback(): State = rollbackStack match { + def dropRollback(): State[Nid] = rollbackStack match { case Nil => throw new IllegalStateException("Not inside a rollback scope") case _ :: tailStack => this.copy(rollbackStack = tailStack) } - private def withinRollbackScope: Boolean = rollbackStack.nonEmpty + private[lf] def withinRollbackScope: Boolean = rollbackStack.nonEmpty /** Let `this` state be the result of iterating over a transaction `tx` until just before a node `n`. * Let `substate` be the state obtained after fully iterating over the subtree rooted at `n` @@ -368,28 +364,32 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { * [[com.daml.lf.transaction.ContractKeyUniquenessMode.Strict]] and * @see ContractStateMachineSpec.visitSubtree for iteration in all modes */ - def advance(resolver: KeyResolver, substate: State): Either[KeyInputError, State] = { + def advance(resolver: KeyResolver, substate: State[Nid]): Either[KeyInputError, State[Nid]] = { require( !substate.withinRollbackScope, "Cannot lift a state over a substate with unfinished rollback scopes", ) // We want consistent key lookups within an action in any contract key mode. - def consistentGlobalKeyInputs: Either[KeyInputError, Unit] = { + def consistentGlobalKeyInputs: Either[KeyInputError, Unit] = substate.globalKeyInputs - .collectFirst { - case (key, KeyCreate) - if lookupActiveKey(key).exists(_ != KeyInactive) && - mode == ContractKeyUniquenessMode.Strict => - Right(DuplicateContractKey(key)) - case (key, NegativeKeyLookup) if lookupActiveKey(key).exists(_ != KeyInactive) => - Left(InconsistentContractKey(key)) - case (key, Transaction.KeyActive(cid)) - if lookupActiveKey(key).exists(_ != KeyActive(cid)) => - Left(InconsistentContractKey(key)) - } - .toLeft(()) - } + .find { + case (key, KeyCreate) => + lookupActiveKey(key).exists( + _ != KeyInactive + ) && mode == ContractKeyUniquenessMode.Strict + case (key, NegativeKeyLookup) => lookupActiveKey(key).exists(_ != KeyInactive) + case (key, Transaction.KeyActive(cid)) => + lookupActiveKey(key).exists(_ != KeyActive(cid)) + case _ => false + } match { + case Some((key, KeyCreate)) => Left[KeyInputError, Unit](Right(DuplicateContractKey(key))) + case Some((key, NegativeKeyLookup)) => + Left[KeyInputError, Unit](Left(InconsistentContractKey(key))) + case Some((key, Transaction.KeyActive(_))) => + Left[KeyInputError, Unit](Left(InconsistentContractKey(key))) + case _ => Right[KeyInputError, Unit](()) + } for { _ <- consistentGlobalKeyInputs @@ -446,11 +446,16 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) { } object State { - val empty: State = new State(Set.empty, Map.empty, ActiveLedgerState.empty, List.empty) + def empty[Nid](mode: ContractKeyUniquenessMode): State[Nid] = new State( + Set.empty, + Map.empty, + ContractStateMachine.ActiveLedgerState.empty, + List.empty, + mode, + ) } -} -object ContractStateMachine { + def initial[Nid](mode: ContractKeyUniquenessMode): State[Nid] = State.empty(mode) /** Represents the answers for [[com.daml.lf.engine.ResultNeedKey]] requests * that may arise during Daml interpretation. @@ -479,12 +484,12 @@ object ContractStateMachine { * was consumed or not. That information is stored in consumedBy. * It also _only_ includes local contracts not global contracts. */ - final case class ActiveLedgerState[+Nid]( + final case class ActiveLedgerState[Nid]( locallyCreatedThisTimeline: Set[ContractId], consumedBy: Map[ContractId, Nid], - private val localKeys: Map[GlobalKey, Value.ContractId], + private[lf] val localKeys: Map[GlobalKey, Value.ContractId], ) { - def consume[Nid2 >: Nid](contractId: ContractId, nodeId: Nid2): ActiveLedgerState[Nid2] = + def consume(contractId: ContractId, nodeId: Nid): ActiveLedgerState[Nid] = this.copy(consumedBy = consumedBy.updated(contractId, nodeId)) def createKey(key: GlobalKey, cid: Value.ContractId): ActiveLedgerState[Nid] = @@ -492,16 +497,16 @@ object ContractStateMachine { /** Equivalence relative to locallyCreatedThisTimeline, consumedBy & localActiveKeys. */ - def isEquivalent[Nid2 >: Nid](other: ActiveLedgerState[Nid2]): Boolean = + def isEquivalent(other: ActiveLedgerState[Nid]): Boolean = this.locallyCreatedThisTimeline == other.locallyCreatedThisTimeline && this.consumedBy == other.consumedBy && this.localActiveKeys == other.localActiveKeys /** See docs of [[ContractStateMachine.advance]] */ - private[ContractStateMachine] def advance[Nid2 >: Nid]( - substate: ActiveLedgerState[Nid2] - ): ActiveLedgerState[Nid2] = + private[lf] def advance( + substate: ActiveLedgerState[Nid] + ): ActiveLedgerState[Nid] = ActiveLedgerState( locallyCreatedThisTimeline = this.locallyCreatedThisTimeline .union(substate.locallyCreatedThisTimeline), @@ -512,9 +517,9 @@ object ContractStateMachine { /** localKeys filter by whether contracts have been consumed already. */ def localActiveKeys: Map[GlobalKey, KeyMapping] = - localKeys.map { case (k, v) => - k -> (if (consumedBy.contains(v)) KeyInactive else KeyActive(v)) - } + localKeys.view + .mapValues((v: ContractId) => if (consumedBy.contains(v)) KeyInactive else KeyActive(v)) + .toMap /** Lookup in localActiveKeys. */ @@ -527,9 +532,7 @@ object ContractStateMachine { } object ActiveLedgerState { - private val EMPTY: ActiveLedgerState[Nothing] = - ActiveLedgerState(Set.empty, Map.empty, Map.empty) - def empty[Nid]: ActiveLedgerState[Nid] = EMPTY + def empty[Nid]: ActiveLedgerState[Nid] = ActiveLedgerState(Set.empty, Map.empty, Map.empty) } } diff --git a/daml-lf/transaction/src/main/scala/com/digitalasset/daml/lf/transaction/Transaction.scala b/daml-lf/transaction/src/main/scala/com/digitalasset/daml/lf/transaction/Transaction.scala index ba61a95cc8..01d459f276 100644 --- a/daml-lf/transaction/src/main/scala/com/digitalasset/daml/lf/transaction/Transaction.scala +++ b/daml-lf/transaction/src/main/scala/com/digitalasset/daml/lf/transaction/Transaction.scala @@ -490,9 +490,8 @@ sealed abstract class HasTxNodes { "If a contract key contains a contract id" ) def contractKeyInputs: Either[KeyInputError, Map[GlobalKey, KeyInput]] = { - val machine = new ContractStateMachine[NodeId](mode = ContractKeyUniquenessMode.Strict) - foldInExecutionOrder[Either[KeyInputError, machine.State]]( - Right(machine.initial) + foldInExecutionOrder[Either[KeyInputError, ContractStateMachine.State[NodeId]]]( + Right(ContractStateMachine.initial[NodeId](ContractKeyUniquenessMode.Strict)) )( exerciseBegin = (acc, nid, exe) => (acc.flatMap(_.handleExercise(nid, exe)), Transaction.ChildrenRecursion.DoRecurse), diff --git a/daml-lf/transaction/src/test/scala/com/digitalasset/daml/lf/transaction/ContractStateMachineSpec.scala b/daml-lf/transaction/src/test/scala/com/digitalasset/daml/lf/transaction/ContractStateMachineSpec.scala index 11472b7109..a5c1ef75b6 100644 --- a/daml-lf/transaction/src/test/scala/com/digitalasset/daml/lf/transaction/ContractStateMachineSpec.scala +++ b/daml-lf/transaction/src/test/scala/com/digitalasset/daml/lf/transaction/ContractStateMachineSpec.scala @@ -162,7 +162,7 @@ class ContractStateMachineSpec extends AnyWordSpec with Matchers with TableDrive val tx = builder.build() val expected = Right( Map(gkey("key1") -> KeyCreate) -> - ActiveLedgerState(Set(1), Map.empty, Map(gkey("key1") -> 1)) + ActiveLedgerState[Unit](Set(1), Map.empty, Map(gkey("key1") -> 1)) ) TestCase( "Create|Rb-Ex-LBK|LBK", @@ -331,7 +331,7 @@ class ContractStateMachineSpec extends AnyWordSpec with Matchers with TableDrive val tx = builder.build() val expectedOff = Right( Map(gkey("key1") -> KeyCreate) -> - ActiveLedgerState( + ActiveLedgerState[Unit]( Set(2, 3), Map.empty, Map( @@ -358,7 +358,7 @@ class ContractStateMachineSpec extends AnyWordSpec with Matchers with TableDrive val tx = builder.build() val expected = Right( Map(gkey("key1") -> NegativeKeyLookup) -> - ActiveLedgerState(Set.empty, Map.empty, Map.empty) + ActiveLedgerState[Unit](Set.empty, Map.empty, Map.empty) ) TestCase( "DivulgedLookup", @@ -427,7 +427,7 @@ class ContractStateMachineSpec extends AnyWordSpec with Matchers with TableDrive val tx = builder.build() val expected = Right( Map(gkey("key1") -> KeyCreate) -> - ActiveLedgerState(Set(3), Map.empty, Map(gkey("key1") -> cid(3))) + ActiveLedgerState[Unit](Set(3), Map.empty, Map(gkey("key1") -> cid(3))) ) TestCase( "CreateAfterRbExercise", @@ -475,7 +475,7 @@ class ContractStateMachineSpec extends AnyWordSpec with Matchers with TableDrive val tx = builder.build() val expectedOff = Right( Map(gkey("key1") -> KeyCreate, gkey("key2") -> KeyCreate) -> - ActiveLedgerState( + ActiveLedgerState[Unit]( Set(1, 3, 4, 5), Map.empty, Map(gkey("key1") -> cid(4), gkey("key2") -> cid(5)), @@ -561,10 +561,14 @@ class ContractStateMachineSpec extends AnyWordSpec with Matchers with TableDrive expected.foreach { case (mode, expectedResult) => s"mode $mode" in { // We use `Unit` instead of `NodeId` so that we don't have to fiddle with node ids - val ksm = new ContractStateMachine[Unit](mode) val actualResolver: KeyResolver = if (mode == ContractKeyUniquenessMode.Strict) Map.empty else resolver - val result = visitSubtrees(ksm)(tx.nodes, tx.roots.toSeq, actualResolver, ksm.initial) + val result = visitSubtrees( + tx.nodes, + tx.roots.toSeq, + actualResolver, + ContractStateMachine.initial[Unit](mode), + ) (result, expectedResult) match { case (Left(err1), Left(err2)) => err1 shouldBe err2 @@ -652,12 +656,12 @@ class ContractStateMachineSpec extends AnyWordSpec with Matchers with TableDrive * for handling [[com.daml.lf.transaction.Node.LookupByKey]]. * Ignored in mode [[com.daml.lf.transaction.ContractKeyUniquenessMode.Strict]]. */ - private def visitSubtree(ksm: ContractStateMachine[Unit])( + private def visitSubtree( nodes: Map[NodeId, Node], root: NodeId, resolver: KeyResolver, - state: ksm.State, - ): Either[Transaction.KeyInputError, ksm.State] = { + state: ContractStateMachine.State[Unit], + ): Either[Transaction.KeyInputError, ContractStateMachine.State[Unit]] = { val node = nodes(root) for { next <- node match { @@ -667,7 +671,7 @@ class ContractStateMachineSpec extends AnyWordSpec with Matchers with TableDrive Right(state.beginRollback()) } afterChildren <- withClue(s"visiting children of $node") { - visitSubtrees(ksm)(nodes, children(node).toSeq, resolver, next) + visitSubtrees(nodes, children(node).toSeq, resolver, next) } exited = node match { case _: Node.Rollback => afterChildren.endRollback() @@ -682,26 +686,26 @@ class ContractStateMachineSpec extends AnyWordSpec with Matchers with TableDrive * * @see visitSubtree for how visiting nodes updates the state */ - private def visitSubtrees(ksm: ContractStateMachine[Unit])( + private def visitSubtrees( nodes: Map[NodeId, Node], roots: Seq[NodeId], resolver: KeyResolver, - state: ksm.State, - ): Either[Transaction.KeyInputError, ksm.State] = { + state: ContractStateMachine.State[Unit], + ): Either[Transaction.KeyInputError, ContractStateMachine.State[Unit]] = { roots match { case Seq() => Right(state) case root +: tail => val node = nodes(root) - val directVisit = visitSubtree(ksm)(nodes, root, resolver, state) + val directVisit = visitSubtree(nodes, root, resolver, state) // Now project the resolver and visit the subtree from a fresh state and check whether we end up the same using advance - val fresh = ksm.initial + val fresh = ContractStateMachine.initial[Unit](state.mode) val projectedResolver: KeyResolver = if (state.mode == ContractKeyUniquenessMode.Strict) Map.empty else state.projectKeyResolver(resolver) withClue( s"Advancing over subtree rooted at $node with projected resolver $projectedResolver; projection state=$state; original resolver=$resolver" ) { - val freshVisit = visitSubtree(ksm)(nodes, root, projectedResolver, fresh) + val freshVisit = visitSubtree(nodes, root, projectedResolver, fresh) val advanced = freshVisit.flatMap(substate => state.advance(resolver, substate)) (directVisit, advanced) match { @@ -715,7 +719,7 @@ class ContractStateMachineSpec extends AnyWordSpec with Matchers with TableDrive case _ => fail(s"$directVisit was knot equal to $advanced") } } - directVisit.flatMap(next => visitSubtrees(ksm)(nodes, tail, resolver, next)) + directVisit.flatMap(next => visitSubtrees(nodes, tail, resolver, next)) } } } diff --git a/daml-lf/verification/README.md b/daml-lf/verification/README.md new file mode 100644 index 0000000000..892b891ebb --- /dev/null +++ b/daml-lf/verification/README.md @@ -0,0 +1,62 @@ +# Proof of the advance method of the Contract State Machine in UCK mode + +Formal verification in [Stainless](https://stainless.epfl.ch/) of the advance method of the [ContractStateMachine](../transaction/src/main/scala/com/digitalasset/daml/lf/transaction/ContractStateMachine.scala) in strict mode. + +More precisely, under reasonable assumptions, if $tr$ is a transaction, $init$ a State, and $\textup{traverse}(tr, \varepsilon)$ is well-defined then: + + $$\textup{traverse}(tr, init) = init.\textup{advance}(\textup{traverse}(tr, \varepsilon))$$ + + where $\varepsilon$ is the empty State. + +A pen and paper proof of the main component of the verification can be found [here](latex/proof.pdf). + +## Components + +- `latex` : the pen and paper proof. +- `scripts` : the verification scripts and its helpers. +- `transaction` : all the proofs related to how the CSM handles a node. +- `translation` : original file and proofs of the translation to a simplified version of it +- `tree` : all the proofs related to how the CSM handles a transaction. +- `utils` : helpers and theorems about collections. + +## Developer Environment Setup + +``` nix-shell ./stainless.nix -A stainlessEnv ``` + +## Build + +To build the Stainless version used in the proof: + + 1. Clone [Stainless repo](https://github.com/epfl-lara/stainless) + 2. Run sbt universal:stage + 3. The generated binary can be found at the following location (in the following, this path is referred to as ``): + `$STAINLESS_REPO_ROOT/frontends/dotty/target/universal/stage/bin/stainless-dotty` + 4. The verification currently works with JDK 17 + + +## Verification + +The verification happens in 2 major steps: + - The translation from the original file to a [simplified version](transaction/ContractStateMachineAlt.scala) that is easy to work with in stainless and a proof of the soudness of this translation. + - The verification of the property + +To verify the former you can execute the verification script with the following argument: + +``` scripts/verification_script.sh translate``` + +The scripts takes the original files and uses a regex to create a temporary copy of the file that modifies the imports, +removes the exceptions and in general features that are not yet supported in Stainless. + + + + To verify the latter, you can either execute the following command: + +```stainless utils/* transaction/* tree/* --watch=false --timeout=30 --vc-cache=false --compact=true --solvers=nativez3``` + +or execute the script with the following argument: + +``` scripts/verification_script.sh verify``` + +If in the first command you find that the timeout is too big you can reduce it as you wish (it is recommended to keep it above 10 if you don't want to have any suprises). + + diff --git a/daml-lf/verification/latex/proof.tex b/daml-lf/verification/latex/proof.tex new file mode 100644 index 0000000000..2663e9f9cd --- /dev/null +++ b/daml-lf/verification/latex/proof.tex @@ -0,0 +1,899 @@ +% Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +% SPDX-License-Identifier: Apache-2.0 + +\documentclass{article} + +\usepackage[margin=2.5cm]{geometry} + +\usepackage{amsmath} +\usepackage{amssymb} +\usepackage{amsthm} +\usepackage{amsthm} +\usepackage{parskip} +\usepackage{hyperref} +\usepackage{cleveref} + +\newtheorem{lemma}{Lemma} +\crefformat{lemma}{Lem.~#2#1#3} +\newtheorem{theorem}{Theorem} +\crefformat{theorem}{Thm.~#2#1#3} +\newtheorem{definition}[lemma]{Definition} +\crefformat{definition}{Def.~#2#1#3} +\newtheorem{claim}[lemma]{Claim} +\crefformat{claim}{Claim~#2#1#3} +\newtheorem{corollary}[lemma]{Corollary} +\crefformat{corollary}{Cor.~#2#1#3} +\newtheorem*{res}{Final result} + +\newtheorem*{term}{Terminology} + +\newcommand{\textfun}[1]{\textup{#1}} +\newcommand{\textcode}[1]{\texttt{#1}} +\newcommand{\bolddef}[1]{\textbf{\ensuremath{\mathbf{#1}}}} +\newcommand{\negat}[1]{\neg\,#1} + +\newcommand{\fdown}[2]{\ensuremath{f_{down}(#1, #2)}} +\newcommand{\fup}[2]{\ensuremath{f_{up}(#1, #2)}} +\newcommand{\fone}[2]{\ensuremath{f_{1}(#1, #2)}} +\newcommand{\ftwo}[2]{\ensuremath{f_{2}(#1, #2)}} +\newcommand{\fcoll}[2]{\ensuremath{f_{collect}(#1, #2)}} + +\newcommand{\emptyList}{\ensuremath{[\ ]}} +\newcommand{\concat}{{{\ ++\ }}} + +% NODES % +\newcommand{\fetchNode}[2]{\textfun{Fetch} \ #1\ #2} +\newcommand{\lookupNode}[2]{\textfun{Lookup}\ #1\ #2} +\newcommand{\createNode}[2]{\textfun{Create}\ #1\ #2} +\newcommand{\exNode}[2]{\textfun{Exercise}\ #1\ #2} + + +%TREES +\newcommand{\nilNode}{\textfun{Endpoint}} +\newcommand{\contentNode}[2]{\textfun{ContentNode}\ #1\ #2} +\newcommand{\artNode}[2]{\textfun{ArticulationNode}\ #1\ #2} + +\newcommand{\longtraverse}[4]{\textfun{traverse}\ #1\ #2\ #3 \ #4} +\newcommand{\traverse}[2]{\textfun{traverse}\ #1\ #2} +\newcommand{\longscan}[4]{\textfun{scan}\ #1\ #2\ #3 \ #4} +\newcommand{\scan}[2]{\textfun{scan}\ #1\ #2} +\newcommand{\collect}[1]{\textfun{collect}\ #1} +\newcommand{\collectTr}[1]{\textfun{collectTrace}\ #1} + + +\newcommand{\up}{\uparrow} +\newcommand{\down}{\downarrow} + +\newcommand{\hNode}[2]{\textfun{handleNode}\ #1\ #2} +\newcommand{\beginRb}[1]{\textfun{beginRollback}\ #1} +\newcommand{\enRb}[1]{\textfun{endRollback}\ #1} +\newcommand{\defined}[1]{\textfun{defined}\ #1} +\newcommand{\get}[1]{\textfun{get}\ #1} +\newcommand{\mapping}[1]{\textfun{mapping}\ #1} +\newcommand{\size}[1]{\textfun{size}\ #1} + +\newcommand{\fst}[1]{\textfun{proj}_1\ #1} +\newcommand{\snd}[1]{\textfun{proj}_2\ #1} +\newcommand{\trd}[1]{\textfun{proj}_3\ #1} + +\newcommand{\keyAct}[1]{\textfun{KeyActive}\ #1} +\newcommand{\keyInact}{\textfun{KeyInactive}} +\newcommand{\state}{\textfun{State}} +\newcommand{\appfst}[2]{\textfun{app}_{#1} #2 = 1} +\newcommand{\actkey}[2]{\textfun{act}_{#1}\ #2} +\newcommand{\emptyState}[1]{\varepsilon_{#1}} + + + + +\title{\huge Proof of \texttt{advance} in the UCK State Machine} +\date{} +\author{} + +\begin{document} + +\maketitle + + +\begin{term}\ + \begin{itemize} + \item A \textbf{well-defined state} is an instance of the type $\state$ or $\textfun{Right[KeyInputError, \state]}$ + \item A \textbf{state} is an instance of the type $\textfun{Either[KeyInputError, \state]}$ + \item A \textbf{node} is an instance of the type $\textfun{Node}$ or $\textfun{(Node, NodeId)}$. + \end{itemize} + When multiple interpretations are possible, the context will always make clear which one we are referring to. + In particular, depending on the situation, $\beginRb{s}$ and $\enRb{s}$ will return either a $\state$ or a $\textfun{Right[KeyInputError, \state]}$. +\end{term} + +\begin{definition} + Let $s$ be a state, \bolddef{\defined{s}} is true if $s$ is a well-defined state. + When this is the case, \bolddef{\get{s}} yields the $\textfun{State}$ instance out of the option. +\end{definition} + +\begin{definition} + Let $t := (v_1, v_2, v_3)$ a triple. The projections of the triple are $\bolddef{\fst{t}} := v_1$, $\bolddef{\snd{t}} := v_2$ and $\bolddef{\trd{t}} := v_3$. +\end{definition} + + + + + + + + + + + + + + + + +%TREES + +\section*{Trees} + +\begin{definition}[\textcode{Tree}] + A tree of nodes is a data structure defined inductively as either: + \begin{itemize} + \item $\nilNode$ + \item $\contentNode{sub}{n}$, where $sub$ is a tree and $n$ a node. + \item $\artNode{l}{r}$, where $l$ and $r$ are both trees. + \end{itemize} +\end{definition} + +Although this is not the most intuitive definition, it has the advantage of being able to represent forests +without running into measure decreaseness problems. In fact one can see content nodes as being trees and articulation nodes as forests. + + +\begin{definition}[\textcode{Tree.size}] + \label{size_def} + Let $tr$ be a tree of nodes, \bolddef{\size{tr}} is defined inductively as: + \begin{itemize} + \item $\size{\nilNode} := 0$ + \item $\size{(\contentNode{sub}{n})} := (\size{sub}) + 1$ + \item $\size{(\artNode{l}{r})} := (\size{l}) + (\size{r})$ + \end{itemize} +\end{definition} + +\begin{definition}[\textcode{Tree.traverse}] + \label{traverse_def} + Let $tr$ be a tree of nodes and $init$ a value, \bolddef{\longtraverse{tr}{init}{f_1}{f_2}} is defined inductively as: + \begin{itemize} + \item $\longtraverse{\nilNode}{init}{f_1}{f_2} := init$ + \item $\longtraverse{(\contentNode{sub}{n})}{init}{f_1}{f_2} := \ftwo{\longtraverse{sub}{\fone{init}{n}}{f_1}{f_2}}{n}$ + \item $\longtraverse{(\artNode{l}{r})}{init}{f_1}{f_2} := \longtraverse{r}{(\longtraverse{l}{init}{f_1}{f_2})}{f_1}{f_2}$ + \end{itemize} +\end{definition} + +\begin{definition}[\textcode{Tree.scan}] + \label{scan_def} + Let $tr$ be a tree of nodes and $init$ a value, \bolddef{\longscan{tr}{init}{f_1}{f_2}} is defined inductively as: + \begin{itemize} + \item $\longscan{\nilNode}{init}{f_1}{f_2} := \emptyList$ + \item $\longscan{(\contentNode{sub}{n})}{init}{f_1}{f_2} :=$ + \[ [(init, n,\,\down)] \concat \longscan{sub}{f_1(init, n)}{f_1}{f_2} \concat [(\longtraverse{sub}{f_1(init, n)}{f_1}{f_2}, n,\,\up)]\] + \item $\longscan{(\artNode{l}{r})}{init}{f_1}{f_2} :=$ + \[\longscan{l}{init}{f_1}{f_2} \concat \longscan{r}{(\longtraverse{l}{init}{f_1}{f_2})}{f_1}{f_2}\] + \end{itemize} +\end{definition} + +\begin{claim}[postcondition of \textcode{Tree.size}] + \label{size_pos} + Let $tr$ be a tree of nodes, + \[\size{tr} \geq 0 \] +\end{claim} +\begin{proof} + Straight induction on $tr$. +\end{proof} + +\begin{claim}[postcondition of \textcode{Tree.scan}] + \label{scan_size} + For any tree of nodes $tr$, value $init$ and functions $f_1, f_2$ + \[\size{(\longscan{tr}{init}{f_1}{f_2})} = 2 \cdot \size{tr}\] +\end{claim} + +\begin{proof} + Straight induction on tr. +\end{proof} + +\begin{lemma}[\textcode{scanIndexing}] + \label{scan_indexing} + For any tree of nodes $tr$, value $init$ and functions $f_1, f_2$ + + If $tr = \contentNode{sub}{n}$: + \[(\longscan{tr}{init}{f_1}{f_2})[i] = + \begin{cases} + (init, n,\,\down) & \text{if $i = 0$}\\ + (\longtraverse{sub}{\fone{init}{n}}{f_1}{f_2}, n,\,\up) & \text{if $i = 2 \cdot \size{tr} - 1$}\\ + (\longscan{sub}{\fone{init}{n}}{f_1}{f_2})[i - 1] & \text{otherwise} + \end{cases}\] + + If $tr = \artNode{l}{r}$: + \[(\longscan{tr}{init}{f_1}{f_2})[i] = + \begin{cases} + (\longscan{l}{init}{f_1}{f_2})[i] & \text{if $i < 2\cdot \size{l}$}\\ + (\longscan{r}{(\longtraverse{l}{init}{f_1}{f_2})}{f_1}{f_2})[i - 2 \cdot \size{l}] & \text{otherwise} + \end{cases}\] + +\end{lemma} + +\begin{proof} + Application of union indexing property in lists +\end{proof} + +\begin{lemma}[\textcode{scanIndexingState}] + \label{scan_indexing_state} + Let $tr$ be a tree of nodes $tr$, $init$ a value, $f_1, f_2$ two functions and $i < 2 \cdot \size{tr}$ a non-negative integer. Let $l := \longscan{tr}{init}{f_1}{f_2}$: + \begin{align} + \fst{(l[i])} =& \begin{cases} + init & \text{if $i = 0$}\\ + f_1(\fst{(l[i - 1])}, \snd{(l[i - 1])}) & \text{if $\trd{(l[i - 1])} = \, \down$} \\ + f_2(\fst{(l[i - 1])}, \snd{(l[i - 1])}) & \text{if $\trd{(l[i - 1])} = \, \up$} + \end{cases}\\ + \intertext{\hskip 25em and} + \longtraverse{tr}{init}{f_1}{f_2} =& + \begin{cases} + init &\text{if $\size{tr} = 0$} \\ + f_2(\fst{(l[2 \cdot \size{tr} - 1])}, \snd{(l[2 \cdot \size{tr} - 1])}) &\text{otherwise} + \end{cases} + \end{align} +\end{lemma} + +\begin{proof} +Induction on tr: + +\begin{itemize} + \item If $tr = \nilNode$, then $\size{tr} = 0$. + \item If $tr = \contentNode{sub}{n}$: + + \begin{itemize} + \item If $i = 0$ then by \cref{scan_indexing} $l[i] = (init, n, \down)$. + \item If $i = 2\, \cdot \, \size{tr} \, - \, 1$, then by \cref{scan_indexing} $l[i] = (\longtraverse{sub}{\fone{init}{n}}{f_1}{f_2}, n, \up)$. + + Let $l_{sub} := \longscan{sub}{\fone{init}{n}}{f_1}{f_2}$. + + If $\size{sub} = 0$, then $i = 1$: + \begin{align*} + \longtraverse{sub}{\fone{init}{n}}{f_1}{f_2} =& \fone{init}{n} & \text{by IH} \\ + =& \fone{\fst{(l[0])}}{\snd{(l[0])}} & \text{by \cref{scan_indexing}} \\ + =& \fone{\fst{(l[i - 1])}}{\snd{(l[i - 1])}} + \end{align*} + Moreover we know that $\trd{(l[0])} = \down$ + + Otherwise: + + + \begin{align*} + &\longtraverse{sub}{\fone{init}{n}}{f_1}{f_2} \\ + =& \ftwo{\fst{(l_{sub}[2 \cdot \size{sub} - 1])}}{\snd{(l_{sub}[2 \cdot \size{sub} - 1])}} \\ + =& \ftwo{\fst{(l_{sub}[2 \cdot \size{tr} - 3])}}{\snd{(l_{sub}[2 \cdot \size{tr} - 3])}} & \text{unfolding \cref{size_def}} \\ + =& \ftwo{\fst{(l[2 \cdot \size{tr} - 2])}}{\snd{(l[2 \cdot \size{tr} - 2])}} & \text{by \cref{scan_indexing}} \\ + =& \ftwo{\fst{(l[i - 1])}}{\snd{(l[i - 1])}} + \end{align*} + + \item Else apply induction hypothesis with $sub$, $\fone{init}{n}$ and $i - 1$. Use \cref{scan_indexing} when $i > 1$. + \item To prove (2) we know by \cref{size_def} that $\size{tr} > 0$. The result is then immediate from \cref{scan_indexing}. + \end{itemize} + \item If $tr = \artNode{le}{ri}$: + \begin{itemize} + \item If $i < 2 \cdot \size{le}$, apply induction hypothesis with $le$, $init$ and $i$, and \cref{scan_indexing}. + \item If $i = 2 \cdot \size{le}$: + + Let $l_{le} := (\longscan{le}{init}{f_1}{f_2})$ and $l_{ri} := (\longscan{ri}{(\longtraverse{le}{init}{f_1}{f_2})}{f_1}{f_2})$ + + \begin{align*} + \fst{(l_{ri}[0])} &= \longtraverse{le}{init}{f_1}{f_2}\: &\text{by IH on $ri$}\\ + \fst{(l_{ri}[i - 2 \cdot \size{le}])} &= \longtraverse{le}{init}{f_1}{f_2} \\ + \fst{(l[i])} &= \longtraverse{le}{init}{f_1}{f_2} &\text{by \cref{scan_indexing}} \\ + &= \ftwo{\fst{(l_{le}[2 \cdot \size{le} - 1])}}{\snd{(l_{le}[2 \cdot \size{le} - 1])}} &\text{by IH on $le$} \\ + &= \ftwo{\fst{(l_{le}[i - 1])}}{\snd{(l_{le}[i - 1])}} + \end{align*} + \item Else apply induction hypothesis with $ri$, $\longtraverse{le}{init}{f_1}{f_2}$ and $i - 2 \cdot \size{le}$ and \cref*{scan_indexing}. + \item If $\size{tr} = 0$ then by \cref{size_pos} and unfolding \cref{size_def}, $\size{le} = 0$ and $\size{ri} = 0$. + \begin{align*} + \longtraverse{tr}{init}{f_1}{f_2} &= \longtraverse{ri}{(\longtraverse{le}{init}{f_1}{f_2})}{f_1}{f_2} &\text{unfolding \cref{traverse_def}} \\ + &= \longtraverse{le}{init}{f_1}{f_2} & \text{by IH on $ri$} \\ + &= init & \text{by IH on $l$} + \end{align*} + + + Else if $\size{ri} = 0$: + \begin{align*} + \longtraverse{tr}{init}{f_1}{f_2} + =\ & \longtraverse{ri}{(\longtraverse{le}{init}{f_1}{f_2})}{f_1}{f_2} &\text{unfolding \cref{traverse_def}} \\ + =\ & \longtraverse{le}{init}{f_1}{f_2} & \text{by IH on $ri$} \\ + =\ & f_2(\fst{(l_{le}[2 \cdot \size{le} - 1])}, \snd{(l_{le}[2 \cdot \size{le} - 1])}) & \text{by IH on $le$} \\ + =\ & f_2(\fst{(l[2 \cdot \size{le} - 1])}, \snd{(l[2 \cdot \size{le} - 1])}) & \text{by \cref{scan_indexing}} \\ + =\ & f_2(\fst{(l[2 \cdot \size{tr} - 1])}, \snd{(l[2 \cdot \size{tr} - 1])}) & \text{by \cref{size_def}} + \end{align*} + + When $\size{ri} > 0$: + \begin{align*} + &\longtraverse{tr}{init}{f_1}{f_2} \\ + =\ & \longtraverse{ri}{(\longtraverse{le}{init}{f_1}{f_2})}{f_1}{f_2} &\text{unfolding \cref{traverse_def}} \\ + =\ & f_2(\fst{(l_{ri}[2 \cdot \size{ri} - 1])}, \snd{(l_{ri}[2 \cdot \size{ri} - 1])}) & \text{by IH on $ri$} \\ + =\ & f_2(\fst{(l[2 \cdot \size{ri} + 2 \cdot \size{le} - 1])}, \snd{(l[2 \cdot \size{ri} + 2 \cdot \size{le} - 1])}) & \text{by \cref{scan_indexing}} \\ + =\ & f_2(\fst{(l[2 \cdot \size{tr} - 1])}, \snd{(l[2 \cdot \size{tr} - 1])}) & \text{by \cref{size_def}} + \end{align*} + \end{itemize} +\end{itemize} + +\end{proof} + +\begin{claim}[\textcode{scanIndexingNode}] + \label{scan_indexing_node} + For any tree of nodes $tr$, values $init_1,\, init_2$, functions $f^1_1, f^1_2, f^2_1, f^2_2$ and non-negative integer $i < 2\, \cdot\, \size{tr}$. + \[ \snd{((\longscan{tr}{init_1}{f^1_1}{f^1_2})[i])} = \snd{((\longscan{tr}{init_2}{f^2_1}{f^2_2})[i])}\] + \[\text{and}\] + \[ \trd{((\longscan{tr}{init_1}{f^1_1}{f^1_2})[i])} = \trd{((\longscan{tr}{init_2}{f^2_1}{f^2_2})[i])}\] +\end{claim} + +\begin{proof} + Induction on tr using \cref{scan_indexing}. +\end{proof} + +We now define the two functions we will be applying during a tree traversal. + +\begin{definition}[\textcode{traverseInFun}, \textcode{traverseOutFun}] + Let $s$ be a state and $n$ be a node, + \[\fdown{s}{n} : = + \begin{cases} + \hNode{(\get{s})}{n}& \text{if $n$ is an Action node and $\defined{s}$} \\ + \beginRb{(\get{s})} & \text{if $n$ is a Rollback node and $\defined{s}$} \\ + s& \text{otherwise} + \end{cases} \] + \[\fup{s}{n} : = + \begin{cases} + \enRb{(\get{s})} & \text{if $n$ is a Rollback node and $\defined{s}$} \\ + s& \text{otherwise} + \end{cases} \] +\end{definition} + +In the rest of the document, unless stated otherwise, init will be a state and $\traverse{tr}{init}$ and $\scan{tr}{init}$ will be referring to as respectively +$\longtraverse{tr}{init}{f_{down}}{f_{up}}$ and $\longscan{tr}{init}{f_{down}}{f_{up}}$ + +\begin{claim}[\textcode{traverseTransactionProp}] + For any tree of nodes tr, state init, if $\defined{(\traverse{tr}{init})}$ then + \label{traverse_prop} + \[\defined{init} \land (\get{(\traverse{tr}{init})}).rollbackStack = (\get{init}).rollbackStack \ \land\] + \[(\get{(\traverse{tr}{init})}).globalKeys = (\get{init}).globalKeys\] +\end{claim} + +\begin{proof} + Straight induction on tr. +\end{proof} + +\begin{claim}[\textcode{scanTransactionProp}] + \label{defined_prop} + Let $tr$ be a tree of nodes, $init$ a state, $0 \leq i \leq j < 2 \cdot\, \size{tr} $ two integers and let $l:= \scan{tr}{init}$. + \[\defined (\fst{(l[j])}) \implies\, \defined (\fst{(l[i])})\] +\end{claim} + +\begin{proof} + By induction on $j$: the base case is immediate and \cref{scan_indexing_state} + combined with the induction hypothesis concludes the proof. +\end{proof} + +\begin{corollary}[\textcode{scanTransactionProp}] + \label{defined_traverse_prop} + Let $tr$ be a tree of nodes, $init$ a state, $0 \leq i < 2 \cdot\, \size{tr} $ an integer and let $l:= \scan{tr}{init}$. + \[\defined (\traverse{tr}{init}) \implies\, \defined (\fst{(l[i])})\] +\end{corollary} + +\begin{proof} + By \cref{scan_indexing_state} and \cref{defined_prop} setting $j = 2 \cdot\, \size{tr} - 1$. +\end{proof} + +\begin{lemma} + \label{defined_alt_def} + Let $tr$ be a tree of nodes, $init$ a state and $l := \scan{tr}{init}$ then + \[\defined{(\traverse{tr}{init})} \iff + \begin{cases} + \defined{init} \\ + \forall\, 0 \leq i < 2 \cdot\, \size{tr} - 1, \, \defined{(\fst{(l[i])})} \implies \defined{(\fst{(l[i + 1])})} \, + \end{cases}\] + +\end{lemma} + +\begin{proof} + Let's first note that by \cref{scan_indexing_state}, $\fst{(l[0])} = init$. + + If $\defined{(\traverse{tr}{init})}$ then by \cref{defined_traverse_prop}, $\defined{\fst{(l[i])}}$ for all $0 \leq i < 2 \cdot\, \size{tr}$. + + If the right statement is true, then $\defined{(\fst{(l[2 \cdot \, \size{tr} - 1])})}$ and therefore by \cref{scan_indexing_state} we have $\defined{(\traverse{tr}{init})}$. +\end{proof} + + +\begin{claim}[\textcode{findBeginRollback}] + \label{find_two_begin_rb} + Let tr be a tree of nodes, $init_1, init_2$ be states and let $l_1 := \scan{tr}{init_1}$, $l_2 := \scan{tr}{init_2}$. + If there exists a non-negative integer $i < 2 \cdot (\size{tr})$, well-defined states $s^1_i$, $s^2_i$ and a Rollback node $n$ such that + $l_1[i] = (s^1_i, n, \up)$ and $l_2[i] = (s^2_i, n, \up)$, then there is an integer $j$, well-defined states $s^1_j, s^2_j$ and a tree $sub$ such that + \[0 \leq j < i \qquad l_1[j] = (s^1_j, n, \down) \qquad l_2[j] = (s^2_j, n, \down) \qquad \size{sub} < \size{tr}\] + \[s^1_i = \traverse{sub}{(\beginRb{(\get{s^1_j})})} \qquad s^2_i = \traverse{sub}{(\beginRb{(\get{s^2_j})})}\] +\end{claim} + +\begin{proof} + Induction on $tr$: + \begin{itemize} + \item If $tr = \nilNode$, then $\size{tr} = 0$ which means the precondition is never met. + \item If $tr = \contentNode{str}{c}$: + + \begin{itemize} + \item If $c \neq n$ then by \cref{scan_indexing}, $ 0 < i < 2 \cdot (\size{tr}) - 1$, $l_1[i] = (\scan{str}{\fdown{init_1}{n}})[i - 1]$ and + $l_2[i] = (\scan{str}{\fdown{init_2}{n}})[i - 1]$. + By induction hypothesis there is an integer $j$, well-defined state $s^1_j$, $s^2_j$ and a tree $sub$ such that: + \[(\scan{str}{\fdown{init_1}{n}})[j] = (s^1_j, n, \down) \qquad (\scan{str}{\fdown{init_2}{n}})[j] = (s^2_j, n, \down)\] + \[s^1_i = \traverse{sub}{(\beginRb{(\get{s^1_j})})} \qquad s^2_i = \traverse{sub}{(\beginRb{(\get{s^2_j})})}\] + \[0 \leq j < i - 1 \qquad \size{sub} < \size{tr} \qquad \] + + Since $l_1[j + 1] = (\scan{str}{\fdown{init_1}{n}})[j]$ and $l_2[j + 1] = (\scan{str}{\fdown{init_2}{n}})[j]$, $j + 1$, $sub$, $s^1_j$ and $s^2_j$ satisfy the above conditions. + + \item If $c = n$, then $i = 2\cdot (\size{tr}) - 1$ is valid and therefore $j = 0$, $s^1_j = init_1$, $s^2_j = init_2$ and $sub = str$. + \end{itemize} + \item If $tr = \artNode{left}{right}$: + \begin{itemize} + \item If $i < 2 \cdot \size{left}$, then by \cref{scan_indexing}, $l_1[i] = (\scan{left}{init_1})[i]$ and $l_2[i] = (\scan{left}{init_2})[i]$. + By induction hypothesis, there are $j$, well-defined $s^1_j$, $s^2_j$ and $sub$ such that. + \[(\scan{left}{init_1})[j] = (s^1_j, n, \down) \qquad\, (\scan{left}{init_2})[j] = (s^2_j, n, \down)\] + \[s^1_i = \traverse{sub}{(\beginRb{(\get{s^1_j})})} \qquad s^2_i = \traverse{sub}{(\beginRb{(\get{s^2_j})})}\] + \[0 \leq j < i \qquad \size{sub} < \size{left}\] + Since $l_1[j] = (\scan{left}{init_1})[j]$ and $l_2[j] = (\scan{left}{init_2})[j]$, $j$, $s^1_j$, $s^2_j$ and $sub$ satisfy the claim. + \item If $i \geq 2 \cdot \size{left}$, then by \cref{scan_indexing}, $l_1[i] = (\scan{right}{(\traverse{left}{init_1})})[i - 2 \cdot \size{left}]$ and $l_2[i] = (\scan{right}{(\traverse{left}{init_2})})[i - 2 \cdot \size{left}]$. + By induction hypothesis, there are $j$, well-defined $s^1_j$, $s^2_j$ and $sub$ such that. + \[(\scan{right}{(\traverse{left}{init_1})})[j] = (s^1_j, n, \down) \qquad (\scan{right}{(\traverse{left}{init_2})})[j] = (s^2_j, n, \down)\] + \[s^1_i = \traverse{sub}{(\beginRb{(\get{s^1_j})})} \qquad s^2_i = \traverse{sub}{(\beginRb{(\get{s^2_j})})}\] + \[ 0 \leq\, j < i - 2\cdot \size{left} \qquad \size{sub} < \size{right}\] + + Since $l_1[j + 2 \cdot\, \size{left}] = (\scan{right}{(\traverse{left}{init_1})})[j] $ and $l_2[j + 2 \cdot \size{left}] = (\scan{right}{(\traverse{left}{init_2})})[j]$, $j + 2 \cdot \size{left}$, $s^1_j$, $s^2_j$ and $sub$ satisfy the claim. + \end{itemize} + \end{itemize} +\end{proof} + +\begin{corollary}[\textcode{findBeginRollback}] + \label{find_begin_rb} + Let tr be a tree of nodes, $init$ a states and let $l := \scan{tr}{init}$. + If there exists a non-negative integer $i < 2 \cdot (\size{tr})$, a well-defined states $s_i$ and a rollback node $n$ such that + $l[i] = (s_i, n, \up)$, then there is an integer $j$, a well-defined states $s_j,$ and a tree $sub$ such that + \[0 \leq j < i \qquad l[j] = (s_j, n, \down) \qquad s_i = \traverse{sub}{(\beginRb{s_j})} \qquad \size{sub} < \size{tr}\] +\end{corollary} + +\begin{proof} + By \cref{find_two_begin_rb} with $init_1 = init_2$. +\end{proof} + + + + + + + + + +\newpage + + +\section*{Active Keys Lemmas} + +\begin{definition}[\textcode{State.activeKeys.get}] + Let $s$ be a well-defined state and $k$ a key, \bolddef{\actkey{k}{s}} is the value associated to key in the active keys of the state + (i.e. we first look at the local keys, then the global ones filtering the consumed contracts). +\end{definition} + +\begin{definition}[\textcode{nodeActionKeyMapping}] + Let $n$ be a node, \bolddef{\mapping{n}} is defined as: + \begin{itemize} + \item \mapping{(\createNode{id}{k})} := \keyInact + \item \mapping{(\fetchNode{id}{k})} := \keyAct{id} + \item \mapping{(\lookupNode{result}{k})} := result + \item \mapping{(\exNode{id}{k})} := \keyAct{id} + \end{itemize} +\end{definition} + +\begin{lemma}[\textcode{handleNodeUndefined}] + \label{defined_def} + For any well-defined state s and Action node n, + \[\defined{(\hNode{s}{n})} \iff {\actkey{n.k}{s} = \mapping{n}}\] +\end{lemma} + + +\begin{corollary}[\textcode{handleSameNodeActiveKeys}] + \label{same_key_hnode} + For any well-defined states $s_1$, $s_2$ and Action node n, if $\defined{(\hNode{s_ 1}{n})}$ and $\defined{(\hNode{s_ 2}{n})}$ + \[\actkey{n.k}{s_1} = \actkey{n.k}{s_2}\] +\end{corollary} + +\begin{proof} + Direct consequence of \cref{defined_def}. +\end{proof} + +\begin{lemma}[\textcode{handleNodeDifferentStatesActiveKeysGet}] + \label{same_key_after_hnode} + For any well-defined states $s_1$, $s_2$ and Action node n, if $\defined{(\hNode{s_ 1}{n})}$ and $\defined{(\hNode{s_ 2}{n})}$ + \[\actkey{n.k}{(\get{(\hNode{s_1}{n})})} = \actkey{n.k}{(\get{(\hNode{s_2}{n})})}\] +\end{lemma} + + +\begin{lemma}[\textcode{handleNodeActiveKeysGet}] + \label{actkey_handle_node} + For any well-defined state $s$, Action node $n$, key $k_2$, if $n$ has no key or $k_2 \neq n.k$ and if $\defined{(\hNode{s}{n})}$, + \[\actkey{k_2}{(\get{(\hNode{s}{n})})} = \actkey{k_2}{s}\] +\end{lemma} + +\begin{lemma}[\textcode{beginRollbackActiveKeysGet}] + \label{act_begin_rb} + For any well-defined state s, node n, key $k$, + \[\actkey{k}{(\beginRb{s})} = \actkey{k}{s}\] +\end{lemma} + + +\begin{lemma}[\textcode{activeKeysGetRollbackScope}] + \label{skip_rb} + For any well-defined state s, node n, key $k$, function $g: \state \rightarrow \state $ and: + \begin{itemize} + \item $g(\beginRb{s}).rollbackStack = (\beginRb{s}).rollbackStack$ + \item $g(\beginRb{s}).globalKeys = (\beginRb{s}).globalKeys$ + \end{itemize} + + We have: + \[\actkey{k}{(\enRb{g(\beginRb{s})})} = \actkey{k}{s}\] + +\end{lemma} + + + + + + +\newpage + + + + +\section*{The real deal} + + +\begin{definition}[\textcode{appearsAtIndex}, \textcode{doesNotAppearBefore}, \textcode{firstAppears}] + \label{appear_def} + Let tr be a tree of nodes, init a value, $k$ a key, $f_1, f_2$ two functions, $i < 2 \cdot \size{tr}$ a non-negative integer and let $l := \longscan{tr}{init}{f_1}{f_2}$. + + We say that $k$ does not appear before $i$ if for all $0 \leq j < i,\ (\snd{l[j]}).k \neq k$ or $\trd{l[j]} =\, \up$ + + We say that $i$ is the first appearance of $k$ in $l$ if $(\snd{l[i]}).k = k$, $\trd{l[i]} = \ \down$ and $k$ does not appear before $i$. +\end{definition} + + +\begin{claim}[\textcode{doesNotAppearBeforeSame}, \textcode{firstAppearsSame}] + \label{appear_same} + Let tr be a tree of nodes, $init, init_2$ a state, $f^1_1, f^1_2, f^2_1, f^2_2$ functions, $k$ a key, $i$ a non-negative integer smaller than $2 \cdot \size{tr}$, + $l_1 := \longscan{tr}{init_1}{f^1_1}{f^1_2}$ and $l_2 := \longscan{tr}{init_2}{f^2_1}{f^2_2}$. + \begin{center} + $k$ does not appear before $i$ in $l_1$ $\iff$ $k$ does not appear before $i$ in $l_2$ + \end{center} + and in particular + \begin{center} + $i$ is the first appearance of $k$ in $l_1$ $\iff$ $i$ is the first appearance of $k$ in $l_2$ + \end{center} +\end{claim} + +\begin{proof} + Consequence of \cref{scan_indexing_node}. +\end{proof} + +\begin{claim}[\textcode{findFirstAppears}] + \label{find_first_appear} + Let tr be a tree of nodes, $init$ a state, $k$ a key, $0 \leq i_1 < i_2 < 2\cdot \size{tr}$ twos integers and + let $l := \scan{tr}{init}$. If $k$ appears before $i_2$ in $l$ but does not before $i_1$, then there exists an integer + $i_1 \leq j < i_2$ such that $j$ is the first appearance of $k$ in $l$. +\end{claim} + +\begin{proof} + Immediate from \cref{appear_def}. +\end{proof} + +\begin{claim}[\textcode{doesNotAppearBeforeSameActiveKeysGet}] + \label{actkey_not_appear} + Let tr be a tree of nodes, init a state, $k$ a key, $0 \leq j < i < 2\cdot\, \size{tr}$ two integers and let $l := \scan{tr}{init}$. + If $k$ does not appear before $i$ in $l$ and $\defined{(\fst{(l[i])})}$ (and $\defined{(\fst{(l[j])})}$ by \cref{defined_prop}) then + \[\actkey{k}{(\get{(\fst{(l[i])})})} = \actkey{k}{(\get{(\fst{(l[j])})})}\] +\end{claim} + +\begin{proof} + By induction on $i$. + + If $i = 0$ then the precondition is never met. + + Else let $(s_{i -1}, n_{i - 1}, dir_{i - 1}) := l[i - 1]$ and $s_i := \fst{(l[i])}$. By \cref{defined_prop} $\defined{s_{i - 1}}$. + + By \cref{scan_indexing_state} we either have: + \begin{itemize} + \item $s_i = \fdown{s_{i - 1}}{n_{i - 1}}$ and $dir_{i - 1} =\, \down$. + + If $n_{i - 1}$ is an Action node, since $k$ does not appear before $i$, $k \neq n_{i - 1}.k$, then: + \begin{align*} + \actkey{k}{(\get{s_i})} &= \actkey{k}{(\get{(\hNode{(\get{s_{i - 1}})}{n_{i - 1}})})} \\ + &= \actkey{k}{(\get{s_{i - 1}})} &\text{from \cref{actkey_handle_node}} \\ + &= \actkey{k}{(\get{(\fst{(l[j])})})} &\text{for all $j < i - 1$ by IH} + \end{align*} + + If $n_{i - 1}$ is a Rollback node then + \begin{align*} + \actkey{k}{(\get{s_i})} &= \actkey{k}{(\beginRb{(\get{s_{i - 1}})})}\\ + &= \actkey{k}{(\get{s_{i - 1}})} &\text{by \cref{act_begin_rb}} \\ + &= \actkey{k}{(\get{(\fst{(l[j])})})} &\text{for all $j < i - 1$ by IH } + \end{align*} + + \item $s_i = \fup{s_{i - 1}}{n_{i - 1}}$ and $dir_{i - 1} =\, \up$. + + If $n_{i - 1}$ is an Action node then + \begin{align*} + \actkey{k}{(\get{s_i})} &= \actkey{k}{(\get{s_{i - 1}})} \\ + &=\actkey{k}{(\get{(\fst{(l[j])})})} &\text{\qquad \qquad for all $j < i - 1$ by IH } + \end{align*} + + If $n_{i - 1}$ is a Rollback node then + \[\actkey{k}{(\get{s_i})} = \actkey{k}{(\enRb{(\get{s_{i - 1}})})} \] + By \cref{find_begin_rb} there exists an integer $0 \leq j' < i - 1$, a well-defined state $s_{j'} = \fst{(l[j'])}$ and a tree $sub$ such that $s_{i - 1} = \traverse{sub}{(\beginRb{(\get{s_{j'}})})}$. Therefore + \begin{align*} + \actkey{k}{(\get{s_i})} &= \actkey{k}{(\enRb{(\get{(\traverse{sub}{(\beginRb{(\get{s_j'})})})})})} \\ + &= \actkey{k}{(\get{s_{j'}})} \text{\hskip 14.5em by \cref{skip_rb}} \\ + &= \actkey{k}{(\get{(\fst{(l[j'])})})} + \end{align*} + In addition, by the induction hypothesis + \begin{align*} + \actkey{k}{(\get{s_{i - 1}})} &= \actkey{k}{(\get{(\fst{(l[j])})})} \text{\hskip 6em for all $j < i - 1$} \\ + &= \actkey{k}{(\get{(\fst{(l[j'])})})} \\ + &= \actkey{k}{(\get{s_i})} + \end{align*} + \end{itemize} + +\end{proof} + +\begin{corollary}[\textcode{firstAppearsSameActiveKeysGet}] + \label{actkey_first_appear} + Let tr be a tree of nodes, init a state, $k$ a key, $0 \leq j < i < 2\cdot \size{tr}$ two integers and let $l := \scan{tr}{init}$. + If $i$ is the first appearance of $k$ in $l$ and $\defined{(\fst{(l[i])})}$ (and $\defined{(\fst{(l[j])})}$ by \cref{defined_prop}), then + \[\actkey{k}{(\get{(\fst{(l[i])})})} = \actkey{k}{(\get{(\fst{(l[j])})})}\] +\end{corollary} +\begin{proof} + Direct consequence of \cref{actkey_not_appear}. +\end{proof} + +\begin{corollary}[\textcode{firstAppearsHandleNodeUndefined}] + \label{defined_first_appear} + Let tr be a tree of nodes, init a state, $0 \leq i < 2\cdot \size{tr}$ an integer, let $l := \scan{tr}{init}$ and $(s, n, dir) := l[i]$. + If $n$ is an Action node with a well-defined key, $i$ is the first appearance of $n.k$ in $l$ and $\defined{s}$, then + \[\actkey{n.k}{(\get{init})} = \actkey{n.k}{(\get{s})}\] + \[\text{and in particular}\] + \[\defined{(\hNode{(\get{s})}{n})} \iff \actkey{n.k}{(\get{init})} = \mapping{n}\] +\end{corollary} +\begin{proof} + The first statement is a direct consequence of \cref{actkey_first_appear}. Applying \cref{defined_def} gives us the second statement. +\end{proof} + +\begin{claim}[\textcode{appearsBeforeSameActiveKeysGet}] + \label{actkey_after_appear} + Let tr be a tree of nodes, $init_1$, $init_2$ two state, $k$ a key, $0 \leq i < 2\cdot \size{tr}$ an integer, + let $l_1 := \scan{tr}{init_1}$, $l_2 := \scan{tr}{init_2}$, $s^1_i := \fst{(l_1[i])}$ and $s^2_i := \fst{(l_2[i])}$. + If $k$ appears before $i$ in $l_1$ and $l_2$, $\defined{s^1_i}$ and $\defined{s^2_i}$, then + \[\actkey{k}{(\get{s^1_i})} = \actkey{k}{(\get{s^2_i})}\] +\end{claim} +\begin{proof} + By strong induction on i: + + If $i = 0$ then the precondition is never met. + + Else let $(s^1_{i -1}, n^1_{i - 1}, dir^1_{i - 1}) := l_1[i - 1]$, $(s^2_{i -1}, n^2_{i - 1}, dir^2_{i - 1}) := l_2[i - 1]$. + By \cref{scan_indexing_node}, $n_{i - 1} := n^1_{i - 1} = n^2_{i - 1}$ and $dir_{i - 1} := dir^1_{i - 1} = dir^2_{i - 1}$. + + By \cref{scan_indexing_state} we either have: + \begin{itemize} + \item $s^1_i = \fdown{s^1_{i - 1}}{n_{i - 1}}$, $s^2_i = \fdown{s^2_{i - 1}}{n_{i - 1}}$ and $dir_{i - 1} =\, \down$ + + If $n_{i - 1}$ is an Action node then: + \begin{itemize} + \item If $n_{i - 1}.k = k$: + \begin{align*} + \actkey{n_{i - 1}.k}{(\get{s^1_i})} &= \actkey{n_{i - 1}.k}{(\get{(\hNode{(\get{s^1_{i - 1}})}{n_{i - 1}})})} \\ + &= \actkey{n_{i - 1}.k}{(\get{(\hNode{(\get{s^2_{i - 1}})}{n_{i - 1}})})} & \text{by \cref{same_key_after_hnode}}\\ + &= \actkey{n_{i - 1}.k}{(\get{s^2_{i}})} \\ + \end{align*} + \item Otherwise, k appears before $i - 1$ in $l_1$ and $l_2$: + \begin{align*} + \actkey{k}{(\get{s^1_i})} &= \actkey{k}{(\get{(\hNode{(\get{s^1_{i - 1}})}{n_{i - 1}})})} \\ + &= \actkey{k}{(\get{s^1_{i - 1}})} &\text{by \cref{actkey_handle_node}} \\ + &= \actkey{k}{(\get{s^2_{i - 1}})} &\text{by IH} \\ + &= \actkey{k}{(\get{(\hNode{(\get{s^2_{i - 1}})}{n_{i - 1}})})} &\text{by \cref{actkey_handle_node}} \\ + &= \actkey{k}{(\get{s^2_i})} + \end{align*} + \end{itemize} + + If $n_{i - 1}$ is a Rollback node then + \begin{align*} + \actkey{k}{(\get{s^1_i})} &= \actkey{k}{(\beginRb{(\get{s^1_{i - 1}})})}\\ + &= \actkey{k}{(\get{s^1_{i - 1}})} &\text{by \cref{act_begin_rb}} \\ + &= \actkey{k}{(\get{s^2_{i - 1}})} &\text{by IH} \\ + &= \actkey{k}{(\beginRb{(\get{s^2_{i - 1}})})} &\text{by \cref{act_begin_rb}}\\ + &= \actkey{k}{(\get{s^2_i})} + \end{align*} + + \item $s^1_i = \fup{s^1_{i - 1}}{n_{i - 1}}$, $s^2_i = \fup{s^2_{i - 1}}{n_{i - 1}}$ and $dir_{i - 1} =\, \up$ + + If $n_{i - 1}$ is an Action node then + \begin{align*} + \actkey{k}{(\get{s^1_i})} &= \actkey{k}{(\get{s^1_{i - 1}})} \\ + &= \actkey{k}{(\get{s^2_{i - 1}})} &\text{by IH}\\ + &=\actkey{k}{(\get{s^2_i})} + \end{align*} + + If $n_{i - 1}$ is a Rollback node then + \[\actkey{k}{(\get{s^1_i})} = \actkey{k}{(\enRb{(\get{s^1_{i - 1}})})} \] + By \cref{find_two_begin_rb} there exists an integer $0 \leq j < i - 1$ , well-defined states $s^1_{j}$, $s^2_{j}$ and tree $sub$ such that $s^1_{i - 1} = \traverse{sub}{(\beginRb{(\get{s^1_{j}})})}$, + $s^2_{i - 1} = \traverse{sub}{(\beginRb{(\get{s^2_{j}})})}$, $l_1[j] = (s^1_j, n_{i - 1}, \down)$, $l_2[j] = (s^2_j, n_{i - 1}, \down)$. Therefore + \begin{align*} + \actkey{k}{(\get{s^1_i})} &= \actkey{k}{(\enRb{(\get{(\traverse{sub}{(\beginRb{(\get{s^1_{j}})})})})})} \\ + &= \actkey{k}{(\get{s^1_{j}})} &\text{by \cref{skip_rb}} \\ + \intertext{If $k$ appears before $j$ in $l_1$ (and $l_2$ by \cref{appear_same}) then we can use the induction hypothesis and go backward.} + &= \actkey{k}{(\get{s^2_{j}})} &\text{by IH} \\ + &= \actkey{k}{(\enRb{(\get{(\traverse{sub}{(\beginRb{(\get{s^2_{j}})})})})})} &\text{by \cref{skip_rb}}\\ + &= \actkey{k}{(\get{s^2_i})} + \end{align*} + + If $k$ does not appear before $j$ in $l_1$ and $l_2$, we can use \cref{find_first_appear} to obtain the index $j \leq j' < i$ such that $j'$ is the + first appearance of $k$ in $l_1$ and $l_2$. + + Since $j' < i$ , by \cref{defined_prop}, $\defined{(\fst{(l_1[j'])})}$, $\defined{(\fst{(l_2[j'])})}$, $\defined{(\fst{(l_1[j' + 1])})}$ and $\defined{(\fst{(l_2[j' + 1])})}$. + By \cref{scan_indexing_state}, $\fst{(l_1[j' + 1])} = \hNode{(\get{(\fst{(l_1[j'])})})}{(\snd{(l_1[j'])})}$ and $(\snd{(l_1[j'])}).k = k$. + Similarly, $\fst{(l_2[j' + 1])} = \hNode{(\get{(\fst{(l_2[j'])})})}{(\snd{(l_2[j'])})}$ and $(\snd{(l_2[j'])}).k = k$. + Therefore: + \begin{align*} + \actkey{k}{s^1_{j}} &= \actkey{k}{(\get{(\fst{(l_1[j'])})})} &\text{by \cref{actkey_first_appear}} \\ + &= \actkey{k}{(\get{(\fst{(l_2[j'])})})} &\text{by \cref{same_key_hnode}} \\ + &= \actkey{k}{(\get{s^2_{j}})} &\text{by \cref{actkey_first_appear}} \\ + &= \actkey{k}{(\enRb{(\get{(\traverse{sub}{(\beginRb{(\get{s^2_j})})})})})} &\text{by \cref{act_begin_rb}} \\ + &= \actkey{k}{(\get{s^2_i})} + \end{align*} + + \end{itemize} + +\end{proof} + +\begin{corollary}[\textcode{appearsBeforeSameUndefined}] + \label{same_defined_appear} + Let tr be a tree of nodes, $init_1, init_2$ states, $0 \leq i < 2\cdot \size{tr} - 1$ an integer, let $l_1 := \scan{tr}{init_1}$, $l_2 := \scan{tr}{init_2}$, $(s^1, n^1, dir^1) := l_1[i]$ and $(s^2, n^2, dir^2) := l_2[i]$. + By \cref{scan_indexing_node}, $n := n^1 = n^2$. If $n$ is an Action node, $n.k$ appears before $i$ in $l_1$ and $l_2$, $\defined{s^1}$ and $\defined{s^2}$, then + \[\defined{(\fst{(l_1[i + 1])})} \iff \defined{(\fst{(l_2[i + 1])})}\] +\end{corollary} + +\begin{proof} + Consequence of \cref{actkey_after_appear} and \cref{defined_def}. +\end{proof} + +\newpage + +\section*{Empty state traversal} + +\begin{definition}[\textcode{collect}] + Let $tr$ be a tree of nodes. + We respectively define the mapping \bolddef{\collect{tr}} as $\longtraverse{tr}{\emptyList}{f_{collect}}{id}$ and \bolddef{\collectTr{tr}} as + $\longscan{tr}{\emptyList}{f_{collect}}{id}$, where + \[\fcoll{m}{n} = + \begin{cases} + m + (n.k \to \mapping{n}) &\text{if $n$ is an Action node with a well-defined key and $n.k \notin m$} \\ + m &\text{otherwise} + \end{cases} \] +\end{definition} + +\begin{definition}[\textcode{emptyState}] + We define the empty state $\emptyState{tr}$ as the state whose rollback stack, locally created contracts set, consumed contracts set and local keys map are empty, + but whose global key map is $\collect{tr}$ +\end{definition} + +\begin{claim}[\textcode{collectTraceProp}] + \label{collect_submap} + Let tr be a tree of nodes and $0 \leq i < 2 \cdot \, \size{tr}$ an integer + \[\fst{((\collectTr{tr})[i])} \subseteq \collect{tr}\] +\end{claim} + +\begin{proof} + By backward induction on $i$ and applying \cref{scan_indexing_state}. +\end{proof} + +\begin{claim}[\textcode{collectTraceDoesNotAppear}, \textcode{collectDoesNotAppear}] + \label{collect_not_contains} + Let tr be a tree of nodes, $0 \leq i < 2 \cdot \, \size{tr}$ an integer and $k$ a key. + \[k \text{ does not appear before } i \text{ in } \collectTr{tr} \iff k \notin \fst{((\collectTr{tr})[i])}\] + In particular + \[k \text{ does not appear before } 2 \cdot \size{tr} - 1 \text{ in } \collectTr{tr} \iff k \notin \collect{tr}\] +\end{claim} + +\begin{proof} + Induction on $i$, applying \cref{scan_indexing_state}. +\end{proof} + + +\begin{corollary}[\textcode{collectGet}] + \label{first_mapping} + Let tr be a tree of nodes, $0 \leq i < 2 \cdot \, \size{tr}$ an integer, + $n := \snd{((\collectTr{tr})[i])}$. If $n$ is an Action node with a well-defined key and $i$ is the first appearance of $n.k$ in $\collectTr{tr}$, then + \[(\collect{tr})[n.k] = \mapping{n}\] +\end{corollary} + +\begin{proof} + By \cref{collect_not_contains}, $n.k \notin \fst{((\collectTr{tr})[i])}$. By applying \cref{scan_indexing_state} we have that $(n.k \to \mapping{n}) \in \fst{((\collectTr{tr})[i + 1])}$ + if $i < 2 \cdot \, \size{tr} - 1$ and $(n.k \to \mapping{n}) \in \collect{tr}$ otherwise. In the first case we make use of \cref{collect_submap} to prove the claim. +\end{proof} + +\begin{corollary}[\textcode{collectContains}] + \label{find_k_first_appear} + Let $tr$ be a tree of nodes and $k$ a key. If $k \in \collect{tr}$ then there exists an integer $0 \leq i < 2 \cdot \size{tr} - 1$ such that + $i$ is the first appearance of $k$ in $\collectTr{tr}$. +\end{corollary} +\begin{proof} + By \cref{collect_not_contains}, if $k \in \collect{tr}$ then $k$ does not appear before $2 \cdot \size{tr} - 1$. \cref{find_first_appear} concludes the proof. +\end{proof} + +\begin{corollary}[\textcode{firstAppearsHandleNodeUndefinedEmpty}] + \label{defined_globalkeys_mapping} + Let tr be a tree of nodes, $init$ a state, $0 \leq i < 2\cdot \size{tr} - 1$ an integer, let $l := \scan{tr}{init}$, and $n := \snd{(l[i])}$. + If $n$ is an Action node with a well-defined key, $i$ is the first appearance of $n.k$ in $l$ and $\defined{(\fst{(l[i])})}$ then + \[\defined{(\fst{(l[i + 1])})} \iff \actkey{n.k}{(\get{init})} = (\emptyState{tr}.globalKeys)[n.k]\] +\end{corollary} + +\begin{proof} + Let $l_{\varepsilon} = \scan{tr}{\emptyState{tr}}$. We know by \cref{scan_indexing_node} that $n = \snd{(l_{\varepsilon}[i])}$. + and $\fst{(l_\varepsilon[i + 1])} = \hNode{(\fst{(l_\varepsilon[i])})}{n}$. + Furthemore by \cref{appear_same}, $n.k$ does not appear before $i$ in $l_{\varepsilon}$. Finally by \cref{defined_traverse_prop}, if $\defined{(\traverse{tr}{\emptyState{tr}})}$ then $\defined{(\fst{(l_{\varepsilon}[i])})}$. + + \begin{align*} + &\actkey{n.k}{(\get{init})} = (\emptyState{tr}.globalKeys)[n.k] \\ + \iff &\actkey{n.k}{(\get{init})} = (\collect{tr})[n.k] \\ + \iff &\actkey{n.k}{(\get{init})} =\, \mapping{n} &\text{by \cref{first_mapping}} \\ + \iff &\defined{(\hNode{(\get{(\fst{(l[i])})})}{n})} &\text{by \cref{defined_first_appear}}\\ + \iff &\defined{(\fst{(l[i + 1])})} &\text{by \cref{scan_indexing_state}} + \end{align*} +\end{proof} + +\begin{corollary} + \label{defined_empty_traverse} + Let tr be a tree of nodes, $init$ a state, $0 \leq i < 2\cdot \size{tr} - 1$ an integer, let $l := \scan{tr}{init}$, and $n := \snd{(l[i])}$. + If $n$ is an Action node with a well-defined key, $\defined{(\traverse{tr}{\emptyState{tr}})}$, $n.k$ appears before $i$ in $l$ and $\defined{(\fst{(l[i])})}$ then + \[\defined{(\fst{(l[i + 1])})}\] +\end{corollary} + +\begin{proof} + By \cref{defined_traverse_prop}, if $\defined{(\traverse{tr}{\emptyState{tr}})}$ then $\defined{(\fst{((\scan{tr}{\emptyState{tr}})[i])})}$ and \\ + $\defined{(\fst{((\scan{tr}{\emptyState{tr}})[i + 1])})}$. Applying \cref{same_defined_appear} concludes the proof. +\end{proof} + + +\begin{res}[\textcode{traverseTransactionEmptyDefined}] + Let $tr$ be a tree of nodes and $init$ a well-defined state. If $\defined{(\traverse{tr}{\emptyState{tr}})}$, then + \[ (\forall (k \to m) \in \emptyState{tr}.globalKeys,\ \actkey{k}{(\get{init})} = m) \iff \defined{(\traverse{tr}{init})} \] +\end{res} + + +\begin{proof} + Let $l := \scan{tr}{init}$. + + $(\Rightarrow)$ direction: By \cref{defined_alt_def}, we only need to prove that for an arbitrary $0 \leq i < 2 \cdot \, \size{tr} \, - 1$,\\ + $\defined{(\fst{(l[i])})} \implies \defined{(\fst{(l[i + 1])})}$. + If $n := \snd{(l[i])}$ is a Rollback node, if $\trd{(l[i])} =\, \up$ or if $n$ is an Action node with no key, then by \cref{scan_indexing_state} (and \cref{actkey_handle_node} in the Action node case) this is automatically true. + + If $n$ is an Action node then either: + \begin{itemize} + \item $i$ is the first appearance of $n.k$ in which case by \cref{defined_globalkeys_mapping} validates the claim + \item $n.k$ appears before $i$ in which case by \cref{defined_empty_traverse}, $\defined{(\fst{(l[i + 1])})}$ + \end{itemize} + + $(\Leftarrow)$ direction: + Assume there exists a key $k$ such that $k \in \emptyState{tr}.globalKeys$ and $\actkey{k}{(\get{init})} \neq (\emptyState{tr}.globalKeys)[k]$. + Then $k \in \collect{tr}$ and by \cref{find_k_first_appear}, there is a $0 \leq i < 2 \cdot \size{tr} - 1$ such that + $i$ is the first appearance of $k$ in $\collectTr{tr}$. In particular this means that $n := \snd{((\collectTr{tr})[i])}$ is an Action node with a well-defined + key and $k = n.k$. By \cref{defined_globalkeys_mapping} this means that either $\neg{\defined{(\fst{(l[i])})}}$ or $\neg{\defined{(\fst{(l[i + 1])})}}$, which both imply by \cref{defined_traverse_prop} + that $\neg{\defined{(\traverse{tr}{init})}}$. +\end{proof} + + +\end{document} diff --git a/daml-lf/verification/scripts/stainless_imports.txt b/daml-lf/verification/scripts/stainless_imports.txt new file mode 100644 index 0000000000..ccbda55611 --- /dev/null +++ b/daml-lf/verification/scripts/stainless_imports.txt @@ -0,0 +1,27 @@ +package lf.verified +package translation + +import stainless.lang._ +import stainless.annotation._ +import stainless.collection._ +import utils.{ + Either, + Map, + Set, + Value, + GlobalKey, + Transaction, + Unreachable, + Node, + ContractKeyUniquenessMode, + Option +} +import utils.Value.ContractId +import utils.Transaction.{ + KeyCreate, + KeyInputError, + NegativeKeyLookup, + InconsistentContractKey, + DuplicateContractKey, + KeyInput +} diff --git a/daml-lf/verification/scripts/verification_script.sh b/daml-lf/verification/scripts/verification_script.sh new file mode 100755 index 0000000000..7805865111 --- /dev/null +++ b/daml-lf/verification/scripts/verification_script.sh @@ -0,0 +1,85 @@ +#!/bin/bash +# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +cd scripts + + +STAINLESS=$1 +ARGS="--watch=false --timeout=30 --vc-cache=false --compact=true --solvers=nativez3 --infer-measures=false" + +#Running stainless, there are 3 modes: +# - translate: translate the original file to a simplified version and verifies only that +# the translation is correct +# - verification: verifies the proof +# - test: test that stainless is executed correctly; does not verify any file +# Finally returns the exit code of stainless: +# 0 : everything verifies +# 1 : something does not verify +# 2 : the files do not compile + +if [[ $2 = "translate" ]]; then + DAML_ROOT=$(git rev-parse --show-toplevel); + FILE_LOCATION="$DAML_ROOT/daml-lf/transaction/src/main/scala/com/digitalasset/daml/lf/transaction/ContractStateMachine.scala"; + + #We first load the original file in a variable + FILE=$(cat $FILE_LOCATION) + + #To be able to quickly double check the output we first remove comments and emptyLines + FILE=$(sed '/^\s*\/\*/d' <<<"$FILE"); + FILE=$(sed '/^\s*\*/d' <<<"$FILE"); + FILE=$(sed '/^\s*\/\//d' <<<"$FILE"); + FILE=$(sed '/^$/d' <<<"$FILE"); + + #Replacing covariant options and lists by invariant ones + FILE=$(sed 's/\bNone\b/None()/g' <<<"$FILE"); + FILE=$(sed 's/\([A-Za-z0-9_]*\)\s*::\s*\([A-Za-z0-9_]*\)/Cons(\1, \2)/g' <<<"$FILE"); + FILE=$(sed 's/\(\([A-Za-z0-9_]\|\.\)*\).filterNot(\(\([A-Za-z0-9_]\|\.\)*\))/Option.filterNot(\1, \3)/g' <<<"$FILE"); + FILE=$(sed 's/\bNil\b/Nil()/g' <<<"$FILE"); + + #Replacing KeyMapping with Option + FILE=$(sed '/^\s*val\s*KeyActive\s*=/d' <<<"$FILE"); + FILE=$(sed '/^\s*val\s*KeyInactive\s*=/d' <<<"$FILE"); + FILE=$(sed 's/\s\(ContractStateMachine\.\)\?KeyInactive\b/ None[ContractId]()/g' <<<"$FILE"); + FILE=$(sed 's/\s\(ContractStateMachine\.\)\?KeyActive\b/ Some[ContractId]/g' <<<"$FILE"); + FILE=$(sed '/^\s*val KeyActive/d' <<<"$FILE"); + + #Replacing exceptions with Unreachable + FILE=$(sed -z 's/throw\s*new\s*[A-Za-z0-9]\+(\s*\("\([A-Za-z0-9(),;:_]\|\s\|\[\|\]\|\-\|\.\|\/\)*"\)\?\s*)/Unreachable()/g' <<<"$FILE"); + FILE=$(sed 's/throw\s*new\s*[A-Za-z0-9]\+(\("\([A-Za-z0-9(),;:_]\|\s\|\[\|\]\|\-\|\.\|\/\)*"\)\?)/Unreachable()/' <<<"$FILE"); + + + #Replacing imports and package name by our own (located in $2) and writing the output in translation/ContractStateMachine.scala + FILE=$(sed '/^package/d' <<<"$FILE"); + FILE=$(sed -z 's/\s*import\s*\([A-Za-z0-9]\|\.\)\+{\([A-Za-z0-9(),:;]\|\s\|\[\|\]\|\-\|\.\|\/\)*}//g' <<<"$FILE"); + FILE=$(sed '/^\s*import\s*\([A-Za-z0-9]\|\.\)\+/d' <<<"$FILE"); + FILE=$(sed 's/\(\s*\)\(case class State\)/\1@dropVCs\n\1\2/' <<<"$FILE"); + + ADD=$(cat stainless_imports.txt); + FILE_DESTINATION="../translation/ContractStateMachine.scala" + echo -e "${ADD}$FILE" > $FILE_DESTINATION; + + $STAINLESS ../utils/* ../translation/* ../transaction/* $ARGS; + + RES=$? + + #Cleaning everything up + rm $FILE_DESTINATION + + exit $RES + +elif [[ $2 = "verify" ]]; then + $STAINLESS ../utils/* ../transaction/* ../tree/* $ARGS; + exit $? + +elif [[ $2 = "test" ]]; then + $STAINLESS $ARGS; + exit $? +else + exit 3 +fi + + + + diff --git a/daml-lf/verification/stainless.nix b/daml-lf/verification/stainless.nix new file mode 100644 index 0000000000..0bfa194c71 --- /dev/null +++ b/daml-lf/verification/stainless.nix @@ -0,0 +1,17 @@ +let + pkgs = import {}; + stdenv = pkgs.stdenv; +in rec { + stainlessEnv = stdenv.mkDerivation rec { + name = "stainless-env"; + shellHook = '' + alias cls=clear + ''; + buildInputs = with pkgs; [ + stdenv + sbt + openjdk17 + z3 + ]; + }; +} diff --git a/daml-lf/verification/transaction/CSMAdvance.scala b/daml-lf/verification/transaction/CSMAdvance.scala new file mode 100644 index 0000000000..2a0f72873e --- /dev/null +++ b/daml-lf/verification/transaction/CSMAdvance.scala @@ -0,0 +1,307 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package transaction + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import scala.annotation.targetName +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ + +import ContractStateMachine._ +import CSMEitherDef._ + +/** This file expresses the change of active state after handling a node as a call to [[ActiveLedgerState.advance]] + * with another state representing what is added. + * + * It also contains necessary properties of the advance methods that are needed in order to prove the above claim. + */ +object CSMAdvanceDef { + + /** Express what is added to an active state after handling a node. + * + * @see [[CSMAdvance.handleNodeActiveState]] for an use case of this definition. + */ + @pure + @opaque + def actionActiveStateAddition(id: NodeId, n: Node.Action): ActiveLedgerState = { + n match { + case create: Node.Create => + ActiveLedgerState( + Set(create.coid), + Map.empty[ContractId, NodeId], + (create.gkeyOpt match { + case None() => Map.empty[GlobalKey, ContractId] + case Some(k) => Map[GlobalKey, ContractId](k -> create.coid) + }), + ) + case exe: Node.Exercise => + ActiveLedgerState( + Set.empty[ContractId], + (if (exe.consuming) Map(exe.targetCoid -> id) else Map.empty[ContractId, NodeId]), + Map.empty[GlobalKey, ContractId], + ) + case _ => ActiveLedgerState.empty + } + } + +} + +object CSMAdvance { + + import CSMAdvanceDef._ + + /** [[ActiveLedgerState.empty]] is a neutral element wrt [[ActiveLedgerState.advance]]. + * + * The set extensionality axiom is used here for the sake of simplicity. + */ + @pure + @opaque + def emptyAdvance(s: ActiveLedgerState): Unit = { + unfold(s.advance(ActiveLedgerState.empty)) + unfold(ActiveLedgerState.empty.advance(s)) + + SetProperties.unionEmpty(s.locallyCreatedThisTimeline) + SetAxioms.extensionality( + s.locallyCreatedThisTimeline ++ Set.empty[ContractId], + s.locallyCreatedThisTimeline, + ) + SetAxioms.extensionality( + Set.empty[ContractId] ++ s.locallyCreatedThisTimeline, + s.locallyCreatedThisTimeline, + ) + MapProperties.concatEmpty(s.consumedBy) + MapAxioms.extensionality(s.consumedBy ++ Map.empty[ContractId, NodeId], s.consumedBy) + MapAxioms.extensionality(Map.empty[ContractId, NodeId] ++ s.consumedBy, s.consumedBy) + MapProperties.concatEmpty(s.localKeys) + MapAxioms.extensionality(s.localKeys ++ Map.empty[GlobalKey, ContractId], s.localKeys) + MapAxioms.extensionality(Map.empty[GlobalKey, ContractId] ++ s.localKeys, s.localKeys) + + }.ensuring( + (s.advance(ActiveLedgerState.empty) == s) && + (ActiveLedgerState.empty.advance(s) == s) + ) + + /** [[ActiveLedgerState.advance]] is an associative operation. + * + * The set extensionality axiom is used here for the sake of simplicity. + */ + @pure + @opaque + def advanceAssociativity( + s1: ActiveLedgerState, + s2: ActiveLedgerState, + s3: ActiveLedgerState, + ): Unit = { + unfold(s1.advance(s2).advance(s3)) + unfold(s1.advance(s2)) + unfold(s1.advance(s2.advance(s3))) + unfold(s2.advance(s3)) + + SetProperties.unionAssociativity( + s1.locallyCreatedThisTimeline, + s2.locallyCreatedThisTimeline, + s3.locallyCreatedThisTimeline, + ) + SetAxioms.extensionality( + (s1.locallyCreatedThisTimeline ++ s2.locallyCreatedThisTimeline) ++ s3.locallyCreatedThisTimeline, + s1.locallyCreatedThisTimeline ++ (s2.locallyCreatedThisTimeline ++ s3.locallyCreatedThisTimeline), + ) + MapProperties.concatAssociativity(s1.consumedBy, s2.consumedBy, s3.consumedBy) + MapAxioms.extensionality( + (s1.consumedBy ++ s2.consumedBy) ++ s3.consumedBy, + s1.consumedBy ++ (s2.consumedBy ++ s3.consumedBy), + ) + MapProperties.concatAssociativity(s1.localKeys, s2.localKeys, s3.localKeys) + MapAxioms.extensionality( + (s1.localKeys ++ s2.localKeys) ++ s3.localKeys, + s1.localKeys ++ (s2.localKeys ++ s3.localKeys), + ) + }.ensuring( + s1.advance(s2).advance(s3) == s1.advance(s2.advance(s3)) + ) + + /** Two states are equal if and only if one is the other after calling [[ActiveLedgerState.advance]] with + * [[ActiveLedgerState.empty]]. + */ + @pure + @opaque + def sameActiveStateActiveState[T](s1: State, e2: Either[T, State]): Unit = { + require(sameActiveState(s1, e2)) + + unfold(sameActiveState(s1, e2)) + + e2 match { + case Right(s2) => emptyAdvance(s1.activeState) + case _ => Trivial() + } + + }.ensuring( + sameActiveState(s1, e2) == + e2.forall( + _.activeState == + s1.activeState.advance(ActiveLedgerState.empty) + ) + ) + + /** Express the [[State.activeState]] of a [[State]], after handling a [[Node.Create]], as a call to + * [[ActiveLedgerState.advance]]. + * + * The set extensionality axiom is used here for the sake of simplicity. + */ + @pure + @opaque + def visitCreateActiveState(s: State, contractId: ContractId, mbKey: Option[GlobalKey]): Unit = { + unfold(s.visitCreate(contractId, mbKey)) + unfold( + s.activeState.advance( + ActiveLedgerState( + Set(contractId), + Map.empty[ContractId, NodeId], + (mbKey match { + case None() => Map.empty[GlobalKey, ContractId] + case Some(k) => Map[GlobalKey, ContractId](k -> contractId) + }), + ) + ) + ) + + unfold(s.activeState.locallyCreatedThisTimeline.incl(contractId)) + MapProperties.concatEmpty(s.activeState.consumedBy) + MapAxioms.extensionality( + s.activeState.consumedBy ++ Map.empty[ContractId, NodeId], + s.activeState.consumedBy, + ) + + mbKey match { + case None() => + MapProperties.concatEmpty(s.activeState.localKeys) + MapAxioms.extensionality( + s.activeState.localKeys ++ Map.empty[GlobalKey, ContractId], + s.activeState.localKeys, + ) + case Some(k) => + unfold(s.activeState.localKeys.updated(k, contractId)) + } + + }.ensuring( + s.visitCreate(contractId, mbKey) + .forall( + _.activeState == s.activeState.advance( + ActiveLedgerState( + Set(contractId), + Map.empty[ContractId, NodeId], + (mbKey match { + case None() => Map.empty[GlobalKey, ContractId] + case Some(k) => Map[GlobalKey, ContractId](k -> contractId) + }), + ) + ) + ) + ) + + /** Express the [[State.activeState]] of a [[State]], after handling a [[Node.Exercise]], as a call to + * [[ActiveLedgerState.advance]]. + * + * The set extensionality axiom is used here for the sake of simplicity. + */ + @pure + @opaque + def visitExerciseActiveState( + s: State, + nodeId: NodeId, + targetId: ContractId, + mbKey: Option[GlobalKey], + byKey: Boolean, + consuming: Boolean, + ): Unit = { + + unfold(s.visitExercise(nodeId, targetId, mbKey, byKey, consuming)) + unfold( + s.activeState.advance( + ActiveLedgerState( + Set.empty[ContractId], + (if (consuming) Map(targetId -> nodeId) else Map.empty[ContractId, NodeId]), + Map.empty[GlobalKey, ContractId], + ) + ) + ) + unfold(sameActiveState(s, s.assertKeyMapping(targetId, mbKey))) + + s.assertKeyMapping(targetId, mbKey) match { + case Right(state) => + unfold(state.consume(targetId, nodeId)) + unfold(state.activeState.consume(targetId, nodeId)) + unfold(state.activeState.consumedBy.updated(targetId, nodeId)) + MapProperties.concatEmpty(state.activeState.consumedBy) + MapAxioms.extensionality( + state.activeState.consumedBy ++ Map.empty[ContractId, NodeId], + state.activeState.consumedBy, + ) + + SetProperties.unionEmpty(state.activeState.locallyCreatedThisTimeline) + SetAxioms.extensionality( + state.activeState.locallyCreatedThisTimeline ++ Set.empty[ContractId], + state.activeState.locallyCreatedThisTimeline, + ) + MapProperties.concatEmpty(state.activeState.localKeys) + MapAxioms.extensionality( + state.activeState.localKeys ++ Map.empty[GlobalKey, ContractId], + state.activeState.localKeys, + ) + case _ => Trivial() + } + + }.ensuring( + s.visitExercise(nodeId, targetId, mbKey, byKey, consuming) + .forall( + _.activeState == s.activeState.advance( + ActiveLedgerState( + Set.empty[ContractId], + (if (consuming) Map(targetId -> nodeId) else Map.empty[ContractId, NodeId]), + Map.empty[GlobalKey, ContractId], + ) + ) + ) + ) + + /** Express the [[State.activeState]] of a [[State]], after handling a [[Node.Action]], as a call to + * [[ActiveLedgerState.advance]]. + */ + @pure + @opaque + def handleNodeActiveState(s: State, nodeId: NodeId, n: Node.Action): Unit = { + + unfold(s.handleNode(nodeId, n)) + unfold(actionActiveStateAddition(nodeId, n)) + + n match { + case create: Node.Create => visitCreateActiveState(s, create.coid, create.gkeyOpt) + case fetch: Node.Fetch => + sameActiveStateActiveState(s, s.assertKeyMapping(fetch.coid, fetch.gkeyOpt)) + case lookup: Node.LookupByKey => + sameActiveStateActiveState(s, s.visitLookup(lookup.gkey, lookup.result)) + case exe: Node.Exercise => + visitExerciseActiveState(s, nodeId, exe.targetCoid, exe.gkeyOpt, exe.byKey, exe.consuming) + } + + }.ensuring( + s.handleNode(nodeId, n) + .forall(sf => sf.activeState == s.activeState.advance(actionActiveStateAddition(nodeId, n))) + ) + +} diff --git a/daml-lf/verification/transaction/CSMEither.scala b/daml-lf/verification/transaction/CSMEither.scala new file mode 100644 index 0000000000..f2da453d9d --- /dev/null +++ b/daml-lf/verification/transaction/CSMEither.scala @@ -0,0 +1,673 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package transaction + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import scala.annotation.targetName +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ + +import ContractStateMachine._ + +/** Having pattern matching inside ensuring or require clauses generates a lot of verification conditions + * and considerably slows down stainless. In order to avoid that, every pattern match in a postcondition + * or precondition should, if possible, be enclosed in an opaque function. + * + * This files gather all field equalities between a [[State]] and a Either[T, State], or between an Either[T, State] + * and an Either[U, State], where T and U are generic type. + * In practice T will be either [[KeyInputError]], [[InconsistentContractKey]] or [[DuplicateContractKey]]. Having a + * generic type avoids Stainless to spend time on what kind of type that is, making it sort of 'opaque'. + * + * It also describes error propagation, which happens when processing a transaction. The exact definitions of error + * propagation are described lower in the file. + */ +object CSMEitherDef { + + /** Checks state equality. Equivalent to checking equality for every field. + * + * @param s1 A well-defined state + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1 is equal to s2 if s2 is well-defined, otherwise returns true + */ + @pure + @opaque + def sameState[T](s1: State, s2: Either[T, State]): Boolean = { + unfold(sameStack(s1, s2)) + unfold(sameGlobalKeys(s1, s2)) + unfold(sameActiveState(s1, s2)) + unfold(sameLocallyCreated(s1, s2)) + unfold(sameConsumed(s1, s2)) + + s2.forall((s: State) => s1 == s) + }.ensuring( + _ == (sameStack(s1, s2) && sameGlobalKeys(s1, s2) && sameActiveState( + s1, + s2, + ) && sameLocallyCreated(s1, s2) && sameConsumed(s1, s2)) + ) + + /** Checks state equality. Equivalent to checking equality for every field. + * + * @param s1 A state that can be either well-defined or erroneous + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1 is equal to s2 when both are well-defined, otherwise returns true + */ + @pure + @opaque + def sameState[U, T](s1: Either[U, State], s2: Either[T, State]): Boolean = { + unfold(sameStack(s1, s2)) + unfold(sameGlobalKeys(s1, s2)) + unfold(sameActiveState(s1, s2)) + unfold(sameLocallyCreated(s1, s2)) + unfold(sameConsumed(s1, s2)) + s1.forall((s: State) => sameState(s, s2)) + }.ensuring( + _ == (sameStack(s1, s2) && sameGlobalKeys(s1, s2) && sameActiveState( + s1, + s2, + ) && sameLocallyCreated(s1, s2) && sameConsumed(s1, s2)) + ) + + /** Checks that two states have the same [[State.locallyCreated]] field. + * + * @param s1 A well-defined state + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.locallyCreated is equal to s2.locallyCreated if s2 is well-defined, otherwise returns true + */ + @pure + @opaque + def sameLocallyCreated[T](s1: State, s2: Either[T, State]): Boolean = { + s2.forall((s: State) => s1.locallyCreated == s.locallyCreated) + } + + /** Checks that two states have the same [[State.locallyCreated]] field. + * + * @param s1 A state that can be either well-defined or erroneous + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.locallyCreated is equal to s2.locallyCreated when both are well-defined, + * otherwise returns true + */ + @pure + @opaque + def sameLocallyCreated[U, T](s1: Either[U, State], s2: Either[T, State]): Boolean = { + s1.forall((s: State) => sameLocallyCreated(s, s2)) + } + + /** Checks that two states have the same [[State.consumed]] field. + * + * @param s1 A well-defined state + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.consumed is equal to s2.consumed if s2 is well-defined, otherwise returns true + */ + @pure + @opaque + def sameConsumed[T](s1: State, s2: Either[T, State]): Boolean = { + s2.forall((s: State) => s1.consumed == s.consumed) + } + + /** Checks that two states have the same [[State.consumed]] field. + * + * @param s1 A state that can be either well-defined or erroneous + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.consumed is equal to s2.locallyCreated when both are well-defined, + * otherwise returns true + */ + @pure + @opaque + def sameConsumed[U, T](s1: Either[U, State], s2: Either[T, State]): Boolean = { + s1.forall((s: State) => sameConsumed(s, s2)) + } + + /** Checks that two states have the same [[State.activeState]] field. + * + * @param s1 A well-defined state + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.activeState is equal to s2.activeState if s2 is well-defined, otherwise returns true + */ + @pure + @opaque + def sameActiveState[T](s1: State, s2: Either[T, State]): Boolean = { + unfold(sameLocallyCreatedThisTimeline(s1, s2)) + unfold(sameConsumedBy(s1, s2)) + unfold(sameLocalKeys(s1, s2)) + s2.forall((s: State) => s1.activeState == s.activeState) + }.ensuring( + _ == (sameLocallyCreatedThisTimeline(s1, s2) && sameConsumedBy(s1, s2) && sameLocalKeys(s1, s2)) + ) + + /** Checks that two states have the same [[State.activeState]] field. + * + * @param s1 A state that can be either well-defined or erroneous + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.activeState is equal to s2.activeState when both are well-defined, + * otherwise returns true + */ + @pure + @opaque + def sameActiveState[U, T](s1: Either[U, State], s2: Either[T, State]): Boolean = { + unfold(sameLocallyCreatedThisTimeline(s1, s2)) + unfold(sameConsumedBy(s1, s2)) + unfold(sameLocalKeys(s1, s2)) + s1.forall((s: State) => sameActiveState(s, s2)) + }.ensuring( + _ == (sameLocallyCreatedThisTimeline(s1, s2) && sameConsumedBy(s1, s2) && sameLocalKeys(s1, s2)) + ) + + /** Checks that two states have the same [[State.activeState.localKeys]] field. + * + * @param s1 A well-defined state + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.activeState.localKeys is equal to s2.activeState.localKeys if s2 is well-defined, otherwise returns true + */ + @pure + @opaque + def sameLocalKeys[T](s1: State, s2: Either[T, State]): Boolean = { + s2.forall((s: State) => s1.activeState.localKeys == s.activeState.localKeys) + } + + /** Checks that two states have the same [[State.activeState.localKeys]] field. + * + * @param s1 A state that can be either well-defined or erroneous + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.activeState.localKeys is equal to s2.activeState.localKeys when both are well-defined, + * otherwise returns true + */ + @pure + @opaque + def sameLocalKeys[U, T](s1: Either[U, State], s2: Either[T, State]): Boolean = { + s1.forall((s: State) => sameLocalKeys(s, s2)) + } + + /** Checks that two states have the same [[State.activeState.locallyCreatedThisTimeline]] field. + * + * @param s1 A well-defined state + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.activeState.locallyCreatedThisTimeline is equal to s2.activeState.locallyCreatedThisTimeline + * if s2 is well-defined, otherwise returns true + */ + @pure + @opaque + def sameLocallyCreatedThisTimeline[T](s1: State, s2: Either[T, State]): Boolean = { + s2.forall((s: State) => + s1.activeState.locallyCreatedThisTimeline == s.activeState.locallyCreatedThisTimeline + ) + } + + /** Checks that two states have the same [[State.activeState.locallyCreatedThisTimeline]] field. + * + * @param s1 A state that can be either well-defined or erroneous + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.activeState.locallyCreatedThisTimeline is equal to s2.activeState.locallyCreatedThisTimeline + * when both are well-defined, otherwise returns true + */ + @pure + @opaque + def sameLocallyCreatedThisTimeline[U, T](s1: Either[U, State], s2: Either[T, State]): Boolean = { + s1.forall((s: State) => sameLocallyCreatedThisTimeline(s, s2)) + } + + /** Checks that two states have the same [[State.activeState.consumedBy]] field. + * + * @param s1 A well-defined state + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.activeState.consumedBy is equal to s2.activeState.consumedBy if s2 is well-defined, otherwise + * returns true + */ + @pure + @opaque + def sameConsumedBy[T](s1: State, s2: Either[T, State]): Boolean = { + s2.forall((s: State) => s1.activeState.consumedBy == s.activeState.consumedBy) + } + + /** Checks that two states have the same [[State.activeState.consumedBy]] field. + * + * @param s1 A state that can be either well-defined or erroneous + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.activeState.consumedBy is equal to s2.activeState.consumedBy when both are well-defined, + * otherwise returns true + */ + @pure + @opaque + def sameConsumedBy[U, T](s1: Either[U, State], s2: Either[T, State]): Boolean = { + s1.forall((s: State) => sameConsumedBy(s, s2)) + } + + /** Checks that two states have the same [[State.globalKeys]] field. + * + * @param s1 A well-defined state + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.globalKeys is equal to s2.globalKeys if s2 is well-defined, otherwise returns true + */ + @pure + @opaque + def sameGlobalKeys[T](s1: State, s2: Either[T, State]): Boolean = { + s2.forall((s: State) => s1.globalKeys == s.globalKeys) + } + + /** Checks that two states have the same [[State.globalKeys]] field. + * + * @param s1 A state that can be either well-defined or erroneous + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.globalKeys is equal to s2.globalKeys when both are well-defined, + * otherwise returns true + */ + @pure + @opaque + def sameGlobalKeys[U, T](s1: Either[U, State], s2: Either[T, State]): Boolean = { + s1.forall((s: State) => sameGlobalKeys(s, s2)) + } + + /** Checks that two states have the same [[State.rollbackStack]] field. + * + * @param s1 A well-defined state + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.rollbackStack is equal to s2.rollbackStack if s2 is well-defined, otherwise returns true + */ + @pure + @opaque + def sameStack[T](s1: State, s2: Either[T, State]): Boolean = { + s2.forall((s: State) => s1.rollbackStack == s.rollbackStack) + } + + /** Checks that two states have the same [[State.rollbackStack]] field. + * + * @param s1 A state that can be either well-defined or erroneous + * @param s2 A state that can be either well-defined or erroneous + * @return Whether s1.activeState.rollbackStack is equal to s2.rollbackStack when both are well-defined, + * otherwise returns true + */ + @pure + @opaque + def sameStack[U, T](s1: Either[U, State], s2: Either[T, State]): Boolean = { + s1.forall((s: State) => sameStack(s, s2)) + } + + /** Various definitions of error propagation + * + * @param s1 + * The state or error before the operation + * + * @param s2 + * The state or error after the operation + * + * 1. PropagatesError: if s1 is an error state then s2 will be an error as well. Both errors + * do not necessarily need to be the same. + * + * 2. PropagatesBothError: s1 is an error state if and only if s2 is an error state as well, + * although both errors do not necessarily need to be the same + * + * 3. PropagatesSameError: if s1 is an error state then s1 == s2 + * + * 4. SameError: if s1 or s2 is an error state then s1 == s2 + * + * Implications: + * + * SameError ==> PropagatesSameError ==> PropagatesError + * SameError ==> PropagatesBothError ==> PropagatesError + */ + + /** Checks that if a state is erroneous than the other one is as well. + * + * @see above for further details + */ + @pure + @opaque + def propagatesError[U, T, V](s1: Either[U, V], s2: Either[T, V]): Boolean = { + s1.isLeft ==> s2.isLeft + } + + /** Checks that a state is erroneous if and only if the other one is as well. + * + * @see above for further details + */ + @pure + def propagatesBothError[U, T, V](s1: Either[U, V], s2: Either[T, V]): Boolean = { + propagatesError(s1, s2) && propagatesError(s2, s1) + } + + /** Checks that if a state is erroneous then the other one is erroneous has well and outputs the same result. + * + * @see above for further details + */ + @pure + @opaque + def propagatesSameError[T, V](s1: Either[T, V], s2: Either[T, V]): Boolean = { + unfold(propagatesError(s1, s2)) + s1.isLeft ==> (s1 == s2) + }.ensuring(_ ==> propagatesError(s1, s2)) + + /** Checks that a state is erroneous if and only if the other one is erroneous. Moreover when this happens, + * both errors are the same. + * + * @see above for further details + */ + @pure + def sameError[T, V](s1: Either[T, V], s2: Either[T, V]): Boolean = { + propagatesSameError(s1, s2) && propagatesSameError(s2, s1) + }.ensuring(_ ==> propagatesBothError(s1, s2)) + +} + +object CSMEither { + + import CSMEitherDef._ + + /** Field equality extension is always reflexive. In order to be transitivity, we need error propagation + * on at least the second operation. In fact transitivity does not hold if we have s1 -> error -> s2 + * with s1.field =/= s2.field. + */ + + /** Reflexivity of [[State.rollbackStack]] field equality. + */ + @pure + @opaque + def sameStackReflexivity[T](s: Either[T, State]): Unit = { + unfold(sameStack(s, s)) + s match { + case Left(_) => Trivial() + case Right(s2) => unfold(sameStack(s2, s)) + } + }.ensuring(sameStack(s, s)) + + /** Transitivity of [[State.rollbackStack]] field equality. Holds only if error propagates between the + * second and the third state. + */ + @pure + @opaque + def sameStackTransitivity[V, W](s1: State, s2: Either[V, State], s3: Either[W, State]): Unit = { + require(sameStack(s1, s2)) + require(sameStack(s2, s3)) + require(propagatesError(s2, s3)) + + unfold(sameStack(s1, s2)) + unfold(sameStack(s2, s3)) + unfold(sameStack(s1, s3)) + unfold(propagatesError(s2, s3)) + + s2 match { + case Left(_) => Trivial() + case Right(s) => unfold(sameStack(s, s3)) + } + + }.ensuring(sameStack(s1, s3)) + + /** Transitivity of [[State.rollbackStack]] field equality. Holds only if error propagates between the + * second and the third state. + */ + @pure + @opaque + def sameStackTransitivity[U, V, W]( + s1: Either[U, State], + s2: Either[V, State], + s3: Either[W, State], + ): Unit = { + require(sameStack(s1, s2)) + require(sameStack(s2, s3)) + require(propagatesError(s2, s3)) + + unfold(sameStack(s1, s2)) + unfold(sameStack(s1, s3)) + s1 match { + case Left(_) => Trivial() + case Right(s) => sameStackTransitivity(s, s2, s3) + } + }.ensuring(sameStack(s1, s3)) + + /** Reflexivity of [[State.globaKeys]] field equality. Holds only if error propagates between the + * second and the third state. + */ + @pure + @opaque + def sameGlobalKeysReflexivity[T](s: Either[T, State]): Unit = { + unfold(sameGlobalKeys(s, s)) + s match { + case Left(_) => Trivial() + case Right(s2) => unfold(sameGlobalKeys(s2, s)) + } + }.ensuring(sameGlobalKeys(s, s)) + + /** Transitivity of [[State.globalKeys]] field equality. Holds only if error propagates between the + * second and the third state. + */ + @pure + @opaque + def sameGlobalKeysTransitivity[V, W]( + s1: State, + s2: Either[V, State], + s3: Either[W, State], + ): Unit = { + require(sameGlobalKeys(s1, s2)) + require(sameGlobalKeys(s2, s3)) + require(propagatesError(s2, s3)) + + unfold(sameGlobalKeys(s1, s2)) + unfold(sameGlobalKeys(s2, s3)) + unfold(sameGlobalKeys(s1, s3)) + unfold(propagatesError(s2, s3)) + + s2 match { + case Left(_) => Trivial() + case Right(s) => unfold(sameGlobalKeys(s, s3)) + } + + }.ensuring(sameGlobalKeys(s1, s3)) + + /** Transitivity of [[State.globalKeys]] field equality. Holds only if error propagates between the + * second and the third state. + */ + @pure + @opaque + def sameGlobalKeysTransitivity[U, V, W]( + s1: Either[U, State], + s2: Either[V, State], + s3: Either[W, State], + ): Unit = { + require(sameGlobalKeys(s1, s2)) + require(sameGlobalKeys(s2, s3)) + require(propagatesError(s2, s3)) + + unfold(sameGlobalKeys(s1, s2)) + unfold(sameGlobalKeys(s1, s3)) + s1 match { + case Left(_) => Trivial() + case Right(s) => sameGlobalKeysTransitivity(s, s2, s3) + } + }.ensuring(sameGlobalKeys(s1, s3)) + + /** Reflexivity of [[State.locallyCreated]] field equality. Holds only if error propagates between the + * second and the third state. + */ + @opaque + def sameLocallyCreatedReflexivity[T](s: Either[T, State]): Unit = { + unfold(sameLocallyCreated(s, s)) + s match { + case Left(_) => Trivial() + case Right(s2) => unfold(sameLocallyCreated(s2, s)) + } + }.ensuring(sameLocallyCreated(s, s)) + + /** Transitivity of [[State.locallyCreated]] field equality. Holds only if error propagates between the + * second and the third state. + */ + @pure + @opaque + def sameLocallyCreatedTransitivity[V, W]( + s1: State, + s2: Either[V, State], + s3: Either[W, State], + ): Unit = { + require(sameLocallyCreated(s1, s2)) + require(sameLocallyCreated(s2, s3)) + require(propagatesError(s2, s3)) + + unfold(sameLocallyCreated(s1, s2)) + unfold(sameLocallyCreated(s2, s3)) + unfold(sameLocallyCreated(s1, s3)) + unfold(propagatesError(s2, s3)) + + s2 match { + case Left(_) => Trivial() + case Right(s) => unfold(sameLocallyCreated(s, s3)) + } + + }.ensuring(sameLocallyCreated(s1, s3)) + + /** Reflexivity of [[State.consumed]] field equality. Holds only if error propagates between the + * second and the third state. + */ + @pure + @opaque + def sameConsumedReflexivity[T](s: Either[T, State]): Unit = { + unfold(sameConsumed(s, s)) + s match { + case Left(_) => Trivial() + case Right(s2) => unfold(sameConsumed(s2, s)) + } + }.ensuring(sameConsumed(s, s)) + + /** Transitivity of [[State.consumed]] field equality. Holds only if error propagates between the + * second and the third state. + */ + @pure + @opaque + def sameConsumedTransitivity[V, W]( + s1: State, + s2: Either[V, State], + s3: Either[W, State], + ): Unit = { + require(sameConsumed(s1, s2)) + require(sameConsumed(s2, s3)) + require(propagatesError(s2, s3)) + + unfold(sameConsumed(s1, s2)) + unfold(sameConsumed(s2, s3)) + unfold(sameConsumed(s1, s3)) + unfold(propagatesError(s2, s3)) + + s2 match { + case Left(_) => Trivial() + case Right(s) => unfold(sameConsumed(s, s3)) + } + + }.ensuring(sameConsumed(s1, s3)) + + /** Reflexivity of error propagation. + */ + @pure + @opaque + def propagatesErrorReflexivity[T, V](s: Either[T, V]): Unit = { + unfold(propagatesError(s, s)) + }.ensuring(propagatesError(s, s)) + + /** Transitivity of error propagation. + */ + @pure + @opaque + def propagatesErrorTransitivity[T, V]( + s1: Either[T, V], + s2: Either[T, V], + s3: Either[T, V], + ): Unit = { + require(propagatesError(s1, s2)) + require(propagatesError(s2, s3)) + unfold(propagatesError(s1, s2)) + unfold(propagatesError(s2, s3)) + unfold(propagatesError(s1, s3)) + }.ensuring(propagatesError(s1, s3)) + + /** Reflexivity of error propagation. + */ + @pure + @opaque + def propagatesSameErrorReflexivity[T, V](s: Either[T, V]): Unit = { + unfold(propagatesSameError(s, s)) + }.ensuring(propagatesSameError(s, s)) + + /** Transitivity of error propagation. + */ + @pure + @opaque + def propagatesSameErrorTransitivity[T, V]( + s1: Either[T, V], + s2: Either[T, V], + s3: Either[T, V], + ): Unit = { + require(propagatesSameError(s1, s2)) + require(propagatesSameError(s2, s3)) + unfold(propagatesSameError(s1, s2)) + unfold(propagatesSameError(s2, s3)) + unfold(propagatesSameError(s1, s3)) + }.ensuring(propagatesSameError(s1, s3)) + + /** Reflexivity of error propagation. + */ + @pure + @opaque + def sameErrorReflexivity[T, V](s: Either[T, V]): Unit = { + propagatesSameErrorReflexivity(s) + }.ensuring(sameError(s, s)) + + /** Transitivity of error propagation. + */ + @pure + @opaque + def sameErrorTransitivity[T, V](s1: Either[T, V], s2: Either[T, V], s3: Either[T, V]): Unit = { + require(sameError(s1, s2)) + require(sameError(s2, s3)) + propagatesSameErrorTransitivity(s1, s2, s3) + propagatesSameErrorTransitivity(s3, s2, s1) + }.ensuring(sameError(s1, s3)) + + /** Right map propagates error in both directions. + */ + @pure + @opaque + def propagatesSameErrorMap[T, V](e: Either[T, V], f: V => V): Unit = { + unfold(propagatesSameError(e, e.map(f))) + unfold(propagatesSameError(e.map(f), e)) + }.ensuring(propagatesSameError(e, e.map(f)) && propagatesSameError(e.map(f), e)) + + /** Right map propagates error in both direction. Furthermore they are the same. + */ + @pure + @opaque + def sameErrorMap[T, V](e: Either[T, V], f: V => V): Unit = { + propagatesSameErrorMap(e, f) + }.ensuring(sameError(e, e.map(f))) + + /** Left map propagates error in both directions. Furthermore the resulting state is equal to the input. + */ + @pure + @opaque + def sameStateLeftProj[U, T](s: Either[U, State], f: U => T): Unit = { + + val lMap = s.left.map(f) + + unfold(sameState(s, lMap)) + unfold(propagatesError(s, lMap)) + unfold(propagatesError(lMap, s)) + s match { + case Left(_) => Trivial() + case Right(state) => unfold(sameState(state, lMap)) + } + }.ensuring( + sameState(s, s.left.map(f)) && + propagatesBothError(s, s.left.map(f)) + ) + +} diff --git a/daml-lf/verification/transaction/CSMHelpers.scala b/daml-lf/verification/transaction/CSMHelpers.scala new file mode 100644 index 0000000000..cb1f1a8146 --- /dev/null +++ b/daml-lf/verification/transaction/CSMHelpers.scala @@ -0,0 +1,174 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package transaction + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import scala.annotation.targetName +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ + +import ContractStateMachine._ +import CSMEitherDef._ +import CSMEither._ + +/** Helpers definitions + * + * Definitions tightly related to ContractStateMachineAlt content. Most of them are extensions of already existing + * methods with Either arguments type instead of State. + */ +object CSMHelpers { + + /** Extension of [[State.handleNode]] method to a union type argument. If e is defined calls handleNode otherwise returns + * the same error. + */ + @pure + def handleNode( + e: Either[KeyInputError, State], + id: NodeId, + node: Node.Action, + ): Either[KeyInputError, State] = { + val res = e match { + case Left(e) => Left[KeyInputError, State](e) + case Right(s) => s.handleNode(id, node) + } + unfold(propagatesError(e, res)) + unfold(sameGlobalKeys(e, res)) + res + }.ensuring(res => + propagatesError(e, res) && + sameGlobalKeys(e, res) + ) + + /** Converts consumed active key mappings into an inactive key mappings. + */ + @pure + @opaque + def keyMappingToActiveMapping(consumedBy: Map[ContractId, NodeId]): KeyMapping => KeyMapping = { + case Some(cid) if !consumedBy.contains(cid) => KeyActive(cid) + case _ => KeyInactive + } + + /** Converts the error of an union type argument (InconsistentContractKey or DuplicateContractkey) into the more + * generic error type KeyInputError. + */ + @pure + def toKeyInputError(e: Either[InconsistentContractKey, State]): Either[KeyInputError, State] = { + sameStateLeftProj(e, Left[InconsistentContractKey, DuplicateContractKey](_)) + e.left.map(Left[InconsistentContractKey, DuplicateContractKey](_)) + }.ensuring(res => + propagatesBothError(e, res) && + sameState(e, res) + ) + @pure + @targetName("toKeyInputErrorDuplicateContractKey") + def toKeyInputError(e: Either[DuplicateContractKey, State]): Either[KeyInputError, State] = { + sameStateLeftProj(e, Right[InconsistentContractKey, DuplicateContractKey](_)) + e.left.map(Right[InconsistentContractKey, DuplicateContractKey](_)) + }.ensuring(res => + propagatesBothError(e, res) && + sameState(e, res) + ) + + /** Mapping that is added in the global keys when the n is handled. + * + * In the original ContractStateMachine, handleNode first adds the node's key with this mapping to the global keys + * and then process the node in the same way as ContractStateMachineAlt. In the simplified version, both operations + * are separated which brings the need for this function. + * + * @see mapping n, in the latex document + */ + @pure + @opaque + def nodeActionKeyMapping(n: Node.Action): KeyMapping = { + n match { + case create: Node.Create => KeyInactive + case fetch: Node.Fetch => KeyActive(fetch.coid) + case lookup: Node.LookupByKey => lookup.result + case exe: Node.Exercise => KeyActive(exe.targetCoid) + } + } + + /** Extension of State.beginRollback() method to a union type argument. If e is defined calls beginRollback otherwise + * returns the same error. + */ + @pure + @opaque + def beginRollback[T](e: Either[T, State]): Either[T, State] = { + sameErrorMap(e, s => s.beginRollback()) + val res = e.map(s => s.beginRollback()) + unfold(sameGlobalKeys(e, res)) + unfold(sameLocallyCreated(e, res)) + unfold(sameConsumed(e, res)) + e match { + case Left(t) => Trivial() + case Right(s) => + unfold(s.beginRollback()) + unfold(sameGlobalKeys(s, res)) + unfold(sameLocallyCreated(s, res)) + unfold(sameConsumed(s, res)) + } + res + }.ensuring((res: Either[T, State]) => + sameGlobalKeys[T, T](e, res) && + sameLocallyCreated(e, res) && + sameConsumed(e, res) && + sameError(e, res) + ) + + /** Extension of State.endRollback() method to a union type argument. If e is defined calls endRollback otherwise + * returns the same error. + * + * Due to a bug in Stainless, writing the properties in the enduring clause of endRollback makes the JVM crash. + * Therefore, all of these are grouped in a different function endRollbackProp that often needs to be called + * separately when using endRollback with a union type argument. + */ + @pure + @opaque + def endRollback(e: Either[KeyInputError, State]): Either[KeyInputError, State] = { + e match { + case Left(t) => Left[KeyInputError, State](t) + case Right(s) => s.endRollback() + } + } + + @pure @opaque + def endRollbackProp(e: Either[KeyInputError, State]): Unit = { + unfold(endRollback(e)) + unfold(sameLocallyCreated(e, endRollback(e))) + unfold(sameConsumed(e, endRollback(e))) + unfold(sameGlobalKeys(e, endRollback(e))) + unfold(propagatesError(e, endRollback(e))) + }.ensuring( + sameLocallyCreated(e, endRollback(e)) && + sameConsumed(e, endRollback(e)) && + sameGlobalKeys(e, endRollback(e)) && + propagatesError(e, endRollback(e)) + ) + + @pure + @opaque + def advanceIsDefined(init: State, substate: State): Unit = { + require(!substate.withinRollbackScope) + }.ensuring( + init.advance(substate).isRight == + substate.globalKeys.keySet.forall(k => + init.activeKeys.get(k).forall(m => Some(m) == substate.globalKeys.get(k)) + ) + ) + +} diff --git a/daml-lf/verification/transaction/CSMInconsistency.scala b/daml-lf/verification/transaction/CSMInconsistency.scala new file mode 100644 index 0000000000..40385c141c --- /dev/null +++ b/daml-lf/verification/transaction/CSMInconsistency.scala @@ -0,0 +1,698 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package transaction + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import scala.annotation.targetName +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ + +import ContractStateMachine._ +import CSMHelpers._ +import CSMEitherDef._ +import CSMKeysPropertiesDef._ +import CSMKeysProperties._ + +/** File stating under which conditions a state is well defined after handling a node and how its active keys behave + * after that. + */ +object CSMInconsistencyDef { + + /** * + * Alternative and general definition of an error in a transaction + * + * The objective of bringing this definition in are proving that: + * - s.handleNode(nid, node) leads to an error <=> inconsistencyCheck(s, node.gkeyOpt, node.m) is true + * - If node.gkeyOpt =/= Some(k2) then + * inconsistencyCheck(s, k2, m) == inconsistencyCheck(s.handleNode(nid, node), k2, m) + * + * @param s + * The current state + * @param k + * The key of the node we are handling. If the key is not defined then the transaction never fails, so + * the inconsistencyCheck is false + * @param m + * The key mapping that would have been added to the globalKeys if there was not an entry already for that key. + * This mapping (let denote it node.m) is: + * - KeyInactive for Node.Create + * - KeyActive(coid) for Node.Fetch + * - result for Node.Lookup + * - KeyActive(targetCoid) for Node.Exercise + */ + @pure + def inconsistencyCheck(s: State, k: GlobalKey, m: KeyMapping): Boolean = { + s.activeKeys.get(k).exists(_ != m) + } + + @pure + def inconsistencyCheck(s: State, k: Option[GlobalKey], m: KeyMapping): Boolean = { + k.exists(gk => inconsistencyCheck(s, gk, m)) + } + +} + +object CSMInconsistency { + + import CSMInconsistencyDef._ + import CSMInvariantDef._ + import CSMInvariantDerivedProp._ + + /** The resulting state after handling a Create node is defined if any only if the inconsistencyCondition is not met + */ + @pure + @opaque + def visitCreateUndefined(s: State, contractId: ContractId, mbKey: Option[GlobalKey]): Unit = { + unfold(inconsistencyCheck(s, mbKey, KeyInactive)) + unfold(s.visitCreate(contractId, mbKey)) + + mbKey match { + case None() => Trivial() + case Some(gk) => unfold(inconsistencyCheck(s, gk, KeyInactive)) + } + + }.ensuring( + (s.visitCreate(contractId, mbKey).isLeft == inconsistencyCheck(s, mbKey, KeyInactive)) + ) + + /** The resulting state after handling a Fetch node is defined if any only if the inconsistencyCondition is not met + */ + @pure + @opaque + def assertKeyMappingUndefined(s: State, cid: ContractId, gkey: Option[GlobalKey]): Unit = { + require(containsOptionKey(s)(gkey)) + + unfold(containsOptionKey(s)(gkey)) + unfold(inconsistencyCheck(s, gkey, KeyActive(cid))) + unfold(s.assertKeyMapping(cid, gkey)) + + gkey match { + case None() => Trivial() + case Some(gk) => visitLookupUndefined(s, gk, KeyActive(cid)) + } + + }.ensuring(s.assertKeyMapping(cid, gkey).isLeft == inconsistencyCheck(s, gkey, KeyActive(cid))) + + /** The resulting state after handling an Exercise node is defined if any only if the inconsistencyCondition is not met + */ + @pure + @opaque + def visitExerciseUndefined( + s: State, + nodeId: NodeId, + targetId: ContractId, + gk: Option[GlobalKey], + byKey: Boolean, + consuming: Boolean, + ): Unit = { + require(containsOptionKey(s)(gk)) + unfold(s.visitExercise(nodeId, targetId, gk, byKey, consuming)) + assertKeyMappingUndefined(s, targetId, gk) + }.ensuring( + s.visitExercise(nodeId, targetId, gk, byKey, consuming).isLeft == inconsistencyCheck( + s, + gk, + KeyActive(targetId), + ) + ) + + /** The resulting state after handling a Lookup node is defined if any only if the inconsistencyCondition is not met + */ + @pure + @opaque + def visitLookupUndefined(s: State, gk: GlobalKey, keyResolution: Option[ContractId]): Unit = { + require(containsKey(s)(gk)) + + unfold(inconsistencyCheck(s, gk, keyResolution)) + unfold(s.visitLookup(gk, keyResolution)) + unfold(containsKey(s)(gk)) + unfold(s.globalKeys.contains) + + activeKeysGetOrElse(s, gk, KeyInactive) + activeKeysGet(s, gk) + + }.ensuring(s.visitLookup(gk, keyResolution).isLeft == inconsistencyCheck(s, gk, keyResolution)) + + /** The resulting state after handling a node is defined if any only if the inconsistencyCondition is not met + */ + @pure + @opaque + def handleNodeUndefined(s: State, id: NodeId, node: Node.Action): Unit = { + require(containsActionKey(s)(node)) + + unfold(s.handleNode(id, node)) + unfold(nodeActionKeyMapping(node)) + unfold(containsActionKey(s)(node)) + + node match { + case create: Node.Create => visitCreateUndefined(s, create.coid, create.gkeyOpt) + case fetch: Node.Fetch => assertKeyMappingUndefined(s, fetch.coid, fetch.gkeyOpt) + case lookup: Node.LookupByKey => + unfold(containsOptionKey(s)(node.gkeyOpt)) + unfold(inconsistencyCheck(s, node.gkeyOpt, nodeActionKeyMapping(node))) + visitLookupUndefined(s, lookup.gkey, lookup.result) + case exe: Node.Exercise => + visitExerciseUndefined(s, id, exe.targetCoid, exe.gkeyOpt, exe.byKey, exe.consuming) + } + }.ensuring( + s.handleNode(id, node).isLeft == inconsistencyCheck(s, node.gkeyOpt, nodeActionKeyMapping(node)) + ) + + /** If two states are defined after handling a node then their activeKeys shared the same entry for the node's key + * before entering the node. + */ + @pure + @opaque + def handleSameNodeActiveKeys(s1: State, s2: State, id: NodeId, node: Node.Action): Unit = { + require(containsActionKey(s1)(node)) + require(containsActionKey(s2)(node)) + require(s1.handleNode(id, node).isRight) + require(s2.handleNode(id, node).isRight) + activeKeysContainsKey(s1, node) + activeKeysContainsKey(s2, node) + handleNodeUndefined(s1, id, node) + handleNodeUndefined(s2, id, node) + unfold(s1.activeKeys.contains) + unfold(s2.activeKeys.contains) + }.ensuring(node.gkeyOpt.forall(k => s1.activeKeys.get(k) == s2.activeKeys.get(k))) + + /** If a state stay well defined after handling a Create node and if we are given a key different from the node's one, + * then the entry for that key in the activeKeys of the state did not change + */ + @opaque + @pure + def visitCreateActiveKeysGet( + s: State, + contractId: ContractId, + mbKey: Option[GlobalKey], + k2: GlobalKey, + ): Unit = { + require(s.visitCreate(contractId, mbKey).isRight) + require(mbKey.forall(k1 => k1 != k2)) + unfold(s.visitCreate(contractId, mbKey)) + + activeKeysGet(s, k2) + + val me = + s.copy( + locallyCreated = s.locallyCreated + contractId, + activeState = s.activeState + .copy(locallyCreatedThisTimeline = s.activeState.locallyCreatedThisTimeline + contractId), + ) + + mbKey match { + case None() => + activeKeysGet(me, k2) + case Some(gk) => + activeKeysGet(me.copy(activeState = me.activeState.createKey(gk, contractId)), k2) + MapProperties.updatedGet(me.activeState.localKeys, gk, contractId, k2) + } + + }.ensuring(s.visitCreate(contractId, mbKey).get.activeKeys.get(k2) == s.activeKeys.get(k2)) + + /** If a state stay well defined after handling a Lookup node and if we are given a key different from the node's one, + * then the entry for that key in the activeKeys of the state did not change + */ + @opaque + @pure + def visitLookupActiveKeysGet( + s: State, + gk: GlobalKey, + keyResolution: Option[ContractId], + k2: GlobalKey, + ): Unit = { + require(s.visitLookup(gk, keyResolution).isRight) + unfold(sameState(s, s.visitLookup(gk, keyResolution))) + }.ensuring(s.visitLookup(gk, keyResolution).get.activeKeys.get(k2) == s.activeKeys.get(k2)) + + /** If a state stay well defined after handling a Fetch node and if we are given a key different from the node's one, + * then the entry for that key in the activeKeys of the state did not change + */ + @opaque + @pure + def assertKeyMappingActiveKeysGet( + s: State, + cid: ContractId, + mbKey: Option[GlobalKey], + k2: GlobalKey, + ): Unit = { + require(s.assertKeyMapping(cid, mbKey).isRight) + unfold(sameState(s, s.assertKeyMapping(cid, mbKey))) + }.ensuring(s.assertKeyMapping(cid, mbKey).get.activeKeys.get(k2) == s.activeKeys.get(k2)) + + /** If a state stay well defined after handling an Exercise node and if we are given a key different from the node's one, + * then the entry for that key in the activeKeys of the state did not change + */ + @pure + @opaque + def visitExerciseActiveKeysGet( + s: State, + nodeId: NodeId, + targetId: ContractId, + gk: Option[GlobalKey], + byKey: Boolean, + consuming: Boolean, + k2: GlobalKey, + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(s.visitExercise(nodeId, targetId, gk, byKey, consuming).isRight) + require(containsOptionKey(s)(gk)) + require(!gk.isDefined ==> unbound.contains(targetId)) + require(gk.forall(k1 => k1 != k2)) + require(stateInvariant(s)(unbound, lc)) + + unfold(s.visitExercise(nodeId, targetId, gk, byKey, consuming)) + + visitExerciseUndefined(s, nodeId, targetId, gk, byKey, consuming) + unfold(inconsistencyCheck(s, gk, KeyActive(targetId))) + unfold(containsOptionKey(s)(gk)) + + s.assertKeyMapping(targetId, gk) match { + case Left(_) => Trivial() + case Right(state) => + keysGet(s, k2) + unfold(s.globalKeys.contains) + unfold(s.activeState.localKeys.contains) + + gk match { + case Some(k) => + // (s.globalKeys.get(k) == Some(KeyActive(targetId))) || (s.activeState.localKeys.get(k).map(KeyActive) == Some(KeyActive(targetId))) + activeKeysGet(s, k) + unfold(containsKey(s)(k)) + unfold(keyMappingToActiveMapping(s.activeState.consumedBy)) + if (s.globalKeys.get(k) == Some(KeyActive(targetId))) { + invariantGetGlobalKeys(s, k, k2, targetId, unbound, lc) + } else { + invariantGetLocalKeys(s, k, k2, targetId, unbound, lc) + } + case None() => + SetProperties.mapContains(unbound, KeyActive, targetId) + SetProperties.disjointContains( + unbound.map(KeyActive), + s.globalKeys.values, + KeyActive(targetId), + ) + if (s.globalKeys.get(k2) == Some(KeyActive(targetId))) { + MapAxioms.valuesContains(s.globalKeys, k2) + } + + SetProperties.disjointContains(unbound, s.activeState.localKeys.values, targetId) + if (s.activeState.localKeys.get(k2) == Some(targetId)) { + MapAxioms.valuesContains(s.activeState.localKeys, k2) + } + } + + unfold(sameState(s, s.assertKeyMapping(targetId, gk))) + unfold(state.consume(targetId, nodeId)) + unfold(state.activeState.consume(targetId, nodeId)) + activeKeysGet(state.consume(targetId, nodeId), k2) + activeKeysGet(state, k2) + unfold(keyMappingToActiveMapping(state.activeState.consumedBy)) + unfold(keyMappingToActiveMapping(state.activeState.consumedBy.updated(targetId, nodeId))) + + (s.activeState.localKeys.get(k2).map(KeyActive), s.globalKeys.get(k2)) match { + case (Some(Some(c1)), _) => + MapProperties.updatedContains(state.activeState.consumedBy, targetId, nodeId, c1) + case (_, Some(Some(c2))) => + MapProperties.updatedContains(state.activeState.consumedBy, targetId, nodeId, c2) + case _ => Trivial() + } + } + + }.ensuring( + s.visitExercise(nodeId, targetId, gk, byKey, consuming).get.activeKeys.get(k2) == s.activeKeys + .get(k2) + ) + + /** If a state stay well defined after handling a node and if we are given a key different from the node's one, + * then the entry for that key in the activeKeys of the state did not change + */ + @pure + @opaque + def handleNodeActiveKeysGet( + s: State, + id: NodeId, + node: Node.Action, + k2: GlobalKey, + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(s.handleNode(id, node).isRight) + require(containsActionKey(s)(node)) + require(node.gkeyOpt.forall(k1 => k1 != k2)) + require(stateInvariant(s)(unbound, lc)) + require(stateNodeCompatibility(s, node, unbound, lc, TraversalDirection.Down)) + + unfold(s.handleNode(id, node)) + unfold(containsActionKey(s)(node)) + + node match { + case create: Node.Create => visitCreateActiveKeysGet(s, create.coid, create.gkeyOpt, k2) + case fetch: Node.Fetch => assertKeyMappingActiveKeysGet(s, fetch.coid, fetch.gkeyOpt, k2) + case lookup: Node.LookupByKey => visitLookupActiveKeysGet(s, lookup.gkey, lookup.result, k2) + case exe: Node.Exercise => + unfold(stateNodeCompatibility(s, node, unbound, lc, TraversalDirection.Down)) + visitExerciseActiveKeysGet( + s, + id, + exe.targetCoid, + exe.gkeyOpt, + exe.byKey, + exe.consuming, + k2, + unbound, + lc, + ) + } + + }.ensuring(s.handleNode(id, node).get.activeKeys.get(k2) == s.activeKeys.get(k2)) + + /** The entry for a given key in the activeKeys of the state does not change after entering a beginRollback + */ + @pure + @opaque + def beginRollbackActiveKeysGet(s: State, k: GlobalKey): Unit = { + unfold(s.beginRollback()) + activeKeysGetSameFields(s.beginRollback(), s, k) + }.ensuring(s.beginRollback().activeKeys.get(k) == s.activeKeys.get(k)) + + /** If two states are defined after handling a Create node then they share the same mapping for the node's key entry in + * the activeKeys. + */ + @opaque + @pure + def visitCreateDifferentStatesActiveKeysGet( + s1: State, + s2: State, + contractId: ContractId, + mbKey: Option[GlobalKey], + ): Unit = { + require(s1.visitCreate(contractId, mbKey).isRight) + require(s2.visitCreate(contractId, mbKey).isRight) + require(!s1.activeState.consumedBy.contains(contractId)) + require(!s2.activeState.consumedBy.contains(contractId)) + + unfold(s1.visitCreate(contractId, mbKey)) + unfold(s2.visitCreate(contractId, mbKey)) + + val me1 = + s1.copy( + locallyCreated = s1.locallyCreated + contractId, + activeState = s1.activeState + .copy(locallyCreatedThisTimeline = s1.activeState.locallyCreatedThisTimeline + contractId), + ) + + val me2 = + s2.copy( + locallyCreated = s2.locallyCreated + contractId, + activeState = s2.activeState + .copy(locallyCreatedThisTimeline = s2.activeState.locallyCreatedThisTimeline + contractId), + ) + + mbKey match { + case None() => Trivial() + case Some(gk) => + val sf1 = me1.copy(activeState = me1.activeState.createKey(gk, contractId)) + val sf2 = me2.copy(activeState = me2.activeState.createKey(gk, contractId)) + activeKeysGet(sf1, gk) + activeKeysGet(sf2, gk) + unfold(me1.activeState.createKey(gk, contractId)) + unfold(me2.activeState.createKey(gk, contractId)) + unfold(keyMappingToActiveMapping(s1.activeState.consumedBy)) + unfold(keyMappingToActiveMapping(s2.activeState.consumedBy)) + } + }.ensuring( + mbKey.forall(k => + s1.visitCreate(contractId, mbKey).get.activeKeys.get(k) == s2 + .visitCreate(contractId, mbKey) + .get + .activeKeys + .get(k) + ) + ) + + /** If two states are defined after handling a Lookup node then they share the same mapping for the node's key entry in + * the activeKeys. + */ + @opaque + @pure + def visitLookupDifferentStatesActiveKeysGet( + s1: State, + s2: State, + gk: GlobalKey, + keyResolution: Option[ContractId], + unbound1: Set[ContractId], + lc1: Set[ContractId], + unbound2: Set[ContractId], + lc2: Set[ContractId], + ): Unit = { + require(s1.visitLookup(gk, keyResolution).isRight) + require(s2.visitLookup(gk, keyResolution).isRight) + require( + keyResolution.isDefined || (containsKey(s1)(gk) && containsKey(s2)(gk) && stateInvariant(s1)( + unbound1, + lc1, + ) && stateInvariant(s2)(unbound2, lc2)) + ) + + unfold(s1.visitLookup(gk, keyResolution)) + unfold(s2.visitLookup(gk, keyResolution)) + unfold(s1.activeKeys.getOrElse(gk, KeyInactive)) + unfold(s2.activeKeys.getOrElse(gk, KeyInactive)) + + keyResolution match { + case None() => + invariantContainsKey(s1, gk, unbound1, lc1) + invariantContainsKey(s2, gk, unbound2, lc2) + unfold(s1.activeKeys.contains) + unfold(s2.activeKeys.contains) + case Some(_) => Trivial() + } + + }.ensuring( + (s1.visitLookup(gk, keyResolution).get.activeKeys.get(gk) == s2 + .visitLookup(gk, keyResolution) + .get + .activeKeys + .get(gk)) && + (s1.visitLookup(gk, keyResolution).get.activeKeys.get(gk) == Some(keyResolution)) + ) + + /** If two states are defined after handling a Fetch node then they share the same mapping for the node's key entry in + * the activeKeys. + */ + @opaque + @pure + def assertKeyMappingDifferentStatesActiveKeysGet( + s1: State, + s2: State, + cid: ContractId, + mbKey: Option[GlobalKey], + ): Unit = { + require(s1.assertKeyMapping(cid, mbKey).isRight) + require(s2.assertKeyMapping(cid, mbKey).isRight) + + unfold(s1.assertKeyMapping(cid, mbKey)) + unfold(s2.assertKeyMapping(cid, mbKey)) + mbKey match { + case None() => Trivial() + case Some(k) => + visitLookupDifferentStatesActiveKeysGet( + s1, + s2, + k, + Some(cid), + Set.empty[ContractId], + Set.empty[ContractId], + Set.empty[ContractId], + Set.empty[ContractId], + ) + } + + }.ensuring( + mbKey.forall(gk => + (s1.assertKeyMapping(cid, mbKey).get.activeKeys.get(gk) == s2 + .assertKeyMapping(cid, mbKey) + .get + .activeKeys + .get(gk)) && + (s1.assertKeyMapping(cid, mbKey).get.activeKeys.get(gk) == Some(KeyActive(cid))) + ) + ) + + /** If two states are defined after handling an Exercise node then they share the same mapping for the node's key entry in + * the activeKeys. + */ + @opaque + @pure + def visitExerciseDifferentStatesActiveKeysGet( + s1: State, + s2: State, + nodeId: NodeId, + targetId: ContractId, + mbKey: Option[GlobalKey], + byKey: Boolean, + consuming: Boolean, + ): Unit = { + require(s1.visitExercise(nodeId, targetId, mbKey, byKey, consuming).isRight) + require(s2.visitExercise(nodeId, targetId, mbKey, byKey, consuming).isRight) + unfold(s1.visitExercise(nodeId, targetId, mbKey, byKey, consuming)) + unfold(s2.visitExercise(nodeId, targetId, mbKey, byKey, consuming)) + + assertKeyMappingDifferentStatesActiveKeysGet(s1, s2, targetId, mbKey) + + (s1.assertKeyMapping(targetId, mbKey), s2.assertKeyMapping(targetId, mbKey), mbKey) match { + case (Left(_), _, _) => Unreachable() + case (_, Left(_), _) => Unreachable() + case (Right(state1), Right(state2), Some(k)) => + unfold(sameState(s1, s1.assertKeyMapping(targetId, mbKey))) + unfold(sameState(s2, s2.assertKeyMapping(targetId, mbKey))) + unfold(state1.consume(targetId, nodeId)) + unfold(state1.activeState.consume(targetId, nodeId)) + unfold(state2.consume(targetId, nodeId)) + unfold(state2.activeState.consume(targetId, nodeId)) + + activeKeysGet(state1, k) + activeKeysGet(state2, k) + unfold(keyMappingToActiveMapping(state1.activeState.consumedBy)) + unfold(keyMappingToActiveMapping(state2.activeState.consumedBy)) + + activeKeysGet(state1.consume(targetId, nodeId), k) + activeKeysGet(state2.consume(targetId, nodeId), k) + unfold(keyMappingToActiveMapping(state1.activeState.consumedBy.updated(targetId, nodeId))) + unfold(keyMappingToActiveMapping(state2.activeState.consumedBy.updated(targetId, nodeId))) + + case _ => Trivial() + } + + }.ensuring( + mbKey.forall(k => + s1.visitExercise(nodeId, targetId, mbKey, byKey, consuming).get.activeKeys.get(k) == + s2.visitExercise(nodeId, targetId, mbKey, byKey, consuming).get.activeKeys.get(k) + ) + ) + + /** If two states are defined after handling a node then they share the same mapping for the node's key entry in + * the activeKeys. + */ + @opaque + @pure + def handleNodeDifferentStatesActiveKeysGet( + s1: State, + s2: State, + nid: NodeId, + node: Node.Action, + unbound1: Set[ContractId], + lc1: Set[ContractId], + unbound2: Set[ContractId], + lc2: Set[ContractId], + ): Unit = { + require(s1.handleNode(nid, node).isRight) + require(s2.handleNode(nid, node).isRight) + require(containsActionKey(s1)(node)) + require(containsActionKey(s2)(node)) + require(stateInvariant(s1)(unbound1, lc1)) + require(stateInvariant(s2)(unbound2, lc2)) + require(stateNodeCompatibility(s1, node, unbound1, lc1, TraversalDirection.Down)) + require(stateNodeCompatibility(s2, node, unbound2, lc2, TraversalDirection.Down)) + + unfold(s1.handleNode(nid, node)) + unfold(s2.handleNode(nid, node)) + unfold(containsActionKey(s1)(node)) + unfold(containsActionKey(s2)(node)) + unfold(stateNodeCompatibility(s1, node, unbound1, lc1, TraversalDirection.Down)) + unfold(stateNodeCompatibility(s2, node, unbound2, lc2, TraversalDirection.Down)) + + node match { + case create: Node.Create => + if (s1.activeState.consumedBy.contains(create.coid)) { + MapProperties.keySetContains(s1.activeState.consumedBy, create.coid) + SetProperties.subsetOfContains(s1.activeState.consumedBy.keySet, s1.consumed, create.coid) + Unreachable() + } + if (s2.activeState.consumedBy.contains(create.coid)) { + MapProperties.keySetContains(s2.activeState.consumedBy, create.coid) + SetProperties.subsetOfContains(s2.activeState.consumedBy.keySet, s2.consumed, create.coid) + Unreachable() + } + visitCreateDifferentStatesActiveKeysGet(s1, s2, create.coid, create.gkeyOpt) + case fetch: Node.Fetch => + assertKeyMappingDifferentStatesActiveKeysGet(s1, s2, fetch.coid, fetch.gkeyOpt) + case lookup: Node.LookupByKey => + unfold(containsOptionKey(s1)(node.gkeyOpt)) + unfold(containsOptionKey(s2)(node.gkeyOpt)) + visitLookupDifferentStatesActiveKeysGet( + s1, + s2, + lookup.gkey, + lookup.result, + unbound1, + lc1, + unbound2, + lc2, + ) + case exe: Node.Exercise => + visitExerciseDifferentStatesActiveKeysGet( + s1, + s2, + nid, + exe.targetCoid, + exe.gkeyOpt, + exe.byKey, + exe.consuming, + ) + } + + }.ensuring( + node.gkeyOpt.forall(k => + s1.handleNode(nid, node).get.activeKeys.get(k) == + s2.handleNode(nid, node).get.activeKeys.get(k) + ) + ) + + /** If two states have the same activeState and the same globalKeys then their activeKeys are also the same. + */ + @pure @opaque + def activeKeysGetSameFields(s1: State, s2: State, k: GlobalKey): Unit = { + require(s1.activeState == s2.activeState) + require(s1.globalKeys == s2.globalKeys) + unfold(s1.activeKeys) + unfold(s2.activeKeys) + unfold(s1.keys) + unfold(s2.keys) + }.ensuring(s1.activeKeys.get(k) == s2.activeKeys.get(k)) + + /** beginRollback followed by a function that does not modify the rollbackStack and the globalKeys, followed by an + * endRollback does not change the activeKeys of a state. + */ + @opaque + @pure + def activeKeysGetRollbackScope(s: State, k: GlobalKey, f: State => State): Unit = { + require(f(s.beginRollback()).endRollback().isRight) + require(f(s.beginRollback()).rollbackStack == s.beginRollback().rollbackStack) + require(f(s.beginRollback()).globalKeys == s.beginRollback().globalKeys) + + unfold(s.beginRollback()) + unfold(f(s.beginRollback()).endRollback()) + activeKeysGetSameFields(s, f(s.beginRollback()).endRollback().get, k) + }.ensuring( + f(s.beginRollback()).endRollback().get.activeKeys.get(k) == + s.activeKeys.get(k) + ) + +} diff --git a/daml-lf/verification/transaction/CSMInvariant.scala b/daml-lf/verification/transaction/CSMInvariant.scala new file mode 100644 index 0000000000..e5bde78ce2 --- /dev/null +++ b/daml-lf/verification/transaction/CSMInvariant.scala @@ -0,0 +1,825 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package transaction + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import scala.annotation.targetName +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ + +import ContractStateMachine._ +import CSMHelpers._ +import CSMKeysPropertiesDef._ +import CSMKeysProperties._ +import CSMEitherDef._ + +object CSMInvariantDef { + + /** * + * List of invariants that hold for well-defined transactions + * + * The following relationships tie together the various sets of ContractId. + * Let + * globalIds := globalKeys.values.filter(_.isDefined) + * act.localIds := act.localKeys.values + * + * a. unbound is disjoint from globalIds + * + * b. for every act in the rollbackStack and for the active state: + * act.localIds ⊆ locallyCreated ⊆ lc + * + * c. lc.map(KeyActive) is disjoint from globalIds + * + * d. for every act in the rollbackStack and for the active state: + * act.localIds is disjoint from unbound + * + * e. for every act in the rollbackStack and for the active state: + * act.consumedBy.values ⊆ consumed + * + * f. for every contractId v in globalIds: + * globalKeys.preimage(v).size <= 1 + * + * g. for every contractId v in act.localIds + * act.localKeys.preimage(v).size <= 1 + * + * Conditions e and f mean that both maps are injective, i.e. that no contract is assigned to two keys + * at the same time. + * + * Regarding the keys the only invariant is that local keys are a subset of global keys + * + * @param s + * The state for which those invariants hold + * + * @param lc + * Set of contracts contracted during the transaction, before executing + * @param unbound + * Set of contracts created throughout the transaction but not linked to any key + */ + + @pure + def invariantWrtActiveState( + s: State + )(unbound: Set[ContractId], lc: Set[ContractId]): ActiveLedgerState => Boolean = + a => + // point b.1. + a.localKeys.values.subsetOf(s.locallyCreated) && + // point d. + unbound.disjoint(a.localKeys.values) && + // point e. + a.consumedBy.keySet.subsetOf(s.consumed) && + // point g. + (a.localKeys.values).forall(v => a.localKeys.preimage(v).size <= BigInt(1)) && + // keySets + a.localKeys.keySet.subsetOf(s.globalKeys.keySet) + + @pure + def invariantWrtList(s: State)(unbound: Set[ContractId], lc: Set[ContractId])( + l: List[ActiveLedgerState] + ): Boolean = l.forall(invariantWrtActiveState(s)(unbound, lc)) + + @pure + def stateInvariant(s: State)(unbound: Set[ContractId], lc: Set[ContractId]): Boolean = { + + val invariant: Boolean = { + + // point a. + unbound.map[KeyMapping](KeyActive).disjoint(s.globalKeys.values) && + // point b.2. + s.locallyCreated.subsetOf(lc) && + // point c. + lc.map[KeyMapping](KeyActive).disjoint(s.globalKeys.values) && + // point f. + (s.globalKeys.values + .filter(_.isDefined)) + .forall(v => s.globalKeys.preimage(v).size <= BigInt(1)) && + // the invariants hold for the active state + invariantWrtActiveState(s)(unbound, lc)(s.activeState) && + // the invariants hold for all the states in the stack + invariantWrtList(s)(unbound, lc)(s.rollbackStack) + } + + /** Properties that can be derived from the invariants: + * - globalKeys.keySet == activeKeys.keySet + */ + @opaque + def invariantProperties: Unit = { + require(invariant) + unfold(invariantWrtActiveState) + + // globalKeys.keySet == activeKeys.keySet + unfold(s.activeKeys) + activeKeysKeySet(s) + SetProperties.unionAltDefinition(s.activeState.localKeys.keySet, s.globalKeys.keySet) + SetProperties.equalsTransitivityStrong( + s.globalKeys.keySet, + s.globalKeys.keySet ++ s.activeState.localKeys.keySet, + s.activeKeys.keySet, + ) + + // localKeys.values.subsetOf(lc) + SetProperties.subsetOfTransitivity(s.activeState.localKeys.values, s.locallyCreated, lc) + + // localKeys.values.map(KeyActive).disjoint(s.globalKeys.values) + SetProperties.mapSubsetOf[ContractId, KeyMapping]( + s.activeState.localKeys.values, + lc, + KeyActive, + ) + SetProperties.disjointSubsetOf( + lc.map[KeyMapping](KeyActive), + s.activeState.localKeys.values.map[KeyMapping](KeyActive), + s.globalKeys.values, + ) + + }.ensuring( + (s.globalKeys.keySet === s.activeKeys.keySet) && + s.activeState.localKeys.values.map[KeyMapping](KeyActive).disjoint(s.globalKeys.values) + ) + + if (invariant) { + invariantProperties + } + invariant + + } + + /** Properties that need to be true for any pair intermediate state - node during a transaction so that the invariants + * are preserved. + */ + + @pure + @opaque + def stateNodeCompatibility( + s: State, + n: Node, + unbound: Set[ContractId], + lc: Set[ContractId], + dir: TraversalDirection, + ): Boolean = { + (dir == TraversalDirection.Down) ==> + (n match { + case Node.Create(coid, mbKey) => + lc.contains(coid) && + (mbKey.isDefined == !unbound.contains(coid)) && + !s.locallyCreated.contains(coid) && + !s.consumed.contains(coid) + case exe: Node.Exercise => exe.gkeyOpt.isDefined == !unbound.contains(exe.targetCoid) + case _ => true + }) + } + + /** The invariants are considered true for an error state which means that they are preserved even when the transaction + * leads to an error + */ + @pure + def invariantWrtList[T](e: Either[T, State])(unbound: Set[ContractId], lc: Set[ContractId])( + l: List[ActiveLedgerState] + ): Boolean = + e.forall(s => invariantWrtList(s)(unbound, lc)(l)) + + @pure + def stateInvariant[T]( + e: Either[T, State] + )(unbound: Set[ContractId], lc: Set[ContractId]): Boolean = { + e.forall(s => stateInvariant(s)(unbound, lc)) + } + +} + +/** The scope of this object is to prove that the invariants are preserved when traversing a transaction. This includes + * proving that they still hold: + * - after handling a node (given some conditions on the node that hold for well defined transactions) + * - after entering or exiting a rollback node + * - after adding a global mapping (given some conditions on the mapping) + */ + +object CSMInvariant { + + import CSMInvariantDef._ + + /** If a state fulfills the invariants then the state obtained after calling beginRollback also fulfills the invariants. + */ + @pure + @opaque + def stateInvariantBeginRollback(s: State, unbound: Set[ContractId], lc: Set[ContractId]): Unit = { + require(stateInvariant(s)(unbound, lc)) + unfold(s.beginRollback()) + }.ensuring(stateInvariant(s.beginRollback())(unbound, lc)) + + /** If a state fulfills the invariants then the state obtained after calling endRollback also fulfills the invariants. + */ + @pure + @opaque + def stateInvariantEndRollback(s: State, unbound: Set[ContractId], lc: Set[ContractId]): Unit = { + require(stateInvariant(s)(unbound, lc)) + unfold(s.endRollback()) + }.ensuring(stateInvariant(s.endRollback())(unbound, lc)) + + /** If a state fulfills the invariants and an other state has its same fields except for the locallyCreatedThisTimeline + * then it also fulfills the invariants. + */ + @pure + @opaque + def stateInvariantSameFields[T]( + s1: State, + s2: Either[T, State], + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(s1)(unbound, lc)) + require(sameLocalKeys(s1, s2)) + require(sameConsumedBy(s1, s2)) + require(sameGlobalKeys(s1, s2)) + require(sameStack(s1, s2)) + require(sameLocallyCreated(s1, s2)) + require(sameConsumed(s1, s2)) + + unfold(sameLocalKeys(s1, s2)) + unfold(sameConsumedBy(s1, s2)) + unfold(sameGlobalKeys(s1, s2)) + unfold(sameStack(s1, s2)) + unfold(sameLocallyCreated(s1, s2)) + unfold(sameConsumed(s1, s2)) + }.ensuring(stateInvariant(s2)(unbound, lc)) + + /** If a state fulfills the invariants and an other state has its same fields except for the locallyCreatedThisTimeline + * then it also fulfills the invariants. + */ + @pure + @opaque + def stateInvariantSameFields[U, T]( + s1: Either[U, State], + s2: Either[T, State], + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(s1)(unbound, lc)) + require(sameLocalKeys(s1, s2)) + require(sameConsumedBy(s1, s2)) + require(sameGlobalKeys(s1, s2)) + require(sameStack(s1, s2)) + require(sameLocallyCreated(s1, s2)) + require(sameConsumed(s1, s2)) + require(propagatesError(s1, s2)) // we want to avoid the case s1.isLeft and s2.isRight + + unfold(sameLocalKeys(s1, s2)) + unfold(sameConsumedBy(s1, s2)) + unfold(sameGlobalKeys(s1, s2)) + unfold(sameStack(s1, s2)) + unfold(sameLocallyCreated(s1, s2)) + unfold(sameConsumed(s1, s2)) + unfold(propagatesError(s1, s2)) + + s1 match { + case Right(s) => stateInvariantSameFields(s, s2, unbound, lc) + case Left(_) => Trivial() + } + }.ensuring(stateInvariant(s2)(unbound, lc)) + + /** If a state fulfills the invariants and an other state is equal to it then it also fulfills the invariants. + */ + @pure + @opaque + def stateInvariantSameState[T]( + s1: State, + s2: Either[T, State], + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(s1)(unbound, lc)) + require(sameState(s1, s2)) + + stateInvariantSameFields(s1, s2, unbound, lc) + }.ensuring(stateInvariant(s2)(unbound, lc)) + + /** If a state fulfills the invariants and an other state is equal to it then it also fulfills the invariants. + */ + @pure + @opaque + def stateInvariantSameState[U, T]( + s1: Either[U, State], + s2: Either[T, State], + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(s1)(unbound, lc)) + require(sameState(s1, s2)) + require(propagatesError(s1, s2)) // we want to avoid the case s1.isLeft and s2.isRight + + stateInvariantSameFields(s1, s2, unbound, lc) + }.ensuring(stateInvariant(s2)(unbound, lc)) + + /** If a state fulfills the invariants then the state obtained after calling assertKeyMapping also fulfills the invariants. + */ + @pure + @opaque + def stateInvariantAssertKeyMapping( + s: State, + cid: ContractId, + mbKey: Option[GlobalKey], + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(s)(unbound, lc)) + + stateInvariantSameState(s, s.assertKeyMapping(cid, mbKey), unbound, lc) + }.ensuring(stateInvariant(s.assertKeyMapping(cid, mbKey))(unbound, lc)) + + /** If a state fulfills the invariants then the state obtained after calling visitLookup also fulfills the invariants. + */ + @pure + @opaque + def stateInvariantVisitLookup( + s: State, + gk: GlobalKey, + keyResolution: Option[ContractId], + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(s)(unbound, lc)) + + stateInvariantSameState(s, s.visitLookup(gk, keyResolution), unbound, lc) + }.ensuring(stateInvariant(s.visitLookup(gk, keyResolution))(unbound, lc)) + + /** If a state fulfills the invariants then the state obtained after calling visitExercise also fulfills the invariants. + */ + @pure + @opaque + def stateInvariantVisitExercise( + s: State, + nodeId: NodeId, + targetId: ContractId, + mbKey: Option[GlobalKey], + byKey: Boolean, + consuming: Boolean, + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(s)(unbound, lc)) + unfold(s.visitExercise(nodeId, targetId, mbKey, byKey, consuming)) + + s.assertKeyMapping(targetId, mbKey) match { + case Left(_) => Trivial() + case Right(state) => + unfold(sameState(s, s.assertKeyMapping(targetId, mbKey))) + if (consuming) { + unfold(state.consume(targetId, nodeId)) + unfold(state.activeState.consume(targetId, nodeId)) + SetProperties.subsetOfIncl(s.activeState.consumedBy.keySet, s.consumed, targetId) + SetProperties.equalsSubsetOfTransitivity( + s.activeState.consumedBy.updated(targetId, nodeId).keySet, + s.activeState.consumedBy.keySet + targetId, + s.consumed + targetId, + ) + + val res = state.consume(targetId, nodeId) + + if (!invariantWrtList(res)(unbound, lc)(s.rollbackStack)) { + val a = ListProperties.notForallWitness( + s.rollbackStack, + invariantWrtActiveState(res)(unbound, lc), + ) + ListProperties.forallContains( + s.rollbackStack, + invariantWrtActiveState(s)(unbound, lc), + a, + ) + SetProperties.subsetOfTransitivity(a.consumedBy.keySet, s.consumed, res.consumed) + Unreachable() + } + } + } + + }.ensuring( + stateInvariant(s.visitExercise(nodeId, targetId, mbKey, byKey, consuming))(unbound, lc) + ) + + /** Invariants are preserved after a create node, given some conditions on the state and the node: + * - The key associated to the new contract has to appear in the global keys. This is always the case since + * we update all the globalKeys before traversing the transaction + * - lc has to contain contractId (since it is the set of all newly created contracts in the transaction) + * - if unbound contains the contract then it means that no key is given + * - The contract should not have been created previously in the transaction + */ + @pure + @opaque + def stateInvariantVisitCreate( + s: State, + contractId: ContractId, + mbKey: Option[GlobalKey], + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(s)(unbound, lc)) + require(containsOptionKey(s)(mbKey)) + require(lc.contains(contractId)) + require(mbKey.isDefined ==> !unbound.contains(contractId)) + require(!s.locallyCreated.contains(contractId)) + + unfold(s.visitCreate(contractId, mbKey)) + + val me = + s.copy( + locallyCreated = s.locallyCreated + contractId, + activeState = s.activeState + .copy(locallyCreatedThisTimeline = s.activeState.locallyCreatedThisTimeline + contractId), + ) + + /** STEP 1: the invariants still hold for every state in the rollback stack + * + * At this stage all the modifications relevant to the rollbackStack has already been done. + * In fact what remains is changing the active state. Therefore we can already prove that the + * invariants are preserved for every state in the rollback stack (by contradiction) + */ + + if (!invariantWrtList(me)(unbound, lc)(me.rollbackStack)) { + val a = + ListProperties.notForallWitness(me.rollbackStack, invariantWrtActiveState(me)(unbound, lc)) + ListProperties.forallContains(s.rollbackStack, invariantWrtActiveState(s)(unbound, lc), a) + SetProperties.subsetOfTransitivity(a.localKeys.values, s.locallyCreated, me.locallyCreated) + Unreachable() + } + + /** STEP 2 locallyCreated + contractId ⊆ lc + * + * Here we need the fact that lc contains contractId + */ + + SetProperties.subsetOfIncl(s.locallyCreated, lc, contractId) + + mbKey match { + case None() => + /** STEP 3.a localKeys.values ⊆ locallyCreated + contractId + */ + SetProperties.subsetOfTransitivity( + s.activeState.localKeys.values, + s.locallyCreated, + me.locallyCreated, + ) + + case Some(gk) => + val ns = me.copy(activeState = me.activeState.createKey(gk, contractId)) + + unfold(me.activeState.createKey(gk, contractId)) + MapProperties.updatedValues(s.activeState.localKeys, gk, contractId) + + /** STEP 3.b localKeys.updated(gk, contractId).values ⊆ locallyCreated + contractId + */ + SetProperties.subsetOfIncl(s.activeState.localKeys.values, s.locallyCreated, contractId) + SetProperties.subsetOfTransitivity( + s.activeState.createKey(gk, contractId).localKeys.values, + s.activeState.localKeys.values + contractId, + s.locallyCreated + contractId, + ) + + /** STEP 4 unbound is disjoint from localKeys.updated(gk, contractId).values + * + * Here we need the fact that contractId is not in unbound if it is bound to some key + */ + SetProperties.disjointSubsetOf( + unbound, + s.activeState.localKeys.values + contractId, + s.activeState.createKey(gk, contractId).localKeys.values, + ) + SetProperties.disjointIncl(unbound, s.activeState.localKeys.values, contractId) + + /** STEP 5 (localKeys.updated(gk, contractId).values).forall(v => localKeys.updated(gk, contractId).preimage(v).size <= BigInt(1)) + * + * We show that the statement is true for contractId and prove by contradiction that it remains true for + * the previous values. We will need the fact that locallyCreated and therefore localKeys.values do not + * contain contractId. + */ + val newF: ContractId => Boolean = + v => ns.activeState.localKeys.preimage(v).size <= BigInt(1) + + SetProperties.forallIncl(s.activeState.localKeys.values, contractId, newF) + + // STEP 5.1 localKeys.values do not contain contractId + if (s.activeState.localKeys.values.contains(contractId)) { + SetProperties.subsetOfContains( + s.activeState.localKeys.values, + s.locallyCreated, + contractId, + ) + } + + // STEP 5.2 newF(contractId) + { + // STEP 5.2.1 localKeys.preimage(contractId).size == 0 + MapProperties.preimageIsEmpty(s.activeState.localKeys, contractId) + SetProperties.isEmptySize(s.activeState.localKeys.preimage(contractId)) + + // STEP 5.2.2 localKeys.updated(gk, contractId).preimage(contractId).size <= 1 + MapProperties.inclPreimage(s.activeState.localKeys, gk, contractId, contractId) + SetProperties.subsetOfSize( + s.activeState.localKeys.updated(gk, contractId).preimage(contractId), + s.activeState.localKeys.preimage(contractId) + gk, + ) + } + + // STEP 5.3 localKeys.values.forall(v => localKeys.updated(gk, contractId).preimage(v).size <= BigInt(1)) + if (!s.activeState.localKeys.values.forall(newF)) { + val w = SetProperties.notForallWitness(s.activeState.localKeys.values, newF) + SetProperties.forallContains( + s.activeState.localKeys.values, + v => s.activeState.localKeys.preimage(v).size <= BigInt(1), + w, + ) + MapProperties.inclPreimage(s.activeState.localKeys, gk, contractId, w) + SetProperties.subsetOfSize( + s.activeState.localKeys.updated(gk, contractId).preimage(w), + s.activeState.localKeys.preimage(w), + ) + Unreachable() + } + + // STEP 5.4 Final calls + MapProperties.updatedValues(s.activeState.localKeys, gk, contractId) + SetProperties.forallSubsetOf( + s.activeState.localKeys.updated(gk, contractId).values, + s.activeState.localKeys.values + contractId, + newF, + ) + + /** STEP 6 localKeys.keySet ⊆ globalKeys.keySet + * + * For this, we will need the fact that the key is already in the global keys. + */ + unfold(containsOptionKey(s)(mbKey)) + unfold(containsKey(s)(gk)) + MapProperties.keySetContains(s.globalKeys, gk) + SetProperties.subsetOfIncl(s.activeState.localKeys.keySet, s.globalKeys.keySet, gk) + SetProperties.equalsSubsetOfTransitivity( + s.activeState.localKeys.updated(gk, contractId).keySet, + s.activeState.localKeys.keySet + gk, + s.globalKeys.keySet, + ) + } + }.ensuring(stateInvariant(s.visitCreate(contractId, mbKey))(unbound, lc)) + + /** If a state fulfills the invariants and a node respects the [[CSMInvariantDef.stateNodeCompatibility]] conditions + * then the state obtained after handling it also fulfills the invariants. + */ + @pure + @opaque + def handleNodeInvariant( + s: State, + id: NodeId, + node: Node.Action, + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(s)(unbound, lc)) + require(stateNodeCompatibility(s, node, unbound, lc, TraversalDirection.Down)) + require(containsActionKey(s)(node)) + + unfold(s.handleNode(id, node)) + unfold(stateNodeCompatibility(s, node, unbound, lc, TraversalDirection.Down)) + unfold(containsActionKey(s)(node)) + + node match { + case create: Node.Create => + stateInvariantVisitCreate(s, create.coid, create.gkeyOpt, unbound, lc) + case fetch: Node.Fetch => + stateInvariantAssertKeyMapping(s, fetch.coid, fetch.gkeyOpt, unbound, lc) + case lookup: Node.LookupByKey => + unfold(containsOptionKey(s)(node.gkeyOpt)) + stateInvariantVisitLookup(s, lookup.gkey, lookup.result, unbound, lc) + case exe: Node.Exercise => + stateInvariantVisitExercise( + s, + id, + exe.targetCoid, + exe.gkeyOpt, + exe.byKey, + exe.consuming, + unbound, + lc, + ) + } + }.ensuring(stateInvariant(s.handleNode(id, node))(unbound, lc)) + +} + +object CSMInvariantDerivedProp { + + import CSMInvariantDef._ + + /** If a state fulfills the invariants then the activeKeys contains a key if and only if the globalKeys also do. + */ + @pure + @opaque + def invariantContainsKey( + s: State, + gk: GlobalKey, + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(s)(unbound, lc)) + unfold(containsKey(s)(gk)) + MapProperties.equalsKeySetContains(s.globalKeys, s.activeKeys, gk) + }.ensuring(containsKey(s)(gk) == s.activeKeys.contains(gk)) + + /** If a state fulfills the invariants and a key maps to a contract in the global keys, then no other key maps to the + * same contract in the global or local keys. + */ + @pure + @opaque + def invariantGetGlobalKeys( + s: State, + k: GlobalKey, + k2: GlobalKey, + c: ContractId, + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(s)(unbound, lc)) + require(s.globalKeys.get(k) == Some(KeyActive(c))) + require(k != k2) + + /** STEP 1: globalKeys.values.contains(KeyActive(c)) + */ + + unfold(s.globalKeys.contains) + MapAxioms.valuesContains(s.globalKeys, k) + + /** STEP 2: localKeys.get(k2) != Some(c) + */ + + if (s.activeState.localKeys.get(k2) == Some(c)) { + unfold(s.activeState.localKeys.contains) + MapAxioms.valuesContains(s.activeState.localKeys, k2) + SetProperties.mapContains[ContractId, KeyMapping]( + s.activeState.localKeys.values, + KeyActive, + s.activeState.localKeys(k2), + ) + SetProperties.disjointContains( + s.activeState.localKeys.values.map[KeyMapping](KeyActive), + s.globalKeys.values, + KeyActive(c), + ) + Unreachable() + } + + /** STEP 3: globalKeys.get(k2) != Some(KeyActive(c) + */ + + if (s.globalKeys.get(k2) == Some(KeyActive(c))) { + + // STEP 2.1: s.globalKeys.preimage(KeyActive(c)) contains k and k2 + MapProperties.preimageGet(s.globalKeys, KeyActive(c), k) + MapProperties.preimageGet(s.globalKeys, KeyActive(c), k2) + + // STEP 2.2: Set(k) ++ Set(k2) subsetOf s.globalKeys.preimage(KeyActive(c)) + SetProperties.singletonSubsetOf(s.globalKeys.preimage(KeyActive(c)), k) + SetProperties.singletonSubsetOf(s.globalKeys.preimage(KeyActive(c)), k2) + SetProperties.unionSubsetOf(Set(k), Set(k2), s.globalKeys.preimage(KeyActive(c))) + + // STEP 2.3: Set(k) ++ Set(k2) size == 2 + SetAxioms.singletonSize(k) + SetAxioms.singletonSize(k2) + SetProperties.disjointTwoSingleton(k, k2) + SetAxioms.unionDisjointSize(Set(k), Set(k2)) + + // STEP 2.4: Final calls + SetProperties.subsetOfSize(Set(k) ++ Set(k2), s.globalKeys.preimage(KeyActive(c))) + SetProperties.filterContains(s.globalKeys.values, _.isDefined, KeyActive(c)) + SetProperties.forallContains( + s.globalKeys.values.filter(_.isDefined), + v => s.globalKeys.preimage(v).size <= BigInt(1), + KeyActive(c), + ) + Unreachable() + } + }.ensuring( + (s.globalKeys.get(k2) != Some(KeyActive(c))) && (s.activeState.localKeys.get(k2) != Some(c)) + ) + + /** If a state fulfills the invariants and a key maps to a contract in the local keys, then no other key maps to the + * same contract in the global or local keys. + */ + @pure + @opaque + def invariantGetLocalKeys( + s: State, + k: GlobalKey, + k2: GlobalKey, + c: ContractId, + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(s)(unbound, lc)) + require(s.activeState.localKeys.get(k) == Some(c)) + require(k != k2) + + /** STEP 1: localKeys.values.contains(KeyActive(c)) + */ + + unfold(s.activeState.localKeys.contains) + MapAxioms.valuesContains(s.activeState.localKeys, k) + + /** STEP 2: globalKeys.get(k2) != Some(KeyActive(c)) + */ + + if (s.globalKeys.get(k2) == Some(KeyActive(c))) { + unfold(s.globalKeys.contains) + MapAxioms.valuesContains(s.globalKeys, k2) + SetProperties.mapContains[ContractId, KeyMapping]( + s.activeState.localKeys.values, + KeyActive, + c, + ) + SetProperties.disjointContains( + s.activeState.localKeys.values.map[KeyMapping](KeyActive), + s.globalKeys.values, + s.globalKeys(k2), + ) + Unreachable() + } + + /** STEP 3: localKeys.get(k2) != Some(c) + */ + + if (s.activeState.localKeys.get(k2) == Some(c)) { + + // STEP 2.1: s.activeState.localKeys.preimage(c) contains k and k2 + MapProperties.preimageGet(s.activeState.localKeys, c, k) + MapProperties.preimageGet(s.activeState.localKeys, c, k2) + + // STEP 2.2: Set(k) ++ Set(k2) subsetOf s.activeState.localKeys.preimage(c) + SetProperties.singletonSubsetOf(s.activeState.localKeys.preimage(c), k) + SetProperties.singletonSubsetOf(s.activeState.localKeys.preimage(c), k2) + SetProperties.unionSubsetOf(Set(k), Set(k2), s.activeState.localKeys.preimage(c)) + + // STEP 2.3: Set(k) ++ Set(k2) size == 2 + SetAxioms.singletonSize(k) + SetAxioms.singletonSize(k2) + SetProperties.disjointTwoSingleton(k, k2) + SetAxioms.unionDisjointSize(Set(k), Set(k2)) + + // STEP 2.4: Final calls + SetProperties.subsetOfSize(Set(k) ++ Set(k2), s.activeState.localKeys.preimage(c)) + SetProperties.forallContains( + s.activeState.localKeys.values, + v => s.activeState.localKeys.preimage(v).size <= BigInt(1), + c, + ) + Unreachable() + } + }.ensuring( + (s.globalKeys.get(k2) != Some(KeyActive(c))) && (s.activeState.localKeys.get(k2) != Some(c)) + ) + + /** The invariants are true for any state whose activeState is empty. + */ + @pure + @opaque + def emptyActiveStateInvariant(s: State, unbound: Set[ContractId], lc: Set[ContractId]): Unit = { + unfold(ActiveLedgerState.empty) + unfold(invariantWrtActiveState(s)(unbound, lc)(ActiveLedgerState.empty)) + + MapProperties.emptyValues[GlobalKey, ContractId] + SetProperties.isEmptySubsetOf(ActiveLedgerState.empty.localKeys.values, s.locallyCreated) + SetProperties.disjointIsEmpty(unbound, ActiveLedgerState.empty.localKeys.values) + SetProperties.forallIsEmpty( + ActiveLedgerState.empty.localKeys.values, + v => ActiveLedgerState.empty.localKeys.preimage(v).size <= BigInt(1), + ) + + MapProperties.emptyKeySet[GlobalKey, ContractId] + MapProperties.emptyKeySet[ContractId, NodeId] + SetProperties.isEmptySubsetOf(ActiveLedgerState.empty.consumedBy.keySet, s.consumed) + SetProperties.isEmptySubsetOf(ActiveLedgerState.empty.localKeys.keySet, s.globalKeys.keySet) + + }.ensuring( + invariantWrtActiveState(s)(unbound, lc)(ActiveLedgerState.empty) + ) + + /** The list invariants are always true for the empty state. + */ + @pure + @opaque + def emptyListInvariant(s: State, unbound: Set[ContractId], lc: Set[ContractId]): Unit = { + unfold(State.empty) + unfold(invariantWrtList(State.empty)(unbound, lc)(State.empty.rollbackStack)) + }.ensuring( + invariantWrtList(State.empty)(unbound, lc)(State.empty.rollbackStack) + ) + +} diff --git a/daml-lf/verification/transaction/CSMKeysProperties.scala b/daml-lf/verification/transaction/CSMKeysProperties.scala new file mode 100644 index 0000000000..d8d67fe4e5 --- /dev/null +++ b/daml-lf/verification/transaction/CSMKeysProperties.scala @@ -0,0 +1,1096 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package transaction + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import scala.annotation.targetName +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ + +import ContractStateMachine._ +import CSMHelpers._ +import CSMEither._ +import CSMEitherDef._ + +/** Proofs related to globalKeys and localKeys fields of a state. + * As usual CSMKeysPropertiesDef contains all the definitions whereas CSMKeysProperties contains all the theorems + * with their respective proofs. + * + * In the simplified version of the contract state machine, the handleNode function is split into two functions: + * - A simplified version of handleNode that does not modify the global keys + * - A function addKey that adds the right mapping to the global keys if the key is still not registered in the map. + * This mapping is described by nodeActionKeyMapping in CSMHelpers + * + * We prove in TransactionTreeFull that we can extract all addKey of a transaction to put them at the beginning. + * Going through a transaction is therefore equivalent to first update all the keys and then processing the transaction. + * + * Therefore when dealing with (the simplified version of) handleNode, we always make the assumption that there is + * already an entry in the activeKeys for the key of the node we are handling. + */ +object CSMKeysPropertiesDef { + + /** Checks whether there is already an entry for the key k in the globalKeys of the state. + * + * Since in a well-defined state the keys in localKeys are a subset of the keys in globalKeys + * this is easier yet equivalent to checking whether k is contained in the activeKeys. + * + * All the versions of the function are opaque which can make reasonning tedious. This is however + * necessary to reduce the complexity of the proofs and being able to keep a low timeout. + */ + @pure + @opaque + def containsKey(s: State)(k: GlobalKey): Boolean = { + s.globalKeys.contains(k) + } + + /** Checks whether there is already an entry for k in the globalKeys of the state in case k is defined. + * If the key is not defined, returns true. + * + * For more details cf containsKey above + */ + @pure + @opaque + def containsOptionKey(s: State)(k: Option[GlobalKey]): Boolean = { + k.forall(containsKey(s)) + } + + /** Checks whether there is already an entry for the key of n in the globalKeys of the state in case it is defined. + * If the key is not defined, returns true. + * + * For more details cf containsKey above + */ + @pure + @opaque + def containsActionKey(s: State)(n: Node.Action): Boolean = { + containsOptionKey(s)(n.gkeyOpt) + } + @pure + @opaque + def containsNodeKey(s: State)(n: Node): Boolean = { + n match { + case a: Node.Action => containsActionKey(s)(a) + case r: Node.Rollback => true + } + } + + /** Checks whether there is already an entry for the key of n in the globalKeys of the state in case both are defined. + * If the key or the state is not defined, returns true. + * + * For more details cf containsKey above + */ + @pure + @opaque + def containsKey[T](e: Either[T, State])(n: Node): Boolean = { + e match { + case Right(s) => containsNodeKey(s)(n) + case Left(_) => true + } + } + + /** Replace the globalKeys of s with glK one of their supermap + * + * Among the direct properties one can deduce from the definition we have that the activeKeys of s are also + * a submap of the activeKeys of the result + */ + + @pure + @opaque + def extendGlobalKeys(s: State, glK: Map[GlobalKey, KeyMapping]): State = { + require(s.globalKeys.submapOf(glK)) + + val res = s.copy(globalKeys = glK) + + @pure @opaque + def extendGlobalKeysProperties: Unit = { + MapProperties.submapOfKeySet(s.globalKeys, glK) + unfold(s.activeKeys) + unfold(res.activeKeys) + unfold(s.keys) + unfold(res.keys) + MapProperties.concatSubmapOf(s.globalKeys, glK, s.activeState.localKeys.mapValues(KeyActive)) + MapProperties.mapValuesSubmapOf( + s.keys, + res.keys, + keyMappingToActiveMapping(s.activeState.consumedBy), + ) + }.ensuring( + (s.rollbackStack == res.rollbackStack) && + (s.activeState == res.activeState) && + (s.locallyCreated == res.locallyCreated) && + (s.consumed == res.consumed) && + (res.globalKeys == glK) && + (s.activeKeys.submapOf(res.activeKeys)) + ) + + extendGlobalKeysProperties + + res + + }.ensuring(res => + (s.rollbackStack == res.rollbackStack) && + (s.activeState == res.activeState) && + (s.locallyCreated == res.locallyCreated) && + (s.consumed == res.consumed) && + (res.globalKeys == glK) && + (s.activeKeys.submapOf(res.activeKeys)) + ) + + /** Extends the globalKeys of s to the left. More precisely overwrite the globalKeys of glK with the globalKeys of s. + * + * Among the direct properties one can deduce from the definition we have that the activeKeys of s are also + * a submap of the activeKeys of the result. The same holds for the globalKeys. + */ + + @pure + @opaque + def concatLeftGlobalKeys(s: State, glK: Map[GlobalKey, KeyMapping]): State = { + MapProperties.concatSubmapOf(glK, s.globalKeys) + extendGlobalKeys(s, glK ++ s.globalKeys) + }.ensuring(res => + (s.rollbackStack == res.rollbackStack) && + (s.activeState == res.activeState) && + (s.locallyCreated == res.locallyCreated) && + (s.consumed == res.consumed) && + (s.activeKeys.submapOf(res.activeKeys)) && + (s.globalKeys.submapOf(res.globalKeys)) + ) + + /** If e is a well defined state, extends the globalKeys of s to the left. Otherwise returns the same error. + */ + + @pure + def concatLeftGlobalKeys[T]( + e: Either[T, State], + glK: Map[GlobalKey, KeyMapping], + ): Either[T, State] = { + val res = e.map(s => concatLeftGlobalKeys(s, glK)) + + @pure + @opaque + def concatLeftGlobalKeysProp: Unit = { + sameErrorMap(e, s => concatLeftGlobalKeys(s, glK)) + unfold(sameStack(e, res)) + unfold(sameActiveState(e, res)) + unfold(sameLocallyCreated(e, res)) + unfold(sameConsumed(e, res)) + e match { + case Left(_) => Trivial() + case Right(r) => + unfold(sameStack(r, res)) + unfold(sameActiveState(r, res)) + unfold(sameLocallyCreated(r, res)) + unfold(sameConsumed(r, res)) + } + }.ensuring( + sameError(e, res) && + sameStack(e, res) && + sameActiveState(e, res) && + sameLocallyCreated(e, res) && + sameConsumed(e, res) + ) + + concatLeftGlobalKeysProp + res + }.ensuring(res => + sameError(e, res) && + sameStack(e, res) && + sameActiveState(e, res) && + sameLocallyCreated(e, res) && + sameConsumed(e, res) + ) + + /** Adds the pair (k, mapping) to the globalKeys of s if s does not contain k. + */ + + @pure + @opaque + def addKey(s: State, k: GlobalKey, mapping: KeyMapping): State = { + MapProperties.submapOfReflexivity(s.globalKeys) + unfold(containsKey(s)(k)) + val res = + extendGlobalKeys(s, if (containsKey(s)(k)) s.globalKeys else s.globalKeys.updated(k, mapping)) + unfold(containsKey(res)(k)) + unfold(concatLeftGlobalKeys(s, Map[GlobalKey, KeyMapping](k -> mapping))) + if (containsKey(s)(k)) { + MapProperties.keySetContains(s.globalKeys, k) + MapProperties.singletonKeySet(k, mapping) + SetProperties.singletonSubsetOf(s.globalKeys.keySet, k) + SetProperties.equalsSubsetOfTransitivity( + Set(k), + Map[GlobalKey, KeyMapping](k -> mapping).keySet, + s.globalKeys.keySet, + ) + MapProperties.concatSubsetOfEquals(Map[GlobalKey, KeyMapping](k -> mapping), s.globalKeys) + MapAxioms.extensionality( + Map[GlobalKey, KeyMapping](k -> mapping) ++ s.globalKeys, + s.globalKeys, + ) + } else { + MapProperties.updatedCommutativity(s.globalKeys, k, mapping) + MapAxioms.extensionality( + Map[GlobalKey, KeyMapping](k -> mapping) ++ s.globalKeys, + s.globalKeys.updated(k, mapping), + ) + } + res + }.ensuring(res => + containsKey(res)(k) && + (res == concatLeftGlobalKeys(s, Map[GlobalKey, KeyMapping](k -> mapping))) && + (containsKey(s)(k) ==> (res == s)) + ) + + /** Adds the pair (k, mapping) to the globalKeys of s if k is defined and s does not contain k. + */ + @pure + @opaque + def addKey(s: State, k: Option[GlobalKey], mapping: KeyMapping): State = { + MapProperties.submapOfReflexivity(s.globalKeys) + val res = k match { + case None() => s + case Some(gk) => addKey(s, gk, mapping) + } + unfold(containsOptionKey(s)(k)) + unfold(containsOptionKey(res)(k)) + unfold(optionKeyMap(k, mapping)) + unfold(concatLeftGlobalKeys(s, optionKeyMap(k, mapping))) + MapProperties.concatEmpty(s.globalKeys) + MapAxioms.extensionality(Map.empty[GlobalKey, KeyMapping] ++ s.globalKeys, s.globalKeys) + res + }.ensuring(res => + containsOptionKey(res)(k) && + (res == concatLeftGlobalKeys(s, optionKeyMap(k, mapping))) && + (containsOptionKey(s)(k) ==> (res == s)) + ) + + /** Adds the pair (n.gkeyOpt, nodeActionKeyMapping(n)) to the globalKeys of s if the key of n is well-defined. + * + * nodeActionKeyMapping is the mapping corresponding to the node (cf. CSMHelpers). It is defined as: + * - KeyInactive if n is a Create Node + * - KeyActive(n.coid) if n is a Fetch Node + * - n.result if n is a Lookup Node + * - KeyActive(n.targetCoid) if n is an Exercise Node + */ + @pure + @opaque + def addKeyBeforeAction(s: State, n: Node.Action): State = { + val res = addKey(s, n.gkeyOpt, nodeActionKeyMapping(n)) + unfold(containsActionKey(s)(n)) + unfold(containsActionKey(res)(n)) + unfold(actionKeyMap(n)) + res + }.ensuring(res => + containsActionKey(res)(n) && + (res == concatLeftGlobalKeys(s, actionKeyMap(n))) && + (containsActionKey(s)(n) ==> (res == s)) + ) + @pure + @opaque + def addKeyBeforeNode(s: State, n: Node): State = { + val res = + n match { + case a: Node.Action => addKeyBeforeAction(s, a) + case r: Node.Rollback => s + } + unfold(containsNodeKey(res)(n)) + unfold(containsNodeKey(s)(n)) + unfold(nodeKeyMap(n)) + unfold(concatLeftGlobalKeys(s, nodeKeyMap(n))) + MapProperties.concatEmpty(s.globalKeys) + MapAxioms.extensionality(Map.empty[GlobalKey, KeyMapping] ++ s.globalKeys, s.globalKeys) + res + }.ensuring(res => + (res == concatLeftGlobalKeys(s, nodeKeyMap(n))) && + containsNodeKey(res)(n) && + (containsNodeKey(s)(n) ==> (res == s)) + ) + + /** If e is well-defined adds the pair corresponding to n to the global keys of the state. + * + * For more details cf. addKeyBeforeAction + */ + @pure + @opaque + def addKeyBeforeNode[T](e: Either[T, State], n: Node): Either[T, State] = { + + val res = e.map((s: State) => addKeyBeforeNode(s, n)) + + @pure + @opaque + def addKeyBeforeNodeProperties: Unit = { + sameErrorMap(e, (s: State) => addKeyBeforeNode(s, n)) + unfold(containsKey(res)(n)) + unfold(containsKey(e)(n)) + unfold(sameStack(e, res)) + unfold(concatLeftGlobalKeys(e, nodeKeyMap(n))) + e match { + case Left(_) => Trivial() + case Right(r) => unfold(sameStack(r, res)) + } + }.ensuring( + sameStack(e, res) && + sameError(e, res) && + containsKey(res)(n) && + (containsKey(e)(n) ==> (res == e)) && + (res == concatLeftGlobalKeys(e, nodeKeyMap(n))) + ) + + addKeyBeforeNodeProperties + res + }.ensuring(res => + sameStack(e, res) && + sameError(e, res) && + containsKey(res)(n) && + (containsKey(e)(n) ==> (res == e)) && + (res == concatLeftGlobalKeys(e, nodeKeyMap(n))) + ) + @pure + def addKeyBeforeNode[T](e: Either[T, State], p: (NodeId, Node)): Either[T, State] = { + addKeyBeforeNode(e, p._2) + } + + /** Map representing a key-mapping pair when the key is an option. If the key is not defined then the result is empty + * otherwise it is a singleton containing the pair. + */ + @pure + @opaque + def optionKeyMap(k: Option[GlobalKey], m: KeyMapping): Map[GlobalKey, KeyMapping] = { + k match { + case None() => Map.empty[GlobalKey, KeyMapping] + case Some(gk) => Map[GlobalKey, KeyMapping](gk -> m) + } + } + + /** Map representing the key-mapping pair of a node. If the key of the node is not defined then the result is empty + * otherwise it is a singleton containing the key with the node's corresponding mapping. + * + * The latter is defined as: + * - KeyInactive if n is a Create Node + * - KeyActive(n.coid) if n is a Fetch Node + * - n.result if n is a Lookup Node + * - KeyActive(n.targetCoid) if n is an Exercise Node + */ + @pure + @opaque + def actionKeyMap(n: Node.Action): Map[GlobalKey, KeyMapping] = { + optionKeyMap(n.gkeyOpt, nodeActionKeyMapping(n)) + } + @pure + @opaque + def nodeKeyMap(n: Node): Map[GlobalKey, KeyMapping] = { + n match { + case r: Node.Rollback => Map.empty[GlobalKey, KeyMapping] + case a: Node.Action => actionKeyMap(a) + } + } + +} + +object CSMKeysProperties { + + import CSMKeysPropertiesDef._ + + /** If s1 and s2 have the same globalKeys, then s1 contains k if and only if s2 contains k. + * + * Even though this is obvious from the definition, the opaque annotations require the needs for theorems + * stating this property. + */ + @pure + @opaque + def containsKeySameGlobalKeys(s1: State, s2: State, k: GlobalKey): Unit = { + require(s1.globalKeys == s2.globalKeys) + unfold(containsKey(s1)(k)) + unfold(containsKey(s2)(k)) + }.ensuring(containsKey(s1)(k) == containsKey(s2)(k)) + @pure + @opaque + def containsOptionKeySameGlobalKeys(s1: State, s2: State, k: Option[GlobalKey]): Unit = { + require(s1.globalKeys == s2.globalKeys) + unfold(containsOptionKey(s1)(k)) + unfold(containsOptionKey(s2)(k)) + k match { + case Some(gk) => containsKeySameGlobalKeys(s1, s2, gk) + case _ => Trivial() + } + }.ensuring(containsOptionKey(s1)(k) == containsOptionKey(s2)(k)) + + /** If s1 and s2 have the same globalKeys, then s1 contains the key of the argument node n if and only if s2 contains it. + * + * Even though this is obvious from the definition, the opaque annotations require the needs for theorems + * stating this property. + */ + @pure + @opaque + def containsActionKeySameGlobalKeys(s1: State, s2: State, n: Node.Action): Unit = { + require(s1.globalKeys == s2.globalKeys) + unfold(containsActionKey(s1)(n)) + unfold(containsActionKey(s2)(n)) + containsOptionKeySameGlobalKeys(s1, s2, n.gkeyOpt) + }.ensuring(containsActionKey(s1)(n) == containsActionKey(s2)(n)) + @pure + @opaque + def containsNodeKeySameGlobalKeys(s1: State, s2: State, n: Node): Unit = { + require(s1.globalKeys == s2.globalKeys) + unfold(containsNodeKey(s1)(n)) + unfold(containsNodeKey(s2)(n)) + n match { + case a: Node.Action => + containsActionKeySameGlobalKeys(s1, s2, a) + case r: Node.Rollback => Trivial() + } + }.ensuring(containsNodeKey(s1)(n) == containsNodeKey(s2)(n)) + + /** e1 and e2 are states such that when e1 is not defined e2 is not defined as well. + * If when both are defined, their globalKeys are the same then one contains the key of n if and only if the other + * also does. + * + * Even though this is obvious from the definition, the opaque annotations require the needs for theorems + * stating this property. + */ + @pure + @opaque + def containsKeySameGlobalKeys[S, T](e1: Either[S, State], e2: Either[T, State], n: Node): Unit = { + require(sameGlobalKeys(e1, e2)) + require(propagatesError(e1, e2)) + + unfold(sameGlobalKeys(e1, e2)) + unfold(containsKey(e1)(n)) + unfold(containsKey(e2)(n)) + unfold(propagatesError(e1, e2)) + + e1 match { + case Right(s1) => + unfold(sameGlobalKeys(s1, e2)) + e2 match { + case Right(s2) => containsNodeKeySameGlobalKeys(s1, s2, n) + case _ => Trivial() + } + case _ => Trivial() + } + + }.ensuring(containsKey(e1)(n) ==> containsKey(e2)(n)) + + /** Definition of s.keys.contains expressed in function containsKey + */ + @pure + @opaque + def keysContains(s: State, k: GlobalKey): Unit = { + unfold(s.keys) + unfold(containsKey(s)(k)) + + MapProperties.concatContains(s.globalKeys, s.activeState.localKeys.mapValues(KeyActive), k) + MapProperties.mapValuesContains(s.activeState.localKeys, KeyActive, k) + }.ensuring(s.keys.contains(k) == (containsKey(s)(k) || s.activeState.localKeys.contains(k))) + + /** Definition of s.activeKeys.contains expressed in function containsKey + */ + @pure + @opaque + def activeKeysContains(s: State, k: GlobalKey): Unit = { + unfold(s.activeKeys) + keysContains(s, k) + MapProperties.mapValuesContains(s.keys, keyMappingToActiveMapping(s.activeState.consumedBy), k) + }.ensuring(s.activeKeys.contains(k) == (containsKey(s)(k) || s.activeState.localKeys.contains(k))) + + /** If s contains k then the activeKeys of s also contains k + */ + @pure + @opaque + def activeKeysContainsKey(s: State, k: GlobalKey): Unit = { + require(containsKey(s)(k)) + activeKeysContains(s, k) + }.ensuring(s.activeKeys.contains(k)) + + /** If k is defined and s contains k then the activeKeys of s also contains k + */ + @pure + @opaque + def activeKeysContainsKey(s: State, k: Option[GlobalKey]): Unit = { + require(containsOptionKey(s)(k)) + unfold(containsOptionKey(s)(k)) + k match { + case Some(gk) => activeKeysContains(s, gk) + case None() => Trivial() + } + }.ensuring(k.forall(s.activeKeys.contains)) + + /** If n key is defined and s contains it then the activeKeys of s also contains the key + */ + @pure + @opaque + def activeKeysContainsKey(s: State, n: Node.Action): Unit = { + require(containsActionKey(s)(n)) + unfold(containsActionKey(s)(n)) + activeKeysContainsKey(s, n.gkeyOpt) + }.ensuring(n.gkeyOpt.forall(s.activeKeys.contains)) + + /** If s contains k then any state extending its global keys also contains k + */ + @pure + @opaque + def containsKeyExtend(s: State, k: GlobalKey, glK: Map[GlobalKey, KeyMapping]): Unit = { + require(s.globalKeys.submapOf(glK)) + require(containsKey(s)(k)) + unfold(containsKey(s)(k)) + unfold(containsKey(extendGlobalKeys(s, glK))(k)) + MapProperties.submapOfContains(s.globalKeys, extendGlobalKeys(s, glK).globalKeys, k) + }.ensuring(containsKey(extendGlobalKeys(s, glK))(k)) + + /** If k is defined and s contains k then any state extending its global keys also contains k + */ + @pure + @opaque + def containsOptionKeyExtend( + s: State, + k: Option[GlobalKey], + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + require(s.globalKeys.submapOf(glK)) + require(containsOptionKey(s)(k)) + unfold(containsOptionKey(s)(k)) + unfold(containsOptionKey(extendGlobalKeys(s, glK))(k)) + k match { + case None() => Trivial() + case Some(gk) => containsKeyExtend(s, gk, glK) + } + }.ensuring(containsOptionKey(extendGlobalKeys(s, glK))(k)) + + /** If the key of a node n is defined and s contains its then any state extending its global keys also contains the key + */ + @pure + @opaque + def containsActionKeyExtend(s: State, n: Node.Action, glK: Map[GlobalKey, KeyMapping]): Unit = { + require(s.globalKeys.submapOf(glK)) + require(containsActionKey(s)(n)) + unfold(containsActionKey(s)(n)) + unfold(containsActionKey(extendGlobalKeys(s, glK))(n)) + containsOptionKeyExtend(s, n.gkeyOpt, glK) + }.ensuring(containsActionKey(extendGlobalKeys(s, glK))(n)) + + @pure + @opaque + def containsNodeKeyExtend(s: State, n: Node, glK: Map[GlobalKey, KeyMapping]): Unit = { + require(s.globalKeys.submapOf(glK)) + require(containsNodeKey(s)(n)) + unfold(containsNodeKey(s)(n)) + unfold(containsNodeKey(extendGlobalKeys(s, glK))(n)) + n match { + case a: Node.Action => containsActionKeyExtend(s, a, glK) + case r: Node.Rollback => Trivial() + } + }.ensuring(containsNodeKey(extendGlobalKeys(s, glK))(n)) + + /** If the key of a node n is defined and s contains its then any state extending its global keys to the left also contains the key + */ + @pure + @opaque + def containsNodeKeyConcatLeft(s: State, n: Node, glK: Map[GlobalKey, KeyMapping]): Unit = { + require(containsNodeKey(s)(n)) + unfold(concatLeftGlobalKeys(s, glK)) + containsNodeKeyExtend(s, n, glK ++ s.globalKeys) + }.ensuring(containsNodeKey(concatLeftGlobalKeys(s, glK))(n)) + + @pure + @opaque + def containsKeyConcatLeft[T]( + e: Either[T, State], + n: Node, + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + require(containsKey(e)(n)) + + unfold(containsKey(e)(n)) + unfold(containsKey(concatLeftGlobalKeys(e, glK))(n)) + unfold(concatLeftGlobalKeys(e, glK)) + + e match { + case Right(s) => containsNodeKeyConcatLeft(s, n, glK) + case _ => Trivial() + } + + }.ensuring(containsKey(concatLeftGlobalKeys(e, glK))(n)) + + /** If the key of a node n1 is defined and e contains it, then adding the key of n2 to the global keys of the state + * does not change the truth of the statement. + */ + @pure + @opaque + def containsKeyAddKeyBeforeNode[T](e: Either[T, State], n1: Node, n2: Node): Unit = { + require(containsKey(e)(n2)) + if (containsKey(e)(n1)) { + Trivial() + } else { + containsKeyConcatLeft(e, n2, nodeKeyMap(n1)) + } + }.ensuring(containsKey(addKeyBeforeNode(e, n1))(n2)) + + /** The key set of the activeKeys of a state are the concatenation between the keyset of its globalKeys and its localKeys + */ + @pure + @opaque + def activeKeysKeySet(s: State): Unit = { + unfold(s.activeKeys) + MapProperties.mapValuesKeySet(s.keys, keyMappingToActiveMapping(s.activeState.consumedBy)) + SetProperties.equalsTransitivity( + s.activeKeys.keySet, + s.keys.keySet, + s.globalKeys.keySet ++ s.activeState.localKeys.keySet, + ) + }.ensuring(s.activeKeys.keySet === s.globalKeys.keySet ++ s.activeState.localKeys.keySet) + + /** Getting a mapping in the keys of a state is the same as first looking into the local keys, mapping it to KeyActive + * (since localKeys contains contractIds) and in case of failure looking in the global keys + */ + @pure + @opaque + def keysGet(s: State, k: GlobalKey): Unit = { + unfold(s.keys) + MapAxioms.concatGet(s.globalKeys, s.activeState.localKeys.mapValues(KeyActive), k) + MapProperties.mapValuesGet(s.activeState.localKeys, KeyActive, k) + }.ensuring( + s.keys.get(k) == s.activeState.localKeys.get(k).map(KeyActive).orElse(s.globalKeys.get(k)) + ) + + /** Getting a mapping in the activeKeys of a state is the same as first looking into the local keys, mapping it to + * KeyActive (since localKeys contains contractIds) and afterward filtering it if is consumed, and in case of failure + * looking in the global keys before filtering it. + */ + @pure + @opaque + def activeKeysGet(s: State, k: GlobalKey): Unit = { + unfold(s.activeKeys) + keysGet(s, k) + MapProperties.mapValuesGet(s.keys, keyMappingToActiveMapping(s.activeState.consumedBy), k) + }.ensuring( + s.activeKeys.get(k) == s.activeState.localKeys + .get(k) + .map(KeyActive andThen keyMappingToActiveMapping(s.activeState.consumedBy)) + .orElse(s.globalKeys.get(k).map(keyMappingToActiveMapping(s.activeState.consumedBy))) + ) + + /** Getting a mapping in the activeKeys of a state is the same as first looking into the local keys, mapping it to + * KeyActive (since localKeys contains contractIds) and afterward filtering it if is consumed, and in case of failure + * looking in the global keys before filtering it. + */ + @pure + @opaque + def activeKeysGetOrElse(s: State, k: GlobalKey, d: KeyMapping): Unit = { + unfold(s.activeKeys.getOrElse(k, d)) + activeKeysGet(s, k) + }.ensuring( + s.activeKeys.getOrElse(k, d) == s.activeState.localKeys + .get(k) + .map(KeyActive andThen keyMappingToActiveMapping(s.activeState.consumedBy)) + .orElse(s.globalKeys.get(k).map(keyMappingToActiveMapping(s.activeState.consumedBy))) + .getOrElse(d) + ) + + /** If the globalKeys of a state already contain a key, then concatenating other keys to the globalKeys does not affect + * the mapping of that key in the activeKey. + */ + @pure + @opaque + def activeKeysConcatLeftGlobalKeys( + s: State, + k: GlobalKey, + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + require(containsKey(s)(k)) + activeKeysContains(s, k) + MapAxioms.submapOfGet(s.activeKeys, concatLeftGlobalKeys(s, glK).activeKeys, k) + }.ensuring( + concatLeftGlobalKeys(s, glK).activeKeys.get(k) == s.activeKeys.get(k) + ) + + /** Getting a mapping in the activeKeys after having added the key to the globalKeys is the same as searching it in the + * activeKeys of the state before the addition and in case of failure taking the newly added mapping before filtering + * it if is consumed. + */ + @pure + @opaque + def activeKeysAddKey(s: State, k: GlobalKey, m: KeyMapping): Unit = { + activeKeysGet(s, k) + activeKeysGet(addKey(s, k, m), k) + unfold(addKey(s, k, m)) + unfold(containsKey(s)(k)) + unfold(s.globalKeys.contains) + activeKeysContains(addKey(s, k, m), k) + }.ensuring( + addKey(s, k, m).activeKeys.contains(k) && + (addKey(s, k, m).activeKeys.get(k) == s.activeKeys + .get(k) + .orElse(Some(keyMappingToActiveMapping(s.activeState.consumedBy)(m)))) + ) + + /** If a state already contains the key of an Create node, then concatenating keys to the globalKeys and + * visiting the node leads to the same result than doing the same operations in the reverse order. + */ + @pure + @opaque + def visitCreateConcatLeftGlobalKeys( + s: State, + contractId: ContractId, + mbKey: Option[GlobalKey], + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + require(containsOptionKey(s)(mbKey)) + + unfold(s.visitCreate(contractId, mbKey)) + unfold(concatLeftGlobalKeys(s, glK)) + unfold(concatLeftGlobalKeys(s, glK).visitCreate(contractId, mbKey)) + unfold(concatLeftGlobalKeys(s.visitCreate(contractId, mbKey), glK)) + unfold(containsOptionKey(s)(mbKey)) + + val me = + s.copy( + locallyCreated = s.locallyCreated + contractId, + activeState = s.activeState + .copy(locallyCreatedThisTimeline = s.activeState.locallyCreatedThisTimeline + contractId), + ) + + mbKey match { + case Some(k) => + activeKeysConcatLeftGlobalKeys(s, k, glK) + unfold( + concatLeftGlobalKeys(me.copy(activeState = me.activeState.createKey(k, contractId)), glK) + ) + case None() => + unfold(concatLeftGlobalKeys(me, glK)) + } + }.ensuring( + concatLeftGlobalKeys(s, glK).visitCreate(contractId, mbKey) == + concatLeftGlobalKeys(s.visitCreate(contractId, mbKey), glK) + ) + + /** If a state already contains the key of a Lookup node, then concatenating keys to the globalKeys and + * visiting the node leads to the same result than doing the same operations in the reverse order. + */ + @pure + @opaque + def visitLookupConcatLeftGlobalKeys( + s: State, + gk: GlobalKey, + keyResolution: Option[ContractId], + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + require(containsKey(s)(gk)) + + unfold(s.visitLookup(gk, keyResolution)) + unfold(concatLeftGlobalKeys(s, glK).visitLookup(gk, keyResolution)) + unfold(concatLeftGlobalKeys(s.visitLookup(gk, keyResolution), glK)) + + unfold(s.activeKeys.getOrElse(gk, KeyInactive)) + unfold(concatLeftGlobalKeys(s, glK).activeKeys.getOrElse(gk, KeyInactive)) + + activeKeysConcatLeftGlobalKeys(s, gk, glK) + + }.ensuring( + concatLeftGlobalKeys(s, glK).visitLookup(gk, keyResolution) == + concatLeftGlobalKeys(s.visitLookup(gk, keyResolution), glK) + ) + + /** If a state already contains the key of a Fetch node, then concatenating keys to the globalKeys and + * visiting the node leads to the same result than doing the same operations in the reverse order. + */ + @pure + @opaque + def assertKeyMappingConcatLeftGlobalKeys( + s: State, + cid: ContractId, + mbKey: Option[GlobalKey], + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + require(containsOptionKey(s)(mbKey)) + + unfold(s.assertKeyMapping(cid, mbKey)) + unfold(concatLeftGlobalKeys(s, glK).assertKeyMapping(cid, mbKey)) + unfold(concatLeftGlobalKeys(s.assertKeyMapping(cid, mbKey), glK)) + unfold(containsOptionKey(s)(mbKey)) + + mbKey match { + case None() => Trivial() + case Some(gk) => visitLookupConcatLeftGlobalKeys(s, gk, KeyActive(cid), glK) + } + + }.ensuring( + concatLeftGlobalKeys(s, glK).assertKeyMapping(cid, mbKey) == + concatLeftGlobalKeys(s.assertKeyMapping(cid, mbKey), glK) + ) + + /** Concatenating keys to the globalKeys and consuming a contract leads to the same result than doing the same + * operations in the reverse order. + */ + @pure + @opaque + def consumeConcatLeftGlobalKeys( + s: State, + cid: ContractId, + nid: NodeId, + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + unfold(s.consume(cid, nid)) + unfold(concatLeftGlobalKeys(s, glK).consume(cid, nid)) + unfold(concatLeftGlobalKeys(s.consume(cid, nid), glK)) + unfold(concatLeftGlobalKeys(s, glK)) + }.ensuring( + concatLeftGlobalKeys(s, glK).consume(cid, nid) == + concatLeftGlobalKeys(s.consume(cid, nid), glK) + ) + + /** If a state already contains the key of an Exercise node, then concatenating keys to the globalKeys and + * visiting the node leads to the same result than doing the same operations in the reverse order. + */ + @pure + @opaque + def visitExerciseConcatLeftGlobalKeys( + s: State, + nodeId: NodeId, + targetId: ContractId, + mbKey: Option[GlobalKey], + byKey: Boolean, + consuming: Boolean, + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + require(containsOptionKey(s)(mbKey)) + + unfold(s.visitExercise(nodeId, targetId, mbKey, byKey, consuming)) + unfold(concatLeftGlobalKeys(s, glK).visitExercise(nodeId, targetId, mbKey, byKey, consuming)) + unfold(concatLeftGlobalKeys(s.visitExercise(nodeId, targetId, mbKey, byKey, consuming), glK)) + assertKeyMappingConcatLeftGlobalKeys(s, targetId, mbKey, glK) + unfold(concatLeftGlobalKeys(s.assertKeyMapping(targetId, mbKey), glK)) + + s.assertKeyMapping(targetId, mbKey) match { + case Right(s) => consumeConcatLeftGlobalKeys(s, targetId, nodeId, glK) + case _ => Trivial() + } + + }.ensuring( + concatLeftGlobalKeys(s, glK).visitExercise(nodeId, targetId, mbKey, byKey, consuming) == + concatLeftGlobalKeys(s.visitExercise(nodeId, targetId, mbKey, byKey, consuming), glK) + ) + + /** Concatenating keys to the globalKeys and then mapping the error leads to the same result than doing the same + * operations in the reverse order. + */ + @pure + @opaque + def toKeyInputErrorConcatLeftGlobalKeys( + e: Either[InconsistentContractKey, State], + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + unfold(concatLeftGlobalKeys(e, glK)) + unfold(toKeyInputError(e)) + unfold(toKeyInputError(concatLeftGlobalKeys(e, glK))) + unfold(concatLeftGlobalKeys(toKeyInputError(e), glK)) + }.ensuring( + toKeyInputError(concatLeftGlobalKeys(e, glK)) == + concatLeftGlobalKeys(toKeyInputError(e), glK) + ) + + /** Concatenating keys to the globalKeys and then mapping the error leads to the same result than doing the same + * operations in the reverse order. + */ + @pure + @opaque + @targetName("toKeyInputErrorConcatLeftGlobalKeysDuplicateContractKey") + def toKeyInputErrorConcatLeftGlobalKeys( + e: Either[DuplicateContractKey, State], + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + unfold(concatLeftGlobalKeys(e, glK)) + unfold(toKeyInputError(e)) + unfold(toKeyInputError(concatLeftGlobalKeys(e, glK))) + unfold(concatLeftGlobalKeys(toKeyInputError(e), glK)) + }.ensuring( + toKeyInputError(concatLeftGlobalKeys(e, glK)) == + concatLeftGlobalKeys(toKeyInputError(e), glK) + ) + + /** If a state already contains the key of a node, then concatenating keys to the globalKeys and handling the node + * leads to the same result than doing the same operations in the reverse order. + */ + @pure + @opaque + def handleNodeConcatLeftGlobalKeys( + s: State, + nid: NodeId, + n: Node.Action, + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + require(containsActionKey(s)(n)) + + unfold(concatLeftGlobalKeys(s, glK).handleNode(nid, n)) + unfold(s.handleNode(nid, n)) + unfold(containsActionKey(s)(n)) + n match { + case create: Node.Create => + visitCreateConcatLeftGlobalKeys(s, create.coid, create.gkeyOpt, glK) + toKeyInputErrorConcatLeftGlobalKeys(s.visitCreate(create.coid, create.gkeyOpt), glK) + case fetch: Node.Fetch => + assertKeyMappingConcatLeftGlobalKeys(s, fetch.coid, fetch.gkeyOpt, glK) + toKeyInputErrorConcatLeftGlobalKeys(s.assertKeyMapping(fetch.coid, fetch.gkeyOpt), glK) + case lookup: Node.LookupByKey => + unfold(containsOptionKey(s)(n.gkeyOpt)) + visitLookupConcatLeftGlobalKeys(s, lookup.gkey, lookup.result, glK) + toKeyInputErrorConcatLeftGlobalKeys(s.visitLookup(lookup.gkey, lookup.result), glK) + case exe: Node.Exercise => + visitExerciseConcatLeftGlobalKeys( + s, + nid, + exe.targetCoid, + exe.gkeyOpt, + exe.byKey, + exe.consuming, + glK, + ) + toKeyInputErrorConcatLeftGlobalKeys( + s.visitExercise(nid, exe.targetCoid, exe.gkeyOpt, exe.byKey, exe.consuming), + glK, + ) + } + }.ensuring( + concatLeftGlobalKeys(s, glK).handleNode(nid, n) == + concatLeftGlobalKeys(s.handleNode(nid, n), glK) + ) + + /** If a state already contains the key of a node, then concatenating keys to the globalKeys and handling the node + * leads to the same result than doing the same operations in the reverse order. + */ + @pure + @opaque + def handleNodeConcatLeftGlobalKeys( + e: Either[KeyInputError, State], + nid: NodeId, + n: Node.Action, + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + require(containsKey(e)(n)) + unfold(containsKey(e)(n)) + unfold(concatLeftGlobalKeys(e, glK)) + unfold(handleNode(concatLeftGlobalKeys(e, glK), nid, n)) + unfold(handleNode(e, nid, n)) + e match { + case Right(s) => + unfold(containsNodeKey(s)(n)) + unfold(containsActionKey(s)(n)) + handleNodeConcatLeftGlobalKeys(s, nid, n, glK) + case _ => Trivial() + } + + }.ensuring( + handleNode(concatLeftGlobalKeys(e, glK), nid, n) == + concatLeftGlobalKeys(handleNode(e, nid, n), glK) + ) + + /** Concatenating keys to the globalKeys and calling beginRollback leads to the same result than doing the same operations in the reverse + * order + */ + @pure + @opaque + def beginRollbackConcatLeftGlobalKeys( + e: Either[KeyInputError, State], + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + unfold(beginRollback(concatLeftGlobalKeys(e, glK))) + unfold(concatLeftGlobalKeys(e, glK)) + unfold(beginRollback(e)) + unfold(concatLeftGlobalKeys(beginRollback(e), glK)) + + e match { + case Right(s) => + unfold(concatLeftGlobalKeys(s, glK).beginRollback()) + unfold(concatLeftGlobalKeys(s, glK)) + unfold(s.beginRollback()) + unfold(concatLeftGlobalKeys(s.beginRollback(), glK)) + case _ => Trivial() + } + + }.ensuring( + beginRollback(concatLeftGlobalKeys(e, glK)) == + concatLeftGlobalKeys(beginRollback(e), glK) + ) + + /** Concatenating keys to the globalKeys and calling endRollback leads to the same result than doing the same operations in the reverse + * order + */ + @pure + @opaque + @dropVCs + def endRollbackConcatLeftGlobalKeys( + e: Either[KeyInputError, State], + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + unfold(endRollback(concatLeftGlobalKeys(e, glK))) + unfold(concatLeftGlobalKeys(e, glK)) + unfold(endRollback(e)) + unfold(concatLeftGlobalKeys(endRollback(e), glK)) + + e match { + case Right(s) => + unfold(concatLeftGlobalKeys(s, glK).endRollback()) + unfold(concatLeftGlobalKeys(s, glK)) + unfold(s.endRollback()) + unfold(concatLeftGlobalKeys(s.endRollback(), glK)) + case _ => Trivial() + } + + }.ensuring( + endRollback(concatLeftGlobalKeys(e, glK)) == + concatLeftGlobalKeys(endRollback(e), glK) + ) + + /** Concatenating keys on the left is an associative operation. + */ + @pure + @opaque + def concatLeftGlobalKeysAssociativity( + s: State, + glK1: Map[GlobalKey, KeyMapping], + glK2: Map[GlobalKey, KeyMapping], + ): Unit = { + unfold(concatLeftGlobalKeys(concatLeftGlobalKeys(s, glK1), glK2)) + unfold(concatLeftGlobalKeys(s, glK1)) + unfold(concatLeftGlobalKeys(s, glK2 ++ glK1)) + MapProperties.concatAssociativity(glK2, glK1, s.globalKeys) + MapAxioms.extensionality((glK2 ++ glK1) ++ s.globalKeys, glK2 ++ (glK1 ++ s.globalKeys)) + }.ensuring( + concatLeftGlobalKeys(concatLeftGlobalKeys(s, glK1), glK2) == + concatLeftGlobalKeys(s, glK2 ++ glK1) + ) + + /** Concatenating keys on the left is an associative operation. + */ + @pure + @opaque + def concatLeftGlobalKeysAssociativity[T]( + e: Either[T, State], + glK1: Map[GlobalKey, KeyMapping], + glK2: Map[GlobalKey, KeyMapping], + ): Unit = { + unfold(concatLeftGlobalKeys(e, glK1)) + unfold(concatLeftGlobalKeys(concatLeftGlobalKeys(e, glK1), glK2)) + unfold(concatLeftGlobalKeys(e, glK2 ++ glK1)) + e match { + case Right(s) => concatLeftGlobalKeysAssociativity(s, glK1, glK2) + case Left(_) => Trivial() + } + }.ensuring( + concatLeftGlobalKeys(concatLeftGlobalKeys(e, glK1), glK2) == + concatLeftGlobalKeys(e, glK2 ++ glK1) + ) + +} diff --git a/daml-lf/verification/transaction/CSMLocallyCreatedProperties.scala b/daml-lf/verification/transaction/CSMLocallyCreatedProperties.scala new file mode 100644 index 0000000000..2d1ba77ec0 --- /dev/null +++ b/daml-lf/verification/transaction/CSMLocallyCreatedProperties.scala @@ -0,0 +1,198 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package transaction + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import scala.annotation.targetName +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ + +import ContractStateMachine._ +import CSMHelpers._ +import CSMEither._ +import CSMEitherDef._ + +/** This file shows how the set of locallyCreated and consumed contracts of a state behave after handling a node. + */ +object CSMLocallyCreatedProperties { + + /** The set of locally created contracts has the new contract added to it when the node is an instance of Create and + * otherwise remains the same + */ + @pure + @opaque + def handleNodeLocallyCreated(s: State, id: NodeId, node: Node.Action): Unit = { + unfold(s.handleNode(id, node)) + + val res = s.handleNode(id, node) + + node match { + case create: Node.Create => Trivial() + case fetch: Node.Fetch => + sameLocallyCreatedTransitivity(s, s.assertKeyMapping(fetch.coid, fetch.gkeyOpt), res) + unfold(sameLocallyCreated(s, res)) + case lookup: Node.LookupByKey => + sameLocallyCreatedTransitivity(s, s.visitLookup(lookup.gkey, lookup.result), res) + unfold(sameLocallyCreated(s, res)) + case exe: Node.Exercise => + sameLocallyCreatedTransitivity( + s, + s.visitExercise(id, exe.targetCoid, exe.gkeyOpt, exe.byKey, exe.consuming), + res, + ) + unfold(sameLocallyCreated(s, res)) + } + }.ensuring( + s.handleNode(id, node) + .forall(r => + node match { + case Node.Create(coid, _) => r.locallyCreated == s.locallyCreated + coid + case _ => s.locallyCreated == r.locallyCreated + } + ) + ) + + /** The set of locally created contracts has the new contract added to it when the node is an instance of Create and + * otherwise remains the same + */ + @pure + @opaque + def handleNodeLocallyCreated( + e: Either[KeyInputError, State], + id: NodeId, + node: Node.Action, + ): Unit = { + unfold(handleNode(e, id, node)) + e match { + case Left(_) => Trivial() + case Right(s) => handleNodeLocallyCreated(s, id, node) + } + }.ensuring( + handleNode(e, id, node).forall(r => + e.forall(s => + node match { + case Node.Create(coid, _) => r.locallyCreated == s.locallyCreated + coid + case _ => s.locallyCreated == r.locallyCreated + } + ) + ) + ) + + /** If two states propagates the error have the same set of locally created contracts then if the first is a subset of + * a third set then the second one also is. + */ + def sameLocallyCreatedSubsetOfTransivity[T, U]( + e1: Either[T, State], + e2: Either[U, State], + lc: Set[ContractId], + ): Unit = { + require(sameLocallyCreated(e1, e2)) + require(propagatesError(e1, e2)) + require(e1.forall(s1 => s1.locallyCreated.subsetOf(lc))) + + unfold(propagatesError(e1, e2)) + unfold(sameLocallyCreated(e1, e2)) + e1 match { + case Left(_) => Trivial() + case Right(s1) => unfold(sameLocallyCreated(s1, e2)) + } + }.ensuring(e2.forall(s2 => s2.locallyCreated.subsetOf(lc))) + + /** The set of consumed contracts has the consumed contract added to it when the node is an instance of a consuming + * Exercise and otherwise remains the same + */ + @pure + @opaque + def handleNodeConsumed(s: State, id: NodeId, node: Node.Action): Unit = { + unfold(s.handleNode(id, node)) + + val res = s.handleNode(id, node) + + node match { + case create: Node.Create => + sameConsumedTransitivity(s, s.visitCreate(create.coid, create.gkeyOpt), res) + unfold(sameConsumed(s, res)) + case fetch: Node.Fetch => + sameConsumedTransitivity(s, s.assertKeyMapping(fetch.coid, fetch.gkeyOpt), res) + unfold(sameConsumed(s, res)) + case lookup: Node.LookupByKey => + sameConsumedTransitivity(s, s.visitLookup(lookup.gkey, lookup.result), res) + unfold(sameConsumed(s, res)) + case exe: Node.Exercise => + unfold(s.visitExercise(id, exe.targetCoid, exe.gkeyOpt, exe.byKey, exe.consuming)) + for { + state <- s.assertKeyMapping(exe.targetCoid, exe.gkeyOpt) + } yield { + unfold(sameConsumed(s, s.assertKeyMapping(exe.targetCoid, exe.gkeyOpt))) + unfold(state.consume(exe.targetCoid, id)) + } + () + } + }.ensuring( + s.handleNode(id, node) + .forall(r => + node match { + case Node.Exercise(targetCoid, true, _, _, _) => r.consumed == s.consumed + targetCoid + case _ => s.consumed == r.consumed + } + ) + ) + + /** The set of consumed contracts has the consumed contract added to it when the node is an instance of a consuming + * Exercise and otherwise remains the same + */ + @pure + @opaque + def handleNodeConsumed(e: Either[KeyInputError, State], id: NodeId, node: Node.Action): Unit = { + unfold(handleNode(e, id, node)) + e match { + case Left(_) => Trivial() + case Right(s) => handleNodeConsumed(s, id, node) + } + }.ensuring( + handleNode(e, id, node).forall(r => + e.forall(s => + node match { + case Node.Exercise(targetCoid, true, _, _, _) => r.consumed == s.consumed + targetCoid + case _ => s.consumed == r.consumed + } + ) + ) + ) + + /** If two states propagates the error have the same set of consumed contracts then if the first is a subset of + * a third set then the second one also is. + */ + def sameConsumedSubsetOfTransivity[T, U]( + e1: Either[T, State], + e2: Either[U, State], + lc: Set[ContractId], + ): Unit = { + require(sameConsumed(e1, e2)) + require(propagatesError(e1, e2)) + require(e1.forall(s1 => s1.consumed.subsetOf(lc))) + + unfold(propagatesError(e1, e2)) + unfold(sameConsumed(e1, e2)) + e1 match { + case Left(_) => Trivial() + case Right(s1) => unfold(sameConsumed(s1, e2)) + } + }.ensuring(e2.forall(s2 => s2.consumed.subsetOf(lc))) + +} diff --git a/daml-lf/verification/transaction/ContractStateMachineAlt.scala b/daml-lf/verification/transaction/ContractStateMachineAlt.scala new file mode 100644 index 0000000000..997d3f4687 --- /dev/null +++ b/daml-lf/verification/transaction/ContractStateMachineAlt.scala @@ -0,0 +1,334 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package transaction + +import stainless.lang._ +import stainless.annotation._ +import stainless.collection._ + +import utils.{ + Either, + Map, + Set, + Value, + GlobalKey, + Transaction, + Unreachable, + Node, + ContractKeyUniquenessMode, + NodeId, + MapProperties, + SetProperties, +} +import utils.Value.ContractId +import utils.Transaction.{ + KeyCreate, + KeyInputError, + NegativeKeyLookup, + InconsistentContractKey, + DuplicateContractKey, + KeyInput, +} +import CSMHelpers._ +import CSMEitherDef._ +import CSMEither._ + +/** Simplified version of the contract state machine. All the implementations are simplified, and [[State.globalKeys]] are + * not updated anymore during [[State.handleNode]]. [[CSMKeysPropertiesDef.addKeyBeforeNode]] has to be called beforehand. + */ +case class State( + locallyCreated: Set[ContractId], + consumed: Set[ContractId], + globalKeys: Map[GlobalKey, ContractStateMachine.KeyMapping], + activeState: ContractStateMachine.ActiveLedgerState, + rollbackStack: List[ContractStateMachine.ActiveLedgerState], +) { + + import ContractStateMachine._ + + @pure @opaque + def keys: Map[GlobalKey, KeyMapping] = { + MapProperties.mapValuesKeySet(activeState.localKeys, KeyActive) + SetProperties.unionEqualsRight( + globalKeys.keySet, + activeState.localKeys.keySet, + activeState.localKeys.mapValues(KeyActive).keySet, + ) + MapProperties.concatKeySet(globalKeys, activeState.localKeys.mapValues(KeyActive)) + SetProperties.equalsTransitivity( + (globalKeys ++ activeState.localKeys.mapValues(KeyActive)).keySet, + globalKeys.keySet ++ activeState.localKeys.mapValues(KeyActive).keySet, + globalKeys.keySet ++ activeState.localKeys.keySet, + ) + + globalKeys ++ activeState.localKeys.mapValues(KeyActive) + }.ensuring(res => res.keySet === globalKeys.keySet ++ activeState.localKeys.keySet) + + @pure @opaque + def activeKeys: Map[GlobalKey, KeyMapping] = { + keys.mapValues(keyMappingToActiveMapping(activeState.consumedBy)) + } + + @pure @opaque + def consume(cid: ContractId, nid: NodeId): State = { + unfold(activeState.consume(cid, nid)) + this.copy(activeState = activeState.consume(cid, nid), consumed = consumed + cid) + }.ensuring(res => + (globalKeys == res.globalKeys) && + (locallyCreated == res.locallyCreated) && + (rollbackStack == res.rollbackStack) && + (activeState.localKeys == res.activeState.localKeys) + ) + + @pure @opaque + def visitCreate( + contractId: ContractId, + mbKey: Option[GlobalKey], + ): Either[DuplicateContractKey, State] = { + + val me = + this.copy( + locallyCreated = locallyCreated + contractId, + activeState = this.activeState + .copy(locallyCreatedThisTimeline = + this.activeState.locallyCreatedThisTimeline + contractId + ), + ) + + val res = mbKey match { + case None() => Right[DuplicateContractKey, State](me) + case Some(gk) => + val conflict = activeKeys.get(gk).exists(_ != KeyInactive) + + Either.cond( + !conflict, + me.copy(activeState = me.activeState.createKey(gk, contractId)), + DuplicateContractKey(gk), + ) + } + + unfold(sameGlobalKeys(this, res)) + unfold(sameStack(this, res)) + unfold(sameConsumed(this, res)) + + res + }.ensuring(res => + sameGlobalKeys(this, res) && + sameStack(this, res) && + sameConsumed(this, res) && + res.forall(r => r.locallyCreated == locallyCreated + contractId) + ) + + @pure @opaque + def visitExercise( + nodeId: NodeId, + targetId: ContractId, + mbKey: Option[GlobalKey], + byKey: Boolean, + consuming: Boolean, + ): Either[InconsistentContractKey, State] = { + val res = + for { + state <- assertKeyMapping(targetId, mbKey) + } yield { + + if (consuming) + state.consume(targetId, nodeId) + else state + } + unfold(sameGlobalKeys(this, assertKeyMapping(targetId, mbKey))) + unfold(sameStack(this, assertKeyMapping(targetId, mbKey))) + unfold(sameLocalKeys(this, assertKeyMapping(targetId, mbKey))) + unfold(sameLocallyCreated(this, assertKeyMapping(targetId, mbKey))) + unfold(sameGlobalKeys(this, res)) + unfold(sameLocalKeys(this, res)) + unfold(sameStack(this, res)) + unfold(sameLocallyCreated(this, res)) + res + }.ensuring(res => + sameGlobalKeys(this, res) && + sameStack(this, res) && + sameLocalKeys(this, res) && + sameLocallyCreated(this, res) + ) + + @pure @opaque + def visitLookup( + gk: GlobalKey, + keyResolution: Option[ContractId], + ): Either[InconsistentContractKey, State] = { + val res = Either.cond( + activeKeys.getOrElse(gk, KeyInactive) == keyResolution, + this, + InconsistentContractKey(gk), + ) + unfold(sameState(this, res)) + res + }.ensuring(res => sameState(this, res)) + + @pure @opaque + def assertKeyMapping( + cid: ContractId, + mbKey: Option[GlobalKey], + ): Either[InconsistentContractKey, State] = { + val res = mbKey match { + case None() => Right[InconsistentContractKey, State](this) + case Some(gk) => visitLookup(gk, KeyActive(cid)) + } + unfold(sameState(this, res)) + res + }.ensuring(res => sameState(this, res)) + + @pure @opaque + def handleNode(id: NodeId, node: Node.Action): Either[KeyInputError, State] = { + val res = node match { + case create: Node.Create => toKeyInputError(visitCreate(create.coid, create.gkeyOpt)) + case fetch: Node.Fetch => toKeyInputError(assertKeyMapping(fetch.coid, fetch.gkeyOpt)) + case lookup: Node.LookupByKey => toKeyInputError(visitLookup(lookup.gkey, lookup.result)) + case exe: Node.Exercise => + toKeyInputError(visitExercise(id, exe.targetCoid, exe.gkeyOpt, exe.byKey, exe.consuming)) + } + + @pure @opaque + def sameHandleNode: Unit = { + node match { + case create: Node.Create => + sameGlobalKeysTransitivity(this, visitCreate(create.coid, create.gkeyOpt), res) + sameStackTransitivity(this, visitCreate(create.coid, create.gkeyOpt), res) + case fetch: Node.Fetch => + sameGlobalKeysTransitivity(this, assertKeyMapping(fetch.coid, fetch.gkeyOpt), res) + sameStackTransitivity(this, assertKeyMapping(fetch.coid, fetch.gkeyOpt), res) + case lookup: Node.LookupByKey => + sameGlobalKeysTransitivity(this, visitLookup(lookup.gkey, lookup.result), res) + sameStackTransitivity(this, visitLookup(lookup.gkey, lookup.result), res) + case exe: Node.Exercise => + sameGlobalKeysTransitivity( + this, + visitExercise(id, exe.targetCoid, exe.gkeyOpt, exe.byKey, exe.consuming), + res, + ) + sameStackTransitivity( + this, + visitExercise(id, exe.targetCoid, exe.gkeyOpt, exe.byKey, exe.consuming), + res, + ) + } + }.ensuring(sameGlobalKeys(this, res) && sameStack(this, res)) + + sameHandleNode + + res + }.ensuring(res => + sameGlobalKeys(this, res) && + sameStack(this, res) + ) + + @pure @opaque + def beginRollback(): State = { + val res = this.copy(rollbackStack = activeState :: rollbackStack) + unfold(res.withinRollbackScope) + res + }.ensuring(res => + (globalKeys == res.globalKeys) && + (locallyCreated == res.locallyCreated) && + (consumed == res.consumed) + ) + + @pure @opaque + def endRollback(): Either[KeyInputError, State] = { + val res = rollbackStack match { + case Nil() => + Left[KeyInputError, State]( + Left[InconsistentContractKey, DuplicateContractKey]( + InconsistentContractKey(GlobalKey(BigInt(0))) + ) + ) + case Cons(headState, tailStack) => + Right[KeyInputError, State](this.copy(activeState = headState, rollbackStack = tailStack)) + } + unfold(sameGlobalKeys(this, res)) + unfold(sameLocallyCreated(this, res)) + unfold(sameConsumed(this, res)) + res + }.ensuring(res => + sameGlobalKeys(this, res) && + sameLocallyCreated(this, res) && + sameConsumed(this, res) + ) + + @pure @opaque + def withinRollbackScope: Boolean = !rollbackStack.isEmpty + + @pure + def advance(substate: State): Either[Unit, State] = { + require(!substate.withinRollbackScope) + if ( + substate.globalKeys.keySet + .forall(k => activeKeys.get(k).forall(m => Some(m) == substate.globalKeys.get(k))) + ) { + Right[Unit, State]( + this.copy( + locallyCreated = locallyCreated ++ substate.locallyCreated, + consumed = consumed ++ substate.consumed, + globalKeys = substate.globalKeys ++ globalKeys, + activeState = activeState.advance(substate.activeState), + ) + ) + } else { + Left[Unit, State](()) + } + } + +} + +object State { + def empty: State = new State( + Set.empty, + Set.empty, + Map.empty, + ContractStateMachine.ActiveLedgerState.empty, + List.empty, + ) +} + +object ContractStateMachine { + + type KeyResolver = Map[GlobalKey, KeyMapping] + + type KeyMapping = Option[ContractId] + val KeyInactive: KeyMapping = None[ContractId]() + val KeyActive: ContractId => KeyMapping = Some[ContractId](_) + + final case class ActiveLedgerState( + locallyCreatedThisTimeline: Set[ContractId], + consumedBy: Map[ContractId, NodeId], + localKeys: Map[GlobalKey, ContractId], + ) { + + @pure @opaque + def consume(contractId: ContractId, nodeId: NodeId): ActiveLedgerState = + this.copy(consumedBy = consumedBy.updated(contractId, nodeId)) + + def createKey(key: GlobalKey, cid: ContractId): ActiveLedgerState = + this.copy(localKeys = localKeys.updated(key, cid)) + + @pure @opaque + def advance(substate: ActiveLedgerState): ActiveLedgerState = + ActiveLedgerState( + locallyCreatedThisTimeline = + locallyCreatedThisTimeline ++ substate.locallyCreatedThisTimeline, + consumedBy = consumedBy ++ substate.consumedBy, + localKeys = localKeys ++ substate.localKeys, + ) + } + + object ActiveLedgerState { + def empty: ActiveLedgerState = ActiveLedgerState( + Set.empty[ContractId], + Map.empty[ContractId, NodeId], + Map.empty[GlobalKey, ContractId], + ) + } +} diff --git a/daml-lf/verification/translation/CSMConversion.scala b/daml-lf/verification/translation/CSMConversion.scala new file mode 100644 index 0000000000..d66a596364 --- /dev/null +++ b/daml-lf/verification/translation/CSMConversion.scala @@ -0,0 +1,626 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package translation + +import stainless.lang._ +import stainless.annotation._ +import stainless.collection._ +import utils.{ + Either, + Map, + Set, + Value, + GlobalKey, + Transaction, + Unreachable, + Node, + ContractKeyUniquenessMode, + MapProperties, + MapAxioms, + NodeId, + SetAxioms, + SetProperties, + Trivial, +} +import utils.Value.ContractId +import utils.Transaction.{ + KeyCreate, + KeyInputError, + NegativeKeyLookup, + InconsistentContractKey, + DuplicateContractKey, + KeyInput, +} + +import transaction.ContractStateMachine._ +import transaction.CSMHelpers._ +import transaction.CSMKeysPropertiesDef._ +import transaction.CSMKeysProperties._ +import transaction.CSMInvariantDerivedProp._ +import transaction.CSMInvariantDef._ +import transaction.State + +import ContractStateMachine.{State => StateOriginal, ActiveLedgerState => ActiveLedgerStateOriginal} + +object CSMConversion { + @pure + def toKeyMapping: KeyInput => KeyMapping = _.toKeyMapping + @pure + def toKeyInput: KeyMapping => KeyInput = { + case None() => NegativeKeyLookup + case Some(cid) => Transaction.KeyActive(cid) + } + @pure @opaque + def globalKeys(gki: Map[GlobalKey, KeyInput]): Map[GlobalKey, Option[ContractId]] = { + MapProperties.mapValuesKeySet(gki, toKeyMapping) + gki.mapValues(toKeyMapping) + }.ensuring(_.keySet === gki.keySet) + + @pure + @opaque + def globalKeyInputs(gki: Map[GlobalKey, KeyMapping]): Map[GlobalKey, KeyInput] = { + MapProperties.mapValuesKeySet(gki, toKeyInput) + gki.mapValues(toKeyInput) + }.ensuring(_.keySet === gki.keySet) + + @pure @opaque + def globalKeysGlobalKeyInputs(gki: Map[GlobalKey, KeyMapping]): Unit = { + unfold(globalKeyInputs(gki)) + unfold(globalKeys(globalKeyInputs(gki))) + MapProperties.mapValuesAndThen(gki, toKeyInput, toKeyMapping) + if (gki =/= gki.mapValues(toKeyInput andThen toKeyMapping)) { + val k = MapProperties.notEqualsWitness(gki, gki.mapValues(toKeyInput andThen toKeyMapping)) + MapProperties.mapValuesGet(gki, toKeyInput andThen toKeyMapping, k) + } + MapProperties.equalsTransitivity( + gki, + gki.mapValues(toKeyInput andThen toKeyMapping), + globalKeys(globalKeyInputs(gki)), + ) + MapAxioms.extensionality(gki, globalKeys(globalKeyInputs(gki))) + }.ensuring(globalKeys(globalKeyInputs(gki)) == gki) + + @pure + def toOriginal(s: ActiveLedgerState): ActiveLedgerStateOriginal[NodeId] = { + ActiveLedgerStateOriginal[NodeId](s.locallyCreatedThisTimeline, s.consumedBy, s.localKeys) + } + @pure + def toAlt(s: ActiveLedgerStateOriginal[NodeId]): ActiveLedgerState = { + ActiveLedgerState(s.locallyCreatedThisTimeline, s.consumedBy, s.localKeys) + } + + @pure @opaque + def altOriginalInverse(s: ActiveLedgerState): Unit = { + unfold(toOriginal(s)) + unfold(toAlt(toOriginal(s))) + }.ensuring(toAlt(toOriginal(s)) == s) + + @pure + @opaque + def originalAltInverse(s: ActiveLedgerStateOriginal[NodeId]): Unit = { + unfold(toAlt(s)) + unfold(toOriginal(toAlt(s))) + }.ensuring(toOriginal(toAlt(s)) == s) + + @pure + def toOriginal(s: State): StateOriginal[NodeId] = { + StateOriginal[NodeId]( + s.locallyCreated, + globalKeyInputs(s.globalKeys), + toOriginal(s.activeState), + s.rollbackStack.map(toOriginal), + ContractKeyUniquenessMode.Strict, + ) + } + @pure + def toAlt(consumed: Set[ContractId])(s: StateOriginal[NodeId]): State = { + State( + s.locallyCreated, + consumed, + globalKeys(s.globalKeyInputs), + toAlt(s.activeState), + s.rollbackStack.map(toAlt), + ) + } + + @pure + def toAlt[T](consumed: Set[ContractId])(e: Either[T, StateOriginal[NodeId]]): Either[T, State] = { + e.map(toAlt(consumed)) + } + + @pure + def toOriginal[T](e: Either[T, State]): Either[T, StateOriginal[NodeId]] = { + e.map(toOriginal) + } + + @pure + @opaque + def lookupActiveGlobalKeyInputToAlt( + s: StateOriginal[NodeId], + key: GlobalKey, + consumed: Set[ContractId], + ) = { + unfold(toAlt(consumed)(s)) + unfold(toAlt(s.activeState)) + unfold(keyMappingToActiveMapping(toAlt(s.activeState).consumedBy)) + unfold(globalKeys(s.globalKeyInputs)) + MapProperties.mapValuesGet(s.globalKeyInputs, toKeyMapping, key) + MapProperties.mapValuesGet( + toAlt(consumed)(s).globalKeys, + keyMappingToActiveMapping(toAlt(consumed)(s).activeState.consumedBy), + key, + ) + + }.ensuring( + s.lookupActiveGlobalKeyInput(key) == (toAlt(consumed)(s)).globalKeys + .mapValues(keyMappingToActiveMapping(toAlt(consumed)(s).activeState.consumedBy)) + .get(key) + ) + + @pure + @opaque + def getLocalActiveKeyToAlt( + s: ActiveLedgerStateOriginal[NodeId], + key: GlobalKey, + consumed: Set[ContractId], + ): Unit = { + unfold(toAlt(s)) + unfold(keyMappingToActiveMapping(toAlt(s).consumedBy)) + MapProperties.mapValuesGet( + s.localKeys, + (v: ContractId) => if (s.consumedBy.contains(v)) KeyInactive else KeyActive(v), + key, + ) + MapProperties.mapValuesGet( + toAlt(s).localKeys.mapValues(KeyActive), + keyMappingToActiveMapping(toAlt(s).consumedBy), + key, + ) + MapProperties.mapValuesGet(toAlt(s).localKeys, KeyActive, key) + }.ensuring( + s.getLocalActiveKey(key) == toAlt(s).localKeys + .mapValues(KeyActive) + .mapValues(keyMappingToActiveMapping(toAlt(s).consumedBy)) + .get(key) + ) + + @pure + @opaque + def lookupActiveKeyToAlt( + s: StateOriginal[NodeId], + key: GlobalKey, + consumed: Set[ContractId], + ): Unit = { + unfold(toAlt(consumed)(s).activeKeys) + unfold(toAlt(consumed)(s).keys) + MapProperties.mapValuesConcat( + toAlt(consumed)(s).globalKeys, + toAlt(consumed)(s).activeState.localKeys.mapValues(KeyActive), + keyMappingToActiveMapping(toAlt(consumed)(s).activeState.consumedBy), + ) + MapProperties.equalsGet( + toAlt(consumed)(s).activeKeys, + toAlt(consumed)(s).globalKeys + .mapValues(keyMappingToActiveMapping(toAlt(consumed)(s).activeState.consumedBy)) ++ + toAlt(consumed)(s).activeState.localKeys + .mapValues(KeyActive) + .mapValues(keyMappingToActiveMapping(toAlt(consumed)(s).activeState.consumedBy)), + key, + ) + MapAxioms.concatGet( + toAlt(consumed)(s).globalKeys + .mapValues(keyMappingToActiveMapping(toAlt(consumed)(s).activeState.consumedBy)), + toAlt(consumed)(s).activeState.localKeys + .mapValues(KeyActive) + .mapValues(keyMappingToActiveMapping(toAlt(consumed)(s).activeState.consumedBy)), + key, + ) + getLocalActiveKeyToAlt(s.activeState, key, consumed) + lookupActiveGlobalKeyInputToAlt(s, key, consumed) + }.ensuring( + s.lookupActiveKey(key) == toAlt(consumed)(s).activeKeys.get(key) + ) + + @pure + @opaque + def consumeToAlt( + s: StateOriginal[NodeId], + cid: ContractId, + nid: NodeId, + consumed: Set[ContractId], + ): Unit = { + unfold(toAlt(consumed)(s).consume(cid, nid)) + }.ensuring( + toAlt(consumed + cid)(s.copy(activeState = s.activeState.consume(cid, nid))) == toAlt(consumed)( + s + ).consume(cid, nid) + ) + +// + @pure @opaque + def globalKeyInputsContainsToAlt( + s: StateOriginal[NodeId], + gk: GlobalKey, + consumed: Set[ContractId], + ): Unit = { + unfold(containsKey(toAlt(consumed)(s))(gk)) + MapProperties.equalsKeySetContains(s.globalKeyInputs, globalKeys(s.globalKeyInputs), gk) + }.ensuring(s.globalKeyInputs.contains(gk) == containsKey(toAlt(consumed)(s))(gk)) + + @pure + @opaque + def globalKeyInputsUpdatedToAlt( + s: StateOriginal[NodeId], + gk: GlobalKey, + keyInput: KeyInput, + consumed: Set[ContractId], + ): Unit = { + unfold(globalKeys(s.globalKeyInputs)) + unfold(globalKeys(s.globalKeyInputs.updated(gk, keyInput))) + MapProperties.mapValuesUpdated(s.globalKeyInputs, gk, keyInput, toKeyMapping) + MapAxioms.extensionality( + s.globalKeyInputs.updated(gk, keyInput).mapValues(toKeyMapping), + s.globalKeyInputs.mapValues(toKeyMapping).updated(gk, keyInput.toKeyMapping), + ) + }.ensuring( + globalKeys(s.globalKeyInputs.updated(gk, keyInput)) == globalKeys(s.globalKeyInputs) + .updated(gk, keyInput.toKeyMapping) + ) + + @pure + @opaque + def visitCreateToAlt( + s: StateOriginal[NodeId], + contractId: ContractId, + mbKey: Option[GlobalKey], + consumed: Set[ContractId], + ): Unit = { + require(s.mode == ContractKeyUniquenessMode.Strict) + + unfold(addKey(toAlt(consumed)(s), mbKey, KeyInactive)) + unfold(addKey(toAlt(consumed)(s), mbKey, KeyInactive).visitCreate(contractId, mbKey)) + + mbKey match { + case None() => Trivial() + case Some(gk) => + // STEP 1: conflicts match + lookupActiveKeyToAlt(s, gk, consumed) + activeKeysAddKey(toAlt(consumed)(s), gk, KeyInactive) + unfold(keyMappingToActiveMapping(toAlt(consumed)(s).activeState.consumedBy)) + + val newKeyInputs = + if (s.globalKeyInputs.contains(gk)) s.globalKeyInputs + else s.globalKeyInputs.updated(gk, KeyCreate) + + // STEP 2: globalKeys(newKeyInputs) == toAlt(s).addKey(gk, KeyInactive).globalKeyInputs + unfold(addKey(toAlt(consumed)(s), gk, KeyInactive)) + globalKeyInputsContainsToAlt(s, gk, consumed) + globalKeyInputsUpdatedToAlt(s, gk, KeyCreate, consumed) + } + + }.ensuring( + addKey(toAlt(consumed)(s), mbKey, KeyInactive).visitCreate(contractId, mbKey) == toAlt( + consumed + )(s.visitCreate(contractId, mbKey)) + ) + + @pure + def extractPair( + e: Either[Option[ + ContractId + ] => (KeyMapping, StateOriginal[NodeId]), (KeyMapping, StateOriginal[NodeId])], + result: Option[ContractId], + ): (KeyMapping, StateOriginal[NodeId]) = { + e match { + case Left(handle) => handle(result) + case Right(p) => p + } + } + + @pure + def extractState( + e: Either[Option[ + ContractId + ] => (KeyMapping, StateOriginal[NodeId]), (KeyMapping, StateOriginal[NodeId])], + result: Option[ContractId], + ): StateOriginal[NodeId] = { + extractPair(e, result)._2 + } + + @pure + @opaque + def resolveKeyToAlt( + s: StateOriginal[NodeId], + result: Option[ContractId], + gkey: GlobalKey, + consumed: Set[ContractId], + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(toAlt(consumed)(s))(unbound, lc)) + + lookupActiveKeyToAlt(s, gkey, consumed) + invariantContainsKey(toAlt(consumed)(s), gkey, unbound, lc) +// globalKeyInputsContainsToAlt(s, gkey, consumed) + activeKeysAddKey(toAlt(consumed)(s), gkey, result) + unfold(addKey(toAlt(consumed)(s), gkey, result)) + unfold(toAlt(consumed)(s).activeKeys.contains) + + unfold(addKey(toAlt(consumed)(s), gkey, result)) + globalKeyInputsUpdatedToAlt( + s, + gkey, + result match { + case None() => NegativeKeyLookup + case Some(cid) => Transaction.KeyActive(cid) + }, + consumed, + ) + unfold( + toAlt(consumed)(s).activeKeys.getOrElse( + gkey, + keyMappingToActiveMapping(toAlt(consumed)(s).activeState.consumedBy)(result), + ) + ) + unfold(keyMappingToActiveMapping(toAlt(consumed)(s).activeState.consumedBy)) + + }.ensuring( + ( + addKey(toAlt(consumed)(s), gkey, result).activeKeys(gkey), + addKey(toAlt(consumed)(s), gkey, result), + ) == + (extractPair(s.resolveKey(gkey), result)._1, toAlt(consumed)( + extractState(s.resolveKey(gkey), result) + )) + ) + + @pure + @opaque + def visitLookupToAlt( + s: StateOriginal[NodeId], + gk: GlobalKey, + keyInput: Option[ContractId], + keyResolution: Option[ContractId], + consumed: Set[ContractId], + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(toAlt(consumed)(s))(unbound, lc)) + unfold(addKey(toAlt(consumed)(s), gk, keyInput).visitLookup(gk, keyResolution)) + resolveKeyToAlt(s, keyInput, gk, consumed, unbound, lc) + unfold(addKey(toAlt(consumed)(s), gk, keyInput).activeKeys.getOrElse(gk, KeyInactive)) + }.ensuring( + addKey(toAlt(consumed)(s), gk, keyInput).visitLookup(gk, keyResolution) == + toAlt(consumed)(s.visitLookup(gk, keyInput, keyResolution)) + ) + + @pure + @opaque + def assertKeyMappingToAlt( + s: StateOriginal[NodeId], + cid: ContractId, + gkey: Option[GlobalKey], + consumed: Set[ContractId], + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(toAlt(consumed)(s))(unbound, lc)) + unfold(addKey(toAlt(consumed)(s), gkey, Some(cid)).assertKeyMapping(cid, gkey)) + unfold(addKey(toAlt(consumed)(s), gkey, Some(cid))) + gkey match { + case None() => Trivial() + case Some(k) => visitLookupToAlt(s, k, Some(cid), Some(cid), consumed, unbound, lc) + } + + }.ensuring( + addKey(toAlt(consumed)(s), gkey, Some(cid)).assertKeyMapping(cid, gkey) == toAlt(consumed)( + s.assertKeyMapping(cid, gkey) + ) + ) + + @pure + @opaque + def visitExerciseToAlt( + s: StateOriginal[NodeId], + nodeId: NodeId, + targetId: ContractId, + gk: Option[GlobalKey], + byKey: Boolean, + consuming: Boolean, + consumed: Set[ContractId], + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(toAlt(consumed)(s))(unbound, lc)) + require(s.mode == ContractKeyUniquenessMode.Strict) + unfold( + addKey(toAlt(consumed)(s), gk, Some(targetId)) + .visitExercise(nodeId, targetId, gk, byKey, consuming) + ) + + assertKeyMappingToAlt(s, targetId, gk, consumed, unbound, lc) + s.assertKeyMapping(targetId, gk) match { + case Left(e) => Trivial() + case Right(state) => consumeToAlt(state, targetId, nodeId, consumed) + } + }.ensuring( + addKey(toAlt(consumed)(s), gk, Some(targetId)) + .visitExercise(nodeId, targetId, gk, byKey, consuming) == + toAlt(if (consuming) consumed + targetId else consumed)( + s.visitExercise(nodeId, targetId, gk, byKey, consuming) + ) + ) + + @pure @opaque + def nodeConsumed(consumed: Set[ContractId], n: Node): Set[ContractId] = { + n match { + case exe: Node.Exercise if exe.consuming => consumed + exe.targetCoid + case _ => consumed + } + } + + @pure + @opaque + def handleNodeOriginal( + e: Either[KeyInputError, StateOriginal[NodeId]], + nodeId: NodeId, + node: Node.Action, + keyInput: Option[ContractId], + ): Either[KeyInputError, StateOriginal[NodeId]] = { + e match { + case Right(s) => s.handleNode(nodeId, node, keyInput) + case Left(_) => e + } + } + + @pure @opaque + def handleNodeToAlt( + s: StateOriginal[NodeId], + nodeId: NodeId, + node: Node.Action, + keyInput: Option[ContractId], + consumed: Set[ContractId], + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(toAlt(consumed)(s))(unbound, lc)) + require(s.mode == ContractKeyUniquenessMode.Strict) + + unfold(nodeConsumed(consumed, node)) + unfold(addKeyBeforeAction(toAlt(consumed)(s), node)) + unfold(addKeyBeforeAction(toAlt(consumed)(s), node).handleNode(nodeId, node)) + unfold(nodeActionKeyMapping(node)) + + node match { + case create: Node.Create => visitCreateToAlt(s, create.coid, create.gkeyOpt, consumed) + case fetch: Node.Fetch => + assertKeyMappingToAlt(s, fetch.coid, fetch.gkeyOpt, consumed, unbound, lc) + case lookup: Node.LookupByKey => + unfold(addKey(toAlt(consumed)(s), node.gkeyOpt, nodeActionKeyMapping(node))) + visitLookupToAlt(s, lookup.gkey, lookup.result, lookup.result, consumed, unbound, lc) + case exe: Node.Exercise => + visitExerciseToAlt( + s, + nodeId, + exe.targetCoid, + exe.gkeyOpt, + exe.byKey, + exe.consuming, + consumed, + unbound, + lc, + ) + } + }.ensuring( + addKeyBeforeAction(toAlt(consumed)(s), node).handleNode(nodeId, node) == + toAlt(nodeConsumed(consumed, node))(s.handleNode(nodeId, node, keyInput)) + ) + + @pure + @opaque + def handleNodeToAlt( + e: Either[KeyInputError, StateOriginal[NodeId]], + nodeId: NodeId, + node: Node.Action, + keyInput: Option[ContractId], + consumed: Set[ContractId], + unbound: Set[ContractId], + lc: Set[ContractId], + ): Unit = { + require(stateInvariant(toAlt(consumed)(e))(unbound, lc)) + require(e.forall(s => s.mode == ContractKeyUniquenessMode.Strict)) + + unfold(addKeyBeforeNode(toAlt(consumed)(e), node)) + unfold(handleNodeOriginal(e, nodeId, node, keyInput)) + + e match { + case Right(s) => + unfold(addKeyBeforeNode(toAlt(consumed)(s), node)) + handleNodeToAlt(s, nodeId, node, keyInput, consumed, unbound, lc) + case Left(_) => Trivial() + } + + }.ensuring( + handleNode(addKeyBeforeNode(toAlt(consumed)(e), node), nodeId, node) == + toAlt(nodeConsumed(consumed, node))(handleNodeOriginal(e, nodeId, node, keyInput)) + ) + + @pure + @opaque + def beginRollbackToAlt(s: StateOriginal[NodeId], consumed: Set[ContractId]): Unit = { + unfold(toAlt(consumed)(s).beginRollback()) + }.ensuring( + toAlt(consumed)(s).beginRollback() == + toAlt(consumed)(s.beginRollback()) + ) + + @pure @opaque + def endRollbackOriginal( + s: StateOriginal[NodeId] + ): Either[KeyInputError, StateOriginal[NodeId]] = { + if (s.withinRollbackScope) { + Right[KeyInputError, StateOriginal[NodeId]](s.endRollback()) + } else { + Left[KeyInputError, StateOriginal[NodeId]]( + Left[InconsistentContractKey, DuplicateContractKey]( + InconsistentContractKey(GlobalKey(BigInt(0))) + ) + ) + } + } + + @pure + @opaque + def endRollbackToAlt(s: StateOriginal[NodeId], consumed: Set[ContractId]): Unit = { + unfold(toAlt(consumed)(s).endRollback()) + unfold(endRollbackOriginal(s)) + }.ensuring( + toAlt(consumed)(s).endRollback() == + toAlt(consumed)(endRollbackOriginal(s)) + ) + + @pure + @opaque + def advanceToAlt( + s: ActiveLedgerStateOriginal[NodeId], + substate: ActiveLedgerStateOriginal[NodeId], + ): Unit = { + + unfold(toAlt(s).advance(toAlt(substate))) + }.ensuring(toAlt(s).advance(toAlt(substate)) == toAlt(s.advance(substate))) + + /** Proof of only a part of advance equivalence. The other part is not here due to a Stainless bug. + */ + @pure @opaque + def advanceToAlt( + s: StateOriginal[NodeId], + substate: StateOriginal[NodeId], + resolver: KeyResolver, + consumed1: Set[ContractId], + consumed2: Set[ContractId], + ): Unit = { + require(s.mode == ContractKeyUniquenessMode.Strict) + require(!substate.withinRollbackScope) + require(!toAlt(consumed2)(substate).withinRollbackScope) + require(toAlt(consumed1)(s).advance(toAlt(consumed2)(substate)).isRight) + require(s.advance(resolver, substate).isRight) + + unfold(globalKeys(s.globalKeyInputs)) + unfold(globalKeys(substate.globalKeyInputs)) + unfold(globalKeys(substate.globalKeyInputs ++ s.globalKeyInputs)) + + MapProperties.mapValuesConcat(substate.globalKeyInputs, s.globalKeyInputs, toKeyMapping) + MapAxioms.extensionality( + globalKeys(substate.globalKeyInputs ++ s.globalKeyInputs), + globalKeys(substate.globalKeyInputs) ++ globalKeys(s.globalKeyInputs), + ) + advanceToAlt(s.activeState, substate.activeState) + + }.ensuring( + (toAlt(consumed1)(s).advance(toAlt(consumed2)(substate)).get == + toAlt(consumed1 ++ consumed2)(s.advance(resolver, substate).get)) + ) + +} diff --git a/daml-lf/verification/tree/TransactionTree.scala b/daml-lf/verification/tree/TransactionTree.scala new file mode 100644 index 0000000000..336fa64d45 --- /dev/null +++ b/daml-lf/verification/tree/TransactionTree.scala @@ -0,0 +1,532 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package tree + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ +import utils.TreeProperties._ + +import transaction.CSMHelpers._ +import transaction.CSMEitherDef._ +import transaction.CSMEither._ +import transaction.{State} + +/** Definitions and basic properties of simplified transaction traversal. + * + * In the contract state maching, handling a node comes in two major step: + * - Adding the node's key to the global keys with its corresponding mapping + * - Processing the node + * In the simplified version of the contract state machine, this behavior is respectively split in two different + * functions [[transaction.CSMKeysPropertiesDef.addKeyBeforeNode]] and [[transaction.State.handleNode]] + * + * A key property of transaction traversal is that one can first add the key-mapping pairs of every node in the + * globalKeys and then process the transaction. The proof of this claims lies in [[TransactionTreeFull]]. + * We therefore define the processing part of a transaction assuming that the initial state already contains + * all the key mappings it needs. + */ +object TransactionTreeDef { + + /** Function called when a node is entered for the first time ([[utils.TraversalDirection.Down]]). + * - If the node is an instance of [[transaction.Node.Action]] we handle it. + * - If it is a [[transaction.Node.Rollback]] node we call [[transaction.State.beginRollback]]. + * + * Among the direct properties one can deduce we have that + * - the [[transaction.State.globalKeys]] are not modified (since we did before processing the transaction) + * - if the inital state is an error then the result is also an error + * + * @param s State before entering the node for the first time + * @param p Node and its id + * @see f,,down,, in the associated latex document + */ + @pure @opaque + def traverseInFun( + s: Either[KeyInputError, State], + p: (NodeId, Node), + ): Either[KeyInputError, State] = { + p._2 match { + case r: Node.Rollback => beginRollback(s) + case a: Node.Action => handleNode(s, p._1, a) + } + }.ensuring(res => + sameGlobalKeys(s, res) && + propagatesError(s, res) + ) + + /** Function called when a node is entered for the second time ([[utils.TraversalDirection.Up]]). + * - If the node is an instance of [[transaction.Node.Action]] nothing happens + * - If it is a [[transaction.Node.Rollback]] node we call [[transaction.State.endRollback]] + * + * Among the direct properties one can deduce we have that + * - the [[transaction.State.globalKeys]] are not modified (since we did before processing the transaction) + * - the [[transaction.State.locallyCreated]] and [[transaction.State.consumed]] sets + * are not modified as well. + * - if the inital state is an error then the result is also an error + * + * @param s State before entering the node for the second time + * @param p Node and its id + * @see f,,up,, in the associated latex document + */ + @pure @opaque + def traverseOutFun( + s: Either[KeyInputError, State], + p: (NodeId, Node), + ): Either[KeyInputError, State] = { + + val res = p._2 match { + case r: Node.Rollback => endRollback(s) + case a: Node.Action => s + } + + @pure @opaque + def traverseOutFunProperties: Unit = { + + // case Rollback + endRollbackProp(s) + + // case Action + sameGlobalKeysReflexivity(s) + sameLocallyCreatedReflexivity(s) + sameConsumedReflexivity(s) + propagatesErrorReflexivity(s) + + }.ensuring( + sameGlobalKeys(s, res) && + sameLocallyCreated(s, res) && + sameConsumed(s, res) && + propagatesError(s, res) + ) + + traverseOutFunProperties + res + + }.ensuring(res => + sameGlobalKeys(s, res) && + sameLocallyCreated(s, res) && + sameConsumed(s, res) && + propagatesError(s, res) + ) + + /** List of triples whose respective entries are: + * - The state before the i-th step of the traversal + * - The pair node id - node that is handle during the i-th step + * - The direction i.e. if that's the first or the second time we enter the node + * + * @param tr The transaction that is being processed + * @param init The initial state (containing already all the necessary key-mapping pairs in the global keys) + */ + @pure + def scanTransaction( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): List[(Either[KeyInputError, State], (NodeId, Node), TraversalDirection)] = { + tr.scan(init, traverseInFun, traverseOutFun) + } + + /** Resulting state after a transaction traversal. + * + * @param tr The transaction that is being processed + * @param init The initial state (containing already all the necessary key-mapping pairs in the global keys) + */ + @pure + def traverseTransaction( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): Either[KeyInputError, State] = { + tr.traverse(init, traverseInFun, traverseOutFun) + } + +} + +object TransactionTree { + + import TransactionTreeDef._ + + /** Let i less or equal than j two integers. + * The states before the i-th and j-th step of the transaction traversal: + * - have the same [[transaction.State.globalKeys]] + * - are both erroneous if the state before the i-th step is erroneous + * + * @param tr The transaction that is being processed. + * @param init The initial state (containing already all the necessary key-mapping pairs in the global keys). + * @see The corresponding latex document for a pen and paper proof. + */ + @pure + @opaque + def scanTransactionProp( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + j: BigInt, + ): Unit = { + decreases(j) + require(0 <= i) + require(i <= j) + require(j < 2 * tr.size) + + val si = scanTransaction(tr, init)(i)._1 + + if (j == i) { + propagatesErrorReflexivity(si) + sameGlobalKeysReflexivity(si) + } else { + val sjm1 = scanTransaction(tr, init)(j - 1)._1 + val sj = scanTransaction(tr, init)(j)._1 + + scanTransactionProp(tr, init, i, j - 1) + scanIndexingState(tr, init, traverseInFun, traverseOutFun, j) + propagatesErrorTransitivity(si, sjm1, sj) + sameGlobalKeysTransitivity(si, sjm1, sj) + } + }.ensuring( + propagatesError(scanTransaction(tr, init)(i)._1, scanTransaction(tr, init)(j)._1) && + sameGlobalKeys(scanTransaction(tr, init)(i)._1, scanTransaction(tr, init)(j)._1) + ) + + /** Let i an integer. The inital state and the state before the i-th step of the transaction traversal: + * - have the same [[transaction.State.globalKeys]] + * - are both erroneous if the initial state is erroneous + * + * @param tr The transaction that is being processed. + * @param init The initial state (containing already all the necessary key-mapping pairs in the global keys). + * @see The corresponding latex document for a pen and paper proof. + */ + @pure + @opaque + def scanTransactionProp( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): Unit = { + require(0 <= i) + require(i < 2 * tr.size) + scanTransactionProp(tr, init, 0, i) + scanIndexingState(tr, init, traverseInFun, traverseOutFun, 0) + }.ensuring( + propagatesError(init, scanTransaction(tr, init)(i)._1) && + sameGlobalKeys(init, scanTransaction(tr, init)(i)._1) + ) + + /** Basic properties of transaction processing: + * - The [[transaction.State.globalKeys]] are not modified throughout the traversal + * - The [[transaction.State.rollbackStack]] is the same before and after the transaction + * - If the initial state is erroneous, so is the result of the traversal + * + * @param tr The transaction that is being processed. + * @param init The initial state (containing already all the necessary key-mapping pairs in the global keys). + * + * @note This is one of the only structural induction proof in the codebase. Whereas structural induction is + * not needed for the first and third property, it is for the second one. In fact, the property is not + * a local claim that is preserved every step. It is due to the symmetrical nature of the traversal (for + * rollbacks) and it would therefore not make sense finding a i such that the property is true at the + * i-th step but not the i+1-th one. + * Please also note that since we are working with opaqueness (almost) everywhere we will need to unfold + * many definition when proving such global properties. + * + * @see The corresponding latex document for a pen and paper proof. + * @see [[findBeginRollback]] and [[TransactionTreeChecks.traverseTransactionLC]] for other examples of + * structural induction proofs. + */ + @pure + @opaque + def traverseTransactionProp( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): Unit = { + decreases(tr) + unfold(tr.traverse(init, traverseInFun, traverseOutFun)) + tr match { + case Endpoint() => + sameGlobalKeysReflexivity(init) + sameStackReflexivity(init) + propagatesErrorReflexivity(init) + case ContentNode(n, sub) => + val e1 = traverseInFun(init, n) + val e2 = traverseTransaction(sub, e1) + val e3 = traverseTransaction(tr, init) + + traverseTransactionProp(sub, e1) + + // As the key property is a local one applying transitivity is sufficient + // we could have done the same with propagatesError but we need to unfold + // the latter to prove the stack equality property (which is global) + sameGlobalKeysTransitivity(init, e1, e2) + sameGlobalKeysTransitivity(init, e2, e3) + + unfold(traverseInFun(init, n)) + unfold(traverseOutFun(e2, n)) + unfold(beginRollback(init)) + unfold(endRollback(e2)) + + unfold(propagatesError(init, e1)) + unfold(propagatesError(e1, e2)) + unfold(propagatesError(e2, e3)) + unfold(propagatesError(init, e3)) + + unfold(sameStack(init, e1)) + unfold(sameStack(e1, e2)) + unfold(sameStack(e2, e3)) + unfold(sameStack(init, e3)) + + // If any of the intermediate states is not defined the result is immediate + // because of error propagation. Otherwise, we need to unfold the definition + // of beginRollback and endRollback as the result depends on the content of + // their body + (init, e1, e2) match { + case (Right(s0), Right(s1), Right(s2)) => + unfold(s0.beginRollback()) + unfold(s2.endRollback()) + + unfold(sameStack(s0, e1)) + unfold(sameStack(s1, e2)) + unfold(sameStack(s2, e3)) + unfold(sameStack(s0, e3)) + case (Right(s0), _, _) => unfold(sameStack(s0, e3)) + case _ => Trivial() + } + case ArticulationNode(l, r) => + val el = traverseTransaction(l, init) + val er = traverseTransaction(tr, init) + + traverseTransactionProp(l, init) + traverseTransactionProp(r, el) + + propagatesErrorTransitivity(init, el, er) + sameStackTransitivity(init, el, er) + sameGlobalKeysTransitivity(init, el, er) + } + }.ensuring( + sameGlobalKeys(init, traverseTransaction(tr, init)) && + sameStack(init, traverseTransaction(tr, init)) && + propagatesError(init, traverseTransaction(tr, init)) + ) + + /** If any state of the transaction traversal is erroneous, then the result of the traversal is erroneous as well + * + * @param tr The transaction that is being processed. + * @param init The initial state (containing already all the necessary key-mapping pairs in the global keys). + * @param i The number of the step during the traversal + * @see The corresponding latex document for a pen and paper proof. + */ + @pure + @opaque + def traverseTransactionDefined( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): Unit = { + require(0 <= i) + require(i < 2 * tr.size) + + if (tr.size > 0) { + scanIndexingState(tr, init, traverseInFun, traverseOutFun, 0) + } + scanTransactionProp(tr, init, i, 2 * tr.size - 1) + propagatesErrorTransitivity( + scanTransaction(tr, init)(i)._1, + scanTransaction(tr, init)(2 * tr.size - 1)._1, + traverseTransaction(tr, init), + ) + }.ensuring( + propagatesError(scanTransaction(tr, init)(i)._1, traverseTransaction(tr, init)) + ) + + /** Given a step number in which [[transaction.Node.Rollback]] is entered for the second time, returns the step + * number in which it is has been entered for the first time. Also returns a subtree with the following convenient + * properties. + * + * If the result of the function is (j, sub) then: + * - The size of sub is strictly smaller than the size of the transaction tree. + * - Due to causality, j is strictly smaller than the step number given in argument. + * - The state before entering the node for the second time is the result of traversing sub. The + * inital state of the traversal is the state after step j. + * + * This version of findBeginRollback deals with two traversal simultaneously. This is due to the fact that the result + * only depends on the shape of the transaction and is independent of the initial state. + * For a simplified version cf. below. + * + * Please note that the claim is valid only if the tree is unique. + * + * @param tr The transaction that is being processed. + * @param init1 The initial state of the first traversal. + * @param init2 The initial state of the second traversal. + * @param i The step number during which the node is entered for the second time. + * + * @note This is one of the only structural induction proof in the codebase. Because the global nature of the property, + * it is not possible to prove a local claim that is preserved during every step. This is due to the symmetry + * of the traversal (for rollbacks) and it would therefore not make sense finding a i such that the + * property is true at the i-th step but not the i+1-th one. + * + * @see The corresponding latex document for a pen and paper proof. + * @see [[traverseTransactionProp]] and [[TransactionTreeChecks.traverseTransactionLC]] for other examples of + * structural induction proofs. + */ + @pure + @opaque + def findBeginRollback( + tr: Tree[(NodeId, Node)], + init1: Either[KeyInputError, State], + init2: Either[KeyInputError, State], + i: BigInt, + ): (BigInt, Tree[(NodeId, Node)]) = { + decreases(tr) + require(tr.isUnique) + require(i >= 0) + require(i < 2 * tr.size) + require(scanTransaction(tr, init1)(i)._2._2.isInstanceOf[Node.Rollback]) + require(scanTransaction(tr, init1)(i)._3 == TraversalDirection.Up) + + unfold(tr.size) + unfold(tr.isUnique) + + scanIndexingNode( + tr, + init1, + init2, + traverseInFun, + traverseOutFun, + traverseInFun, + traverseOutFun, + i, + ) + + tr match { + case Endpoint() => Unreachable() + case ContentNode(c, str) => + scanIndexing(c, str, init1, traverseInFun, traverseOutFun, i) + scanIndexing(c, str, init2, traverseInFun, traverseOutFun, i) + + if (c == scanTransaction(tr, init1)(i)._2) { + scanIndexing(c, str, init1, traverseInFun, traverseOutFun, 0) + scanIndexing(c, str, init2, traverseInFun, traverseOutFun, 0) + scanIndexing(c, str, init1, traverseInFun, traverseOutFun, 2 * tr.size - 1) + scanIndexing(c, str, init2, traverseInFun, traverseOutFun, 2 * tr.size - 1) + isUniqueIndexing(tr, init1, traverseInFun, traverseOutFun, i, 2 * tr.size - 1) + isUniqueIndexing(tr, init2, traverseInFun, traverseOutFun, i, 2 * tr.size - 1) + + unfold(traverseInFun(init1, c)) + unfold(traverseInFun(init2, c)) + unfold(traverseOutFun(init1, c)) + unfold(traverseOutFun(init2, c)) + + (BigInt(0), str) + } else { + val (j, sub) = + findBeginRollback(str, traverseInFun(init1, c), traverseInFun(init2, c), i - 1) + scanIndexing(c, str, init1, traverseInFun, traverseOutFun, j + 1) + scanIndexing(c, str, init2, traverseInFun, traverseOutFun, j + 1) + + (j + 1, sub) + } + case ArticulationNode(l, r) => + scanIndexing(l, r, init1, traverseInFun, traverseOutFun, i) + scanIndexing(l, r, init2, traverseInFun, traverseOutFun, i) + + if (i < 2 * l.size) { + val (j, sub) = findBeginRollback(l, init1, init2, i) + scanIndexing(l, r, init1, traverseInFun, traverseOutFun, j) + scanIndexing(l, r, init2, traverseInFun, traverseOutFun, j) + + // Required for performance (timeout: 2) + assert( + scanTransaction(tr, init1)(i)._1 == traverseTransaction( + sub, + beginRollback(scanTransaction(tr, init1)(j)._1), + ) + ) + assert( + scanTransaction(tr, init2)(i)._1 == traverseTransaction( + sub, + beginRollback(scanTransaction(tr, init2)(j)._1), + ) + ) + + (j, sub) + } else { + val (j, sub) = findBeginRollback( + r, + traverseTransaction(l, init1), + traverseTransaction(l, init2), + i - 2 * l.size, + ) + scanIndexing(l, r, init1, traverseInFun, traverseOutFun, j + 2 * l.size) + scanIndexing(l, r, init2, traverseInFun, traverseOutFun, j + 2 * l.size) + + (j + 2 * l.size, sub) + } + } + }.ensuring((j, sub) => + 0 <= j && j < i && sub.size < tr.size && + (scanTransaction(tr, init1)(j)._2 == scanTransaction(tr, init1)(i)._2) && + (scanTransaction(tr, init2)(j)._2 == scanTransaction(tr, init2)(i)._2) && + (scanTransaction(tr, init1)(j)._3 == TraversalDirection.Down) && + (scanTransaction(tr, init2)(j)._3 == TraversalDirection.Down) && + (scanTransaction(tr, init1)(i)._1 == traverseTransaction( + sub, + beginRollback(scanTransaction(tr, init1)(j)._1), + )) && + (scanTransaction(tr, init2)(i)._1 == traverseTransaction( + sub, + beginRollback(scanTransaction(tr, init2)(j)._1), + )) + ) + + /** Given a step number in which [[transaction.Node.Rollback]] is entered for the second time, returns the step + * number in which it is has been entered for the first time. Also returns a subtree with the following convenient + * properties. + * + * If the result of the function is (j, sub) then: + * - The size of sub is strictly smaller than the size of the transaction tree. + * - Due to causality, j is strictly smaller than the step number given in argument. + * - The state before entering the node for the second time is the result of traversing sub. The + * inital state of the traversal is the state after step j. + * + * This results only depends on the shape of the transaction and is independent of the initial state. + * For a version dealing with multiple transaction at the same time cf. the more complex version above. + * + * Please note that the claim is valid only if the tree is unique. + * + * @param tr The transaction that is being processed. + * @param init The initial state of the traversal. + * @param i The step number during which the node is entered for the second time. + * @see The corresponding latex document for a pen and paper proof. + */ + @pure + @opaque + def findBeginRollback( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): (BigInt, Tree[(NodeId, Node)]) = { + require(tr.isUnique) + require(i >= 0) + require(i < 2 * tr.size) + require(scanTransaction(tr, init)(i)._2._2.isInstanceOf[Node.Rollback]) + require(scanTransaction(tr, init)(i)._3 == TraversalDirection.Up) + + findBeginRollback(tr, init, init, i) + }.ensuring((j, sub) => + 0 <= j && j < i && sub.size < tr.size && + (scanTransaction(tr, init)(j)._2 == scanTransaction(tr, init)(i)._2) && + (scanTransaction(tr, init)(j)._3 == TraversalDirection.Down) && + (scanTransaction(tr, init)(i)._1 == traverseTransaction( + sub, + beginRollback(scanTransaction(tr, init)(j)._1), + )) + ) +} diff --git a/daml-lf/verification/tree/TransactionTreeAdvance.scala b/daml-lf/verification/tree/TransactionTreeAdvance.scala new file mode 100644 index 0000000000..f92ecd91d8 --- /dev/null +++ b/daml-lf/verification/tree/TransactionTreeAdvance.scala @@ -0,0 +1,480 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package tree + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ +import utils.TreeProperties._ + +import transaction.{State} +import transaction.CSMHelpers._ +import transaction.CSMAdvanceDef._ +import transaction.CSMAdvance._ +import transaction.CSMEitherDef._ +import transaction.ContractStateMachine._ + +import TransactionTreeDef._ +import TransactionTree._ + +/** File stating how the activeState evolve while processing a transaction. We want to prove that the activeState of the + * resulting state of a traversal can be expressed in terms of the inital state active state and the advance method. + */ +object TransactionTreeAdvanceDef { + + /** Function called when a node is entered for the first time ([[utils.TraversalDirection.Down]]). + * Compute the active state and the rollbackStack of a state after processing a node given only the rollbackStack + * and the active state of it behorehand. + * + * @param s rollbackStack and activeState of the state before the node + * @param p Node and its id + */ + @pure + @opaque + def activeStateInFun( + s: (ActiveLedgerState, List[ActiveLedgerState]), + p: (NodeId, Node), + ): (ActiveLedgerState, List[ActiveLedgerState]) = { + p._2 match { + case a: Node.Action => (s._1.advance(actionActiveStateAddition(p._1, a)), s._2) + case r: Node.Rollback => (s._1, s._1 :: s._2) + } + } + + /** Function called when a node is entered for the second time ([[utils.TraversalDirection.Up]]). + * Compute the active state and the rollbackStack of a state after processing a node given only the rollbackStack + * and the active state of it behorehand. + * + * @param s rollbackStack and activeState of the state before the node + * @param p Node and its id + */ + @pure + @opaque + def activeStateOutFun( + s: (ActiveLedgerState, List[ActiveLedgerState]), + p: (NodeId, Node), + ): (ActiveLedgerState, List[ActiveLedgerState]) = { + p._2 match { + case a: Node.Action => s + case r: Node.Rollback => + s._2 match { + case Nil() => s + case Cons(h, t) => (h, t) + } + } + } + + /** List of triples whose respective entries are: + * - The activeState and the rollbackStack before the i-th step of the traversal + * - The pair node id - node that is handle during the i-th step + * - The direction i.e. if that's the first or the second time we enter the node + * + * @param tr The transaction that is being processed + * @param init The activeState and teh rollbackStack of the initial state of the traversal + */ + @pure + def scanActiveState( + tr: Tree[(NodeId, Node)], + init: (ActiveLedgerState, List[ActiveLedgerState]), + ): List[((ActiveLedgerState, List[ActiveLedgerState]), (NodeId, Node), TraversalDirection)] = { + tr.scan(init, activeStateInFun, activeStateOutFun) + } + + /** [[scanActiveState]] where the inital state is empty + */ + @pure + def scanActiveState( + tr: Tree[(NodeId, Node)] + ): List[((ActiveLedgerState, List[ActiveLedgerState]), (NodeId, Node), TraversalDirection)] = { + scanActiveState(tr, (ActiveLedgerState.empty, Nil[ActiveLedgerState]())) + } + + /** Computes the activeState and the rollbackStack of the state obtained after processing a transaction. + * + * @param tr The transaction that is being processed + * @param init The activeState and the rollbackStack of the initial state of the traversal + */ + @pure + def traverseActiveState( + tr: Tree[(NodeId, Node)], + init: (ActiveLedgerState, List[ActiveLedgerState]), + ): (ActiveLedgerState, List[ActiveLedgerState]) = { + tr.traverse(init, activeStateInFun, activeStateOutFun) + } + + /** [[traverseActiveState]] where the inital state is empty + */ + @pure + def traverseActiveState( + tr: Tree[(NodeId, Node)] + ): (ActiveLedgerState, List[ActiveLedgerState]) = { + traverseActiveState(tr, (ActiveLedgerState.empty, Nil[ActiveLedgerState]())) + } +} + +object TransactionTreeAdvance { + + import TransactionTreeAdvanceDef._ + + /** The rollbackStack of the initial and final state of a transaction traversal are equal + * + * @param tr The transaction that is being processed + * @param init The initial state of the traversal + * + * @see [[TransactionTree.traverseTransactionProp]] for an alternative proof of the same concept. + */ + @pure + @opaque + def traverseActiveStateSameStack( + tr: Tree[(NodeId, Node)], + init: (ActiveLedgerState, List[ActiveLedgerState]), + ): Unit = { + decreases(tr) + unfold(tr.traverse(init, activeStateInFun, activeStateOutFun)) + tr match { + case Endpoint() => Trivial() + case ContentNode(n, sub) => + val a1 = activeStateInFun(init, n) + traverseActiveStateSameStack(sub, a1) + unfold(activeStateInFun(init, n)) + unfold(activeStateOutFun(traverseActiveState(sub, a1), n)) + case ArticulationNode(l, r) => + val al = traverseActiveState(l, init) + traverseActiveStateSameStack(l, init) + traverseActiveStateSameStack(r, al) + } + }.ensuring( + init._2 == traverseActiveState(tr, init)._2 + ) + + /** The rollbackStacks of the state before entering a node for the first time and after having entered it for the second + * time are the same. + * + * @param tr The transaction that is being processed. + * @param init The activeState and the rollbackStack of the initial state of the transaction + * @param i The step number during which the node is entered for the first time. + * @param j The step number during which the node is entered for the second time. + * + * @note This is one of the only structural induction proof in the codebase. Because the global nature of the property, + * it is not possible to prove a local claim that is preserved during every step. This is due to the symmetry + * of the traversal (for rollbacks) and it would therefore not make sense finding a i such that the + * property is true at the i-th step but not the i+1-th one. + * @see [[TransactionTree.findBeginRollback]] for an alternative proof of the same concept. + */ + @pure + @opaque + def findBeginRollbackActiveState( + tr: Tree[(NodeId, Node)], + init: (ActiveLedgerState, List[ActiveLedgerState]), + i: BigInt, + j: BigInt, + ): Unit = { + decreases(tr) + + require(tr.isUnique) + require(i >= 0) + require(i <= j) + require(j < 2 * tr.size) + require(scanActiveState(tr, init)(i)._2 == scanActiveState(tr, init)(j)._2) + require(scanActiveState(tr, init)(i)._2._2.isInstanceOf[Node.Rollback]) + require(scanActiveState(tr, init)(i)._3 == TraversalDirection.Down) + require(scanActiveState(tr, init)(j)._3 == TraversalDirection.Up) + + unfold(tr.size) + unfold(tr.isUnique) + + tr match { + case Endpoint() => Unreachable() + case ContentNode(c, str) => + scanIndexing(c, str, init, activeStateInFun, activeStateOutFun, i) + scanIndexing(c, str, init, activeStateInFun, activeStateOutFun, j) + + if ((i == 0) || (j == 2 * tr.size - 1)) { + scanIndexing(c, str, init, activeStateInFun, activeStateOutFun, 0) + scanIndexing(c, str, init, activeStateInFun, activeStateOutFun, 2 * tr.size - 1) + isUniqueIndexing(tr, init, activeStateInFun, activeStateOutFun, 0, i) + isUniqueIndexing(tr, init, activeStateInFun, activeStateOutFun, j, 2 * tr.size - 1) + traverseActiveStateSameStack(str, activeStateInFun(init, c)) + unfold(activeStateInFun(init, c)) + } else { + findBeginRollbackActiveState(str, activeStateInFun(init, c), i - 1, j - 1) + } + case ArticulationNode(l, r) => + scanIndexing(l, r, init, activeStateInFun, activeStateOutFun, i) + scanIndexing(l, r, init, activeStateInFun, activeStateOutFun, j) + + if (j < 2 * l.size) { + findBeginRollbackActiveState(l, init, i, j) + } else if (i >= 2 * l.size) { + findBeginRollbackActiveState( + r, + l.traverse(init, activeStateInFun, activeStateOutFun), + i - 2 * l.size, + j - 2 * l.size, + ) + } else { + scanContains(l, init, activeStateInFun, activeStateOutFun, i) + scanContains( + r, + l.traverse(init, activeStateInFun, activeStateOutFun), + activeStateInFun, + activeStateOutFun, + j - 2 * l.size, + ) + SetProperties.disjointContains( + l.content, + r.content, + tr.scan(init, activeStateInFun, activeStateOutFun)(i)._2, + ) + SetProperties.disjointContains( + l.content, + r.content, + tr.scan(init, activeStateInFun, activeStateOutFun)(j)._2, + ) + Unreachable() + } + } + }.ensuring( + scanActiveState(tr, init)(i)._1._1 :: scanActiveState(tr, init)(i)._1._2 == scanActiveState( + tr, + init, + )(j)._1._2 + ) + + /** For any step in a transaction traversal, [[scanActiveState]] computes the activeState of the intermediate state + * of that step. + * + * @param tr The transaction that is being processed. + * @param init The initial state of the traversal. + * @param i The step number. + */ + @pure + @opaque + def scanActiveStateAdvance( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): Unit = { + decreases(i) + + require(i >= 0) + require(i < 2 * tr.size) + require(init.isRight) + require(scanTransaction(tr, init)(i)._1.isRight) + require(tr.isUnique) + + scanIndexingState(tr, init, traverseInFun, traverseOutFun, i) + scanIndexingState( + tr, + (ActiveLedgerState.empty, Nil[ActiveLedgerState]()), + activeStateInFun, + activeStateOutFun, + i, + ) + + if (i == 0) { + emptyAdvance(init.get.activeState) + } else { + scanTransactionProp(tr, init, i - 1, i) + unfold(propagatesError(scanTransaction(tr, init)(i - 1)._1, scanTransaction(tr, init)(i)._1)) + scanActiveStateAdvance(tr, init, i - 1) + scanIndexingNode( + tr, + (ActiveLedgerState.empty, Nil[ActiveLedgerState]()), + init, + activeStateInFun, + activeStateOutFun, + traverseInFun, + traverseOutFun, + i - 1, + ) + + val (si, n, dir) = scanTransaction(tr, init)(i - 1) + val ai = scanActiveState(tr)(i - 1)._1 + + if (dir == TraversalDirection.Down) { + unfold(traverseInFun(si, n)) + unfold(activeStateInFun(ai, n)) + + n._2 match { + case a: Node.Action => + unfold(handleNode(si, n._1, a)) + handleNodeActiveState(si.get, n._1, a) + advanceAssociativity(init.get.activeState, ai._1, actionActiveStateAddition(n._1, a)) + case r: Node.Rollback => + unfold(beginRollback(si)) + unfold(si.get.beginRollback()) + } + } else { + unfold(traverseOutFun(si, n)) + unfold(activeStateOutFun(ai, n)) + n._2 match { + case a: Node.Action => Trivial() + case r: Node.Rollback => + val (j, sub) = findBeginRollback(tr, init, i - 1) + val sj = scanTransaction(tr, init)(j)._1 + + traverseTransactionProp(sub, beginRollback(sj)) + scanTransactionProp(tr, init, j, i) + unfold(propagatesError(sj, scanTransaction(tr, init)(i)._1)) + + unfold(sameStack(beginRollback(sj), si)) + unfold(sameStack(sj.get.beginRollback(), si)) + unfold(endRollback(si)) + unfold(si.get.endRollback()) + unfold(beginRollback(sj)) + unfold(sj.get.beginRollback()) + + scanIndexingNode( + tr, + (ActiveLedgerState.empty, Nil[ActiveLedgerState]()), + init, + activeStateInFun, + activeStateOutFun, + traverseInFun, + traverseOutFun, + j, + ) + + findBeginRollbackActiveState( + tr, + (ActiveLedgerState.empty, Nil[ActiveLedgerState]()), + j, + i - 1, + ) + scanActiveStateAdvance(tr, init, j) + } + } + } + + }.ensuring( + (scanTransaction(tr, init)(i)._1.get.activeState == + init.get.activeState.advance(scanActiveState(tr)(i)._1._1)) + ) + + /** [[traverseActiveState]] computes the activeState of the resulting state of transaction traversal. + * + * @param tr The transaction that is being processed. + * @param init The initial state of the traversal. + */ + @pure + @opaque + def traverseActiveStateAdvance( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): Unit = { + + require(init.isRight) + require(traverseTransaction(tr, init).isRight) + require(tr.isUnique) + + if (tr.size == 0) { + emptyAdvance(init.get.activeState) + } else { + scanIndexingState(tr, init, traverseInFun, traverseOutFun, 0) + scanIndexingState( + tr, + (ActiveLedgerState.empty, Nil[ActiveLedgerState]()), + activeStateInFun, + activeStateOutFun, + 0, + ) + + traverseTransactionDefined(tr, init, 2 * tr.size - 1) + unfold( + propagatesError( + scanTransaction(tr, init)(2 * tr.size - 1)._1, + traverseTransaction(tr, init), + ) + ) + scanActiveStateAdvance(tr, init, 2 * tr.size - 1) + scanIndexingNode( + tr, + (ActiveLedgerState.empty, Nil[ActiveLedgerState]()), + init, + activeStateInFun, + activeStateOutFun, + traverseInFun, + traverseOutFun, + 2 * tr.size - 1, + ) + + val (si, n, dir) = scanTransaction(tr, init)(2 * tr.size - 1) + val ai = scanActiveState(tr)(2 * tr.size - 1)._1 + + if (dir == TraversalDirection.Down) { + unfold(traverseInFun(si, n)) + unfold(activeStateInFun(ai, n)) + + n._2 match { + case a: Node.Action => + unfold(handleNode(si, n._1, a)) + handleNodeActiveState(si.get, n._1, a) + advanceAssociativity(init.get.activeState, ai._1, actionActiveStateAddition(n._1, a)) + case r: Node.Rollback => + unfold(beginRollback(si)) + unfold(si.get.beginRollback()) + } + } else { + unfold(traverseOutFun(si, n)) + unfold(activeStateOutFun(ai, n)) + n._2 match { + case a: Node.Action => Trivial() + case r: Node.Rollback => + val (j, sub) = findBeginRollback(tr, init, 2 * tr.size - 1) + val sj = scanTransaction(tr, init)(j)._1 + + traverseTransactionProp(sub, beginRollback(sj)) + traverseTransactionDefined(tr, init, j) + unfold(propagatesError(sj, traverseTransaction(tr, init))) + + unfold(sameStack(beginRollback(sj), si)) + unfold(sameStack(sj.get.beginRollback(), si)) + unfold(endRollback(si)) + unfold(si.get.endRollback()) + unfold(beginRollback(sj)) + unfold(sj.get.beginRollback()) + + scanIndexingNode( + tr, + (ActiveLedgerState.empty, Nil[ActiveLedgerState]()), + init, + activeStateInFun, + activeStateOutFun, + traverseInFun, + traverseOutFun, + j, + ) + + findBeginRollbackActiveState( + tr, + (ActiveLedgerState.empty, Nil[ActiveLedgerState]()), + j, + 2 * tr.size - 1, + ) + scanActiveStateAdvance(tr, init, j) + } + } + } + + }.ensuring( + (traverseTransaction(tr, init).get.activeState == + init.get.activeState.advance(traverseActiveState(tr)._1)) + ) + +} diff --git a/daml-lf/verification/tree/TransactionTreeChecks.scala b/daml-lf/verification/tree/TransactionTreeChecks.scala new file mode 100644 index 0000000000..e74332fab9 --- /dev/null +++ b/daml-lf/verification/tree/TransactionTreeChecks.scala @@ -0,0 +1,1175 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package tree + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ +import utils.TreeProperties._ + +import transaction.CSMHelpers._ +import transaction.CSMEither._ +import transaction.CSMLocallyCreatedProperties._ +import transaction.{State} +import transaction.CSMEitherDef._ + +import TransactionTreeDef._ +import TransactionTree._ + +/** This file introduces two traversals whose purpose is to check whether the transaction is well-formed. + * + * The first one, [[TransactionTreeChecksDef.traverseUnbound]] checks whether for every contract in the + * transaction is either bound to a key every time it appears in a node (the key should be the same but + * this is not checked here yet) or it is not bound to any key every time it appears inside a node. + * + * The first one, [[TransactionTreeChecksDef.traverseLC]] checks whether for every contract in the + * transaction is created only once and it has not been consumed before. By doing so, we collect all the + * locally created and consumed contracts. We show later that the collected sets are actually the + * [[State.locallyCreated]] and [[State.consumed]] fields of the state obtained after processing a + * transaction. This allows us to isolate the two fields and manipulate them more easily. + */ +object TransactionTreeChecksDef { + + /** --------------------------------------------------------------------------------------------------------------- * + * -------------------------------------------------UNBOUND-------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** Function used the first time we visit a [[Node]] during the [[Tree]] traversal that gathers all the unbound + * contracts i.e. those that are not associated to any key. + * + * @param s The state of the traversal before processing the node. It is a triple whose entries are: + * - The unbound contracts + * - The contracts not bound to any key + * - A boolean checking that no contract belongs to the two sets at the same time. In fact a tree where a + * contract appears both in a node with a well-defined key and at the same time in a node with no key, is + * not valid + * @param p The node we are currently processing + * @return The triple updated after processing the node. In particular checks that the new entry is not already + * contained in the second set if it is added to the first one and vice-versa. + */ + @pure + @opaque + def unboundFun( + s: (Set[ContractId], Set[ContractId], Boolean), + p: (NodeId, Node), + ): (Set[ContractId], Set[ContractId], Boolean) = { + SetProperties.subsetOfReflexivity(s._1) + SetProperties.subsetOfReflexivity(s._2) + p._2 match { + case Node.Create(coid, None()) => (s._1 + coid, s._2, s._3 && !s._2.contains(coid)) + case Node.Create(coid, _) => (s._1, s._2 + coid, s._3 && !s._1.contains(coid)) + case Node.Fetch(coid, None(), _) => (s._1 + coid, s._2, s._3 && !s._2.contains(coid)) + case Node.Fetch(coid, _, _) => (s._1, s._2 + coid, s._3 && !s._1.contains(coid)) + case Node.Exercise(targetCoid, _, _, None(), _) => + (s._1 + targetCoid, s._2, s._3 && !s._2.contains(targetCoid)) + case Node.Exercise(targetCoid, _, _, _, _) => + (s._1, s._2 + targetCoid, s._3 && !s._1.contains(targetCoid)) + case Node.LookupByKey(_, Some(coid)) => (s._1, s._2 + coid, s._3 && !s._1.contains(coid)) + case _ => s + } + }.ensuring(res => + s._1.subsetOf(res._1) && + s._2.subsetOf(res._2) && + (res._3 ==> s._3) + ) + + /** List of triples whose respective entries are: + * - The unbound, bound contracts and whether an error has been found before the i-th step of the traversal. + * - The pair node id - node that is handle during the i-th step + * - The direction i.e. if that's the first or the second time we enter the node + * + * @param tr The transaction that is being processed + */ + @pure + def scanUnbound( + tr: Tree[(NodeId, Node)] + ): List[((Set[ContractId], Set[ContractId], Boolean), (NodeId, Node), TraversalDirection)] = { + tr.scan((Set.empty[ContractId], Set.empty[ContractId], true), unboundFun, (z, t) => z) + } + + /** Set of unbound and bound contracts in the tree. A bound contract is a contract that has a global key assigned to + * it.The third entry of the triple indicates whether an error has been found while traversing the tree. + * If an error arises (the boolean is false) it is due to the fact that a contract appeared in a node that did not + * have a key and at the same time in a node with a well-defined key. + * + * In fact both sets should be disjoint as we prove in [[TransactionTreeChecks.traverseUnboundProp]]. + * + * @param tr The transaction that is being processed + */ + @pure + def traverseUnbound(tr: Tree[(NodeId, Node)]): (Set[ContractId], Set[ContractId], Boolean) = { + tr.traverse((Set.empty[ContractId], Set.empty[ContractId], true), unboundFun, (z, t) => z) + } + + /** --------------------------------------------------------------------------------------------------------------- * + * ----------------------------------------------------LC---------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** Function used the first time we visit a [[Node]] during the [[Tree]] traversal that gathers all the created and + * consumed contract in the tree. + * + * @param s The state of the traversal before processing the node. It is a triple whose entries are: + * - The created contracts in the tree until that point + * - The consumed contracts in the tree until that point + * - A boolean checking that a contract is not created twice or that a consumed contract is not created again. + * @param p The node we are currently processing + * @return The triple updated after processing the node. + */ + @pure + @opaque + def buildLC( + s: (Set[ContractId], Set[ContractId], Boolean), + p: (NodeId, Node), + ): (Set[ContractId], Set[ContractId], Boolean) = { + SetProperties.subsetOfReflexivity(s._1) + SetProperties.subsetOfReflexivity(s._2) + p._2 match { + case Node.Create(coid, _) => + (s._1 + coid, s._2, s._3 && !s._1.contains(coid) && !s._2.contains(coid)) + case exe: Node.Exercise if exe.consuming => + (s._1, s._2 + exe.targetCoid, s._3) + case _ => s + } + }.ensuring(res => s._1.subsetOf(res._1) && s._2.subsetOf(res._2) && (!s._3 ==> !res._3)) + + /** List of triples whose respective entries are: + * - The locally created, consumed contracts and whether an error has been found before the i-th step of the traversal. + * - The pair node id - node that is handle during the i-th step + * - The direction i.e. if that's the first or the second time we enter the node + * + * @param tr The transaction that is being processed + * @param lc The inital set of locally created contracts. Should be empty by default. + * @param consumed The inital set of consumed contracts. Should be empty by default. + * @param defined Indicates whether there is already an error or not. Should be true by default (i.e. no error) + */ + @pure + def scanLC( + tr: Tree[(NodeId, Node)], + lc: Set[ContractId], + consumed: Set[ContractId], + defined: Boolean, + ): List[((Set[ContractId], Set[ContractId], Boolean), (NodeId, Node), TraversalDirection)] = { + tr.scan((lc, consumed, defined), buildLC, (z, t) => z) + } + + /** Set of locally created and consumed contracts in the transaction. The third entry of the triple indicates whether + * an error has been found while traversing the tree. If an error arises (the boolean is false) it can either be due + * to: + * - A contract has been created twice with the same id + * - A contract has been created after it was consumed + * + * @param tr The transaction that is being processed + * @param lc The inital set of locally created contracts. Should be empty by default. + * @param consumed The inital set of consumed contracts. Should be empty by default. + * @param defined Indicates whether there is already an error or not. Should be true by default (i.e. no error) + */ + @pure + def traverseLC( + tr: Tree[(NodeId, Node)], + lc: Set[ContractId], + consumed: Set[ContractId], + defined: Boolean, + ): (Set[ContractId], Set[ContractId], Boolean) = { + tr.traverse((lc, consumed, defined), buildLC, (z, t) => z) + } + +} + +object TransactionTreeChecks { + + import TransactionTreeChecksDef._ + + /** --------------------------------------------------------------------------------------------------------------- * + * -------------------------------------------------UNBOUND-------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** Properties between intermediate states of the unbound contracts traversal and the final one + * - Any intermediate state of the set of unbound contracts is a subset of the final state, i.e. the set of unbound + * contracts of the transaction. + * - The same claim holds for bound contracts. + * - If there is no error in the end, then there is no error during any intermediate step. + * - For every step of the traversal, if an error did not arise, the set of bound and unbound contracts is disjoint + * + * The proof goes by contradiction. + * + * @param tr The tree that is being traversed + * @param i The number of the step we are looking at in the traversal + */ + @pure + @opaque + def scanTraverseUnboundProp(tr: Tree[(NodeId, Node)], i: BigInt): Unit = { + require(i < 2 * tr.size) + require(0 <= i) + + val init = (Set.empty[ContractId], Set.empty[ContractId], true) + val (tr1, tr2, tr3) = traverseUnbound(tr) + + // proving the first three claims by contradiction + val p1: ((Set[ContractId], Set[ContractId], Boolean)) => Boolean = (s1, s2, b) => + s1.subsetOf(tr1) && + s2.subsetOf(tr2) && + (tr3 ==> b) + + SetProperties.subsetOfReflexivity(tr1) + SetProperties.subsetOfReflexivity(tr2) + + if (!p1(scanUnbound(tr)(i)._1)) { + val j = scanNotPropRev(tr, init, unboundFun, (z, t) => z, p1, i) + val (sj1, sj2, _) = scanUnbound(tr)(j)._1 + + SetProperties.subsetOfReflexivity(sj1) + SetProperties.subsetOfReflexivity(sj2) + if (j == 2 * tr.size - 1) { + scanIndexingState(tr, init, unboundFun, (z, t) => z, 0) + } else { + scanIndexingState(tr, init, unboundFun, (z, t) => z, j + 1) + SetProperties.subsetOfTransitivity(sj1, scanUnbound(tr)(j + 1)._1._1, tr1) + SetProperties.subsetOfTransitivity(sj2, scanUnbound(tr)(j + 1)._1._2, tr2) + } + } + + // proving the last claim by contradiction as well + val p2: ((Set[ContractId], Set[ContractId], Boolean)) => Boolean = + (s1, s2, b) => b ==> s1.disjoint(s2) + + SetProperties.disjointEmpty(Set.empty[ContractId]) + if (!p2(scanUnbound(tr)(i)._1)) { + val j = scanNotProp(tr, init, unboundFun, (z, t) => z, p2, i) + scanIndexingState(tr, init, unboundFun, (z, t) => z, j + 1) + val (s, p, _) = scanUnbound(tr)(j) + unfold(unboundFun(s, p)) + + if (!s._3) { + Trivial() + } else { + p._2 match { + case Node.Create(coid, None()) => + SetProperties.disjointSymmetry(s._2, s._1) + SetProperties.disjointIncl(s._2, s._1, coid) + SetProperties.disjointSymmetry(s._2, s._1 + coid) + case Node.Create(coid, _) => + SetProperties.disjointIncl(s._1, s._2, coid) + case Node.Fetch(coid, None(), _) => + SetProperties.disjointSymmetry(s._2, s._1) + SetProperties.disjointIncl(s._2, s._1, coid) + SetProperties.disjointSymmetry(s._2, s._1 + coid) + case Node.Fetch(coid, _, _) => + SetProperties.disjointIncl(s._1, s._2, coid) + case Node.Exercise(targetCoid, _, _, None(), _) => + SetProperties.disjointSymmetry(s._2, s._1) + SetProperties.disjointIncl(s._2, s._1, targetCoid) + SetProperties.disjointSymmetry(s._2, s._1 + targetCoid) + case Node.Exercise(targetCoid, _, _, _, _) => + SetProperties.disjointIncl(s._1, s._2, targetCoid) + case Node.LookupByKey(_, Some(coid)) => + SetProperties.disjointIncl(s._1, s._2, coid) + case _ => Trivial() + } + } + } + + }.ensuring( + scanUnbound(tr)(i)._1._1.subsetOf(traverseUnbound(tr)._1) && + scanUnbound(tr)(i)._1._2.subsetOf(traverseUnbound(tr)._2) && + (traverseUnbound(tr)._3 ==> scanUnbound(tr)(i)._1._3) && + (scanUnbound(tr)(i)._1._3 ==> scanUnbound(tr)(i)._1._1.disjoint(scanUnbound(tr)(i)._1._2)) + ) + + /** If at the end of the unbound contracts traversal no error did arise, then the set of unbound and bound contracts + * are disjoint. + * + * @param tr The tree that is being traversed + */ + @pure + @opaque + def traverseUnboundProp(tr: Tree[(NodeId, Node)]): Unit = { + require(traverseUnbound(tr)._3) + + if (tr.size > 0) { + scanTraverseUnboundProp(tr, 2 * tr.size - 1) + scanIndexingState( + tr, + (Set.empty[ContractId], Set.empty[ContractId], true), + unboundFun, + (z, t) => z, + 0, + ) + } else { + SetProperties.disjointEmpty(Set.empty[ContractId]) + } + + }.ensuring(traverseUnbound(tr)._1.disjoint(traverseUnbound(tr)._2)) + + /** When a node is processed for the first time, then its contract is contained in the set of unbound contracts if and + * only if the node has no key. Similarly, the contract belongs to the bound contracts if and only if the key of the + * node is well-defined. + * + * The claim is valid if no error arose in the traversal. + * + * @param tr The tree that is being traversed. + * @param i The step during which the node is processed. + */ + @pure + @opaque + def scanTraverseUnboundPropDown(tr: Tree[(NodeId, Node)], i: BigInt): Unit = { + require(i < 2 * tr.size) + require(0 <= i) + require(traverseUnbound(tr)._3) + require(scanUnbound(tr)(i)._3 == TraversalDirection.Down) + + if (i == 2 * tr.size - 1) { + Unreachable() + } else { + scanIndexingState(tr, (Set.empty, Set.empty, true), unboundFun, (z, t) => z, i + 1) + + val (s, p, _) = scanUnbound(tr)(i) + val (spu, spb, _) = scanUnbound(tr)(i + 1)._1 + val (tru, trb, _) = traverseUnbound(tr) + + unfold(unboundFun(s, p)) + + scanTraverseUnboundProp(tr, i + 1) + traverseUnboundProp(tr) + + p._2 match { + case Node.Create(coid, None()) => + SetProperties.subsetOfContains(spu, tru, coid) + SetProperties.disjointContains(tru, trb, coid) + case Node.Create(coid, _) => + SetProperties.subsetOfContains(spb, trb, coid) + SetProperties.disjointContains(tru, trb, coid) + case Node.Exercise(targetCoid, _, _, None(), _) => + SetProperties.subsetOfContains(spu, tru, targetCoid) + SetProperties.disjointContains(tru, trb, targetCoid) + case Node.Exercise(targetCoid, _, _, _, _) => + SetProperties.subsetOfContains(spb, trb, targetCoid) + SetProperties.disjointContains(tru, trb, targetCoid) + case _ => Trivial() + } + } + }.ensuring( + scanUnbound(tr)(i)._2._2 match { + case Node.Create(coid, opt) => + (traverseUnbound(tr)._1.contains(coid) == !opt.isDefined) && (traverseUnbound(tr)._2 + .contains(coid) == opt.isDefined) + case exe: Node.Exercise => + (traverseUnbound(tr)._1 + .contains(exe.targetCoid) == !exe.gkeyOpt.isDefined) && (traverseUnbound(tr)._2 + .contains(exe.targetCoid) == exe.gkeyOpt.isDefined) + case _ => true + } + ) + + /** --------------------------------------------------------------------------------------------------------------- * + * -------------------------------------------------LC------------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** Properties between intermediate states of the locally created/consumed contracts traversal and the final one + * - Any intermediate state of the set of locally created contracts is a subset of the final state, i.e. the set of + * locally created contracts of the transaction. + * - The same claim holds for consumed contracts. + * - If there is no error in the end, then there is no error during any intermediate step. + * + * The proof goes by contradiction. + * + * @param tr The tree that is being traversed + * @param lc The inital set of locally created contracts. Should be empty by default. + * @param consumed The inital set of consumed contracts. Should be empty by default. + * @param defined Indicates whether there is already an error or not. Should be true by default (i.e. no error) + * @param i The number of the step we are looking at in the traversal + */ + @pure + @opaque + def scanTraverseLCProp( + tr: Tree[(NodeId, Node)], + lc: Set[ContractId], + consumed: Set[ContractId], + defined: Boolean, + i: BigInt, + ): Unit = { + decreases(2 * tr.size - i) + require(i < 2 * tr.size) + require(0 <= i) + + val p: ((Set[ContractId], Set[ContractId], Boolean)) => Boolean = (s1, s2, b) => + s1.subsetOf(traverseLC(tr, lc, consumed, defined)._1) && + s2.subsetOf(traverseLC(tr, lc, consumed, defined)._2) && + (traverseLC(tr, lc, consumed, defined)._3 ==> b) + + SetProperties.subsetOfReflexivity(tr.traverse((lc, consumed, defined), buildLC, (z, t) => z)._1) + SetProperties.subsetOfReflexivity(tr.traverse((lc, consumed, defined), buildLC, (z, t) => z)._2) + + if (!p(tr.scan((lc, consumed, defined), buildLC, (z, t) => z)(i)._1)) { + val j = scanNotPropRev(tr, (lc, consumed, defined), buildLC, (z, t) => z, p, i) + SetProperties.subsetOfReflexivity( + tr.scan((lc, consumed, defined), buildLC, (z, t) => z)(j)._1._1 + ) + SetProperties.subsetOfReflexivity( + tr.scan((lc, consumed, defined), buildLC, (z, t) => z)(j)._1._2 + ) + if (j == 2 * tr.size - 1) { + scanIndexingState(tr, (lc, consumed, defined), buildLC, (z, t) => z, 0) + } else { + scanIndexingState(tr, (lc, consumed, defined), buildLC, (z, t) => z, j + 1) + SetProperties.subsetOfTransitivity( + scanLC(tr, lc, consumed, defined)(j)._1._1, + scanLC(tr, lc, consumed, defined)(j + 1)._1._1, + traverseLC(tr, lc, consumed, defined)._1, + ) + SetProperties.subsetOfTransitivity( + scanLC(tr, lc, consumed, defined)(j)._1._2, + scanLC(tr, lc, consumed, defined)(j + 1)._1._2, + traverseLC(tr, lc, consumed, defined)._2, + ) + } + } + + }.ensuring( + scanLC(tr, lc, consumed, defined)(i)._1._1.subsetOf(traverseLC(tr, lc, consumed, defined)._1) && + scanLC(tr, lc, consumed, defined)(i)._1._2 + .subsetOf(traverseLC(tr, lc, consumed, defined)._2) && + (traverseLC(tr, lc, consumed, defined)._3 ==> scanLC(tr, lc, consumed, defined)(i)._1._3) + ) + + /** If a contract is locally created it belongs to the set of locally created contracts at the end of the traversal. + * The same holds when a contract is consumed. Furthermore when a contract is created, the set of locally created + * before reaching this step does not contain it. + * + * The claim holds obviously only when the node is visited for the first time and if no error arose during the + * traversal. + * + * @param tr The tree that is being traversed. + * @param lc The inital set of locally created contracts. Should be empty by default. + * @param consumed The inital set of consumed contracts. Should be empty by default. + * @param defined Indicates whether there is already an error or not. Should be true by default (i.e. no error). + * @param i The step during which the create or exercise node is processed. + */ + @pure + @opaque + def scanTraverseLCPropDown( + tr: Tree[(NodeId, Node)], + lc: Set[ContractId], + consumed: Set[ContractId], + defined: Boolean, + i: BigInt, + ): Unit = { + require(i < 2 * tr.size) + require(0 <= i) + require(traverseLC(tr, lc, consumed, defined)._3) + require(scanLC(tr, lc, consumed, defined)(i)._3 == TraversalDirection.Down) + + if (i == 2 * tr.size - 1) { + Unreachable() + } else { + scanIndexingState(tr, (lc, consumed, defined), buildLC, (z, t) => z, i + 1) + scanTraverseLCProp(tr, lc, consumed, defined, i + 1) + scanTraverseLCProp(tr, lc, consumed, defined, i) + + val (s, p, _) = scanLC(tr, lc, consumed, defined)(i) + + unfold(buildLC(s, p)) + p._2 match { + case Node.Create(coid, _) => + SetProperties.subsetOfContains( + scanLC(tr, lc, consumed, defined)(i + 1)._1._1, + traverseLC(tr, lc, consumed, defined)._1, + coid, + ) + case exe: Node.Exercise if exe.consuming => + SetProperties.subsetOfContains( + scanLC(tr, lc, consumed, defined)(i + 1)._1._2, + traverseLC(tr, lc, consumed, defined)._2, + exe.targetCoid, + ) + case _ => Trivial() + } + } + }.ensuring( + scanLC(tr, lc, consumed, defined)(i)._2._2 match { + case Node.Create(coid, mbKey) => + traverseLC(tr, lc, consumed, defined)._1.contains(coid) && + !scanLC(tr, lc, consumed, defined)(i)._1._1.contains(coid) && + !scanLC(tr, lc, consumed, defined)(i)._1._2.contains(coid) + case exe: Node.Exercise if exe.consuming => + traverseLC(tr, lc, consumed, defined)._2.contains(exe.targetCoid) + case _ => true + } + ) + + /** If the initial parameters of the traversal are subset of or is implied the one of another traversal, then every + * intermediate state of the first traversal will be a subset of or will by implied the same intermediate state of + * the second traversal. + * + * @param tr The tree that is being traversed. + * @param lc1 The inital set of locally created contracts of the first traversal. + * @param lc2 The inital set of locally created contracts of the second traversal. + * @param consumed1 The inital set of consumed contracts of the first traversal. + * @param consumed2 The inital set of consumed contracts of the second traversal. + * @param defined1 Indicates whether there is already an error or not. + * @param defined2 Indicates whether there is already an error or not. + * @param i The step number of the intermediate states we are looking at. + */ + @pure + @opaque + def scanLCSubsetOf( + tr: Tree[(NodeId, Node)], + lc1: Set[ContractId], + consumed1: Set[ContractId], + defined1: Boolean, + lc2: Set[ContractId], + consumed2: Set[ContractId], + defined2: Boolean, + i: BigInt, + ): Unit = { + decreases(i) + require(i < 2 * tr.size) + require(0 <= i) + require(lc1.subsetOf(lc2)) + require(consumed1.subsetOf(consumed2)) + require(defined2 ==> defined1) + + scanIndexingState(tr, (lc1, consumed1, defined1), buildLC, (z, t) => z, i) + scanIndexingState(tr, (lc2, consumed2, defined2), buildLC, (z, t) => z, i) + + if (i == 0) { + Trivial() + } else { + scanLCSubsetOf(tr, lc1, consumed1, defined1, lc2, consumed2, defined2, i - 1) + val (s11, s12, b13) = scanLC(tr, lc1, consumed1, defined1)(i - 1)._1 + val (s21, s22, b23) = scanLC(tr, lc2, consumed2, defined2)(i - 1)._1 + val n = scanLC(tr, lc1, consumed1, defined1)(i - 1)._2 + + scanIndexingNode( + tr, + (lc1, consumed1, defined1), + (lc2, consumed2, defined2), + buildLC, + (z, t) => z, + buildLC, + (z, t) => z, + i - 1, + ) + unfold(buildLC((s11, s12, b13), n)) + unfold(buildLC((s21, s22, b23), n)) + + n._2 match { + case Node.Create(coid, _) => + SetProperties.subsetOfIncl(s11, s21, coid) + if (s11.contains(coid)) { + SetProperties.subsetOfContains(s11, s21, coid) + } else if (s12.contains(coid)) { + SetProperties.subsetOfContains(s12, s22, coid) + } + case exe: Node.Exercise if exe.consuming => + SetProperties.subsetOfIncl(s12, s22, exe.targetCoid) + case _ => Trivial() + } + } + + }.ensuring( + scanLC(tr, lc1, consumed1, defined1)(i)._1._1 + .subsetOf(scanLC(tr, lc2, consumed2, defined2)(i)._1._1) && + scanLC(tr, lc1, consumed1, defined1)(i)._1._2 + .subsetOf(scanLC(tr, lc2, consumed2, defined2)(i)._1._2) && + (scanLC(tr, lc2, consumed2, defined2)(i)._1._3 ==> scanLC(tr, lc1, consumed1, defined1)( + i + )._1._3) + ) + + /** If the initial parameters of the traversal are subset of or is implied the one of another traversal, then the + * resulting state of the first traversal will be a subset of or will by implied the resulting state of the second + * traversal. + * + * @param tr The tree that is being traversed. + * @param lc1 The inital set of locally created contracts of the first traversal. + * @param lc2 The inital set of locally created contracts of the second traversal. + * @param consumed1 The inital set of consumed contracts of the first traversal. + * @param consumed2 The inital set of consumed contracts of the second traversal. + * @param defined1 Indicates whether there is already an error or not. + * @param defined2 Indicates whether there is already an error or not. + */ + @pure + @opaque + def traverseLCSubsetOf( + tr: Tree[(NodeId, Node)], + lc1: Set[ContractId], + consumed1: Set[ContractId], + defined1: Boolean, + lc2: Set[ContractId], + consumed2: Set[ContractId], + defined2: Boolean, + ): Unit = { + require(lc1.subsetOf(lc2)) + require(consumed1.subsetOf(consumed2)) + require(defined2 ==> defined1) + + if (tr.size == 0) { + Trivial() + } else { + + scanIndexingState(tr, (lc1, consumed1, defined1), buildLC, (z, t) => z, 0) + scanIndexingState(tr, (lc2, consumed2, defined2), buildLC, (z, t) => z, 0) + + scanLCSubsetOf(tr, lc1, consumed1, defined1, lc2, consumed2, defined2, 2 * tr.size - 1) + val (s11, s12, b13) = scanLC(tr, lc1, consumed1, defined1)(2 * tr.size - 1)._1 + val (s21, s22, b23) = scanLC(tr, lc2, consumed2, defined2)(2 * tr.size - 1)._1 + val n = scanLC(tr, lc1, consumed1, defined1)(2 * tr.size - 1)._2 + + scanIndexingNode( + tr, + (lc1, consumed1, defined1), + (lc2, consumed2, defined2), + buildLC, + (z, t) => z, + buildLC, + (z, t) => z, + 2 * tr.size - 1, + ) + unfold(buildLC((s11, s12, b13), n)) + unfold(buildLC((s21, s22, b23), n)) + + n._2 match { + case Node.Create(coid, _) => + SetProperties.subsetOfIncl(s11, s21, coid) + if (s11.contains(coid)) { + SetProperties.subsetOfContains(s11, s21, coid) + } else if (s12.contains(coid)) { + SetProperties.subsetOfContains(s12, s22, coid) + } + case exe: Node.Exercise if exe.consuming => + SetProperties.subsetOfIncl(s12, s22, exe.targetCoid) + case _ => Trivial() + } + + } + + }.ensuring( + traverseLC(tr, lc1, consumed1, defined1)._1 + .subsetOf(traverseLC(tr, lc2, consumed2, defined2)._1) && + traverseLC(tr, lc1, consumed1, defined1)._2.subsetOf( + traverseLC(tr, lc2, consumed2, defined2)._2 + ) && + (traverseLC(tr, lc2, consumed2, defined2)._3 ==> traverseLC(tr, lc1, consumed1, defined1)._3) + ) + + /** Any intermediate set of locally created contracts in the traversal is independent of the initial set. That is + * we can extract the set of the initial state, run the traversal with the empty set and afterward computing the union + * between the result and the inital set. + * + * @param tr The tree that is being traversed. + * @param lc The inital set of locally created contracts of the original traversal. + * @param consumed1 The inital set of consumed contracts of the original traversal. + * @param consumed2 The inital set of consumed contracts of the traversal with an empty locally created contracts set. + * @param defined Indicates whether there is already an error or not. + * @param i The step number of the intermediate set we are looking at. + */ + def scanLCExtractInitLC( + tr: Tree[(NodeId, Node)], + lc: Set[ContractId], + consumed1: Set[ContractId], + consumed2: Set[ContractId], + defined: Boolean, + i: BigInt, + ): Unit = { + decreases(i) + require(i < 2 * tr.size) + require(0 <= i) + + scanIndexingState(tr, (lc, consumed1, defined), buildLC, (z, t) => z, i) + scanIndexingState(tr, (Set.empty[ContractId], consumed2, defined), buildLC, (z, t) => z, i) + + if (i == 0) { + SetProperties.unionEmpty(lc) + } else { + scanLCExtractInitLC(tr, lc, consumed1, consumed2, defined, i - 1) + scanIndexingNode( + tr, + (lc, consumed1, defined), + (Set.empty[ContractId], consumed2, defined), + buildLC, + (z, t) => z, + buildLC, + (z, t) => z, + i - 1, + ) + + val (si, n, dir) = scanLC(tr, lc, consumed1, defined)(i - 1) + val se = scanLC(tr, Set.empty[ContractId], consumed2, defined)(i - 1)._1 + + if (dir == TraversalDirection.Down) { + unfold(buildLC(si, n)) + unfold(buildLC(se, n)) + + n._2 match { + case Node.Create(coid, _) => + SetProperties.equalsIncl(se._1 ++ lc, si._1, coid) + unfold((se._1 ++ lc).incl(coid)) + SetProperties.unionAssociativity(se._1, lc, Set(coid)) + SetProperties.unionCommutativity(lc, Set(coid)) + SetProperties.unionAssociativity(se._1, Set(coid), lc) + unfold(si._1.incl(coid)) + unfold(se._1.incl(coid)) + SetProperties.unionEqualsRight(se._1, Set(coid) ++ lc, lc ++ Set(coid)) + SetProperties.equalsTransitivity( + (se._1 ++ Set(coid)) ++ lc, + se._1 ++ (Set(coid) ++ lc), + se._1 ++ (lc ++ Set(coid)), + ) + SetProperties.equalsTransitivity( + (se._1 ++ Set(coid)) ++ lc, + se._1 ++ (lc ++ Set(coid)), + (se._1 ++ lc) ++ Set(coid), + ) + SetProperties.equalsTransitivity( + (se._1 ++ Set(coid)) ++ lc, + (se._1 ++ lc) ++ Set(coid), + si._1 ++ Set(coid), + ) + case _ => Trivial() + } + } + + } + + }.ensuring( + scanLC(tr, Set.empty[ContractId], consumed2, defined)(i)._1._1 ++ lc === + scanLC(tr, lc, consumed1, defined)(i)._1._1 + ) + + /** The set of locally created contracts obtained at the end of the traversal is independent of the initial set. That is + * we can extract the set of the initial state, run the traversal with the empty set and afterward computing the union + * between the result and the inital set. + * + * @param tr The tree that is being traversed. + * @param lc The inital set of locally created contracts of the original traversal. + * @param consumed1 The inital set of consumed contracts of the original traversal. + * @param consumed2 The inital set of consumed contracts of the traversal with an empty locally created contracts set. + * @param defined Indicates whether there is already an error or not. + */ + def traverseLCExtractInitLC( + tr: Tree[(NodeId, Node)], + lc: Set[ContractId], + consumed1: Set[ContractId], + consumed2: Set[ContractId], + defined: Boolean, + ): Unit = { + + if (tr.size > 0) { + + scanIndexingState(tr, (lc, consumed1, defined), buildLC, (z, t) => z, 0) + scanIndexingState(tr, (Set.empty[ContractId], consumed2, defined), buildLC, (z, t) => z, 0) + + scanLCExtractInitLC(tr, lc, consumed1, consumed2, defined, 2 * tr.size - 1) + scanIndexingNode( + tr, + (lc, consumed1, defined), + (Set.empty[ContractId], consumed2, defined), + buildLC, + (z, t) => z, + buildLC, + (z, t) => z, + 2 * tr.size - 1, + ) + + val (si, n, dir) = scanLC(tr, lc, consumed1, defined)(2 * tr.size - 1) + val se = scanLC(tr, Set.empty[ContractId], consumed2, defined)(2 * tr.size - 1)._1 + + if (dir == TraversalDirection.Down) { + unfold(buildLC(si, n)) + unfold(buildLC(se, n)) + + n._2 match { + case Node.Create(coid, _) => + SetProperties.equalsIncl(se._1 ++ lc, si._1, coid) + unfold((se._1 ++ lc).incl(coid)) + SetProperties.unionAssociativity(se._1, lc, Set(coid)) + SetProperties.unionCommutativity(lc, Set(coid)) + SetProperties.unionAssociativity(se._1, Set(coid), lc) + unfold(si._1.incl(coid)) + unfold(se._1.incl(coid)) + SetProperties.unionEqualsRight(se._1, Set(coid) ++ lc, lc ++ Set(coid)) + SetProperties.equalsTransitivity( + (se._1 ++ Set(coid)) ++ lc, + se._1 ++ (Set(coid) ++ lc), + se._1 ++ (lc ++ Set(coid)), + ) + SetProperties.equalsTransitivity( + (se._1 ++ Set(coid)) ++ lc, + se._1 ++ (lc ++ Set(coid)), + (se._1 ++ lc) ++ Set(coid), + ) + SetProperties.equalsTransitivity( + (se._1 ++ Set(coid)) ++ lc, + (se._1 ++ lc) ++ Set(coid), + si._1 ++ Set(coid), + ) + case _ => Trivial() + } + } + + } else { + SetProperties.unionEmpty(lc) + } + + }.ensuring( + traverseLC(tr, Set.empty[ContractId], consumed2, defined)._1 ++ lc === + traverseLC(tr, lc, consumed1, defined)._1 + ) + + /** Any intermediate set of consumed contracts in the traversal is independent of the initial set. That is + * we can extract the set of the initial state, run the traversal with the empty set and afterward computing the union + * between the result and the inital set. + * + * @param tr The tree that is being traversed. + * @param lc1 The inital set of locally created contracts of the original traversal. + * @param lc2 The inital set of locally created contracts of the traversal with an empty consumed contracts set. + * @param consumed The inital set of consumed contracts of the original traversal. + * @param defined Indicates whether there is already an error or not. + * @param i The step number of the intermediate set we are looking at. + */ + def scanLCExtractInitConsumed( + tr: Tree[(NodeId, Node)], + lc1: Set[ContractId], + lc2: Set[ContractId], + consumed: Set[ContractId], + defined: Boolean, + i: BigInt, + ): Unit = { + decreases(i) + require(i < 2 * tr.size) + require(0 <= i) + + scanIndexingState(tr, (lc1, consumed, defined), buildLC, (z, t) => z, i) + scanIndexingState(tr, (lc2, Set.empty[ContractId], defined), buildLC, (z, t) => z, i) + + if (i == 0) { + SetProperties.unionEmpty(consumed) + } else { + scanLCExtractInitConsumed(tr, lc1, lc2, consumed, defined, i - 1) + scanIndexingNode( + tr, + (lc1, consumed, defined), + (lc2, Set.empty[ContractId], defined), + buildLC, + (z, t) => z, + buildLC, + (z, t) => z, + i - 1, + ) + + val (si, n, dir) = scanLC(tr, lc1, consumed, defined)(i - 1) + val se = scanLC(tr, lc2, Set.empty[ContractId], defined)(i - 1)._1 + + if (dir == TraversalDirection.Down) { + unfold(buildLC(si, n)) + unfold(buildLC(se, n)) + + n._2 match { + case exe: Node.Exercise if exe.consuming => + SetProperties.equalsIncl(se._2 ++ consumed, si._2, exe.targetCoid) + unfold((se._2 ++ consumed).incl(exe.targetCoid)) + SetProperties.unionAssociativity(se._2, consumed, Set(exe.targetCoid)) + SetProperties.unionCommutativity(consumed, Set(exe.targetCoid)) + SetProperties.unionAssociativity(se._2, Set(exe.targetCoid), consumed) + unfold(si._2.incl(exe.targetCoid)) + unfold(se._2.incl(exe.targetCoid)) + SetProperties.unionEqualsRight( + se._2, + Set(exe.targetCoid) ++ consumed, + consumed ++ Set(exe.targetCoid), + ) + SetProperties.equalsTransitivity( + (se._2 ++ Set(exe.targetCoid)) ++ consumed, + se._2 ++ (Set(exe.targetCoid) ++ consumed), + se._2 ++ (consumed ++ Set(exe.targetCoid)), + ) + SetProperties.equalsTransitivity( + (se._2 ++ Set(exe.targetCoid)) ++ consumed, + se._2 ++ (consumed ++ Set(exe.targetCoid)), + (se._2 ++ consumed) ++ Set(exe.targetCoid), + ) + SetProperties.equalsTransitivity( + (se._2 ++ Set(exe.targetCoid)) ++ consumed, + (se._2 ++ consumed) ++ Set(exe.targetCoid), + si._2 ++ Set(exe.targetCoid), + ) + case _ => Trivial() + } + } + + } + + }.ensuring( + scanLC(tr, lc2, Set.empty[ContractId], defined)(i)._1._2 ++ consumed === + scanLC(tr, lc1, consumed, defined)(i)._1._2 + ) + + /** The set of consumed contracts obtained at the end of the traversal is independent of the initial set. That is + * we can extract the set of the initial state, run the traversal with the empty set and afterward computing the union + * between the result and the inital set. + * + * @param tr The tree that is being traversed. + * @param lc1 The inital set of locally created contracts of the original traversal. + * @param lc2 The inital set of locally created contracts of the traversal with an empty consumed contracts set. + * @param consumed The inital set of consumed contracts of the original traversal. + * @param defined Indicates whether there is already an error or not. + */ + def traverseLCExtractInitConsumed( + tr: Tree[(NodeId, Node)], + lc1: Set[ContractId], + lc2: Set[ContractId], + consumed: Set[ContractId], + defined: Boolean, + ): Unit = { + + if (tr.size > 0) { + + scanIndexingState(tr, (lc1, consumed, defined), buildLC, (z, t) => z, 0) + scanIndexingState(tr, (lc2, Set.empty[ContractId], defined), buildLC, (z, t) => z, 0) + + scanLCExtractInitConsumed(tr, lc1, lc2, consumed, defined, 2 * tr.size - 1) + scanIndexingNode( + tr, + (lc1, consumed, defined), + (lc2, Set.empty[ContractId], defined), + buildLC, + (z, t) => z, + buildLC, + (z, t) => z, + 2 * tr.size - 1, + ) + + val (si, n, dir) = scanLC(tr, lc1, consumed, defined)(2 * tr.size - 1) + val se = scanLC(tr, lc2, Set.empty[ContractId], defined)(2 * tr.size - 1)._1 + + if (dir == TraversalDirection.Down) { + unfold(buildLC(si, n)) + unfold(buildLC(se, n)) + + n._2 match { + case exe: Node.Exercise if exe.consuming => + SetProperties.equalsIncl(se._2 ++ consumed, si._2, exe.targetCoid) + unfold((se._2 ++ consumed).incl(exe.targetCoid)) + SetProperties.unionAssociativity(se._2, consumed, Set(exe.targetCoid)) + SetProperties.unionCommutativity(consumed, Set(exe.targetCoid)) + SetProperties.unionAssociativity(se._2, Set(exe.targetCoid), consumed) + unfold(si._2.incl(exe.targetCoid)) + unfold(se._2.incl(exe.targetCoid)) + SetProperties.unionEqualsRight( + se._2, + Set(exe.targetCoid) ++ consumed, + consumed ++ Set(exe.targetCoid), + ) + SetProperties.equalsTransitivity( + (se._2 ++ Set(exe.targetCoid)) ++ consumed, + se._2 ++ (Set(exe.targetCoid) ++ consumed), + se._2 ++ (consumed ++ Set(exe.targetCoid)), + ) + SetProperties.equalsTransitivity( + (se._2 ++ Set(exe.targetCoid)) ++ consumed, + se._2 ++ (consumed ++ Set(exe.targetCoid)), + (se._2 ++ consumed) ++ Set(exe.targetCoid), + ) + SetProperties.equalsTransitivity( + (se._2 ++ Set(exe.targetCoid)) ++ consumed, + (se._2 ++ consumed) ++ Set(exe.targetCoid), + si._2 ++ Set(exe.targetCoid), + ) + case _ => Trivial() + } + } + } else { + SetProperties.unionEmpty(consumed) + } + + }.ensuring( + traverseLC(tr, lc2, Set.empty[ContractId], defined)._2 ++ consumed === + traverseLC(tr, lc1, consumed, defined)._2 + ) + + /** Key theorem making the link between the locally created/consumed traversal and the classical transaction traversal. + * + * At any point in both traversals, the set of locally created and consumed contracts in the tree until that point + * are equal to the locallyCreated and consumed fields of the state obtained while processing the transaction until + * that point. In that case the starting sets of the former traversal are the fields of the initial state of the latter. + * + * @param tr The transaction that is being processed + * @param init The initial state of the classical transaction traversal + * @param defined If there is already an error when the checks are performed. Should be true by default. + * @param i The step of the traversal we are looking at. + */ + @pure + @opaque + def scanTransactionLC( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + defined: Boolean, + i: BigInt, + ): Unit = { + + decreases(i) + + require(0 <= i) + require(i < 2 * tr.size) + require(init.isRight) + require(scanTransaction(tr, init)(i)._1.isRight) + + val initTr = (init.get.locallyCreated, init.get.consumed, defined) + + scanIndexingState(tr, init, traverseInFun, traverseOutFun, i) + scanIndexingState(tr, initTr, buildLC, (z, t) => z, i) + + if (i == 0) { + Trivial() + } else { + + val (si, n, dir) = scanTransaction(tr, init)(i - 1) + val lci = scanLC(tr, init.get.locallyCreated, init.get.consumed, defined)(i - 1)._1 + + scanTransactionProp(tr, init, i - 1, i) + unfold(propagatesError(si, scanTransaction(tr, init)(i)._1)) + scanTransactionLC(tr, init, defined, i - 1) + scanIndexingNode(tr, init, initTr, traverseInFun, traverseOutFun, buildLC, (z, t) => z, i - 1) + + if (dir == TraversalDirection.Down) { + unfold(buildLC(lci, n)) + unfold(traverseInFun(si, n)) + val snext = traverseInFun(si, n) + + n._2 match { + case a: Node.Action => + handleNodeLocallyCreated(si, n._1, a) + handleNodeConsumed(si, n._1, a) + case _ => + unfold(sameLocallyCreated(si, snext)) + unfold(sameLocallyCreated(si.get, snext)) + unfold(sameConsumed(si, snext)) + unfold(sameConsumed(si.get, snext)) + } + } else { + val snext = traverseOutFun(si, n) + unfold(sameLocallyCreated(si, snext)) + unfold(sameLocallyCreated(si.get, snext)) + unfold(sameConsumed(si, snext)) + unfold(sameConsumed(si.get, snext)) + } + } + }.ensuring( + (scanTransaction(tr, init)(i)._1.get.locallyCreated == scanLC( + tr, + init.get.locallyCreated, + init.get.consumed, + defined, + )(i)._1._1) && + (scanTransaction(tr, init)(i)._1.get.consumed == scanLC( + tr, + init.get.locallyCreated, + init.get.consumed, + defined, + )(i)._1._2) + ) + + /** Key theorem making the link between the locally created/consumed traversal and the classical transaction traversal. + * + * The set of locally created and consumed contracts in the tree are equal to the locallyCreated and consumed fields + * of the state obtained after processing the transaction. + * In that case the starting sets of the former traversal are the fields of the initial state of the latter. + * + * @param tr The transaction that is being processed + * @param init The initial state of the classical transaction traversal + * @param defined If there is already an error when the checks are performed. Should be true by default. + */ + @pure + @opaque + def traverseTransactionLC( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + defined: Boolean, + ): Unit = { + + require(init.isRight) + require(traverseTransaction(tr, init).isRight) + + if (tr.size == 0) { + Trivial() + } else { + val initTr = (init.get.locallyCreated, init.get.consumed, defined) + + scanIndexingState(tr, init, traverseInFun, traverseOutFun, 0) + scanIndexingState(tr, initTr, buildLC, (z, t) => z, 0) + + val (si, n, dir) = scanTransaction(tr, init)(2 * tr.size - 1) + val lci = scanLC(tr, init.get.locallyCreated, init.get.consumed, defined)(2 * tr.size - 1)._1 + + traverseTransactionDefined(tr, init, 2 * tr.size - 1) + unfold(propagatesError(si, traverseTransaction(tr, init))) + scanTransactionLC(tr, init, defined, 2 * tr.size - 1) + scanIndexingNode( + tr, + init, + initTr, + traverseInFun, + traverseOutFun, + buildLC, + (z, t) => z, + 2 * tr.size - 1, + ) + + if (dir == TraversalDirection.Down) { + unfold(buildLC(lci, n)) + unfold(traverseInFun(si, n)) + val snext = traverseInFun(si, n) + + n._2 match { + case a: Node.Action => + handleNodeLocallyCreated(si, n._1, a) + handleNodeConsumed(si, n._1, a) + case _ => + unfold(sameLocallyCreated(si, snext)) + unfold(sameLocallyCreated(si.get, snext)) + unfold(sameConsumed(si, snext)) + unfold(sameConsumed(si.get, snext)) + } + } else { + val snext = traverseOutFun(si, n) + unfold(sameLocallyCreated(si, snext)) + unfold(sameLocallyCreated(si.get, snext)) + unfold(sameConsumed(si, snext)) + unfold(sameConsumed(si.get, snext)) + } + } + }.ensuring( + (traverseTransaction(tr, init).get.locallyCreated == traverseLC( + tr, + init.get.locallyCreated, + init.get.consumed, + defined, + )._1) && + (traverseTransaction(tr, init).get.consumed == traverseLC( + tr, + init.get.locallyCreated, + init.get.consumed, + defined, + )._2) + ) + +} diff --git a/daml-lf/verification/tree/TransactionTreeFull.scala b/daml-lf/verification/tree/TransactionTreeFull.scala new file mode 100644 index 0000000000..b2f790f2dc --- /dev/null +++ b/daml-lf/verification/tree/TransactionTreeFull.scala @@ -0,0 +1,1004 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package tree + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ +import utils.TreeProperties._ + +import transaction.CSMHelpers._ + +import transaction.CSMEitherDef._ +import transaction.CSMEither._ + +import transaction.CSMKeysPropertiesDef._ +import transaction.CSMKeysProperties._ + +import transaction.CSMAdvance._ +import transaction.CSMInvariantDef._ + +import transaction.{State} +import transaction.ContractStateMachine.{KeyMapping} + +import TransactionTree._ +import TransactionTreeDef._ +import TransactionTreeKeys._ +import TransactionTreeKeysDef._ +import TransactionTreeChecksDef._ +import TransactionTreeChecks._ +import TransactionTreeInconsistency._ +import TransactionTreeAdvance._ +import TransactionTreeAdvanceDef._ + +/** In the contract state maching, handling a node comes in two major step: + * - Adding the node's key to the global keys with its corresponding mapping + * - Processing the node + * In the simplified version of the contract state machine, this behavior is respectively split in two different + * functions [[transaction.CSMKeysPropertiesDef.addKeyBeforeNode]] and [[transaction.State.handleNode]] + * + * A key property of transaction traversal is that one can first add the key-mapping pairs of every node in the + * globalKeys and then process the transaction. The proof of this claims lies in [[TransactionTreeFull.scanTransactionFullCommute]]. + */ + +object TransactionTreeFullDef { + + /** Function called when a node is entered for the first time ([[utils.TraversalDirection.Down]]). + * - If the node is an instance of [[transaction.Node.Action]] we add its key and its corresponding + * mapping. Then, we handle it. + * - If it is a [[transaction.Node.Rollback]] node we call [[transaction.State.beginRollback]]. + * + * Among the direct properties one can deduce we have that + * - If the state already contains the node's key, then the function behaves in the same way that + * [[TransactionTreeDef.traverseInFun]] + * - if the inital state is an error then the result is also an error + * + * @param s State before entering the node for the first time + * @param p Node and its id + */ + @pure + @opaque + def traverseInFunFull( + s: Either[KeyInputError, State], + p: (NodeId, Node), + ): Either[KeyInputError, State] = { + val res = traverseInFun(addKeyBeforeNode(s, p), p) + propagatesErrorTransitivity(s, addKeyBeforeNode(s, p), res) + res + }.ensuring(res => + propagatesError(s, res) && + (containsKey(s)(p._2) ==> (res == traverseInFun(s, p))) + ) + + /** Tree traversal that adds all the necessary keys in the tree to the initial state. + * Returns a list of triples whose entries are: + * - The intermediate states of the traversal + * - The nodes that are being processed + * - The traversal directions (i.e. if the nodes are visited for the first or the second time) + */ + @pure + def scanAddKey( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): List[(Either[KeyInputError, State], (NodeId, Node), TraversalDirection)] = { + tr.scan(init, addKeyBeforeNode, (z, t) => z) + } + + /** Tree traversal that adds all the necessary keys in the tree to the initial state. + */ + @pure + def traverseAddKey( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): Either[KeyInputError, State] = { + tr.traverse(init, addKeyBeforeNode, (z, t) => z) + } + + /** List of triples whose respective entries are: + * - The state before the i-th step of the traversal + * - The pair node id - node that is handle during the i-th step + * - The direction i.e. if that's the first or the second time we enter the node + * + * @param tr The transaction that is being processed + * @param init The initial state of the transaction + */ + @pure + def scanTransactionFull( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): List[(Either[KeyInputError, State], (NodeId, Node), TraversalDirection)] = { + tr.scan(init, traverseInFunFull, traverseOutFun) + } + + /** Resulting state after a transaction traversal. + * + * @param tr The transaction that is being processed + * @param init The initial state of the transaction + */ + @pure + def traverseTransactionFull( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): Either[KeyInputError, State] = { + tr.traverse(init, traverseInFunFull, traverseOutFun) + } +} + +object TransactionTreeFull { + + import TransactionTreeFullDef._ + + /** Adding all the keys to a state up to a given point in a tree traversal is equivalent to collecting all of them + * and concatenating them to the global keys of the state. + * + * @param tr The tree that is being traversed + * @param init The state to which we are adding the keys + * @param i The point up to which we gather the keys + */ + @pure @opaque + def scanAddKeyIsConcatLeftGlobalKey( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): Unit = { + decreases(i) + require(i >= 0) + require(i < 2 * tr.size) + + scanIndexingState(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, i) + scanIndexingState(tr, init, addKeyBeforeNode, (z, t) => z, i) + if (i == 0) { + unfold(concatLeftGlobalKeys(init, collectTrace(tr)(i)._1)) + if (init.isRight) { + unfold(concatLeftGlobalKeys(init.get, collectTrace(tr)(i)._1)) + MapProperties.concatEmpty(init.get.globalKeys) + MapAxioms.extensionality( + init.get.globalKeys, + Map.empty[GlobalKey, KeyMapping] ++ init.get.globalKeys, + ) + } + } else { + scanAddKeyIsConcatLeftGlobalKey(tr, init, i - 1) + scanIndexingNode( + tr, + Map.empty[GlobalKey, KeyMapping], + init, + collectFun, + (z, t) => z, + addKeyBeforeNode, + (z, t) => z, + i - 1, + ) + val (si, n, dir) = scanAddKey(tr, init)(i - 1) + val ci = collectTrace(tr)(i - 1)._1 + + if (dir == TraversalDirection.Down) { + collectFunConcat(ci, n) + concatLeftGlobalKeysAssociativity(init, ci, nodeKeyMap(n._2)) + } + } + }.ensuring(scanAddKey(tr, init)(i)._1 == concatLeftGlobalKeys(init, collectTrace(tr)(i)._1)) + + /** Adding all the keys of a tree to a state is equivalent to collecting all of them + * and concatenating them to the global keys of the state. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + */ + @pure + @opaque + def traverseAddKeyIsConcatLeftGlobalKey( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): Unit = { + + if (tr.size > 0) { + scanIndexingState(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, 0) + scanIndexingState(tr, init, addKeyBeforeNode, (z, t) => z, 0) + scanAddKeyIsConcatLeftGlobalKey(tr, init, 2 * tr.size - 1) + scanIndexingNode( + tr, + Map.empty[GlobalKey, KeyMapping], + init, + collectFun, + (z, t) => z, + addKeyBeforeNode, + (z, t) => z, + 2 * tr.size - 1, + ) + val (si, n, dir) = scanAddKey(tr, init)(2 * tr.size - 1) + val ci = collectTrace(tr)(2 * tr.size - 1)._1 + + if (dir == TraversalDirection.Down) { + collectFunConcat(ci, n) + concatLeftGlobalKeysAssociativity(init, ci, nodeKeyMap(n._2)) + } + } else { + unfold(concatLeftGlobalKeys(init, collect(tr))) + if (init.isRight) { + unfold(concatLeftGlobalKeys(init.get, collect(tr))) + MapProperties.concatEmpty(init.get.globalKeys) + MapAxioms.extensionality( + init.get.globalKeys, + Map.empty[GlobalKey, KeyMapping] ++ init.get.globalKeys, + ) + } + } + }.ensuring(traverseAddKey(tr, init) == concatLeftGlobalKeys(init, collect(tr))) + + /** If a state contains all the keys of a transaction up to a given point, then first concatenating to the left new + * keys to its glboal keys and processing the transaction afterward up to that point, is the same as first processing + * the transaction and then adding the keys. The operations commute. + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param i The point up to which the transaction is being processed + * @param glK The keys that are being concatenated to the globalKeys of the state + */ + @pure + @opaque + def scanTransactionConcatLeftGlobalKeys( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + glK: Map[GlobalKey, KeyMapping], + ): Unit = { + decreases(i) + require(i >= 0) + require(i < 2 * tr.size) + require(containsAllKeysBefore(tr, init, i)) + + scanIndexingState(tr, concatLeftGlobalKeys(init, glK), traverseInFun, traverseOutFun, i) + scanIndexingState(tr, init, traverseInFun, traverseOutFun, i) + + if (i == 0) { + Trivial() + } else { + scanIndexingNode( + tr, + init, + concatLeftGlobalKeys(init, glK), + traverseInFun, + traverseOutFun, + traverseInFun, + traverseOutFun, + i - 1, + ) + containsAllKeysBeforeImplies(tr, init, i, i - 1) + scanTransactionConcatLeftGlobalKeys(tr, init, i - 1, glK) + + val (si, n, dir) = scanTransaction(tr, init)(i - 1) + val sci = scanTransaction(tr, concatLeftGlobalKeys(init, glK))(i - 1)._1 + + if (dir == TraversalDirection.Down) { + unfold(traverseInFun(si, n)) + unfold(traverseInFun(sci, n)) + n._2 match { + case a: Node.Action => + scanIndexingNode( + tr, + init, + true, + traverseInFun, + traverseOutFun, + containsAllKeysFun(init), + containsAllKeysFun(init), + i - 1, + ) + scanIndexingState(tr, true, containsAllKeysFun(init), containsAllKeysFun(init), i) + unfold(containsAllKeysFun(init)(containsAllKeysBefore(tr, init, i - 1), n)) + scanTransactionProp(tr, init, i - 1) + containsKeySameGlobalKeys(init, si, n._2) + handleNodeConcatLeftGlobalKeys(si, n._1, a, glK) + case r: Node.Rollback => beginRollbackConcatLeftGlobalKeys(si, glK) + } + } else { + unfold(traverseOutFun(si, n)) + unfold(traverseOutFun(sci, n)) + endRollbackConcatLeftGlobalKeys(si, glK) + } + } + + }.ensuring( + scanTransaction(tr, concatLeftGlobalKeys(init, glK))(i)._1 == + concatLeftGlobalKeys(scanTransaction(tr, init)(i)._1, glK) + ) + + /** If a state contains all the keys of a transaction up to a given point, then first adding a key mapping and processing + * the transaction afterward up to that point, is the same as first processing the transaction and then adding the + * key. The operations commute. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param i The point up to which the transaction is being processed + * @param n The node whose key is added. + */ + @pure + @opaque + def scanTransactionAddKeyBeforeNode( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + n: Node, + ): Unit = { + require(i >= 0) + require(i < 2 * tr.size) + require(containsAllKeysBefore(tr, init, i)) + scanTransactionConcatLeftGlobalKeys(tr, init, i, nodeKeyMap(n)) + }.ensuring( + scanTransaction(tr, addKeyBeforeNode(init, n))(i)._1 == + addKeyBeforeNode(scanTransaction(tr, init)(i)._1, n) + ) + + /** For any point in time adding all the necessary keys at the beginning and then processing the transaction is + * equivalent to add the key and then processing the node for every step. + * + * @param tr The transaction + * @param init The initial state of the traversal + * @param i The point up to which the transaction is being processed + */ + @pure + @opaque + def scanTransactionFullCommute( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): Unit = { + decreases(i) + require(i >= 0) + require(i < 2 * tr.size) + require(init.isRight) + + scanIndexingNode( + tr, + scanAddKey(tr, init)(i)._1, + init, + traverseInFun, + traverseOutFun, + traverseInFunFull, + traverseOutFun, + i, + ) + scanIndexingState(tr, init, traverseInFunFull, traverseOutFun, i) + scanIndexingState(tr, scanAddKey(tr, init)(i)._1, traverseInFun, traverseOutFun, i) + scanIndexingState(tr, init, addKeyBeforeNode, (z, t) => z, i) + + if (i == 0) { + Trivial() + } else { + scanTransactionFullCommute(tr, init, i - 1) + scanIndexingNode( + tr, + scanAddKey(tr, init)(i)._1, + init, + traverseInFun, + traverseOutFun, + traverseInFunFull, + traverseOutFun, + i - 1, + ) + scanIndexingNode( + tr, + init, + init, + traverseInFunFull, + traverseOutFun, + addKeyBeforeNode, + (z, t) => z, + i - 1, + ) + + val (tfi, n, dir) = scanTransactionFull(tr, init)(i - 1) + val ti = scanTransaction(tr, scanAddKey(tr, init)(i)._1)(i - 1)._1 + val si = scanAddKey(tr, init)(i - 1)._1 + + if (dir == TraversalDirection.Down) { + unfold(traverseInFunFull(tfi, n)) + scanAddKeyContainsAllKeysBefore(tr, init, i - 1) + scanTransactionAddKeyBeforeNode(tr, scanAddKey(tr, init)(i - 1)._1, i - 1, n._2) + } + } + + }.ensuring(scanTransactionFull(tr, init)(i) == scanTransaction(tr, scanAddKey(tr, init)(i)._1)(i)) + + /** Adding all the necessary keys at the beginning and then processing the transaction is + * equivalent to add the key and then processing the node for every step. + * + * @param tr The transaction + * @param init The initial state of the traversal + */ + @pure + @opaque + def traverseTransactionFullCommute( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): Unit = { + require(init.isRight) + + if (tr.size > 0) { + scanIndexingState(tr, init, traverseInFunFull, traverseOutFun, 0) + scanIndexingState(tr, traverseAddKey(tr, init), traverseInFun, traverseOutFun, 0) + scanIndexingState(tr, init, addKeyBeforeNode, (z, t) => z, 0) + + scanTransactionFullCommute(tr, init, 2 * tr.size - 1) + scanIndexingNode( + tr, + traverseAddKey(tr, init), + init, + traverseInFun, + traverseOutFun, + traverseInFunFull, + traverseOutFun, + 2 * tr.size - 1, + ) + scanIndexingNode( + tr, + init, + init, + traverseInFunFull, + traverseOutFun, + addKeyBeforeNode, + (z, t) => z, + 2 * tr.size - 1, + ) + + val (tfi, n, dir) = scanTransactionFull(tr, init)(2 * tr.size - 1) + val ti = scanTransaction(tr, traverseAddKey(tr, init))(2 * tr.size - 1)._1 + val si = scanAddKey(tr, init)(2 * tr.size - 1)._1 + + if (dir == TraversalDirection.Down) { + unfold(traverseInFunFull(tfi, n)) + scanAddKeyContainsAllKeysBefore(tr, init, 2 * tr.size - 1) + scanTransactionAddKeyBeforeNode( + tr, + scanAddKey(tr, init)(2 * tr.size - 1)._1, + 2 * tr.size - 1, + n._2, + ) + } + } + + }.ensuring(traverseTransactionFull(tr, init) == traverseTransaction(tr, traverseAddKey(tr, init))) + + /** Given a node in a tree traversal that adds all the necessary key of the tree to an initial state, any intermediate + * state in the traversal that comes after the node will contain its key + * + * @param tr The tree that is being traversed + * @param init The initial state of the transaction + * @param i The step number of the node + * @param j The step number of the interemediate state + */ + @pure + @opaque + def scanAddKeyContains( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + j: BigInt, + ): Unit = { + decreases(j) + require(i >= 0) + require(i < j) + require(j < 2 * tr.size) + require(scanAddKey(tr, init)(i)._3 == TraversalDirection.Down) + + scanIndexingState(tr, init, addKeyBeforeNode, (z, t) => z, j) + if (j == i + 1) { + Trivial() + } else { + scanAddKeyContains(tr, init, i, j - 1) + containsKeyAddKeyBeforeNode( + scanAddKey(tr, init)(j - 1)._1, + scanAddKey(tr, init)(j - 1)._2._2, + scanAddKey(tr, init)(i)._2._2, + ) + } + + }.ensuring(containsKey(scanAddKey(tr, init)(j)._1)(scanAddKey(tr, init)(i)._2._2)) + + /** If a state contains all the keys of a tree up to a given point, then adding a key to the global keys of the state + * does not change the truth of the statement. + */ + @pure + @opaque + def containsAllKeysBeforeAddKeyBeforeNode( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + n: Node, + i: BigInt, + ): Unit = { + decreases(i) + require(i >= 0) + require(i < 2 * tr.size) + require(containsAllKeysBefore(tr, init, i)) + + scanIndexingState(tr, true, containsAllKeysFun(init), containsAllKeysFun(init), i) + scanIndexingState( + tr, + true, + containsAllKeysFun(addKeyBeforeNode(init, n)), + containsAllKeysFun(addKeyBeforeNode(init, n)), + i, + ) + if (i == 0) { + Trivial() + } else { + scanIndexingNode( + tr, + true, + true, + containsAllKeysFun(init), + containsAllKeysFun(init), + containsAllKeysFun(addKeyBeforeNode(init, n)), + containsAllKeysFun(addKeyBeforeNode(init, n)), + i - 1, + ) + val (s1, ni, dir) = tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(i - 1) + val s2 = tr + .scan( + true, + containsAllKeysFun(addKeyBeforeNode(init, n)), + containsAllKeysFun(addKeyBeforeNode(init, n)), + )(i - 1) + ._1 + unfold(containsAllKeysFun(init)(s1, ni)) + unfold(containsAllKeysFun(addKeyBeforeNode(init, n))(s2, ni)) + containsKeyAddKeyBeforeNode(init, n, ni._2) + containsAllKeysBeforeAddKeyBeforeNode(tr, init, n, i - 1) + } + + }.ensuring(containsAllKeysBefore(tr, addKeyBeforeNode(init, n), i)) + + /** Any intermediate state of the traversal that adds all the keys of the tree to the initial state contains all the keys + * of the tree up to that point + * + * @param tr Tne tree that is being traversed + * @param init The initial state of the traversal + * @param i The step number of the intermediate state + */ + @pure + @opaque + def scanAddKeyContainsAllKeysBefore( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): Unit = { + decreases(i) + require(i >= 0) + require(i < 2 * tr.size) + + containsAllKeysBeforeAlt(tr, scanAddKey(tr, init)(i)._1, i) + scanIndexingState(tr, true, containsAllKeysFun(scanAddKey(tr, init)(i)._1), (z, t) => z, i) + scanIndexingState(tr, init, addKeyBeforeNode, (z, t) => z, i) + if (i == 0) { + Trivial() + } else { + containsAllKeysBeforeAlt(tr, scanAddKey(tr, init)(i)._1, i - 1) + scanIndexingNode( + tr, + init, + true, + addKeyBeforeNode, + (z, t) => z, + containsAllKeysFun(scanAddKey(tr, init)(i)._1), + (z, t) => z, + i - 1, + ) + scanAddKeyContainsAllKeysBefore(tr, init, i - 1) + + val si = tr.scan(init, addKeyBeforeNode, (z, t) => z)(i - 1)._1 + val (ci, ni, dir) = + tr.scan(true, containsAllKeysFun(scanAddKey(tr, init)(i)._1), (z, t) => z)(i - 1) + + if (dir == TraversalDirection.Down) { + unfold(containsAllKeysFun(scanAddKey(tr, init)(i)._1)(ci, ni)) + scanAddKeyContains(tr, init, i - 1, i) + containsAllKeysBeforeAddKeyBeforeNode(tr, si, ni._2, i - 1) + } + } + + }.ensuring(containsAllKeysBefore(tr, scanAddKey(tr, init)(i)._1, i)) + + /** The final state of the traversal that adds all the keys of the tree to the initial state contains all the keys + * of the tree. + * + * @param tr Tne tree that is being traversed + * @param init The initial state of the traversal + */ + @pure + @opaque + def traverseAddKeyContainsAllKeys( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): Unit = { + + if (tr.size == 0) { + Trivial() + } else { + containsAllKeysAlt(tr, traverseAddKey(tr, init)) + scanIndexingState(tr, true, containsAllKeysFun(traverseAddKey(tr, init)), (z, t) => z, 0) + scanIndexingState(tr, init, addKeyBeforeNode, (z, t) => z, 0) + containsAllKeysBeforeAlt(tr, traverseAddKey(tr, init), 2 * tr.size - 1) + scanIndexingNode( + tr, + init, + true, + addKeyBeforeNode, + (z, t) => z, + containsAllKeysFun(traverseAddKey(tr, init)), + (z, t) => z, + 2 * tr.size - 1, + ) + scanAddKeyContainsAllKeysBefore(tr, init, 2 * tr.size - 1) + + val si = tr.scan(init, addKeyBeforeNode, (z, t) => z)(2 * tr.size - 1)._1 + val (ci, ni, dir) = + tr.scan(true, containsAllKeysFun(traverseAddKey(tr, init)), (z, t) => z)(2 * tr.size - 1) + + if (dir == TraversalDirection.Down) { + unfold(containsAllKeysFun(scanAddKey(tr, init)(2 * tr.size - 1)._1)(ci, ni)) +// scanAddKeyContains(tr, init, i - 1, i) + containsAllKeysBeforeAddKeyBeforeNode(tr, si, ni._2, 2 * tr.size - 1) + } + } + + }.ensuring(containsAllKeys(tr, traverseAddKey(tr, init))) + + @pure @opaque + def traverseTransactionFullEmpty(tr: Tree[(NodeId, Node)]): Unit = { + + val remptyNoKey = Right[KeyInputError, State](State.empty) + val trEmpty = traverseAddKey(tr, remptyNoKey) + + traverseTransactionFullCommute(tr, remptyNoKey) + traverseAddKeyIsConcatLeftGlobalKey(tr, remptyNoKey) + + unfold(concatLeftGlobalKeys(State.empty, collect(tr))) + + MapProperties.concatEmpty(collect(tr)) + MapAxioms.extensionality(collect(tr) ++ Map.empty[GlobalKey, KeyMapping], collect(tr)) + + unfold(emptyState(tr)) + + unfold(sameConsumed(remptyNoKey, trEmpty)) + unfold(sameConsumed(State.empty, trEmpty)) + + unfold(sameActiveState(remptyNoKey, trEmpty)) + unfold(sameActiveState(State.empty, trEmpty)) + + unfold(sameLocallyCreated(remptyNoKey, trEmpty)) + unfold(sameLocallyCreated(State.empty, trEmpty)) + + unfold(sameStack(remptyNoKey, trEmpty)) + unfold(sameStack(State.empty, trEmpty)) + + }.ensuring( + traverseTransactionFull(tr, Right[KeyInputError, State](State.empty)) == + traverseTransaction(tr, Right[KeyInputError, State](emptyState(tr))) + ) + + /** The globalKeys of the resulting state of a transaction traversal starting from the empty state is the map + * obtained from collecting the keys of the tree. + */ + @pure + @opaque + def traverseTransactionFullEmptyKeys(tr: Tree[(NodeId, Node)]): Unit = { + + require(traverseTransactionFull(tr, Right[KeyInputError, State](State.empty)).isRight) + + val remptyNoKey = Right[KeyInputError, State](State.empty) + val rempty = Right[KeyInputError, State](emptyState(tr)) + val trFullEmpty = traverseTransactionFull(tr, remptyNoKey) + + traverseTransactionFullEmpty(tr) + traverseTransactionProp(tr, rempty) + unfold(sameGlobalKeys(rempty, trFullEmpty)) + unfold(sameGlobalKeys(emptyState(tr), trFullEmpty)) + unfold(emptyState(tr)) + + }.ensuring( + traverseTransactionFull(tr, Right[KeyInputError, State](State.empty)).get.globalKeys == collect( + tr + ) + ) + + /** Advance is defined if and only if the map obtained by gathering the keys of the trees is a submap of the active keys + * of the state on which advance is applied to. + * + * Unfortunately due to a bug with lambdas in stainless the proof does not verify so we ignore it. + */ + @pure + @opaque + @dropVCs + def traverseTransactionAdvanceDefined( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): Unit = { + + require(init.isRight) + require(traverseTransactionFull(tr, Right[KeyInputError, State](State.empty)).isRight) + require(tr.isUnique) + require(traverseUnbound(tr)._3) + require(traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._3) + require( + stateInvariant(traverseAddKey(tr, init))( + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + ) + require( + stateInvariant(Right[KeyInputError, State](emptyState(tr)))( + traverseUnbound(tr)._1, + traverseLC(tr, Set.empty[ContractId], Set.empty[ContractId], true)._1, + ) + ) + require( + !traverseTransactionFull(tr, Right[KeyInputError, State](State.empty)).get.withinRollbackScope + ) + + val remptyNoKey = Right[KeyInputError, State](State.empty) + val trFullEmpty = traverseTransactionFull(tr, remptyNoKey) + val trFull = traverseTransactionFull(tr, init) + val trKey = traverseAddKey(tr, init) + val rempty = Right[KeyInputError, State](emptyState(tr)) + + traverseAddKeyIsConcatLeftGlobalKey(tr, init) + + @pure @opaque + def trKeyProp: Unit = { + unfold(propagatesError(init, trKey)) + + unfold(sameActiveState(init, trKey)) + unfold(sameActiveState(init.get, trKey)) + + unfold(sameConsumed(init, trKey)) + unfold(sameConsumed(init.get, trKey)) + + }.ensuring( + trKey.isRight && + (trKey.get.consumed == init.get.consumed) && + (trKey.get.activeState == init.get.activeState) + ) + + trKeyProp + + @pure + @opaque + def traverseTransactionAdvanceCondition: Unit = { + + unfold(emptyState(tr)) + unfold(concatLeftGlobalKeys(init.get, collect(tr))) + + traverseTransactionFullEmptyKeys(tr) + + val p1: GlobalKey => Boolean = k => trKey.get.activeKeys.get(k) == collect(tr).get(k) + val p2: GlobalKey => Boolean = + k => init.get.activeKeys.get(k).forall(m => Some(m) == trFullEmpty.get.globalKeys.get(k)) + + if (collect(tr).keySet.forall(p1) && !trFullEmpty.get.globalKeys.keySet.forall(p2)) { + val k = SetProperties.notForallWitness(collect(tr).keySet, p2) + SetProperties.forallContains(collect(tr).keySet, p1, k) + activeKeysGet(trKey.get, k) + activeKeysGet(init.get, k) + MapAxioms.concatGet(collect(tr), init.get.globalKeys, k) + } else if (!collect(tr).keySet.forall(p1) && trFullEmpty.get.globalKeys.keySet.forall(p2)) { + val k = SetProperties.notForallWitness(collect(tr).keySet, p1) + SetProperties.forallContains(collect(tr).keySet, p2, k) + activeKeysGet(trKey.get, k) + activeKeysGet(init.get, k) + MapAxioms.concatGet(collect(tr), init.get.globalKeys, k) + + if (!init.get.activeKeys.get(k).isDefined) { + // true bc of invariant + assert(trKey.get.activeKeys.get(k) == collect(tr).get(k)) + } + } + }.ensuring( + emptyState(tr).globalKeys.keySet.forall(k => + trKey.get.activeKeys.get(k) == emptyState(tr).globalKeys.get(k) + ) == + trFullEmpty.get.globalKeys.keySet.forall(k => + init.get.activeKeys.get(k).forall(m => Some(m) == trFullEmpty.get.globalKeys.get(k)) + ) + ) + + advanceIsDefined(init.get, trFullEmpty.get) + assert( + init.get.advance(trFullEmpty.get).isRight == + trFullEmpty.get.globalKeys.keySet.forall(k => + init.get.activeKeys.get(k).forall(m => Some(m) == trFullEmpty.get.globalKeys.get(k)) + ) + ) + + traverseTransactionAdvanceCondition + assert( + emptyState(tr).globalKeys.keySet.forall(k => + trKey.get.activeKeys.get(k) == emptyState(tr).globalKeys.get(k) + ) == + trFullEmpty.get.globalKeys.keySet.forall(k => + init.get.activeKeys.get(k).forall(m => Some(m) == trFullEmpty.get.globalKeys.get(k)) + ) + ) + + traverseTransactionFullEmpty(tr) + traverseAddKeyContainsAllKeys(tr, init) + traverseTransactionEmptyDefined(tr, trKey) + assert( + emptyState(tr).globalKeys.keySet.forall(k => + trKey.get.activeKeys.get(k) == emptyState(tr).globalKeys.get(k) + ) == + traverseTransaction(tr, trKey).isRight + ) + + traverseTransactionFullCommute(tr, init) + assert( + traverseTransactionFull(tr, init).isRight == + traverseTransaction(tr, trKey).isRight + ) + + }.ensuring( + (init.get + .advance(traverseTransactionFull(tr, Right[KeyInputError, State](State.empty)).get) + .isRight == + traverseTransactionFull(tr, init).isRight) + ) + + /** If the advance method is defined, the traversing a transaction from a given initial state is the same as calling + * advance on this state with the final state of the empty state traversal as argument. + * + * @param tr The tree being traversed + * @param init The initial state of the traversal + */ + @pure + @opaque + def traverseTransactionAdvance( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): Unit = { + + require(init.isRight) + require(traverseTransactionFull(tr, Right[KeyInputError, State](State.empty)).isRight) + require(tr.isUnique) + require(traverseUnbound(tr)._3) + require(traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._3) + require( + stateInvariant(traverseAddKey(tr, init))( + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + ) + require( + stateInvariant(Right[KeyInputError, State](emptyState(tr)))( + traverseUnbound(tr)._1, + traverseLC(tr, Set.empty[ContractId], Set.empty[ContractId], true)._1, + ) + ) + require( + !traverseTransactionFull(tr, Right[KeyInputError, State](State.empty)).get.withinRollbackScope + ) + + val remptyNoKey = Right[KeyInputError, State](State.empty) + val trFullEmpty = traverseTransactionFull(tr, remptyNoKey) + val trFull = traverseTransactionFull(tr, init) + val trKey = traverseAddKey(tr, init) + val rempty = Right[KeyInputError, State](emptyState(tr)) + + @pure + @opaque + def trKeyProp: Unit = { + traverseAddKeyIsConcatLeftGlobalKey(tr, init) + unfold(propagatesError(init, trKey)) + + unfold(sameActiveState(init, trKey)) + unfold(sameActiveState(init.get, trKey)) + + unfold(sameConsumed(init, trKey)) + unfold(sameConsumed(init.get, trKey)) + + }.ensuring( + trKey.isRight && + (trKey.get.consumed == init.get.consumed) && + (trKey.get.activeState == init.get.activeState) + ) + + traverseTransactionAdvanceDefined(tr, init) + + if (traverseTransactionFull(tr, init).isRight) { + + trKeyProp + unfold(init.get.advance(trFullEmpty.get)) + traverseTransactionFullCommute(tr, init) + traverseTransactionFullEmpty(tr) + unfold(emptyState(tr)) + + // activeState + traverseActiveStateAdvance(tr, trKey) + traverseActiveStateAdvance(tr, rempty) + unfold(sameActiveState(remptyNoKey, rempty)) + unfold(sameActiveState(State.empty, rempty)) + emptyAdvance(traverseActiveState(tr)._1) + + // key + traverseAddKeyIsConcatLeftGlobalKey(tr, init) + unfold(concatLeftGlobalKeys(init.get, collect(tr))) + traverseTransactionProp(tr, trKey) + unfold(sameGlobalKeys(trKey, trFull)) + unfold(sameGlobalKeys(trKey.get, trFull)) + traverseTransactionFullEmptyKeys(tr) + + // locallyCreated and consumed + traverseTransactionLC(tr, trKey, true) + traverseTransactionLC(tr, rempty, true) + traverseLCExtractInitLC( + tr, + init.get.locallyCreated, + init.get.consumed, + Set.empty[ContractId], + true, + ) + SetProperties.unionCommutativity( + traverseTransaction(tr, rempty).get.locallyCreated, + init.get.locallyCreated, + ) + SetProperties.equalsTransitivity( + init.get.locallyCreated ++ traverseTransaction(tr, rempty).get.locallyCreated, + traverseTransaction(tr, rempty).get.locallyCreated ++ init.get.locallyCreated, + traverseTransactionFull(tr, init).get.locallyCreated, + ) + SetAxioms.extensionality( + init.get.locallyCreated ++ traverseTransaction(tr, rempty).get.locallyCreated, + traverseTransactionFull(tr, init).get.locallyCreated, + ) + + traverseLCExtractInitConsumed( + tr, + init.get.locallyCreated, + Set.empty[ContractId], + init.get.consumed, + true, + ) + SetProperties.unionCommutativity( + traverseTransaction(tr, rempty).get.consumed, + init.get.consumed, + ) + SetProperties.equalsTransitivity( + init.get.consumed ++ traverseTransaction(tr, rempty).get.consumed, + traverseTransaction(tr, rempty).get.consumed ++ init.get.consumed, + traverseTransactionFull(tr, init).get.consumed, + ) + SetAxioms.extensionality( + init.get.consumed ++ traverseTransaction(tr, rempty).get.consumed, + traverseTransactionFull(tr, init).get.consumed, + ) + + // rollbackStack + traverseTransactionProp(tr, trKey) + sameStackTransitivity(init, trKey, trFull) + unfold(sameStack(init, trFull)) + unfold(sameStack(init.get, trFull)) + + } + + }.ensuring( + (traverseTransactionFull(tr, init).isRight ==> + (init.get + .advance(traverseTransactionFull(tr, Right[KeyInputError, State](State.empty)).get) + .get == + traverseTransactionFull(tr, init).get)) + ) + +} diff --git a/daml-lf/verification/tree/TransactionTreeInconsistency.scala b/daml-lf/verification/tree/TransactionTreeInconsistency.scala new file mode 100644 index 0000000000..4be9d7b88f --- /dev/null +++ b/daml-lf/verification/tree/TransactionTreeInconsistency.scala @@ -0,0 +1,1086 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package tree + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ +import utils.TreeProperties._ + +import transaction.{State} +import transaction.CSMInconsistency._ +import transaction.CSMInconsistencyDef._ +import transaction.CSMInvariantDef._ +import transaction.CSMEitherDef._ +import transaction.CSMEither._ +import transaction.CSMHelpers._ +import transaction.CSMKeysProperties._ +import transaction.CSMKeysPropertiesDef._ +import transaction.ContractStateMachine.{KeyMapping} + +import TransactionTreeInvariant._ +import TransactionTreeDef._ +import TransactionTree._ +import TransactionTreeKeysDef._ +import TransactionTreeKeys._ + +import TransactionTreeChecksDef._ +import TransactionTreeChecks._ + +/** The purpose of this file is proving that the advance method is defined if and only if the transaction traversal + * is defined as well. + */ +object TransactionTreeInconsistency { + + /** If an intermediate state of a transaction traversal is well-defined, then it will stay well-defined if it traverses + * a node for the second time. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param i The step number during which we process the node for the second time + */ + @opaque + @pure + def upDefined(tr: Tree[(NodeId, Node)], init: Either[KeyInputError, State], i: BigInt): Unit = { + require(i >= 0) + require(i < 2 * tr.size) + require(tr.isUnique) + require(scanTransaction(tr, init)(i)._1.isRight) + require(scanTransaction(tr, init)(i)._3 == TraversalDirection.Up) + + val s: Either[KeyInputError, State] = scanTransaction(tr, init)(i)._1 + val p = scanTransaction(tr, init)(i)._2 + unfold(traverseOutFun(s, p)) + + if (i == 2 * tr.size - 1) { + scanIndexingState(tr, init, traverseInFun, traverseOutFun, 0) + p._2 match { + case a: Node.Action => Trivial() + case r: Node.Rollback => + val (j, sub) = findBeginRollback(tr, init, i) + traverseTransactionProp(sub, beginRollback(scanTransaction(tr, init)(j)._1)) + scanTransactionProp(tr, init, j, i) + unfold( + sameStack( + beginRollback(scanTransaction(tr, init)(j)._1), + traverseTransaction(sub, beginRollback(scanTransaction(tr, init)(j)._1)), + ) + ) + unfold(propagatesError(scanTransaction(tr, init)(j)._1, s)) + unfold(beginRollback(scanTransaction(tr, init)(j)._1)) + unfold(scanTransaction(tr, init)(j)._1.get.beginRollback()) + unfold( + sameStack( + scanTransaction(tr, init)(j)._1.get.beginRollback(), + traverseTransaction(sub, beginRollback(scanTransaction(tr, init)(j)._1)), + ) + ) + unfold(endRollback(s)) + unfold(s.get.endRollback()) + } + } else { + scanIndexingState(tr, init, traverseInFun, traverseOutFun, i + 1) + p._2 match { + case a: Node.Action => Trivial() + case r: Node.Rollback => + val (j, sub) = findBeginRollback(tr, init, i) + traverseTransactionProp(sub, beginRollback(scanTransaction(tr, init)(j)._1)) + scanTransactionProp(tr, init, j, i) + unfold( + sameStack( + beginRollback(scanTransaction(tr, init)(j)._1), + traverseTransaction(sub, beginRollback(scanTransaction(tr, init)(j)._1)), + ) + ) + unfold(propagatesError(scanTransaction(tr, init)(j)._1, s)) + unfold(beginRollback(scanTransaction(tr, init)(j)._1)) + unfold(scanTransaction(tr, init)(j)._1.get.beginRollback()) + unfold( + sameStack( + scanTransaction(tr, init)(j)._1.get.beginRollback(), + traverseTransaction(sub, beginRollback(scanTransaction(tr, init)(j)._1)), + ) + ) + unfold(endRollback(s)) + unfold(s.get.endRollback()) + + } + } + + }.ensuring( + if (i == 2 * tr.size - 1) { + traverseTransaction(tr, init).isRight + } else { + scanTransaction(tr, init)(i + 1)._1.isRight + } + ) + + /** The actives keys does not change after entering a rollback node, processing a subtree and leaving + * the rollback node. + * + * @param tr The subtree that is being traversed + * @param s The initial state + * @param k The key we are querying in the active keys + */ + @opaque + @pure + def activeKeysGetTraverseTransaction( + tr: Tree[(NodeId, Node)], + s: Either[KeyInputError, State], + k: GlobalKey, + ): Unit = { + require(endRollback(traverseTransaction(tr, beginRollback(s))).isRight) + require(s.isRight) + + traverseTransactionProp(tr, beginRollback(s)) + endRollbackProp(traverseTransaction(tr, beginRollback(s))) + + unfold(propagatesError(beginRollback(s), traverseTransaction(tr, beginRollback(s)))) + unfold( + propagatesError( + traverseTransaction(tr, beginRollback(s)), + endRollback(traverseTransaction(tr, beginRollback(s))), + ) + ) + + sameGlobalKeysTransitivity(s, beginRollback(s), traverseTransaction(tr, beginRollback(s))) + sameGlobalKeysTransitivity( + s, + traverseTransaction(tr, beginRollback(s)), + endRollback(traverseTransaction(tr, beginRollback(s))), + ) + unfold(sameGlobalKeys(s, endRollback(traverseTransaction(tr, beginRollback(s))))) + unfold(sameGlobalKeys(s.get, endRollback(traverseTransaction(tr, beginRollback(s))))) + + unfold(beginRollback(s)) + unfold(s.get.beginRollback()) + unfold(endRollback(traverseTransaction(tr, beginRollback(s)))) + unfold(traverseTransaction(tr, beginRollback(s)).get.endRollback()) + unfold(sameStack(beginRollback(s), traverseTransaction(tr, beginRollback(s)))) + unfold(sameStack(s.get.beginRollback(), traverseTransaction(tr, beginRollback(s)))) + + activeKeysGetSameFields(s.get, endRollback(traverseTransaction(tr, beginRollback(s))).get, k) + }.ensuring( + endRollback(traverseTransaction(tr, beginRollback(s))).get.activeKeys.get(k) == + s.get.activeKeys.get(k) + ) + + /** If a key did not appear at given point in the transaction, then its entry in the active keys will be the same for + * all intermediate states until that point. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param k The key which did not appear yet + * @param j The step number of an other intermediate state that appeared before j + * @param j The step number for which the key did not appear yet + * + * @see The corresponding latex document for a pen and paper proof + */ + @pure @opaque + def doesNotAppearBeforeSameActiveKeysGet( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + k: GlobalKey, + i: BigInt, + j: BigInt, + ): Unit = { + decreases(j) + require(0 <= i) + require(i <= j) + require(j < 2 * tr.size) + + require(init.isRight) + require(scanTransaction(tr, init)(i)._1.isRight) + require(scanTransaction(tr, init)(j)._1.isRight) + + require(doesNotAppearBefore(tr, init, traverseInFun, traverseOutFun, k, j)) + + require(tr.isUnique) + require(traverseUnbound(tr)._3) + require(traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._3) + require( + stateInvariant(init)( + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + ) + require(containsAllKeys(tr, init)) + + if (i == j) { + Trivial() + } else { + val (sj, n, dir) = scanTransaction(tr, init)(j - 1) + + unfold(doesNotAppearBefore(tr, init, traverseInFun, traverseOutFun, k, j)) + scanIndexingState(tr, init, traverseInFun, traverseOutFun, j) + unfold(propagatesError(sj, scanTransaction(tr, init)(j)._1)) + doesNotAppearBeforeSameActiveKeysGet(tr, init, k, i, j - 1) + + if (dir == TraversalDirection.Down) { + unfold(traverseInFun(sj, n)) + n._2 match { + case a: Node.Action => + scanInvariant(tr, init, j - 1) + scanStateNodeCompatibility(tr, init, j - 1) + containsAllKeysImpliesDown(tr, init, j - 1) + unfold(containsKey(sj)(n._2)) + unfold(containsNodeKey(sj.get)(n._2)) + + // required + assert(!appearsAtIndex(tr, init, traverseInFun, traverseOutFun, k, j - 1)) + + unfold(appearsAtIndex(tr, init, traverseInFun, traverseOutFun, k, j - 1)) + handleNodeActiveKeysGet( + sj.get, + n._1, + a, + k, + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + case r: Node.Rollback => + unfold(beginRollback(scanTransaction(tr, init)(j - 1)._1)) + beginRollbackActiveKeysGet(sj.get, k) + } + } else { + unfold(traverseOutFun(sj, n)) + n._2 match { + case a: Node.Action => Trivial() + case r: Node.Rollback => + val (j2, sub) = findBeginRollback(tr, init, j - 1) + scanTransactionProp(tr, init, j2, j - 1) + unfold(propagatesError(scanTransaction(tr, init)(j2)._1, sj)) + activeKeysGetTraverseTransaction(sub, scanTransaction(tr, init)(j2)._1, k) + doesNotAppearBeforeSameActiveKeysGet(tr, init, k, j2, j - 1) + } + } + } + }.ensuring( + scanTransaction(tr, init)(i)._1.get.activeKeys.get(k) == + scanTransaction(tr, init)(j)._1.get.activeKeys.get(k) + ) + + /** If a key appeared for the first time at given point in the transaction, then its entry in the active keys will be + * the same for all intermediate states until that point. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param k The key that appeared for the first time at step j + * @param j The step number of an other intermediate state that appeared before j + * @param j The step number during which the key appeared + * @see The corresponding latex document for a pen and paper proof + */ + @pure @opaque + def firstAppearsSameActiveKeysGet( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + k: GlobalKey, + i: BigInt, + j: BigInt, + ): Unit = { + decreases(j) + require(0 <= i) + require(i <= j) + require(j < 2 * tr.size) + + require(init.isRight) + require(scanTransaction(tr, init)(i)._1.isRight) + require(scanTransaction(tr, init)(j)._1.isRight) + + require(firstAppears(tr, init, traverseInFun, traverseOutFun, k, j)) + + require(tr.isUnique) + require(traverseUnbound(tr)._3) + require(traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._3) + require( + stateInvariant(init)( + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + ) + require(containsAllKeys(tr, init)) + + unfold(firstAppears(tr, init, traverseInFun, traverseOutFun, k, j)) + doesNotAppearBeforeSameActiveKeysGet(tr, init, k, i, j) + + }.ensuring( + scanTransaction(tr, init)(i)._1.get.activeKeys.get(k) == + scanTransaction(tr, init)(j)._1.get.activeKeys.get(k) + ) + + /** If a key of a node appeared for the first time at given point in the transaction, then the state after handling + * that node is well-defined if and only if the pair key-mapping is in the active keys of the initial state. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param k The node whose key appeared for the first time. + * @param j The step during which the node is processed + * @see The corresponding latex document for a pen and paper proof + */ + @pure @opaque + def firstAppearsHandleNodeUndefined( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + n: Node.Action, + i: BigInt, + ): Unit = { + require(0 <= i) + require(i < 2 * tr.size) + + require(init.isRight) + require(scanTransaction(tr, init)(i)._1.isRight) + require(scanTransaction(tr, init)(i)._2._2 == n) + + require(n.gkeyOpt.forall(k => firstAppears(tr, init, traverseInFun, traverseOutFun, k, i))) + + require(tr.isUnique) + require(traverseUnbound(tr)._3) + require(traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._3) + require( + stateInvariant(init)( + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + ) + require(containsAllKeys(tr, init)) + + containsAllKeysImpliesDown(tr, init, i) + unfold(containsKey(scanTransaction(tr, init)(i)._1)(n)) + unfold(containsNodeKey(scanTransaction(tr, init)(i)._1.get)(n)) + + handleNodeUndefined(scanTransaction(tr, init)(i)._1.get, scanTransaction(tr, init)(i)._2._1, n) + + unfold( + inconsistencyCheck(scanTransaction(tr, init)(i)._1.get, n.gkeyOpt, nodeActionKeyMapping(n)) + ) + unfold(inconsistencyCheck(init.get, n.gkeyOpt, nodeActionKeyMapping(n))) + + n.gkeyOpt match { + case Some(k) => + scanIndexingState(tr, init, traverseInFun, traverseOutFun, 0) + firstAppearsSameActiveKeysGet(tr, init, k, 0, i) + case _ => Trivial() + } + + }.ensuring( + inconsistencyCheck(init.get, n.gkeyOpt, nodeActionKeyMapping(n)) == + scanTransaction(tr, init)(i)._1.get.handleNode(scanTransaction(tr, init)(i)._2._1, n).isLeft + ) + + /** If the key of a node already appeared in the transaction traversal, then its entry in the active keys of the state + * before that node, is independent of the initial state of the traversal. + * + * @param tr The tree that is being traversed + * @param init1 The initial state of the first traversal + * @param init2 The initial state of the first traversal + * @param node The node whose key already appeared in the traversals + * @param i The step number during which the node is processed + * @see The corresponding latex document for a pen and paper proof + */ + @pure @opaque + def appearsBeforeSameActiveKeysGet( + tr: Tree[(NodeId, Node)], + init1: Either[KeyInputError, State], + init2: Either[KeyInputError, State], + k: GlobalKey, + i: BigInt, + ): Unit = { + decreases(i) + require(0 <= i) + require(i < 2 * tr.size) + + require(init1.isRight) + require(init2.isRight) + require(scanTransaction(tr, init1)(i)._1.isRight) + require(scanTransaction(tr, init2)(i)._1.isRight) + + require(!doesNotAppearBefore(tr, init1, traverseInFun, traverseOutFun, k, i)) + + require(tr.isUnique) + require(traverseUnbound(tr)._3) + require(traverseLC(tr, init1.get.locallyCreated, init1.get.consumed, true)._3) + require( + stateInvariant(init1)( + traverseUnbound(tr)._1, + traverseLC(tr, init1.get.locallyCreated, init1.get.consumed, true)._1, + ) + ) + require(traverseLC(tr, init2.get.locallyCreated, init2.get.consumed, true)._3) + require( + stateInvariant(init2)( + traverseUnbound(tr)._1, + traverseLC(tr, init2.get.locallyCreated, init2.get.consumed, true)._1, + ) + ) + require(containsAllKeys(tr, init1)) + require(containsAllKeys(tr, init2)) + + @pure + @opaque + def appearsBeforeSameActiveKeysGetDefined(j: BigInt): Unit = { + require(0 <= j) + require(j <= i) + scanTransactionProp(tr, init1, j, i) + scanTransactionProp(tr, init2, j, i) + unfold(propagatesError(scanTransaction(tr, init1)(j)._1, scanTransaction(tr, init1)(i)._1)) + unfold(propagatesError(scanTransaction(tr, init2)(j)._1, scanTransaction(tr, init2)(i)._1)) + }.ensuring( + scanTransaction(tr, init1)(j)._1.isRight && + scanTransaction(tr, init2)(j)._1.isRight + ) + + unfold(doesNotAppearBefore(tr, init1, traverseInFun, traverseOutFun, k, i)) + + if (i == 0) { + Unreachable() + } else { + val (s1, n, dir) = scanTransaction(tr, init1)(i - 1) + val (s2, n2, dir2) = scanTransaction(tr, init2)(i - 1) + + scanIndexingNode( + tr, + init1, + init2, + traverseInFun, + traverseOutFun, + traverseInFun, + traverseOutFun, + i - 1, + ) + scanIndexingState(tr, init1, traverseInFun, traverseOutFun, i) + scanIndexingState(tr, init2, traverseInFun, traverseOutFun, i) + appearsBeforeSameActiveKeysGetDefined(i - 1) + + unfold(appearsAtIndex(tr, init1, traverseInFun, traverseOutFun, k, i - 1)) + + if (dir == TraversalDirection.Down) { + unfold(traverseInFun(s1, n)) + unfold(traverseInFun(s2, n2)) + n._2 match { + case a: Node.Action => + // required + assert( + appearsAtIndex( + tr, + init1, + traverseInFun, + traverseOutFun, + k, + i - 1, + ) == (a.gkeyOpt == Some(k)) + ) + + scanInvariant(tr, init1, i - 1) + scanStateNodeCompatibility(tr, init1, i - 1) + scanInvariant(tr, init2, i - 1) + scanStateNodeCompatibility(tr, init2, i - 1) + + @pure @opaque + def appearsBeforeSameActiveKeysGetContainsKey: Unit = { + containsAllKeysImpliesDown(tr, init1, i - 1) + unfold(containsKey(s1)(n._2)) + unfold(containsNodeKey(s1.get)(n._2)) + containsAllKeysImpliesDown(tr, init2, i - 1) + unfold(containsKey(s2)(n._2)) + unfold(containsNodeKey(s2.get)(n._2)) + + }.ensuring(containsActionKey(s1.get)(a) && containsActionKey(s2.get)(a)) + + appearsBeforeSameActiveKeysGetContainsKey + + if (a.gkeyOpt == Some(k)) { + handleNodeDifferentStatesActiveKeysGet( + s1.get, + s2.get, + n._1, + a, + traverseUnbound(tr)._1, + traverseLC(tr, init1.get.locallyCreated, init1.get.consumed, true)._1, + traverseUnbound(tr)._1, + traverseLC(tr, init2.get.locallyCreated, init2.get.consumed, true)._1, + ) + } else { + handleNodeActiveKeysGet( + s1.get, + n._1, + a, + k, + traverseUnbound(tr)._1, + traverseLC(tr, init1.get.locallyCreated, init1.get.consumed, true)._1, + ) + handleNodeActiveKeysGet( + s2.get, + n._1, + a, + k, + traverseUnbound(tr)._1, + traverseLC(tr, init2.get.locallyCreated, init2.get.consumed, true)._1, + ) + appearsBeforeSameActiveKeysGet(tr, init1, init2, k, i - 1) + } + + case r: Node.Rollback => + unfold(beginRollback(s1)) + unfold(beginRollback(s2)) + beginRollbackActiveKeysGet(s1.get, k) + beginRollbackActiveKeysGet(s2.get, k) + appearsBeforeSameActiveKeysGet(tr, init1, init2, k, i - 1) + } + } else { + unfold(traverseOutFun(s1, n)) + unfold(traverseOutFun(s2, n2)) + + n._2 match { + case a: Node.Action => appearsBeforeSameActiveKeysGet(tr, init1, init2, k, i - 1) + case r: Node.Rollback => + val (i2, sub) = findBeginRollback(tr, init1, init2, i - 1) + appearsBeforeSameActiveKeysGetDefined(i2) + activeKeysGetTraverseTransaction(sub, scanTransaction(tr, init1)(i2)._1, k) + if (!doesNotAppearBefore(tr, init1, traverseInFun, traverseOutFun, k, i2)) { + appearsBeforeSameActiveKeysGet(tr, init1, init2, k, i2) + activeKeysGetTraverseTransaction(sub, scanTransaction(tr, init2)(i2)._1, k) + } else { + val i1 = findFirstAppears(tr, init1, traverseInFun, traverseOutFun, k, i2, i) + + @pure @opaque + def appearsBeforeSameActiveKeysGetSameFirstAppears: Unit = { + appearsBeforeSameActiveKeysGetDefined(i1) + appearsBeforeSameActiveKeysGetDefined(i1 + 1) + + unfold(firstAppears(tr, init1, traverseInFun, traverseOutFun, k, i1)) + assert(appearsAtIndex(tr, init1, traverseInFun, traverseOutFun, k, i1)) + unfold(appearsAtIndex(tr, init1, traverseInFun, traverseOutFun, k, i1)) + scanIndexingNode( + tr, + init1, + init2, + traverseInFun, + traverseOutFun, + traverseInFun, + traverseOutFun, + i1, + ) + + scanTransaction(tr, init1)(i1)._2._2 match { + case a: Node.Action => + // required + assert( + scanTransaction(tr, init1)(i1)._2._2 == scanTransaction(tr, init2)(i1)._2._2 + ) + + @pure @opaque + def appearsBeforeSameActiveKeysGetSameFirstAppearsContainsKey: Unit = { + containsAllKeysImpliesDown(tr, init1, i1) + unfold(containsKey(scanTransaction(tr, init1)(i1)._1)(a)) + unfold(containsNodeKey(scanTransaction(tr, init1)(i1)._1.get)(a)) + containsAllKeysImpliesDown(tr, init2, i1) + unfold(containsKey(scanTransaction(tr, init2)(i1)._1)(a)) + unfold(containsNodeKey(scanTransaction(tr, init2)(i1)._1.get)(a)) + }.ensuring( + containsActionKey(scanTransaction(tr, init1)(i1)._1.get)(a) && + containsActionKey(scanTransaction(tr, init2)(i1)._1.get)(a) + ) + + appearsBeforeSameActiveKeysGetSameFirstAppearsContainsKey + + scanIndexingState(tr, init1, traverseInFun, traverseOutFun, i1 + 1) + scanIndexingState(tr, init2, traverseInFun, traverseOutFun, i1 + 1) + unfold( + traverseInFun( + scanTransaction(tr, init1)(i1)._1, + scanTransaction(tr, init1)(i1)._2, + ) + ) + unfold( + traverseInFun( + scanTransaction(tr, init2)(i1)._1, + scanTransaction(tr, init2)(i1)._2, + ) + ) + unfold( + handleNode( + scanTransaction(tr, init1)(i1)._1, + scanTransaction(tr, init1)(i1)._2._1, + a, + ) + ) + unfold( + handleNode( + scanTransaction(tr, init2)(i1)._1, + scanTransaction(tr, init2)(i1)._2._1, + a, + ) + ) + handleSameNodeActiveKeys( + scanTransaction(tr, init1)(i1)._1.get, + scanTransaction(tr, init2)(i1)._1.get, + scanTransaction(tr, init1)(i1)._2._1, + a, + ) + + case _ => Unreachable() + } + + }.ensuring( + scanTransaction(tr, init1)(i1)._1.get.activeKeys.get(k) == + scanTransaction(tr, init2)(i1)._1.get.activeKeys.get(k) + ) + + appearsBeforeSameActiveKeysGetDefined(i1) + firstAppearsSameActiveKeysGet(tr, init1, k, i2, i1) + appearsBeforeSameActiveKeysGetSameFirstAppears + firstAppearsSame( + tr, + init1, + init2, + traverseInFun, + traverseOutFun, + traverseInFun, + traverseOutFun, + k, + i1, + ) + firstAppearsSameActiveKeysGet(tr, init2, k, i2, i1) + activeKeysGetTraverseTransaction(sub, scanTransaction(tr, init2)(i2)._1, k) + } + + } + } + } + }.ensuring( + scanTransaction(tr, init1)(i)._1.get.activeKeys.get(k) == + scanTransaction(tr, init2)(i)._1.get.activeKeys.get(k) + ) + + /** If the key of a node already appeared in the transaction traversal, then the well-definedness of the state after + * having processed that node is independent of the initial state of the traversal. + * + * @param tr The tree that is being traversed + * @param init1 The initial state of the first traversal + * @param init2 The initial state of the first traversal + * @param n The node whose key already appeared in the traversals + * @param i The step number during which the node is processed + * @see The corresponding latex document for a pen and paper proof + */ + @pure @opaque + def appearsBeforeSameUndefined( + tr: Tree[(NodeId, Node)], + init1: Either[KeyInputError, State], + init2: Either[KeyInputError, State], + n: Node.Action, + i: BigInt, + ): Unit = { + decreases(i) + require(0 <= i) + require(i < 2 * tr.size) + + require(init1.isRight) + require(init2.isRight) + require(scanTransaction(tr, init1)(i)._1.isRight) + require(scanTransaction(tr, init2)(i)._1.isRight) + require(scanTransaction(tr, init1)(i)._2._2 == n) + + require( + n.gkeyOpt.forall(k => !doesNotAppearBefore(tr, init1, traverseInFun, traverseOutFun, k, i)) + ) + + require(tr.isUnique) + require(traverseUnbound(tr)._3) + require(traverseLC(tr, init1.get.locallyCreated, init1.get.consumed, true)._3) + require( + stateInvariant(init1)( + traverseUnbound(tr)._1, + traverseLC(tr, init1.get.locallyCreated, init1.get.consumed, true)._1, + ) + ) + require(traverseLC(tr, init2.get.locallyCreated, init2.get.consumed, true)._3) + require( + stateInvariant(init2)( + traverseUnbound(tr)._1, + traverseLC(tr, init2.get.locallyCreated, init2.get.consumed, true)._1, + ) + ) + require(containsAllKeys(tr, init1)) + require(containsAllKeys(tr, init2)) + + val (s1, n1, dir) = scanTransaction(tr, init1)(i) + val (s2, n2, dir2) = scanTransaction(tr, init2)(i) + + scanIndexingNode( + tr, + init1, + init2, + traverseInFun, + traverseOutFun, + traverseInFun, + traverseOutFun, + i, + ) + + @pure @opaque + def appearsBeforeSameUndefinedContainsKey: Unit = { + containsAllKeysImpliesDown(tr, init1, i) + unfold(containsKey(s1)(n)) + unfold(containsNodeKey(s1.get)(n)) + containsAllKeysImpliesDown(tr, init2, i) + unfold(containsKey(s2)(n)) + unfold(containsNodeKey(s2.get)(n)) + }.ensuring(containsActionKey(s1.get)(n) && containsActionKey(s2.get)(n)) + + appearsBeforeSameUndefinedContainsKey + + handleNodeUndefined(s1.get, n1._1, n) + handleNodeUndefined(s2.get, n2._1, n) + + unfold(inconsistencyCheck(s1.get, n.gkeyOpt, nodeActionKeyMapping(n))) + unfold(inconsistencyCheck(s2.get, n.gkeyOpt, nodeActionKeyMapping(n))) + + n.gkeyOpt match { + case Some(k) => + appearsBeforeSameActiveKeysGet(tr, init1, init2, k, i) + unfold(inconsistencyCheck(s1.get, n.gkeyOpt, nodeActionKeyMapping(n))) + unfold(inconsistencyCheck(s2.get, n.gkeyOpt, nodeActionKeyMapping(n))) + case _ => Trivial() + } + }.ensuring( + scanTransaction(tr, init1)(i)._1.get + .handleNode(scanTransaction(tr, init1)(i)._2._1, n) + .isLeft == + scanTransaction(tr, init2)(i)._1.get.handleNode(scanTransaction(tr, init2)(i)._2._1, n).isLeft + ) + + /** If the key of a node appears for the first time at a given point of the transaction traversal and the intermediate + * state at this point is well-defined, then the state after processing the next node will be well-defined as well if + * and only if the key of the node is mapped to same mapping in the activeKeys of the initial state and the global keys of the + * empty state (after having collected the keys of the tree). + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param node The node whose key appears for the first time + * @param i The step number during which the node is processed + * @see The corresponding latex document for a pen and paper proof + */ + @pure + @opaque + def firstAppearsHandleNodeUndefinedEmpty( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + n: Node.Action, + i: BigInt, + ): Unit = { + require(0 <= i) + require(i < 2 * tr.size) + + require(init.isRight) + require(scanTransaction(tr, init)(i)._1.isRight) + require(scanTransaction(tr, init)(i)._2._2 == n) + + require(n.gkeyOpt.isDefined) + require(firstAppears(tr, init, traverseInFun, traverseOutFun, n.gkeyOpt.get, i)) + + require(tr.isUnique) + require(traverseUnbound(tr)._3) + require(traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._3) + require( + stateInvariant(init)( + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + ) + require(containsAllKeys(tr, init)) + + scanIndexingNode( + tr, + init, + Map.empty[GlobalKey, KeyMapping], + traverseInFun, + traverseOutFun, + collectFun, + (z, t) => z, + i, + ) + firstAppearsSame( + tr, + init, + Map.empty[GlobalKey, KeyMapping], + traverseInFun, + traverseOutFun, + collectFun, + (z, t) => z, + n.gkeyOpt.get, + i, + ) + collectGet(tr, i, n) + unfold(emptyState(tr)) + unfold(emptyState(tr).globalKeys.contains) + unfold(emptyState(tr).globalKeys(n.gkeyOpt.get)) + firstAppearsHandleNodeUndefined(tr, init, n, i) + + }.ensuring( + emptyState(tr).globalKeys.contains(n.gkeyOpt.get) && + (inconsistencyCheck(init.get, n.gkeyOpt, emptyState(tr).globalKeys(n.gkeyOpt.get)) == + scanTransaction(tr, init)(i)._1.get + .handleNode(scanTransaction(tr, init)(i)._2._1, n) + .isLeft) + ) + + /** If the key of a node already appeared in the transaction traversal, the final state of the traversal with the + * empty state as initial state, and the state before processing the node is well-defined as well, then the state + * after processing the node will be well-defined as well. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param node The node whose key already appeared + * @param i The step number during which the node is processed + * + * @see The corresponding latex document for a pen and paper proof + */ + @pure + @opaque + def appearsBeforeHandleNodeUndefinedEmpty( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + n: Node.Action, + i: BigInt, + ): Unit = { + require(0 <= i) + require(i < 2 * tr.size) + + require(init.isRight) + require(scanTransaction(tr, init)(i)._1.isRight) + require(scanTransaction(tr, init)(i)._2._2 == n) + + require(n.gkeyOpt.isDefined) + require(!doesNotAppearBefore(tr, init, traverseInFun, traverseOutFun, n.gkeyOpt.get, i)) + require(traverseTransaction(tr, Right[KeyInputError, State](emptyState(tr))).isRight) + + require(tr.isUnique) + require(traverseUnbound(tr)._3) + require(traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._3) + require( + stateInvariant(init)( + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + ) + require( + stateInvariant(Right[KeyInputError, State](emptyState(tr)))( + traverseUnbound(tr)._1, + traverseLC(tr, Set.empty[ContractId], Set.empty[ContractId], true)._1, + ) + ) + require(containsAllKeys(tr, init)) + + val rempty: Either[KeyInputError, State] = Right[KeyInputError, State](emptyState(tr)) + + scanIndexingNode( + tr, + init, + rempty, + traverseInFun, + traverseOutFun, + traverseInFun, + traverseOutFun, + i, + ) + + val si = scanTransaction(tr, init)(i)._1 + val ni = scanTransaction(tr, init)(i)._2 + if (i == 2 * tr.size - 1) { + scanIndexingState(tr, init, traverseInFun, traverseOutFun, 0) + unfold(traverseOutFun(si, ni)) + } else { + scanIndexingState(tr, init, traverseInFun, traverseOutFun, i + 1) + scanIndexingState(tr, rempty, traverseInFun, traverseOutFun, i + 1) + emptyContainsAllKeys(tr) + unfold(emptyState(tr)) + SetProperties.emptySubsetOf(init.get.locallyCreated) + SetProperties.emptySubsetOf(init.get.consumed) + traverseLCSubsetOf( + tr, + Set.empty[ContractId], + Set.empty[ContractId], + true, + init.get.locallyCreated, + init.get.consumed, + true, + ) + traverseTransactionDefined(tr, rempty, i) + unfold(propagatesError(scanTransaction(tr, rempty)(i)._1, traverseTransaction(tr, rempty))) + traverseTransactionDefined(tr, rempty, i + 1) + unfold( + propagatesError(scanTransaction(tr, rempty)(i + 1)._1, traverseTransaction(tr, rempty)) + ) + + unfold(traverseInFun(si, ni)) + unfold(traverseOutFun(si, ni)) + unfold(traverseInFun(scanTransaction(tr, rempty)(i)._1, ni)) + unfold(traverseOutFun(scanTransaction(tr, rempty)(i)._1, ni)) + + appearsBeforeSameUndefined(tr, init, rempty, n, i) + +// assert(scanTransaction(tr, init)(i + 1)._1.isRight) + } + + }.ensuring( + if (i == 2 * tr.size - 1) { + traverseTransaction(tr, init).isRight + } else { + scanTransaction(tr, init)(i + 1)._1.isRight + } + ) + + /** If a transaction is defined when it is being traversed with emptyState as the initial state, then it is also defined + * when starting with a given state if the map obtained by gathering all the key - key mappings of the tree is a submap + * of the active keys of that state. + * + * @param tr The transaction that is being processed + * @param init The initial state of the traversal + */ + @pure + @opaque + def traverseTransactionEmptyDefined( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): Unit = { + + require(init.isRight) + require(traverseTransaction(tr, Right[KeyInputError, State](emptyState(tr))).isRight) + require(tr.isUnique) + require(traverseUnbound(tr)._3) + require(traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._3) + require( + stateInvariant(init)( + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + ) + require( + stateInvariant(Right[KeyInputError, State](emptyState(tr)))( + traverseUnbound(tr)._1, + traverseLC(tr, Set.empty[ContractId], Set.empty[ContractId], true)._1, + ) + ) + require(containsAllKeys(tr, init)) + + val p: GlobalKey => Boolean = + k => init.get.activeKeys.get(k) == emptyState(tr).globalKeys.get(k) + + if ( + emptyState(tr).globalKeys.keySet.forall(p) && + !traverseTransaction(tr, init).isRight + ) { + val i = traverseNotProp(tr, init, traverseInFun, traverseOutFun, x => x.isRight) + + val (si, n, dir) = scanTransaction(tr, init)(i) + + if (dir == TraversalDirection.Up) { + upDefined(tr, init, i) + } else { + scanIndexingState(tr, init, traverseInFun, traverseOutFun, i + 1) + unfold(traverseInFun(si, n)) + n._2 match { + case a: Node.Action => + a.gkeyOpt match { + case Some(k) => + if (doesNotAppearBefore(tr, init, traverseInFun, traverseOutFun, k, i)) { + unfold(appearsAtIndex(tr, init, traverseInFun, traverseOutFun, k, i)) + unfold(firstAppears(tr, init, traverseInFun, traverseOutFun, k, i)) + firstAppearsHandleNodeUndefinedEmpty(tr, init, a, i) + unfold(inconsistencyCheck(init.get, a.gkeyOpt, emptyState(tr).globalKeys(k))) + unfold(inconsistencyCheck(init.get, k, emptyState(tr).globalKeys(k))) + MapProperties.keySetContains(emptyState(tr).globalKeys, k) + SetProperties.forallContains(emptyState(tr).globalKeys.keySet, p, k) + } else { + appearsBeforeHandleNodeUndefinedEmpty(tr, init, a, i) + } + case None() => + containsAllKeysImpliesDown(tr, init, i) + unfold(containsKey(si)(n._2)) + unfold(containsNodeKey(si.get)(n._2)) + handleNodeUndefined(si.get, n._1, a) + unfold(inconsistencyCheck(si.get, a.gkeyOpt, nodeActionKeyMapping(a))) + } + case r: Node.Rollback => unfold(propagatesError(beginRollback(si), si)) + } + } + } else if ( + !emptyState(tr).globalKeys.keySet.forall(p) && + traverseTransaction(tr, init).isRight + ) { + val k: GlobalKey = SetProperties.notForallWitness(emptyState(tr).globalKeys.keySet, p) + MapProperties.keySetContains(emptyState(tr).globalKeys, k) + unfold(emptyState(tr)) + val i: BigInt = collectContains(tr, k) + + val (si, n, dir) = scanTransaction(tr, init)(i) + firstAppearsSame( + tr, + Map.empty[GlobalKey, KeyMapping], + init, + collectFun, + (z, t) => z, + traverseInFun, + traverseOutFun, + k, + i, + ) + unfold(firstAppears(tr, init, traverseInFun, traverseOutFun, k, i)) + unfold(appearsAtIndex(tr, init, traverseInFun, traverseOutFun, k, i)) + traverseTransactionDefined(tr, init, i) + traverseTransactionDefined(tr, init, i + 1) + unfold(propagatesError(si, traverseTransaction(tr, init))) + unfold(propagatesError(scanTransaction(tr, init)(i + 1)._1, traverseTransaction(tr, init))) + scanIndexingState(tr, init, traverseInFun, traverseOutFun, i + 1) + unfold(traverseInFun(si, n)) + n._2 match { + case a: Node.Action if (dir == TraversalDirection.Down) && (a.gkeyOpt == Some(k)) => + @pure + @opaque + def traverseTransactionEmptyDefinedContains: Unit = { + containsAllKeysImpliesDown(tr, init, i) + unfold(containsKey(si)(n._2)) + scanTransactionProp(tr, init, i) + unfold(sameGlobalKeys(init, si)) + unfold(sameGlobalKeys(init.get, si)) + containsNodeKeySameGlobalKeys(init.get, si.get, n._2) + unfold(containsNodeKey(init.get)(n._2)) + activeKeysContainsKey(init.get, a) + }.ensuring(init.get.activeKeys.contains(k)) + + firstAppearsHandleNodeUndefinedEmpty(tr, init, a, i) + unfold(inconsistencyCheck(init.get, a.gkeyOpt, emptyState(tr).globalKeys(k))) + unfold(inconsistencyCheck(init.get, k, emptyState(tr).globalKeys(k))) + traverseTransactionEmptyDefinedContains + unfold(init.get.activeKeys.contains) + // required + assert(!init.get.activeKeys.get(k).exists(_ != emptyState(tr).globalKeys(k))) + assert(init.get.activeKeys.get(k).isDefined) + case _ => Unreachable() + } + } + + }.ensuring( + emptyState(tr).globalKeys.keySet.forall(k => + init.get.activeKeys.get(k) == emptyState(tr).globalKeys.get(k) + ) + == + traverseTransaction(tr, init).isRight + ) + +} diff --git a/daml-lf/verification/tree/TransactionTreeInvariant.scala b/daml-lf/verification/tree/TransactionTreeInvariant.scala new file mode 100644 index 0000000000..fbcb8a4976 --- /dev/null +++ b/daml-lf/verification/tree/TransactionTreeInvariant.scala @@ -0,0 +1,270 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package tree + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ +import utils.TreeProperties._ +import transaction.CSMHelpers._ + +import transaction.CSMEitherDef._ +import transaction.CSMEither._ + +import transaction.CSMLocallyCreatedProperties._ +import transaction.{State} + +import transaction.CSMKeysPropertiesDef._ + +import transaction.CSMInvariantDef._ +import transaction.CSMInvariant._ + +import transaction.ContractStateMachine.{KeyMapping, ActiveLedgerState} + +import TransactionTreeDef._ + +import TransactionTreeChecksDef._ +import TransactionTreeChecks._ + +import TransactionTreeKeys._ +import TransactionTreeKeysDef._ + +/** Properties on how invariants are preserved throughout transaction traversals. + */ +object TransactionTreeInvariant { + + /** If the traversals in [[TransactionTreeChecks]] did not raise any error, then every pair intermediate state - node + * in the traversal respects the [[transaction.TransactioInvariantDef.stateNodeCompatibility]] condition. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param i The step number during which the node is processed + */ + @pure @opaque + def scanStateNodeCompatibility( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): Unit = { + require(i < 2 * tr.size) + require(0 <= i) + require(init.isRight) + require(traverseUnbound(tr)._3) + require(traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._3) + require(scanTransaction(tr, init)(i)._1.isRight) + + val (si, n, dir) = scanTransaction(tr, init)(i) + val ilc = init.get.locallyCreated + val icons = init.get.consumed + + unfold( + stateNodeCompatibility( + si.get, + n._2, + traverseUnbound(tr)._1, + traverseLC(tr, ilc, icons, true)._1, + dir, + ) + ) + + dir match { + case TraversalDirection.Up => Trivial() + case TraversalDirection.Down => + scanIndexingNode( + tr, + init, + (ilc, icons, true), + traverseInFun, + traverseOutFun, + buildLC, + (z, t) => z, + i, + ) + scanIndexingNode( + tr, + init, + (Set.empty[ContractId], Set.empty[ContractId], true), + traverseInFun, + traverseOutFun, + unboundFun, + (z, t) => z, + i, + ) + scanTraverseLCPropDown(tr, ilc, icons, true, i) + scanTraverseUnboundPropDown(tr, i) + scanTransactionLC(tr, init, true, i) + } + + }.ensuring( + stateNodeCompatibility( + scanTransaction(tr, init)(i)._1.get, + scanTransaction(tr, init)(i)._2._2, + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + scanTransaction(tr, init)(i)._3, + ) + ) + + /** If the traversals in [[TransactionTreeChecks]] did not raise any error and if the initial state contains all the + * keys of the tree, then every intermediate state of the transaction traversal preserves the invariants. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param i The step number of the intermediate state + */ + @pure @opaque + def scanInvariant( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): Unit = { + require(i >= 0) + require(i < 2 * tr.size) + require(init.isRight) + require(traverseUnbound(tr)._3) + require(traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._3) + require( + stateInvariant(init)( + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + ) + require(containsAllKeys(tr, init)) + + val p: Either[KeyInputError, State] => Boolean = x => + stateInvariant(x)( + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + + if (!p(scanTransaction(tr, init)(i)._1)) { + val j = scanNotProp(tr, init, traverseInFun, traverseOutFun, p, i) + scanIndexingState(tr, init, traverseInFun, traverseOutFun, j + 1) + + val s = scanTransaction(tr, init)(j)._1 + val n = scanTransaction(tr, init)(j)._2 + val dir = scanTransaction(tr, init)(j)._3 + + s match { + case Left(_) => + if (dir == TraversalDirection.Down) { + unfold(propagatesError(s, traverseInFun(s, n))) + } else { + unfold(propagatesError(s, traverseOutFun(s, n))) + } + case Right(state) => + unfold(traverseInFun(s, n)) + unfold(traverseOutFun(s, n)) + + n._2 match { + case a: Node.Action => + if (dir == TraversalDirection.Down) { + containsAllKeysImpliesDown(tr, init, j) + unfold(containsKey(s)(n._2)) + unfold(containsNodeKey(state)(n._2)) + scanStateNodeCompatibility(tr, init, j) + handleNodeInvariant( + state, + n._1, + a, + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + } + case r: Node.Rollback => + if (dir == TraversalDirection.Down) { + unfold(beginRollback(s)) + stateInvariantBeginRollback( + state, + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + } else { + unfold(endRollback(s)) + stateInvariantEndRollback( + state, + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + } + } + } + } + + }.ensuring( + stateInvariant(scanTransaction(tr, init)(i)._1)( + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + ) + + /** If the traversals in [[TransactionTreeChecks]] did not raise any error and if the initial state contains all the + * keys of the tree, then the state obtained after processing the transaction preserves the invariants. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + */ + @pure + @opaque + def traverseInvariant(tr: Tree[(NodeId, Node)], init: Either[KeyInputError, State]): Unit = { + require(init.isRight) + require(traverseUnbound(tr)._3) + require(traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._3) + require( + stateInvariant(init)( + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + ) + require(containsAllKeys(tr, init)) + + if (tr.size == 0) { + Trivial() + } else { + + scanIndexingState(tr, init, traverseInFun, traverseOutFun, 0) + scanInvariant(tr, init, 2 * tr.size - 1) + + val s = scanTransaction(tr, init)(2 * tr.size - 1)._1 + val n = scanTransaction(tr, init)(2 * tr.size - 1)._2 + + s match { + case Left(_) => unfold(propagatesError(s, traverseOutFun(s, n))) + case Right(state) => + unfold(traverseOutFun(s, n)) + n._2 match { + case a: Node.Action => Trivial() + case r: Node.Rollback => + unfold(endRollback(s)) + stateInvariantEndRollback( + state, + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + } + } + } + + }.ensuring( + stateInvariant(traverseTransaction(tr, init))( + traverseUnbound(tr)._1, + traverseLC(tr, init.get.locallyCreated, init.get.consumed, true)._1, + ) + ) + +} diff --git a/daml-lf/verification/tree/TransactionTreeKeys.scala b/daml-lf/verification/tree/TransactionTreeKeys.scala new file mode 100644 index 0000000000..aa8698013a --- /dev/null +++ b/daml-lf/verification/tree/TransactionTreeKeys.scala @@ -0,0 +1,1109 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package tree + +import stainless.lang.{ + unfold, + decreases, + BooleanDecorations, + Either, + Some, + None, + Option, + Right, + Left, +} +import stainless.annotation._ +import stainless.collection._ +import utils.Value.ContractId +import utils.Transaction.{DuplicateContractKey, InconsistentContractKey, KeyInputError} +import utils._ +import utils.TreeProperties._ + +import transaction.{State} +import transaction.CSMHelpers._ +import transaction.CSMKeysPropertiesDef._ +import transaction.CSMKeysProperties._ +import transaction.ContractStateMachine.{KeyMapping, ActiveLedgerState} + +import TransactionTreeDef._ +import TransactionTree._ + +/** This files introduces two tree traversals: + * - [[TransactionTreeKeysDef.containsAllKeys] takes a state and checks that it contains all the keys + * that appear in the tree. + * - [[TransactionTreeKeysDef.collect]] collects all the keys, with the mapping of the node in which + * they appeared for the first time, into a map. + * + * In fact in [[TransactionTreeFull]] we prove that handling a transaction is equivalent to collecting first all the + * keys into a map, concatenating it with the global keys of the initial state and then finally traversing the tree + * without modifying the globalKeys. + */ +object TransactionTreeKeysDef { + + /** Indicates whether a [[GlobalKey]] k is equals to the key of the [[Node]] that is being traversed for the first time + * during the i-th step of a [[utils.Tree]] traversal. If that's the second time the node is visited or if the node + * has not a well-defined key, returns false. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal. + * @param f1 The function that is executed when a node of the tree is traversed for the first time + * @param f2 The function that is executed when a node of the tree is traversed for the second time + * @param k The key we are querying + * @param i The step number in the traversal + */ + @pure + @opaque + def appearsAtIndex[Z]( + tr: Tree[(NodeId, Node)], + init: Z, + f1: (Z, (NodeId, Node)) => Z, + f2: (Z, (NodeId, Node)) => Z, + k: GlobalKey, + i: BigInt, + ): Boolean = { + require(i >= 0) + require(i < 2 * tr.size) + + tr.scan(init, f1, f2)(i)._2._2 match { + case a: Node.Action + if (tr.scan(init, f1, f2)(i)._3 == TraversalDirection.Down) && (a.gkeyOpt == Some(k)) => + true + case _ => false + } + } + + /** Indicates whether a [[GlobalKey]] did not yet appear before a given step in a tree traversal. + * + * More precisely, asserts that k is not equals to the key of the [[Node]] that is being traversed for the first + * time during all the steps between 0 and i excluded of a [[utils.Tree]] traversal. If that's the second time a node + * is visited or if the node has not a well-defined key, it is ignored. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal. + * @param f1 The function that is executed when a node of the tree is traversed for the first time + * @param f2 The function that is executed when a node of the tree is traversed for the second time + * @param k The key we are querying + * @param i The strict upper bound on the checked steps + * @see The corresponding latex document for a pen and paper definition + */ + @pure + @opaque + def doesNotAppearBefore[Z]( + tr: Tree[(NodeId, Node)], + init: Z, + f1: (Z, (NodeId, Node)) => Z, + f2: (Z, (NodeId, Node)) => Z, + k: GlobalKey, + i: BigInt, + ): Boolean = { + decreases(i) + require(i >= 0) + require(i < 2 * tr.size) + + if (i == 0) { + true + } else { + !appearsAtIndex(tr, init, f1, f2, k, i - 1) && doesNotAppearBefore(tr, init, f1, f2, k, i - 1) + } + } + + /** Indicates whether a [[GlobalKey]] apears for the first time at step i in a tree traversal + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal. + * @param f1 The function that is executed when a node of the tree is traversed for the first time + * @param f2 The function that is executed when a node of the tree is traversed for the second time + * @param k The key we are querying + * @param i The step in which the key appears for the first time + * @see The corresponding latex document for a pen and paper definition + */ + @pure + @opaque + def firstAppears[Z]( + tr: Tree[(NodeId, Node)], + init: Z, + f1: (Z, (NodeId, Node)) => Z, + f2: (Z, (NodeId, Node)) => Z, + k: GlobalKey, + i: BigInt, + ): Boolean = { + require(i >= 0) + require(i < 2 * tr.size) + + appearsAtIndex(tr, init, f1, f2, k, i) && doesNotAppearBefore(tr, init, f1, f2, k, i) + } + + /** Function used the first time we visit a [[Node]] during the [[Tree]] traversal that checks whether a [[State]] + * contains all the [[GlobalKey]] of the tree. + * + * Since the result of the function is cumulative, then the function can returns true only if b is true as well. + * + * @param init The state that is being checked. This is an hyperparameter of the function that does not change + * throughout the traversal + * @param b The status of the condition we are checking. If it is set to false then we already found earlier + * in the tree a key such that init does not contain it + * @param n The node we are currently checking + */ + @pure + @opaque + def containsAllKeysFun( + init: Either[KeyInputError, State] + )(b: Boolean, n: (NodeId, Node)): Boolean = { + b && containsKey(init)(n._2) + }.ensuring(res => (res ==> b)) + + /** Checks whether a state contains all the keys of a tree up to a given step of the traversal. + * + * Note that in the traversal [[containsAllKeysFun]] is used when visiting a [[Node]] for both the first and the second + * time. This choice makes some statements easier to prove. Others are easier to prove with the alternative version + * [[scanContainsAllKeys]] where the identity function is used when visiting a [[Node]] for the second time. + * + * @param tr The tree that is being traversed + * @param init The states for which we are checking the keys + * @param i The step at which we stop the traversal (exclusive) + */ + @pure + def containsAllKeysBefore( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): Boolean = { + require(i >= 0) + require(i < 2 * tr.size) + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(i)._1 + } + + /** Checks whether a state contains all the keys of a tree. + * + * Note that in the traversal [[containsAllKeysFun]] is used when visiting a [[Node]] for both the first and the second + * time. This choice makes some statements easier to prove. Others are easier to prove with the alternative version + * [[traverseContainsAllKeys]] where the identity function is used when visiting a [[Node]] for the second time. + * + * @param tr The tree that is being traversed + * @param init The states which we are checking the keys + */ + @pure + def containsAllKeys(tr: Tree[(NodeId, Node)], init: Either[KeyInputError, State]): Boolean = { + tr.traverse(true, containsAllKeysFun(init), containsAllKeysFun(init)) + } + + /** List of triples whose entries are: + * - Whether the state given as argument contains all the keys of a tree up to a given step of the traversal + * - The node that is being processed during this step + * - A traversal direction i.e. whether this is the first or the second time we visit the node + * + * Note that in the traversal [[containsAllKeysFun]] is used when visiting a [[Node]] only the first + * time. This choice makes some statements easier to prove. Others are easier to prove with the alternative version + * [[containsAllKeys]] where the function is used when visiting a [[Node]] both the first and the second time. + * + * @param tr The tree that is being traversed + * @param init The states which we are checking the keys + */ + @pure + def scanContainsAllKeys( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): List[(Boolean, (NodeId, Node), TraversalDirection)] = { + tr.scan(true, containsAllKeysFun(init), (z, t) => z) + } + + /** Checks whether a state contains all the keys of a tree. + * + * Note that in the traversal [[containsAllKeysFun]] is used when visiting a [[Node]] only the first + * time. This choice makes some statements easier to prove. Others are easier to prove with the alternative version + * [[containsAllKeys]] where the function is used when visiting a [[Node]] both the first and the second time. + * + * @param tr The tree that is being traversed + * @param init The states which we are checking the keys + */ + @pure + def traverseContainsAllKeys( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + ): Boolean = { + tr.traverse(true, containsAllKeysFun(init), (z, t) => z) + } + + /** Function used the first time we visit a [[Node]] during the [[Tree]] traversal that collect all the global keys + * of tree. Only collects the key if it did not appear before. + * + * @param m The state of the traversal before reaching the node, i.e. the map with all the previous mappings. + * @param n The node we are currently processing + */ + @pure + @opaque + def collectFun(m: Map[GlobalKey, KeyMapping], n: (NodeId, Node)): Map[GlobalKey, KeyMapping] = { + n._2 match { + case a: Node.Action if a.gkeyOpt.isDefined && !m.contains(a.gkeyOpt.get) => + m.updated(a.gkeyOpt.get, nodeActionKeyMapping(a)) + case _ => m + } + } + + /** Map containing all keys of the tree with the contract bound to them the first time they appeared in the + * traversal. + * + * @param tr The tree that is being traversed + */ + @pure + def collect(tr: Tree[(NodeId, Node)]): Map[GlobalKey, KeyMapping] = { + tr.traverse(Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z) + } + + /** List of triples whose respective entries are: + * - The map of keys-contracts before the i-th step of the traversal + * - The pair node id - node that is handle during the i-th step + * - The direction i.e. if that's the first or the second time we enter the node + * + * @param tr The tree that is being traversed + */ + @pure + def collectTrace( + tr: Tree[(NodeId, Node)] + ): List[(Map[GlobalKey, KeyMapping], (NodeId, Node), TraversalDirection)] = { + tr.scan(Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z) + } + + /** Result of [[State.empty]] after having traversed the tree a first time to collect the keys. + * + * @param tr The tree that has been traversed + */ + @pure + @opaque + def emptyState(tr: Tree[(NodeId, Node)]): State = { + State(Set.empty[ContractId], Set.empty[ContractId], collect(tr), ActiveLedgerState.empty, Nil()) + } + +} + +object TransactionTreeKeys { + + import TransactionTreeKeysDef._ + + /** [[appearsAtIndex]] only depends on the shape of the tree and is independent from the initial state or the + * functions used during the travesal + * + * @param tr The tree that is being traversed + * @param init1 The initial state of the first traversal. + * @param init2 The initial state of the second traversal. + * @param f11 + * The function that is executed when a node of the tree is traversed for the first time in the first traversal + * @param f12 + * The function that is executed when a node of the tree is traversed for the second time in the first traversal + * @param f21 + * The function that is executed when a node of the tree is traversed for the first time in the second traversal + * @param f22 + * The function that is executed when a node of the tree is traversed for the second time in the second traversal + * @param k The key that is being queried + * @param i The step number we are looking at + */ + @pure + @opaque + def appearsAtIndexSame[Z1, Z2]( + tr: Tree[(NodeId, Node)], + init1: Z1, + init2: Z2, + f11: (Z1, (NodeId, Node)) => Z1, + f12: (Z1, (NodeId, Node)) => Z1, + f21: (Z2, (NodeId, Node)) => Z2, + f22: (Z2, (NodeId, Node)) => Z2, + k: GlobalKey, + i: BigInt, + ): Unit = { + require(i >= 0) + require(i < 2 * tr.size) + unfold(appearsAtIndex(tr, init1, f11, f12, k, i)) + unfold(appearsAtIndex(tr, init2, f21, f22, k, i)) + scanIndexingNode(tr, init1, init2, f11, f12, f21, f22, i) + + }.ensuring( + appearsAtIndex(tr, init1, f11, f12, k, i) == appearsAtIndex(tr, init2, f21, f22, k, i) + ) + + /** [[doesNotAppearBefore]] only depends on the shape of the tree and is independent from the initial state or the + * functions used during the travesal + * + * @param tr The tree that is being traversed + * @param init1 The initial state of the first traversal. + * @param init2 The initial state of the second traversal. + * @param f11 + * The function that is executed when a node of the tree is traversed for the first time in the first traversal + * @param f12 + * The function that is executed when a node of the tree is traversed for the second time in the first traversal + * @param f21 + * The function that is executed when a node of the tree is traversed for the first time in the second traversal + * @param f22 + * The function that is executed when a node of the tree is traversed for the second time in the second traversal + * @param k The key that is being queried + * @param i The stric upper bound on the steps + * + * @see The corresponding latex document for a pen and paper proof + */ + @pure @opaque + def doesNotAppearBeforeSame[Z1, Z2]( + tr: Tree[(NodeId, Node)], + init1: Z1, + init2: Z2, + f11: (Z1, (NodeId, Node)) => Z1, + f12: (Z1, (NodeId, Node)) => Z1, + f21: (Z2, (NodeId, Node)) => Z2, + f22: (Z2, (NodeId, Node)) => Z2, + k: GlobalKey, + i: BigInt, + ): Unit = { + decreases(i) + require(i >= 0) + require(i < 2 * tr.size) + unfold(doesNotAppearBefore(tr, init1, f11, f12, k, i)) + unfold(doesNotAppearBefore(tr, init2, f21, f22, k, i)) + + if (i == 0) { + Trivial() + } else { + appearsAtIndexSame(tr, init1, init2, f11, f12, f21, f22, k, i - 1) + doesNotAppearBeforeSame(tr, init1, init2, f11, f12, f21, f22, k, i - 1) + } + + }.ensuring( + doesNotAppearBefore(tr, init1, f11, f12, k, i) == doesNotAppearBefore(tr, init2, f21, f22, k, i) + ) + + /** [[firstAppears]] only depends on the shape of the tree and is independent from the initial state or the + * functions used during the travesal + * + * @param tr The tree that is being traversed + * @param init1 The initial state of the first traversal. + * @param init2 The initial state of the second traversal. + * @param f11 + * The function that is executed when a node of the tree is traversed for the first time in the first traversal + * @param f12 + * The function that is executed when a node of the tree is traversed for the second time in the first traversal + * @param f21 + * The function that is executed when a node of the tree is traversed for the first time in the second traversal + * @param f22 + * The function that is executed when a node of the tree is traversed for the second time in the second traversal + * @param k The key that is being queried + * @param i The step number we are looking at + * @see The corresponding latex document for a pen and paper proof + */ + + @pure + @opaque + def firstAppearsSame[Z1, Z2]( + tr: Tree[(NodeId, Node)], + init1: Z1, + init2: Z2, + f11: (Z1, (NodeId, Node)) => Z1, + f12: (Z1, (NodeId, Node)) => Z1, + f21: (Z2, (NodeId, Node)) => Z2, + f22: (Z2, (NodeId, Node)) => Z2, + k: GlobalKey, + i: BigInt, + ): Unit = { + require(i >= 0) + require(i < 2 * tr.size) + unfold(firstAppears(tr, init1, f11, f12, k, i)) + unfold(firstAppears(tr, init2, f21, f22, k, i)) + + appearsAtIndexSame(tr, init1, init2, f11, f12, f21, f22, k, i) + doesNotAppearBeforeSame(tr, init1, init2, f11, f12, f21, f22, k, i) + + }.ensuring( + firstAppears(tr, init1, f11, f12, k, i) == firstAppears(tr, init2, f21, f22, k, i) + ) + + /** If a key k does not appear in a tree traversal before step j then it does not appear either for any 0 ≤ i ≤ j. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal. + * @param f1 The function that is executed when a node of the tree is traversed for the first time + * @param f2 The function that is executed when a node of the tree is traversed for the second time + */ + @pure + @opaque + def doesNotAppearBeforeDiffIndex[Z]( + tr: Tree[(NodeId, Node)], + init: Z, + f1: (Z, (NodeId, Node)) => Z, + f2: (Z, (NodeId, Node)) => Z, + k: GlobalKey, + i: BigInt, + j: BigInt, + ): Unit = { + decreases(j - i) + require(0 <= i) + require(i <= j) + require(j < 2 * tr.size) + require(doesNotAppearBefore(tr, init, f1, f2, k, j)) + + unfold(doesNotAppearBefore(tr, init, f1, f2, k, j)) + if (j == i) { + Trivial() + } else { + doesNotAppearBeforeDiffIndex(tr, init, f1, f2, k, i, j - 1) + } + }.ensuring(doesNotAppearBefore(tr, init, f1, f2, k, i)) + + /** If a key k does not appear in a tree traversal before the i1-th step but appears before the i2-th step, then + * there exists a j such that: + * - i1 ≤ j < i2 + * - k first appears at step j + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal. + * @param f1 The function that is executed when a node of the tree is traversed for the first time + * @param f2 The function that is executed when a node of the tree is traversed for the second time + * + * @return j the step during which the key appears for the first time + * + * @see The corresponding latex document for a pen and paper proof + */ + @pure + @opaque + def findFirstAppears[Z]( + tr: Tree[(NodeId, Node)], + init: Z, + f1: (Z, (NodeId, Node)) => Z, + f2: (Z, (NodeId, Node)) => Z, + k: GlobalKey, + i1: BigInt, + i2: BigInt, + ): BigInt = { + decreases(i2 - i1) + require(0 <= i1) + require(i1 < i2) + require(i2 < 2 * tr.size) + require(doesNotAppearBefore(tr, init, f1, f2, k, i1)) + require(!doesNotAppearBefore(tr, init, f1, f2, k, i2)) + + unfold(doesNotAppearBefore(tr, init, f1, f2, k, i1 + 1)) + + if (appearsAtIndex(tr, init, f1, f2, k, i1)) { + unfold(firstAppears(tr, init, f1, f2, k, i1)) + i1 + } else { + findFirstAppears(tr, init, f1, f2, k, i1 + 1, i2) + } + + }.ensuring(j => + i1 <= j && j < i2 && + firstAppears(tr, init, f1, f2, k, j) + ) + + /** If a state contains all the keys of a tree up to a given point, then it contains all the keys of the tree + * up to any previous point. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal. + * @param i The point such that init contains all the keys before it. + * @param j A point in the traversal before i + */ + @pure + @opaque + def containsAllKeysBeforeImplies( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + j: BigInt, + ): Unit = { + decreases(i - j) + require(0 <= j) + require(j <= i) + require(i < 2 * tr.size) + require(containsAllKeysBefore(tr, init, i)) + + if (j == i) { + Trivial() + } else { + scanIndexingState(tr, true, containsAllKeysFun(init), containsAllKeysFun(init), i) + unfold( + containsAllKeysFun(init)( + containsAllKeysBefore(tr, init, i - 1), + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(i - 1)._2, + ) + ) + containsAllKeysBeforeImplies(tr, init, i - 1, j) + } + + }.ensuring(containsAllKeysBefore(tr, init, j)) + + /** If a state contains all the keys of a tree, then it contains all the keys of the tree up to any previous point. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal. + * @param i A point in the traversal. + */ + @pure + @opaque + def containsAllKeysImplies( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): Unit = { + require(containsAllKeys(tr, init)) + require(0 <= i) + require(i < 2 * tr.size) + + scanIndexingState(tr, true, containsAllKeysFun(init), containsAllKeysFun(init), 0) + unfold( + containsAllKeysFun(init)( + containsAllKeysBefore(tr, init, 2 * tr.size - 1), + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(2 * tr.size - 1)._2, + ) + ) + containsAllKeysBeforeImplies(tr, init, 2 * tr.size - 1, i) + + }.ensuring(containsAllKeysBefore(tr, init, i)) + + /** If a state contains all the keys of a tree and the same state is the inital state of a transaction traversal, then + * any at any step of this traversal the intermediate state contains the node that is visited during the step. + * + * @param tr The transaction that is being processed + * @param init The initial state of the traversal which contains all the keys of the tree. + * @param i A step in the traversal + */ + @pure + @opaque + def containsAllKeysImpliesDown( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): Unit = { + require(containsAllKeys(tr, init)) + require(0 <= i) + require(i < 2 * tr.size) + + scanIndexingNode( + tr, + init, + true, + traverseInFun, + traverseOutFun, + containsAllKeysFun(init), + containsAllKeysFun(init), + i, + ) + unfold( + containsAllKeysFun(init)(containsAllKeysBefore(tr, init, i), scanTransaction(tr, init)(i)._2) + ) + scanTransactionProp(tr, init, i) + containsKeySameGlobalKeys( + init, + scanTransaction(tr, init)(i)._1, + scanTransaction(tr, init)(i)._2._2, + ) + + if (i == 2 * tr.size - 1) { + scanIndexingState(tr, true, containsAllKeysFun(init), containsAllKeysFun(init), 0) + } else { + scanIndexingState(tr, true, containsAllKeysFun(init), containsAllKeysFun(init), i + 1) + containsAllKeysImplies(tr, init, i + 1) + } + + }.ensuring( + containsKey(scanTransaction(tr, init)(i)._1)(scanTransaction(tr, init)(i)._2._2) + ) + + /** If a state contains all the keys of a tree up to a given point, then it contains all the keys of the tree + * up to any previous point. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal. + * @param i The point such that init contains all the keys before it. + * @param j A point in the traversal before i + */ + @pure + @opaque + def containsAllKeysBeforeAltImplies( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + j: BigInt, + ): Unit = { + decreases(i - j) + require(0 <= j) + require(j <= i) + require(i < 2 * tr.size) + require(scanContainsAllKeys(tr, init)(i)._1) + + if (j == i) { + Trivial() + } else { + scanIndexingState(tr, true, containsAllKeysFun(init), (z, t) => z, i) + unfold( + containsAllKeysFun(init)( + tr.scan(true, containsAllKeysFun(init), (z, t) => z)(i - 1)._1, + tr.scan(true, containsAllKeysFun(init), (z, t) => z)(i - 1)._2, + ) + ) + containsAllKeysBeforeAltImplies(tr, init, i - 1, j) + } + + }.ensuring(scanContainsAllKeys(tr, init)(j)._1) + + /** If a state contains all the keys of a tree, then it contains all the keys of the tree up to any previous point. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal. + * @param i A point in the traversal. + */ + @pure + @opaque + def containsAllKeysAltImplies( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): Unit = { + require(traverseContainsAllKeys(tr, init)) + require(0 <= i) + require(i < 2 * tr.size) + + scanIndexingState(tr, true, containsAllKeysFun(init), (z, t) => z, 0) + containsAllKeysBeforeAltImplies(tr, init, 2 * tr.size - 1, i) + + }.ensuring(scanContainsAllKeys(tr, init)(i)._1) + + /** States the equivalence between [[TransactionTreeKeysDef.containsAllKeysBefore]] and + * [[TransactionTreeKeysDef.scanContainsAllKeys]]. Both functions behave in a similar manner when visiting each node for + * the first time. However, the former also checks that the state given in argument contains the node key when visiting + * it for the second time whereas the latter does not do anything. Some claims are easier to prove using one version + * or the other. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal. + */ + @pure @opaque + def containsAllKeysBeforeAlt( + tr: Tree[(NodeId, Node)], + init: Either[KeyInputError, State], + i: BigInt, + ): Unit = { + decreases(i) + require(i >= 0) + require(i < 2 * tr.size) + scanIndexingNode( + tr, + true, + true, + containsAllKeysFun(init), + containsAllKeysFun(init), + containsAllKeysFun(init), + (z, t) => z, + i, + ) + if (i == 0) { + scanIndexingState(tr, true, containsAllKeysFun(init), (z, t) => z, 0) + scanIndexingState(tr, true, containsAllKeysFun(init), containsAllKeysFun(init), 0) + } else { + scanIndexingNode( + tr, + true, + true, + containsAllKeysFun(init), + containsAllKeysFun(init), + containsAllKeysFun(init), + (z, t) => z, + i - 1, + ) + containsAllKeysBeforeAlt(tr, init, i - 1) + scanIndexingState(tr, true, containsAllKeysFun(init), (z, t) => z, i) + scanIndexingState(tr, true, containsAllKeysFun(init), containsAllKeysFun(init), i) + unfold( + containsAllKeysFun(init)( + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(i - 1)._1, + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(i - 1)._2, + ) + ) + if ( + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(i - 1) + ._3 == TraversalDirection.Up + ) { + + val j = findDown(tr, true, containsAllKeysFun(init), (z, t) => z, i - 1) + scanIndexingState(tr, true, containsAllKeysFun(init), (z, t) => z, j + 1) + unfold( + containsAllKeysFun(init)( + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(j)._1, + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(j)._2, + ) + ) + if (tr.scan(true, containsAllKeysFun(init), (z, t) => z)(i)._1) { + containsAllKeysBeforeAltImplies(tr, init, i, j + 1) + } + } else { + Trivial() + } + } + }.ensuring( + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(i) == scanContainsAllKeys( + tr, + init, + )(i) + ) + + /** States the equivalence between [[TransactionTreeKeysDef.containsAllKeys]] and + * [[TransactionTreeKeysDef.traverseContainsAllKeys]]. Both functions behave in a similar manner when visiting each node for + * the first time. However, the former also checks that the state given in argument contains the node key when visiting + * it for the second time whereas the latter does not do anything. Some claims are easier to prove using one version + * or the other. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal. + */ + @pure + @opaque + def containsAllKeysAlt(tr: Tree[(NodeId, Node)], init: Either[KeyInputError, State]): Unit = { + if (tr.size > 0) { + scanIndexingState(tr, true, containsAllKeysFun(init), (z, t) => z, 0) + scanIndexingState(tr, true, containsAllKeysFun(init), containsAllKeysFun(init), 0) + + scanIndexingNode( + tr, + true, + true, + containsAllKeysFun(init), + containsAllKeysFun(init), + containsAllKeysFun(init), + (z, t) => z, + 2 * tr.size - 1, + ) + containsAllKeysBeforeAlt(tr, init, 2 * tr.size - 1) + unfold( + containsAllKeysFun(init)( + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(2 * tr.size - 1)._1, + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(2 * tr.size - 1)._2, + ) + ) + if ( + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(2 * tr.size - 1) + ._3 == TraversalDirection.Up + ) { + + val j = findDown(tr, true, containsAllKeysFun(init), (z, t) => z, 2 * tr.size - 1) + scanIndexingState(tr, true, containsAllKeysFun(init), (z, t) => z, j + 1) + unfold( + containsAllKeysFun(init)( + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(j)._1, + tr.scan(true, containsAllKeysFun(init), containsAllKeysFun(init))(j)._2, + ) + ) + if (tr.traverse(true, containsAllKeysFun(init), (z, t) => z)) { + containsAllKeysAltImplies(tr, init, j + 1) + } + } else { + Trivial() + } + } else { + Trivial() + } + }.ensuring(containsAllKeys(tr, init) == traverseContainsAllKeys(tr, init)) + + /** Expresses [[TransactionTreeKeysDef.collectFun]] as the concatenation of the map given as argument and a map + * representing the node contribution. + * + * @param m The map of the previous state + * @param n The node that is being processed + */ + @pure + @opaque + def collectFunConcat(m: Map[GlobalKey, KeyMapping], n: (NodeId, Node)): Unit = { + unfold(collectFun(m, n)) + unfold(nodeKeyMap(n._2)) + n._2 match { + case a: Node.Action => + unfold(actionKeyMap(a)) + unfold(optionKeyMap(a.gkeyOpt, nodeActionKeyMapping(a))) + if (a.gkeyOpt.isDefined) { + val k = a.gkeyOpt.get + if (!m.contains(a.gkeyOpt.get)) { + MapProperties.updatedCommutativity(m, k, nodeActionKeyMapping(a)) + MapAxioms.extensionality( + Map[GlobalKey, KeyMapping](k -> nodeActionKeyMapping(a)) ++ m, + m.updated(k, nodeActionKeyMapping(a)), + ) + } else { + MapProperties.keySetContains(m, k) + MapProperties.singletonKeySet(k, nodeActionKeyMapping(a)) + SetProperties.singletonSubsetOf(m.keySet, k) + SetProperties.equalsSubsetOfTransitivity( + Set(k), + Map[GlobalKey, KeyMapping](k -> nodeActionKeyMapping(a)).keySet, + m.keySet, + ) + MapProperties.concatSubsetOfEquals( + Map[GlobalKey, KeyMapping](k -> nodeActionKeyMapping(a)), + m, + ) + MapAxioms.extensionality( + Map[GlobalKey, KeyMapping](k -> nodeActionKeyMapping(a)) ++ m, + m, + ) + } + } else { + MapProperties.concatEmpty(m) + MapAxioms.extensionality(Map.empty[GlobalKey, KeyMapping] ++ m, m) + } + case _ => + MapProperties.concatEmpty(m) + MapAxioms.extensionality(Map.empty[GlobalKey, KeyMapping] ++ m, m) + } + }.ensuring(collectFun(m, n) == nodeKeyMap(n._2) ++ m) + + /** Any intermediate state of the map that collects the keys of the tree, is a submap of the map obtained at the end + * of the traversal + * + * @param tr The tree that is being traversed + * @param i The step number of the intermediate map + * + * @see The corresponding latex document for a pen and paper proof + */ + @pure @opaque + def collectTraceProp(tr: Tree[(NodeId, Node)], i: BigInt): Unit = { + require(i >= 0) + require(i < 2 * tr.size) + MapProperties.submapOfReflexivity(collect(tr)) + if (!collectTrace(tr)(i)._1.submapOf(collect(tr))) { + val j = scanNotPropRev( + tr, + Map.empty[GlobalKey, KeyMapping], + collectFun, + (z, t) => z, + x => x.submapOf(collect(tr)), + i, + ) + val sj = collectTrace(tr)(j)._1 + val n = collectTrace(tr)(j)._2 + MapProperties.submapOfReflexivity(sj) + if (j == 2 * tr.size - 1) { + scanIndexingState(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, 0) + } else { + scanIndexingState(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, j + 1) + unfold(collectFun(sj, n)) + MapProperties.submapOfTransitivity(sj, collectFun(sj, n), collect(tr)) + } + } + }.ensuring(collectTrace(tr)(i)._1.submapOf(collect(tr))) + + /** A key does not appear in the tree before a given step if and only if the map obtained by collecting all the key + * up to that step does not contain it. + * + * @see The corresponding latex document for a pen and paper proof + */ + @pure + @opaque + def collectTraceDoesNotAppear(tr: Tree[(NodeId, Node)], k: GlobalKey, i: BigInt): Unit = { + decreases(i) + require(i >= 0) + require(i < 2 * tr.size) + + unfold(doesNotAppearBefore(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, k, i)) + if (i == 0) { + scanIndexingState(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, 0) + MapProperties.emptyContains[GlobalKey, KeyMapping](k) + + } else { + val si = collectTrace(tr)(i - 1)._1 + val n = collectTrace(tr)(i - 1)._2 + + collectTraceDoesNotAppear(tr, k, i - 1) + scanIndexingState(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, i) + unfold( + appearsAtIndex(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, k, i - 1) + ) + unfold(collectFun(si, n)) + n._2 match { + case a: Node.Action if a.gkeyOpt.isDefined && !si.contains(a.gkeyOpt.get) => + MapProperties.updatedContains(si, a.gkeyOpt.get, nodeActionKeyMapping(a), k) + case a: Node.Action if a.gkeyOpt.isDefined => + MapProperties.updatedContains(si, a.gkeyOpt.get, nodeActionKeyMapping(a), k) + case _ => Trivial() + } + } + }.ensuring( + doesNotAppearBefore(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, k, i) == + !collectTrace(tr)(i)._1.contains(k) + ) + + /** A key does not appear in the tree if and only if the map obtained by collecting all the key does not contain it. + * + * @see The corresponding latex document for a pen and paper proof + */ + @pure + @opaque + def collectDoesNotAppear(tr: Tree[(NodeId, Node)], k: GlobalKey): Unit = { + require(tr.size > 0) + collectTraceDoesNotAppear(tr, k, 2 * tr.size - 1) + scanIndexingState(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, 0) + unfold(collectFun(collectTrace(tr)(2 * tr.size - 1)._1, collectTrace(tr)(2 * tr.size - 1)._2)) + }.ensuring( + doesNotAppearBefore( + tr, + Map.empty[GlobalKey, KeyMapping], + collectFun, + (z, t) => z, + k, + 2 * tr.size - 1, + ) == + !collect(tr).contains(k) + ) + + /** If the key of a node appears for the first time at a given step of the traversal, then the map containing all the + * keys of the tree, will have an entry with the key and the contract of the node. + * + * @param tr The tree that is being traversed + * @param i The step number during which the node is processed + * @param n The node whose key appeared for the first time at step i + * + * @see The corresponding latex document for a pen and paper proof + */ + @pure + @opaque + def collectGet(tr: Tree[(NodeId, Node)], i: BigInt, n: Node.Action): Unit = { + require(i >= 0) + require(i < 2 * tr.size) + require(n.gkeyOpt.isDefined) + require(collectTrace(tr)(i)._2._2 == n) + require( + firstAppears(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, n.gkeyOpt.get, i) + ) + + unfold( + firstAppears(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, n.gkeyOpt.get, i) + ) + collectTraceDoesNotAppear(tr, n.gkeyOpt.get, i) + + unfold( + appearsAtIndex( + tr, + Map.empty[GlobalKey, KeyMapping], + collectFun, + (z, t) => z, + n.gkeyOpt.get, + i, + ) + ) + if (i == 2 * tr.size - 1) { + Unreachable() + } else { + scanIndexingState(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, i + 1) + unfold(collectFun(collectTrace(tr)(i)._1, collectTrace(tr)(i)._2)) + collectTraceProp(tr, i + 1) + MapAxioms.submapOfGet(collectTrace(tr)(i + 1)._1, collect(tr), n.gkeyOpt.get) + } + + }.ensuring( + collect(tr).get(n.gkeyOpt.get) == Some(nodeActionKeyMapping(n)) + ) + + /** If [[TransactionTreeKeysDef.collect]] contains a key after having traversed a transaction, then there exists a + * step during which the key appeared for the first time. + * @param tr The tree that is being traversed + * @param k The key contained in traversal result + * @return The step number where the key appeared for the first time + * @see The corresponding latex document for a pen and paper proof + */ + @pure + @opaque + def collectContains(tr: Tree[(NodeId, Node)], k: GlobalKey): BigInt = { + require(collect(tr).contains(k)) + + if (tr.size == 0) { + MapProperties.emptyContains[GlobalKey, KeyMapping](k) + Unreachable() + } else { + collectDoesNotAppear(tr, k) + unfold( + doesNotAppearBefore(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, k, 0) + ) + findFirstAppears( + tr, + Map.empty[GlobalKey, KeyMapping], + collectFun, + (z, t) => z, + k, + 0, + 2 * tr.size - 1, + ) + } + }.ensuring(i => + (tr.size > BigInt(0)) + i >= BigInt (0) && i < 2 * tr.size - 1 && + firstAppears(tr, Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z, k, i) + ) + + /** The empty state whose global keys are replaced with the collection of all the keys, contains all the keys of the + * tree. + * + * @param tr The tree whose keys have been collected. + */ + @pure @opaque + def emptyContainsAllKeys(tr: Tree[(NodeId, Node)]): Unit = { + + val rempty: Either[KeyInputError, State] = Right[KeyInputError, State](emptyState(tr)) + + containsAllKeysAlt(tr, rempty) + if (!containsAllKeys(tr, rempty)) { + val j = traverseNotProp(tr, true, containsAllKeysFun(rempty), (z, t) => z, z => z) + + if (j == 2 * tr.size - 1) { + scanIndexingState(tr, true, containsAllKeysFun(rempty), (z, t) => z, 0) + } else { + scanIndexingState(tr, true, containsAllKeysFun(rempty), (z, t) => z, j + 1) + scanIndexingNode( + tr, + true, + Map.empty[GlobalKey, KeyMapping], + containsAllKeysFun(rempty), + (z, t) => z, + collectFun, + (z, t) => z, + j, + ) + unfold( + containsAllKeysFun(rempty)( + tr.scan(true, containsAllKeysFun(rempty), (z, t) => z)(j)._1, + tr.scan(true, containsAllKeysFun(rempty), (z, t) => z)(j)._2, + ) + ) + unfold(containsKey(rempty)(tr.scan(true, containsAllKeysFun(rempty), (z, t) => z)(j)._2._2)) + unfold( + containsNodeKey(emptyState(tr))( + tr.scan(true, containsAllKeysFun(rempty), (z, t) => z)(j)._2._2 + ) + ) + tr.scan(true, containsAllKeysFun(rempty), (z, t) => z)(j)._2._2 match { + case a: Node.Action => + unfold(containsActionKey(emptyState(tr))(a)) + unfold(containsOptionKey(emptyState(tr))(a.gkeyOpt)) + a.gkeyOpt match { + case None() => Trivial() + case Some(k) => + unfold(containsKey(emptyState(tr))(k)) + unfold(emptyState(tr)) + + scanIndexingState( + tr, + Map.empty[GlobalKey, KeyMapping], + collectFun, + (z, t) => z, + j + 1, + ) + unfold( + collectFun( + tr.scan(Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z)(j)._1, + tr.scan(Map.empty[GlobalKey, KeyMapping], collectFun, (z, t) => z)(j)._2, + ) + ) + collectTraceProp(tr, j + 1) + MapProperties.submapOfContains(collectTrace(tr)(j + 1)._1, collect(tr), k) + } + case r: Node.Rollback => Trivial() + } + } + } + }.ensuring(containsAllKeys(tr, Right[KeyInputError, State](emptyState(tr)))) + +} diff --git a/daml-lf/verification/utils/AxiomaticMap.scala b/daml-lf/verification/utils/AxiomaticMap.scala new file mode 100644 index 0000000000..05c9cae03a --- /dev/null +++ b/daml-lf/verification/utils/AxiomaticMap.scala @@ -0,0 +1,257 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package utils + +import stainless.lang.{unfold, Option, None, Some, BooleanDecorations} +import stainless.annotation._ +import scala.annotation.targetName +import scala.collection.{Map => ScalaMap} + +import MapProperties._ +import SetProperties._ + +case class Map[K, V](@pure @extern toScala: ScalaMap[K, V]) { + + import MapAxioms._ + + @pure @extern + def get(k: K): Option[V] = ??? + + @pure @opaque + def getOrElse(k: K, d: V): V = get(k).getOrElse(d) + + @pure + def view: Map[K, V] = this + + @pure + def toMap: Map[K, V] = this + + @pure @extern + def preimage(s: Set[V]): Set[K] = ??? + + @pure + @extern + def find(f: ((K, V)) => Boolean): Option[(K, V)] = ??? + @pure + @extern + def foldLeft[B](z: B)(op: (B, (K, V)) => B): B = ??? + @pure @opaque + def preimage(v: V): Set[K] = preimage(Set(v)) + + @pure @extern + def submapOf(m2: Map[K, V]): Boolean = ??? + + @pure @opaque + def apply(k: K): V = { + require(contains(k)) + unfold(contains) + get(k).get + }.ensuring(Some[V](_) == get(k)) + + @pure @opaque + def contains: K => Boolean = get(_).isDefined + + @pure @extern + def concat(m2: Map[K, V]): Map[K, V] = Map(toScala ++ m2.toScala) + @pure @alias + def ++(s2: Map[K, V]): Map[K, V] = concat(s2) + + @pure @opaque + def updated(k: K, v: V): Map[K, V] = { + val res = concat(Map(k -> v)) + + def updatedProperties: Unit = { + unfold(res.contains) + singletonGet(k, v, k) + concatGet(this, Map(k -> v), k) + + // values + singletonValues(k, v) + concatValues(this, Map(k -> v)) + unfold(values.incl(v)) + unionEqualsRight(values, Map(k -> v).values, Set(v)) + SetProperties.subsetOfEqualsTransitivity( + res.values, + values ++ Map(k -> v).values, + values ++ Set(v), + ) + + singletonKeySet(k, v) + concatKeySet(this, Map(k -> v)) + unionEqualsRight(keySet, Map(k -> v).keySet, Set(k)) + SetProperties.equalsTransitivity(res.keySet, keySet ++ Map(k -> v).keySet, keySet ++ Set(k)) + unfold(keySet.incl(k)) + if (!submapOf(res) && (!contains(k) || (get(k) == Some[V](v)))) { + val w = notSubmapOfWitness(this, res) + concatGet(this, Map(k -> v), w) + singletonGet(k, v, w) + } + if (submapOf(res) && contains(k) && (get(k) != Some[V](v))) { + submapOfGet(this, res, k) + } + }.ensuring( + res.contains(k) && + res.get(k) == Some[V](v) && + res.keySet === keySet + k && + res.values.subsetOf(values + v) && + (submapOf(res) == (!contains(k) || (get(k) == Some[V](v)))) + ) + + updatedProperties + res + }.ensuring(res => + res.contains(k) && + res.get(k) == Some[V](v) && + res.keySet === keySet + k && + res.values.subsetOf(values + v) && + (submapOf(res) == (!contains(k) || get(k) == Some[V](v))) + ) + + @pure @opaque + def keySet: Set[K] = preimage(values) + + @pure @extern + def values: Set[V] = ??? + + @pure @targetName("mapValuesArgSingle") + def mapValues[V2](f: V => V2): Map[K, V2] = { + mapValues((k: K) => f) + } + + @pure @opaque @targetName("mapValuesArgPair") + def mapValues[V2](f: K => V => V2): Map[K, V2] = { + map { case (k, v) => k -> f(k)(v) } + } + + @pure + @extern + def map[K2, V2](f: ((K, V)) => (K2, V2)): Map[K2, V2] = Map(toScala.map(f)) + + @pure @extern + def filter(f: ((K, V)) => Boolean): Map[K, V] = Map(toScala.filter(f)) + + @pure + def ===(m2: Map[K, V]): Boolean = { + submapOf(m2) && m2.submapOf(this) + } + + @pure + def =/=(m2: Map[K, V]): Boolean = !(this === m2) + +} + +object Map { + + @pure @extern + def empty[K, V]: Map[K, V] = Map[K, V](ScalaMap[K, V]()) + + @pure @extern + def apply[K, V](p: (K, V)): Map[K, V] = Map[K, V](ScalaMap[K, V](p)) + +} + +object MapAxioms { + + /** Getting from a concatenation is equivalent to getting in the second map and in case of failure in the first one. + */ + @pure @extern + def concatGet[K, V](m1: Map[K, V], m2: Map[K, V], ks: K): Unit = {}.ensuring( + (m1 ++ m2).get(ks) == m2.get(ks).orElse(m1.get(ks)) + ) + + /** Getting in an empty map will always result in a failure + */ + @pure + @extern + def emptyGet[K, V](ks: K): Unit = {}.ensuring(Map.empty[K, V].get(ks) == None[V]()) + + /** Getting from a singleton will be defined if and only if the query is the key of the pair. In this case the returned + * mapping is the value of the pair. + */ + @pure @extern + def singletonGet[K, V](k: K, v: V, ks: K): Unit = {}.ensuring( + Map(k -> v).get(ks) == (if (k == ks) Some(v) else None[V]()) + ) + + /** Getting in a map after having mapped its value is the same as getting in the original map and afterward applying + * the function on the returned mapping + */ + @pure @extern + def mapValuesGet[K, V, V2](m: Map[K, V], f: K => V => V2, k: K): Unit = {}.ensuring( + m.mapValues(f).get(k) == m.get(k).map(f(k)) + ) + + /** If a map is a submap of another one and it contains a given key, then that key is bound to the same mapping in both + * maps + */ + @pure @extern + def submapOfGet[K, V](m1: Map[K, V], m2: Map[K, V], k: K): Unit = { + require(m1.submapOf(m2)) + require(m1.contains(k)) + }.ensuring(m1.get(k) == m2.get(k)) + + /** If a map is not submap of another one then we can exhibit a key in the first map such that their mapping is different + */ + @pure + @extern + def notSubmapOfWitness[K, V](m1: Map[K, V], m2: Map[K, V]): K = { + require(!m1.submapOf(m2)) + ??? : K + }.ensuring(res => m1.contains(res) && (m1.get(res) != m2.get(res))) + + @pure + @extern + def preimageGet[K, V](m: Map[K, V], s: Set[V], k: K): Unit = {}.ensuring( + m.preimage(s).contains(k) == (m.get(k).exists(s.contains)) + ) + + /** If the values of a map contain a value then we can exhibit a key such that its mapping in the map is equal to the + * value + */ + @pure @extern + def valuesWitness[K, V](m: Map[K, V], v: V): K = { + require(m.values.contains(v)) + (??? : K) + }.ensuring(res => m.get(res) == Some[V](v)) + + /** If map contains a key then the values of the map contain its mapping + */ + @pure + @extern + def valuesContains[K, V](m: Map[K, V], k: K): Unit = { + require(m.contains(k)) + }.ensuring(m.values.contains(m(k))) + + /** If find is defined then + */ + @pure + @extern + def findGet[K, V](m: Map[K, V], f: ((K, V)) => Boolean): Unit = { + require(m.find(f).isDefined) + }.ensuring( + m.get(m.find(f).get._1) == Some[V](m.find(f).get._2) && f((m.find(f).get)) + ) + + /** If if there is a value in the map that satisfies a given predicate then find is defined. + */ + @pure + @extern + def findDefined[K, V](m: Map[K, V], f: ((K, V)) => Boolean, k: K, v: V): Unit = { + require(m.get(k) == Some[V](v)) + require(f(k, v)) + }.ensuring( + m.find(f).isDefined + ) + + /** Extensionality axiom + * + * If two maps are submap of each other then their are equal + */ + @pure @extern + def extensionality[K, V](m1: Map[K, V], m2: Map[K, V]): Unit = { + require(m1 === m2) + }.ensuring(m1 == m2) + +} diff --git a/daml-lf/verification/utils/AxiomaticSet.scala b/daml-lf/verification/utils/AxiomaticSet.scala new file mode 100644 index 0000000000..7219269d2e --- /dev/null +++ b/daml-lf/verification/utils/AxiomaticSet.scala @@ -0,0 +1,287 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package utils + +import stainless.lang._ +import stainless.annotation._ +import scala.collection.{Set => ScalaSet} + +import SetProperties._ + +case class Set[T](@pure @extern toScala: ScalaSet[T]) { + + import SetAxioms._ + + @pure @opaque + def isEmpty: Boolean = this === Set.empty[T] + + @pure @extern + def size: BigInt = toScala.size + + @pure @opaque + def contains: T => Boolean = e => exists(_ == e) + + @pure @alias + def apply(e: T): Boolean = contains(e) + + @pure @extern + def forall(f: T => Boolean): Boolean = toScala.forall(f) + + @pure @extern + def exists(f: T => Boolean): Boolean = toScala.exists(f) + + @pure @opaque + def subsetOf(s2: Set[T]): Boolean = forall(s2.contains) + + @pure @extern + def union(s2: Set[T]): Set[T] = Set(toScala ++ s2.toScala) + @pure @alias + def ++(s2: Set[T]): Set[T] = union(s2) + + @pure @opaque + def incl(e: T): Set[T] = { + val res = union(Set(e)) + + @pure @opaque + def inclProp: Unit = { + unionContains(this, Set(e), e) + singletonContains(e) + subsetOfUnion(this, Set(e)) + singletonSize(e) + if (contains(e)) { + intersectSingleton(this, e) + unionSize(this, Set(e)) + sizeEquals(intersect(Set(e)), Set(e)) + } else { + disjointSingleton(this, e) + unionDisjointSize(this, Set(e)) + } + + if (res.isEmpty) { + isEmptyContains(res, e) + } + }.ensuring( + res.contains(e) && + subsetOf(res) && + (res.size == (if (contains(e)) size else size + 1)) && + !res.isEmpty + ) + inclProp + res + }.ensuring { res => + res.contains(e) && + subsetOf(res) && + (res.size == (if (contains(e)) size else size + 1)) && + !res.isEmpty + } + + @pure @alias + def +(e: T): Set[T] = incl(e) + + @pure @extern + def filter(f: T => Boolean): Set[T] = Set(toScala.filter(f)) + + @pure + @extern + def map[V](f: T => V): Set[V] = Set(toScala.map(f)) + + @pure + @opaque + def diff(s2: Set[T]): Set[T] = { + filterSubsetOf(this, x => !s2.contains(x)) + filter(!s2.contains(_)) + }.ensuring(res => res.subsetOf(this)) + + @pure @alias + def &~(s2: Set[T]): Set[T] = diff(s2) + + @pure @opaque + def remove(e: T): Set[T] = { + diff(Set[T](e)) + }.ensuring(res => res.subsetOf(this)) + + @pure @alias + def -(e: T): Set[T] = remove(e) + + @pure + @opaque + def symDiff(s2: Set[T]): Set[T] = diff(s2) ++ s2.diff(this) + + @pure @opaque + def intersect(s2: Set[T]): Set[T] = union(s2) &~ symDiff(s2) + + @pure @opaque + def disjoint(s2: Set[T]): Boolean = intersect(s2).isEmpty + + @pure @alias + def &(s2: Set[T]): Set[T] = intersect(s2) + + @pure @extern + def witness(f: T => Boolean): T = { + require(exists(f)) + ??? : T + }.ensuring(w => contains(w) && f(w)) + + /** Extensional equality for finite sets + * + * Should not be opaque in order to have automatic symmetry + */ + @pure @nopaque + def ===(s2: Set[T]): Boolean = subsetOf(s2) && s2.subsetOf(this) + + @pure @alias + def =/=(s2: Set[T]): Boolean = !(this === s2) + +} + +object Set { + + @pure @extern + def empty[T]: Set[T] = Set[T](ScalaSet[T]()) + + @pure @extern + def apply[T](e: T): Set[T] = Set[T](ScalaSet[T](e)) + + @pure + def range(a: BigInt, b: BigInt): Set[BigInt] = { + decreases(b - a) + require(a <= b) + if (a == b) { + Set.empty[BigInt] + } else { + Set.range(a + 1, b) + a + } + } + +} + +object SetAxioms { + + /** Size nonnegativity axiom + * + * The size of any set is always non negative + */ + @pure + @extern + def sizePositive[T](s: Set[T]): Unit = {}.ensuring(s.size >= BigInt(0)) + + /** Singleton size axiom + * + * The size of a singleton is 1 + */ + @pure @extern + def singletonSize[T](e: T): Unit = {}.ensuring(Set(e).size == BigInt(1)) + + /** Disjoint union size axiom + * + * The size of a disjoint union is the sum of the sizes of the set + */ + @pure + @extern + def unionDisjointSize[T](s1: Set[T], s2: Set[T]): Unit = { + require(s1.disjoint(s2)) + }.ensuring((s1 ++ s2).size == s1.size + s2.size) + + /** Congruence size axiom + * + * If two sets are equal then their size is also equal + */ + @pure + @extern + def sizeEquals[T](s1: Set[T], s2: Set[T]): Unit = { + require(s1 === s2) + }.ensuring(s1.size == s2.size) + + /** De Morgan's laws for quantifiers + * + * The second one should not be an axiom if we assume that !!f == f + * which Stainless is not able to prove. + */ + @pure @extern @inlineOnce + def notForallExists[T](s: Set[T], f: T => Boolean): Unit = {}.ensuring( + !s.forall(f) == s.exists(!f(_)) + ) + + @pure + @extern + @inlineOnce + def forallNotExists[T](s: Set[T], f: T => Boolean): Unit = {}.ensuring( + !s.forall(!f(_)) == s.exists(f) + ) + + /** Existential quantifier definition + * + * Should not be an axiom if we assume some properties on lambdas that + * Stainless is not able to prove + */ + @pure @extern + def witnessExists[T](s: Set[T], f: T => Boolean, w: T): Unit = { + require(s.contains(w)) + require(f(w)) + }.ensuring(s.exists(f)) + + /** Axiom of empty set + * + * Empty introduction axiom: any predicate on the empty set is valid. Equivalent + * to the ZF empty set axiom stating that there exists an empty such that it + * contains no element. + */ + @pure + @extern + def forallEmpty[T](f: T => Boolean): Unit = {}.ensuring(Set.empty.forall(f)) + + /** Axiom of union + * + * Union introduction axiom: any predicate is true on the union of two sets if + * and only if it is true for both. + */ + @pure @extern + def forallUnion[T](s1: Set[T], s2: Set[T], f: T => Boolean): Unit = {}.ensuring( + (s1 ++ s2).forall(f) == (s1.forall(f) && s2.forall(f)) + ) + + /** Axiom of filter + * + * Filter introduction axiom: if a set is filtered by a predicate then all the element + * of the set verify this predicate. + */ + @pure + @extern + @inlineOnce + def forallFilter[T](s: Set[T], f: T => Boolean, p: T => Boolean): Unit = {}.ensuring( + s.filter(f).forall(p) == s.forall(x => f(x) ==> p(x)) + ) + + /** Axiom of map + * + * Map introduction axiom: if a set is filtered by a predicate then all the element + * of the set verify this predicate. + */ + @pure + @extern + @inlineOnce + def forallMap[T, V](s: Set[T], f: T => V, p: V => Boolean): Unit = {}.ensuring( + s.map(f).forall(p) == s.forall(f andThen p) + ) + + /** Axiom of singleton + * + * Singleton introduction axiom: a predicate is true on a singleton if and only if + * it is valid for its unique element. Equivalent to an axiom for incl. + */ + @pure @extern + def forallSingleton[T](e: T, f: T => Boolean): Unit = {}.ensuring(Set(e).forall(f) == f(e)) + + /** Extensionality axiom + * + * If two sets are subset of each other then their are equal + */ + @pure + @extern + def extensionality[T](s1: Set[T], s2: Set[T]): Unit = { + require(s1 === s2) + }.ensuring(s1 == s2) + +} diff --git a/daml-lf/verification/utils/GlobalKey.scala b/daml-lf/verification/utils/GlobalKey.scala new file mode 100644 index 0000000000..28ded6a13a --- /dev/null +++ b/daml-lf/verification/utils/GlobalKey.scala @@ -0,0 +1,14 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package utils + +case class GlobalKey(hash: BigInt) + +sealed abstract class ContractKeyUniquenessMode extends Product with Serializable + +object ContractKeyUniquenessMode { + case object Off extends ContractKeyUniquenessMode + case object Strict extends ContractKeyUniquenessMode +} diff --git a/daml-lf/verification/utils/Helpers.scala b/daml-lf/verification/utils/Helpers.scala new file mode 100644 index 0000000000..e1a67cc1c1 --- /dev/null +++ b/daml-lf/verification/utils/Helpers.scala @@ -0,0 +1,45 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package utils + +import stainless.annotation._ +import stainless.lang._ +import scala.annotation.{Annotation} + +object Unreachable { + + @opaque + def apply(): Nothing = { + require(false) + ??? + } +} + +object Trivial { + def apply(): Unit = () +} + +@ignore +class nopaque extends Annotation + +@ignore +class alias extends Annotation + +object Either { + + @pure + def cond[A, B](test: Boolean, right: B, left: A): Either[A, B] = { + if (test) Right[A, B](right) else Left[A, B](left) + }.ensuring((res: Either[A, B]) => res.isInstanceOf[Right[A, B]] == test) +} + +object Option { + + def filterNot[T](o: Option[T], p: T => Boolean): Option[T] = + o match { + case Some(v) if !p(v) => o + case _ => None[T]() + } +} diff --git a/daml-lf/verification/utils/InvListProperties.scala b/daml-lf/verification/utils/InvListProperties.scala new file mode 100644 index 0000000000..850014e5f4 --- /dev/null +++ b/daml-lf/verification/utils/InvListProperties.scala @@ -0,0 +1,184 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package utils + +import stainless.collection._ +import stainless.lang._ +import stainless.annotation._ + +object ListProperties { + + extension[T](l: List[T]) { + def bindexOf(e: T): BigInt = l.indexOf(e) + def blength: BigInt = l.length + def bapply(i: BigInt): T = { + require(i >= 0) + require(i < l.length) + l(i) + } + } + + @pure @opaque + def notForallWitness[T](l: List[T], f: T => Boolean): T = { + decreases(l) + require(!l.forall(f)) + l match { + case Nil() => Unreachable() + case Cons(h, t) => if (!f(h)) h else notForallWitness(t, f) + } + }.ensuring(res => l.contains(res) && !f(res)) + + @pure + @opaque + def concatContains[T](@induct l1: List[T], l2: List[T], e: T): Unit = {}.ensuring( + (l1 ++ l2).contains(e) == (l1.contains(e) || l2.contains(e)) + ) + + @pure + @opaque + def forallConcat[T](l1: List[T], l2: List[T], p: T => Boolean): Unit = { + + if ((l1 ++ l2).forall(p)) { + if (!l1.forall(p)) { + val w = notForallWitness(l1, p) + concatContains(l1, l2, w) + forallContains(l1 ++ l2, p, w) + } + if (!l2.forall(p)) { + val w = notForallWitness(l2, p) + concatContains(l1, l2, w) + forallContains(l1 ++ l2, p, w) + } + } + if (l1.forall(p) && l2.forall(p) && !(l1 ++ l2).forall(p)) { + val w = notForallWitness(l1 ++ l2, p) + concatContains(l1, l2, w) + if (l1.contains(w)) forallContains(l1, p, w) else forallContains(l2, p, w) + } + }.ensuring((l1 ++ l2).forall(p) == (l1.forall(p) && l2.forall(p))) + + @pure @opaque + def forallContains[T](l: List[T], f: T => Boolean, e: T): Unit = { + if (l.forall(f) && l.contains(e)) { + ListSpecs.forallContained(l, f, e) + } + }.ensuring((l.forall(f) && l.contains(e)) ==> f(e)) + + @pure + @opaque + def bapplyContains[T](tr: List[T], i: BigInt): Unit = { + decreases(tr) + require(i >= 0) + require(i < tr.blength) + tr match { + case Nil() => Trivial() + case Cons(h, t) => + if (i == 0) { + Trivial() + } else { + bapplyContains(t, i - 1) + } + } + }.ensuring(tr.contains(tr.bapply(i))) + + def isUnique[T](tr: List[T]): Boolean = { + decreases(tr) + tr match { + case Nil() => true + case Cons(h, t) => !t.contains(h) && isUnique(t) + } + } + + @pure + @opaque + def isUniqueIndex[T](tr: List[T], i1: BigInt, i2: BigInt): Unit = { + require(i1 >= 0) + require(i2 >= 0) + require(i1 < tr.blength) + require(i2 < tr.blength) + require(isUnique(tr)) + decreases(tr) + tr match { + case Nil() => Trivial() + case Cons(h, t) => + if ((i1 == 0) && (i2 == 0)) { + Trivial() + } else if (i1 == 0) { + bapplyContains(t, i2 - 1) + } else if (i2 == 0) { + bapplyContains(t, i1 - 1) + } else { + isUniqueIndex(t, i1 - 1, i2 - 1) + } + } + }.ensuring((tr.bapply(i1) == tr.bapply(i2)) == (i1 == i2)) + + @pure @opaque + def bapplyBindexOf[T](l: List[T], e: T): Unit = { + decreases(l) + require(l.contains(e)) + l match { + case Nil() => Unreachable() + case Cons(h, t) => + if (h == e) { + Trivial() + } else { + bapplyBindexOf(t, e) + } + } + }.ensuring( + l.bapply(l.bindexOf(e)) == e + ) + + @pure + @opaque + def bindexOfLast[T](l: List[T], e: T): Unit = { + require(l.bindexOf(e) >= l.blength - 1) + require(!l.isEmpty) + decreases(l) + require(l.contains(e)) + l match { + case Nil() => Unreachable() + case Cons(h, t) => + if (t.isEmpty) { + Trivial() + } else { + bindexOfLast(t, e) + } + } + }.ensuring( + e == l.last + ) + + @pure + def next[T](l: List[T], e: T): T = { + require(l.contains(e)) + require(e != l.last) + if (l.bindexOf(e) >= l.blength - 1) { + bindexOfLast(l, e) + } + l.bapply(l.bindexOf(e) + 1) + }.ensuring(l.contains) + + @pure + @opaque + def concatIndex[T](l1: List[T], l2: List[T], i: BigInt): Unit = { + decreases(l1) + require(i >= 0) + require(i < l1.size + l2.size) + l1 match { + case Nil() => Trivial() + case Cons(h1, t1) => + if (i == 0) { + Trivial() + } else { + concatIndex(t1, l2, i - 1) + } + } + }.ensuring( + (l1 ++ l2)(i) == (if (i < l1.size) l1(i) else l2(i - l1.size)) + ) + +} diff --git a/daml-lf/verification/utils/MapProperties.scala b/daml-lf/verification/utils/MapProperties.scala new file mode 100644 index 0000000000..4a9598dde9 --- /dev/null +++ b/daml-lf/verification/utils/MapProperties.scala @@ -0,0 +1,1104 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package utils + +import stainless.lang.{unfold, None, Some, BooleanDecorations} +import stainless.annotation._ +import scala.annotation.targetName +import MapAxioms._ +import SetProperties._ + +object MapProperties { + + /** --------------------------------------------------------------------------------------------------------------- * + * -------------------------------------------------EMPTY---------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** The empty map does not contain any element + */ + @pure + @opaque + def emptyContains[K, V](k: K): Unit = { + emptyGet[K, V](k) + unfold(Map.empty[K, V].contains) + }.ensuring(!Map.empty[K, V].contains(k)) + + /** The empty map is submap of any other map. + */ + @pure + @opaque + def emptySubmapOf[K, V](m: Map[K, V]): Unit = { + if (!Map.empty[K, V].submapOf(m)) { + val k = notSubmapOfWitness(Map.empty[K, V], m) + emptyContains[K, V](k) + } + }.ensuring(Map.empty[K, V].submapOf(m)) + + /** The set of values of the empty map is empty. + */ + @pure @opaque + def emptyValues[K, V]: Unit = { + if (Map.empty[K, V].values =/= Set.empty[V]) { + val w = SetProperties.notEqualsWitness(Map.empty[K, V].values, Set.empty[V]) + SetProperties.emptyContains(w) + if (Map.empty[K, V].values.contains(w)) { + val k = valuesWitness(Map.empty[K, V], w) + emptyGet[K, V](k) + } + } + unfold(Map.empty[K, V].values.isEmpty) + }.ensuring(Map.empty[K, V].values.isEmpty) + + /** The keyset of the empty map is empty. + */ + @pure + @opaque + def emptyKeySet[K, V]: Unit = { + unfold(Map.empty[K, V].keySet) + emptyValues[K, V] + preimageIsEmpty(Map.empty[K, V], Map.empty[K, V].values) + }.ensuring(Map.empty[K, V].keySet.isEmpty) + + /** If a map is non empty we can exhibit a key it contains + */ + @pure + @opaque + def notEmptyWitness[K, V](m: Map[K, V]): K = { + require(m =/= Map.empty[K, V]) + val k = notEqualsWitness(m, Map.empty[K, V]) + unfold(m.contains) + unfold(Map.empty[K, V].contains) + emptyContains[K, V](k) + k + }.ensuring(k => m.contains(k)) + + /** --------------------------------------------------------------------------------------------------------------- * + * -------------------------------------------------CONCAT--------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** The empty map is a neutral element wrt concatenation. + */ + @pure @opaque + def concatEmpty[K, V](m: Map[K, V]): Unit = { + if (Map.empty[K, V] ++ m =/= m) { + val k = notEqualsWitness(Map.empty[K, V] ++ m, m) + concatGet(Map.empty[K, V], m, k) + emptyGet[K, V](k) + } + if (m ++ Map.empty[K, V] =/= m) { + val k = notEqualsWitness(m ++ Map.empty[K, V], m) + concatGet(m, Map.empty[K, V], k) + emptyGet[K, V](k) + } + }.ensuring( + Map.empty[K, V] ++ m === m && + m ++ Map.empty[K, V] === m + ) + + /** The concatenation of two maps contains a key if and only if one of the two contains it. + */ + @pure @opaque + def concatContains[K, V](m1: Map[K, V], m2: Map[K, V], k: K): Unit = { + unfold((m1 ++ m2).contains) + unfold(m1.contains) + unfold(m2.contains) + concatGet(m1, m2, k) + }.ensuring((m1 ++ m2).contains(k) == (m1.contains(k) || m2.contains(k))) + + /** The keyset of the concatenation of two maps is the union of both keysets. + */ + @pure @opaque + def concatKeySet[K, V](m1: Map[K, V], m2: Map[K, V]): Unit = { + val unionKeySet = (m1 ++ m2).keySet + val keySetUnion = m1.keySet ++ m2.keySet + + if (unionKeySet =/= keySetUnion) { + val d = SetProperties.notEqualsWitness(unionKeySet, keySetUnion) + unionContains(m1.keySet, m2.keySet, d) + keySetContains(m1 ++ m2, d) + concatContains(m1, m2, d) + keySetContains(m1, d) + keySetContains(m2, d) + } + }.ensuring((m1 ++ m2).keySet === m1.keySet ++ m2.keySet) + + /** The set of values of the concatenation of two maps is a subset of the union of the values of both sets. + */ + @pure + @opaque + def concatValues[K, V](m1: Map[K, V], m2: Map[K, V]): Unit = { + val unionValues = (m1 ++ m2).values + val valuesUnion = m1.values ++ m2.values + + if (!unionValues.subsetOf(valuesUnion)) { + val d = SetProperties.notSubsetOfWitness(unionValues, valuesUnion) + val k = valuesWitness(m1 ++ m2, d) + concatGet(m1, m2, k) + unfold(m1.contains) + unfold(m2.contains) + if (m1.get(k) == Some[V](d)) { + valuesContains(m1, k) + } else { + valuesContains(m2, k) + } + unionContains(m1.values, m2.values, d) + } + }.ensuring((m1 ++ m2).values.subsetOf(m1.values ++ m2.values)) + + /** If a map is a submap of an other one then concatenating on the right does not change the subset relationship. + */ + @pure @opaque + def concatSubmapOf[K, V](m11: Map[K, V], m12: Map[K, V], m2: Map[K, V]): Unit = { + require(m11.submapOf(m12)) + if (!(m11 ++ m2).submapOf(m12 ++ m2)) { + val w = notSubmapOfWitness(m11 ++ m2, m12 ++ m2) + concatContains(m11, m2, w) + unfold(m2.contains) + if (m11.contains(w)) { + submapOfGet(m11, m12, w) + } + concatGet(m11, m2, w) + concatGet(m12, m2, w) + } + }.ensuring((m11 ++ m2).submapOf(m12 ++ m2)) + + /** If two maps are equal then concatenating them with a third list maintain the equality. + */ + @pure + @opaque + def concatEqualsRight[K, V](m1: Map[K, V], m21: Map[K, V], m22: Map[K, V]): Unit = { + require(m21 === m22) + if ((m1 ++ m21) =/= (m1 ++ m22)) { + val w = notEqualsWitness(m1 ++ m21, m1 ++ m22) + equalsGet(m21, m22, w) + concatGet(m1, m21, w) + concatGet(m1, m22, w) + } + }.ensuring((m1 ++ m21) === (m1 ++ m22)) + + /** Any map is a submap of itself concatenated with an other map on the left. + */ + @pure + @opaque + def concatSubmapOf[K, V](m1: Map[K, V], m2: Map[K, V]): Unit = { + concatEmpty(m2) + emptySubmapOf(m1) + concatSubmapOf(Map.empty[K, V], m1, m2) + equalsSubmapOfTransitivity(m2, Map.empty[K, V] ++ m2, m1 ++ m2) + }.ensuring(m2.submapOf(m1 ++ m2)) + + /** If the keySet of a map is a subset of the keyset of another one, the second map is equal to the concatenation + * between the first and the second one. + */ + @pure + @opaque + def concatSubsetOfEquals[K, V](m1: Map[K, V], m2: Map[K, V]): Unit = { + require(m1.keySet.subsetOf(m2.keySet)) + concatSubmapOf(m1, m2) + if (!(m1 ++ m2).submapOf(m2)) { + val k = notSubmapOfWitness(m1 ++ m2, m2) + keySetContains(m1, k) + keySetContains(m2, k) + concatGet(m1, m2, k) + unfold(m1.contains) + unfold(m2.contains) + subsetOfContains(m1.keySet, m2.keySet, k) + } + }.ensuring(m2 === m1 ++ m2) + + /** Map concatenation is idempotent. That is any map is equal to itself concatenated to itself again. + */ + @pure + @opaque + def concatIdempotence[K, V](m: Map[K, V]): Unit = { + subsetOfReflexivity(m.keySet) + concatSubsetOfEquals(m, m) + }.ensuring(m === m ++ m) + + /** Map concatenation is an associative operation. + */ + @pure + @opaque + def concatAssociativity[K, V](m1: Map[K, V], m2: Map[K, V], m3: Map[K, V]): Unit = { + if (((m1 ++ m2) ++ m3) =/= (m1 ++ (m2 ++ m3))) { + val k = notEqualsWitness((m1 ++ m2) ++ m3, m1 ++ (m2 ++ m3)) + concatGet(m1 ++ m2, m3, k) + concatGet(m1, m2, k) + concatGet(m1, m2 ++ m3, k) + concatGet(m2, m3, k) + } + }.ensuring(((m1 ++ m2) ++ m3) === (m1 ++ (m2 ++ m3))) + + /** If two maps have disjoint keysets then their concatenation is commutative. + */ + @pure + @opaque + def concatCommutativity[K, V](m1: Map[K, V], m2: Map[K, V]): Unit = { + require(m1.keySet.disjoint(m2.keySet)) + if (m1 ++ m2 =/= m2 ++ m1) { + val k = notEqualsWitness(m1 ++ m2, m2 ++ m1) + concatGet(m1, m2, k) + concatGet(m2, m1, k) + keySetContains(m1, k) + keySetContains(m2, k) + disjointContains(m1.keySet, m2.keySet, k) + unfold(m1.contains) + unfold(m2.contains) + } + }.ensuring(m1 ++ m2 === m2 ++ m1) + + /** --------------------------------------------------------------------------------------------------------------- * + * -------------------------------------------------SINGLETON------------------------------------------------------ * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** A singleton contains a key if and only if it is equal to its first argument + */ + @pure @opaque + def singletonContains[K, V](k: K, v: V, ks: K): Unit = { + singletonGet(k, v, ks) + unfold(Map[K, V](k -> v).contains) + }.ensuring(Map[K, V](k -> v).contains(ks) == (k == ks)) + + /** A singleton map contains the key of its unique pair. + */ + @pure + @opaque + def singletonContains[K, V](k: K, v: V): Unit = { + singletonContains(k, v, k) + }.ensuring(Map[K, V](k -> v).contains(k)) + + /** The values of a singelton is the singleton of its second argument. + */ + @pure @opaque + def singletonKeySet[K, V](k: K, v: V): Unit = { + if (Map[K, V](k -> v).keySet =/= Set(k)) { + val d = SetProperties.notEqualsWitness(Map[K, V](k -> v).keySet, Set(k)) + keySetContains(Map(k -> v), d) + singletonContains(k, v, d) + SetProperties.singletonContains(k, d) + } + }.ensuring(Map[K, V](k -> v).keySet === Set(k)) + + /** The values of a singelton is the singleton of its second argument. + */ + @pure + @opaque + def singletonValues[K, V](k: K, v: V): Unit = { + if (Map[K, V](k -> v).values =/= Set[V](v)) { + val d: V = SetProperties.notEqualsWitness(Map[K, V](k -> v).values, Set[V](v)) + SetProperties.singletonContains[V](v, d) + + if (Map[K, V](k -> v).values.contains(d)) { + val ks = valuesWitness(Map(k -> v), d) + singletonGet(k, v, ks) + } else { + singletonGet(k, v, k) + unfold(Map[K, V](k -> v).contains) + valuesContains(Map[K, V](k -> v), k) + } + } + }.ensuring(Map[K, V](k -> v).values === Set[V](v)) + + /** A singleton whose values has been mapped is equal to the singleton to which the function has been applied to the + * second argument. + */ + @pure + @opaque + @targetName("singletonMapValuesArgPair") + def singletonMapValues[K, V, V2](k: K, v: V, f: K => V => V2): Unit = { + if (Map(k -> v).mapValues[V2](f) =/= Map(k -> f(k)(v))) { + val w = notEqualsWitness(Map(k -> v).mapValues[V2](f), Map(k -> f(k)(v))) + singletonGet(k, v, w) + MapAxioms.mapValuesGet(Map(k -> v), f, w) + singletonGet(k, f(k)(v), w) + } + }.ensuring(Map(k -> v).mapValues[V2](f) === Map(k -> f(k)(v))) + + /** A singleton whose values has been mapped is equal to the singleton to which the function has been applied to the + * second argument. + */ + @pure + @opaque + def singletonMapValues[K, V, V2](k: K, v: V, f: V => V2): Unit = { + singletonMapValues(k, v, (k: K) => f) + }.ensuring(Map(k -> v).mapValues[V2](f) === Map(k -> f(v))) + + /** --------------------------------------------------------------------------------------------------------------- * + * ---------------------------------------------------MAP---------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** The mapping of a map after a map operations on its value is the mapping of the original map on which the function + * has been applied. + */ + @pure + @opaque + def mapValuesGet[K, V, V2](m: Map[K, V], f: V => V2, k: K): Unit = { + MapAxioms.mapValuesGet(m, (k: K) => f, k) + }.ensuring(m.mapValues(f).get(k) == m.get(k).map(f)) + + /** A map after a mapValues operation contains a key if and only if the original map also contains the key + */ + @pure + @opaque + @targetName("mapValuesContainsArgPair") + def mapValuesContains[K, V, V2](m: Map[K, V], f: K => V => V2, k: K): Unit = { + MapAxioms.mapValuesGet(m, f, k) + unfold(m.mapValues(f).contains) + unfold(m.contains) + }.ensuring(m.mapValues(f).contains(k) == m.contains(k)) + + /** A map after a mapValues operation contains a key if and only if the original map also contains the key + */ + @pure + @opaque + def mapValuesContains[K, V, V2](m: Map[K, V], f: V => V2, k: K): Unit = { + mapValuesContains(m, (k: K) => f, k) + }.ensuring(m.mapValues(f).contains(k) == m.contains(k)) + + /** The keyset of a map after a map operations on its value as been performed is the same as the keyset of the original + * map. + */ + @pure + @opaque + @targetName("mapValuesKeySetArgPair") + def mapValuesKeySet[K, V, V2](m: Map[K, V], f: K => V => V2): Unit = { + if (m.mapValues(f).keySet =/= m.keySet) { + val d = SetProperties.notEqualsWitness(m.mapValues(f).keySet, m.keySet) + keySetContains(m.mapValues(f), d) + keySetContains(m, d) + mapValuesContains(m, f, d) + Unreachable() + } + }.ensuring(m.mapValues(f).keySet === m.keySet) + + /** The key set after applying a map operation on values does not change the key set. + */ + @pure + @opaque + def mapValuesKeySet[K, V, V2](m: Map[K, V], f: V => V2): Unit = { + mapValuesKeySet(m, (k: K) => f) + }.ensuring(m.mapValues(f).keySet === m.keySet) + + /** Applying a map values operation on a map twice is the same as applying map once with the composition of both functions. + */ + @pure + @opaque + @inlineOnce + def mapValuesAndThen[K, V, V2, V3](m: Map[K, V], f: V => V2, g: V2 => V3): Unit = { + val mapMap = m.mapValues(f).mapValues(g) + val mapAT = m.mapValues(f andThen g) + if (mapMap =/= mapAT) { + val w = notEqualsWitness(mapMap, mapAT) + mapValuesGet(m.mapValues(f), g, w) + mapValuesGet(m, f, w) + mapValuesGet(m, f andThen g, w) + } + + }.ensuring(m.mapValues(f).mapValues(g) === m.mapValues(f andThen g)) + + /** If a map is submap of an other one then applying a map operation on their values does not change the relationship. + */ + @pure @opaque + def mapValuesSubmapOf[K, V, V2](m1: Map[K, V], m2: Map[K, V], f: K => V => V2): Unit = { + require(m1.submapOf(m2)) + if (!m1.mapValues(f).submapOf(m2.mapValues(f))) { + val w = notSubmapOfWitness(m1.mapValues(f), m2.mapValues(f)) + mapValuesContains(m1, f, w) + submapOfGet(m1, m2, w) + MapAxioms.mapValuesGet(m1, f, w) + MapAxioms.mapValuesGet(m2, f, w) + } + }.ensuring(m1.mapValues(f).submapOf(m2.mapValues(f))) + + /** If a map is submap of an other one then applying a map operation on their values does not change the relationship. + */ + @pure + @opaque + @targetName("mapValuesSubmapOfArgPair") + def mapValuesSubmapOf[K, V, V2](m1: Map[K, V], m2: Map[K, V], f: V => V2): Unit = { + require(m1.submapOf(m2)) + mapValuesSubmapOf(m1, m2, (k: K) => f) + }.ensuring(m1.mapValues(f).submapOf(m2.mapValues(f))) + + /** If a map is equals to an other one then applying a map operation on their values does not change the relationship. + */ + @pure + @opaque + def mapValuesEquals[K, V, V2](m1: Map[K, V], m2: Map[K, V], f: K => V => V2): Unit = { + require(m1 === m2) + mapValuesSubmapOf(m1, m2, f) + mapValuesSubmapOf(m2, m1, f) + }.ensuring(m1.mapValues(f) === m2.mapValues(f)) + + /** If a map is equals to an other one then applying a map operation on their values does not change the relationship. + */ + @pure + @opaque + @targetName("mapValuesEqualsArgPair") + def mapValuesEquals[K, V, V2](m1: Map[K, V], m2: Map[K, V], f: V => V2): Unit = { + require(m1 === m2) + mapValuesSubmapOf(m1, m2, f) + mapValuesSubmapOf(m2, m1, f) + }.ensuring(m1.mapValues(f) === m2.mapValues(f)) + + /** Applying a map operations on the values of the emtpy map gives the empty map + */ + @pure + @opaque + def mapValuesEmpty[K, V, V2](f: K => V => V2): Unit = { + if (Map.empty[K, V].mapValues(f) =/= Map.empty[K, V2]) { + val k = notEqualsWitness(Map.empty[K, V].mapValues(f), Map.empty[K, V2]) + emptyGet[K, V2](k) + emptyGet[K, V](k) + MapAxioms.mapValuesGet(Map.empty[K, V], f, k) + } + }.ensuring(Map.empty[K, V].mapValues(f) === Map.empty[K, V2]) + + /** Applying a map operations on the values of the emtpy map gives the empty map + */ + @pure + @opaque + @targetName("mapValuesEmptyArgPair") + def mapValuesEmpty[K, V, V2](f: V => V2): Unit = { + mapValuesEmpty[K, V, V2]((k: K) => f) + }.ensuring(Map.empty[K, V].mapValues(f) === Map.empty[K, V2]) + + /** The set of values of a map after a map operation has been applied to its values is the set of values of the original map + * mapped by the same function. + */ + @pure + @opaque + def mapValuesValues[K, V, V2](m: Map[K, V], f: V => V2): Unit = { + if (m.mapValues[V2](f).values =/= m.values.map[V2](f)) { + val w: V2 = SetProperties.notEqualsWitness(m.mapValues[V2](f).values, m.values.map[V2](f)) + if (m.mapValues[V2](f).values.contains(w)) { + val k: K = valuesWitness(m.mapValues[V2](f), w) + MapProperties.mapValuesGet(m, f, k) + unfold(m.contains) + valuesContains(m, k) + SetProperties.mapContains(m.values, f, m(k)) + } else { + val v: V = SetProperties.mapContainsWitness(m.values, f, w) + val k: K = valuesWitness(m, v) + MapProperties.mapValuesGet(m, f, k) + unfold(m.mapValues[V2](f).contains) + valuesContains(m.mapValues[V2](f), k) + } + } + }.ensuring(m.mapValues[V2](f).values === m.values.map(f)) + + /** Mapping the values of a map concatenation is the same then mapping both maps and then concatenating them. + */ + @pure + @opaque + @targetName("mapValuesConcatArgPair") + def mapValuesConcat[K, V, V2](m1: Map[K, V], m2: Map[K, V], f: K => V => V2): Unit = { + if ((m1 ++ m2).mapValues[V2](f) =/= (m1.mapValues[V2](f) ++ m2.mapValues[V2](f))) { + val k = + notEqualsWitness((m1 ++ m2).mapValues[V2](f), m1.mapValues[V2](f) ++ m2.mapValues[V2](f)) + MapAxioms.mapValuesGet(m1 ++ m2, f, k) + concatGet(m1, m2, k) + MapAxioms.mapValuesGet(m1, f, k) + MapAxioms.mapValuesGet(m2, f, k) + concatGet(m1.mapValues[V2](f), m2.mapValues[V2](f), k) + } + }.ensuring((m1 ++ m2).mapValues[V2](f) === (m1.mapValues[V2](f) ++ m2.mapValues[V2](f))) + + /** Mapping the values of a map concatenation is the same then mapping both maps and then concatenating them. + */ + @pure + @opaque + def mapValuesConcat[K, V, V2](m1: Map[K, V], m2: Map[K, V], f: V => V2): Unit = { + mapValuesConcat(m1, m2, (k: K) => f) + }.ensuring((m1 ++ m2).mapValues[V2](f) === (m1.mapValues[V2](f) ++ m2.mapValues[V2](f))) + + /** Mapping the values of a map with an added pair is the same then mapping the maps and then applying the function + * to the value in the pair and then adding it. + */ + @pure + @opaque + @targetName("mapValuesUpdatedArgPair") + def mapValuesUpdated[K, V, V2](m: Map[K, V], k: K, v: V, f: K => V => V2): Unit = { + unfold(m.updated(k, v)) + unfold(m.mapValues[V2](f).updated(k, f(k)(v))) + mapValuesConcat(m, Map(k -> v), f) + singletonMapValues(k, v, f) + concatEqualsRight(m.mapValues[V2](f), Map(k -> v).mapValues[V2](f), Map(k -> f(k)(v))) + equalsTransitivity( + m.updated(k, v).mapValues[V2](f), + m.mapValues[V2](f) ++ Map(k -> v).mapValues[V2](f), + m.mapValues[V2](f) ++ Map(k -> f(k)(v)), + ) + }.ensuring(m.updated(k, v).mapValues[V2](f) === m.mapValues[V2](f).updated(k, f(k)(v))) + + /** Mapping the values of a map with an added pair is the same then mapping the maps and then applying the function + * to the value in the pair and then adding it. + */ + @pure + @opaque + def mapValuesUpdated[K, V, V2](m: Map[K, V], k: K, v: V, f: V => V2): Unit = { + mapValuesUpdated(m, k, v, (k: K) => f) + }.ensuring(m.updated(k, v).mapValues[V2](f) === m.mapValues[V2](f).updated(k, f(v))) + + /** --------------------------------------------------------------------------------------------------------------- * + * -----------------------------------------------SUBMAP OF-------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** If a map contains a key and is submap of another map then it also contains this key. + */ + @pure + @opaque + def submapOfContains[K, V](m1: Map[K, V], m2: Map[K, V], k: K): Unit = { + require(m1.submapOf(m2)) + require(m1.contains(k)) + submapOfGet(m1, m2, k) + unfold(m1.contains) + unfold(m2.contains) + }.ensuring(m2.contains(k)) + + /** Every map is submap of itself + */ + @pure + @opaque + def submapOfReflexivity[K, V](m: Map[K, V]): Unit = { + if (!m.submapOf(m)) { + val w = notSubmapOfWitness(m, m) + Unreachable() + } + }.ensuring(m.submapOf(m)) + + /** Submap is a transitive relation. + */ + @pure + @opaque + def submapOfTransitivity[K, V](m1: Map[K, V], m2: Map[K, V], m3: Map[K, V]): Unit = { + require(m1.submapOf(m2)) + require(m2.submapOf(m3)) + if (!m1.submapOf(m3)) { + val d = notSubmapOfWitness(m1, m3) + submapOfContains(m1, m2, d) + submapOfGet(m1, m2, d) + submapOfGet(m2, m3, d) + Unreachable() + } + }.ensuring(m1.submapOf(m3)) + + /** If two maps are equal then the first is submap of a third map if and only if the second also is. + */ + @pure + @opaque + def equalsSubmapOfTransitivity[K, V](m1: Map[K, V], m2: Map[K, V], m3: Map[K, V]): Unit = { + require(m1 === m2) + + if (m1.submapOf(m3)) { + submapOfTransitivity(m2, m1, m3) + } + if (m2.submapOf(m3)) { + submapOfTransitivity(m1, m2, m3) + } + + }.ensuring( + m1.submapOf(m3) == m2.submapOf(m3) + ) + + /** If two maps are equal then the first is a supermap of a third map if and only if the second also is. + */ + @pure + @opaque + def submapOfEqualsTransitivity[K, V](m1: Map[K, V], m2: Map[K, V], m3: Map[K, V]): Unit = { + require(m2 === m3) + + if (m1.submapOf(m2)) { + submapOfTransitivity(m1, m2, m3) + } + if (m1.submapOf(m3)) { + submapOfTransitivity(m1, m3, m2) + } + + }.ensuring( + m1.submapOf(m2) == m1.submapOf(m3) + ) + + /** If a map is submap of another one then its keySets are also subset of each other. + */ + @pure @opaque + def submapOfKeySet[K, V](m1: Map[K, V], m2: Map[K, V]): Unit = { + require(m1.submapOf(m2)) + + if (!m1.keySet.subsetOf(m2.keySet)) { + val w = SetProperties.notSubsetOfWitness(m1.keySet, m2.keySet) + keySetContains(m1, w) + keySetContains(m2, w) + submapOfContains(m1, m2, w) + } + + }.ensuring(m1.keySet.subsetOf(m2.keySet)) + + /** --------------------------------------------------------------------------------------------------------------- * + * -------------------------------------------------EQUALS--------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** If two maps are equal, they contain an element if an only if the other one does as well. + */ + @pure + @opaque + def equalsContains[K, V](m1: Map[K, V], m2: Map[K, V], k: K): Unit = { + require(m1 === m2) + if (m1.contains(k)) + submapOfContains(m1, m2, k) + if (m2.contains(k)) + submapOfContains(m2, m1, k) + }.ensuring(m1.contains(k) == m2.contains(k)) + + /** If two maps are equal then the mapping associated to any of their key is equal as well. + */ + @pure + @opaque + def equalsGet[K, V](m1: Map[K, V], m2: Map[K, V], k: K): Unit = { + require(m1 === m2) + equalsContains(m1, m2, k) + unfold(m1.contains) + unfold(m2.contains) + if (m1.contains(k)) { + submapOfGet(m1, m2, k) + submapOfGet(m2, m1, k) + } + }.ensuring(m1.get(k) == m2.get(k)) + + /** If two maps are not equal then we can exhibit a key such that they mapping differ. + */ + @pure + @opaque + def notEqualsWitness[K, V](m1: Map[K, V], m2: Map[K, V]): K = { + require(m1 =/= m2) + if (!m1.submapOf(m2)) + notSubmapOfWitness(m1, m2) + else + notSubmapOfWitness(m2, m1) + }.ensuring(res => m1.get(res) != m2.get(res)) + + /** Two equal maps have the same keyset. + */ + @pure @opaque + def equalsKeySet[K, V](m1: Map[K, V], m2: Map[K, V]): Unit = { + require(m1 === m2) + val ks1 = m1.keySet + val ks2 = m2.keySet + if (ks1 =/= ks2) { + val d = SetProperties.notEqualsWitness(ks1, ks2) + keySetContains(m1, d) + keySetContains(m2, d) + equalsContains(m1, m2, d) + } + }.ensuring(m1.keySet === m2.keySet) + + /** Map equality is transitive. + */ + @pure + @opaque + def equalsTransitivityStrong[K, V](m1: Map[K, V], m2: Map[K, V], m3: Map[K, V]): Unit = { + require(m2 === m3) + submapOfEqualsTransitivity(m1, m2, m3) + equalsSubmapOfTransitivity(m2, m3, m1) + }.ensuring((m1 === m2) == (m1 === m3)) + + /** Map equality is transitive. + */ + @pure + @opaque + def equalsTransitivity[K, V](m1: Map[K, V], m2: Map[K, V], m3: Map[K, V]): Unit = { + require(m1 === m2) + require(m2 === m3) + equalsTransitivityStrong(m1, m2, m3) + }.ensuring(m1 === m3) + + /** --------------------------------------------------------------------------------------------------------------- * + * ------------------------------------------------KEY SET--------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** A map contains a key if and only if its key set contains the key as well. + */ + @pure + @opaque + def keySetContains[K, V](m: Map[K, V], ks: K): Unit = { + unfold(m.keySet) + MapAxioms.preimageGet(m, m.values, ks) + unfold(m.contains) + if (m.contains(ks)) { + valuesContains(m, ks) + } + }.ensuring(m.contains(ks) == m.keySet.contains(ks)) + + /** If two maps have the same key set then one contains a key if and only if the other does as well. + */ + @pure + @opaque + def equalsKeySetContains[K, V, V2](m1: Map[K, V], m2: Map[K, V2], k: K): Unit = { + require(m1.keySet === m2.keySet) + SetProperties.equalsContains(m1.keySet, m2.keySet, k) + keySetContains(m1, k) + keySetContains(m2, k) + }.ensuring(m1.contains(k) == m2.contains(k)) + + /** --------------------------------------------------------------------------------------------------------------- * + * ---------------------------------------------------UPDATED------------------------------------------------------ * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** A map which has been updated with a new pair contains a key if and only if either the oiriginal map contains the + * key or if it is equal to the first element of the pair. + */ + @pure @opaque + def updatedContains[K, V](m: Map[K, V], k: K, v: V, e: K): Unit = { + keySetContains(m.updated(k, v), e) + SetProperties.equalsContains(m.updated(k, v).keySet, m.keySet + k, e) + keySetContains(m, e) + inclContains(m.keySet, k, e) + }.ensuring(m.updated(k, v).contains(e) == (m.contains(e) || k == e)) + + /** If a pair is added to a map, then the mapping associated to a key other than the one in the pair is the same mapping + * as in the original map. + */ + @pure + @opaque + def updatedGet[K, V](m: Map[K, V], k: K, v: V, k2: K): Unit = { + require(k != k2) + unfold(m.updated(k, v)) + concatGet(m, Map(k -> v), k2) + singletonGet(k, v, k2) + }.ensuring(m.updated(k, v).get(k2) == m.get(k2)) + + /** If a map does not contain the key of a pair, then adding this pair is equivalent to concatenating it on the left. + */ + @pure + @opaque + def updatedCommutativity[K, V](m: Map[K, V], k: K, v: V): Unit = { + require(!m.contains(k)) + unfold(m.updated(k, v)) + singletonKeySet(k, v) + disjointEquals(m.keySet, Set(k), Map[K, V](k -> v).keySet) + disjointSingleton(m.keySet, k) + keySetContains(m, k) + concatCommutativity(m, Map[K, V](k -> v)) + }.ensuring(m.updated(k, v) === Map[K, V](k -> v) ++ m) + + /** --------------------------------------------------------------------------------------------------------------- * + * ------------------------------------------------VALUES-------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** The values of a map to which a mapping has been added are a subset of the valeus of the original map to which the + * value of the pair has been added. + */ + @pure + @opaque + def updatedValues[K, V](m: Map[K, V], k: K, v: V): Unit = { + unfold(m.updated(k, v)) + concatValues(m, Map(k -> v)) + singletonValues(k, v) + unionEqualsRight(m.values, Map(k -> v).values, Set(v)) + unfold(m.values + v) + subsetOfTransitivity(m.updated(k, v).values, m.values ++ Map(k -> v).values, m.values ++ Set(v)) + }.ensuring( + m.updated(k, v).values.subsetOf(m.values + v) + ) + + /** --------------------------------------------------------------------------------------------------------------- * + * ------------------------------------------------PREIMAGE-------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** If a map is submap of an other one then the preimage of the first wrt a set is subset of the preimage of the second. + */ + @pure + @opaque + def preimageSubsetOf[K, V](m1: Map[K, V], m2: Map[K, V], s: Set[V]): Unit = { + require(m1.submapOf(m2)) + if (!m1.preimage(s).subsetOf(m2.preimage(s))) { + val k = notSubsetOfWitness(m1.preimage(s), m2.preimage(s)) + MapAxioms.preimageGet(m1, s, k) + MapAxioms.preimageGet(m2, s, k) + unfold(m1.contains) + submapOfGet(m1, m2, k) + } + }.ensuring(m1.preimage(s).subsetOf(m2.preimage(s))) + + /** If two maps are equal then their preimages wrt a set are also equal. + */ + @pure + @opaque + def preimageEquals[K, V](m1: Map[K, V], m2: Map[K, V], s: Set[V]): Unit = { + require(m1 === m2) + preimageSubsetOf(m1, m2, s) + preimageSubsetOf(m2, m1, s) + }.ensuring(m1.preimage(s) === m2.preimage(s)) + + /** If a map is submap of an other one then the preimage of the first wrt a value is subset of the preimage of the + * second. + */ + @pure + @opaque + def preimageSubsetOf[K, V](m1: Map[K, V], m2: Map[K, V], s: V): Unit = { + require(m1.submapOf(m2)) + unfold(m1.preimage(s)) + unfold(m2.preimage(s)) + preimageSubsetOf(m1, m2, Set(s)) + }.ensuring(m1.preimage(s).subsetOf(m2.preimage(s))) + + /** If two maps are equal then their preimages wrt an element are also equal. + */ + @pure + @opaque + def preimageEquals[K, V](m1: Map[K, V], m2: Map[K, V], s: V): Unit = { + require(m1 === m2) + unfold(m1.preimage(s)) + unfold(m2.preimage(s)) + preimageEquals(m1, m2, Set(s)) + }.ensuring(m1.preimage(s) === m2.preimage(s)) + + /** The preimage of any set wrt a map is a subset of the keyset of that map + */ + @pure @opaque + def preimageKeySet[K, V](m: Map[K, V], s: Set[V]): Unit = { + if (!m.preimage(s).subsetOf(m.keySet)) { + val k = SetProperties.notSubsetOfWitness(m.preimage(s), m.keySet) + MapAxioms.preimageGet(m, s, k) + keySetContains(m, k) + unfold(m.contains) + } + }.ensuring(m.preimage(s).subsetOf(m.keySet)) + + /** The preimage of a set wrt a map is empty if and only if this set is disjoint from the values of the map. + * In particular, if the set is empty this is true. + */ + @pure + @opaque + def preimageIsEmpty[K, V](m: Map[K, V], s: Set[V]): Unit = { + if (m.preimage(s).isEmpty && !s.disjoint(m.values)) { + val v = notDisjointWitness(s, m.values) + val k = valuesWitness(m, v) + MapAxioms.preimageGet(m, s, k) + isEmptyContains(m.preimage(s), k) + } + if (!m.preimage(s).isEmpty && s.disjoint(m.values)) { + val k = SetProperties.notEmptyWitness(m.preimage(s)) + unfold(m.contains) + MapAxioms.preimageGet(m, s, k) + disjointContains(s, m.values, m(k)) + valuesContains(m, k) + } + if (s.isEmpty) { + disjointIsEmpty(s, m.values) + } + }.ensuring( + (s.disjoint(m.values) == m.preimage(s).isEmpty) && + (s.isEmpty ==> m.preimage(s).isEmpty) + ) + + /** The preimage of a map with respect to a value contains a key if and only if the mapping associated to the key in + * the map is equal to the value. + */ + @pure + @extern + def preimageGet[K, V](m: Map[K, V], v: V, k: K): Unit = {}.ensuring( + m.preimage(v).contains(k) == (m.get(k) == Some[V](v)) + ) + + /** If the values of a map do not contain an element, then the preimage wrt that element is empty. + */ + @pure + @opaque + def preimageIsEmpty[K, V](m: Map[K, V], v: V): Unit = { + unfold(m.preimage(v)) + preimageIsEmpty(m, Set(v)) + SetProperties.disjointSingleton(m.values, v) + }.ensuring(!m.values.contains(v) == m.preimage(v).isEmpty) + + /** If the preimage of a value is a singleton then any other key mapping to this value is equal to the key in the preimage. + */ + @pure + @opaque + def preimageSingletonGet[K, V](m: Map[K, V], v: V, k: K, k2: K): Unit = { + require(m.preimage(v) === Set[K](k)) + SetProperties.singletonSubsetOf(m.preimage(v), k) + preimageGet(m, v, k) + if (m.get(k2) == Some[V](v)) { + preimageGet(m, v, k2) + SetProperties.singletonContains(k, k2) + SetProperties.equalsContains(m.preimage(v), Set(k), k2) + } + }.ensuring((m.get(k2) == Some[V](v)) == (k == k2)) + + /** The preimage of a set wrt to a singleton is either the key if the mapping of the singleton is contained in the set + * or the empty set. + */ + @pure + @opaque + def preimageSingleton[K, V](k: K, v: V, s: Set[V]): Unit = { + val m: Map[K, V] = Map(k -> v) + if (m.preimage(s) =/= (if (s.contains(v)) Set(k) else Set.empty[K])) { + val ks = SetProperties.notEqualsWitness[K]( + m.preimage(s), + if (s.contains(v)) Set(k) else Set.empty[K], + ) + MapAxioms.preimageGet(m, s, ks) + SetProperties.emptyContains(ks) + singletonGet(k, v, ks) + SetProperties.singletonContains(k, ks) + } + }.ensuring( + ((m: Map[K, V]) => m.preimage(s) === (if (s.contains(v)) Set(k) else Set.empty[K]))( + Map[K, V](k -> v) + ) + ) + + /** The preimage of a value wrt to a singleton is either the key if the mapping of the singleton is equal to the value + * or the empty set. + */ + @pure + @opaque + def preimageSingleton[K, V](k: K, v: V, s: V): Unit = { + unfold(Map(k -> v).preimage(s)) + SetProperties.singletonContains(s, v) + preimageSingleton(k, v, Set(s)) + }.ensuring(Map(k -> v).preimage(s) === (if (s == v) Set(k) else Set.empty[K])) + + /** The preimage of a concatenation between two maps is a subset of the union of the preimages. + */ + @pure + @opaque + def concatPreimage[K, V](m1: Map[K, V], m2: Map[K, V], s: Set[V]): Unit = { + if (!(m1 ++ m2).preimage(s).subsetOf(m1.preimage(s) ++ m2.preimage(s))) { + val k: K = + SetProperties.notSubsetOfWitness((m1 ++ m2).preimage(s), m1.preimage(s) ++ m2.preimage(s)) + MapAxioms.preimageGet(m1 ++ m2, s, k) + MapAxioms.preimageGet(m1, s, k) + MapAxioms.preimageGet(m2, s, k) + SetProperties.unionContains(m1.preimage(s), m2.preimage(s), k) + concatContains(m1, m2, k) + concatGet(m1, m2, k) + } + }.ensuring((m1 ++ m2).preimage(s).subsetOf(m1.preimage(s) ++ m2.preimage(s))) + + /** The preimage of a concatenation between two maps is a subset of the union of the preimages. + */ + @pure + @opaque + def concatPreimage[K, V](m1: Map[K, V], m2: Map[K, V], s: V): Unit = { + concatPreimage(m1, m2, Set(s)) + unfold(m1.preimage(s)) + unfold(m2.preimage(s)) + unfold((m1 ++ m2).preimage(s)) + }.ensuring((m1 ++ m2).preimage(s).subsetOf(m1.preimage(s) ++ m2.preimage(s))) + + /** The preimageof a set wrt to a map to which a pair has been added is a subset of the preimage of the original map + * to which the mapping of the pair is added. + */ + @pure + @opaque + def inclPreimage[K, V](m: Map[K, V], k: K, v: V, s: Set[V]): Unit = { + val singl: Map[K, V] = Map(k -> v) + unfold(m.updated(k, v)) + unfold(m.preimage(s).incl(k)) + concatPreimage(m, singl, s) + preimageSingleton(k, v, s) + if (s.contains(v)) { + SetProperties.unionEqualsRight(m.preimage(s), singl.preimage(s), Set(k)) + SetProperties.subsetOfEqualsTransitivity( + m.updated(k, v).preimage(s), + m.preimage(s) ++ singl.preimage(s), + m.preimage(s) + k, + ) + } else { + SetProperties.unionEqualsRight(m.preimage(s), singl.preimage(s), Set.empty) + SetProperties.unionEmpty(m.preimage(s)) + SetProperties.equalsTransitivity( + m.preimage(s) ++ singl.preimage(s), + m.preimage(s) ++ Set.empty, + m.preimage(s), + ) + SetProperties.subsetOfEqualsTransitivity( + m.updated(k, v).preimage(s), + m.preimage(s) ++ singl.preimage(s), + m.preimage(s), + ) + SetProperties.subsetOfTransitivity( + m.updated(k, v).preimage(s), + m.preimage(s), + m.preimage(s) + k, + ) + } + }.ensuring( + if (s.contains(v)) + m.updated(k, v).preimage(s).subsetOf(m.preimage(s) + k) + else + m.updated(k, v).preimage(s).subsetOf(m.preimage(s)) && m + .updated(k, v) + .preimage(s) + .subsetOf(m.preimage(s) + k) + ) + + /** The preimageof a set wrt to a map to which a pair has been added is a subset of the preimage of the original map + * to which the mapping of the pair is added. + */ + @pure + @opaque + def inclPreimage[K, V](m: Map[K, V], k: K, v: V, s: V): Unit = { + unfold(m.updated(k, v).preimage(s)) + unfold(m.preimage(s)) + inclPreimage(m, k, v, Set(s)) + SetProperties.singletonContains(s, v) + }.ensuring( + if (s == v) + m.updated(k, v).preimage(s).subsetOf(m.preimage(s) + k) + else + m.updated(k, v).preimage(s).subsetOf(m.preimage(s)) && m + .updated(k, v) + .preimage(s) + .subsetOf(m.preimage(s) + k) + ) + + /** --------------------------------------------------------------------------------------------------------------- * + * ------------------------------------------------FIND------------------------------------------------------------ * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** If a map is empty, then find will always return None + */ + @pure + @opaque + def findEmpty[K, V](m: Map[K, V], f: ((K, V)) => Boolean): Unit = { + require(m === Map.empty[K, V]) + if (m.find(f).isDefined) { + findGet(m, f) + unfold(m.contains) + equalsContains(m, Map.empty[K, V], m.find(f).get._1) + emptyContains[K, V](m.find(f).get._1) + } + }.ensuring(!m.find(f).isDefined) + + @pure + @opaque + def findEquals[K, V](m1: Map[K, V], m2: Map[K, V], f: ((K, V)) => Boolean): Unit = { + require(m1 === m2) + if (m1.find(f).isDefined) { + findGet(m1, f) + equalsGet(m1, m2, m1.find(f).get._1) + findDefined(m2, f, m1.find(f).get._1, m1.find(f).get._2) + } else if (m2.find(f).isDefined) { + findGet(m2, f) + equalsGet(m2, m1, m2.find(f).get._1) + findDefined(m1, f, m2.find(f).get._1, m2.find(f).get._2) + } + }.ensuring(m1.find(f).isDefined == m2.find(f).isDefined) + +} diff --git a/daml-lf/verification/utils/Node.scala b/daml-lf/verification/utils/Node.scala new file mode 100644 index 0000000000..56bb2103bf --- /dev/null +++ b/daml-lf/verification/utils/Node.scala @@ -0,0 +1,57 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package utils + +import stainless.lang._ +import stainless.annotation._ +import stainless.proof._ +import stainless.collection._ + +import Value.ContractId + +/** Generic transaction node type for both update transactions and the + * transaction graph. + */ +sealed trait Node + +object Node { + + sealed trait Action extends Node { + def gkeyOpt: Option[GlobalKey] + + def byKey: Boolean + } + + sealed trait LeafOnlyAction extends Action + + final case class Create(coid: ContractId, override val gkeyOpt: Option[GlobalKey]) + extends LeafOnlyAction { + override def byKey: Boolean = false + } + + final case class Fetch( + coid: ContractId, + override val gkeyOpt: Option[GlobalKey], + override val byKey: Boolean, + ) extends LeafOnlyAction + + final case class Exercise( + targetCoid: ContractId, + consuming: Boolean, + children: List[NodeId], + override val gkeyOpt: Option[GlobalKey], + override val byKey: Boolean, + ) extends Action + + final case class LookupByKey(gkey: GlobalKey, result: Option[ContractId]) extends LeafOnlyAction { + override def gkeyOpt: Option[GlobalKey] = Some(gkey) + + override def byKey: Boolean = true + } + + final case class Rollback(children: List[NodeId]) extends Node +} + +final case class NodeId(index: BigInt) diff --git a/daml-lf/verification/utils/SetProperties.scala b/daml-lf/verification/utils/SetProperties.scala new file mode 100644 index 0000000000..2610af6612 --- /dev/null +++ b/daml-lf/verification/utils/SetProperties.scala @@ -0,0 +1,2106 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package utils + +import stainless.lang.{Set => StainlessSet, BooleanDecorations, unfold} +import stainless.lang +import stainless.annotation._ +import SetAxioms._ + +/** ∀ ∅ ⊆ ∈ ∪ ∃ + */ + +object SetProperties { + + /** --------------------------------------------------------------------------------------------------------------- * + * ----------------------------------------------------EMPTY------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** The empty set is a subset of any set. + * + * - ∀s. ∅ ⊆ s + */ + @pure + @opaque + def emptySubsetOf[T](s: Set[T]): Unit = { + unfold(Set.empty.subsetOf(s)) + forallEmpty(s.contains) + }.ensuring(Set.empty.subsetOf(s)) + + /** The empty set has size 0. + * + * - |∅| = 0 + */ + @pure + @opaque + def emptySize[T]: Unit = { + disjointEmpty(Set.empty[T]) + unionEmpty(Set.empty[T]) + sizeEquals(Set.empty[T] ++ Set.empty[T], Set.empty[T]) + unionDisjointSize(Set.empty[T], Set.empty[T]) + sizePositive(Set.empty[T]) + }.ensuring(Set.empty[T].size == BigInt(0)) + + /** There is no element satisfying a given predicate in the empty set. + * + * - ∀p. (!∃x ∈ ∅. p(x)) + */ + @pure + @opaque + def emptyExists[T](p: T => Boolean): Unit = { + forallNotExists(Set.empty, p) + forallEmpty[T](!p(_)) + }.ensuring(!Set.empty.exists(p)) + + /** Adding an element to the empty set gives the singleton of that element. + */ + @pure + @opaque + def inclEmpty[T](e: T): Unit = { + unfold(Set.empty.incl(e)) + unionEmpty(Set[T](e)) + equalsReflexivity(Set(e)) + }.ensuring(Set.empty.incl(e) === Set(e)) + + /** The empty set does not contain any element. + * + * - ∀x. x ∉ ∅ + */ + @pure + @opaque + def emptyContains[T](e: T): Unit = { + if (Set.empty(e)) { + singletonSubsetOf(Set.empty, e) + val neq: T => Boolean = _ != e + forallEmpty(neq) + forallSingleton(e, neq) + forallSubsetOf(Set(e), Set.empty, neq) + } + }.ensuring(!Set.empty(e)) + + /** The intersection with the empty set is the empty set (i.e. the empty set is an absorbing element). + * + * - ∀s. s ∩ ∅ = ∅ + * - ∀s. ∅ ∩ s = ∅ + */ + @pure @opaque + def intersectionEmpty[T](s: Set[T]): Unit = { + intersectAltDefinition(Set.empty[T], s) + emptySubsetOf(s) + intersectCommutativity(s, Set.empty[T]) + equalsTransitivity(s & Set.empty[T], Set.empty[T] & s, Set.empty[T]) + }.ensuring( + ((Set.empty[T] & s) === Set.empty[T]) && + ((s & Set.empty[T]) === Set.empty[T]) + ) + + /** Any set is disjoint with the empty set. + */ + @pure + @opaque + def disjointEmpty[T](s: Set[T]): Unit = { + intersectionEmpty(s) + unfold(Set.empty[T].disjoint(s)) + unfold((Set.empty[T] & s).isEmpty) + disjointSymmetry(Set.empty[T], s) + }.ensuring(Set.empty[T].disjoint(s) && s.disjoint(Set.empty[T])) + + /** The empty set is a neutral element with respect to the union. + * + * - ∀s. s ∪ ∅ = ∅ + * - ∀s. ∅ ∪ s = ∅ + */ + @pure + @opaque + def unionEmpty[T](s: Set[T]): Unit = { + unionAltDefinition(Set.empty[T], s) + emptySubsetOf(s) + unionCommutativity(s, Set.empty[T]) + equalsTransitivity(s ++ Set.empty[T], Set.empty[T] ++ s, s) + }.ensuring( + s ++ Set.empty[T] === s && + Set.empty[T] ++ s === s + ) + + /** The image of the empty set with respect to a function is also the empty set. + * + * - ∀f. f[∅] = ∅ + */ + @pure + @opaque + def mapEmpty[T, U](f: T => U): Unit = { + if (Set.empty[T].map(f) =/= Set.empty[U]) { + val w: U = notEqualsWitness(Set.empty[T].map(f), Set.empty[U]) + emptyContains(w) + assert(!Set.empty[U].contains(w)) + if (Set.empty[T].map[U](f).contains(w)) { + val w2 = mapContainsWitness(Set.empty[T], f, w) + emptyContains(w2) + } + } + }.ensuring(Set.empty[T].map(f) === Set.empty[U]) + + /** --------------------------------------------------------------------------------------------------------------- * + * ----------------------------------------------------ISEMPTY----------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** If two sets are empty then their are equals (i.e. the empty set is unique). + */ + @pure @opaque + def isEmptyEquals[T](s1: Set[T], s2: Set[T]): Unit = { + if (s1.isEmpty && s2.isEmpty) { + isEmptySubsetOf(s1, s2) + isEmptySubsetOf(s2, s1) + } + if (s1 === s2) { + unfold(s1.isEmpty) + unfold(s2.isEmpty) + equalsTransitivityStrong(Set.empty[T], s1, s2) + } + }.ensuring( + ((s1.isEmpty && s2.isEmpty) ==> (s1 === s2)) && + ((s1 === s2) ==> (s1.isEmpty == s2.isEmpty)) + ) + + /** If a set is empty, then it a subset of any other set. + */ + @pure + @opaque + def isEmptySubsetOf[T](s1: Set[T], s2: Set[T]): Unit = { + require(s1.isEmpty) + unfold(s1.isEmpty) + emptySubsetOf(s2) + equalsSubsetOfTransitivity(s1, Set.empty[T], s2) + }.ensuring(s1.subsetOf(s2)) + + /** If a set is empty if and only if it is a subset of the empty set. + */ + @pure + @opaque + def subsetOfEmpty[T](s: Set[T]): Unit = { + unfold(s.isEmpty) + emptySubsetOf(s) + }.ensuring(s.subsetOf(Set.empty) == s.isEmpty) + + /** If a set is empty then it does not contain any element. + */ + @pure @opaque + def isEmptyContains[T](s: Set[T], e: T): Unit = { + require(s.isEmpty) + unfold(s.isEmpty) + equalsContains(s, Set.empty[T], e) + emptyContains(e) + }.ensuring(!s.contains(e)) + + /** If a set is empty for all elements in it any predicate holds. + */ + @pure + @opaque + def forallIsEmpty[T](s: Set[T], p: T => Boolean): Unit = { + require(s.isEmpty) + unfold(s.isEmpty) + forallEquals(s, Set.empty[T], p) + forallEmpty(p) + }.ensuring(s.forall(p)) + + /** If a set is empty then there exist no element for which a given predicate holds. + */ + @pure + @opaque + def existsIsEmpty[T](s: Set[T], p: T => Boolean): Unit = { + require(s.isEmpty) + unfold(s.isEmpty) + existsEquals(s, Set.empty[T], p) + emptyExists(p) + }.ensuring(!s.exists(p)) + + /** If a set is not empty, then we can exhibit an element such that is contained in it. + * + * @return the element contained in the set + */ + @pure @opaque + def notEmptyWitness[T](s: Set[T]): T = { + require(!s.isEmpty) + unfold(s.isEmpty) + val w = notEqualsWitness(s, Set.empty[T]) + emptyContains(w) + w + }.ensuring(s.contains) + + /** If a set is empty then it is disjoint to any other set. + */ + @pure + @opaque + def disjointIsEmpty[T](s1: Set[T], s2: Set[T]): Unit = { + require(s1.isEmpty || s2.isEmpty) + if (!s1.disjoint(s2)) { + val w = notDisjointWitness(s1, s2) + if (s1.isEmpty) { + isEmptyContains(s1, w) + } else { + isEmptyContains(s2, w) + } + } + disjointSymmetry(s1, s2) + }.ensuring(s1.disjoint(s2) && s2.disjoint(s1)) + + /** If a set is empty then it is a neutral element wrt union. + */ + @pure + @opaque + def unionIsEmpty[T](s1: Set[T], s2: Set[T]): Unit = { + unionEmpty(s1) + unionEmpty(s2) + if (s1.isEmpty) { + unfold(s1.isEmpty) + unionEqualsLeft(s1, Set.empty[T], s2) + equalsTransitivity(s1 ++ s2, Set.empty ++ s2, s2) + } + if (s2.isEmpty) { + unfold(s2.isEmpty) + unionEqualsRight(s1, Set.empty[T], s2) + equalsTransitivity(s1 ++ s2, s1 ++ Set.empty, s1) + } + }.ensuring( + (s1.isEmpty ==> (s1 ++ s2 === s2)) && + (s2.isEmpty ==> (s1 ++ s2 === s1)) + ) + + /** If a set is empty and an element is added to it, then the result is equal to the singleton of that element. + */ + @pure + @opaque + def inclIsEmpty[T](s: Set[T], e: T): Unit = { + require(s.isEmpty) + unfold(s.isEmpty) + equalsIncl(s, Set.empty[T], e) + inclEmpty(e) + unfold(s.incl(e)) + unfold(Set.empty[T].incl(e)) + equalsTransitivity(s.incl(e), Set.empty[T].incl(e), Set(e)) + }.ensuring(s.incl(e) === Set(e)) + + /** If a set is empty then its size is equal to 0. + */ + @pure + @opaque + def isEmptySize[T](s: Set[T]): Unit = { + if (s.isEmpty) { + unfold(s.isEmpty) + sizeEquals(s, Set.empty[T]) + emptySize[T] + } else { + val w = notEmptyWitness(s) + singletonSubsetOf(s, w) + singletonSize(w) + subsetOfSize(Set(w), s) + } + }.ensuring(s.isEmpty == (s.size == BigInt(0))) + + /** The empty set is empty. + */ + @pure + @opaque + def emptyIsEmpty[T](): Unit = { + unfold(Set.empty[T].isEmpty) + equalsReflexivity(Set.empty[T]) + }.ensuring(Set.empty[T].isEmpty) + + /** The image of an empty set is also empty. + */ + @pure + @opaque + def mapIsEmpty[T, U](s: Set[T], f: T => U): Unit = { + mapEmpty(f) + unfold(s.map[U](f).isEmpty) + unfold(s.isEmpty) + if (s.isEmpty) { + mapEquals(s, Set.empty[T], f) + equalsTransitivity(s.map[U](f), Set.empty.map[U](f), Set.empty[U]) + } + if (s.map[U](f).isEmpty && !s.isEmpty) { + val w = notEmptyWitness(s) + mapContains(s, f, w) + isEmptyContains(s.map[U](f), f(w)) + } + }.ensuring(s.map[U](f).isEmpty == s.isEmpty) + + /** --------------------------------------------------------------------------------------------------------------- * + * -----------------------------------------------SINGLETON-------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** A singleton contains an element if and only if the element of the singleton is equal to the latter. + */ + @pure @opaque + def singletonContains[T](e: T, o: T): Unit = { + unfold(Set[T](e).contains) + val eq: T => Boolean = _ == o + val neq: T => Boolean = !eq(_) + forallNotExists(Set(e), eq) + forallSingleton(e, neq) + }.ensuring(Set[T](e).contains(o) == (e == o)) + + /** A singleton contains its own element. + */ + @pure @opaque + def singletonContains[T](e: T): Unit = { + singletonContains(e, e) + }.ensuring(Set[T](e).contains(e)) + + /** There exists an element satisfying a predicate in a singleton if and only if its element satisfies the predicate. + */ + @pure + @opaque + def singletonExists[T](e: T, f: T => Boolean): Unit = { + forallNotExists(Set[T](e), f) + forallSingleton(e, !f(_)) + }.ensuring(Set[T](e).exists(f) == f(e)) + + /** A set contains an element if and only if its singleton is subset of the set. + */ + @pure @opaque + def singletonSubsetOf[T](s: Set[T], e: T): Unit = { + unfold(Set[T](e).subsetOf(s)) + forallSingleton(e, s.contains) + }.ensuring(s.contains(e) == Set[T](e).subsetOf(s)) + + /** Two elements are the same if and only if the singleton of the first is subset of the singleton of the other. + */ + @pure + @opaque + def twoSingletonSubsetOf[T](a: T, b: T): Unit = { + singletonSubsetOf(Set(b), a) + singletonContains(b, a) + }.ensuring((a == b) == Set[T](a).subsetOf(Set[T](b))) + + /** Two elements are the same if and only if the singleton of the first is equal of the singleton of the other. + */ + @pure + @opaque + def twoSingletonEquals[T](a: T, b: T): Unit = { + twoSingletonSubsetOf(a, b) + twoSingletonSubsetOf(b, a) + }.ensuring((a == b) == (Set[T](a) === Set[T](b))) + + /** If a set is a subset of a singleton then either they are equal or the set is empty. + */ + @pure + @opaque + def subsetOfSingleton[T](s: Set[T], e: T): Unit = { + + // ==> + if (s.isEmpty) { + isEmptySubsetOf(s, Set[T](e)) + assert(s.subsetOf(Set[T](e))) + } + + // <== + if (s.subsetOf(Set[T](e))) { + subsetOfSize(s, Set[T](e)) + singletonSize(e) + sizePositive(s) + + isEmptySize[T](s) + + if (s.size == BigInt(1)) { + val w = sizeOneWitness(s) + equalsSubsetOfTransitivity(Set[T](w), s, Set[T](e)) + twoSingletonSubsetOf(w, e) + } + } + }.ensuring((s.isEmpty || s === Set[T](e)) == s.subsetOf(Set[T](e))) + + /** Being disjoint from a singleton is equivalent to not containing its element + */ + @pure @opaque + def disjointSingleton[T](s: Set[T], e: T): Unit = { + if (s.disjoint(Set(e))) { + disjointContains(s, Set(e), e) + singletonContains(e) + } else { + val w = notDisjointWitness(s, Set(e)) + singletonContains(e, w) + } + disjointSymmetry(s, Set(e)) + }.ensuring( + (s.disjoint(Set(e)) == !s.contains(e)) && + (Set(e).disjoint(s) == !s.contains(e)) + ) + + /** If two singleton are disjoint then their element are different + */ + @pure + @opaque + def disjointTwoSingleton[T](a: T, b: T): Unit = { + singletonContains(a, b) + disjointSingleton(Set(a), b) + }.ensuring(Set(a).disjoint(Set(b)) == (a != b)) + + /** A set contains an element if and only if the intersection with its singleton is the singleton itself. + */ + @pure + @opaque + def intersectSingleton[T](s: Set[T], e: T): Unit = { + intersectAltDefinition(Set(e), s) + singletonSubsetOf(s, e) + }.ensuring( + (s.intersect(Set(e)) === Set(e)) == s.contains(e) + ) + + /** If the size of a set is positive then it has an element and we can exhibit it. + */ + @pure + @opaque + def sizePositiveWitness[T](s: Set[T]): T = { + require(s.size >= BigInt(1)) + isEmptySize(s) + notEmptyWitness(s) + }.ensuring(res => s.contains(res)) + + /** If the size of a set is equal to one then it is a singleton. Furthermore we can exhibit its unique element. + */ + @pure + @opaque + def sizeOneWitness[T](s: Set[T]): T = { + require(s.size == BigInt(1)) + val res = sizePositiveWitness(s) + singletonSubsetOf(s, res) + if (!s.subsetOf(Set(res))) { + val w = notSubsetOfWitness(s, Set(res)) + + // (Set(res) ++ Set(w)).size == 2 + singletonContains(res, w) + disjointTwoSingleton(res, w) + unionDisjointSize(Set(res), Set(w)) + singletonSize(res) + singletonSize(w) + + // s.size >= 2 + unionSubsetOf(Set(res), Set(w), s) + singletonSubsetOf(s, res) + singletonSubsetOf(s, w) + subsetOfSize(Set(res) ++ Set(w), s) + } + + res + }.ensuring(res => s === Set[T](res)) + + /** If the union of two sets is equal to a singleton then either: + * - Both sets are equal to this singleton + * - A set is equal to this singleton and the other one is empty + */ + @pure @opaque + def unionEqualsSingleton[T](s1: Set[T], s2: Set[T], e: T): Unit = { + require((s1 ++ s2) === Set(e)) + unionSubsetOf(s1, s2, Set(e)) + subsetOfSingleton(s1, e) + subsetOfSingleton(s2, e) + unionIsEmpty(s1, s2) + if (s1.isEmpty) { + equalsTransitivity(Set(e), s1 ++ s2, s2) + } + if (s2.isEmpty) { + equalsTransitivity(Set(e), s1 ++ s2, s1) + } + }.ensuring( + (s1.isEmpty && s2 === Set(e)) || + (s1 === Set(e) && s2.isEmpty) || + (s1 === Set(e) && s2 === Set(e)) + ) + + /** --------------------------------------------------------------------------------------------------------------- * + * ----------------------------------------------------SUBSET OF--------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** If a set is not a subset of an other one then we can exhibit an element that is in the former but not in the + * latter. + */ + @pure @opaque + def notSubsetOfWitness[T](s1: Set[T], s2: Set[T]): T = { + require(!s1.subsetOf(s2)) + unfold(s1.subsetOf(s2)) + val d = notForallWitness(s1, s2.contains) + d + }.ensuring(res => s1.contains(res) && !s2.contains(res)) + + /** Pingeonhole principle + * + * If a set is bigger then an other one we can find an element that is in the former but not in the latter. + */ + @pure + @opaque + def pigeonhole[T](s1: Set[T], s2: Set[T]): T = { + require(s1.size > s2.size) + intersectSize(s1, s2) + diffSize(s1, s2) + isEmptySize(s1 &~ s2) + val res = notEmptyWitness(s1 &~ s2) + diffContains(s1, s2, res) + res + }.ensuring(res => s1.contains(res) && !s2.contains(res)) + + /** If two sets are not equal then we can find an element that is in one of the two sets but not the other one. + */ + @pure + @opaque + def notEqualsWitness[T](s1: Set[T], s2: Set[T]): T = { + require(s1 =/= s2) + val d = + if (!s1.subsetOf(s2)) + notSubsetOfWitness(s1, s2) + else + notSubsetOfWitness(s2, s1) + d + }.ensuring(res => s1(res) != s2(res)) + + /** If a set contains an element but a second one does not contain it then the former cannot be a subset of the + * latter. + */ + @pure @opaque + def isDifferentiatedByNotSubset[T](s1: Set[T], s2: Set[T], d: T): Unit = { + require(s1.contains(d)) + require(!s2.contains(d)) + + if (s1.subsetOf(s2)) { + subsetOfContains(s1, s2, d) + Unreachable() + } + }.ensuring( + !s1.subsetOf(s2) + ) + + /** All the elements of a subset belong to the superset. + * + * - ∀s1,s2. s1 ⊆ s2 => (∀x. x ∈ s1 => x ∈ s2) + */ + @pure + @opaque + def subsetOfContains[T](s1: Set[T], s2: Set[T], e: T): Unit = { + require(s1.subsetOf(s2)) + unfold(s1.subsetOf(s2)) + if (s1(e)) { + forallContains(s1, s2.contains, e) + } + }.ensuring( + (s1(e) ==> s2(e)) && + (!s2(e) ==> !s1(e)) + ) + + /** The size of a subset is smaller then the size of its superset. + * + * - ∀s1,s2. s1 ⊆ s2 => |s1| ≤ |s2| + */ + @pure @opaque + def subsetOfSize[T](s1: Set[T], s2: Set[T]): Unit = { + require(s1.subsetOf(s2)) + unionAltDefinition(s1, s2) + unionSize(s1, s2) + sizeEquals(s1 ++ s2, s2) + }.ensuring(s1.size <= s2.size) + + /** Every set is subset to itself (i.e. subsetOf is a reflexive relation). + * + * - ∀s. s ⊆ s + */ + @pure @opaque + def subsetOfReflexivity[T](s: Set[T]): Unit = { + if (!s.subsetOf(s)) { + val d = notSubsetOfWitness(s, s) + Unreachable() + } + }.ensuring(s.subsetOf(s)) + + /** SubsetOf is a transitive relation. + * + * - ∀s1,s2,s3. s1 ⊆ s2 /\ s2 ⊆ s3 => s1 ⊆ s3 + */ + @pure @opaque + def subsetOfTransitivity[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + require(s1.subsetOf(s2)) + require(s2.subsetOf(s3)) + if (!s1.subsetOf(s3)) { + val d = notSubsetOfWitness(s1, s3) + subsetOfContains(s1, s2, d) + subsetOfContains(s2, s3, d) + Unreachable() + } + }.ensuring(s1.subsetOf(s3)) + + /** If a set is equal to another then the former is a subset of a third set if and only if the latter also is. + * + * - ∀s1,s2,s3. s1 = s2 => (s1 ⊆ s3 <=> s2 ⊆ s3) + */ + @pure + @opaque + def equalsSubsetOfTransitivity[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + require(s1 === s2) + + if (s1.subsetOf(s3)) { + subsetOfTransitivity(s2, s1, s3) + } + if (s2.subsetOf(s3)) { + subsetOfTransitivity(s1, s2, s3) + } + + }.ensuring( + s1.subsetOf(s3) == s2.subsetOf(s3) + ) + + /** If a set is equal to another then the former is a superset of a third set if and only if the latter also is. + * + * - ∀s1,s2,s3. s2 = s3 => (s1 ⊆ s2 <=> s1 ⊆ s3) + */ + @pure + @opaque + def subsetOfEqualsTransitivity[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + require(s2 === s3) + + if (s1.subsetOf(s2)) { + subsetOfTransitivity(s1, s2, s3) + } + if (s1.subsetOf(s3)) { + subsetOfTransitivity(s1, s3, s2) + } + + }.ensuring( + s1.subsetOf(s2) == s1.subsetOf(s3) + ) + + /** --------------------------------------------------------------------------------------------------------------- * + * ----------------------------------------------------EQUALS------------------------------------------------------ * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** If two equals are equals then an element belongs to the former if and only if it belongs to the latter (ie. set + * equality is congruent wrt contains). + * + * - ∀s1,s2. s1 = s2 => (∀x. x ∈ s1 <=> x ∈ s2) + */ + @pure + @opaque + def equalsContains[T](s1: Set[T], s2: Set[T], e: T): Unit = { + require(s1 === s2) + subsetOfContains(s1, s2, e) + subsetOfContains(s2, s1, e) + }.ensuring( + s1(e) == s2(e) + ) + + /** Set equality is reflexive + * + * - ∀s. s = s + */ + @pure + @opaque + def equalsReflexivity[T](s: Set[T]): Unit = { + subsetOfReflexivity(s) + }.ensuring( + s === s + ) + + /** Set equality is transitive + * + * - ∀s1,s2,s3. s2 = s3 => (s1 = s2 <=> s1 = s3) + */ + @pure + @opaque + def equalsTransitivityStrong[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + require(s2 === s3) + subsetOfEqualsTransitivity(s1, s2, s3) + equalsSubsetOfTransitivity(s2, s3, s1) + }.ensuring((s1 === s2) == (s1 === s3)) + + /** Set equality is transitive + * + * - ∀s1,s2,s3. s1 = s2 /\ s2 = s3 => s1 = s3 + */ + @pure + @opaque + def equalsTransitivity[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + require(s1 === s2) + require(s2 === s3) + equalsTransitivityStrong(s1, s2, s3) + }.ensuring(s1 === s3) + + /** --------------------------------------------------------------------------------------------------------------- * + * ----------------------------------------------------UNION------------------------------------------------------ * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** If an union between two sets contains an element then at least one of the two sets contain the element. + * + * - ∀s1,s2,x. x ∈ s1 ∪ s2 => (x ∈ s1 \/ x ∈ s2) + */ + @pure + @opaque + def unionContains[T](s1: Set[T], s2: Set[T], e: T): Unit = { + unfold((s1 ++ s2).contains) + unfold(s1.contains) + unfold(s2.contains) + val eq: T => Boolean = _ == e + val neq: T => Boolean = !eq(_) + forallNotExists(s1 ++ s2, eq) + forallNotExists(s1, eq) + forallNotExists(s2, eq) + forallUnion(s1, s2, neq) + }.ensuring((s1 ++ s2)(e) == (s1(e) || s2(e))) + + /** Alternative definition of the subset relation. + * + * A set is subset of another one if and only if the union of both is equal to the latter. + */ + @pure + @opaque + def unionAltDefinition[T](s1: Set[T], s2: Set[T]): Unit = { + unionSubsetOf(s1, s2, s2) + subsetOfReflexivity(s2) + subsetOfUnion(s1, s2) + unionCommutativity(s1, s2) + equalsTransitivityStrong(s2, s2 ++ s1, s1 ++ s2) + }.ensuring( + (s1.subsetOf(s2) == (s1 ++ s2 === s2)) && + (s1.subsetOf(s2) == (s2 ++ s1 === s2)) + ) + + /** The union between a set and itself gives the original set (i.e. the union is an idempotent operator). + * + * - ∀s. s ∪ s = s + */ + @pure + @opaque + def unionIdempotence[T](s: Set[T]): Unit = { + subsetOfReflexivity(s) + unionAltDefinition(s, s) + }.ensuring(s ++ s === s) + + /** If there exists in the union of two sets, an element satisfying a given predicate then the predicate is either + * satisfied by an element in the first set or the second one. + * + * - ∀s1,s2,f. (∃x. x ∈ s1 ∪ s2 /\ f(x)) <=> ((∃x. x ∈ s1 /\ f(x)) \/ (∃x. x ∈ s2 /\ f(x))) + */ + @pure + @opaque + def unionExists[T](s1: Set[T], s2: Set[T], f: T => Boolean): Unit = { + SetAxioms.forallNotExists(s1, f) + SetAxioms.forallNotExists(s2, f) + SetAxioms.forallNotExists(s1 ++ s2, f) + SetAxioms.forallUnion(s1, s2, !f(_)) + }.ensuring((s1 ++ s2).exists(f) == (s1.exists(f) || s2.exists(f))) + + /** The union between two sets is subset of a third set if and only if both are subset of the latter. + * + * - ∀s1,s2,s3. s1 ∪ s2 ⊆ s3 <=> (s1 ⊆ s3 /\ s2 ⊆ s3) + */ + @pure + @opaque + def unionSubsetOf[T](s1: Set[T], s2: Set[T], s: Set[T]): Unit = { + unfold((s1 ++ s2).subsetOf(s)) + unfold(s1.subsetOf(s)) + unfold(s2.subsetOf(s)) + forallUnion(s1, s2, s.contains) + }.ensuring((s1 ++ s2).subsetOf(s) == (s1.subsetOf(s) && s2.subsetOf(s))) + + /** Any set is subset of its union with an other set. + * + * - ∀s1,s2. s1 ⊆ s1 ∪ s2 /\ s2 ⊆ s1 ∪ s2 + */ + @pure @opaque + def subsetOfUnion[T](s1: Set[T], s2: Set[T]): Unit = { + if (!s1.subsetOf(s1 ++ s2)) { + val d = notSubsetOfWitness(s1, s1 ++ s2) + unionContains(s1, s2, d) + } + if (!s2.subsetOf(s1 ++ s2)) { + val d = notSubsetOfWitness(s2, s1 ++ s2) + unionContains(s1, s2, d) + } + }.ensuring( + s1.subsetOf(s1 ++ s2) && + s2.subsetOf(s1 ++ s2) + ) + + /** If a set is subset of another one it is also a subset of the union between latter and any third set. + * + * - ∀s,s1,s2. s ⊆ s1 \/ s ⊆ s2 => s ⊆ s1 ∪ s2 + */ + @pure + @opaque + def subsetOfUnion[T](s: Set[T], s1: Set[T], s2: Set[T]): Unit = { + require(s.subsetOf(s1) || s.subsetOf(s2)) + subsetOfUnion(s1, s2) + if (s.subsetOf(s1)) { + subsetOfTransitivity(s, s1, s1 ++ s2) + } else { + subsetOfTransitivity(s, s2, s1 ++ s2) + } + }.ensuring( + s.subsetOf(s1 ++ s2) + ) + + /** If a set is subset of another then its union with a third set is also subset of the union of the second set. + */ + @pure + @opaque + def unionSubsetOfRight[T](s: Set[T], s1: Set[T], s2: Set[T]): Unit = { + require(s1.subsetOf(s2)) + subsetOfReflexivity(s) + unionSubsetOf(s, s1, s ++ s2) + subsetOfUnion(s, s, s2) + subsetOfUnion(s1, s, s2) + }.ensuring((s ++ s1).subsetOf(s ++ s2)) + + /** If two sets are equal then their union with a third set is equal as well. + */ + @pure + @opaque + def unionEqualsRight[T](s: Set[T], s1: Set[T], s2: Set[T]): Unit = { + require(s1 === s2) + equalsReflexivity(s) + unionEquals(s, s, s1, s2) + }.ensuring((s ++ s1) === (s ++ s2)) + + /** If a set is subset of another then its union with a third set is also subset of the union of the second set. + */ + @pure + @opaque + def unionSubsetOfLeft[T](s1: Set[T], s2: Set[T], s: Set[T]): Unit = { + require(s1.subsetOf(s2)) + subsetOfReflexivity(s) + unionSubsetOf(s1, s, s2 ++ s) + subsetOfUnion(s, s2, s) + subsetOfUnion(s1, s2, s) + }.ensuring((s1 ++ s).subsetOf(s2 ++ s)) + + /** If two sets are equal then their union with a third set is equal as well. + */ + @pure + @opaque + def unionEqualsLeft[T](s1: Set[T], s2: Set[T], s: Set[T]): Unit = { + require(s1 === s2) + equalsReflexivity(s) + unionEquals(s1, s2, s, s) + }.ensuring((s1 ++ s) === (s2 ++ s)) + + /** If two pairs of set are subset one of the other, then their unions are also subset of the other. + */ + @pure + @opaque + def unionSubsetOf[T](s11: Set[T], s12: Set[T], s21: Set[T], s22: Set[T]): Unit = { + require(s11.subsetOf(s12)) + require(s21.subsetOf(s22)) + unionSubsetOfLeft(s11, s12, s21) + unionSubsetOfRight(s12, s21, s22) + subsetOfTransitivity(s11 ++ s21, s12 ++ s21, s12 ++ s22) + }.ensuring((s11 ++ s21).subsetOf(s12 ++ s22)) + + /** If two pairs of set are equal, then their unions also are. + */ + @pure + @opaque + def unionEquals[T](s11: Set[T], s12: Set[T], s21: Set[T], s22: Set[T]): Unit = { + require(s11 === s12) + require(s21 === s22) + unionSubsetOf(s11, s12, s21, s22) + unionSubsetOf(s12, s11, s22, s21) + }.ensuring(s11 ++ s21 === s12 ++ s22) + + /** Union is a commutative operation. + */ + @pure + @opaque + def unionCommutativity[T](s1: Set[T], s2: Set[T]): Unit = { + unionSubsetOf(s1, s2, s2 ++ s1) + subsetOfUnion(s2, s1) + unionSubsetOf(s2, s1, s1 ++ s2) + subsetOfUnion(s1, s2) + }.ensuring(s1 ++ s2 === s2 ++ s1) + + /** Union is an associative operation. + */ + @pure + @opaque + def unionAssociativity[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + // (s1 ++ s2) ++ s3 < s1 ++ (s2 ++ s3) + unionSubsetOf(s1 ++ s2, s3, s1 ++ (s2 ++ s3)) + unionSubsetOf(s1, s2, s1 ++ (s2 ++ s3)) + subsetOfUnion(s2, s3) + subsetOfUnion(s1, s2 ++ s3) + subsetOfTransitivity(s2, s2 ++ s3, s1 ++ (s2 ++ s3)) + subsetOfTransitivity(s3, s2 ++ s3, s1 ++ (s2 ++ s3)) + + // (s1 ++ s2) ++ s3 > s1 ++ (s2 ++ s3) + unionSubsetOf(s1, s2 ++ s3, (s1 ++ s2) ++ s3) + unionSubsetOf(s2, s3, (s1 ++ s2) ++ s3) + subsetOfUnion(s1, s2) + subsetOfUnion(s1 ++ s2, s3) + subsetOfTransitivity(s1, s1 ++ s2, (s1 ++ s2) ++ s3) + subsetOfTransitivity(s2, s1 ++ s2, (s1 ++ s2) ++ s3) + + }.ensuring((s1 ++ s2) ++ s3 === s1 ++ (s2 ++ s3)) + + /** The union between two sets is emtpy if and only if both are empty. + */ + @pure + @opaque + def isEmptyUnion[T](s1: Set[T], s2: Set[T]): Unit = { + subsetOfEmpty(s1 ++ s2) + subsetOfEmpty(s1) + subsetOfEmpty(s2) + unionSubsetOf(s1, s2, Set.empty[T]) + }.ensuring((s1 ++ s2).isEmpty == (s1.isEmpty && s2.isEmpty)) + + /** The union of two sets is equal to the union between the first one and the difference between the second and the + * first one. + */ + @pure + @opaque + def unionDiffDef[T](s1: Set[T], s2: Set[T]): Unit = { + if ((s1 ++ s2) =/= (s1.diff(s2) ++ s2)) { + val w = notEqualsWitness(s1 ++ s2, s1.diff(s2) ++ s2) + unionContains(s1, s2, w) + unionContains(s1.diff(s2), s2, w) + diffContains(s1, s2, w) + } + if ((s1 ++ s2) =/= (s1 ++ s2.diff(s1))) { + val w = notEqualsWitness(s1 ++ s2, s1 ++ s2.diff(s1)) + unionContains(s1, s2, w) + unionContains(s1, s2.diff(s1), w) + diffContains(s2, s1, w) + } + }.ensuring( + ((s1 ++ s2) === (s1.diff(s2) ++ s2)) && + ((s1 ++ s2) === (s1 ++ s2.diff(s1))) + ) + + /** Inclusion-exclusion principle + * + * The size of the union between two sets is greater or equal than the size of both sets but smaller or equal than + * the sum of the sizes. It is equal to the sum of the size of both sets minus the size of the intersection. + */ + @pure @opaque + def unionSize[T](s1: Set[T], s2: Set[T]): Unit = { + + // (s1 ++ s2).size >= s2.size + unionDiffDef(s1, s2) + sizeEquals(s1 ++ s2, s1.diff(s2) ++ s2) + diffDisjoint(s1, s2) + unionDisjointSize(s1.diff(s2), s2) + sizePositive(s1.diff(s2)) + + // (s1 ++ s2).size >= s1.size + sizeEquals(s1 ++ s2, s1 ++ s2.diff(s1)) + diffDisjoint(s2, s1) + disjointSymmetry(s2.diff(s1), s1) + unionDisjointSize(s1, s2.diff(s1)) + sizePositive(s2.diff(s1)) + + // (s1 ++ s2).size <= s1.size + s2.size + diffSize(s1, s2) + sizePositive(s1 & s2) + + }.ensuring( + (s1 ++ s2).size <= s1.size + s2.size && + ((s1 ++ s2).size == s1.size + s2.size - (s1 & s2).size) && + ((s1 ++ s2).size >= s2.size) && + ((s1 ++ s2).size >= s1.size) + ) + + /** --------------------------------------------------------------------------------------------------------------- * + * ----------------------------------------------------INCL-------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** If a set is subset of an other one then adding a value to the second set or on both sides does not change the + * subset relationship. Furthermore, if one the two sets already contains the value, adding it to the first set + * also does not change the relationship. + */ + @pure + @opaque + def subsetOfIncl[T](s1: Set[T], s2: Set[T], e: T): Unit = { + require(s1.subsetOf(s2)) + subsetOfTransitivity(s1, s2, s2 + e) + unfold(s1.incl(e)) + unfold(s2.incl(e)) + unionSubsetOfLeft(s1, s2, Set(e)) + unionSubsetOf(s1, Set(e), s2) + if (s1(e)) { + subsetOfContains(s1, s2, e) + } + if (s1(e) || s2(e)) { + singletonSubsetOf(s2, e) + } + + }.ensuring( + s1.subsetOf(s2 + e) && + (s1 + e).subsetOf(s2 + e) && + ((s1(e) || s2(e)) ==> (s1 + e).subsetOf(s2)) + ) + + /** If two sets are equal then they stay equal when adding a value on both sides. + * Futhermore if one of the sets contains the value, this also holds when adding it on + * one side only. + */ + @pure + @opaque + def equalsIncl[T](s1: Set[T], s2: Set[T], e: T): Unit = { + require(s1 === s2) + subsetOfIncl(s1, s2, e) + subsetOfIncl(s2, s1, e) + }.ensuring( + s1 + e === s2 + e && + ((s1(e) || s2(e)) ==> (s1 + e === s2 && s1 === s2 + e)) + ) + + /** A predicate holds for all elements of a set for which a value has been added to it if and only if it holds for all + * elements of the set and for the value. + */ + @pure + @opaque + def forallIncl[T](s: Set[T], e: T, p: T => Boolean): Unit = { + unfold(s.incl(e)) + forallSingleton(e, p) + forallUnion(s, Set[T](e), p) + }.ensuring((s + e).forall(p) == (s.forall(p) && p(e))) + + /** There exists an element satisfying a predicate in a set with a value added to it if an only if the predicate is + * satisfied for an element in the original set or if is true for the value. + */ + @pure + @opaque + def inclExists[T](s: Set[T], e: T, f: T => Boolean): Unit = { + SetAxioms.forallNotExists(s, f) + SetAxioms.forallNotExists(s + e, f) + forallIncl(s, e, !f(_)) + }.ensuring((s + e).exists(f) == (s.exists(f) || f(e))) + + /** A set with a value added to it contains an element if an only if the set contains this element or if it is equal + * to the value. + */ + @pure @opaque + def inclContains[T](s: Set[T], add: T, e: T): Unit = { + unfold(s.incl(add)) + unionContains(s, Set(add), e) + singletonContains(add, e) + }.ensuring((s + add).contains(e) == (s.contains(e) || add == e)) + + /** Making the union of two sets and then adding an element is equivalent to adding an element to the second set and + * then perform the union. + */ + @pure + @opaque + def unionInclAssociativity[T](s1: Set[T], s2: Set[T], e: T): Unit = { + unfold((s1 ++ s2).incl(e)) + unfold(s2.incl(e)) + unionAssociativity(s1, s2, Set(e)) + }.ensuring((s1 ++ s2) + e === s1 ++ (s2 + e)) + + /** When adding two elements to a set, the order of the addition does not matter. + */ + @pure + @opaque + def inclCommutativity[T](s: Set[T], e1: T, e2: T): Unit = { + + unfold(Set(e1).incl(e2)) + unfold(Set(e2).incl(e1)) + unfold(s.incl(e1)) + unfold(s.incl(e2)) + unionCommutativity(Set(e1), Set(e2)) + unionEqualsRight(s, Set(e1) + e2, Set(e2) + e1) + unionInclAssociativity(s, Set(e1), e2) + unionInclAssociativity(s, Set(e2), e1) + equalsTransitivity((s ++ Set(e1)) + e2, s ++ (Set(e1) + e2), s ++ (Set(e2) + e1)) + equalsTransitivity((s ++ Set(e1)) + e2, s ++ (Set(e2) + e1), (s ++ Set(e2)) + e1) + }.ensuring((s + e1) + e2 === (s + e2) + e1) + + /** --------------------------------------------------------------------------------------------------------------- * + * ---------------------------------------------FORALL/EXISTS------------------------------------------------------ * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** If predicate is not satisfied for all the elements of a set, we can exhibit an element in the set such that the + * predicate is not satisfied. + */ + @pure + @opaque + @inlineOnce + def notForallWitness[T](s: Set[T], f: T => Boolean): T = { + require(!s.forall(f)) + notForallExists(s, f) + s.witness(!f(_)) + }.ensuring(w => s.contains(w) && !f(w)) + + /** If a predicate is valid for all elements of a set and the same set contains an element, then this predicate is + * valid for the element. + */ + @pure + @opaque + def forallContains[T](s: Set[T], f: T => Boolean, e: T): Unit = { + require(s.forall(f)) + require(s.contains(e)) + val nf: T => Boolean = !f(_) + if (nf(e)) { + witnessExists(s, nf, e) + notForallExists(s, f) + } + }.ensuring(f(e)) + + /** If a set is subset of an other and a predicate is valid for all elements of the latter, then it also is for all + * elements of the former + */ + @pure @opaque + def forallSubsetOf[T](s1: Set[T], s2: Set[T], f: T => Boolean): Unit = { + require(s1.subsetOf(s2)) + require(s2.forall(f)) + if (!s1.forall(f)) { + val nf: T => Boolean = !f(_) + val w = notForallWitness(s1, f) + subsetOfContains(s1, s2, w) + forallContains(s2, f, w) + Unreachable() + } + }.ensuring(s1.forall(f)) + + /** If a set is subset of an other and a predicate is valid for one of the elements of the former, then it also is + * for an element of the latter. + */ + @pure + @opaque + def existsSubsetOf[T](s1: Set[T], s2: Set[T], f: T => Boolean): Unit = { + require(s1.subsetOf(s2)) + require(s1.exists(f)) + forallNotExists(s1, f) + forallNotExists(s2, f) + if (!s2.exists(f)) { + forallSubsetOf(s1, s2, x => !f(x)) + } + }.ensuring(s2.exists(f)) + + /** If two sets are equal then a predicate is true for all the elements of the first set if and only if it is true for + * all the elements of the second one. + */ + @pure + @opaque + def forallEquals[T](s1: Set[T], s2: Set[T], f: T => Boolean): Unit = { + require(s1 === s2) + if (s2.forall(f)) { + forallSubsetOf(s1, s2, f) + } + if (s1.forall(f)) { + forallSubsetOf(s2, s1, f) + } + + }.ensuring(s1.forall(f) == s2.forall(f)) + + /** If two sets are equal there exist an element satisfying a given predicate in the first one if and only if it is + * also the case in the second one. + */ + @pure + @opaque + def existsEquals[T](s1: Set[T], s2: Set[T], f: T => Boolean): Unit = { + require(s1 === s2) + if (s2.exists(f)) { + existsSubsetOf(s2, s1, f) + } + if (s1.exists(f)) { + existsSubsetOf(s1, s2, f) + } + + }.ensuring(s1.exists(f) == s2.exists(f)) + + /** Double negation rule inside a forall predicate. + */ + @pure + @opaque + def notNotForall[T](s: Set[T], f: T => Boolean): Unit = { + val nnf: T => Boolean = x => !(!f(x)) + if (!s.forall(f) && s.forall(nnf)) { + val w = notForallWitness(s, f) + forallContains(s, nnf, w) + } + if (!s.forall(nnf) && s.forall(f)) { + val w = notForallWitness(s, nnf) + forallContains(s, f, w) + } + }.ensuring(s.forall(f) == s.forall(x => !(!f(x)))) + + /** --------------------------------------------------------------------------------------------------------------- * + * ----------------------------------------------------FILTER------------------------------------------------------ * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** A set filtered by a predicate contains an element if and only if the element is both in the set and satisfies + * the predicate. + */ + @pure @opaque + def filterContains[T](s: Set[T], f: T => Boolean, e: T): Unit = { + + unfold(s.contains) + unfold(s.filter(f).contains) + forallNotExists(s, _ == e) + forallNotExists(s.filter(f), _ == e) + forallFilter(s, f, x => !(x == e)) + + val g: T => Boolean = x => f(x) ==> !(x == e) + + if (!s.filter(f).contains(e)) { + if (s.contains(e)) { + forallContains(s, g, e) + } + } else { + val w = notForallWitness(s, g) + } + }.ensuring(s.filter(f).contains(e) == (f(e) && s.contains(e))) + + /** A set filtered by a predicate is empty if and only if all the elements do not satisfy the predicate. + */ + @pure @opaque @inlineOnce + def filterIsEmpty[T](s: Set[T], f: T => Boolean): Unit = { + if (!s.forall(!f(_)) && s.filter(f).isEmpty) { + val w = notForallWitness(s, !f(_)) + filterContains(s, f, w) + isEmptyContains(s.filter(f), w) + } + if (!s.filter(f).isEmpty && s.forall(!f(_))) { + val w = notEmptyWitness(s.filter(f)) + filterContains(s, f, w) + forallContains(s, !f(_), w) + } + }.ensuring(s.filter(f).isEmpty == s.forall(!f(_))) + + /** If a set is subset of an other then they stay subset when they are filtered by some predicate. + */ + @pure @opaque + def filterSubsetOf[T](s1: Set[T], s2: Set[T], f: T => Boolean): Unit = { + require(s1.subsetOf(s2)) + if (!s1.filter(f).subsetOf(s2.filter(f))) { + val w = notSubsetOfWitness(s1.filter(f), s2.filter(f)) + filterContains(s1, f, w) + filterContains(s2, f, w) + subsetOfContains(s1, s2, w) + } + }.ensuring(s1.filter(f).subsetOf(s2.filter(f))) + + /** If two sets are equal then they stay equal when they are filtered by some predicate. + */ + @pure + @opaque + def filterEquals[T](s1: Set[T], s2: Set[T], f: T => Boolean): Unit = { + require(s1 === s2) + filterSubsetOf(s1, s2, f) + filterSubsetOf(s2, s1, f) + }.ensuring(s1.filter(f) === s2.filter(f)) + + /** A set filter by a predicate is always a subset of the original set. + */ + @pure @opaque + def filterSubsetOf[T](s: Set[T], f: T => Boolean): Unit = { + if (!s.filter(f).subsetOf(s)) { + val w = notSubsetOfWitness(s.filter(f), s) + filterContains(s, f, w) + } + }.ensuring(s.filter(f).subsetOf(s)) + + /** A set is equal to its filter if and only if all its elements satisfy the predicate. + */ + @pure + @opaque + def filterEquals[T](s: Set[T], f: T => Boolean): Unit = { + filterSubsetOf(s, f) + if (s.forall(f) && !s.subsetOf(s.filter(f))) { + val w = notSubsetOfWitness(s, s.filter(f)) + filterContains(s, f, w) + forallContains(s, f, w) + } + if (s.subsetOf(s.filter(f)) && !s.forall(f)) { + val w = notForallWitness(s, f) + subsetOfContains(s, s.filter(f), w) + filterContains(s, f, w) + } + }.ensuring((s === s.filter(f)) == s.forall(f)) + + /** --------------------------------------------------------------------------------------------------------------- * + * ---------------------------------------------DIFF/REMOVE-------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** If the difference between two sets contains an element then the first set contains it but not the second one. + */ + @pure @opaque + def diffContains[T](s1: Set[T], s2: Set[T], e: T): Unit = { + unfold(s1.diff(s2)) + filterContains(s1, !s2.contains(_), e) + }.ensuring((s1 &~ s2).contains(e) == (s1.contains(e) && !s2.contains(e))) + + /** If a set is subset of an other then their differences with a third set is preserves the subset relationship. + * If the subset relation happens on the other side, then the relationship is inverted. + */ + @pure @opaque + def diffSubsetOf[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + unfold(s1.diff(s3)) + unfold(s2.diff(s3)) + if (s1.subsetOf(s2)) { + filterSubsetOf(s1, s2, x => !s3.contains(x)) + } + if (s2.subsetOf(s3) && !(s1 &~ s3).subsetOf(s1 &~ s2)) { + val w = notSubsetOfWitness(s1 &~ s3, s1 &~ s2) + diffContains(s1, s3, w) + diffContains(s1, s2, w) + subsetOfContains(s2, s3, w) + } + }.ensuring( + (s1.subsetOf(s2) ==> (s1 &~ s3).subsetOf(s2 &~ s3)) && + (s2.subsetOf(s3) ==> (s1 &~ s3).subsetOf(s1 &~ s2)) + ) + + /** If two pairs of set are subset of an other then their differences preserves the subset relationship on the left + * and inverts it on the right. + */ + @pure + @opaque + def diffSubsetOf[T](s1: Set[T], s2: Set[T], s3: Set[T], s4: Set[T]): Unit = { + require(s1.subsetOf(s2)) + require(s4.subsetOf(s3)) + diffSubsetOf(s1, s2, s3) + diffSubsetOf(s2, s4, s3) + subsetOfTransitivity(s1 &~ s3, s2 &~ s3, s2 &~ s4) + }.ensuring( + (s1 &~ s3).subsetOf(s2 &~ s4) + ) + + /** If two sets are equal then their differences with a third set are also equal. + */ + @pure + @opaque + def diffEquals[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + diffSubsetOf(s1, s2, s3) + diffSubsetOf(s2, s1, s3) + diffSubsetOf(s1, s3, s2) + }.ensuring( + ((s1 === s2) ==> ((s1 &~ s3) === (s2 &~ s3))) && + ((s2 === s3) ==> ((s1 &~ s3) === (s1 &~ s2))) + ) + + /** If two pairs of sets are equal then their differences are equal as well. + */ + @pure + @opaque + def diffEquals[T](s1: Set[T], s2: Set[T], s3: Set[T], s4: Set[T]): Unit = { + require(s1 === s2) + require(s3 === s4) + diffSubsetOf(s1, s2, s3, s4) + diffSubsetOf(s2, s1, s4, s3) + }.ensuring( + (s1 &~ s3) === (s2 &~ s4) + ) + + /** If a set is empty, then its difference with any other set is also empty. Furthermore, the difference of a set and + * a second one is empty if and only if the former is a subset of the latter. + */ + @pure + @opaque + def diffIsEmpty[T](s1: Set[T], s2: Set[T]): Unit = { + unfold(s1.diff(s2)) + unfold(s1.subsetOf(s2)) + filterIsEmpty(s1, !s2.contains(_)) + notNotForall(s1, s2.contains) + if (s1.isEmpty) { + isEmptySubsetOf(s1, s2) + } + }.ensuring( + ((s1 &~ s2).isEmpty == s1.subsetOf(s2)) && + (s1.isEmpty ==> (s1 &~ s2).isEmpty) + ) + + /** A set difference is always disjoint with the second set. Furthermore the difference between a first set + * and a second one is disjoint from the difference of the second and the first one. + */ + @pure + @opaque + def diffDisjoint[T](s1: Set[T], s2: Set[T]): Unit = { + if (!(s1 &~ s2).disjoint(s2)) { + val w = notDisjointWitness(s1 &~ s2, s2) + diffContains(s1, s2, w) + } + disjointSubsetOf(s1 &~ s2, s2, s2 &~ s1) + }.ensuring( + (s1 &~ s2).disjoint(s2) && + (s1 &~ s2).disjoint(s2 &~ s1) + ) + + /** The difference of a set and a second one is equal to the difference of the former and the intersection of both + */ + @pure + @opaque + def diffIntersection[T](s1: Set[T], s2: Set[T]): Unit = { + if ((s1 &~ s2) =/= (s1 &~ (s1 & s2))) { + val w = notEqualsWitness(s1 &~ s2, s1 &~ (s1 & s2)) + diffContains(s1, s2, w) + diffContains(s1, s1 & s2, w) + intersectContains(s1, s2, w) + } + }.ensuring( + (s1 &~ s2) === (s1 &~ (s1 & s2)) + ) + + @pure @opaque + def diffUnionDef[T](s1: Set[T], s2: Set[T]): Unit = { + if (s1 =/= (s1 &~ s2) ++ (s1 & s2)) { + val w = notEqualsWitness(s1, (s1 &~ s2) ++ (s1 & s2)) + unionContains(s1 &~ s2, s1 & s2, w) + diffContains(s1, s2, w) + intersectContains(s1, s2, w) + } + diffIntersection(s1, s2) + unionEqualsLeft(s1 &~ s2, s1 &~ (s1 & s2), s1 & s2) + equalsTransitivity(s1, (s1 &~ s2) ++ (s1 & s2), (s1 &~ (s1 & s2)) ++ (s1 & s2)) + }.ensuring( + (s1 === (s1 &~ s2) ++ (s1 & s2)) && + (s1 === (s1 &~ (s1 & s2)) ++ (s1 & s2)) + ) + + /** The size of the difference between two sets is smaller than the size of the first one. It is actually equal to the + * size of the first one minus the size of the intersection of both. + */ + @pure + @opaque + def diffSize[T](s1: Set[T], s2: Set[T]): Unit = { + diffDisjoint(s1, s1 & s2) + unionDisjointSize(s1 &~ (s1 & s2), s1 & s2) + diffUnionDef(s1, s2) + sizeEquals(s1, (s1 &~ (s1 & s2)) ++ (s1 & s2)) + diffIntersection(s1, s2) + sizeEquals(s1 &~ s2, s1 &~ (s1 & s2)) + sizePositive(s1 & s2) + }.ensuring( + ((s1 &~ s2).size == s1.size - (s1 & s2).size) && + ((s1 &~ s2).size <= s1.size) + ) + + /** A set is disjoint to an other one if and only if the difference of the first and the second one is equal to the + * former set. + */ + @pure @opaque + def diffEquals[T](s1: Set[T], s2: Set[T]): Unit = { + if (s1.disjoint(s2) & ((s1 &~ s2) =/= s1)) { + val w = notEqualsWitness(s1 &~ s2, s1) + diffContains(s1, s2, w) + disjointContains(s1, s2, w) + } + if (!s1.disjoint(s2) & ((s1 &~ s2) === s1)) { + val w = notDisjointWitness(s1, s2) + equalsContains(s1 &~ s2, s1, w) + diffContains(s1, s2, w) + } + }.ensuring(((s1 &~ s2) === s1) == s1.disjoint(s2)) + + /** A set minus a value contains an element if and only if the set contains this element and it is different from the value + */ + @pure + @opaque + def removeContains[T](s: Set[T], r: T, e: T): Unit = { + unfold(s.remove(r)) + diffContains(s, Set(r), e) + singletonContains(r, e) + }.ensuring((s - r).contains(e) == (s.contains(e) && r != e)) + + /** A set minus a value does not contain this value + */ + @pure + @opaque + def removeContains[T](s: Set[T], r: T): Unit = { + removeContains(s, r, r) + }.ensuring(!(s - r).contains(r)) + + /** If a set is subset of an other then removing an element on both sides preserves this relationship. + */ + @pure + @opaque + def removeSubsetOf[T](s1: Set[T], s2: Set[T], r: T): Unit = { + require(s1.subsetOf(s2)) + unfold(s1.remove(r)) + unfold(s2.remove(r)) + diffSubsetOf(s1, s2, Set(r)) + }.ensuring((s1 - r).subsetOf(s2 - r)) + + /** If a set are equal then removing an element on both sides preserves this relationship. + */ + @pure + @opaque + def removeEquals[T](s1: Set[T], s2: Set[T], r: T): Unit = { + require(s1 === s2) + unfold(s1.remove(r)) + unfold(s2.remove(r)) + diffEquals(s1, s2, Set(r)) + }.ensuring((s1 - r) === (s2 - r)) + + /** --------------------------------------------------------------------------------------------------------------- * + * ----------------------------------------------------SYMDIFF----------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** If the symmetric difference of two sets contains an element then the element cannot belong to the two sets at the + * same time. + */ + @pure + @opaque + def symDiffContains[T](s1: Set[T], s2: Set[T], e: T): Unit = { + unfold(s1.symDiff(s2)) + diffContains(s1, s2, e) + diffContains(s2, s1, e) + unionContains(s1.diff(s2), s2.diff(s1), e) + }.ensuring( + (s1.symDiff(s2)).contains(e) == + ((s1.contains(e) && !s2.contains(e)) || (!s1.contains(e) && s2.contains(e))) + ) + + /** The symmetric difference of two sets is empty if and only if they are equal + */ + @pure + @opaque + def symDiffIsEmpty[T](s1: Set[T], s2: Set[T], f: T => Boolean): Unit = { + unfold(s1.symDiff(s2)) + diffIsEmpty(s1, s2) + diffIsEmpty(s2, s1) + isEmptyUnion(s1.diff(s2), s2.diff(s1)) + }.ensuring((s1.symDiff(s2)).isEmpty == (s1 === s2)) + + /** Symmetric difference is a commutative operation. + */ + @pure @opaque + def symDiffCommutativity[T](s1: Set[T], s2: Set[T]): Unit = { + unfold(s1.symDiff(s2)) + unfold(s2.symDiff(s1)) + unionCommutativity(s1.diff(s2), s2.diff(s1)) + }.ensuring(s1.symDiff(s2) === s2.symDiff(s1)) + + @pure @opaque + def symDiffEquals[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + unfold(s1.symDiff(s3)) + unfold(s2.symDiff(s3)) + unfold(s1.symDiff(s2)) + + diffEquals(s1, s2, s3) + diffEquals(s3, s1, s2) + diffEquals(s2, s3, s1) + if (s1 === s2) { + unionEquals(s1 &~ s3, s2 &~ s3, s3 &~ s1, s3 &~ s2) + } + if (s2 === s3) { + unionEquals(s1 &~ s2, s1 &~ s3, s2 &~ s1, s3 &~ s1) + } + }.ensuring( + ((s1 === s2) ==> (s1.symDiff(s3) === s2.symDiff(s3))) && + ((s2 === s3) ==> (s1.symDiff(s2) === s1.symDiff(s3))) + ) + + /** If two pairs of set are equal then their symmetric differences are also equal + */ + @pure + @opaque + def symDiffEquals[T](s1: Set[T], s2: Set[T], s3: Set[T], s4: Set[T]): Unit = { + require(s1 === s2) + require(s3 === s4) + symDiffEquals(s1, s2, s3) + symDiffEquals(s2, s4, s3) + equalsTransitivity(s1.symDiff(s3), s2.symDiff(s3), s2.symDiff(s4)) + }.ensuring(s1.symDiff(s3) === s2.symDiff(s4)) + + /** The size of the symmetric difference of two sets is smaller than the sum of the sizes. + * It is actually equal to the sum of the sizes minus twice the size of the intersection. + */ + @pure @opaque + def symDiffSize[T](s1: Set[T], s2: Set[T]): Unit = { + unfold(s1.symDiff(s2)) + diffSize(s1, s2) + diffSize(s2, s1) + diffDisjoint(s1, s2) + unionDisjointSize(s1 &~ s2, s2 &~ s1) + intersectCommutativity(s1, s2) + sizeEquals(s1 & s2, s2 & s1) + assert((s1.symDiff(s2).size == s1.size + s2.size - 2 * (s1 & s2).size)) + }.ensuring( + (s1.symDiff(s2).size == s1.size + s2.size - 2 * (s1 & s2).size) && + (s1.symDiff(s2).size <= s1.size + s2.size) + ) + + /** --------------------------------------------------------------------------------------------------------------- * + * --------------------------------------------------INTERSECT----------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** The intersection of two sets contains an element if and only if both sets contain it. + */ + @pure @opaque + def intersectContains[T](s1: Set[T], s2: Set[T], e: T): Unit = { + unfold(s1.intersect(s2)) + unionContains(s1, s2, e) + symDiffContains(s1, s2, e) + diffContains(s1 ++ s2, s1.symDiff(s2), e) + }.ensuring((s1 & s2).contains(e) == (s1.contains(e) && s2.contains(e))) + + /** If a set is subset of another one, then its intersection with a third set is also a subset of the intersection of + * the second and the third one. + * + * ∀s1,s2,s3. s1 ⊆ s2 => s1 ∩ s3 ⊆ s2 ∩ s3 + */ + @pure @opaque + def intersectSubsetOfLeft[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + if (s1.subsetOf(s2) && !(s1 & s3).subsetOf(s2 & s3)) { + val w = notSubsetOfWitness(s1 & s3, s2 & s3) + intersectContains(s1, s3, w) + intersectContains(s2, s3, w) + subsetOfContains(s1, s2, w) + } + }.ensuring((s1.subsetOf(s2)) ==> (s1 & s3).subsetOf(s2 & s3)) + + /** If a set is subset of another one, then its intersection with a third set is also a subset of the intersection of + * the second and the third one. + * + * ∀s1,s2,s3. s2 ⊆ s3 => s1 ∩ s2 ⊆ s1 ∩ s3 + */ + @pure + @opaque + def intersectSubsetOfRight[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + if (s2.subsetOf(s3) && !(s1 & s2).subsetOf(s1 & s3)) { + val w = notSubsetOfWitness(s1 & s2, s1 & s3) + intersectContains(s1, s2, w) + intersectContains(s1, s3, w) + subsetOfContains(s2, s3, w) + } + + }.ensuring(s2.subsetOf(s3) ==> (s1 & s2).subsetOf(s1 & s3)) + + /** A set is subset of the intersection of two sets if and only if it is subset of both sets. + */ + @pure + @opaque + def intersectSubsetOf[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + if (s1.subsetOf(s2 & s3)) { + if (!s1.subsetOf(s2)) { + val w = notSubsetOfWitness(s1, s2) + subsetOfContains(s1, s2 & s3, w) + intersectContains(s2, s3, w) + } + if (!s1.subsetOf(s3)) { + val w = notSubsetOfWitness(s1, s3) + subsetOfContains(s1, s2 & s3, w) + intersectContains(s2, s3, w) + } + } + + if (s1.subsetOf(s2) && s1.subsetOf(s3) && !s1.subsetOf(s2 & s3)) { + val w = notSubsetOfWitness(s1, s2 & s3) + subsetOfContains(s1, s2, w) + subsetOfContains(s1, s3, w) + intersectContains(s2, s3, w) + assert(false) + } + }.ensuring(s1.subsetOf(s2 & s3) == (s1.subsetOf(s2) && s1.subsetOf(s3))) + + /** The intersection of two sets is subset of each of the sets. + */ + @pure + @opaque + def intersectSubsetOf[T](s1: Set[T], s2: Set[T]): Unit = { + if (!(s1 & s2).subsetOf(s1)) { + val w = notSubsetOfWitness(s1 & s2, s1) + intersectContains(s1, s2, w) + } + if (!(s1 & s2).subsetOf(s2)) { + val w = notSubsetOfWitness(s1 & s2, s2) + intersectContains(s1, s2, w) + } + }.ensuring( + (s1 & s2).subsetOf(s1) && (s1 & s2).subsetOf(s2) + ) + + @pure + @opaque + def intersectSubsetOf[T](s1: Set[T], s2: Set[T], s3: Set[T], s4: Set[T]): Unit = { + require(s1.subsetOf(s2)) + require(s3.subsetOf(s4)) + intersectSubsetOfLeft(s1, s2, s3) + intersectSubsetOfRight(s2, s3, s4) + subsetOfTransitivity(s1 & s3, s2 & s3, s2 & s4) + }.ensuring((s1 & s3).subsetOf(s2 & s4)) + + /** If two sets are equal then their intersection with a third set is also equal + */ + @pure + @opaque + def intersectEquals[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + intersectSubsetOfRight(s1, s2, s3) + intersectSubsetOfLeft(s1, s2, s3) + intersectSubsetOfLeft(s2, s1, s3) + intersectSubsetOfRight(s1, s3, s2) + }.ensuring( + ((s1 === s2) ==> ((s1 & s3) === (s2 & s3))) && + ((s2 === s3) ==> ((s1 & s2) === (s1 & s3))) + ) + + /** If two pairs of sets are equal then their intersection is also equal. + */ + @pure + @opaque + def intersectEquals[T](s1: Set[T], s2: Set[T], s3: Set[T], s4: Set[T]): Unit = { + require(s1 === s2) + require(s3 === s4) + intersectSubsetOf(s1, s2, s3, s4) + intersectSubsetOf(s2, s1, s4, s3) + }.ensuring((s1 & s3) === (s2 & s4)) + + /** Interesection is commutative. + * + * ∀s1,s2. s1 ∩ s2 = s2 ∩ s1 + */ + @pure @opaque + def intersectCommutativity[T](s1: Set[T], s2: Set[T]): Unit = { + unfold(s1.intersect(s2)) + unfold(s2.intersect(s1)) + symDiffCommutativity(s1, s2) + unionCommutativity(s1, s2) + diffEquals(s1 ++ s2, s2 ++ s1, s1.symDiff(s2), s2.symDiff(s1)) + }.ensuring((s1 & s2) === (s2 & s1)) + + /** The union of two sets is equal to the union of their intersection and their symmetric difference. + * + * ∀s1,s2. |s1 ∩ s2| ≤ |s1| /\ |s1 ∩ s2| ≤ |s2| + */ + @pure + @opaque + def intersectSize[T](s1: Set[T], s2: Set[T]): Unit = { + unionSize(s1, s2) + }.ensuring( + (s1 & s2).size <= s1.size && + (s1 & s2).size <= s2.size + ) + + /** The union of two sets is equal to the union of their intersection and their symmetric difference. + * + * ∀s1,s2. s1 ∪ s2 = (s1 Δ s2) ∪ (s1 ∩ s2) + */ + @pure + @opaque + def unionIntersectSymDiffDef[T](s1: Set[T], s2: Set[T]): Unit = { + if ((s1 ++ s2) =/= (s1.symDiff(s2) ++ (s1 & s2))) { + val w = notEqualsWitness(s1 ++ s2, s1.symDiff(s2) ++ (s1 & s2)) + unionContains(s1, s2, w) + unionContains(s1.symDiff(s2), s1 & s2, w) + intersectContains(s1, s2, w) + symDiffContains(s1, s2, w) + } + }.ensuring( + (s1 ++ s2) === (s1.symDiff(s2) ++ (s1 & s2)) + ) + + /** The symmetric difference between two sets is disjoint from the intersection. + */ + @pure @opaque + def intersectDisjointSymDiff[T](s1: Set[T], s2: Set[T]): Unit = { + if (!s1.symDiff(s2).disjoint(s1 & s2)) { + val w = notDisjointWitness(s1.symDiff(s2), s1 & s2) + symDiffContains(s1, s2, w) + intersectContains(s1, s2, w) + } + disjointSymmetry(s1.symDiff(s2), s1 & s2) + }.ensuring( + s1.symDiff(s2).disjoint(s1 & s2) && + (s1 & s2).disjoint(s1.symDiff(s2)) + ) + + /** Alternative definition of the subset relation. + * + * A set is subset of another one if and only if the intersection of both is equal to the former. + */ + @pure + @opaque + def intersectAltDefinition[T](s1: Set[T], s2: Set[T]): Unit = { + intersectSubsetOf(s1, s1, s2) + subsetOfReflexivity(s1) + intersectSubsetOf(s1, s2) + intersectCommutativity(s1, s2) + equalsTransitivityStrong(s1, s2 & s1, s1 & s2) + }.ensuring( + (s1.subsetOf(s2) == ((s1 & s2) === s1)) && + (s1.subsetOf(s2) == ((s2 & s1) === s1)) + ) + + /** If two sets are not disjoint then we can exhibit an element such that both contain it. + */ + @pure @opaque + def notDisjointWitness[T](s1: Set[T], s2: Set[T]): T = { + require(!s1.disjoint(s2)) + unfold(s1.disjoint(s2)) + val w = notEmptyWitness(s1.intersect(s2)) + intersectContains(s1, s2, w) + w + }.ensuring(e => s1.contains(e) && s2.contains(e)) + + /** Two disjoint sets cannot contain the same element. + */ + @pure @opaque + def disjointContains[T](s1: Set[T], s2: Set[T], e: T): Unit = { + require(s1.disjoint(s2)) + unfold(s1.disjoint(s2)) + isEmptyContains(s1.intersect(s2), e) + intersectContains(s1, s2, e) + }.ensuring(!s1.contains(e) || !s2.contains(e)) + @pure @opaque + def disjointSubsetOf[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + if (s3.subsetOf(s2) && s1.disjoint(s2) && !s1.disjoint(s3)) { + val w = notDisjointWitness(s1, s3) + disjointContains(s1, s2, w) + subsetOfContains(s3, s2, w) + } + if (s2.subsetOf(s1) && s1.disjoint(s3) && !s2.disjoint(s3)) { + val w = notDisjointWitness(s2, s3) + disjointContains(s1, s3, w) + subsetOfContains(s2, s1, w) + } + }.ensuring( + ((s3.subsetOf(s2) && s1.disjoint(s2)) ==> s1.disjoint(s3)) && + ((s2.subsetOf(s1) && s1.disjoint(s3)) ==> s2.disjoint(s3)) + ) + + /** If two sets are equal then the first is disjoint to a third set if and only if the second is also disjoint to it. + */ + @pure + @opaque + def disjointEquals[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + disjointSubsetOf(s1, s2, s3) + disjointSubsetOf(s1, s3, s2) + disjointSubsetOf(s2, s1, s3) + }.ensuring( + ((s2 === s3) ==> (s1.disjoint(s2) == s1.disjoint(s3))) && + ((s1 === s2) ==> (s1.disjoint(s3) == s2.disjoint(s3))) + ) + + /** Disjointness is a symmetric relation. + */ + @pure + @opaque + def disjointSymmetry[T](s1: Set[T], s2: Set[T]): Unit = { + unfold(s1.disjoint(s2)) + unfold(s2.disjoint(s1)) + intersectCommutativity(s1, s2) + isEmptyEquals(s1.intersect(s2), s2.intersect(s1)) + }.ensuring( + s1.disjoint(s2) == s2.disjoint(s1) + ) + + /** A set is disjoint with the union of two other sets, if and only it is disjoint with both. + */ + @pure @opaque + def disjointUnionRight[T](s1: Set[T], s2: Set[T], s3: Set[T]): Unit = { + if (s1.disjoint(s2 ++ s3)) { + if (!s1.disjoint(s2)) { + val w = notDisjointWitness(s1, s2) + disjointContains(s1, s2 ++ s3, w) + unionContains(s2, s3, w) + } + if (!s1.disjoint(s3)) { + val w = notDisjointWitness(s1, s3) + disjointContains(s1, s2 ++ s3, w) + unionContains(s2, s3, w) + } + } + if (s1.disjoint(s2) && s1.disjoint(s3)) { + if (!s1.disjoint(s2 ++ s3)) { + val w = notDisjointWitness(s1, s2 ++ s3) + disjointContains(s1, s2, w) + disjointContains(s1, s3, w) + unionContains(s2, s3, w) + } + } + }.ensuring( + s1.disjoint(s2 ++ s3) == (s1.disjoint(s2) && s1.disjoint(s3)) + ) + + /** A set is disjoint with an other set which an element has been added to if and only if it is disjoint to set and + * does not contain the element. + */ + @pure + @opaque + def disjointIncl[T](s1: Set[T], s2: Set[T], e: T): Unit = { + unfold(s2.incl(e)) + disjointUnionRight(s1, s2, Set(e)) + disjointSingleton(s1, e) + }.ensuring( + s1.disjoint(s2 + e) == (s1.disjoint(s2) && !s1.contains(e)) + ) + + /** A set is disjoint with an other set which an element has been added to if and only if it is disjoint to set and + * does not contain the element. + */ + @pure + @opaque + def disjointInclRev[T](s1: Set[T], s2: Set[T], e: T): Unit = { + disjointIncl(s2, s1, e) + disjointSymmetry(s1 + e, s2) + disjointSymmetry(s1, s2) + }.ensuring( + (s1 + e).disjoint(s2) == (s1.disjoint(s2) && !s2.contains(e)) + ) + + /** --------------------------------------------------------------------------------------------------------------- * + * ------------------------------------------------------MAP------------------------------------------------------- * + * ---------------------------------------------------------------------------------------------------------------- + */ + + /** If a set contains an element, then the result after a map operation contains the element on which the function + * has been applied. + */ + @pure + @opaque + def mapContains[T, V](s: Set[T], f: T => V, e: T): Unit = { + require(s.contains(e)) + unfold(s.contains) + unfold(s.map[V](f).contains) + + val eqe: T => Boolean = _ == e + val eqfe: V => Boolean = _ == f(e) + + forallNotExists(s, eqe) + forallNotExists(s.map(f), eqfe) + forallMap(s, f, !eqfe(_)) + + val g: T => Boolean = f andThen (!eqfe(_)) + + if (!s.map[V](f).contains(f(e)) & s.contains(e)) { + forallContains(s, g, e) + } + }.ensuring(s.map[V](f).contains(f(e))) + + /** If a set after a map contains an element, then we can exhibit an element in the original set such that applying + * the function to the latter gives the former. + */ + @pure + @opaque + def mapContainsWitness[T, V](s: Set[T], f: T => V, e: V): T = { + require(s.map[V](f).contains(e)) + + if (s.forall(f andThen (_ != e))) { + forallMap(s, f, _ != e) + forallContains(s.map[V](f), _ != e, e) + } + notForallWitness(s, f andThen (_ != e)) + }.ensuring(res => (f(res) == e) && s.contains(res)) + + /** If a set is a subset of an other then applying a map operations, does not change the subset relation. + */ + @pure + @opaque + def mapSubsetOf[T, V](s1: Set[T], s2: Set[T], f: T => V): Unit = { + require(s1.subsetOf(s2)) + if (!s1.map[V](f).subsetOf(s2.map[V](f))) { + val w = notSubsetOfWitness(s1.map[V](f), s2.map[V](f)) + val e = mapContainsWitness(s1, f, w) + subsetOfContains(s1, s2, e) + mapContains(s2, f, e) + } + }.ensuring(s1.map[V](f).subsetOf(s2.map[V](f))) + + /** If two sets are equal then their map wrt a function is equal as well. + */ + @pure + @opaque + def mapEquals[T, V](s1: Set[T], s2: Set[T], f: T => V): Unit = { + require(s1 === s2) + mapSubsetOf(s1, s2, f) + mapSubsetOf(s2, s1, f) + }.ensuring(s1.map[V](f) === s2.map[V](f)) + + @pure @opaque @inlineOnce + def mapExists[T, V](s: Set[T], f: T => V, p: V => Boolean): Unit = { + forallNotExists(s.map[V](f), p) + forallNotExists(s, f andThen p) + + val nfp: T => Boolean = x => !(f andThen p)(x) + val np: V => Boolean = x => !p(x) + + if (s.map[V](f).forall(np) && !s.forall(nfp)) { + val w = notForallWitness(s, nfp) + mapContains(s, f, w) + forallContains(s.map[V](f), np, f(w)) + } + if (!s.map[V](f).forall(np) && s.forall(nfp)) { + val w2 = notForallWitness(s.map[V](f), np) + val w = mapContainsWitness(s, f, w2) + forallContains(s, nfp, w) + } + }.ensuring(s.map[V](f).exists(p) == s.exists(f andThen p)) + + /** The map of an union is the union of the maps. + */ + @pure + @opaque + def mapUnion[T, V](s1: Set[T], s2: Set[T], f: T => V): Unit = { + if ((s1 ++ s2).map[V](f) =/= (s1.map[V](f) ++ s2.map[V](f))) { + val w = notEqualsWitness((s1 ++ s2).map[V](f), s1.map[V](f) ++ s2.map[V](f)) + if ((s1 ++ s2).map[V](f).contains(w)) { + val v = mapContainsWitness(s1 ++ s2, f, w) + unionContains(s1, s2, v) + if (s1.contains(v)) { + mapContains(s1, f, v) + } else { + mapContains(s2, f, v) + } + unionContains(s1.map[V](f), s2.map[V](f), f(v)) + } + if ((s1.map[V](f) ++ s2.map[V](f)).contains(w)) { + unionContains(s1.map[V](f), s2.map[V](f), w) + if (s1.map[V](f).contains(w)) { + val v1 = mapContainsWitness(s1, f, w) + unionContains(s1, s2, v1) + mapContains(s1 ++ s2, f, v1) + } else { + val v2 = mapContainsWitness(s2, f, w) + unionContains(s1, s2, v2) + mapContains(s1 ++ s2, f, v2) + } + } + } + }.ensuring((s1 ++ s2).map[V](f) === (s1.map[V](f) ++ s2.map[V](f))) + + /** Doing a map on a singleton is the singleton of the function applied to the element + */ + @pure + @opaque + def mapSingleton[T, V](e: T, f: T => V): Unit = { + if (Set(e).map[V](f) =/= Set(f(e))) { + val w: V = notEqualsWitness(Set(e).map[V](f), Set(f(e))) + if (Set(e).map[V](f).contains(w)) { + val v = mapContainsWitness(Set(e), f, w) + singletonContains(e, v) + singletonContains(f(e), w) + } + if (Set[V](f(e)).contains(w)) { + singletonContains(f(e), w) + singletonContains(e) + mapContains(Set(e), f, e) + } + } + }.ensuring(Set(e).map(f) === Set(f(e))) + + /** Doing a map operation on a set and an element is the same as mapping the set only and then adding the element + * on which the function has been applied. + */ + @pure + @opaque + def mapIncl[T, V](s: Set[T], e: T, f: T => V): Unit = { + unfold(s.incl(e)) + unfold(s.map[V](f).incl(f(e))) + mapUnion(s, Set(e), f) + mapSingleton(e, f) + unionEqualsRight(s.map[V](f), Set(e).map[V](f), Set(f(e))) + equalsTransitivity((s + e).map[V](f), s.map[V](f) ++ Set(e).map[V](f), s.map[V](f) ++ Set(f(e))) + + }.ensuring((s + e).map[V](f) === (s.map[V](f) + f(e))) + +} diff --git a/daml-lf/verification/utils/Transaction.scala b/daml-lf/verification/utils/Transaction.scala new file mode 100644 index 0000000000..59570d5894 --- /dev/null +++ b/daml-lf/verification/utils/Transaction.scala @@ -0,0 +1,57 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package utils + +import stainless.lang._ +import stainless.annotation._ +import scala.annotation.nowarn +import stainless.proof._ +import stainless.collection._ + +import Value.ContractId + +object Transaction { + + /** The state of a key at the beginning of the transaction. + */ + sealed trait KeyInput extends Product with Serializable { + def toKeyMapping: Option[ContractId] + + def isActive: Boolean + } + + /** No active contract with the given key. + */ + sealed trait KeyInactive extends KeyInput { + override def toKeyMapping: Option[ContractId] = None[ContractId]() + + override def isActive: Boolean = false + } + + /** A contract with the key will be created so the key must be inactive. + */ + @nowarn + final case object KeyCreate extends KeyInactive + + /** Negative key lookup so the key mus tbe inactive. + */ + @nowarn + final case object NegativeKeyLookup extends KeyInactive + + /** Key must be mapped to this active contract. + */ + final case class KeyActive(cid: ContractId) extends KeyInput { + override def toKeyMapping: Option[ContractId] = Some(cid) + + override def isActive: Boolean = true + } + + sealed abstract class TransactionError + + final case class DuplicateContractKey(key: GlobalKey) extends TransactionError + final case class InconsistentContractKey(key: GlobalKey) + + type KeyInputError = Either[InconsistentContractKey, DuplicateContractKey] +} diff --git a/daml-lf/verification/utils/Tree.scala b/daml-lf/verification/utils/Tree.scala new file mode 100644 index 0000000000..7b5ccd91d1 --- /dev/null +++ b/daml-lf/verification/utils/Tree.scala @@ -0,0 +1,655 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package utils + +import stainless.collection._ +import stainless.lang.{unfold, decreases, BooleanDecorations} +import stainless.annotation._ +import SetProperties._ +import ListProperties._ +import SetAxioms._ + +/** Type class representing forests in manner to avoid measure problems. In the rest of the file and the package we will + * make no distinction between a tree and a forest. + * + * A forest can either be: + * - A content node: a forest with a node on top of it, that is a tree + * - An articulation node: the union between a tree (on the left) and a forest (on the right) + * - The empty forest which does not contain any node. + * + * It can be therefore seen as a linear sequence of trees. + */ +sealed trait Tree[T] { + + /** Numbers of node in the tree + */ + @pure + @opaque + def size: BigInt = { + decreases(this) + this match { + case Endpoint() => BigInt(0) + case ArticulationNode(l, r) => l.size + r.size + case ContentNode(n, sub) => sub.size + 1 + } + }.ensuring(res => + res >= 0 && + ((res == 0) ==> (this == Endpoint[T]())) + ) + + /** Whether the tree contains every node at most once + */ + @pure + @opaque + def isUnique: Boolean = { + decreases(this) + this match { + case Endpoint() => true + case ArticulationNode(l, r) => l.isUnique && r.isUnique && l.content.disjoint(r.content) + case ContentNode(n, sub) => sub.isUnique && !sub.contains(n) + } + } + + /** Set of nodes + */ + @pure @opaque + def content: Set[T] = { + decreases(this) + this match { + case Endpoint() => Set.empty[T] + case ArticulationNode(l, r) => l.content ++ r.content + case ContentNode(n, sub) => sub.content + n + } + } + + /** Checks if the tree contains a node + */ + @pure + def contains(e: T): Boolean = { + decreases(this) + unfold(content) + this match { + case ArticulationNode(l, r) => + SetProperties.unionContains(l.content, r.content, e) + l.contains(e) || r.contains(e) + case Endpoint() => + SetProperties.emptyContains(e) + false + case ContentNode(n, sub) => + SetProperties.inclContains(sub.content, n, e) + sub.contains(e) || (e == n) + } + }.ensuring(res => (res == content.contains(e))) + + /** Inorder traversal where each node is visited twice. When entering a node we apply a given function and when exiting + * the node we apply an other function. While doing this we update a state and return it at the end of the traversal. + * @param init The initial state of the traversal + * @param f1 The function that is applied everytime we visit a node for the first time (i.e. we are entering it) + * @param f2 The function that is applied everytime we visit a node for the second time (i.e. we are exiting it) + */ + @pure @opaque + def traverse[Z](init: Z, f1: (Z, T) => Z, f2: (Z, T) => Z): Z = { + decreases(this) + unfold(size) + this match { + case ArticulationNode(l, r) => + r.traverse(l.traverse(init, f1, f2), f1, f2) + case Endpoint() => init + case ContentNode(n, sub) => f2(sub.traverse(f1(init, n), f1, f2), n) + } + }.ensuring(res => (size == 0) ==> (res == init)) + + /** In order traversal where each node is visited twice. When entering a node we apply a given function and when exiting + * the node we apply an other function. While doing this we update a state and return a list of triple whose entries + * are: + * - Each intermediate state + * - The node we have visited at each step + * - The traversal direction of the step, i.e. if that's the first or the second time we are visiting it + * + * @param init The initial state of the traversal + * @param f1 The function that is applied everytime we visit a node for the first time (i.e. we are entering it) + * @param f2 The function that is applied everytime we visit a node for the second time (i.e. we are exiting it) + */ + @pure + @opaque + def scan[Z](init: Z, f1: (Z, T) => Z, f2: (Z, T) => Z): List[(Z, T, TraversalDirection)] = { + decreases(this) + unfold(size) + this match { + case ArticulationNode(l, r) => + concatIndex(l.scan(init, f1, f2), r.scan(l.traverse(init, f1, f2), f1, f2), 2 * size - 1) + l.scan(init, f1, f2) ++ r.scan(l.traverse(init, f1, f2), f1, f2) + case Endpoint() => Nil[(Z, T, TraversalDirection)]() + case ContentNode(n, sub) => + concatIndex( + sub.scan(f1(init, n), f1, f2), + List((sub.traverse(f1(init, n), f1, f2), n, TraversalDirection.Up)), + 2 * size - 2, + ) + (init, n, TraversalDirection.Down) :: (sub.scan(f1(init, n), f1, f2) ++ List( + (sub.traverse(f1(init, n), f1, f2), n, TraversalDirection.Up) + )) + } + }.ensuring(res => + (res.size == 2 * size) && + ((size > 0) ==> (res(2 * size - 1)._3 == TraversalDirection.Up)) + ) + +} + +sealed trait StructuralNode[T] extends Tree[T] + +case class ArticulationNode[T](left: ContentNode[T], right: StructuralNode[T]) + extends StructuralNode[T] + +case class Endpoint[T]() extends StructuralNode[T] + +case class ContentNode[T](nodeContent: T, sub: StructuralNode[T]) extends Tree[T] + +/** Trait describing if the node of a tree is being visited for the first time (Down) or the second time (Up) + */ +sealed trait TraversalDirection + +object TraversalDirection { + case object Up extends TraversalDirection + + case object Down extends TraversalDirection +} + +object TreeProperties { + + /** Express an intermediate state of a tree traversal in function of the one of the subtree when the tree is a [[ContentNode]] + * + * @param n The node of the content tree + * @param sub The subtree of the content tree + * @param init The initial state of the traversal + * @param f1 The function that is used when visiting the nodes for the first time + * @param f2 The function that is used when visiting the nodes for the second time + * @param i The step number of the intermediate state + */ + @pure @opaque + def scanIndexing[T, Z]( + n: T, + sub: StructuralNode[T], + init: Z, + f1: (Z, T) => Z, + f2: (Z, T) => Z, + i: BigInt, + ): Unit = { + require(i >= 0) + require(i < 2 * ContentNode(n, sub).size) + unfold(ContentNode(n, sub).size) + unfold(ContentNode(n, sub).scan(init, f1, f2)) + if (i != 0) { + concatIndex( + sub.scan(f1(init, n), f1, f2), + List((sub.traverse(f1(init, n), f1, f2), n, TraversalDirection.Up)), + i - 1, + ) + } + }.ensuring( + ContentNode(n, sub).scan(init, f1, f2)(i) == + ( + if (i == BigInt(0)) (init, n, TraversalDirection.Down) + else if (i == 2 * (ContentNode(n, sub).size) - 1) + (sub.traverse(f1(init, n), f1, f2), n, TraversalDirection.Up) + else (sub.scan(f1(init, n), f1, f2))(i - 1) + ) + ) + + /** Express an intermediate state of a tree traversal in function of the one of the left or the right subtree + * when the tree is am [[ArticulationNode]] + * + * @param l The left subtree + * @param r The right subtree + * @param init The initial state of the traversal + * @param f1 The function that is used when visiting the nodes for the first time + * @param f2 The function that is used when visiting the nodes for the second time + * @param i The step number of the intermediate state + */ + @pure + @opaque + def scanIndexing[T, Z]( + l: ContentNode[T], + r: StructuralNode[T], + init: Z, + f1: (Z, T) => Z, + f2: (Z, T) => Z, + i: BigInt, + ): Unit = { + require(i >= 0) + require(i < 2 * ArticulationNode(l, r).size) + unfold(ArticulationNode(l, r).size) + unfold(ArticulationNode(l, r).scan(init, f1, f2)) + concatIndex(l.scan(init, f1, f2), r.scan(l.traverse(init, f1, f2), f1, f2), i) + }.ensuring( + ArticulationNode(l, r).scan(init, f1, f2)(i) == + ( + if (i < 2 * l.size) l.scan(init, f1, f2)(i) + else r.scan(l.traverse(init, f1, f2), f1, f2)(i - 2 * l.size) + ) + ) + + /** Express an intermediate state of a tree traversal in function of the one before + * + * @param n The node of the content tree + * @param sub The subtree of the content tree + * @param init The initial state of the traversal + * @param init The initial state of the traversal + * @param f1 The function that is used when visiting the nodes for the first time + * @param f2 The function that is used when visiting the nodes for the second time + * @param i The step number of the intermediate state + */ + @pure @opaque + def scanIndexingState[T, Z]( + tr: Tree[T], + init: Z, + f1: (Z, T) => Z, + f2: (Z, T) => Z, + i: BigInt, + ): Unit = { + decreases(tr) + require(i >= 0) + require(i < 2 * tr.size) + require(tr.size > 0) + + unfold(tr.scan(init, f1, f2)) + unfold(tr.traverse(init, f1, f2)) + unfold(tr.size) + + tr match { + case Endpoint() => Trivial() + case ContentNode(n, sub) => + scanIndexing(n, sub, init, f1, f2, i) + scanIndexing(n, sub, init, f1, f2, 2 * tr.size - 1) + + if (i == 0) { + Trivial() + } else if (i == 2 * tr.size - 1) { + if (sub.size > 0) { + scanIndexingState(sub, f1(init, n), f1, f2, 0) + scanIndexing(n, sub, init, f1, f2, i - 1) + } + } else { + scanIndexingState(sub, f1(init, n), f1, f2, i - 1) + scanIndexing(n, sub, init, f1, f2, i - 1) + } + case ArticulationNode(le, ri) => + scanIndexing(le, ri, init, f1, f2, 2 * tr.size - 1) + scanIndexingState(le, init, f1, f2, 0) // traverse + + if (ri.size == 0) { + scanIndexing(le, ri, init, f1, f2, 2 * le.size - 1) + } else { + scanIndexingState(ri, le.traverse(init, f1, f2), f1, f2, 0) // traverse + } + + scanIndexing(le, ri, init, f1, f2, i) + + if (i == 0) { + scanIndexingState(le, init, f1, f2, i) + } else { + scanIndexing(le, ri, init, f1, f2, i - 1) + if (i < 2 * le.size) { + scanIndexingState(le, init, f1, f2, i) + } else { + scanIndexingState(ri, le.traverse(init, f1, f2), f1, f2, i - 2 * le.size) + scanIndexing(le, ri, init, f1, f2, 2 * le.size - 1) + } + } + + } + + }.ensuring( + (tr.scan(init, f1, f2)(i)._1 == ( + if (i == 0) init + else if (tr.scan(init, f1, f2)(i - 1)._3 == TraversalDirection.Down) { + f1(tr.scan(init, f1, f2)(i - 1)._1, tr.scan(init, f1, f2)(i - 1)._2) + } else { + f2(tr.scan(init, f1, f2)(i - 1)._1, tr.scan(init, f1, f2)(i - 1)._2) + } + )) && + (tr.traverse(init, f1, f2) == f2( + tr.scan(init, f1, f2)(2 * tr.size - 1)._1, + tr.scan(init, f1, f2)(2 * tr.size - 1)._2, + )) + ) + + /** The nodes and the directions of traversal are only dependent of the tree. That is they are independent of the initial + * state and of which functions are used to traverse it. + * + * @param tr The tree that is being traversed + * @param init1 The initial state of the first traversal + * @param init2 The initial state of the second traversal + * @param f11 The functions that is used when visiting the nodes for the first time in the first traversal + * @param f12 The functions that is used when visiting the nodes for the secondS time in the first traversal + * @param f21 The functions that is used when visiting the nodes for the first time in the second traversal + * @param f22 The functions that is used when visiting the nodes for the second time in the second traversal + * @param i The place in which the node appears + */ + @pure + @opaque + def scanIndexingNode[T, Z1, Z2]( + tr: Tree[T], + init1: Z1, + init2: Z2, + f11: (Z1, T) => Z1, + f12: (Z1, T) => Z1, + f21: (Z2, T) => Z2, + f22: (Z2, T) => Z2, + i: BigInt, + ): Unit = { + decreases(tr) + require(i >= 0) + require(i < 2 * tr.size) + + unfold(tr.scan(init1, f11, f12)) + unfold(tr.scan(init2, f21, f22)) + unfold(tr.size) + + tr match { + case Endpoint() => Trivial() + case ContentNode(n, sub) => + scanIndexing(n, sub, init1, f11, f12, i) + scanIndexing(n, sub, init2, f21, f22, i) + if ((i == 0) || (i == 2 * tr.size - 1)) { + Trivial() + } else { + scanIndexingNode(sub, f11(init1, n), f21(init2, n), f11, f12, f21, f22, i - 1) + } + case ArticulationNode(l, r) => + scanIndexing(l, r, init1, f11, f12, i) + scanIndexing(l, r, init2, f21, f22, i) + if (i < 2 * l.size) { + scanIndexingNode(l, init1, init2, f11, f12, f21, f22, i) + } else { + scanIndexingNode( + r, + l.traverse(init1, f11, f12), + l.traverse(init2, f21, f22), + f11, + f12, + f21, + f22, + i - 2 * l.size, + ) + } + + } + + }.ensuring( + (tr.scan(init1, f11, f12)(i)._2 == tr.scan(init2, f21, f22)(i)._2) && + (tr.scan(init1, f11, f12)(i)._3 == tr.scan(init2, f21, f22)(i)._3) + ) + + /** If an intermediate state of a tree traversal does not satisfy a property but the initial state does, then there + * is an intermediate state before that does satisfy this property but whose the next one does not. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param f1 The function that is used when visiting the nodes for the first time + * @param f2 The function that is used when visiting the nodes for the second time + * @param p The propery that is fulfilled by the initial state but not the intermediate one + * @param i The step number of the state that does not satisfy p + */ + @pure + @opaque + def scanNotProp[T, Z]( + tr: Tree[T], + init: Z, + f1: (Z, T) => Z, + f2: (Z, T) => Z, + p: Z => Boolean, + i: BigInt, + ): BigInt = { + decreases(i) + require(i >= 0) + require(i < 2 * tr.size) + require(p(init)) + require(!p(tr.scan(init, f1, f2)(i)._1)) + + if (i == 0) { + scanIndexingState(tr, init, f1, f2, 0) + Unreachable() + } else if (p(tr.scan(init, f1, f2)(i - 1)._1)) { + i - 1 + } else { + scanNotProp(tr, init, f1, f2, p, i - 1) + } + + }.ensuring(j => + j >= 0 && j < i && p(tr.scan(init, f1, f2)(j)._1) && !p(tr.scan(init, f1, f2)(j + 1)._1) + ) + + /** If an intermediate state of a tree traversal does not satisfy a property but the final state does, then there + * is an intermediate state after that one that does not satisfy this property but whose the next one does. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param f1 The function that is used when visiting the nodes for the first time + * @param f2 The function that is used when visiting the nodes for the second time + * @param p The propery that is fulfilled by the final state but not the intermediate one + * @param i The step number of the state that does not satisfy p + */ + @pure + @opaque + def scanNotPropRev[T, Z]( + tr: Tree[T], + init: Z, + f1: (Z, T) => Z, + f2: (Z, T) => Z, + p: Z => Boolean, + i: BigInt, + ): BigInt = { + decreases(2 * tr.size - i) + require(i >= 0) + require(i < 2 * tr.size) + require(p(tr.traverse(init, f1, f2))) + require(!p(tr.scan(init, f1, f2)(i)._1)) + + scanIndexingState(tr, init, f1, f2, i) + + if (i == 2 * tr.size - 1) { + i + } else if (p(tr.scan(init, f1, f2)(i + 1)._1)) { + i + } else { + scanNotPropRev(tr, init, f1, f2, p, i + 1) + } + + }.ensuring(j => + (j >= i && j < 2 * tr.size - 1 && !p(tr.scan(init, f1, f2)(j)._1) && p( + tr.scan(init, f1, f2)(j + 1)._1 + )) || + ((j == 2 * tr.size - 1) && !p(tr.scan(init, f1, f2)(j)._1) && p(tr.traverse(init, f1, f2))) + ) + + /** If the state obtained after traversing a tree does not satisfy a property but the initial state does, then there + * is an intermediate state of the traversal such that it does satisfy this property but the next one does not. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param f1 The function that is used when visiting the nodes for the first time + * @param f2 The function that is used when visiting the nodes for the second time + * @param p The propery that is fulfilled by the initial state but not the final one + */ + @pure + @opaque + def traverseNotProp[T, Z]( + tr: Tree[T], + init: Z, + f1: (Z, T) => Z, + f2: (Z, T) => Z, + p: Z => Boolean, + ): BigInt = { + require(p(init)) + require(!p(tr.traverse(init, f1, f2))) + + scanIndexingState(tr, init, f1, f2, 0) + if (p(tr.scan(init, f1, f2)(2 * tr.size - 1)._1)) { + 2 * tr.size - 1 + } else { + scanNotProp(tr, init, f1, f2, p, 2 * tr.size - 1) + } + }.ensuring(i => + i >= 0 && ((i < 2 * tr.size - 1 && p(tr.scan(init, f1, f2)(i)._1) && !p( + tr.scan(init, f1, f2)(i + 1)._1 + )) + || (i == 2 * tr.size - 1 && p(tr.scan(init, f1, f2)(i)._1))) + ) + + /** If a node appears in a tree traversal then the tree contains it. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param f1 The function that is used when visiting the nodes for the first time + * @param f2 The function that is used when visiting the nodes for the second time + * @param i The step during which the node appears + */ + @pure @opaque + def scanContains[T, Z]( + tr: Tree[T], + init: Z, + f1: (Z, T) => Z, + f2: (Z, T) => Z, + i: BigInt, + ): Unit = { + decreases(tr) + require(0 <= i) + require(i < 2 * tr.size) + unfold(tr.size) + unfold(tr.contains(tr.scan(init, f1, f2)(i)._2)) + tr match { + case Endpoint() => Unreachable() + case ContentNode(n, sub) => + scanIndexing(n, sub, init, f1, f2, i) + if (i > 0 && i < 2 * tr.size - 1) { + scanContains(sub, f1(init, n), f1, f2, i - 1) + } + case ArticulationNode(l, r) => + scanIndexing(l, r, init, f1, f2, i) + if (i < 2 * l.size) { + scanContains(l, init, f1, f2, i) + } else { + scanContains(r, l.traverse(init, f1, f2), f1, f2, i - 2 * l.size) + } + } + }.ensuring(tr.contains(tr.scan(init, f1, f2)(i)._2)) + + /** If a tree is unique and a node is visited at two locations in the same [[TraversalDirection]], then those two locations + * are the same. + * + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param f1 The function that is used when visiting the nodes for the first time + * @param f2 The function that is used when visiting the nodes for the second time + * @param i The first locationof the node + * @param j The second location of the node + */ + @pure @opaque + def isUniqueIndexing[T, Z]( + tr: Tree[T], + init: Z, + f1: (Z, T) => Z, + f2: (Z, T) => Z, + i: BigInt, + j: BigInt, + ): Unit = { + decreases(tr) + require(tr.isUnique) + require(0 <= i) + require(i < 2 * tr.size) + require(0 <= j) + require(j < 2 * tr.size) + require(tr.scan(init, f1, f2)(i)._2 == tr.scan(init, f1, f2)(j)._2) + require(tr.scan(init, f1, f2)(i)._3 == tr.scan(init, f1, f2)(j)._3) + + unfold(tr.size) + unfold(tr.isUnique) + + tr match { + case Endpoint() => Unreachable() + case ContentNode(n, sub) => + scanIndexing(n, sub, init, f1, f2, i) + scanIndexing(n, sub, init, f1, f2, j) + if ((i == 0) && ((j == 0) || (j == 2 * tr.size - 1))) { + Trivial() + } else if ((i == 2 * tr.size - 1) && ((j == 0) || (j == 2 * tr.size - 1))) { + Trivial() + } else if ((i == 0) || (i == 2 * tr.size - 1)) { + scanContains(sub, f1(init, n), f1, f2, j - 1) + } else if ((j == 0) || (j == 2 * tr.size - 1)) { + scanContains(sub, f1(init, n), f1, f2, i - 1) + } else { + isUniqueIndexing(sub, f1(init, n), f1, f2, i - 1, j - 1) + } + case ArticulationNode(l, r) => + scanIndexing(l, r, init, f1, f2, i) + scanIndexing(l, r, init, f1, f2, j) + if ((i < 2 * l.size) && (j < 2 * l.size)) { + isUniqueIndexing(l, init, f1, f2, i, j) + } else if ((i >= 2 * l.size) && (j >= 2 * l.size)) { + isUniqueIndexing(r, l.traverse(init, f1, f2), f1, f2, i - 2 * l.size, j - 2 * l.size) + } else if (i < 2 * l.size) { + scanContains(l, init, f1, f2, i) + scanContains(r, l.traverse(init, f1, f2), f1, f2, j - 2 * l.size) + disjointContains(l.content, r.content, tr.scan(init, f1, f2)(i)._2) + disjointContains(l.content, r.content, tr.scan(init, f1, f2)(j)._2) + } else { + scanContains(l, init, f1, f2, j) + scanContains(r, l.traverse(init, f1, f2), f1, f2, i - 2 * l.size) + disjointContains(l.content, r.content, tr.scan(init, f1, f2)(j)._2) + disjointContains(l.content, r.content, tr.scan(init, f1, f2)(i)._2) + } + } + }.ensuring(i == j) + + /** Given a node that has been visited for the second time in a traversal, returns when it has been visited for the + * first time. + * @param tr The tree that is being traversed + * @param init The initial state of the traversal + * @param f1 The function that is used when visiting the nodes for the first time + * @param f2 The function that is used when visiting the nodes for the second time + * @param i The step number during whch the node is visited for the second time. + */ + @pure @opaque + def findDown[T, Z](tr: Tree[T], init: Z, f1: (Z, T) => Z, f2: (Z, T) => Z, i: BigInt): BigInt = { + decreases(tr) + require(i >= 0) + require(i < 2 * tr.size) + require(tr.scan(init, f1, f2)(i)._3 == TraversalDirection.Up) + unfold(tr.scan(init, f1, f2)) + unfold(tr.size) + tr match { + case Endpoint() => Unreachable() + case ContentNode(c, sub) => + scanIndexing(c, sub, init, f1, f2, i) + if (i == 2 * tr.size - 1) { + scanIndexing(c, sub, init, f1, f2, 0) + BigInt(0) + } else { + val j = findDown(sub, f1(init, c), f1, f2, i - 1) + scanIndexing(c, sub, init, f1, f2, j + 1) + j + 1 + } + case ArticulationNode(l, r) => + scanIndexing(l, r, init, f1, f2, i) + if (i < 2 * l.size) { + val j = findDown(l, init, f1, f2, i) + scanIndexing(l, r, init, f1, f2, j) + j + } else { + val j = findDown(r, l.traverse(init, f1, f2), f1, f2, i - 2 * l.size) + scanIndexing(l, r, init, f1, f2, j + 2 * l.size) + j + 2 * l.size + } + } + }.ensuring(j => + (j < i) && + (tr.scan(init, f1, f2)(j)._3 == TraversalDirection.Down) && + (tr.scan(init, f1, f2)(i)._2 == tr.scan(init, f1, f2)(j)._2) + ) + +} diff --git a/daml-lf/verification/utils/Value.scala b/daml-lf/verification/utils/Value.scala new file mode 100644 index 0000000000..9574b25877 --- /dev/null +++ b/daml-lf/verification/utils/Value.scala @@ -0,0 +1,14 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package lf.verified +package utils + +import stainless.lang._ +import stainless.annotation._ +import stainless.proof._ +import stainless.collection._ + +object Value { + case class ContractId(coid: String) +}