Refactor batch trigger to wrap arguments in a record type

This commit is contained in:
Carl Pulley 2023-01-13 17:13:02 +00:00 committed by GitHub
parent 8f5b25fc1f
commit edeadc44a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 67 additions and 20 deletions

View File

@ -277,10 +277,10 @@ runTrigger userTrigger = LowLevel.BatchTrigger
, heartbeat = userTrigger.heartbeat
}
where
initialState party readAs (ActiveContracts createdEvents) =
let acs = foldl (\acs created -> applyEvent (CreatedEvent created) acs) (ACS mempty Map.empty) createdEvents
userState = runTriggerInitializeA userTrigger.initialize (TriggerInitState acs party readAs)
state = TriggerState acs party readAs userState Map.empty
initialState args =
let acs = foldl (\acs created -> applyEvent (CreatedEvent created) acs) (ACS mempty Map.empty) args.acs.activeContracts
userState = runTriggerInitializeA userTrigger.initialize (TriggerInitState acs args.actAs args.readAs)
state = TriggerState acs args.actAs args.readAs userState Map.empty
in TriggerSetup $ execStateT (runTriggerRule $ runRule userTrigger.rule) state
mkUserState state acs msg =

View File

@ -179,7 +179,7 @@ runRule rule = do
-- | Transform the (legacy) low-level trigger type into a batching trigger.
runLegacyTrigger : LowLevel.Trigger s -> BatchTrigger s
runLegacyTrigger userTrigger = BatchTrigger
{ initialState = userTrigger.initialState
{ initialState = \args -> userTrigger.initialState args.actAs args.readAs args.acs
, update = \msgs -> forA_ msgs userTrigger.update
, registeredTemplates = userTrigger.registeredTemplates
, heartbeat = userTrigger.heartbeat

View File

@ -6,6 +6,7 @@ module Daml.Trigger.LowLevel
( Message(..)
, Completion(..)
, CompletionStatus(..)
, TriggerSetupArguments(..)
, Transaction(..)
, AnyContractId
, toAnyContractId
@ -167,6 +168,13 @@ data CompletionStatus
| Succeeded { transactionId : TransactionId }
deriving Show
-- Introduced in version 2.5.1: this definition is used to simplify future extensions of trigger initialState arguments
data TriggerSetupArguments = TriggerSetupArguments
{ actAs : Party
, readAs : [Party]
, acs : ActiveContracts
}
data ActiveContracts = ActiveContracts { activeContracts : [Created] }
-- @WARN use 'BatchTrigger s' instead of 'Trigger s'
@ -180,7 +188,7 @@ data Trigger s = Trigger
-- | Batching trigger is (approximately) a left-fold over `Message` with
-- an accumulator of type `s`.
data BatchTrigger s = BatchTrigger
{ initialState : Party -> [Party] -> ActiveContracts -> TriggerSetup s
{ initialState : TriggerSetupArguments -> TriggerSetup s
, update : [Message] -> TriggerRule s ()
, registeredTemplates : RegisteredTemplates
, heartbeat : Optional RelTime

View File

@ -8,8 +8,9 @@ package trigger
import scalaz.std.either._
import scalaz.std.list._
import scalaz.std.option._
import scalaz.syntax.tag._
import scalaz.syntax.traverse._
import com.daml.lf.data.{FrontStack, ImmArray}
import com.daml.lf.data.{FrontStack, ImmArray, Ref}
import com.daml.lf.data.Ref._
import com.daml.lf.language.Ast._
import com.daml.lf.speedy.{ArrayList, SValue}
@ -63,6 +64,8 @@ final class Converter(
private[this] val anyTemplateTyCon = DA.Internal.Any.AnyTemplate
private[this] val anyViewTyCon = DA.Internal.Interface.AnyView.Types.AnyView
private[this] val activeContractsTy = triggerIds.damlTriggerLowLevel("ActiveContracts")
private[this] val triggerSetupArgumentsTy =
triggerIds.damlTriggerLowLevel("TriggerSetupArguments")
private[this] val anyContractIdTy = triggerIds.damlTriggerLowLevel("AnyContractId")
private[this] val archivedTy = triggerIds.damlTriggerLowLevel("Archived")
private[this] val commandIdTy = triggerIds.damlTriggerLowLevel("CommandId")
@ -291,6 +294,23 @@ final class Converter(
.map(xs => SList(FrontStack.from(xs)))
} yield record(activeContractsTy, "activeContracts" -> events)
def fromTriggerSetupArguments(
parties: TriggerParties,
createdEvents: Seq[CreatedEvent],
): Either[String, SValue] =
for {
acs <- fromACS(createdEvents)
actAs = SParty(Ref.Party.assertFromString(parties.actAs.unwrap))
readAs = SList(
parties.readAs.map(p => SParty(Ref.Party.assertFromString(p.unwrap))).to(FrontStack)
)
} yield record(
triggerSetupArgumentsTy,
"actAs" -> actAs,
"readAs" -> readAs,
"acs" -> acs,
)
def toFiniteDuration(value: SValue): Either[String, FiniteDuration] =
value.expect(
"RelTime",

View File

@ -100,7 +100,35 @@ private[lf] final case class Trigger(
heartbeat: Option[FiniteDuration],
// Whether the trigger supports readAs claims (SDK 1.18 and newer) or not.
hasReadAs: Boolean,
)
) {
def initialStateArguments(
parties: TriggerParties,
acs: Seq[CreatedEvent],
converter: Converter,
): Array[SValue] = {
if (defn.version >= Trigger.Version.`2.5.1`) {
Array(converter.fromTriggerSetupArguments(parties, acs).orConverterException)
} else {
val createdValue: SValue = converter.fromACS(acs).orConverterException
val partyArg = SParty(Ref.Party.assertFromString(parties.actAs.unwrap))
if (hasReadAs) {
// trigger version SDK 1.18 and newer
val readAsArg = SList(
parties.readAs.map(p => SParty(Ref.Party.assertFromString(p.unwrap))).to(FrontStack)
)
Array(partyArg, readAsArg, createdValue)
} else {
// trigger version prior to SDK 1.18
Array(partyArg, createdValue)
}
}
}
}
private final case class InFlightCommandOverflowException(inFlightCommands: Int, crashCount: Int)
extends Exception
// Utilities for interacting with the speedy machine.
private[lf] object Machine {
@ -730,20 +758,11 @@ private[lf] class Runner private (
compiler.unsafeCompile(
ERecProj(trigger.defn.ty, Name.assertFromString("initialState"), trigger.defn.expr)
)
// Convert the ACS to a speedy value.
val createdValue: SValue = converter.fromACS(acs).orConverterException
// Setup an application expression of initialState on the ACS.
val partyArg = SParty(Ref.Party.assertFromString(parties.actAs.unwrap))
val initialStateArgs = if (trigger.hasReadAs) {
val readAsArg = SList(
parties.readAs.map(p => SParty(Ref.Party.assertFromString(p.unwrap))).to(FrontStack)
)
Array(partyArg, readAsArg, createdValue)
} else Array(partyArg, createdValue)
val initialState: SExpr =
makeApp(
getInitialState,
initialStateArgs,
trigger.initialStateArguments(parties, acs, converter),
)
// Prepare a speedy machine for evaluating expressions.
val machine: Speedy.PureMachine =

View File

@ -10,8 +10,8 @@ import Daml.Trigger.LowLevel
test : BatchTrigger [Text]
test = BatchTrigger
{ initialState = \party _readAs _ -> do
submitCommands [createCmd (T party)]
{ initialState = \args -> do
submitCommands [createCmd (T args.actAs)]
pure []
, update = \msgs -> forA_ msgs \msg -> do
case msg of