[Speedy] slight refactoring of contract key handeling (#16097)

* cleanup PartialTransaction API

* Put SValue + GlobalKey in cached contract key

* slight change of the ContractStateMahcine API

* drop (unsafe) builder for GlobalKey
This commit is contained in:
Remy 2023-01-23 15:40:45 +01:00 committed by GitHub
parent 758f1ce8f2
commit 18faa81608
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 180 additions and 189 deletions

View File

@ -1000,25 +1000,20 @@ private[lf] object SBuiltin {
val version = machine.tmplId2TxVersion(cached.templateId)
val createArgValue = cached.value.toNormalizedValue(version)
cached.key match {
case Some(skey) if skey.maintainers.isEmpty =>
case Some(cachedKey) if cachedKey.maintainers.isEmpty =>
Control.Error(
IE.CreateEmptyContractKeyMaintainers(
cached.templateId,
createArgValue,
skey.unnormalizedKeyValue,
cachedKey.lfValue,
)
)
case _ =>
machine.ptx
.insertCreate(
submissionTime = machine.submissionTime,
templateId = cached.templateId,
arg = createArgValue,
agreementText = cached.agreementText,
contract = cached,
optLocation = machine.getLastLocation,
signatories = cached.signatories,
stakeholders = cached.stakeholders,
key = cached.key.map(_.toNormalizedKeyWithMaintainers(version)),
version = version,
) match {
case Right((coid, newPtx)) =>
@ -1061,31 +1056,25 @@ private[lf] object SBuiltin {
machine
.getCachedContract(coid)
.getOrElse(crash(s"Contract ${coid.coid} is missing from cache"))
val sigs = cached.signatories
val templateVersion = machine.tmplId2TxVersion(templateId)
val interfaceVersion = interfaceId.map(machine.tmplId2TxVersion)
val exerciseVersion = interfaceVersion.fold(templateVersion)(_.max(templateVersion))
val chosenValue = args.get(0).toNormalizedValue(exerciseVersion)
val templateObservers = cached.observers
val ctrls = extractParties(NameOf.qualifiedNameOfCurrentFunc, args.get(2))
machine.enforceChoiceControllersLimit(ctrls, coid, templateId, choiceId, chosenValue)
val obsrs = extractParties(NameOf.qualifiedNameOfCurrentFunc, args.get(3))
machine.enforceChoiceObserversLimit(obsrs, coid, templateId, choiceId, chosenValue)
val mbKey = cached.key.map(_.toNormalizedKeyWithMaintainers(exerciseVersion))
machine.ptx
.beginExercises(
targetId = coid,
templateId = templateId,
contract = cached,
interfaceId = interfaceId,
choiceId = choiceId,
optLocation = machine.getLastLocation,
consuming = consuming,
actingParties = ctrls,
signatories = sigs,
stakeholders = sigs union templateObservers,
choiceObservers = obsrs,
mbKey = mbKey,
byKey = byKey,
chosenValue = chosenValue,
version = exerciseVersion,
@ -1484,16 +1473,10 @@ private[lf] object SBuiltin {
.getCachedContract(coid)
.getOrElse(crash(s"Contract ${coid.coid} is missing from cache"))
val version = machine.tmplId2TxVersion(templateId)
val signatories = cached.signatories
val observers = cached.observers
val key = cached.key.map(_.toNormalizedKeyWithMaintainers(version))
machine.ptx.insertFetch(
coid = coid,
templateId = templateId,
contract = cached,
optLocation = machine.getLastLocation,
signatories = signatories,
observers = observers,
key = key,
byKey = byKey,
version = version,
) match {
@ -1516,8 +1499,7 @@ private[lf] object SBuiltin {
args: util.ArrayList[SValue],
machine: UpdateMachine,
): Control[Nothing] = {
val keyWithMaintainers =
extractKeyWithMaintainers(NameOf.qualifiedNameOfCurrentFunc, args.get(0))
val cachedKey = extractKey(NameOf.qualifiedNameOfCurrentFunc, templateId, args.get(0))
val mbCoid = args.get(1) match {
case SOptional(mb) =>
mb.map {
@ -1527,11 +1509,10 @@ private[lf] object SBuiltin {
case _ => crash(s"Non option value when inserting lookup node")
}
val version = machine.tmplId2TxVersion(templateId)
val key = keyWithMaintainers.toNormalizedKeyWithMaintainers(version)
machine.ptx.insertLookup(
templateId = templateId,
optLocation = machine.getLastLocation,
key = key,
key = cachedKey,
result = mbCoid,
version = version,
) match {
@ -1593,19 +1574,18 @@ private[lf] object SBuiltin {
args: util.ArrayList[SValue],
machine: UpdateMachine,
): Control[Question.Update] = {
val skey = args.get(0)
val keyWithMaintainers = extractKeyWithMaintainers(NameOf.qualifiedNameOfCurrentFunc, skey)
val svalue = args.get(0)
val cachedKey = extractKey(NameOf.qualifiedNameOfCurrentFunc, operation.templateId, svalue)
if (keyWithMaintainers.maintainers.isEmpty) {
if (cachedKey.maintainers.isEmpty) {
Control.Error(
IE.FetchEmptyContractKeyMaintainers(
operation.templateId,
keyWithMaintainers.unnormalizedKeyValue,
cachedKey.lfValue,
)
)
} else {
val gkey = GlobalKey(operation.templateId, keyWithMaintainers.unnormalizedKeyValue)
val gkey = cachedKey.globalKey
machine.ptx.contractState.resolveKey(gkey) match {
case Right((keyMapping, next)) =>
machine.ptx = machine.ptx.copy(contractState = next)
@ -1632,7 +1612,7 @@ private[lf] object SBuiltin {
// SBFetchAny will populate machine.cachedContracts with the contract pointed by coid
val e = SEAppAtomic(
SEBuiltin(SBFetchAny),
Array(SEValue(SContractId(coid)), SEValue(SOptional(Some(skey)))),
Array(SEValue(SContractId(coid)), SEValue(SOptional(Some(svalue)))),
)
(Control.Expression(e), true)
}
@ -1649,7 +1629,7 @@ private[lf] object SBuiltin {
case None =>
Control.Question(
Question.Update.NeedKey(
GlobalKeyWithMaintainers(gkey, keyWithMaintainers.maintainers),
GlobalKeyWithMaintainers(gkey, cachedKey.maintainers),
machine.committers,
callback = { res =>
val (control, bool) = continue(res)
@ -2145,15 +2125,11 @@ private[lf] object SBuiltin {
val cachedContract = extractCachedContract(args.get(0))
val templateId = cachedContract.templateId
val optError: Option[Either[IE, Unit]] = for {
keyWithMaintainers <- cachedContract.key
cachedKey <- cachedContract.key
} yield {
for {
keyHash <- crypto.Hash
.hashContractKey(templateId, keyWithMaintainers.unnormalizedKeyValue)
.left
.map(msg => IE.DisclosedContractKeyHashingError(contractId, templateId, msg))
result <- machine.disclosureKeyTable
.addContractKey(templateId, keyHash, contractId)
.addContractKey(templateId, cachedKey.globalKey.hash, contractId)
} yield result
}
@ -2203,17 +2179,25 @@ private[lf] object SBuiltin {
private[this] val keyIdx = keyWithMaintainersStructFields.indexOf(Ast.keyFieldName)
private[this] val maintainerIdx = keyWithMaintainersStructFields.indexOf(Ast.maintainersFieldName)
private[this] def extractKeyWithMaintainers(location: String, v: SValue): SKeyWithMaintainers =
private[this] def extractKey(
location: String,
templateId: Ref.TypeConName,
v: SValue,
): CachedKey =
v match {
case SStruct(_, vals) =>
val skey = SKeyWithMaintainers(
vals.get(keyIdx),
val keyValue = vals.get(keyIdx)
val lfValue = keyValue.toUnnormalizedValue
val gkey = GlobalKey
.build(templateId, lfValue)
.getOrElse(
throw SErrorDamlException(IE.ContractIdInContractKey(keyValue.toUnnormalizedValue))
)
CachedKey(
gkey,
keyValue,
extractParties(NameOf.qualifiedNameOfCurrentFunc, vals.get(maintainerIdx)),
)
skey.unnormalizedKeyValue.foreachCid(_ =>
throw SErrorDamlException(IE.ContractIdInContractKey(skey.unnormalizedKeyValue))
)
skey
case _ => throw SErrorCrash(location, s"Invalid key with maintainers: $v")
}
@ -2249,7 +2233,7 @@ private[lf] object SBuiltin {
}
val mbKey = vals.get(cachedContractKeyIdx) match {
case SOptional(mbKey) =>
mbKey.map(extractKeyWithMaintainers(NameOf.qualifiedNameOfCurrentFunc, _))
mbKey.map(extractKey(NameOf.qualifiedNameOfCurrentFunc, templateId, _))
case v =>
throw SErrorCrash(
NameOf.qualifiedNameOfCurrentFunc,

View File

@ -111,11 +111,14 @@ private[lf] object Speedy {
sealed abstract class LedgerMode extends Product with Serializable
final case class SKeyWithMaintainers(key: SValue, maintainers: Set[Party]) {
def toNormalizedKeyWithMaintainers(version: TxVersion) =
final case class CachedKey(
globalKey: GlobalKey,
key: SValue,
maintainers: Set[Party],
) {
def toNodeKey(version: TxVersion) =
Node.KeyWithMaintainers(key.toNormalizedValue(version), maintainers)
val unnormalizedKeyValue = key.toUnnormalizedValue
val unnormalizedKeyWithMaintainers = Node.KeyWithMaintainers(unnormalizedKeyValue, maintainers)
val lfValue = globalKey.key
}
final case class CachedContract(
@ -124,7 +127,7 @@ private[lf] object Speedy {
agreementText: String,
signatories: Set[Party],
observers: Set[Party],
key: Option[SKeyWithMaintainers],
key: Option[CachedKey],
) {
val stakeholders: Set[Party] = signatories union observers
private[speedy] val any = SValue.SAny(TTyCon(templateId), value)
@ -344,7 +347,7 @@ private[lf] object Speedy {
val transactionVersion = tmplId2TxVersion(cachedContract.templateId)
val maybeKeyWithMaintainers =
cachedContract.key
.map(_.toNormalizedKeyWithMaintainers(transactionVersion))
.map(_.toNodeKey(transactionVersion))
.map { case Node.KeyWithMaintainers(key, maintainers) =>
GlobalKeyWithMaintainers(
globalKey = GlobalKey.assertBuild(disclosedContract.templateId, key),

View File

@ -5,12 +5,12 @@ package com.daml.lf
package speedy
import com.daml.lf.data.Ref.{ChoiceName, Location, Party, TypeConName}
import com.daml.lf.data.{BackStack, ImmArray, Ref, Time}
import com.daml.lf.data.{BackStack, ImmArray, Time}
import com.daml.lf.ledger.Authorize
import com.daml.lf.speedy.Speedy.{CachedContract, CachedKey}
import com.daml.lf.transaction.ContractKeyUniquenessMode
import com.daml.lf.transaction.{
ContractStateMachine,
GlobalKey,
Node,
NodeId,
SubmittedTransaction => SubmittedTx,
@ -336,28 +336,23 @@ private[speedy] case class PartialTransaction(
*/
def insertCreate(
submissionTime: Time.Timestamp,
templateId: Ref.Identifier,
arg: Value,
agreementText: String,
contract: CachedContract,
optLocation: Option[Location],
signatories: Set[Party],
stakeholders: Set[Party],
key: Option[Node.KeyWithMaintainers],
version: TxVersion,
): Either[(PartialTransaction, Tx.TransactionError), (Value.ContractId, PartialTransaction)] = {
val auth = Authorize(context.info.authorizers)
val actionNodeSeed = context.nextActionChildSeed
val discriminator =
crypto.Hash.deriveContractDiscriminator(actionNodeSeed, submissionTime, stakeholders)
crypto.Hash.deriveContractDiscriminator(actionNodeSeed, submissionTime, contract.stakeholders)
val cid = Value.ContractId.V1(discriminator)
val createNode = Node.Create(
cid,
templateId,
arg,
agreementText,
signatories,
stakeholders,
key,
contract.templateId,
contract.value.toNormalizedValue(version),
contract.agreementText,
contract.signatories,
contract.stakeholders,
contract.key.map(_.toNodeKey(version)),
version,
)
val nid = NodeId(nextNodeIdx)
@ -371,7 +366,10 @@ private[speedy] case class PartialTransaction(
authorizationChecker.authorizeCreate(optLocation, createNode)(auth) match {
case fa :: _ => Left((ptx, Tx.AuthFailureDuringExecution(nid, fa)))
case Nil =>
ptx.contractState.visitCreate(templateId, cid, key.map(_.key)) match {
ptx.contractState.visitCreate(
cid,
contract.key.map(_.globalKey),
) match {
case Right(next) =>
val nextPtx = ptx.copy(contractState = next)
Right((cid, nextPtx))
@ -383,33 +381,33 @@ private[speedy] case class PartialTransaction(
def insertFetch(
coid: Value.ContractId,
templateId: TypeConName,
contract: CachedContract,
optLocation: Option[Location],
signatories: Set[Party],
observers: Set[Party],
key: Option[Node.KeyWithMaintainers],
byKey: Boolean,
version: TxVersion,
): Either[Tx.TransactionError, PartialTransaction] = {
val stakeholders = observers union signatories
val contextActors = context.info.authorizers
val actingParties = contextActors intersect stakeholders
val actingParties = contextActors intersect contract.stakeholders
val auth = Authorize(context.info.authorizers)
val nid = NodeId(nextNodeIdx)
val node = Node.Fetch(
coid,
templateId,
contract.templateId,
actingParties,
signatories,
stakeholders,
key,
contract.signatories,
contract.stakeholders,
contract.key.map(_.toNodeKey(version)),
normByKey(version, byKey),
version,
)
mustBeActive(NameOf.qualifiedNameOfCurrentFunc, coid) {
val newContractState = assertRightKey(
// evaluation order tests require visitFetch proceeds authorizeFetch
contractState.visitFetch(templateId, coid, key.map(_.key), byKey)
contractState.visitFetch(
coid,
contract.key.map(_.globalKey),
byKey,
)
)
authorizationChecker.authorizeFetch(optLocation, node)(auth) match {
case fa :: _ => Left(Tx.AuthFailureDuringExecution(nid, fa))
@ -422,25 +420,18 @@ private[speedy] case class PartialTransaction(
def insertLookup(
templateId: TypeConName,
optLocation: Option[Location],
key: Node.KeyWithMaintainers,
key: CachedKey,
result: Option[Value.ContractId],
version: TxVersion,
): Either[Tx.TransactionError, PartialTransaction] = {
val auth = Authorize(context.info.authorizers)
val nid = NodeId(nextNodeIdx)
val node = Node.LookupByKey(
templateId,
key,
result,
version,
)
val gkey = GlobalKey.assertBuild(templateId, key.key)
val node = Node.LookupByKey(templateId, key.toNodeKey(version), result, version)
// This method is only called after we have already resolved the key in com.daml.lf.speedy.SBuiltin.SBUKeyBuiltin.execute
// so the current state's global key inputs must resolve the key.
val keyInput = contractState.globalKeyInputs(gkey)
val newContractState = assertRightKey(
contractState.visitLookup(templateId, key.key, keyInput.toKeyMapping, result)
)
val keyInput = contractState.globalKeyInputs(key.globalKey)
val newContractState =
assertRightKey(contractState.visitLookup(key.globalKey, keyInput.toKeyMapping, result))
authorizationChecker.authorizeLookupByKey(optLocation, node)(auth) match {
case fa :: _ => Left(Tx.AuthFailureDuringExecution(nid, fa))
case Nil =>
@ -453,16 +444,13 @@ private[speedy] case class PartialTransaction(
*/
def beginExercises(
targetId: Value.ContractId,
templateId: TypeConName,
contract: CachedContract,
interfaceId: Option[TypeConName],
choiceId: ChoiceName,
optLocation: Option[Location],
consuming: Boolean,
actingParties: Set[Party],
signatories: Set[Party],
stakeholders: Set[Party],
choiceObservers: Set[Party],
mbKey: Option[Node.KeyWithMaintainers],
byKey: Boolean,
chosenValue: Value,
version: TxVersion,
@ -472,15 +460,15 @@ private[speedy] case class PartialTransaction(
val ec =
ExercisesContextInfo(
targetId = targetId,
templateId = templateId,
templateId = contract.templateId,
interfaceId = interfaceId,
contractKey = mbKey,
contractKey = contract.key.map(_.toNodeKey(version)),
choiceId = choiceId,
consuming = consuming,
actingParties = actingParties,
chosenValue = chosenValue,
signatories = signatories,
stakeholders = stakeholders,
signatories = contract.signatories,
stakeholders = contract.stakeholders,
choiceObservers = choiceObservers,
nodeId = nid,
parent = context,
@ -491,7 +479,13 @@ private[speedy] case class PartialTransaction(
// important: the semantics of Daml dictate that contracts are immediately
// inactive as soon as you exercise it. therefore, mark it as consumed now.
val newContractState = assertRightKey(
contractState.visitExercise(nid, templateId, targetId, mbKey.map(_.key), byKey, consuming)
contractState.visitExercise(
nid,
targetId,
contract.key.map(_.globalKey),
byKey,
consuming,
)
)
authorizationChecker.authorizeExercise(optLocation, makeExNode(ec))(auth) match {
case fa :: _ => Left(Tx.AuthFailureDuringExecution(nid, fa))

View File

@ -605,7 +605,7 @@ object CompilerTest {
if (withKey) {
Some(
GlobalKeyWithMaintainers(
GlobalKey(templateId, key.toUnnormalizedValue),
GlobalKey.assertBuild(templateId, key.toUnnormalizedValue),
Set(maintainer),
)
)

View File

@ -11,7 +11,7 @@ import com.daml.lf.data.{Bytes, FrontStack, ImmArray, Ref, Struct, Time}
import com.daml.lf.language.Ast
import com.daml.lf.speedy.SExpr.SEMakeClo
import com.daml.lf.speedy.SValue.{SContractId, SToken}
import com.daml.lf.speedy.Speedy.{CachedContract, SKeyWithMaintainers}
import com.daml.lf.speedy.Speedy.{CachedContract, CachedKey}
import com.daml.lf.transaction.{GlobalKey, GlobalKeyWithMaintainers, TransactionVersion, Versioned}
import com.daml.lf.value.Value
import com.daml.lf.value.Value.{ContractId, ContractInstance}
@ -219,7 +219,15 @@ object ExplicitDisclosureLib {
),
)
val mbKey =
if (withKey) Some(SKeyWithMaintainers(contract, Set(maintainer))) else None
if (withKey)
Some(
CachedKey(
GlobalKey.assertBuild(templateId, contract.toUnnormalizedValue),
contract,
Set(maintainer),
)
)
else None
CachedContract(
templateId,

View File

@ -7,6 +7,7 @@ package speedy
import com.daml.lf.data.ImmArray
import com.daml.lf.speedy.PartialTransaction
import com.daml.lf.speedy.SValue.{SValue => _, _}
import com.daml.lf.speedy.Speedy.CachedContract
import com.daml.lf.transaction.{ContractKeyUniquenessMode, Node, TransactionVersion}
import com.daml.lf.value.Value
import org.scalatest._
@ -40,17 +41,21 @@ class PartialTransactionSpec extends AnyWordSpec with Matchers with Inside {
private[this] implicit class PartialTransactionExtra(val ptx: PartialTransaction) {
val contract = CachedContract(
templateId = templateId,
value = SValue.SUnit,
agreementText = "agreement",
signatories = Set(party),
observers = Set.empty,
key = None,
)
def insertCreate_ : PartialTransaction =
ptx
.insertCreate(
submissionTime = data.Time.Timestamp.Epoch,
templateId = templateId,
arg = Value.ValueUnit,
agreementText = "agreement",
contract = contract,
optLocation = None,
signatories = Set(party),
stakeholders = Set.empty,
key = None,
version = TransactionVersion.maxVersion,
)
.toOption
@ -61,16 +66,13 @@ class PartialTransactionSpec extends AnyWordSpec with Matchers with Inside {
ptx
.beginExercises(
targetId = cid,
templateId = templateId,
contract = contract,
interfaceId = None,
choiceId = choiceId,
optLocation = None,
consuming = false,
actingParties = Set(party),
signatories = Set(party),
stakeholders = Set.empty,
choiceObservers = Set.empty,
mbKey = None,
byKey = false,
chosenValue = Value.ValueUnit,
version = TransactionVersion.maxVersion,

View File

@ -20,7 +20,7 @@ import com.daml.lf.speedy.SBuiltin.{
import com.daml.lf.speedy.SError.{SError, SErrorCrash}
import com.daml.lf.speedy.SExpr._
import com.daml.lf.speedy.SValue.{SValue => _, _}
import com.daml.lf.speedy.Speedy.{CachedContract, Machine, SKeyWithMaintainers}
import com.daml.lf.speedy.Speedy.{CachedContract, Machine, CachedKey}
import com.daml.lf.testing.parser.Implicits._
import com.daml.lf.transaction.{GlobalKey, GlobalKeyWithMaintainers, TransactionVersion}
import com.daml.lf.value.Value
@ -1693,7 +1693,9 @@ class SBuiltinTest extends AnyFreeSpec with Matchers with TableDrivenPropertyChe
val templateId = Ref.Identifier.assertFromString("-pkgId-:Mod:IouWithKey")
val (disclosedContract, Some((key, keyWithMaintainers, keyHash))) =
buildDisclosedContract(contractId, alice, alice, templateId, withKey = true)
val optionalKey = Some(SKeyWithMaintainers(key, Set(alice)))
val optionalKey = Some(
CachedKey(GlobalKey.assertBuild(templateId, key.toUnnormalizedValue), key, Set(alice))
)
val cachedContract = CachedContract(
templateId,
disclosedContract.argument,
@ -1913,7 +1915,7 @@ object SBuiltinTest {
if (withKey) {
Some(
GlobalKeyWithMaintainers(
GlobalKey(templateId, key.toUnnormalizedValue),
GlobalKey.assertBuild(templateId, key.toUnnormalizedValue),
Set(maintainer),
)
)

View File

@ -422,7 +422,7 @@ object TransactionConversionsSpec {
)
}
private val aChildExerciseNode = exercise(Set(aChildObserver), ImmArray.empty)
private val aGlobalKey = GlobalKey(aTemplateId, aUnitValue)
private val aGlobalKey = GlobalKey.assertBuild(aTemplateId, aUnitValue)
private val aRawChildNode = RawTransaction.Node(aChildNode.toByteString)
private val aRootExerciseNode = exercise(Set(aRootObserver), ImmArray(aChildNodeId))
private val aRootNode = buildProtoNode(aRawRootNodeId.value) { builder =>

View File

@ -4,7 +4,6 @@
package com.daml.lf
package transaction
import com.daml.lf.data.Ref.{Identifier, TypeConName}
import com.daml.lf.transaction.Transaction.{
DuplicateContractKey,
InconsistentContractKey,
@ -126,12 +125,11 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) {
/** Visit a create node */
def handleCreate(node: Node.Create): Either[KeyInputError, State] =
visitCreate(node.templateId, node.coid, node.keyValue).left.map(Right(_))
visitCreate(node.coid, globalKeyOpt(node)).left.map(Right(_))
private[lf] def visitCreate(
templateId: TypeConName,
contractId: ContractId,
mbKey: Option[Value],
mbKey: Option[GlobalKey],
): Either[DuplicateContractKey, State] = {
val me =
this.copy(
@ -145,21 +143,19 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) {
// active keys
mbKey match {
case None => Right(me)
case Some(key) =>
val ck = GlobalKey(templateId, key)
val conflict = lookupActiveKey(ck).exists(_ != KeyInactive)
case Some(gk) =>
val conflict = lookupActiveKey(gk).exists(_ != KeyInactive)
val newKeyInputs =
if (globalKeyInputs.contains(ck)) globalKeyInputs
else globalKeyInputs.updated(ck, KeyCreate)
if (globalKeyInputs.contains(gk)) globalKeyInputs
else globalKeyInputs.updated(gk, KeyCreate)
Either.cond(
!conflict || mode == ContractKeyUniquenessMode.Off,
me.copy(
activeState = me.activeState.createKey(ck, contractId),
activeState = me.activeState.createKey(gk, contractId),
globalKeyInputs = newKeyInputs,
),
DuplicateContractKey(ck),
DuplicateContractKey(gk),
)
}
}
@ -167,9 +163,8 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) {
def handleExercise(nid: Nid, exe: Node.Exercise): Either[KeyInputError, State] =
visitExercise(
nid,
exe.templateId,
exe.targetCoid,
exe.keyValue,
globalKeyOpt(exe),
exe.byKey,
exe.consuming,
).left
@ -181,16 +176,15 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) {
*/
private[lf] def visitExercise(
nodeId: Nid,
templateId: TypeConName,
targetId: ContractId,
mbKey: Option[Value],
mbKey: Option[GlobalKey],
byKey: Boolean,
consuming: Boolean,
): Either[InconsistentContractKey, State] = {
for {
state <-
if (byKey || mode == ContractKeyUniquenessMode.Strict)
assertKeyMapping(templateId, targetId, mbKey)
assertKeyMapping(targetId, mbKey)
else
Right(this)
} yield {
@ -210,7 +204,7 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) {
throw new UnsupportedOperationException(
"handleLookup can only be used if all key nodes are considered"
)
visitLookup(lookup.templateId, lookup.key.key, lookup.result, lookup.result).left.map(Left(_))
visitLookup(globalKey(lookup), lookup.result, lookup.result).left.map(Left(_))
}
/** Must be used to handle lookups iff in [[com.daml.lf.transaction.ContractKeyUniquenessMode.Off]] mode
@ -232,16 +226,14 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) {
throw new UnsupportedOperationException(
"handleLookupWith can only be used if only by-key nodes are considered"
)
visitLookup(lookup.templateId, lookup.key.key, keyInput, lookup.result).left.map(Left(_))
visitLookup(globalKey(lookup), keyInput, lookup.result).left.map(Left(_))
}
private[lf] def visitLookup(
templateId: TypeConName,
key: Value,
gk: GlobalKey,
keyInput: Option[ContractId],
keyResolution: Option[ContractId],
): Either[InconsistentContractKey, State] = {
val gk = GlobalKey.assertBuild(templateId, key)
val (keyMapping, next) = resolveKey(gk) match {
case Right(result) => result
case Left(handle) => handle(keyInput)
@ -282,28 +274,25 @@ class ContractStateMachine[Nid](mode: ContractKeyUniquenessMode) {
}
def handleFetch(node: Node.Fetch): Either[KeyInputError, State] =
visitFetch(node.templateId, node.coid, node.keyValue, node.byKey).left.map(Left(_))
visitFetch(node.coid, globalKeyOpt(node), node.byKey).left.map(Left(_))
private[lf] def visitFetch(
templateId: TypeConName,
contractId: ContractId,
mbKey: Option[Value],
mbKey: Option[GlobalKey],
byKey: Boolean,
): Either[InconsistentContractKey, State] =
if (byKey || mode == ContractKeyUniquenessMode.Strict)
assertKeyMapping(templateId, contractId, mbKey)
assertKeyMapping(contractId, mbKey)
else
Right(this)
private[this] def assertKeyMapping(
templateId: Identifier,
cid: Value.ContractId,
mbKey: Option[Value],
mbKey: Option[GlobalKey],
): Either[InconsistentContractKey, State] =
mbKey match {
case None => Right(this)
case Some(key) =>
val gk = GlobalKey.assertBuild(templateId, key)
case Some(gk) =>
val (keyMapping, next) = resolveKey(gk) match {
case Right(result) => result
case Left(handle) => handle(Some(cid))
@ -542,4 +531,11 @@ object ContractStateMachine {
ActiveLedgerState(Set.empty, Map.empty, Map.empty)
def empty[Nid]: ActiveLedgerState[Nid] = EMPTY
}
private def globalKeyOpt(node: Node.Action) =
node.keyOpt.map(k => GlobalKey.assertBuild(node.templateId, k.key))
private def globalKey(node: Node.LookupByKey) =
GlobalKey.assertBuild(node.templateId, node.key.key)
}

View File

@ -14,7 +14,7 @@ final class GlobalKey private (
val templateId: Ref.TypeConName,
val key: Value,
val hash: crypto.Hash,
) extends {
) extends data.NoCopy {
override def equals(obj: Any): Boolean = obj match {
case that: GlobalKey => this.hash == that.hash
case _ => false
@ -27,9 +27,6 @@ final class GlobalKey private (
object GlobalKey {
def apply(templateId: Ref.TypeConName, key: Value): GlobalKey =
new GlobalKey(templateId, key, crypto.Hash.safeHashContractKey(templateId, key))
// Will fail if key contains contract ids
def build(templateId: Ref.TypeConName, key: Value): Either[String, GlobalKey] =
crypto.Hash.hashContractKey(templateId, key).map(new GlobalKey(templateId, key, _))

View File

@ -379,7 +379,7 @@ class TransactionSpec
val dummyBuilder = TransactionBuilder()
val parties = List("Alice")
def keyValue(s: String) = V.ValueText(s)
def globalKey(s: V.ContractId) = GlobalKey("Mod:T", keyValue(s.coid))
def globalKey(s: V.ContractId) = GlobalKey.assertBuild("Mod:T", keyValue(s.coid))
def create(s: V.ContractId) = dummyBuilder
.create(
id = s,

View File

@ -74,7 +74,9 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val value = complexValue
val hash = "2b1019f99147ca726baa3a12509399327746f1f9c4636a6ec5f5d7af1e7c2942"
KeyHasher.hashKeyString(GlobalKey(templateId("module", "name"), value)) shouldBe hash
KeyHasher.hashKeyString(
GlobalKey.assertBuild(templateId("module", "name"), value)
) shouldBe hash
}
"be deterministic and thread safe" in {
@ -82,7 +84,7 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
// Note: intentionally does not reuse value instances
val hashes = Vector
.range(0, 1000)
.map(_ => GlobalKey(templateId("module", "name"), complexValue))
.map(_ => GlobalKey.assertBuild(templateId("module", "name"), complexValue))
.par
.map(key => KeyHasher.hashKeyString(key))
@ -93,8 +95,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
// Same value but different template ID should produce a different hash
val value = ValueText("A")
val hash1 = KeyHasher.hashKeyString(GlobalKey(templateId("AA", "A"), value))
val hash2 = KeyHasher.hashKeyString(GlobalKey(templateId("A", "AA"), value))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(templateId("AA", "A"), value))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(templateId("A", "AA"), value))
hash1.equals(hash2) shouldBe false
}
@ -106,8 +108,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}
@ -121,8 +123,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}
@ -146,8 +148,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}
@ -160,8 +162,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}
@ -174,8 +176,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}
@ -202,8 +204,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}
@ -230,8 +232,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}
@ -242,8 +244,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}
@ -254,8 +256,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}
@ -266,8 +268,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}
@ -280,8 +282,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}
@ -294,8 +296,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}
@ -306,8 +308,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}
@ -332,8 +334,8 @@ class KeyHasherSpec extends AnyWordSpec with Matchers {
val tid = templateId("module", "name")
val hash1 = KeyHasher.hashKeyString(GlobalKey(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey(tid, value2))
val hash1 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value1))
val hash2 = KeyHasher.hashKeyString(GlobalKey.assertBuild(tid, value2))
hash1.equals(hash2) shouldBe false
}

View File

@ -218,7 +218,10 @@ private object MutableCacheBackedContractStoreRaceTests {
contractsCount: Long,
): Seq[Offset => SimplifiedContractStateEvent] = {
val keys = (0L until keysCount).map { keyIdx =>
keyIdx -> GlobalKey(Identifier.assertFromString("pkgId:module:entity"), ValueInt64(keyIdx))
keyIdx -> GlobalKey.assertBuild(
Identifier.assertFromString("pkgId:module:entity"),
ValueInt64(keyIdx),
)
}.toMap
val keysToContracts = keys.map { case (keyIdx, key) =>

View File

@ -561,7 +561,7 @@ class SequenceSpec
private def contractKey(i: Long) = {
val templateId = Ref.Identifier.assertFromString("pkg:M:T")
GlobalKey(templateId, Value.ValueInt64(i))
GlobalKey.assertBuild(templateId, Value.ValueInt64(i))
}
private def cId(i: Int) = ContractId.V1(Hash.hashPrivateKey(i.toString))

View File

@ -114,7 +114,7 @@ class SequencerStateSpec extends AnyFlatSpec with Matchers {
private def key(i: Long) = {
val templateId = Ref.Identifier.assertFromString("pkg:M:T")
GlobalKey(templateId, Value.ValueInt64(i))
GlobalKey.assertBuild(templateId, Value.ValueInt64(i))
}
private def cid(i: Int): ContractId = ContractId.V1(Hash.hashPrivateKey(i.toString))