Remove triggers (#18142)

* remove triggers

* adopt the codebase to missing triggers

run-all-tests: true

* fix builid

* Remove oracle tests
This commit is contained in:
mziolekda 2024-01-12 17:55:05 +01:00 committed by GitHub
parent 043dfc55c3
commit 7108f2c76a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
243 changed files with 24 additions and 30407 deletions

View File

@ -25,7 +25,6 @@ NOTICES @garyverhaegen-da @dasormeter
# Language
/daml-assistant/ @remyhaemmerle-da @filmackay
/daml-script/ @remyhaemmerle-da
/triggers/ @remyhaemmerle-da
/compiler/ @basvangijzel-DA @remyhaemmerle-da @akrmn @dylant-da @samuel-williams-da @paulbrauner-da
/libs-haskell/ @basvangijzel-DA @remyhaemmerle-da @akrmn @dylant-da @samuel-williams-da @paulbrauner-da
/ghc-lib/ @basvangijzel-DA @remyhaemmerle-da @akrmn @dylant-da @samuel-williams-da @paulbrauner-da
@ -56,7 +55,6 @@ NOTICES @garyverhaegen-da @dasormeter
# Application Runtime
/ledger-service/ @filmackay
/runtime-components/ @filmackay
/triggers/service/ @filmackay
# Canton code drop
/canton/ @remyhaemmerle-da

View File

@ -635,10 +635,6 @@ load("@canton_maven//:defs.bzl", pinned_canton_maven_install = "pinned_maven_ins
pinned_canton_maven_install()
load("@triggers_maven//:defs.bzl", pinned_triggers_maven_install = "pinned_maven_install")
pinned_triggers_maven_install()
load("@deprecated_maven//:defs.bzl", pinned_deprecated_maven_install = "pinned_maven_install")
pinned_deprecated_maven_install()

View File

@ -313,22 +313,6 @@ def install_java_deps():
version_conflict_policy = "pinned",
)
# Triggers depend on sjsonnet whose latest version still depends transitively on upickle 1.x,
# while ammonite cannot work with upickle < 2.x. So we define the sjsonnet dependency in a
# different maven_install.
maven_install(
name = "triggers_maven",
maven_install_json = "@//:triggers_maven_install.json",
artifacts = [
"com.lihaoyi:sjsonnet_{}:0.3.0".format(scala_major_version),
],
repositories = [
"https://repo1.maven.org/maven2",
],
fetch_sources = True,
version_conflict_policy = "pinned",
)
# Do not use those dependencies in anything new !
maven_install(
name = "deprecated_maven",

View File

@ -147,101 +147,6 @@ jobs:
trigger_sha: '$(trigger_sha)'
- template: report-end.yml
- job: Linux_oracle
timeoutInMinutes: 240
pool:
name: 'ubuntu_20_04'
demands: assignment -equals default
steps:
- template: report-start.yml
- checkout: self
- bash: ci/dev-env-install.sh
displayName: 'Build/Install the Developer Environment'
- template: clean-up.yml
- bash: |
source dev-env/lib/ensure-nix
ci/dev-env-push.py
displayName: 'Push Developer Environment build results'
condition: and(succeeded(), eq(variables['System.PullRequest.IsFork'], 'False'))
env:
# to upload to the Nix cache
GOOGLE_APPLICATION_CREDENTIALS_CONTENT: $(GOOGLE_APPLICATION_CREDENTIALS_CONTENT)
NIX_SECRET_KEY_CONTENT: $(NIX_SECRET_KEY_CONTENT)
- bash: ci/configure-bazel.sh
displayName: 'Configure Bazel'
env:
IS_FORK: $(System.PullRequest.IsFork)
# to upload to the bazel cache
GOOGLE_APPLICATION_CREDENTIALS_CONTENT: $(GOOGLE_APPLICATION_CREDENTIALS_CONTENT)
- bash: |
set -euo pipefail
eval "$(./dev-env/bin/dade-assist)"
docker login --username "$DOCKER_LOGIN" --password "$DOCKER_PASSWORD"
IMAGE=$(cat ci/oracle_image)
docker pull $IMAGE
# Cleanup stray containers that might still be running from
# another build that didnt get shut down cleanly.
docker rm -f oracle || true
# Oracle does not like if you connect to it via localhost if its running in the container.
# Interestingly it works if you use the external IP of the host so the issue is
# not the host it is listening on (it claims for that to be 0.0.0.0).
# --network host is a cheap escape hatch for this.
docker run -d --rm --name oracle --network host -e ORACLE_PWD=$ORACLE_PWD $IMAGE
function cleanup() {
docker rm -f oracle
}
trap cleanup EXIT
testConnection() {
docker exec oracle bash -c 'sqlplus -L '"$ORACLE_USERNAME"'/'"$ORACLE_PWD"'@//localhost:'"$ORACLE_PORT"'/ORCLPDB1 <<< "select * from dba_users;"; exit $?' >/dev/null
}
until testConnection
do
echo "Could not connect to Oracle, trying again..."
sleep 1
done
# Actually run some tests
# Note: Oracle tests all run sequentially because they all access the same Oracle instance,
# and we sometimes observe transient connection issues when running tests in parallel.
bazel test \
--config=oracle \
--test_strategy=exclusive \
--test_tag_filters=+oracle \
//...
oracle_logs=$(Build.StagingDirectory)/oracle-logs
mkdir $oracle_logs
for path in $(docker exec oracle bash -c 'find /opt/oracle/diag/rdbms/ -type f'); do
# $path starts with a slash
mkdir -p $(dirname ${oracle_logs}${path})
docker exec oracle bash -c "cat $path" > ${oracle_logs}${path}
done
env:
DOCKER_LOGIN: $(DOCKER_LOGIN)
DOCKER_PASSWORD: $(DOCKER_PASSWORD)
ARTIFACTORY_USERNAME: $(ARTIFACTORY_USERNAME)
ARTIFACTORY_PASSWORD: $(ARTIFACTORY_PASSWORD)
displayName: 'Build'
condition: and(succeeded(), eq(variables['System.PullRequest.IsFork'], 'False'))
- task: PublishBuildArtifacts@1
condition: failed()
displayName: 'Publish the bazel test logs'
inputs:
pathtoPublish: 'bazel-testlogs/'
artifactName: 'Test logs Oracle'
- task: PublishBuildArtifacts@1
condition: failed()
displayName: 'Publish Oracle image logs'
inputs:
pathtoPublish: '$(Build.StagingDirectory)/oracle-logs'
artifactName: 'Oracle image logs'
- template: tell-slack-failed.yml
parameters:
trigger_sha: '$(trigger_sha)'
- template: report-end.yml
- job: platform_independence_test
condition: and(succeeded(),
eq(dependencies.check_for_release.outputs['out.is_release'], 'false'))
@ -336,7 +241,6 @@ jobs:
condition: failed()
dependsOn:
- Linux
- Linux_oracle
- macOS
- Windows
- release

View File

@ -32,22 +32,6 @@ if [[ "$NAME" == "linux" ]]; then
PROTOS_ZIP=protobufs-$RELEASE_TAG.zip
cp bazel-bin/release/protobufs.zip $OUTPUT_DIR/github/$PROTOS_ZIP
TRIGGER_SERVICE=trigger-service-$RELEASE_TAG.jar
TRIGGER_SERVICE_EE=trigger-service-$RELEASE_TAG-ee.jar
bazel build //triggers/service:trigger-service-binary-ce_distribute.jar
cp bazel-bin/triggers/service/trigger-service-binary-ce_distribute.jar $OUTPUT_DIR/github/$TRIGGER_SERVICE
bazel build //triggers/service:trigger-service-binary-ee_distribute.jar
cp bazel-bin/triggers/service/trigger-service-binary-ee_distribute.jar $OUTPUT_DIR/artifactory/$TRIGGER_SERVICE_EE
OAUTH2_MIDDLEWARE=oauth2-middleware-$RELEASE_TAG.jar
bazel build //triggers/service/auth:oauth2-middleware-binary_distribute.jar
cp bazel-bin/triggers/service/auth/oauth2-middleware-binary_distribute.jar $OUTPUT_DIR/github/$OAUTH2_MIDDLEWARE
TRIGGER=daml-trigger-runner-$RELEASE_TAG.jar
bazel build //triggers/runner:trigger-runner_distribute.jar
cp bazel-bin/triggers/runner/trigger-runner_distribute.jar $OUTPUT_DIR/artifactory/$TRIGGER
SCRIPT=daml-script-$RELEASE_TAG.jar
bazel build //daml-script/runner:daml-script-binary_distribute.jar
cp bazel-bin/daml-script/runner/daml-script-binary_distribute.jar $OUTPUT_DIR/artifactory/$SCRIPT
@ -58,10 +42,6 @@ if [[ "$NAME" == "linux" ]]; then
bazel build //daml-script/daml3:daml3-script-dars
cp bazel-bin/daml-script/daml3/*.dar $OUTPUT_DIR/split-release/daml-libs/daml-script/
mkdir -p $OUTPUT_DIR/split-release/daml-libs/daml-trigger
bazel build //triggers/daml:daml-trigger-dars
cp bazel-bin/triggers/daml/*.dar $OUTPUT_DIR/split-release/daml-libs/daml-trigger/
mkdir -p $OUTPUT_DIR/split-release/docs
bazel build //docs:sphinx-source-tree //docs:pdf-fonts-tar //docs:non-sphinx-html-docs //docs:sphinx-source-tree-deps

View File

@ -25,17 +25,11 @@ push() {
https://digitalasset.jfrog.io/artifactory/${repository}/$RELEASE_TAG/${file}
}
TRIGGER_RUNNER=daml-trigger-runner-$RELEASE_TAG.jar
TRIGGER_SERVICE=trigger-service-$RELEASE_TAG-ee.jar
SCRIPT_RUNNER=daml-script-$RELEASE_TAG.jar
cd $INPUTS
push daml-trigger-runner $TRIGGER_RUNNER
push daml-trigger-runner $TRIGGER_RUNNER.asc
push daml-script-runner $SCRIPT_RUNNER
push daml-script-runner $SCRIPT_RUNNER.asc
push trigger-service $TRIGGER_SERVICE
push trigger-service $TRIGGER_SERVICE.asc
# For the split release process these are not published to artifactory.
if [[ "$#" -lt 3 || $3 != "split" ]]; then

View File

@ -8,11 +8,6 @@ load(
"daml_script_example_dar",
"daml_script_example_test",
)
load(
"//bazel_tools/daml_trigger:daml_trigger.bzl",
"daml_trigger_dar",
"daml_trigger_test",
)
load(
"//bazel_tools/data_dependencies:data_dependencies.bzl",
"data_dependencies_coins",
@ -103,29 +98,6 @@ head = "0.0.0"
if versions.is_at_least(sdk_version, platform_version)
] if not is_windows else None
# Change to `CommandId` generation
first_post_7587_trigger_version = "1.7.0-snapshot.20201012.5405.0.af92198d"
[
daml_trigger_dar(sdk_version)
for sdk_version in sdk_versions
if versions.is_at_least(first_post_7587_trigger_version, sdk_version)
]
[
daml_trigger_test(
compiler_version = sdk_version,
runner_version = platform_version,
)
for sdk_version in sdk_versions
for platform_version in platform_versions
# Test that the Daml trigger runner can run DARs built with an older SDK
# version. I.e. where the runner version is at least the SDK version or
# more recent.
if versions.is_at_least(first_post_7587_trigger_version, sdk_version) and
versions.is_at_least(sdk_version, platform_version)
]
[
data_dependencies_coins(
sdk_version = sdk_version,

View File

@ -54,13 +54,6 @@ We test that the Daml Script runner from a given SDK version can load
DARs built against the Daml Script library from an older SDK. We only
guarantee backwards compatibility here.
#### Backwards-compatibility for Daml Triggers
We test that the Daml Trigger runner from a given SDK version can load
DARs built against the Daml Script library from an older SDK. We only
guarantee backwards compatibility here.
#### Backwards-compatibility for data-dependencies
We test that we can import DARs built in older SDK versions via

View File

@ -1,4 +0,0 @@
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
exports_files(glob(["example/**"]))

View File

@ -1,239 +0,0 @@
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
load(
"@daml//bazel_tools/client_server:client_server_test.bzl",
"client_server_test",
)
load("//bazel_tools:versions.bzl", "version_to_name", "versions")
load("//bazel_tools:testing.bzl", "extra_tags")
def copy_trigger_src(sdk_version):
# To avoid having to mess with Bazels escaping, avoid `$` and backticks.
# We cant use CPP to make this nicer unfortunately since this doesnt work
# with an installed SDK.
return """
module CopyTrigger where
import DA.List hiding (dedup)
import Templates
import Daml.Trigger
copyTrigger : Trigger ()
copyTrigger = Trigger
{{ initialize = pure ()
, updateState = \\_message {acsArg} -> pure ()
, rule = copyRule
, registeredTemplates = AllInDar
, heartbeat = None
}}
copyRule party {acsArg} {stateArg} = do
subscribers : [(ContractId Subscriber, Subscriber)] <- {query} {acsArg}
originals : [(ContractId Original, Original)] <- {query} {acsArg}
copies : [(ContractId Copy, Copy)] <- {query} {acsArg}
let ownedSubscribers = filter (\\(_, s) -> s.subscribedTo == party) subscribers
let ownedOriginals = filter (\\(_, o) -> o.owner == party) originals
let ownedCopies = filter (\\(_, c) -> c.original.owner == party) copies
let subscribingParties = map (\\(_, s) -> s.subscriber) ownedSubscribers
let groupedCopies : [[(ContractId Copy, Copy)]]
groupedCopies = groupOn snd (sortOn snd ownedCopies)
let copiesToKeep = map head groupedCopies
let archiveDuplicateCopies = concatMap tail groupedCopies
let archiveMissingOriginal = filter (\\(_, c) -> notElem c.original (map snd ownedOriginals)) copiesToKeep
let archiveMissingSubscriber = filter (\\(_, c) -> notElem c.subscriber subscribingParties) copiesToKeep
let archiveCopies = dedup (map fst (archiveDuplicateCopies <> archiveMissingOriginal <> archiveMissingSubscriber))
forA archiveCopies (\\cid -> dedupExercise cid Archive)
let neededCopies = [Copy m o | (_, m) <- ownedOriginals, o <- subscribingParties]
let createCopies = filter (\\c -> notElem c (map snd copiesToKeep)) neededCopies
mapA dedupCreate createCopies
pure ()
dedup : Eq k => [k] -> [k]
dedup [] = []
dedup (x :: xs) = x :: dedup (filter (/= x) xs)
""".format(
stateArg = "_" if versions.is_at_most(last_pre_7674_version, sdk_version) else "",
acsArg = "acs" if versions.is_at_most(last_pre_7632_version, sdk_version) else "",
query = "(pure . getContracts)" if versions.is_at_most(last_pre_7632_version, sdk_version) else "query",
)
# Removal of state argument.
last_pre_7674_version = "1.7.0-snapshot.20201013.5418.0.bda13392"
# Removal of ACS argument
last_pre_7632_version = "1.7.0-snapshot.20201012.5405.0.af92198d"
def daml_trigger_dar(sdk_version):
daml = "@daml-sdk-{sdk_version}//:daml".format(
sdk_version = sdk_version,
)
native.genrule(
name = "trigger-example-dar-{sdk_version}".format(
sdk_version = version_to_name(sdk_version),
),
srcs = [
"//bazel_tools/daml_trigger:example/src/TestScript.daml",
"//bazel_tools/daml_trigger:example/src/Templates.daml",
],
outs = ["trigger-example-{sdk_version}.dar".format(
sdk_version = version_to_name(sdk_version),
)],
tools = [daml],
cmd = """\
set -euo pipefail
TMP_DIR=$$(mktemp -d)
cleanup() {{ rm -rf $$TMP_DIR; }}
trap cleanup EXIT
mkdir -p $$TMP_DIR/src
echo "{copy_trigger}" > $$TMP_DIR/src/CopyTrigger.daml
cp -L $(location //bazel_tools/daml_trigger:example/src/TestScript.daml) $$TMP_DIR/src/
cp -L $(location //bazel_tools/daml_trigger:example/src/Templates.daml) $$TMP_DIR/src/
cat <<EOF >$$TMP_DIR/daml.yaml
sdk-version: {sdk_version}
name: trigger-example
source: src
version: 0.0.1
dependencies:
- daml-prim
- daml-script
- daml-stdlib
- daml-trigger
EOF
$(location {daml}) build --project-root=$$TMP_DIR -o $$PWD/$(OUTS)
""".format(
daml = daml,
sdk_version = sdk_version,
copy_trigger = copy_trigger_src(sdk_version),
),
)
def daml_trigger_test(compiler_version, runner_version):
compiled_dar = "//:trigger-example-dar-{version}".format(
version = version_to_name(compiler_version),
)
daml_runner = "@daml-sdk-{version}//:daml".format(
version = runner_version,
)
name = "daml-trigger-test-compiler-{compiler_version}-runner-{runner_version}".format(
compiler_version = version_to_name(compiler_version),
runner_version = version_to_name(runner_version),
)
# 1.16.0 is the first SDK version that uses LF 1.14, which is the earliest version that canton supports
use_canton = versions.is_at_least("2.0.0", runner_version) and versions.is_at_least("1.16.0", compiler_version)
server = daml_runner
server_args = ["sandbox"] + (["--canton-port-file", "_port_file"] if (use_canton) else [])
server_files = ["$(rootpath {})".format(compiled_dar)]
server_files_prefix = "--dar=" if use_canton else ""
native.genrule(
name = "{}-client-sh".format(name),
outs = ["{}-client.sh".format(name)],
cmd = """\
cat >$(OUTS) <<'EOF'
#!/usr/bin/env bash
set -euo pipefail
canonicalize_rlocation() {{
# Note (MK): This is a fun one: Let's say $$TEST_WORKSPACE is "compatibility"
# and the argument points to a target from an external workspace, e.g.,
# @daml-sdk-0.0.0//:daml. Then the short path will point to
# ../daml-sdk-0.0.0/daml. Putting things together we end up with
# compatibility/../daml-sdk-0.0.0/daml. On Linux and MacOS this works
# just fine. However, on windows we need to normalize the path
# or rlocation will fail to find the path in the manifest file.
rlocation $$(realpath -L -s -m --relative-to=$$PWD $$TEST_WORKSPACE/$$1)
}}
runner=$$(canonicalize_rlocation $(rootpath {runner}))
# Cleanup the trigger runner process but maintain the script runner exit code.
trap 'status=$$?; kill -TERM $$PID; wait $$PID; exit $$status' INT TERM
SCRIPTOUTPUT=$$(mktemp -d)
if [ {wait_for_port_file} -eq 1 ]; then
timeout=60
while [ ! -e _port_file ]; do
if [ "$$timeout" = 0 ]; then
echo "Timed out waiting for Canton startup" >&2
exit 1
fi
sleep 1
timeout=$$((timeout - 1))
done
fi
if [ {upload_dar} -eq 1 ] ; then
$$runner ledger upload-dar \\
--host localhost \\
--port 6865 \\
$$(canonicalize_rlocation $(rootpath {dar}))
fi
$$runner script \\
--ledger-host localhost \\
--ledger-port 6865 \\
--wall-clock-time \\
--dar $$(canonicalize_rlocation $(rootpath {dar})) \\
--script-name TestScript:allocateAlice \\
--output-file $$SCRIPTOUTPUT/alice.json
ALICE=$$(cat $$SCRIPTOUTPUT/alice.json | sed 's/"//g')
rm -rf $$SCRIPTOUTPUT
$$runner trigger \\
--ledger-host localhost \\
--ledger-port 6865 \\
--ledger-party $$ALICE \\
--wall-clock-time \\
--dar $$(canonicalize_rlocation $(rootpath {dar})) \\
--trigger-name CopyTrigger:copyTrigger &
PID=$$!
$$runner script \\
--ledger-host localhost \\
--ledger-port 6865 \\
--wall-clock-time \\
--dar $$(canonicalize_rlocation $(rootpath {dar})) \\
--script-name TestScript:test
EOF
chmod +x $(OUTS)
""".format(
dar = compiled_dar,
runner = daml_runner,
upload_dar = "0",
wait_for_port_file = "1" if use_canton else "0",
),
exec_tools = [
compiled_dar,
daml_runner,
],
)
native.sh_binary(
name = "{}-client".format(name),
srcs = ["{}-client.sh".format(name)],
data = [
compiled_dar,
daml_runner,
],
)
client_server_test(
name = "daml-trigger-test-compiler-{compiler_version}-runner-{runner_version}".format(
compiler_version = version_to_name(compiler_version),
runner_version = version_to_name(runner_version),
),
client = "{}-client".format(name),
client_args = [],
client_files = [],
data = [
compiled_dar,
],
runner = "//bazel_tools/client_server:runner",
runner_args = ["6865"],
server = server,
server_args = server_args,
server_files = server_files,
server_files_prefix = server_files_prefix,
tags = extra_tags(compiler_version, runner_version) + ["exclusive"],
)

View File

@ -1,44 +0,0 @@
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0
module Templates where
-- ORIGINAL_TEMPLATE_BEGIN
template Original
with
owner : Party
name : Text
textdata : Text
where
signatory owner
key (owner, name) : (Party, Text)
maintainer key._1
-- ORIGINAL_TEMPLATE_END
deriving instance Ord Original
-- SUBSCRIBER_TEMPLATE_BEGIN
template Subscriber
with
subscriber : Party
subscribedTo : Party
where
signatory subscriber
observer subscribedTo
key (subscriber, subscribedTo) : (Party, Party)
maintainer key._1
-- SUBSCRIBER_TEMPLATE_END
-- COPY_TEMPLATE_BEGIN
template Copy
with
original : Original
subscriber : Party
where
signatory (signatory original)
observer subscriber
-- COPY_TEMPLATE_END
deriving instance Ord Copy

View File

@ -1,66 +0,0 @@
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0
module TestScript where
import Templates
import DA.Assert
import DA.Time
import Daml.Script
allocateAlice : Script Party
allocateAlice = do
debug "Creating Alice ..."
alice <- allocatePartyWithHint "Alice" (PartyIdHint "Alice")
debug alice
debug "... done"
pure alice
test : Script ()
test = do
debug "Searching for Alice ..."
let isAlice x = displayName x == Some "Alice"
Some aliceDetails <- find isAlice <$> listKnownParties
let alice = party aliceDetails
debug alice
debug "... done"
debug "Creating Bob ..."
bob <- allocatePartyWithHint "Bob" (PartyIdHint "Bob")
debug alice
debug bob
debug "... done"
debug "Creating Subscriber ..."
submit bob $ do
createCmd (Subscriber bob alice)
debug "... done"
debug "Creating Original ..."
let original = Original alice "original" "data"
submit alice $ do
createCmd original
debug "... done"
debug "Waiting for copy ..."
copy <- until $ do
copies <- query @Copy bob
case copies of
[(_, copy)] -> pure (Some copy)
xs -> do
debug xs
pure None
debug "... done"
debug "Asserting equality ..."
assertEq (Copy original bob) copy
debug "... done"
until : Script (Optional a) -> Script a
until action = do
result <- action
case result of
Some a -> pure a
None -> do
sleep (convertMicrosecondsToRelTime 10000)
until action

File diff suppressed because it is too large Load Diff

View File

@ -492,11 +492,10 @@ fileTest externalAnchors scriptPackageData damlFile = do
other -> error $ "Unsupported file extension " <> other
where
diff ref new = [POSIX_DIFF, "--strip-trailing-cr", ref, new]
-- In cases where daml-script/daml-trigger is used, the version of the package is embedded in the json.
-- In cases where daml-script is used, the version of the package is embedded in the json.
-- When we release, this version changes, which would break the golden file test.
-- Instead, we omit daml-script/daml-trigger versions from .EXPECTED.json files in golden tests.
-- Instead, we omit daml-script versions from .EXPECTED.json files in golden tests.
replaceSdkPackages =
TL.encodeUtf8
. TL.replace (TL.pack $ "daml-script-" <> sdkPackageVersion) "daml-script-UNVERSIONED"
. TL.replace (TL.pack $ "daml-trigger-" <> sdkPackageVersion) "daml-trigger-UNVERSIONED"
. TL.decodeUtf8

View File

@ -1,7 +1,7 @@
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0
-- | For compiler level warnings on packages that aren't included as standard. This includes, at the very least, Daml.Script and Triggers
-- | For compiler level warnings on packages that aren't included as standard. This includes, at the very least Daml.Script
module DA.Daml.LFConversion.ExternalWarnings (topLevelWarnings) where
import DA.Daml.LFConversion.ConvertM

View File

@ -289,9 +289,9 @@ dataDependableExtensions :: ES.EnumSet Extension
dataDependableExtensions = ES.fromList $ xExtensionsSet ++
[ -- useful for beginners to learn about type inference
PartialTypeSignatures
-- needed for script and triggers
-- needed for script
, ApplicativeDo
-- used in daml-stdlib and triggers and a very reasonable
-- used in daml-stdlib and very reasonable
-- extension in general in the presence of TypeApplications
, AllowAmbiguousTypes
-- helpful for documentation purposes
@ -306,7 +306,7 @@ dataDependableExtensions = ES.fromList $ xExtensionsSet ++
-- would be silly
, ExplicitNamespaces
-- there's no way for our users to actually use this and listing it here
-- removes a lot of warning from out stdlib, script and trigger builds
-- removes a lot of warning from out stdlib and script builds
-- NOTE: This should not appear on any list of extensions that are
-- compatible with data-dependencies since this would spur wrong hopes.
, Cpp
@ -544,14 +544,14 @@ checkDFlags Options {..} dflags@DynFlags {..}
\ or the dependency."
-- Expand SDK package dependencies using the SDK root path.
-- E.g. `daml-trigger` --> `$DAML_SDK/daml-libs/daml-trigger.dar`
-- E.g. `daml-script` --> `$DAML_SDK/daml-libs/daml-script.dar`
-- When invoked outside of the SDK, we will only error out
-- if there is actually an SDK package so that
-- When there is no SDK
expandSdkPackages :: Logger.Handle IO -> LF.Version -> [FilePath] -> IO [FilePath]
expandSdkPackages logger lfVersion dars = do
mbSdkPath <- handleIO (\_ -> pure Nothing) $ Just <$> getSdkPath
mapM (expand mbSdkPath) (nubOrd $ concatMap addDep dars)
mapM (expand mbSdkPath) (nubOrd dars)
where
isSdkPackage fp = takeExtension fp `notElem` [".dar", ".dalf"]
isInvalidDaml3Script = \case
@ -569,14 +569,6 @@ expandSdkPackages logger lfVersion dars = do
pure $ sdkPath </> "daml-libs" </> fp <> sdkSuffix <.> "dar"
Nothing -> fail $ "Cannot resolve SDK dependency '" ++ fp ++ "'. Use daml assistant."
| otherwise = pure fp
-- For `dependencies` you need to specify all transitive dependencies.
-- However, for the packages in the SDK that is an implementation detail
-- so we automagically insert `daml-script` if youve specified `daml-trigger`.
addDep fp
| isSdkPackage fp = fp : Map.findWithDefault [] fp sdkDependencies
| otherwise = [fp]
sdkDependencies = Map.fromList
[ ("daml-trigger", ["daml-script"]) ]
mkPackageFlag :: UnitId -> PackageFlag

View File

@ -28,8 +28,8 @@ error msg = primitive @"EThrow" (GeneralError msg)
-- | `error` stops execution and displays the given error message.
--
-- If called within a transaction, it will abort the current transaction.
-- Outside of a transaction (scenarios, Daml Script or Daml Triggers)
-- it will stop the whole scenario/script/trigger.
-- Outside of a transaction (scenarios and Daml Script)
-- it will stop the whole scenario/script.
error : Text -> a
error = primitive @"BEError"

View File

@ -23,7 +23,7 @@ import DA.Internal.Exception
-- in this context.
class Action m => CanAssert m where
-- | Abort since an assertion has failed. In an Update, Scenario,
-- Script, or Trigger context this will throw an AssertionFailed
-- or Script context this will throw an AssertionFailed
-- exception. In an `Either Text` context, this will return the
-- message as an error.
assertFail : Text -> m t

View File

@ -161,7 +161,7 @@ withVersionedDamlScriptDep packageFlagName darPath mLfVer extraPackages cont = d
withCurrentDirectory dir $ do
let projDir = toNormalizedFilePath' dir
-- Bring in daml-script as previously installed by withDamlScriptDep, must include package db
-- daml-script and daml-triggers use the sdkPackageVersion for their versioning
-- daml-script use the sdkPackageVersion for their versioning
mkPackageFlag flagName = ExposePackage ("--package " <> flagName) (UnitIdArg $ stringToUnitId flagName) (ModRenaming True [])
toPackageName (name, version) = name <> "-" <> version
packageFlags = mkPackageFlag <$> packageFlagName : (toPackageName <$> extraPackages)

View File

@ -9,10 +9,6 @@ load(
load("@os_info//:os_info.bzl", "is_windows")
load(":util.bzl", "deps")
scala_deps = [
"@maven//:com_typesafe_scala_logging_scala_logging",
]
scala_runtime_deps = [
"@maven//:org_apache_pekko_pekko_slf4j",
"@maven//:org_tpolecat_doobie_postgres",
@ -30,7 +26,7 @@ da_scala_binary(
srcs = glob(["src/main/scala/**/*.scala"]),
main_class = "com.daml.sdk.SdkMain",
resources = glob(["src/main/resources/**/*"]),
scala_deps = scala_deps,
scala_deps = [],
scala_runtime_deps = scala_runtime_deps,
visibility = ["//visibility:public"],
runtime_deps = runtime_deps,
@ -42,7 +38,7 @@ da_scala_binary(
srcs = glob(["src/main/scala/**/*.scala"]),
main_class = "com.daml.sdk.SdkMain",
resources = glob(["src/main/resources/**/*"]),
scala_deps = scala_deps,
scala_deps = [],
scala_runtime_deps = scala_runtime_deps,
tags = ["ee-jar-license"],
visibility = ["//visibility:public"],

View File

@ -5,9 +5,6 @@ package com.daml.sdk
import com.daml.codegen.{CodegenMain => Codegen}
import com.daml.lf.engine.script.{ScriptMain => Script}
import com.daml.lf.engine.trigger.{RunnerMain => Trigger}
import com.daml.lf.engine.trigger.{ServiceMain => TriggerService}
import com.daml.auth.middleware.oauth2.{Main => Oauth2Middleware}
import com.daml.script.export.{Main => Export}
object SdkMain {
@ -15,12 +12,9 @@ object SdkMain {
val command = args(0)
val rest = args.drop(1)
command match {
case "trigger" => Trigger.main(rest)
case "script" => Script.main(rest)
case "export" => Export.main(rest)
case "codegen" => Codegen.main(rest)
case "trigger-service" => TriggerService.main(rest)
case "oauth2-middleware" => Oauth2Middleware.main(rest)
case _ => sys.exit(1)
}
}

View File

@ -6,7 +6,4 @@ def deps(edition):
"//daml-script/runner:script-runner-lib",
"//language-support/codegen-main:codegen-main-lib",
"//daml-script/export",
"//triggers/runner:trigger-runner-lib",
"//triggers/service:trigger-service-binary-{}".format(edition),
"//triggers/service/auth:oauth2-middleware",
]

View File

@ -18,10 +18,3 @@ set -eou pipefail
JAVA=$(rlocation "$TEST_WORKSPACE/$1")
SDK_CE=$(rlocation "$TEST_WORKSPACE/$2")
SDK_EE=$(rlocation "$TEST_WORKSPACE/$3")
if ! ($JAVA -jar $SDK_EE trigger-service --help | grep -q oracle); then
exit 1
fi
if $JAVA -jar $SDK_CE trigger-service --help | grep -q oracle; then
exit 1
fi

View File

@ -370,7 +370,7 @@ argWhitelist = S.fromList
, "install", "latest", "project"
, "uninstall"
, "studio", "never", "always", "published"
, "new", "skeleton", "empty-skeleton", "quickstart-java", "copy-trigger", "gsg-trigger"
, "new", "skeleton", "empty-skeleton", "quickstart-java"
, "daml-intro-1", "daml-intro-2", "daml-intro-3", "daml-intro-4"
, "daml-intro-5", "daml-intro-6", "daml-intro-7", "script-example"
, "daml-intro-13"
@ -386,8 +386,6 @@ argWhitelist = S.fromList
, "codegen", "java", "js"
, "deploy"
, "json-api"
, "trigger", "trigger-service", "list"
, "oauth2-middleware"
, "script"
]

View File

@ -31,7 +31,7 @@ import Test.Tasty.HUnit
import DA.Bazel.Runfiles
import DA.Daml.Assistant.IntegrationTestUtils
import DA.Daml.Helper.Util (waitForHttpServer, tokenFor, decodeCantonSandboxPort)
import DA.Daml.Helper.Util (tokenFor, decodeCantonSandboxPort)
import DA.Test.Daml2jsUtils
import DA.Test.Process (callCommandSilent, callCommandSilentIn, subprocessEnv)
import DA.Test.Util
@ -182,7 +182,6 @@ tests tmpDir =
, testCase "daml new --list" $
callCommandSilentIn tmpDir "daml new --list"
, packagingTests tmpDir
, damlToolTests
, withResource (damlStart (tmpDir </> "sandbox-canton")) stop damlStartTests
, cleanTests cleanDir
, templateTests
@ -200,53 +199,7 @@ packagingTests :: SdkVersioned => FilePath -> TestTree
packagingTests tmpDir =
testGroup
"packaging"
[ testCase "Build copy trigger" $ do
let projDir = tmpDir </> "copy-trigger1"
callCommandSilent $ unwords ["daml", "new", projDir, "--template=copy-trigger"]
callCommandSilentIn projDir "daml build"
let dar = projDir </> ".daml" </> "dist" </> "copy-trigger1-0.0.1.dar"
assertFileExists dar
, testCase "Build copy trigger with LF version 1.dev" $ do
let projDir = tmpDir </> "copy-trigger2"
callCommandSilent $ unwords ["daml", "new", projDir, "--template=copy-trigger"]
callCommandSilentIn projDir "daml build --target 1.dev"
let dar = projDir </> ".daml" </> "dist" </> "copy-trigger2-0.0.1.dar"
assertFileExists dar
, testCase "Build trigger with extra dependency" $ do
let myDepDir = tmpDir </> "mydep"
createDirectoryIfMissing True (myDepDir </> "daml")
writeFileUTF8 (myDepDir </> "daml.yaml") $
unlines
[ "sdk-version: " <> sdkVersion
, "name: mydep"
, "version: \"1.0\""
, "source: daml"
, "dependencies:"
, " - daml-prim"
, " - daml-stdlib"
]
writeFileUTF8 (myDepDir </> "daml" </> "MyDep.daml") $ unlines ["module MyDep where"]
callCommandSilentIn myDepDir "daml build -o mydep.dar"
let myTriggerDir = tmpDir </> "mytrigger"
createDirectoryIfMissing True (myTriggerDir </> "daml")
writeFileUTF8 (myTriggerDir </> "daml.yaml") $
unlines
[ "sdk-version: " <> sdkVersion
, "name: mytrigger"
, "version: \"1.0\""
, "source: daml"
, "dependencies:"
, " - daml-prim"
, " - daml-stdlib"
, " - daml-trigger"
, " - " <> myDepDir </> "mydep.dar"
]
writeFileUTF8 (myTriggerDir </> "daml/Main.daml") $
unlines ["module Main where", "import MyDep ()", "import Daml.Trigger ()"]
callCommandSilentIn myTriggerDir "daml build -o mytrigger.dar"
let dar = myTriggerDir </> "mytrigger.dar"
assertFileExists dar
, testCase "Build Daml script example" $ do
[ testCase "Build Daml script example" $ do
let projDir = tmpDir </> "script-example"
callCommandSilent $ unwords ["daml", "new", projDir, "--template=script-example"]
callCommandSilentIn projDir "daml build"
@ -259,7 +212,7 @@ packagingTests tmpDir =
callCommandSilentIn projDir "daml build --target 1.dev"
let dar = projDir </> ".daml/dist/script-example-0.0.1.dar"
assertFileExists dar -}
, testCase "Package depending on daml-script and daml-trigger can use data-dependencies" $ do
, testCase "Package depending on daml-script can use data-dependencies" $ do
callCommandSilent $ unwords ["daml", "new", tmpDir </> "data-dependency"]
callCommandSilentIn (tmpDir </> "data-dependency") "daml build -o data-dependency.dar"
createDirectoryIfMissing True (tmpDir </> "proj")
@ -269,7 +222,7 @@ packagingTests tmpDir =
, "name: proj"
, "version: 0.0.1"
, "source: ."
, "dependencies: [daml-prim, daml-stdlib, daml-script, daml-trigger]"
, "dependencies: [daml-prim, daml-stdlib, daml-script]"
, "data-dependencies: [" <>
show (tmpDir </> "data-dependency" </> "data-dependency.dar") <>
"]"
@ -285,39 +238,6 @@ packagingTests tmpDir =
callCommandSilentIn (tmpDir </> "proj") "daml build"
]
-- Test tools that can run outside a daml project
damlToolTests :: TestTree
damlToolTests =
testGroup
"daml tools"
[ testCase "OAuth 2.0 middleware startup" $ do
withTempDir $ \tmpDir -> do
middlewarePort <- getFreePort
withDamlServiceIn tmpDir "oauth2-middleware"
[ "--address"
, "localhost"
, "--http-port"
, show middlewarePort
, "--oauth-auth"
, "http://localhost:0/authorize"
, "--oauth-token"
, "http://localhost:0/token"
, "--auth-jwt-hs256-unsafe"
, "jwt-secret"
, "--id"
, "client-id"
, "--secret"
, "client-secret"
] $ \ ph -> do
let endpoint =
"http://localhost:" <> show middlewarePort <> "/livez"
waitForHttpServer 240 ph (threadDelay 500000) endpoint []
req <- parseRequest endpoint
manager <- newManager defaultManagerSettings
resp <- httpLbs req manager
responseBody resp @?= "{\"status\":\"pass\"}"
]
-- We are trying to run as many tests with the same `daml start` process as possible to safe time.
damlStartTests :: SdkVersioned => IO DamlStartResource -> TestTree
damlStartTests getDamlStart =
@ -394,24 +314,6 @@ damlStartTests getDamlStart =
didGenerateDamlYaml <- doesFileExist (exportDir </> "daml.yaml")
didGenerateExportDaml @?= True
didGenerateDamlYaml @?= True
subtest "trigger service startup" $ do
DamlStartResource {projDir, sandboxPort} <- getDamlStart
triggerServicePort <- getFreePort
withDamlServiceIn projDir "trigger-service"
[ "--ledger-host"
, "localhost"
, "--ledger-port"
, show sandboxPort
, "--http-port"
, show triggerServicePort
, "--wall-clock-time"
] $ \ ph -> do
let endpoint = "http://localhost:" <> show triggerServicePort <> "/livez"
waitForHttpServer 240 ph (threadDelay 500000) endpoint []
req <- parseRequest endpoint
manager <- newManager defaultManagerSettings
resp <- httpLbs req manager
responseBody resp @?= "{\"status\":\"pass\"}"
subtest "hot reload" $ do
DamlStartResource {projDir, jsonApiPort, startStdin, stdoutChan, alice, aliceHeaders} <- getDamlStart
@ -529,10 +431,8 @@ templateTests = testGroup "templates" $
-- NOTE (MK) We might want to autogenerate this list at some point but for now
-- this should be good enough.
where templateNames =
[ "copy-trigger"
, "gsg-trigger"
-- daml-intro-1 - daml-intro-6 are not full projects.
, "daml-intro-7"
[ -- daml-intro-1 - daml-intro-6 are not full projects.
"daml-intro-7"
, "daml-patterns"
, "quickstart-java"
, "script-example"

View File

@ -343,7 +343,7 @@ private[lf] final class ValueTranslator(
// This does not try to pull missing packages, return an error instead.
// TODO: https://github.com/digital-asset/daml/issues/17082
// This is used by script and trigger, this should problaby use ValueTranslator.Config.Strict
// This is used by script, this should problaby use ValueTranslator.Config.Strict
def strictTranslateValue(
ty: Type,
value: Value,

View File

@ -35,7 +35,6 @@ da_scala_library(
"//daml-lf:__subpackages__",
"//daml-script:__subpackages__",
"//ledger:__subpackages__",
"//triggers:__subpackages__",
],
deps = [
"//daml-lf/data",

View File

@ -10,7 +10,7 @@ private[lf] sealed abstract class InitialSeeding extends Product with Serializab
private[lf] object InitialSeeding {
// NoSeed may be used to initialize machines that are not intended to create transactions
// e.g. trigger and script runners, tests
// e.g. script runners, tests
final case object NoSeed extends InitialSeeding
final case class TransactionSeed(seed: crypto.Hash) extends InitialSeeding
final case class RootNodeSeeds(seeds: ImmArray[Option[crypto.Hash]]) extends InitialSeeding

View File

@ -21,7 +21,6 @@ da_scala_library(
visibility = [
"//daml-lf:__subpackages__",
"//ledger:__subpackages__",
"//triggers:__subpackages__",
],
deps = [
"//daml-lf/data",

View File

@ -28,14 +28,12 @@ genrule(
srcs = [
"//compiler/damlc:daml-base-hoogle.txt",
"//daml-script/daml:daml-script-hoogle.txt",
"//triggers/daml:daml-trigger-hoogle.txt",
],
outs = ["hoogle_db.tar.gz"],
cmd = """
mkdir hoogle
cp -L $(location //compiler/damlc:daml-base-hoogle.txt) hoogle/
cp -L $(location //daml-script/daml:daml-script-hoogle.txt) hoogle/
cp -L $(location //triggers/daml:daml-trigger-hoogle.txt) hoogle/
$(execpath //bazel_tools/sh:mktgz) $@ hoogle
""",
tools = ["//bazel_tools/sh:mktgz"],
@ -76,7 +74,6 @@ genrule(
name = "sources",
srcs = glob(["source/**"]) + [
"//compiler/damlc:daml-base-rst.tar.gz",
"//triggers/daml:daml-trigger-rst.tar.gz",
"//daml-script/daml:daml-script-rst.tar.gz",
"//canton:ledger-api-docs",
"//:LICENSE",
@ -102,12 +99,6 @@ genrule(
--strip-components 1 \\
-C source/daml/stdlib
# Copy in daml-trigger documentation
mkdir -p source/triggers/api/
tar xf $(location //triggers/daml:daml-trigger-rst.tar.gz) \\
--strip-components 1 \\
-C source/triggers/api/
# Copy in daml-script documentation
mkdir -p source/daml-script/api/
tar xf $(location //daml-script/daml:daml-script-rst.tar.gz) \\
@ -140,7 +131,6 @@ genrule(
":generate-docs-error-codes-inventory-into-rst-file",
":generate-docs-error-categories-inventory-into-rst-file",
"//compiler/damlc:daml-base-rst.tar.gz",
"//triggers/daml:daml-trigger-rst.tar.gz",
"//daml-script/daml:daml-script-rst.tar.gz",
"//canton:ledger-api-docs",
"//:LICENSE",
@ -166,7 +156,6 @@ genrule(
cp -L -- $(location //docs:generate-docs-error-codes-inventory-into-rst-file) $$DIR/deps/
cp -L -- $(location //docs:generate-docs-error-categories-inventory-into-rst-file) $$DIR/deps/
cp $(location //compiler/damlc:daml-base-rst.tar.gz) $$DIR/deps/
cp $(location //triggers/daml:daml-trigger-rst.tar.gz) $$DIR/deps/
cp $(location //daml-script/daml:daml-script-rst.tar.gz) $$DIR/deps/
cp -L $(location //canton:ledger-api-docs) $$DIR/deps/
cp -L $(location //:LICENSE) $$DIR/deps/
@ -622,18 +611,6 @@ daml_build_test(
project_dir = "source/upgrade/example/carbon-initiate-upgrade",
)
daml_build_test(
name = "daml-upgrade-example-upgrade-trigger",
dar_dict = {
":daml-upgrade-example-v1": "path/to/carbon-1.0.0.dar",
":daml-upgrade-example-v2": "path/to/carbon-2.0.0.dar",
":daml-upgrade-example-upgrade": "path/to/carbon-upgrade-1.0.0.dar",
"//triggers/daml:daml-trigger.dar": "daml-trigger.dar",
"//daml-script/daml:daml-script.dar": "daml-script.dar",
},
project_dir = "source/upgrade/example/carbon-upgrade-trigger",
)
filegroup(
name = "daml-intro-1",
srcs = glob(

View File

@ -791,8 +791,6 @@ daml/intro/9_Functional101.html -> /daml/intro/10_Functional101.html
daml/intro/10_StdLib.html -> /daml/intro/11_StdLib.html
daml/intro/11_Testing.html -> /daml/intro/12_Testing.html
daml-script/daml-script-docs.html -> /daml-script/api/index.html
triggers/trigger-docs.html -> /triggers/api/index.html
tools/trigger-service.html -> trigger-service/index.html
support.html -> /support/support.html
release-notes.html -> /support/releases.html#release-notes
support/release-notes.html -> /support/releases.html#release-notes

View File

@ -1,18 +0,0 @@
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
sdk-version: 0.0.0
# BEGIN
name: carbon-upgrade-trigger
version: 1.0.0
dependencies:
- daml-prim
- daml-stdlib
- daml-trigger
- daml-script
data-dependencies:
- path/to/carbon-upgrade-1.0.0.dar
- path/to/carbon-1.0.0.dar
- path/to/carbon-2.0.0.dar
# END
source: .

View File

@ -1,65 +0,0 @@
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0
{-# LANGUAGE CPP #-}
module UpgradeTrigger where
import DA.Assert
import DA.Foldable
import qualified DA.Map as Map
import Daml.Trigger
import Daml.Trigger.Assert
import qualified Daml.Script as Script
import Daml.Script (script)
import CarbonV1
import UpgradeFromCarbonCertV1
-- TRIGGER_BOILERPLATE_BEGIN
upgradeTrigger : Trigger ()
upgradeTrigger = Trigger with
initialize = pure ()
updateState = \_msg -> pure ()
registeredTemplates = AllInDar
heartbeat = None
rule = triggerRule
-- TRIGGER_BOILERPLATE_END
-- TRIGGER_RULE_BEGIN
triggerRule : Party -> TriggerA () ()
triggerRule issuer = do
agreements <-
filter (\(_cid, agreement) -> agreement.issuer == issuer) <$>
query @UpgradeCarbonCertAgreement
allCerts <-
filter (\(_cid, cert) -> cert.issuer == issuer) <$>
query @CarbonCert
forA_ agreements $ \(agreementCid, agreement) -> do
let certsForOwner = filter (\(_cid, cert) -> cert.owner == agreement.owner) allCerts
forA_ certsForOwner $ \(certCid, _) ->
emitCommands
[exerciseCmd agreementCid (Upgrade certCid)]
[toAnyContractId certCid]
-- TRIGGER_RULE_END
-- TODO (MK) The Bazel rule atm doesnt run this script, we should fix that.
test = script do
alice <- Script.allocateParty "Alice"
bob <- Script.allocateParty "Bob"
certProposal <- submit alice $ Script.createCmd (CarbonCertProposal alice bob 10)
cert <- submit bob $ Script.exerciseCmd certProposal CarbonCertProposal_Accept
upgradeProposal <- submit alice $ Script.createCmd (UpgradeCarbonCertProposal alice bob)
upgradeAgreement <- submit bob $ Script.exerciseCmd upgradeProposal Accept
let acs = toACS cert <> toACS upgradeAgreement
(_, commands) <- testRule upgradeTrigger alice [] acs Map.empty ()
let flatCommands = flattenCommands commands
assertExerciseCmd flatCommands $ \(cid, choiceArg) -> do
cid === upgradeAgreement
choiceArg === Upgrade cert
-- TODO (MK) It would be nice to test for the absence of certain commands as well
-- or ideally just assert that the list of emitted commands matches an expected
-- list of commands.

View File

@ -209,8 +209,6 @@
type: jar-scala
- target: //libs-scala/timer-utils:timer-utils
type: jar-scala
- target: //triggers/runner:trigger-runner-lib
type: jar-scala
- target: //libs-scala/nameof:nameof
type: jar-scala
- target: //libs-scala/struct-json/struct-spray-json:struct-spray-json

View File

@ -71,18 +71,6 @@ commands:
- name: ide
path: damlc/damlc
args: ["ide"]
- name: trigger-service
path: daml-helper/daml-helper
desc: "Launch the trigger service"
args: ["run-jar", "--logback-config=daml-sdk/trigger-service-logback.xml", "daml-sdk/daml-sdk.jar", "trigger-service"]
- name: oauth2-middleware
path: daml-helper/daml-helper
desc: "Launch the OAuth 2.0 middleware"
args: ["run-jar", "--logback-config=daml-sdk/oauth2-middleware-logback.xml", "daml-sdk/daml-sdk.jar", "oauth2-middleware"]
- name: trigger
path: daml-helper/daml-helper
args: ["run-jar", "--logback-config=daml-sdk/trigger-logback.xml", "daml-sdk/daml-sdk.jar", "trigger"]
desc: "Run a Daml trigger"
- name: script
path: daml-helper/daml-helper
args: ["run-jar", "--logback-config=daml-sdk/script-logback.xml", "daml-sdk/daml-sdk.jar", "script"]

View File

@ -8,9 +8,6 @@ inputs = {
"sdk_config": ":sdk-config.yaml.tmpl",
"install_sh": ":install.sh",
"install_bat": ":install.bat",
"oauth2_middleware_logback": "//triggers/service/auth:release/oauth2-middleware-logback.xml",
"trigger_service_logback": "//triggers/service:release/trigger-service-logback.xml",
"trigger_logback": "//triggers/runner:src/main/resources/logback.xml",
"java_codegen_logback": "//language-support/java/codegen:src/main/resources/logback.xml",
"daml_script_logback": "//daml-script/runner:src/main/resources/logback.xml",
"export_logback": "//daml-script/export:src/main/resources/logback.xml",
@ -22,7 +19,6 @@ inputs = {
"daml_extension_stylesheet": "//compiler/daml-extension:webview-stylesheet.css",
"daml2js_dist": "//language-support/ts/codegen:daml2js-dist",
"templates": "//templates:templates-tarball.tar.gz",
"trigger_dars": "//triggers/daml:daml-trigger-dars",
"script_dars": "//daml-script/daml:daml-script-dars",
"script3_dars": "//daml-script/daml3:daml3-script-dars",
"canton": "//canton:community_app_deploy.jar",
@ -74,7 +70,6 @@ def sdk_tarball(name, version, config):
tar xf $(location {damlc_dist}) --strip-components=1 -C $$OUT/damlc
mkdir -p $$OUT/daml-libs
cp -t $$OUT/daml-libs $(locations {trigger_dars})
cp -t $$OUT/daml-libs $(locations {script_dars})
cp -t $$OUT/daml-libs $(locations {script3_dars})
@ -96,10 +91,7 @@ def sdk_tarball(name, version, config):
mkdir -p $$OUT/daml-sdk
cp $(location {sdk_deploy_jar}) $$OUT/daml-sdk/daml-sdk.jar
cp -L $(location {trigger_service_logback}) $$OUT/daml-sdk/
cp -L $(location {oauth2_middleware_logback}) $$OUT/daml-sdk/
cp -L $(location {java_codegen_logback}) $$OUT/daml-sdk/codegen-logback.xml
cp -L $(location {trigger_logback}) $$OUT/daml-sdk/trigger-logback.xml
cp -L $(location {daml_script_logback}) $$OUT/daml-sdk/script-logback.xml
cp -L $(location {export_logback}) $$OUT/daml-sdk/export-logback.xml

View File

@ -1,7 +1,6 @@
# Security tests, by category
## Authorization:
- auth and auth-* should not be set together for the trigger service: [CliConfigTest.scala](triggers/service/src/test-suite/scala/com/daml/lf/engine/trigger/CliConfigTest.scala#L40)
- badly-authorized create is rejected: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L76)
- badly-authorized exercise is rejected: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L177)
- badly-authorized exercise/create (create is unauthorized) is rejected: [AuthPropagationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthPropagationSpec.scala#L267)
@ -11,8 +10,6 @@
- badly-authorized lookup is rejected: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L134)
- create with no signatories is rejected: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L66)
- create with non-signatory maintainers is rejected: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L88)
- error on specifying both authCommonUri and authInternalUri/authExternalUri for the trigger service: [AuthorizationConfigTest.scala](triggers/service/src/test-suite/scala/com/daml/lf/engine/trigger/AuthorizationConfigTest.scala#L24)
- error on specifying only authInternalUri and no authExternalUri for the trigger service: [AuthorizationConfigTest.scala](triggers/service/src/test-suite/scala/com/daml/lf/engine/trigger/AuthorizationConfigTest.scala#L52)
- exercise with no controllers is rejected: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L167)
- well-authorized create is accepted: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L59)
- well-authorized exercise is accepted: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L160)

View File

@ -4,7 +4,6 @@ load("@os_info//:os_info.bzl", "is_windows")
load("@build_environment//:configuration.bzl", "sdk_version")
exports_files(glob(["create-daml-app-test-resources/*"]) + [
"copy-trigger/src/CopyTrigger.daml",
"create-daml-app/ui/package.json.template",
])
@ -52,8 +51,6 @@ genrule(
"empty-skeleton/**",
"create-daml-app/**",
"quickstart-java/**",
"copy-trigger/**",
"gsg-trigger.patch",
],
exclude = ["**/NO_AUTO_COPYRIGHT"],
) + [
@ -76,17 +73,11 @@ genrule(
cp -rL templates/* $$SRC/
PATCH_TOOL=$$PWD/$(location @patch_dev_env//:patch)
cp -rL $$SRC/create-daml-app $$SRC/gsg-trigger
"$$PATCH_TOOL" -d $$SRC/gsg-trigger -p1 < $$SRC/gsg-trigger.patch
# templates in templates dir
for d in skeleton \
empty-skeleton \
create-daml-app \
quickstart-java \
copy-trigger \
gsg-trigger; do
quickstart-java; do
mkdir -p $$OUT/$$d
cp -rL $$SRC/$$d/* $$OUT/$$d/
for f in gitattributes gitignore dlint.yaml; do

View File

@ -1,12 +0,0 @@
sdk-version: __VERSION__
name: __PROJECT_NAME__
source: src
version: 0.0.1
# trigger-dependencies-begin
dependencies:
- daml-prim
- daml-stdlib
- daml-trigger
# trigger-dependencies-end
sandbox-options:
- --wall-clock-time

View File

@ -1,109 +0,0 @@
-- Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0
module CopyTrigger where
import DA.List hiding (dedup)
import Daml.Trigger
-- ORIGINAL_TEMPLATE_BEGIN
template Original
with
owner : Party
name : Text
textdata : Text
where
signatory owner
key (owner, name) : (Party, Text)
maintainer key._1
-- ORIGINAL_TEMPLATE_END
deriving instance Ord Original
-- SUBSCRIBER_TEMPLATE_BEGIN
template Subscriber
with
subscriber : Party
subscribedTo : Party
where
signatory subscriber
observer subscribedTo
key (subscriber, subscribedTo) : (Party, Party)
maintainer key._1
-- SUBSCRIBER_TEMPLATE_END
-- COPY_TEMPLATE_BEGIN
template Copy
with
original : Original
subscriber : Party
where
signatory (signatory original)
observer subscriber
-- COPY_TEMPLATE_END
deriving instance Ord Copy
-- TRIGGER_BEGIN
copyTrigger : Trigger ()
copyTrigger = Trigger
{ initialize = pure ()
, updateState = \_message -> pure ()
, rule = copyRule
, registeredTemplates = AllInDar
, heartbeat = None
}
-- TRIGGER_END
-- RULE_SIGNATURE_BEGIN
copyRule : Party -> TriggerA () ()
copyRule party = do
-- RULE_SIGNATURE_END
-- ACS_QUERY_BEGIN
subscribers : [(ContractId Subscriber, Subscriber)] <- query @Subscriber
originals : [(ContractId Original, Original)] <- query @Original
copies : [(ContractId Copy, Copy)] <- query @Copy
-- ACS_QUERY_END
-- ACS_FILTER_BEGIN
let ownedSubscribers = filter (\(_, s) -> s.subscribedTo == party) subscribers
let ownedOriginals = filter (\(_, o) -> o.owner == party) originals
let ownedCopies = filter (\(_, c) -> c.original.owner == party) copies
-- ACS_FILTER_END
-- SUBSCRIBING_PARTIES_BEGIN
let subscribingParties = map (\(_, s) -> s.subscriber) ownedSubscribers
-- SUBSCRIBING_PARTIES_END
-- GROUP_COPIES_BEGIN
let groupedCopies : [[(ContractId Copy, Copy)]]
groupedCopies = groupOn snd $ sortOn snd $ ownedCopies
let copiesToKeep = map head groupedCopies
let archiveDuplicateCopies = concatMap tail groupedCopies
-- GROUP_COPIES_END
-- ARCHIVE_COPIES_BEGIN
let archiveMissingOriginal = filter (\(_, c) -> c.original `notElem` map snd ownedOriginals) copiesToKeep
let archiveMissingSubscriber = filter (\(_, c) -> c.subscriber `notElem` subscribingParties) copiesToKeep
let archiveCopies = dedup $ map fst $ archiveDuplicateCopies <> archiveMissingOriginal <> archiveMissingSubscriber
-- ARCHIVE_COPIES_END
-- ARCHIVE_COMMAND_BEGIN
forA archiveCopies $ \cid -> emitCommands [exerciseCmd cid Archive] [toAnyContractId cid]
-- ARCHIVE_COMMAND_END
-- CREATE_COPIES_BEGIN
let neededCopies = [Copy m o | (_, m) <- ownedOriginals, o <- subscribingParties]
let createCopies = filter (\c -> c `notElem` map snd copiesToKeep) neededCopies
mapA dedupCreate createCopies
-- CREATE_COPIES_END
pure ()
-- | The dedup function from DA.List requires an Ord constraint which we do not have for `ContractId k`. Therefore,
-- we resort to the n^2 version for now. Once we have Maps we can use those to implement a more efficient dedup.
dedup : Eq k => [k] -> [k]
dedup [] = []
dedup (x :: xs) = x :: dedup (filter (/= x) xs)

View File

@ -39,17 +39,12 @@ write_scalatest_runpath(
tests = [
"//ledger-service/utils",
"//libs-scala/jwt:tests-lib",
"//triggers/service:test",
"//triggers/service/auth:oauth2-middleware-tests",
],
runtime_deps = [
"//ledger/error:error-test-lib",
"//libs-scala/flyway-testing",
"//libs-scala/jwt",
"//libs-scala/scalatest-utils",
"//triggers/service:trigger-service",
"//triggers/service:trigger-service-tests",
"//triggers/service/auth:oauth2-middleware-tests",
],
deps = [
":ledger-generator-lib",

View File

@ -1,146 +0,0 @@
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# TODO Once daml_compile uses build instead of package we should use
# daml_compile instead of a genrule.
load("@build_environment//:configuration.bzl", "ghc_version", "sdk_version")
load("//daml-lf/language:daml-lf.bzl", "COMPILER_LF_VERSIONS")
# Build one DAR per LF version to bundle with the SDK.
# Also build one DAR with the default LF version for test-cases.
[
genrule(
name = "daml-trigger{}".format(suffix),
srcs = glob(["**/*.daml"]) + ["//daml-script/daml:daml-script{}".format(suffix)],
outs = ["daml-trigger{}.dar".format(suffix)],
cmd = """
set -eou pipefail
TMP_DIR=$$(mktemp -d)
mkdir -p $$TMP_DIR/daml/Daml/Trigger
cp -L $(location Daml/Trigger.daml) $$TMP_DIR/daml/Daml
cp -L $(location Daml/Trigger/Assert.daml) $$TMP_DIR/daml/Daml/Trigger
cp -L $(location Daml/Trigger/Internal.daml) $$TMP_DIR/daml/Daml/Trigger
cp -L $(location Daml/Trigger/LowLevel.daml) $$TMP_DIR/daml/Daml/Trigger
cp -L $(location {daml_script}) $$TMP_DIR/daml-script.dar
cat << EOF > $$TMP_DIR/daml.yaml
sdk-version: {sdk}
name: daml-trigger
source: daml
version: {ghc}
dependencies:
- daml-stdlib
- daml-prim
- daml-script.dar
build-options: {build_options}
EOF
$(location //compiler/damlc) build --project-root $$TMP_DIR --ghc-option=-Werror --log-level=WARNING \
-o $$PWD/$@
rm -rf $$TMP_DIR
""".format(
build_options = str([
"--target",
lf_version,
] if lf_version else []),
daml_script = "//daml-script/daml:daml-script{}".format(suffix),
ghc = ghc_version,
sdk = sdk_version,
),
tools = [
"//compiler/damlc",
],
visibility = ["//visibility:public"],
)
for lf_version in COMPILER_LF_VERSIONS + [""]
for suffix in [("-" + lf_version) if lf_version else ""]
]
filegroup(
name = "daml-trigger-dars",
srcs = [
"daml-trigger-{}.dar".format(lf_version)
for lf_version in COMPILER_LF_VERSIONS
],
visibility = ["//visibility:public"],
)
genrule(
name = "daml-trigger-json-docs",
srcs = glob(["**/*.daml"]) + [
"//daml-script/daml:daml-script",
],
outs = ["daml-trigger.json"],
cmd = """
TMP_DIR=$$(mktemp -d)
mkdir -p $$TMP_DIR/daml/Daml/Trigger
cp -L $(location Daml/Trigger.daml) $$TMP_DIR/daml/Daml
cp -L $(location Daml/Trigger/Assert.daml) $$TMP_DIR/daml/Daml/Trigger
cp -L $(location Daml/Trigger/Internal.daml) $$TMP_DIR/daml/Daml/Trigger
cp -L $(location Daml/Trigger/LowLevel.daml) $$TMP_DIR/daml/Daml/Trigger
cp -L $$PWD/$(location {daml_script}) $$TMP_DIR/daml-script.dar
cat << EOF > $$TMP_DIR/daml.yaml
sdk-version: {sdk}
name: daml-trigger
source: daml
version: {ghc}
dependencies:
- daml-stdlib
- daml-prim
- daml-script.dar
EOF
DAMLC=$$PWD/$(location //compiler/damlc)
JSON=$$PWD/$(location :daml-trigger.json)
cd $$TMP_DIR
$$DAMLC init
$$DAMLC -- docs \
--combine \
--output=$$JSON \
--format=Json \
--package-name=daml-trigger \
$$TMP_DIR/daml/Daml/Trigger.daml \
$$TMP_DIR/daml/Daml/Trigger/Assert.daml \
$$TMP_DIR/daml/Daml/Trigger/LowLevel.daml
""".format(
daml_script = "//daml-script/daml:daml-script",
ghc = ghc_version,
sdk = sdk_version,
),
tools = [
"//compiler/damlc",
],
visibility = ["//visibility:public"],
)
genrule(
name = "daml-trigger-docs",
srcs = [
":daml-trigger.json",
":daml-trigger-rst-template.rst",
":daml-trigger-index-template.rst",
":daml-trigger-hoogle-template.txt",
],
outs = [
"daml-trigger-rst.tar.gz",
"daml-trigger-hoogle.txt",
"daml-trigger-anchors.json",
],
cmd = """
$(location //compiler/damlc) -- docs \
--output=daml-trigger-rst \
--input-format=json \\
--format=Rst \
--template=$(location :daml-trigger-rst-template.rst) \
--index-template=$(location :daml-trigger-index-template.rst) \\
--hoogle-template=$(location :daml-trigger-hoogle-template.txt) \\
--base-url=https://docs.daml.com/triggers/api/ \\
--output-hoogle=$(location :daml-trigger-hoogle.txt) \\
--output-anchor=$(location :daml-trigger-anchors.json) \\
$(location :daml-trigger.json)
$(execpath //bazel_tools/sh:mktgz) $(location :daml-trigger-rst.tar.gz) daml-trigger-rst
""",
tools = [
"//bazel_tools/sh:mktgz",
"//compiler/damlc",
],
visibility = ["//visibility:public"],
)

View File

@ -1,358 +0,0 @@
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
module Daml.Trigger
( query
, queryFilter
, queryContractId
, queryContractKey
, ActionTriggerAny
, getCommandsInFlight
, ActionTriggerUpdate
, Trigger(..)
, TriggerA
, TriggerUpdateA
, TriggerInitializeA
, get
, put
, modify
, emitCommands
, emitCommandsV2
, runTrigger
, CommandId
, Command(..)
, AnyContractId
, toAnyContractId
, fromAnyContractId
, exerciseCmd
, createCmd
, exerciseByKeyCmd
, createAndExerciseCmd
, dedupExercise
, dedupCreate
, dedupExerciseByKey
, dedupCreateAndExercise
, Message(..)
, Completion(..)
, Transaction(..)
, Event(..)
, Created
, Archived
, fromCreated
, fromArchived
, CompletionStatus(..)
, RegisteredTemplates(..)
, registeredTemplate
, RelTime(..)
, getReadAs
, getActAs
) where
import Prelude hiding (any)
import DA.Action
import DA.Action.State (execState)
import DA.Foldable (any)
import DA.Functor ((<&>))
import DA.Map (Map, size)
import qualified DA.Map as Map
import DA.Optional
import Daml.Trigger.Internal
import Daml.Trigger.LowLevel hiding (BatchTrigger, Trigger)
import qualified Daml.Trigger.LowLevel as LowLevel
-- public API
-- | Extract the contracts of a given template from the ACS.
getContracts : forall a. Template a => ACS -> [(ContractId a, a)]
getContracts acs@(ACS tpls _) = mapOptional fromAny
$ filter (\(cid, _) -> not $ cid `elem` allPending)
$ optional [] Map.toList
$ Map.lookup (templateTypeRep @a) tpls
where
fromAny (cid, tpl) = (,) <$> fromAnyContractId cid <*> fromAnyTemplate tpl
allPending = getPendingContracts acs
getPendingContracts : ACS -> [AnyContractId]
getPendingContracts (ACS _ pending) = concatMap snd $ Map.toList pending
getContractById : forall a. Template a => ContractId a -> ACS -> Optional a
getContractById id (ACS tpls pending) = do
let aid = toAnyContractId id
implSpecific = Map.lookup aid <=< Map.lookup (templateTypeRep @a)
aa <- implSpecific tpls
a <- fromAnyTemplate aa
if any (elem aid) pending then None else Some a
-- | Extract the contracts of a given template from the ACS.
query : forall a m. (Template a, ActionTriggerAny m) => m [(ContractId a, a)]
query = implQuery
-- | Extract the contracts of a given template from the ACS and filter
-- to those that match the predicate.
queryFilter : forall a m. (Functor m, Template a, ActionTriggerAny m) => (a -> Bool) -> m [(ContractId a, a)]
queryFilter pred = filter (\(_, c) -> pred c) <$> implQuery
-- | Find the contract with the given `key` in the ACS, if present.
queryContractKey : forall a k m. (Template a, HasKey a k, Eq k, ActionTriggerAny m, Functor m)
=> k -> m (Optional (ContractId a, a))
queryContractKey k = find (\(_, a) -> k == key a) <$> query
-- | Features possible in `initialize`, `updateState`, and `rule`.
class ActionTriggerAny m where
-- | Extract the contracts of a given template from the ACS. (However, the
-- type parameters are in the 'm a' order, so it is not exported.)
implQuery : forall a. Template a => m [(ContractId a, a)]
-- | Find the contract with the given `id` in the ACS, if present.
queryContractId : Template a => ContractId a -> m (Optional a)
-- | Query the list of currently pending contracts as set by
-- `emitCommands`.
queryPendingContracts : m [AnyContractId]
getReadAs : m [Party]
getActAs : m Party
instance ActionTriggerAny (TriggerA s) where
implQuery = TriggerA $ pure . getContracts
queryContractId id = TriggerA $ pure . getContractById id
queryPendingContracts = TriggerA $ \acs -> pure (getPendingContracts acs)
getReadAs = TriggerA $ \_ -> do
s <- get
pure s.readAs
getActAs = TriggerA $ \_ -> do
s <- get
pure s.actAs
instance ActionTriggerAny (TriggerUpdateA s) where
implQuery = TriggerUpdateA $ \s -> pure (getContracts s.acs)
queryContractId id = TriggerUpdateA $ \s -> pure (getContractById id s.acs)
queryPendingContracts = TriggerUpdateA $ \s -> pure (getPendingContracts s.acs)
getReadAs = TriggerUpdateA $ \s -> pure s.readAs
getActAs = TriggerUpdateA $ \s -> pure s.actAs
instance ActionTriggerAny TriggerInitializeA where
implQuery = TriggerInitializeA (\s -> getContracts s.acs)
queryContractId id = TriggerInitializeA (\s -> getContractById id s.acs)
queryPendingContracts = TriggerInitializeA (\s -> getPendingContracts s.acs)
getReadAs = TriggerInitializeA (\s -> s.readAs)
getActAs = TriggerInitializeA (\s -> s.actAs)
-- | Features possible in `updateState` and `rule`.
class ActionTriggerAny m => ActionTriggerUpdate m where
-- | Retrieve command submissions made by this trigger that have not yet
-- completed. If the trigger has restarted, it will not contain commands from
-- before the restart; therefore, this should be treated as an optimization
-- rather than an absolute authority on ledger state.
getCommandsInFlight : m (Map CommandId [Command])
instance ActionTriggerUpdate (TriggerUpdateA s) where
getCommandsInFlight = TriggerUpdateA $ \s -> pure s.commandsInFlight
instance ActionTriggerUpdate (TriggerA s) where
getCommandsInFlight = liftTriggerRule $ get <&> \s -> s.commandsInFlight
-- | This is the type of your trigger. `s` is the user-defined state type which
-- you can often leave at `()`.
data Trigger s = Trigger
{ initialize : TriggerInitializeA s
-- ^ Initialize the user-defined state based on the ACS.
, updateState : Message -> TriggerUpdateA s ()
-- ^ Update the user-defined state based on a transaction or
-- completion message. It can manipulate the state with `get`, `put`,
-- and `modify`, or query the ACS with `query`.
, rule : Party -> TriggerA s ()
-- ^ The rule defines the main logic of your trigger. It can send commands
-- to the ledger using `emitCommands` to change the ACS.
-- The rule depends on the following arguments:
--
-- * The party your trigger is running as.
-- * The user-defined state.
--
-- and can retrieve other data with functions in `TriggerA`:
--
-- * The current state of the ACS.
-- * The current time (UTC in wallclock mode, Unix epoch in static mode)
-- * The commands in flight.
, registeredTemplates : RegisteredTemplates
-- ^ The templates the trigger will receive events for.
, heartbeat : Optional RelTime
-- ^ Send a heartbeat message at the given interval.
}
-- | Send a transaction consisting of the given commands to the ledger.
-- The second argument can be used to mark a list of contract ids as pending.
-- These contracts will automatically be filtered from getContracts until we
-- either get the corresponding transaction event for this command or
-- a failing completion.
emitCommands : [Command] -> [AnyContractId] -> TriggerA s CommandId
emitCommands cmds pending = do
id <- liftTriggerRule $ submitCommands cmds
let commands = Commands id cmds
liftTriggerRule $ modify $ \s -> s
{ commandsInFlight = addCommands s.commandsInFlight commands
, pendingContracts = Map.insert id pending s.pendingContracts
}
pure id
-- Version of emitCommands that will not perform command submissions once there are too
-- many commands in-flight. When this function does not to return a command ID (i.e. it
-- returns None), then client code should manage this scenario. Failing to manage this
-- warning scenario will eventually cause the trigger runner to stop with an
-- InFlightCommandOverflowException exception.
emitCommandsV2 : [Command] -> [AnyContractId] -> TriggerA s (Optional CommandId)
emitCommandsV2 cmds pending = do
mbId <- liftTriggerRule $ mbSubmitCommands cmds
case mbId of
None ->
pure None
Some id -> do
let commands = Commands id cmds
liftTriggerRule $ modify $ \s -> s
{ commandsInFlight = addCommands s.commandsInFlight commands
, pendingContracts = Map.insert id pending s.pendingContracts
}
pure (Some id)
where
mbSubmitCommands cmds = do
state <- get
if size state.commandsInFlight < state.config.maxInFlightCommands then do
id <- submitCommands cmds
pure (Some id)
else do
_ <- debug "WARN: too many commands currently in-flight, so command submission will be dropped"
pure None
-- | Create the template if its not already in the list of commands
-- in flight (it will still be created if it is in the ACS).
--
-- Note that this will send the create as a single-command transaction.
-- If you need to send multiple commands in one transaction, use
-- `emitCommands` with `createCmd` and handle filtering yourself.
dedupCreate : (Eq t, Template t) => t -> TriggerA s ()
dedupCreate t = do
aState <- liftTriggerRule get
-- This is a very naive approach that is linear in the number of commands in flight.
-- We probably want to change this to express the commands in flight as some kind of
-- map to make these lookups cheaper.
let cmds = concat $ map snd (Map.toList aState.commandsInFlight)
unless (any ((Some t ==) . fromCreate) cmds) $
void $ emitCommands [createCmd t] []
-- | Create the template and exercise a choice on it if its not already in the list of commands
-- in flight (it will still be created if it is in the ACS).
--
-- Note that this will send the create and exercise as a
-- single-command transaction. If you need to send multiple commands
-- in one transaction, use `emitCommands` with `createAndExerciseCmd`
-- and handle filtering yourself.
dedupCreateAndExercise : (Eq t, Eq c, Template t, Choice t c r) => t -> c -> TriggerA s ()
dedupCreateAndExercise t c = do
aState <- liftTriggerRule get
-- This is a very naive approach that is linear in the number of
-- commands in flight. We probably want to change this to express
-- the commands in flight as some kind of map to make these lookups
-- cheaper.
let cmds = concat $ map snd (Map.toList aState.commandsInFlight)
unless (any ((Some (t, c) ==) . fromCreateAndExercise) cmds) $
void $ emitCommands [createAndExerciseCmd t c] []
-- | Exercise the choice on the given contract if it is not already
-- in flight.
--
-- Note that this will send the exercise as a single-command transaction.
-- If you need to send multiple commands in one transaction, use
-- `emitCommands` with `exerciseCmd` and handle filtering yourself.
--
-- If you are calling a consuming choice, you might be better off by using
-- `emitCommands` and adding the contract id to the pending set.
dedupExercise : (Eq c, Choice t c r) => ContractId t -> c -> TriggerA s ()
dedupExercise cid c = do
aState <- liftTriggerRule get
-- This is a very naive approach that is linear in the number of commands in flight.
-- We probably want to change this to express the commands in flight as some kind of
-- map to make these lookups cheaper.
let cmds = concat $ map snd (Map.toList aState.commandsInFlight)
unless (any ((Some (cid, c) ==) . fromExercise) cmds) $
void $ emitCommands [exerciseCmd cid c] []
-- | Exercise the choice on the given contract if it is not already
-- in flight.
--
-- Note that this will send the exercise as a single-command transaction.
-- If you need to send multiple commands in one transaction, use
-- `emitCommands` with `exerciseCmd` and handle filtering yourself.
dedupExerciseByKey : forall t c r k s. (Eq c, Eq k, Choice t c r, TemplateKey t k) => k -> c -> TriggerA s ()
dedupExerciseByKey k c = do
aState <- liftTriggerRule get
-- This is a very naive approach that is linear in the number of commands in flight.
-- We probably want to change this to express the commands in flight as some kind of
-- map to make these lookups cheaper.
let cmds = concat $ map snd (Map.toList aState.commandsInFlight)
unless (any ((Some (k, c) ==) . fromExerciseByKey @t) cmds) $
void $ emitCommands [exerciseByKeyCmd @t k c] []
-- | Transform the high-level trigger type into the batching trigger from `Daml.Trigger.LowLevel`.
runTrigger : forall s. Trigger s -> LowLevel.BatchTrigger (TriggerState s)
runTrigger userTrigger = LowLevel.BatchTrigger
{ initialState
, update = update
, registeredTemplates = userTrigger.registeredTemplates
, heartbeat = userTrigger.heartbeat
}
where
initialState args =
let acs = foldl (\acs created -> applyEvent (CreatedEvent created) acs) (ACS mempty Map.empty) args.acs.activeContracts
userState = runTriggerInitializeA userTrigger.initialize (TriggerInitState acs args.actAs args.readAs)
state = TriggerState acs args.actAs args.readAs userState Map.empty args.config
in TriggerSetup $ execStateT (runTriggerRule $ runRule userTrigger.rule) state
mkUserState state acs msg =
let state' = TriggerUpdateState state.commandsInFlight acs state.actAs state.readAs
in execState (flip runTriggerUpdateA state' $ userTrigger.updateState msg) state.userState
update msgs = do
runUserTrigger <- or <$> mapA processMessage msgs
when runUserTrigger $
runRule userTrigger.rule
-- Returns 'True' if the processed message means we need to run the user's trigger.
processMessage : Message -> TriggerRule (TriggerState s) Bool
processMessage msg = do
state <- get
case msg of
MCompletion completion ->
-- NB: the commands-in-flight and ACS updateState sees are those
-- prior to updates incurred by the msg
let userState = mkUserState state state.acs msg
in case completion.status of
Succeeded {} -> do
-- We delete successful completions when we receive the corresponding transaction
-- to avoid removing a command from commandsInFlight before we have modified the ACS.
put $ state { userState }
pure False
Failed {} -> do
let commandsInFlight = Map.delete completion.commandId state.commandsInFlight
acs = state.acs { pendingContracts = Map.delete completion.commandId state.acs.pendingContracts }
put $ state { commandsInFlight, userState, acs }
pure True
MTransaction transaction -> do
let acs = applyTransaction transaction state.acs
-- again, we use the commands-in-flight and ACS before the update below
userState = mkUserState state acs msg
-- See the comment above for why we delete this here instead of when we receive the completion.
(acs', commandsInFlight) = case transaction.commandId of
None -> (acs, state.commandsInFlight)
Some commandId -> (acs { pendingContracts = Map.delete commandId acs.pendingContracts }, Map.delete commandId state.commandsInFlight)
put $ state { acs = acs', userState, commandsInFlight }
pure True
MHeartbeat -> do
let userState = mkUserState state state.acs msg
put $ state { userState }
pure True

View File

@ -1,147 +0,0 @@
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
module Daml.Trigger.Assert
( ACSBuilder
, toACS
, testRule
, flattenCommands
, assertCreateCmd
, assertExerciseCmd
, assertExerciseByKeyCmd
) where
import qualified DA.List as List
import DA.Map (Map)
import qualified DA.Map as Map
import qualified DA.Text as Text
import Daml.Trigger hiding (queryContractId)
import Daml.Trigger.Internal
import Daml.Trigger.LowLevel hiding (Trigger)
import Daml.Script (Script, queryContractId)
-- | Used to construct an 'ACS' for 'testRule'.
newtype ACSBuilder = ACSBuilder (Party -> [Script (AnyContractId, AnyTemplate)])
instance Semigroup ACSBuilder where
ACSBuilder l <> ACSBuilder r =
ACSBuilder (\p -> l p <> r p)
instance Monoid ACSBuilder where
mempty = ACSBuilder (const mempty)
buildACS : Party -> ACSBuilder -> Script ACS
buildACS party (ACSBuilder fetches) = do
activeContracts <- sequence (fetches party)
pure ACS
{ activeContracts = groupActiveContracts activeContracts
, pendingContracts = Map.empty
}
-- | Include the given contract in the 'ACS'. Note that the `ContractId`
-- must point to an active contract.
toACS : (Template t, HasAgreement t) => ContractId t -> ACSBuilder
toACS cid = ACSBuilder $ \p ->
[ do t <-
queryContractId p cid >>= \case
None -> abort ("Failed to fetch contract passed to toACS: " <> show cid)
Some c -> pure c
pure (toAnyContractId cid, toAnyTemplate t)
]
-- | Execute a trigger's rule once in a scenario.
testRule
: Trigger s -- ^ Test this trigger's 'Trigger.rule'.
-> Party -- ^ Execute the rule as this 'Party'.
-> [Party] -- ^ Execute the rule with these parties as `readAs`
-> ACSBuilder -- ^ List these contracts in the 'ACS'.
-> Map CommandId [Command] -- ^ The commands in flight.
-> s -- ^ The trigger state.
-> Script (s, [Commands]) -- ^ The 'Commands' and new state emitted by the rule. The 'CommandId's will start from @"0"@.
testRule trigger party readAs acsBuilder commandsInFlight s = do
time <- getTime
acs <- buildACS party acsBuilder
let config = TriggerConfig
{ maxInFlightCommands = 1, maxActiveContracts = 1 }
let state = TriggerState
{ acs = acs
, actAs = party
, readAs = readAs
, userState = s
, commandsInFlight = commandsInFlight
, config = config
}
-- A sufficiently powerful `TriggerF` command will entail redoing this as a
-- natural transformation from `Free TriggerF` to `Free ScriptF`, which in
-- turn will likely require exposing `ScriptF` or both in some "internal"
-- fashion. Meanwhile, it is worth pushing `simulateRule` as far as possible
-- to forestall that outcome. -SC
let (state', commands, _) = simulateRule (runRule trigger.rule) time state
pure (state'.userState, commands)
-- | Drop 'CommandId's and extract all 'Command's.
flattenCommands : [Commands] -> [Command]
flattenCommands = concatMap commands
expectCommand
: [Command]
-> (Command -> Optional a)
-> (a -> Either Text ())
-> Either [Text] ()
expectCommand commands fromCommand assertion = foldl step (Left []) commands
where
step : Either [Text] () -> Command -> Either [Text] ()
step (Right ()) _ = Right ()
step (Left msgs) command =
case assertion <$> fromCommand command of
None -> Left msgs
Some (Left msg) -> Left (msg :: msgs)
Some (Right ()) -> Right ()
-- | Check that at least one command is a create command whose payload fulfills the given assertions.
assertCreateCmd
: (Template t, HasAgreement t, CanAbort m)
=> [Command] -- ^ Check these commands.
-> (t -> Either Text ()) -- ^ Perform these assertions.
-> m ()
assertCreateCmd commands assertion =
case expectCommand commands fromCreate assertion of
Right () -> pure ()
Left msgs ->
abort $ "Failure, found no matching create command." <> collectMessages msgs
-- | Check that at least one command is an exercise command whose contract id and choice argument fulfill the given assertions.
assertExerciseCmd
: (Template t, HasAgreement t, Choice t c r, CanAbort m)
=> [Command] -- ^ Check these commands.
-> ((ContractId t, c) -> Either Text ()) -- ^ Perform these assertions.
-> m ()
assertExerciseCmd commands assertion =
case expectCommand commands fromExercise assertion of
Right () -> pure ()
Left msgs ->
abort $ "Failure, found no matching exercise command." <> collectMessages msgs
-- | Check that at least one command is an exercise by key command whose key and choice argument fulfill the given assertions.
assertExerciseByKeyCmd
: forall t c r k m
. (TemplateKey t k, Choice t c r, CanAbort m)
=> [Command] -- ^ Check these commands.
-> ((k, c) -> Either Text ()) -- ^ Perform these assertions.
-> m ()
assertExerciseByKeyCmd commands assertion =
case expectCommand commands (fromExerciseByKey @t) assertion of
Right () -> pure ()
Left msgs ->
abort $ "Failure, found no matching exerciseByKey command." <> collectMessages msgs
collectMessages : [Text] -> Text
collectMessages [] = ""
collectMessages msgs = "\n" <> Text.unlines (map (bullet . nest) msgs)
where
bullet txt = " * " <> txt
nest = Text.intercalate "\n" . List.intersperse " " . Text.lines

View File

@ -1,243 +0,0 @@
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0
{-# LANGUAGE CPP #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
-- | MOVE Daml.Trigger
module Daml.Trigger.Internal
( ACS (..)
, TriggerA (..)
, TriggerUpdateA (..)
, TriggerInitializeA (..)
, addCommands
, insertTpl
, groupActiveContracts
, deleteTpl
, lookupTpl
, applyEvent
, applyTransaction
, runRule
, liftTriggerRule
, TriggerAState (..)
, TriggerState (..)
, TriggerInitState(..)
, TriggerUpdateState(..)
) where
import DA.Action.State
import DA.Foldable (forA_)
import DA.Functor ((<&>))
import DA.Map (Map)
import qualified DA.Map as Map
import DA.Optional (fromOptional)
import Daml.Trigger.LowLevel hiding (Trigger)
import qualified Daml.Trigger.LowLevel as LowLevel
-- We use this singleton enumeration to track the version of the API between Daml code and Scala code of the trigger
data Version = Version_2_6
-- public API
-- | HIDE Active contract set, you can use `getContracts` to access the templates of
-- a given type.
-- This will change to a Map once we have proper maps in Daml-LF
-- The following variant type should be kept in sync with methods numberOfActiveContracts and numberOfPendingContracts
-- in Runner.scala
data ACS = ACS
{ activeContracts : Map.Map TemplateTypeRep (Map.Map AnyContractId AnyTemplate)
, pendingContracts : Map CommandId [AnyContractId]
}
-- | TriggerA is the type used in the `rule` of a Daml trigger.
-- Its main feature is that you can call `emitCommands` to
-- send commands to the ledger.
newtype TriggerA s a =
-- | HIDE
TriggerA { runTriggerA : ACS -> TriggerRule (TriggerAState s) a }
instance Functor (TriggerA s) where
fmap f (TriggerA r) = TriggerA $ rliftFmap fmap f r
instance Applicative (TriggerA s) where
pure = TriggerA . rliftPure pure
TriggerA ff <*> TriggerA fa = TriggerA $ rliftAp (<*>) ff fa
instance Action (TriggerA s) where
TriggerA fa >>= f = TriggerA $ rliftBind (>>=) fa (runTriggerA . f)
instance ActionState s (TriggerA s) where
get = TriggerA $ const (get <&> \tas -> tas.userState)
modify f = TriggerA . const . modify $ \tas -> tas { userState = f tas.userState }
instance HasTime (TriggerA s) where
getTime = TriggerA $ const getTime
-- | HIDE
data TriggerUpdateState = TriggerUpdateState
with
commandsInFlight : Map CommandId [Command]
acs : ACS
actAs : Party
readAs : [Party]
-- | TriggerUpdateA is the type used in the `updateState` of a Daml
-- trigger. It has similar actions in common with `TriggerA`, but
-- cannot use `emitCommands` or `getTime`.
newtype TriggerUpdateA s a =
-- | HIDE
TriggerUpdateA { runTriggerUpdateA : TriggerUpdateState -> State s a }
instance Functor (TriggerUpdateA s) where
fmap f (TriggerUpdateA r) = TriggerUpdateA $ rliftFmap fmap f r
instance Applicative (TriggerUpdateA s) where
pure = TriggerUpdateA . rliftPure pure
TriggerUpdateA ff <*> TriggerUpdateA fa = TriggerUpdateA $ rliftAp (<*>) ff fa
instance Action (TriggerUpdateA s) where
TriggerUpdateA fa >>= f = TriggerUpdateA $ rliftBind (>>=) fa (runTriggerUpdateA . f)
instance ActionState s (TriggerUpdateA s) where
get = TriggerUpdateA $ const get
put = TriggerUpdateA . const . put
modify = TriggerUpdateA . const . modify
-- | HIDE
data TriggerInitState = TriggerInitState
with
acs : ACS
actAs : Party
readAs : [Party]
-- | TriggerInitializeA is the type used in the `initialize` of a Daml
-- trigger. It can query, but not emit commands or update the state.
newtype TriggerInitializeA a =
-- | HIDE
TriggerInitializeA { runTriggerInitializeA : TriggerInitState -> a }
deriving (Functor, Applicative, Action)
-- Internal API
-- | HIDE
addCommands : Map CommandId [Command] -> Commands -> Map CommandId [Command]
addCommands m (Commands cid cmds) = Map.insert cid cmds m
-- | HIDE
insertTpl : AnyContractId -> AnyTemplate -> ACS -> ACS
insertTpl cid tpl acs = acs { activeContracts = Map.alter addct cid.templateId acs.activeContracts }
where addct = Some . Map.insert cid tpl . fromOptional mempty
-- | HIDE
groupActiveContracts :
[(AnyContractId, AnyTemplate)] -> Map.Map TemplateTypeRep (Map.Map AnyContractId AnyTemplate)
groupActiveContracts = foldr (\v@(cid, _) -> Map.alter (addct v) cid.templateId) Map.empty
where addct (cid, tpl) = Some . Map.insert cid tpl . fromOptional mempty
-- | HIDE
deleteTpl : AnyContractId -> ACS -> ACS
deleteTpl cid acs = acs { activeContracts = Map.alter rmct cid.templateId acs.activeContracts }
where rmct om = do
m <- om
let m' = Map.delete cid m
if Map.null m' then None else Some m'
-- | HIDE
lookupTpl : Template a => AnyContractId -> ACS -> Optional a
lookupTpl cid acs = do
tpl <- Map.lookup cid =<< Map.lookup cid.templateId acs.activeContracts
fromAnyTemplate tpl
-- | HIDE
applyEvent : Event -> ACS -> ACS
applyEvent ev acs = case ev of
CreatedEvent (Created _ cid (Some tpl) _) -> insertTpl cid tpl acs
CreatedEvent _ -> acs
ArchivedEvent (Archived _ cid) -> deleteTpl cid acs
-- | HIDE
applyTransaction : Transaction -> ACS -> ACS
applyTransaction (Transaction _ _ evs) acs = foldl (flip applyEvent) acs evs
-- | HIDE
runRule
: (Party -> TriggerA s a)
-> TriggerRule (TriggerState s) a
runRule rule = do
state <- get
TriggerRule . zoom zoomIn zoomOut . runTriggerRule . flip runTriggerA state.acs
$ rule state.actAs
where zoomIn state = TriggerAState state.commandsInFlight state.acs.pendingContracts state.userState state.readAs state.actAs state.config
zoomOut state aState =
let commandsInFlight = aState.commandsInFlight
acs = state.acs { pendingContracts = aState.pendingContracts }
userState = aState.userState
readAs = aState.readAs
actAs = aState.actAs
config = aState.config
in state { commandsInFlight, acs, userState, readAs, actAs, config }
-- | HIDE
-- | Transform the (legacy) low-level trigger type into a batching trigger.
runLegacyTrigger : LowLevel.Trigger s -> BatchTrigger s
runLegacyTrigger userTrigger = BatchTrigger
{ initialState = \args -> userTrigger.initialState args.actAs args.readAs args.acs
, update = \msgs -> forA_ msgs userTrigger.update
, registeredTemplates = userTrigger.registeredTemplates
, heartbeat = userTrigger.heartbeat
}
-- | HIDE
liftTriggerRule : TriggerRule (TriggerAState s) a -> TriggerA s a
liftTriggerRule = TriggerA . const
-- | HIDE
data TriggerAState s = TriggerAState
{ commandsInFlight : Map CommandId [Command]
-- ^ Zoomed from TriggerState; used for dedupCreateCmd/dedupExerciseCmd
-- helpers and extended by emitCommands.
, pendingContracts : Map CommandId [AnyContractId]
-- ^ Map from command ids to the contract ids marked pending by that command;
-- zoomed from TriggerState's acs.
, userState : s
-- ^ zoomed from TriggerState
, readAs : [Party]
-- ^ zoomed from TriggerState
, actAs : Party
-- ^ zoomed from TriggerState
, config : TriggerConfig
-- ^ zoomed from TriggerState
}
-- | HIDE
-- The following record type should be kept in sync with method numberOfInFlightCommands in Runner.scala
data TriggerState s = TriggerState
{ acs : ACS
, actAs : Party
, readAs : [Party]
, userState : s
, commandsInFlight : Map CommandId [Command]
, config : TriggerConfig
}
-- | HIDE
--
-- unboxed newtype for common Trigger*A additions
type TriggerAT r f a = r -> f a
-- | HIDE
rliftFmap : ((a -> b) -> f a -> f b) -> (a -> b) -> TriggerAT r f a -> TriggerAT r f b
rliftFmap ub f r = ub f . r
-- | HIDE
rliftPure : (a -> f a) -> a -> TriggerAT r f a
rliftPure ub = const . ub
-- | HIDE
rliftAp : (f (a -> b) -> f a -> f b) -> TriggerAT r f (a -> b) -> TriggerAT r f a -> TriggerAT r f b
rliftAp ub ff fa r = ff r `ub` fa r
-- | HIDE
rliftBind : (f a -> (a -> f b) -> f b) -> TriggerAT r f a -> (a -> TriggerAT r f b) -> TriggerAT r f b
rliftBind ub fa f r = fa r `ub` \a -> f a r

View File

@ -1,364 +0,0 @@
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0
{-# LANGUAGE AllowAmbiguousTypes #-}
module Daml.Trigger.LowLevel
( Message(..)
, Completion(..)
, CompletionStatus(..)
, TriggerConfig(..)
, TriggerSetupArguments(..)
, Transaction(..)
, AnyContractId
, toAnyContractId
, fromAnyContractId
, TransactionId(..)
, EventId(..)
, CommandId(..)
, Event(..)
, Created(..)
, fromCreated
, Archived(..)
, fromArchived
, Trigger(..)
, BatchTrigger(..)
, ActiveContracts(..)
, Commands(..)
, Command(..)
, createCmd
, exerciseCmd
, exerciseByKeyCmd
, createAndExerciseCmd
, fromCreate
, fromExercise
, fromExerciseByKey
, fromCreateAndExercise
, RegisteredTemplates(..)
, registeredTemplate
, RelTime(..)
, execStateT
, ActionTrigger(..)
, TriggerSetup(..)
, TriggerRule(..)
, ActionState(..)
, zoom
, submitCommands
, simulateRule
) where
import DA.Action.State
import DA.Action.State.Class
import DA.Functor ((<&>))
import DA.Internal.Interface.AnyView.Types
import DA.Time (RelTime(..))
import Daml.Script.Free (Free(..), lift, foldFree)
-- | This type represents the contract id of an unknown template.
-- You can use `fromAnyContractId` to check which template it corresponds to.
data AnyContractId = AnyContractId
{ templateId : TemplateTypeRep
, contractId : ContractId ()
} deriving Eq
deriving instance Ord AnyContractId
-- We cant derive the Show instance since TemplateTypeRep does not have a Show instance
-- but it is useful for debugging so we add one that omits the type.
instance Show AnyContractId where
showsPrec d (AnyContractId _ cid) = showParen (d > app_prec) $
showString "AnyContractId " . showsPrec (app_prec +1) cid
where app_prec = 10
-- | Wrap a `ContractId t` in `AnyContractId`.
toAnyContractId : forall t. Template t => ContractId t -> AnyContractId
toAnyContractId cid = AnyContractId
{ templateId = templateTypeRep @t
, contractId = coerceContractId cid
}
-- | Check if a `AnyContractId` corresponds to the given template or return
-- `None` otherwise.
fromAnyContractId : forall t. Template t => AnyContractId -> Optional (ContractId t)
fromAnyContractId cid
| cid.templateId == templateTypeRep @t = Some (coerceContractId cid.contractId)
| otherwise = None
newtype TransactionId = TransactionId Text
deriving (Show, Eq)
newtype EventId = EventId Text
deriving (Show, Eq)
newtype CommandId = CommandId Text
deriving (Show, Eq, Ord)
data Transaction = Transaction
{ transactionId : TransactionId
, commandId : Optional CommandId
, events : [Event]
}
data InterfaceView = InterfaceView {
interfaceTypeRep : TemplateTypeRep,
anyView: Optional AnyView
}
-- | An event in a transaction.
-- This definition should be kept consistent with the object `EventVariant` defined in
-- triggers/runner/src/main/scala/com/digitalasset/daml/lf/engine/trigger/Converter.scala
data Event
= CreatedEvent Created
| ArchivedEvent Archived
-- | The data in a `Created` event.
data Created = Created
{ eventId : EventId
, contractId : AnyContractId
, argument : Optional AnyTemplate
, views : [InterfaceView]
}
-- | Check if a `Created` event corresponds to the given template.
fromCreated : Template t => Created -> Optional (EventId, ContractId t, t)
fromCreated Created {eventId, contractId, argument}
| Some contractId' <- fromAnyContractId contractId
, Some argument' <- argument
, Some argument'' <- fromAnyTemplate argument'
= Some (eventId, contractId', argument'')
| otherwise
= None
-- | The data in an `Archived` event.
data Archived = Archived
{ eventId : EventId
, contractId : AnyContractId
} deriving (Show, Eq)
-- | Check if an `Archived` event corresponds to the given template.
fromArchived : Template t => Archived -> Optional (EventId, ContractId t)
fromArchived Archived {eventId, contractId}
| Some contractId' <- fromAnyContractId contractId
= Some (eventId, contractId')
| otherwise
= None
-- | Either a transaction or a completion.
-- This definition should be kept consistent with the object `MessageVariant` defined in
-- triggers/runner/src/main/scala/com/digitalasset/daml/lf/engine/trigger/Converter.scala
data Message
= MTransaction Transaction
| MCompletion Completion
| MHeartbeat
-- | A completion message.
-- Note that you will only get completions for commands emitted from the trigger.
-- Contrary to the ledger API completion stream, this also includes
-- synchronous failures.
data Completion = Completion
{ commandId : CommandId
, status : CompletionStatus
} deriving Show
-- This definition should be kept consistent with the object `CompletionStatusVariant` defined in
-- triggers/runner/src/main/scala/com/digitalasset/daml/lf/engine/trigger/Converter.scala
data CompletionStatus
= Failed { status : Int, message : Text }
| Succeeded { transactionId : TransactionId }
deriving Show
-- Introduced in version 2.6.0
data TriggerConfig = TriggerConfig
{ maxInFlightCommands : Int
-- ^ maximum number of commands that should be allowed to be in-flight at any point in time.
-- Exceeding this value will eventually lead to the trigger run raising an InFlightCommandOverflowException exception.
, maxActiveContracts : Int
-- ^ maximum number of active contracts that we will allow to be stored
-- Exceeding this value will lead to the trigger runner raising an ACSOverflowException exception.
}
-- Introduced in version 2.5.1: this definition is used to simplify future extensions of trigger initialState arguments
data TriggerSetupArguments = TriggerSetupArguments
{ actAs : Party
, readAs : [Party]
, acs : ActiveContracts
, config : TriggerConfig -- added in version 2.6.0
}
data ActiveContracts = ActiveContracts { activeContracts : [Created] }
-- @WARN use 'BatchTrigger s' instead of 'Trigger s'
data Trigger s = Trigger
{ initialState : Party -> [Party] -> ActiveContracts -> TriggerSetup s
, update : Message -> TriggerRule s ()
, registeredTemplates : RegisteredTemplates
, heartbeat : Optional RelTime
}
-- | Batching trigger is (approximately) a left-fold over `Message` with
-- an accumulator of type `s`.
data BatchTrigger s = BatchTrigger
{ initialState : TriggerSetupArguments -> TriggerSetup s
, update : [Message] -> TriggerRule s ()
, registeredTemplates : RegisteredTemplates
, heartbeat : Optional RelTime
}
-- | A template that the trigger will receive events for.
newtype RegisteredTemplate = RegisteredTemplate TemplateTypeRep
-- This controls which templates the trigger will receive events for.
-- `AllInDar` is a safe default but for performance reasons you might
-- want to limit it to limit the templates that the trigger will receive
-- events for.
data RegisteredTemplates
= AllInDar -- ^ Listen to events for all templates in the given DAR.
| RegisteredTemplates [RegisteredTemplate]
registeredTemplate : forall t. Template t => RegisteredTemplate
registeredTemplate = RegisteredTemplate (templateTypeRep @t)
-- | A ledger API command. To construct a command use `createCmd` and `exerciseCmd`.
data Command
= CreateCommand
{ templateArg : AnyTemplate
}
| ExerciseCommand
{ contractId : AnyContractId
, choiceArg : AnyChoice
}
| CreateAndExerciseCommand
{ templateArg : AnyTemplate
, choiceArg : AnyChoice
}
| ExerciseByKeyCommand
{ tplTypeRep : TemplateTypeRep
, contractKey : AnyContractKey
, choiceArg : AnyChoice
}
-- | Create a contract of the given template.
createCmd : Template t => t -> Command
createCmd templateArg =
CreateCommand (toAnyTemplate templateArg)
-- | Exercise the given choice.
exerciseCmd : forall t c r. Choice t c r => ContractId t -> c -> Command
exerciseCmd contractId choiceArg =
ExerciseCommand (toAnyContractId contractId) (toAnyChoice @t choiceArg)
-- | Create a contract of the given template and immediately exercise
-- the given choice on it.
createAndExerciseCmd : forall t c r. (Template t, Choice t c r) => t -> c -> Command
createAndExerciseCmd templateArg choiceArg =
CreateAndExerciseCommand (toAnyTemplate templateArg) (toAnyChoice @t choiceArg)
exerciseByKeyCmd : forall t c r k. (Choice t c r, TemplateKey t k) => k -> c -> Command
exerciseByKeyCmd contractKey choiceArg =
ExerciseByKeyCommand (templateTypeRep @t) (toAnyContractKey @t contractKey) (toAnyChoice @t choiceArg)
-- | Check if the command corresponds to a create command
-- for the given template.
fromCreate : Template t => Command -> Optional t
fromCreate (CreateCommand t) = fromAnyTemplate t
fromCreate _ = None
-- | Check if the command corresponds to a create and exercise command
-- for the given template.
fromCreateAndExercise : forall t c r. (Template t, Choice t c r) => Command -> Optional (t, c)
fromCreateAndExercise (CreateAndExerciseCommand t c) = (,) <$> fromAnyTemplate t <*> fromAnyChoice @t c
fromCreateAndExercise _ = None
-- | Check if the command corresponds to an exercise command
-- for the given template.
fromExercise : forall t c r. Choice t c r => Command -> Optional (ContractId t, c)
fromExercise (ExerciseCommand cid c) = (,) <$> fromAnyContractId cid <*> fromAnyChoice @t c
fromExercise _ = None
-- | Check if the command corresponds to an exercise by key command
-- for the given template.
fromExerciseByKey : forall t c r k. (Choice t c r, TemplateKey t k) => Command -> Optional (k, c)
fromExerciseByKey (ExerciseByKeyCommand tyRep k c)
| tyRep == templateTypeRep @t = (,) <$> fromAnyContractKey @t k <*> fromAnyChoice @t c
fromExerciseByKey _ = None
-- | A set of commands that are submitted as a single transaction.
data Commands = Commands
{ commandId : CommandId
, commands : [Command]
}
newtype StateT s m a = StateT { runStateT : s -> m (a, s) }
deriving Functor
liftStateT : Functor m => m a -> StateT s m a
liftStateT ma = StateT $ \s -> (,s) <$> ma
instance Action m => Applicative (StateT s m) where
pure a = StateT (\s -> pure (a, s))
f <*> x = f >>= (<$> x)
instance Action m => Action (StateT s m) where
StateT x >>= f = StateT $ \s -> do
(x', s') <- x s
runStateT (f x') s'
instance Applicative m => ActionState s (StateT s m) where
get = StateT $ \s -> pure (s, s)
put s = StateT $ const $ pure ((), s)
modify f = StateT $ \s -> pure ((), f s)
execStateT : Functor m => StateT s m a -> s -> m s
execStateT (StateT fa) = fmap snd . fa
zoom : Functor m => (t -> s) -> (t -> s -> t) -> StateT s m a -> StateT t m a
zoom r w (StateT smas) = StateT $ \t ->
smas (r t) <&> \(a, s) -> (a, w t s)
-- Must be kept in sync with Runner#freeTriggerSubmits
data TriggerF a =
GetTime (Time -> a)
| Submit ([Command], Text -> a)
deriving Functor
newtype TriggerSetup a = TriggerSetup { runTriggerSetup : Free TriggerF a }
deriving (Functor, Applicative, Action)
newtype TriggerRule s a = TriggerRule { runTriggerRule : StateT s (Free TriggerF) a }
deriving (Functor, Applicative, Action)
-- | Run a rule without running it. May lose information from the rule;
-- meant for testing purposes only.
simulateRule : TriggerRule s a -> Time -> s -> (s, [Commands], a)
simulateRule rule time s = (s', reverse cmds, a)
where ((a, s'), (cmds, _)) = runState (foldFree sim (runStateT (runTriggerRule rule) s)) ([], 0)
sim : TriggerF x -> State ([Commands], Int) x
sim (GetTime f) = pure (f time)
sim (Submit (cmds, f)) = do
(pastCmds, nextId) <- get
let nextIdShown = show nextId
put (Commands (CommandId nextIdShown) cmds :: pastCmds, nextId + 1)
pure $ f nextIdShown
deriving instance ActionState s (TriggerRule s)
-- | Low-level trigger actions.
class HasTime m => ActionTrigger m where
liftTF : TriggerF a -> m a
instance ActionTrigger TriggerSetup where
liftTF = TriggerSetup . lift
instance ActionTrigger (TriggerRule s) where
liftTF = TriggerRule . liftStateT . lift
instance HasTime TriggerSetup where
getTime = liftTF (GetTime identity)
instance HasTime (TriggerRule s) where
getTime = liftTF (GetTime identity)
submitCommands : ActionTrigger m => [Command] -> m CommandId
submitCommands cmds = liftTF (Submit (cmds, CommandId))

View File

@ -1,12 +0,0 @@
-- Hoogle documentation for Daml, generated by damlc
-- See Hoogle, http://www.haskell.org/hoogle/
-- Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates.
-- All rights reserved. Any unauthorized use, duplication or distribution is strictly prohibited.
-- | Daml Trigger library.
@url {{base-url}}
@package daml-trigger
@version 1.2.0
{{{body}}}

View File

@ -1,11 +0,0 @@
.. Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
.. SPDX-License-Identifier: Apache-2.0
.. _daml-trigger-api-docs:
Daml Trigger Library
====================
The Daml Trigger library defines the API used to declare a Daml trigger. See :doc:`/triggers/index`:: for more information on Daml triggers.
{{{body}}}

View File

@ -1,4 +0,0 @@
.. Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
.. SPDX-License-Identifier: Apache-2.0
{{{body}}}

View File

@ -1,23 +0,0 @@
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
load(
"//bazel_tools:scala.bzl",
"da_scala_library",
)
da_scala_library(
name = "metrics",
srcs = glob(["src/main/scala/**/*.scala"]),
scala_deps = [],
tags = ["maven_coordinates=com.daml:trigger-metrics:__VERSION__"],
visibility = [
"//visibility:public",
],
runtime_deps = [],
deps = [
"//observability/metrics",
"//observability/pekko-http-metrics",
"@maven//:io_opentelemetry_opentelemetry_api",
],
)

View File

@ -1,16 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf.engine.trigger.metrics
import com.daml.metrics.api.opentelemetry.OpenTelemetryMetricsFactory
import com.daml.metrics.http.DamlHttpMetrics
import io.opentelemetry.api.metrics.{Meter => OtelMeter}
case class TriggerServiceMetrics(otelMeter: OtelMeter) {
val openTelemetryFactory = new OpenTelemetryMetricsFactory(otelMeter)
val http = new DamlHttpMetrics(openTelemetryFactory, "trigger_service")
}

View File

@ -1,63 +0,0 @@
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
load(
"//bazel_tools:scala.bzl",
"da_scala_binary",
"da_scala_library",
"lf_scalacopts_stricter",
)
da_scala_library(
name = "trigger-runner-lib",
srcs = glob(["src/main/scala/**/*.scala"]),
scala_deps = [
"@maven//:com_github_scopt_scopt",
"@maven//:org_apache_pekko_pekko_actor",
"@maven//:org_apache_pekko_pekko_stream",
"@maven//:org_scalaz_scalaz_core",
"@maven//:org_typelevel_paiges_core",
],
scalacopts = lf_scalacopts_stricter,
tags = ["maven_coordinates=com.daml:trigger-runner:__VERSION__"],
visibility = ["//visibility:public"],
deps = [
"//canton:ledger_api_proto_scala",
"//daml-lf/archive:daml_lf_archive_reader",
"//daml-lf/data",
"//daml-lf/engine",
"//daml-lf/interpreter",
"//daml-lf/language",
"//daml-lf/transaction",
"//daml-script/converter",
"//ledger-service/cli-opts",
"//ledger/ledger-api-client",
"//ledger/ledger-api-common",
"//ledger/ledger-api-domain",
"//libs-scala/auth-utils",
"//libs-scala/contextualized-logging",
"//libs-scala/logging-entries",
"//libs-scala/rs-grpc-bridge",
"//libs-scala/rs-grpc-pekko",
"//libs-scala/scala-utils",
"//observability/tracing",
"@maven//:ch_qos_logback_logback_classic",
"@maven//:io_grpc_grpc_api",
"@maven//:io_netty_netty_handler",
"@maven//:org_slf4j_slf4j_api",
],
)
da_scala_binary(
name = "trigger-runner",
main_class = "com.daml.lf.engine.trigger.RunnerMain",
resources = ["src/main/resources/logback.xml"],
scalacopts = lf_scalacopts_stricter,
tags = ["ee-jar-license"],
visibility = ["//visibility:public"],
deps = [
":trigger-runner-lib",
],
)
exports_files(["src/main/resources/logback.xml"])

View File

@ -1,30 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<configuration>
<appender name="console" class="ch.qos.logback.core.ConsoleAppender">
<if condition='isDefined("LOG_FORMAT_JSON")'>
<then>
<encoder class="net.logstash.logback.encoder.LogstashEncoder"/>
</then>
<else>
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg %replace(, context: %marker){', context: $', ''} %n</pattern>
</encoder>
</else>
</if>
</appender>
<appender name="STDOUT" class="net.logstash.logback.appender.LoggingEventAsyncDisruptorAppender">
<appender-ref ref="console" />
</appender>
<logger name="io.netty" level="WARN" />
<logger name="io.grpc.netty" level="WARN" />
<logger name="pekko.event.slf4j" level="WARN" />
<logger name="daml.tracelog" level="DEBUG" />
<logger name="daml.warnings" level="WARN" />
<logger name="com.daml.lf.engine.trigger" level="DEBUG" />
<root level="${LOG_LEVEL_ROOT:-INFO}">
<appender-ref ref="STDOUT" />
</root>
</configuration>

View File

@ -1,692 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf
package engine
package trigger
import scalaz.std.either._
import scalaz.std.list._
import scalaz.std.option._
import scalaz.syntax.traverse._
import com.daml.lf.data.{BackStack, FrontStack, ImmArray}
import com.daml.lf.data.Ref._
import com.daml.lf.language.Ast._
import com.daml.lf.speedy.{ArrayList, SValue}
import com.daml.lf.speedy.SValue._
import com.daml.lf.value.Value.ContractId
import com.daml.ledger.api.v1.commands.{
CreateAndExerciseCommand,
CreateCommand,
ExerciseByKeyCommand,
ExerciseCommand,
Command => ApiCommand,
}
import com.daml.ledger.api.v1.completion.Completion
import com.daml.ledger.api.v1.event.{ArchivedEvent, CreatedEvent, Event, InterfaceView}
import com.daml.ledger.api.v1.transaction.Transaction
import com.daml.ledger.api.v1.value
import com.daml.ledger.api.validation.NoLoggingValueValidator
import com.daml.lf.language.LanguageVersionRangeOps.LanguageVersionRange
import com.daml.lf.language.StablePackages
import com.daml.lf.speedy.Command
import com.daml.lf.value.Value
import com.daml.platform.participant.util.LfEngineToApi.{
lfValueToApiRecord,
lfValueToApiValue,
toApiIdentifier,
}
import com.daml.script.converter.ConverterException
import com.daml.script.converter.Converter._
import com.daml.script.converter.Converter.Implicits._
import scala.collection.mutable
import scala.concurrent.duration.{FiniteDuration, MICROSECONDS}
// Convert from a Ledger API transaction to an SValue corresponding to a Message from the Daml.Trigger module
final class Converter(
compiledPackages: CompiledPackages,
triggerDef: TriggerDefinition,
) {
import Converter._
private[this] val valueTranslator = new preprocessing.ValueTranslator(
compiledPackages.pkgInterface,
requireV1ContractIdSuffix = false,
)
private[this] def validateRecord(r: value.Record): Either[String, Value.ValueRecord] =
NoLoggingValueValidator.validateRecord(r).left.map(_.getMessage)
private[this] def translateValue(ty: Type, value: Value): Either[String, SValue] =
valueTranslator
.strictTranslateValue(ty, value)
.left
.map(res => s"Failure to translate value: $res")
private[this] val triggerIds: TriggerIds = triggerDef.triggerIds
private[this] val stablePackages = StablePackages(
compiledPackages.compilerConfig.allowedLanguageVersions.majorVersion
)
private[this] val templateTypeRepTyCon = stablePackages.TemplateTypeRep
private[this] val anyTemplateTyCon = stablePackages.AnyTemplate
private[this] val anyViewTyCon = stablePackages.AnyView
private[this] val activeContractsTy = triggerIds.damlTriggerLowLevel("ActiveContracts")
private[this] val triggerConfigTy = triggerIds.damlTriggerLowLevel("TriggerConfig")
private[this] val triggerSetupArgumentsTy =
triggerIds.damlTriggerLowLevel("TriggerSetupArguments")
private[this] val triggerStateTy = triggerIds.damlTriggerInternal("TriggerState")
private[this] val anyContractIdTy = triggerIds.damlTriggerLowLevel("AnyContractId")
private[this] val archivedTy = triggerIds.damlTriggerLowLevel("Archived")
private[this] val commandIdTy = triggerIds.damlTriggerLowLevel("CommandId")
private[this] val completionTy = triggerIds.damlTriggerLowLevel("Completion")
private[this] val createdTy = triggerIds.damlTriggerLowLevel("Created")
private[this] val eventIdTy = triggerIds.damlTriggerLowLevel("EventId")
private[this] val eventTy = triggerIds.damlTriggerLowLevel("Event")
private[this] val failedTy = triggerIds.damlTriggerLowLevel("CompletionStatus.Failed")
private[this] val messageTy = triggerIds.damlTriggerLowLevel("Message")
private[this] val succeedTy = triggerIds.damlTriggerLowLevel("CompletionStatus.Succeeded")
private[this] val transactionIdTy = triggerIds.damlTriggerLowLevel("TransactionId")
private[this] val transactionTy = triggerIds.damlTriggerLowLevel("Transaction")
private[this] val createCommandTy = triggerIds.damlTriggerLowLevel("CreateCommand")
private[this] val exerciseCommandTy = triggerIds.damlTriggerLowLevel("ExerciseCommand")
private[this] val createAndExerciseCommandTy =
triggerIds.damlTriggerLowLevel("CreateAndExerciseCommand")
private[this] val exerciseByKeyCommandTy = triggerIds.damlTriggerLowLevel("ExerciseByKeyCommand")
private[this] val acsTy = triggerIds.damlTriggerInternal("ACS")
private[this] def fromTemplateTypeRep(tyCon: value.Identifier): SValue =
record(
templateTypeRepTyCon,
"getTemplateTypeRep" -> STypeRep(
TTyCon(
TypeConName(
PackageId.assertFromString(tyCon.packageId),
QualifiedName(
DottedName.assertFromString(tyCon.moduleName),
DottedName.assertFromString(tyCon.entityName),
),
)
)
),
)
private[this] def fromTransactionId(transactionId: String): SValue =
record(transactionIdTy, "unpack" -> SText(transactionId))
private[this] def fromEventId(eventId: String): SValue =
record(eventIdTy, "unpack" -> SText(eventId))
private[trigger] def fromCommandId(commandId: String): SValue =
record(commandIdTy, "unpack" -> SText(commandId))
private[this] def fromOptionalCommandId(commandId: String): SValue =
if (commandId.isEmpty)
SOptional(None)
else
SOptional(Some(fromCommandId(commandId)))
private[this] def fromAnyContractId(
templateId: value.Identifier,
contractId: String,
): SValue =
record(
anyContractIdTy,
"templateId" -> fromTemplateTypeRep(templateId),
"contractId" -> SContractId(ContractId.assertFromString(contractId)),
)
private def fromArchivedEvent(archived: ArchivedEvent): SValue =
record(
archivedTy,
"eventId" -> fromEventId(archived.eventId),
"contractId" -> fromAnyContractId(archived.getTemplateId, archived.contractId),
)
private[this] def fromRecord(typ: Type, record: value.Record): Either[String, SValue] =
for {
record <- validateRecord(record)
tmplPayload <- translateValue(typ, record)
} yield tmplPayload
private[this] def fromAnyTemplate(typ: Type, value: SValue) =
record(anyTemplateTyCon, "getAnyTemplate" -> SAny(typ, value))
private[this] def fromV20CreatedEvent(
created: CreatedEvent
): Either[String, SValue] =
for {
tmplId <- fromIdentifier(created.getTemplateId)
tmplType = TTyCon(tmplId)
tmplPayload <- fromRecord(tmplType, created.getCreateArguments)
} yield {
record(
createdTy,
"eventId" -> fromEventId(created.eventId),
"contractId" -> fromAnyContractId(created.getTemplateId, created.contractId),
"argument" -> fromAnyTemplate(tmplType, tmplPayload),
)
}
private[this] def fromAnyView(typ: Type, value: SValue) =
record(anyViewTyCon, "getAnyView" -> SAny(typ, value))
private[this] def fromInterfaceView(view: InterfaceView): Either[String, SOptional] =
for {
ifaceId <- fromIdentifier(view.getInterfaceId)
iface <- compiledPackages.pkgInterface.lookupInterface(ifaceId).left.map(_.pretty)
viewType = iface.view
viewValue <- view.viewValue.traverseU(fromRecord(viewType, _))
} yield SOptional(viewValue.map(fromAnyView(viewType, _)))
private[this] def fromV250CreatedEvent(
created: CreatedEvent
): Either[String, SValue] =
for {
tmplId <- fromIdentifier(created.getTemplateId)
tmplType = TTyCon(tmplId)
tmplPayload <- created.createArguments.traverseU(fromRecord(tmplType, _))
views <- created.interfaceViews.toList.traverseU(fromInterfaceView)
} yield {
record(
createdTy,
"eventId" -> fromEventId(created.eventId),
"contractId" -> fromAnyContractId(created.getTemplateId, created.contractId),
"argument" -> SOptional(tmplPayload.map(fromAnyTemplate(tmplType, _))),
"views" -> SList(views.to(FrontStack)),
)
}
private[this] val fromCreatedEvent: CreatedEvent => Either[String, SValue] =
if (triggerDef.version < Trigger.Version.`2.5.0`) {
fromV20CreatedEvent
} else {
fromV250CreatedEvent
}
private def fromEvent(ev: Event): Either[String, SValue] =
ev.event match {
case Event.Event.Archived(archivedEvent) =>
Right(
SVariant(
id = eventTy,
variant = EventVariant.ArchiveEventConstructor,
constructorRank = EventVariant.ArchiveEventConstructorRank,
value = fromArchivedEvent(archivedEvent),
)
)
case Event.Event.Created(createdEvent) =>
for {
event <- fromCreatedEvent(createdEvent)
} yield SVariant(
id = eventTy,
variant = EventVariant.CreatedEventConstructor,
constructorRank = EventVariant.CreatedEventConstructorRank,
value = event,
)
case _ => Left(s"Expected Archived or Created but got ${ev.event}")
}
def fromTransaction(t: Transaction): Either[String, SValue] =
for {
events <- t.events.to(ImmArray).traverse(fromEvent).map(xs => SList(FrontStack.from(xs)))
transactionId = fromTransactionId(t.transactionId)
commandId = fromOptionalCommandId(t.commandId)
} yield SVariant(
id = messageTy,
variant = MessageVariant.MTransactionVariant,
constructorRank = MessageVariant.MTransactionVariantRank,
value = record(
transactionTy,
"transactionId" -> transactionId,
"commandId" -> commandId,
"events" -> events,
),
)
def fromCompletion(c: Completion): Either[String, SValue] = {
val status: SValue =
if (c.getStatus.code == 0)
SVariant(
triggerIds.damlTriggerLowLevel("CompletionStatus"),
CompletionStatusVariant.SucceedVariantConstructor,
CompletionStatusVariant.SucceedVariantConstructorRank,
record(
succeedTy,
"transactionId" -> fromTransactionId(c.transactionId),
),
)
else
SVariant(
triggerIds.damlTriggerLowLevel("CompletionStatus"),
CompletionStatusVariant.FailVariantConstructor,
CompletionStatusVariant.FailVariantConstructorRank,
record(
failedTy,
"status" -> SInt64(c.getStatus.code.asInstanceOf[Long]),
"message" -> SText(c.getStatus.message),
),
)
Right(
SVariant(
messageTy,
MessageVariant.MCompletionConstructor,
MessageVariant.MCompletionConstructorRank,
record(
completionTy,
"commandId" -> fromCommandId(c.commandId),
"status" -> status,
),
)
)
}
def fromHeartbeat: SValue =
SVariant(
messageTy,
MessageVariant.MHeartbeatConstructor,
MessageVariant.MHeartbeatConstructorRank,
SUnit,
)
def fromActiveContracts(createdEvents: Seq[CreatedEvent]): Either[String, SValue] =
for {
events <- createdEvents
.to(ImmArray)
.traverse(fromCreatedEvent)
.map(xs => SList(FrontStack.from(xs)))
} yield record(activeContractsTy, "activeContracts" -> events)
private[this] def fromTriggerConfig(triggerConfig: TriggerRunnerConfig): SValue =
record(
triggerConfigTy,
"maxInFlightCommands" -> SValue.SInt64(triggerConfig.inFlightCommandBackPressureCount),
"maxActiveContracts" -> SValue.SInt64(triggerConfig.hardLimit.maximumActiveContracts),
)
def fromTriggerSetupArguments(
parties: TriggerParties,
createdEvents: Seq[CreatedEvent],
triggerConfig: TriggerRunnerConfig,
): Either[String, SValue] =
for {
acs <- fromActiveContracts(createdEvents)
actAs = SParty(parties.actAs)
readAs = SList(parties.readAs.view.map(SParty).to(FrontStack))
config = fromTriggerConfig(triggerConfig)
} yield record(
triggerSetupArgumentsTy,
"actAs" -> actAs,
"readAs" -> readAs,
"acs" -> acs,
"config" -> config,
)
def fromCommand(command: Command): SValue = command match {
case Command.Create(templateId, argument) =>
record(
createCommandTy,
"templateArg" -> fromAnyTemplate(
TTyCon(templateId),
argument,
),
)
case Command.ExerciseTemplate(_, contractId, _, choiceArg) =>
record(
exerciseCommandTy,
"contractId" -> contractId,
"choiceArg" -> choiceArg,
)
case Command.ExerciseInterface(_, contractId, _, choiceArg) =>
record(
exerciseCommandTy,
"contractId" -> contractId,
"choiceArg" -> choiceArg,
)
case Command.CreateAndExercise(templateId, createArg, _, choiceArg) =>
record(
createAndExerciseCommandTy,
"templateArg" -> fromAnyTemplate(TTyCon(templateId), createArg),
"choiceArg" -> choiceArg,
)
case Command.ExerciseByKey(templateId, contractKey, _, choiceArg) =>
record(
exerciseByKeyCommandTy,
"tplTypeRep" -> SValue.STypeRep(TTyCon(templateId)),
"contractKey" -> contractKey,
"choiceArg" -> choiceArg,
)
case _ =>
throw new ConverterException(
s"${command.getClass.getSimpleName} is an unexpected command type"
)
}
def fromCommands(commands: Seq[Command]): SValue = {
SList(commands.map(fromCommand).to(FrontStack))
}
def fromACS(activeContracts: Seq[CreatedEvent]): SValue = {
val createMapByTemplateId = mutable.HashMap.empty[value.Identifier, BackStack[CreatedEvent]]
for (create <- activeContracts) {
createMapByTemplateId += (create.getTemplateId -> (createMapByTemplateId.getOrElse(
create.getTemplateId,
BackStack.empty,
) :+ create))
}
record(
acsTy,
"activeContracts" -> SMap(
isTextMap = false,
createMapByTemplateId.iterator.map { case (templateId, creates) =>
fromTemplateTypeRep(templateId) -> SMap(
isTextMap = false,
creates.reverseIterator.map { event =>
val templateType = TTyCon(fromIdentifier(templateId).orConverterException)
val template = fromAnyTemplate(
templateType,
fromRecord(templateType, event.getCreateArguments).orConverterException,
)
fromAnyContractId(templateId, event.contractId) -> template
},
)
},
),
"pendingContracts" -> SMap(isTextMap = false),
)
}
def fromTriggerUpdateState(
createdEvents: Seq[CreatedEvent],
userState: SValue,
commandsInFlight: Map[String, Seq[Command]] = Map.empty,
parties: TriggerParties,
triggerConfig: TriggerRunnerConfig,
): SValue = {
val acs = fromACS(createdEvents)
val actAs = SParty(parties.actAs)
val readAs = SList(parties.readAs.map(SParty).to(FrontStack))
val config = fromTriggerConfig(triggerConfig)
record(
triggerStateTy,
"acs" -> acs,
"actAs" -> actAs,
"readAs" -> readAs,
"userState" -> userState,
"commandsInFlight" -> SMap(
isTextMap = false,
commandsInFlight.iterator.map { case (cmdId, cmds) =>
(fromCommandId(cmdId), fromCommands(cmds))
},
),
"config" -> config,
)
}
}
object Converter {
final case class AnyContractId(templateId: Identifier, contractId: ContractId)
final case class AnyTemplate(ty: Identifier, arg: SValue)
private final case class AnyChoice(name: ChoiceName, arg: SValue)
private final case class AnyContractKey(key: SValue)
object EventVariant {
// Those values should be kept consistent with type `Event` defined in
// triggers/daml/Daml/Trigger/LowLevel.daml
val CreatedEventConstructor = Name.assertFromString("CreatedEvent")
val CreatedEventConstructorRank = 0
val ArchiveEventConstructor = Name.assertFromString("ArchivedEvent")
val ArchiveEventConstructorRank = 1
}
object MessageVariant {
// Those values should be kept consistent with type `Message` defined in
// triggers/daml/Daml/Trigger/LowLevel.daml
val MTransactionVariant = Name.assertFromString("MTransaction")
val MTransactionVariantRank = 0
val MCompletionConstructor = Name.assertFromString("MCompletion")
val MCompletionConstructorRank = 1
val MHeartbeatConstructor = Name.assertFromString("MHeartbeat")
val MHeartbeatConstructorRank = 2
}
object CompletionStatusVariant {
// Those values should be kept consistent `CompletionStatus` defined in
// triggers/daml/Daml/Trigger/LowLevel.daml
val FailVariantConstructor = Name.assertFromString("Failed")
val FailVariantConstructorRank = 0
val SucceedVariantConstructor = Name.assertFromString("Succeeded")
val SucceedVariantConstructorRank = 1
}
def fromIdentifier(identifier: value.Identifier): Either[String, Identifier] =
for {
pkgId <- PackageId.fromString(identifier.packageId)
mod <- DottedName.fromString(identifier.moduleName)
name <- DottedName.fromString(identifier.entityName)
} yield Identifier(pkgId, QualifiedName(mod, name))
private def toLedgerRecord(v: SValue): Either[String, value.Record] =
lfValueToApiRecord(verbose = true, v.toUnnormalizedValue)
private def toLedgerValue(v: SValue): Either[String, value.Value] =
lfValueToApiValue(verbose = true, v.toUnnormalizedValue)
private def toIdentifier(v: SValue): Either[String, Identifier] =
v.expect(
"STypeRep",
{ case STypeRep(TTyCon(id)) =>
id
},
)
def toAnyContractId(v: SValue): Either[String, AnyContractId] =
v.expectE(
"AnyContractId",
{ case SRecord(_, _, ArrayList(stid, scid)) =>
for {
templateId <- toTemplateTypeRep(stid)
contractId <- toContractId(scid)
} yield AnyContractId(templateId, contractId)
},
)
private def toTemplateTypeRep(v: SValue): Either[String, Identifier] =
v.expectE(
"TemplateTypeRep",
{ case SRecord(_, _, ArrayList(id)) =>
toIdentifier(id)
},
)
def toFiniteDuration(value: SValue): Either[String, FiniteDuration] =
value.expect(
"RelTime",
{ case SRecord(_, _, ArrayList(SInt64(microseconds))) =>
FiniteDuration(microseconds, MICROSECONDS)
},
)
private def toRegisteredTemplate(v: SValue): Either[String, Identifier] =
v.expectE(
"RegisteredTemplate",
{ case SRecord(_, _, ArrayList(sttr)) =>
toTemplateTypeRep(sttr)
},
)
def toRegisteredTemplates(v: SValue): Either[String, Seq[Identifier]] =
v.expectE(
"list of RegisteredTemplate",
{ case SList(tpls) =>
tpls.traverse(toRegisteredTemplate).map(_.toImmArray.toSeq)
},
)
def toAnyTemplate(v: SValue): Either[String, AnyTemplate] =
v match {
case SRecord(_, _, ArrayList(SAny(TTyCon(tmplId), value))) =>
Right(AnyTemplate(tmplId, value))
case _ => Left(s"Expected AnyTemplate but got $v")
}
private def choiceArgTypeToChoiceName(choiceCons: TypeConName) = {
// This exploits the fact that in Daml, choice argument type names
// and choice names match up.
assert(choiceCons.qualifiedName.name.segments.length == 1)
choiceCons.qualifiedName.name.segments.head
}
private def toAnyChoice(v: SValue): Either[String, AnyChoice] =
v match {
case SRecord(_, _, ArrayList(SAny(TTyCon(choiceCons), choiceVal), _)) =>
Right(AnyChoice(choiceArgTypeToChoiceName(choiceCons), choiceVal))
case _ =>
Left(s"Expected AnyChoice but got $v")
}
private def toAnyContractKey(v: SValue): Either[String, AnyContractKey] =
v.expect(
"AnyContractKey",
{ case SRecord(_, _, ArrayList(SAny(_, v), _)) =>
AnyContractKey(v)
},
)
private def toCreate(v: SValue): Either[String, CreateCommand] =
v.expectE(
"CreateCommand",
{ case SRecord(_, _, ArrayList(sTpl)) =>
for {
anyTmpl <- toAnyTemplate(sTpl)
templateArg <- toLedgerRecord(anyTmpl.arg)
} yield CreateCommand(Some(toApiIdentifier(anyTmpl.ty)), Some(templateArg))
},
)
private def toExercise(v: SValue): Either[String, ExerciseCommand] =
v.expectE(
"ExerciseCommand",
{ case SRecord(_, _, ArrayList(sAnyContractId, sChoiceVal)) =>
for {
anyContractId <- toAnyContractId(sAnyContractId)
anyChoice <- toAnyChoice(sChoiceVal)
choiceArg <- toLedgerValue(anyChoice.arg)
} yield ExerciseCommand(
Some(toApiIdentifier(anyContractId.templateId)),
anyContractId.contractId.coid,
anyChoice.name,
Some(choiceArg),
)
},
)
private def toExerciseByKey(v: SValue): Either[String, ExerciseByKeyCommand] =
v.expectE(
"ExerciseByKeyCommand",
{ case SRecord(_, _, ArrayList(stplId, skeyVal, sChoiceVal)) =>
for {
tplId <- toTemplateTypeRep(stplId)
keyVal <- toAnyContractKey(skeyVal)
keyArg <- toLedgerValue(keyVal.key)
anyChoice <- toAnyChoice(sChoiceVal)
choiceArg <- toLedgerValue(anyChoice.arg)
} yield ExerciseByKeyCommand(
Some(toApiIdentifier(tplId)),
Some(keyArg),
anyChoice.name,
Some(choiceArg),
)
},
)
private def toCreateAndExercise(v: SValue): Either[String, CreateAndExerciseCommand] =
v.expectE(
"CreateAndExerciseCommand",
{ case SRecord(_, _, ArrayList(sTpl, sChoiceVal)) =>
for {
anyTmpl <- toAnyTemplate(sTpl)
templateArg <- toLedgerRecord(anyTmpl.arg)
anyChoice <- toAnyChoice(sChoiceVal)
choiceArg <- toLedgerValue(anyChoice.arg)
} yield CreateAndExerciseCommand(
Some(toApiIdentifier(anyTmpl.ty)),
Some(templateArg),
anyChoice.name,
Some(choiceArg),
)
},
)
private def toCommand(v: SValue): Either[String, ApiCommand] = {
v match {
case SVariant(_, "CreateCommand", _, createVal) =>
for {
create <- toCreate(createVal)
} yield ApiCommand().withCreate(create)
case SVariant(_, "ExerciseCommand", _, exerciseVal) =>
for {
exercise <- toExercise(exerciseVal)
} yield ApiCommand().withExercise(exercise)
case SVariant(_, "ExerciseByKeyCommand", _, exerciseByKeyVal) =>
for {
exerciseByKey <- toExerciseByKey(exerciseByKeyVal)
} yield ApiCommand().withExerciseByKey(exerciseByKey)
case SVariant(_, "CreateAndExerciseCommand", _, createAndExerciseVal) =>
for {
createAndExercise <- toCreateAndExercise(createAndExerciseVal)
} yield ApiCommand().withCreateAndExercise(createAndExercise)
case _ => Left(s"Expected a Command but got $v")
}
}
def toCommands(v: SValue): Either[String, Seq[ApiCommand]] =
for {
cmdValues <- v.expect(
"[Command]",
{ case SList(cmdValues) =>
cmdValues
},
)
commands <- cmdValues.traverse(toCommand)
} yield commands.toImmArray.toSeq
}
// Helper to create identifiers pointing to the Daml.Trigger module
final case class TriggerIds(triggerPackageId: PackageId) {
def damlTrigger(s: String): Identifier =
Identifier(
triggerPackageId,
QualifiedName(ModuleName.assertFromString("Daml.Trigger"), DottedName.assertFromString(s)),
)
def damlTriggerLowLevel(s: String): Identifier =
Identifier(
triggerPackageId,
QualifiedName(
ModuleName.assertFromString("Daml.Trigger.LowLevel"),
DottedName.assertFromString(s),
),
)
def damlTriggerInternal(s: String): Identifier =
Identifier(
triggerPackageId,
QualifiedName(
ModuleName.assertFromString("Daml.Trigger.Internal"),
DottedName.assertFromString(s),
),
)
}

View File

@ -1,193 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf.engine.trigger
import com.daml.ledger.api.v1.{value => api}
import com.daml.lf.speedy.{Pretty, SValue}
import org.typelevel.paiges.Doc
import org.typelevel.paiges.Doc.{char, fill, intercalate, str, text}
import scala.jdk.CollectionConverters._
object PrettyPrint {
def prettyApiIdentifier(id: api.Identifier): Doc =
text(id.moduleName) + char(':') + text(id.entityName) + char('@') + text(id.packageId)
def prettyApiValue(verbose: Boolean, maxListWidth: Option[Int] = None)(v: api.Value): Doc =
v.sum match {
case api.Value.Sum.Empty => Doc.empty
case api.Value.Sum.Int64(i) => str(i)
case api.Value.Sum.Numeric(d) => text(d)
case api.Value.Sum.Record(api.Record(mbId, fs)) =>
(mbId match {
case Some(id) if verbose => prettyApiIdentifier(id)
case _ => Doc.empty
}) +
char('{') &
fill(
text(", "),
fs.toList.map {
case api.RecordField(k, Some(v)) =>
text(k) & char('=') & prettyApiValue(verbose = true, maxListWidth)(v)
case _ => Doc.empty
},
) &
char('}')
case api.Value.Sum.Variant(api.Variant(mbId, variant, value)) =>
(mbId match {
case Some(id) if verbose => prettyApiIdentifier(id) + char(':')
case _ => Doc.empty
}) +
text(variant) + char('(') + value.fold(Doc.empty)(v =>
prettyApiValue(verbose = true, maxListWidth)(v)
) + char(')')
case api.Value.Sum.Enum(api.Enum(mbId, constructor)) =>
(mbId match {
case Some(id) if verbose => prettyApiIdentifier(id) + char(':')
case _ => Doc.empty
}) + text(constructor)
case api.Value.Sum.Text(t) => char('"') + text(t) + char('"')
case api.Value.Sum.ContractId(acoid) => text(acoid)
case api.Value.Sum.Unit(_) => text("<unit>")
case api.Value.Sum.Bool(b) => str(b)
case api.Value.Sum.List(api.List(lst)) =>
maxListWidth match {
case Some(maxWidth) if lst.size > maxWidth =>
char('[') + intercalate(
text(", "),
lst.take(maxWidth).map(prettyApiValue(verbose = true, maxListWidth)(_)),
) + text(s", ...${lst.size - maxWidth} elements truncated...") + char(']')
case _ =>
char('[') + intercalate(
text(", "),
lst.map(prettyApiValue(verbose = true, maxListWidth)(_)),
) + char(']')
}
case api.Value.Sum.Timestamp(t) => str(t)
case api.Value.Sum.Date(days) => str(days)
case api.Value.Sum.Party(p) => char('\'') + str(p) + char('\'')
case api.Value.Sum.Optional(api.Optional(Some(v1))) =>
text("Option(") + prettyApiValue(verbose, maxListWidth)(v1) + char(')')
case api.Value.Sum.Optional(api.Optional(None)) => text("None")
case api.Value.Sum.Map(api.Map(entries)) =>
val list = entries.map {
case api.Map.Entry(k, Some(v)) =>
text(k) + text(" -> ") + prettyApiValue(verbose, maxListWidth)(v)
case _ => Doc.empty
}
text("TextMap(") + intercalate(text(", "), list) + text(")")
case api.Value.Sum.GenMap(api.GenMap(entries)) =>
val list = entries.map {
case api.GenMap.Entry(Some(k), Some(v)) =>
prettyApiValue(verbose, maxListWidth)(k) + text(" -> ") + prettyApiValue(
verbose,
maxListWidth,
)(v)
case _ => Doc.empty
}
text("GenMap(") + intercalate(text(", "), list) + text(")")
}
def prettySValue(v: SValue): Doc = v match {
case SValue.SPAP(_, _, _) =>
text("...")
case r: SValue.SRecord =>
Pretty.prettyIdentifier(r.id) + char('{') & fill(
text(", "),
r.fields.toSeq.zip(r.values.asScala).map { case (k, v) =>
text(k) & char('=') & prettySValue(v)
},
) & char('}')
case SValue.SStruct(fieldNames, values) =>
char('<') + fill(
text(", "),
fieldNames.names.zip(values.asScala).toSeq.map { case (k, v) =>
text(k) + char('=') + prettySValue(v)
},
) + char('>')
case SValue.SVariant(id, variant, _, value) =>
Pretty.prettyIdentifier(id) + char(':') + text(variant) + char('(') + prettySValue(
value
) + char(')')
case SValue.SEnum(id, constructor, _) =>
Pretty.prettyIdentifier(id) + char(':') + text(constructor)
case SValue.SOptional(Some(value)) =>
text("Option(") + prettySValue(value) + char(')')
case SValue.SOptional(None) =>
text("None")
case SValue.SList(list) =>
char('[') + intercalate(text(", "), list.map(prettySValue).toImmArray.toSeq) + char(']')
case SValue.SMap(isTextMap, entries) =>
val list = entries.map { case (k, v) =>
prettySValue(k) + text(" -> ") + prettySValue(v)
}
text(if (isTextMap) "TextMap(" else "GenMap(") + intercalate(text(", "), list) + text(")")
case SValue.SAny(ty, value) =>
text("to_any") + char('@') + text(ty.pretty) + prettySValue(value)
case SValue.SInt64(value) =>
str(value)
case SValue.SNumeric(value) =>
str(value)
case SValue.SBigNumeric(value) =>
str(value)
case SValue.SText(value) =>
text(s"$value")
case SValue.STimestamp(value) =>
str(value)
case SValue.SParty(value) =>
char('\'') + str(value) + char('\'')
case SValue.SBool(value) =>
str(value)
case SValue.SUnit =>
text("<unit>")
case SValue.SDate(value) =>
str(value)
case SValue.SContractId(value) =>
text(value.coid)
case SValue.STypeRep(ty) =>
text(ty.pretty)
case SValue.SToken =>
text("Token")
}
}

View File

@ -1,391 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf.engine.trigger
import ch.qos.logback.classic.Level
import java.nio.file.{Path, Paths}
import java.time.Duration
import com.daml.lf.data.Ref
import com.daml.ledger.api.domain
import com.daml.ledger.api.tls.TlsConfiguration
import com.daml.ledger.api.tls.TlsConfigurationCli
import com.daml.ledger.client.LedgerClient
import com.daml.lf.engine.trigger.TriggerRunnerConfig.DefaultTriggerRunnerConfig
import com.daml.lf.language.LanguageMajorVersion
import com.daml.platform.services.time.TimeProviderType
import com.daml.lf.speedy.Compiler
import scala.concurrent.{ExecutionContext, Future}
sealed trait LogEncoder
object LogEncoder {
object Plain extends LogEncoder
object Json extends LogEncoder
}
sealed trait CompilerConfigBuilder extends Product with Serializable {
def build(majorLanguageVersion: LanguageMajorVersion): Compiler.Config
}
object CompilerConfigBuilder {
final case object Default extends CompilerConfigBuilder {
override def build(majorLanguageVersion: LanguageMajorVersion): Compiler.Config =
Compiler.Config.Default(majorLanguageVersion)
}
final case object Dev extends CompilerConfigBuilder {
override def build(majorLanguageVersion: LanguageMajorVersion): Compiler.Config =
Compiler.Config.Dev(majorLanguageVersion)
}
}
case class RunnerConfig(
darPath: Path,
// If defined, we will only list the triggers in the DAR and exit.
listTriggers: Option[Boolean],
triggerIdentifier: String,
ledgerHost: String,
ledgerPort: Int,
ledgerClaims: ClaimsSpecification,
maxInboundMessageSize: Int,
// optional so we can detect if both --static-time and --wall-clock-time are passed.
timeProviderType: Option[TimeProviderType],
commandTtl: Duration,
accessTokenFile: Option[Path],
applicationId: Option[Ref.ApplicationId],
tlsConfig: TlsConfiguration,
compilerConfigBuilder: CompilerConfigBuilder,
majorLanguageVersion: LanguageMajorVersion,
triggerConfig: TriggerRunnerConfig,
rootLoggingLevel: Option[Level],
logEncoder: LogEncoder,
) {
private def updatePartySpec(f: TriggerParties => TriggerParties): RunnerConfig =
if (ledgerClaims == null) {
copy(ledgerClaims = PartySpecification(f(TriggerParties.Empty)))
} else
ledgerClaims match {
case PartySpecification(claims) =>
copy(ledgerClaims = PartySpecification(f(claims)))
case _: UserSpecification =>
throw new IllegalArgumentException(
s"Must specify either --ledger-party and --ledger-readas or --ledger-userid but not both"
)
}
private def updateActAs(party: Ref.Party): RunnerConfig =
updatePartySpec(spec => spec.copy(actAsOpt = Some(party)))
private def updateReadAs(parties: Seq[Ref.Party]): RunnerConfig =
updatePartySpec(spec => spec.copy(readAs = spec.readAs ++ parties))
private def updateUser(userId: Ref.UserId): RunnerConfig =
if (ledgerClaims == null) {
copy(ledgerClaims = UserSpecification(userId))
} else
ledgerClaims match {
case UserSpecification(_) => copy(ledgerClaims = UserSpecification(userId))
case _: PartySpecification =>
throw new IllegalArgumentException(
s"Must specify either --ledger-party and --ledger-readas or --ledger-userid but not both"
)
}
}
sealed abstract class ClaimsSpecification {
def resolveClaims(client: LedgerClient)(implicit ec: ExecutionContext): Future[TriggerParties]
}
final case class PartySpecification(claims: TriggerParties) extends ClaimsSpecification {
override def resolveClaims(client: LedgerClient)(implicit
ec: ExecutionContext
): Future[TriggerParties] =
Future.successful(claims)
}
final case class UserSpecification(userId: Ref.UserId) extends ClaimsSpecification {
override def resolveClaims(
client: LedgerClient
)(implicit ec: ExecutionContext): Future[TriggerParties] = for {
user <- client.userManagementClient.getUser(userId)
primaryParty <- user.primaryParty.fold[Future[Ref.Party]](
Future.failed(
new IllegalArgumentException(
s"User $user has no primary party. Specify a party explicitly via --ledger-party"
)
)
)(Future.successful)
rights <- client.userManagementClient.listUserRights(userId)
readAs = rights.collect { case domain.UserRight.CanReadAs(party) =>
party
}.toSet
actAs = rights.collect { case domain.UserRight.CanActAs(party) =>
party
}.toSet
_ <-
if (actAs.contains(primaryParty)) {
Future.unit
} else {
Future.failed(
new IllegalArgumentException(
s"User $user has primary party $primaryParty but no actAs claims for that party. Either change the user rights or specify a different party via --ledger-party"
)
)
}
readers = (readAs ++ actAs) - primaryParty
} yield TriggerParties(primaryParty, readers)
}
final case class TriggerParties(
actAsOpt: Option[Ref.Party],
readAs: Set[Ref.Party],
) {
lazy val actAs = actAsOpt.get
lazy val readers: Set[Ref.Party] = readAs + actAs
}
object TriggerParties {
def apply(actAs: Ref.Party, readAs: Set[Ref.Party]): TriggerParties =
new TriggerParties(Some(actAs), readAs)
def Empty = new TriggerParties(None, Set.empty)
}
object RunnerConfig {
implicit val userRead: scopt.Read[Ref.UserId] = scopt.Read.reads { s =>
Ref.UserId.fromString(s).fold(e => throw new IllegalArgumentException(e), identity)
}
private[trigger] val DefaultMaxInboundMessageSize: Int = 4194304
private[trigger] val DefaultTimeProviderType: TimeProviderType = TimeProviderType.WallClock
private[trigger] val DefaultApplicationId: Some[Ref.ApplicationId] =
Some(Ref.ApplicationId.assertFromString("daml-trigger"))
private[trigger] val DefaultCompilerConfigBuilder: CompilerConfigBuilder =
CompilerConfigBuilder.Default
private[trigger] val DefaultMajorLanguageVersion: LanguageMajorVersion = LanguageMajorVersion.V1
@SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements")) // scopt builders
private val parser = new scopt.OptionParser[RunnerConfig]("trigger-runner") {
head("trigger-runner")
opt[String]("dar")
.required()
.action((f, c) => c.copy(darPath = Paths.get(f)))
.text("Path to the dar file containing the trigger")
opt[String]("trigger-name")
.action((t, c) => c.copy(triggerIdentifier = t))
.text("Identifier of the trigger that should be run in the format Module.Name:Entity.Name")
opt[String]("ledger-host")
.action((t, c) => c.copy(ledgerHost = t))
.text("Ledger hostname")
opt[Int]("ledger-port")
.action((t, c) => c.copy(ledgerPort = t))
.text("Ledger port")
opt[String]("ledger-party")
.action((t, c) => c.updateActAs(Ref.Party.assertFromString(t)))
.text("""The party the trigger can act as.
|Mutually exclusive with --ledger-user.""".stripMargin)
opt[Seq[String]]("ledger-readas")
.action((t, c) => c.updateReadAs(t.map(Ref.Party.assertFromString)))
.unbounded()
.text(
"""A comma-separated list of parties the trigger can read as.
|Can be specified multiple-times.
|Mutually exclusive with --ledger-user.""".stripMargin
)
opt[Ref.UserId]("ledger-user")
.action((u, c) => c.updateUser(u))
.unbounded()
.text(
"""The id of the user the trigger should run as.
|This is equivalent to specifying the primary party
|of the user as --ledger-party and
|all actAs and readAs claims of the user other than the
|primary party as --ledger-readas.
|The user must have a primary party.
|Mutually exclusive with --ledger-party and --ledger-readas.""".stripMargin
)
opt[Int]("max-inbound-message-size")
.action((x, c) => c.copy(maxInboundMessageSize = x))
.optional()
.text(
s"Optional max inbound message size in bytes. Defaults to ${DefaultMaxInboundMessageSize}"
)
opt[Unit]('w', "wall-clock-time")
.action { (_, c) =>
setTimeProviderType(c, TimeProviderType.WallClock)
}
.text("Use wall clock time (UTC).")
opt[Unit]('s', "static-time")
.action { (_, c) =>
setTimeProviderType(c, TimeProviderType.Static)
}
.text("Use static time.")
opt[Long]("ttl")
.action { (t, c) =>
c.copy(commandTtl = Duration.ofSeconds(t))
}
.text("TTL in seconds used for commands emitted by the trigger. Defaults to 30s.")
opt[String]("access-token-file")
.action { (f, c) =>
c.copy(accessTokenFile = Some(Paths.get(f)))
}
.text(
"File from which the access token will be read, required to interact with an authenticated ledger"
)
opt[String]("application-id")
.action { (appId, c) =>
c.copy(applicationId =
Some(appId).filterNot(_.isEmpty).map(Ref.ApplicationId.assertFromString)
)
}
.text(s"Application ID used to submit commands. Defaults to ${DefaultApplicationId}")
opt[Unit]('v', "verbose")
.text("Root logging level -> DEBUG")
.action((_, cli) => cli.copy(rootLoggingLevel = Some(Level.DEBUG)))
opt[Unit]("debug")
.text("Root logging level -> DEBUG")
.action((_, cli) => cli.copy(rootLoggingLevel = Some(Level.DEBUG)))
implicit val levelRead: scopt.Read[Level] = scopt.Read.reads(Level.valueOf)
opt[Level]("log-level-root")
.text("Log-level of the root logger")
.valueName("<LEVEL>")
.action((level, cli) => cli.copy(rootLoggingLevel = Some(level)))
opt[String]("log-encoder")
.text("Log encoder: plain|json")
.action {
case ("json", cli) => cli.copy(logEncoder = LogEncoder.Json)
case ("plain", cli) => cli.copy(logEncoder = LogEncoder.Plain)
case (other, _) =>
throw new IllegalArgumentException(s"Unsupported logging encoder $other")
}
opt[Long]("max-batch-size")
.optional()
.text(
s"maximum number of messages processed between two high-level rule triggers. Defaults to ${DefaultTriggerRunnerConfig.maximumBatchSize}"
)
.action((size, cli) =>
if (size > 0) cli.copy(triggerConfig = cli.triggerConfig.copy(maximumBatchSize = size))
else throw new IllegalArgumentException(s"batch size must be strictly positive")
)
opt[Unit]("dev-mode-unsafe")
.action((_, c) => c.copy(compilerConfigBuilder = CompilerConfigBuilder.Dev))
.optional()
.text(
"Turns on development mode. Development mode allows development versions of Daml-LF language."
)
.hidden()
implicit val majorLanguageVersionRead: scopt.Read[LanguageMajorVersion] =
scopt.Read.reads(s =>
LanguageMajorVersion.fromString(s) match {
case Some(v) => v
case None => throw new IllegalArgumentException(s"$s is not a valid major LF version")
}
)
opt[LanguageMajorVersion]("lf-major-version")
.action((v, c) => c.copy(majorLanguageVersion = v))
.optional()
.text(
"The major version of LF to use."
)
// TODO(#17366): unhide once LF v2 has a stable version
.hidden()
TlsConfigurationCli.parse(this, colSpacer = " ")((f, c) =>
c.copy(tlsConfig = f(c.tlsConfig))
)
help("help").text("Print this usage text")
cmd("list")
.action((_, c) => c.copy(listTriggers = Some(false)))
.text("List the triggers in the DAR.")
cmd("verbose-list")
.hidden()
.action((_, c) => c.copy(listTriggers = Some(true)))
checkConfig(c =>
c.listTriggers match {
case Some(_) =>
// I do not want to break the trigger CLI and require a
// "run" command so I cant make these options required
// in general. Therefore, we do this check in checkConfig.
success
case None =>
if (c.triggerIdentifier == null) {
failure("Missing option --trigger-name")
} else if (c.ledgerHost == null) {
failure("Missing option --ledger-host")
} else if (c.ledgerPort == 0) {
failure("Missing option --ledger-port")
} else if (c.ledgerClaims == null) {
failure("Missing option --ledger-party or --ledger-user")
} else {
c.ledgerClaims match {
case PartySpecification(TriggerParties(actAs, _)) if actAs.isEmpty =>
failure("Missing option --ledger-party")
case _ => success
}
}
}
)
}
private def setTimeProviderType(
config: RunnerConfig,
timeProviderType: TimeProviderType,
): RunnerConfig = {
if (config.timeProviderType.exists(_ != timeProviderType)) {
throw new IllegalStateException(
"Static time mode (`-s`/`--static-time`) and wall-clock time mode (`-w`/`--wall-clock-time`) are mutually exclusive. The time mode must be unambiguous."
)
}
config.copy(timeProviderType = Some(timeProviderType))
}
def parse(args: Array[String]): Option[RunnerConfig] =
parser.parse(
args,
Empty,
)
val Empty: RunnerConfig = RunnerConfig(
darPath = null,
listTriggers = None,
triggerIdentifier = null,
ledgerHost = null,
ledgerPort = 0,
ledgerClaims = null,
maxInboundMessageSize = DefaultMaxInboundMessageSize,
timeProviderType = None,
commandTtl = Duration.ofSeconds(30L),
accessTokenFile = None,
tlsConfig = TlsConfiguration(enabled = false, None, None, None),
applicationId = DefaultApplicationId,
compilerConfigBuilder = DefaultCompilerConfigBuilder,
majorLanguageVersion = DefaultMajorLanguageVersion,
triggerConfig = DefaultTriggerRunnerConfig,
rootLoggingLevel = None,
logEncoder = LogEncoder.Plain,
)
}

View File

@ -1,135 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf.engine.trigger
import java.io.File
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.stream._
import ch.qos.logback.classic.Level
import com.daml.auth.TokenHolder
import com.daml.grpc.adapter.PekkoExecutionSequencerPool
import com.daml.ledger.client.LedgerClient
import com.daml.ledger.client.configuration.{
CommandClientConfiguration,
LedgerClientChannelConfiguration,
LedgerClientConfiguration,
LedgerIdRequirement,
}
import com.daml.lf.archive.{Dar, DarDecoder}
import com.daml.lf.data.Ref.{Identifier, PackageId, QualifiedName}
import com.daml.lf.language.Ast._
import com.daml.lf.language.PackageInterface
import com.daml.scalautil.Statement.discard
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, ExecutionContext, Future}
object RunnerMain {
private def listTriggers(
darPath: File,
dar: Dar[(PackageId, Package)],
verbose: Boolean,
): Unit = {
println(s"Listing triggers in $darPath:")
val pkgInterface = PackageInterface(dar.all.toMap)
val (mainPkgId, mainPkg) = dar.main
for {
mod <- mainPkg.modules.values
defName <- mod.definitions.keys
qualifiedName = QualifiedName(mod.name, defName)
triggerId = Identifier(mainPkgId, qualifiedName)
} {
Trigger.detectTriggerDefinition(pkgInterface, triggerId).foreach {
case TriggerDefinition(_, ty, version, level, _) =>
if (verbose)
println(
s" $qualifiedName\t(type = ${ty.pretty}, level = $level, version = $version)"
)
else
println(s" $qualifiedName")
}
}
}
def main(args: Array[String]): Unit = {
RunnerConfig.parse(args) match {
case None => sys.exit(1)
case Some(config) => {
config.rootLoggingLevel.foreach(setLoggingLevel)
config.logEncoder match {
case LogEncoder.Plain =>
case LogEncoder.Json =>
discard(System.setProperty("LOG_FORMAT_JSON", "true"))
}
val dar: Dar[(PackageId, Package)] =
DarDecoder.assertReadArchiveFromFile(config.darPath.toFile)
config.listTriggers.foreach { verbose =>
listTriggers(config.darPath.toFile, dar, verbose)
sys.exit(0)
}
val triggerId: Identifier =
Identifier(dar.main._1, QualifiedName.assertFromString(config.triggerIdentifier))
val system: ActorSystem = ActorSystem("TriggerRunner")
implicit val materializer: Materializer = Materializer(system)
val sequencer = new PekkoExecutionSequencerPool("TriggerRunnerPool")(system)
implicit val ec: ExecutionContext = system.dispatcher
val tokenHolder = config.accessTokenFile.map(new TokenHolder(_))
// We probably want to refresh the token at some point but given that triggers
// are expected to be written such that they can be killed and restarted at
// any time it would in principle also be fine to just have the auth failure due
// to an expired token tear the trigger down and have some external monitoring process (e.g. systemd)
// restart it.
val clientConfig = LedgerClientConfiguration(
applicationId = config.applicationId.getOrElse(""),
ledgerIdRequirement = LedgerIdRequirement.none,
commandClient =
CommandClientConfiguration.default.copy(defaultDeduplicationTime = config.commandTtl),
token = tokenHolder.flatMap(_.token),
)
val channelConfig = LedgerClientChannelConfiguration(
sslContext = config.tlsConfig.client(),
maxInboundMessageSize = config.maxInboundMessageSize,
)
val flow: Future[Unit] = for {
client <- LedgerClient.singleHost(
config.ledgerHost,
config.ledgerPort,
clientConfig,
channelConfig,
)(ec, sequencer)
parties <- config.ledgerClaims.resolveClaims(client)
_ <- Runner.run(
dar,
triggerId,
client,
config.timeProviderType.getOrElse(RunnerConfig.DefaultTimeProviderType),
config.applicationId,
parties,
config.compilerConfigBuilder.build(config.majorLanguageVersion),
config.triggerConfig,
)
} yield ()
flow.onComplete(_ => system.terminate())
Await.result(flow, Duration.Inf)
}
}
}
private def setLoggingLevel(level: Level): Unit = {
discard(System.setProperty("LOG_LEVEL_ROOT", level.levelStr))
}
}

View File

@ -1,15 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf.engine.trigger
import org.apache.pekko.stream.{Inlet, Outlet, Shape}
import scala.collection.immutable.Seq
private[trigger] final case class SourceShape2[L, R](out1: Outlet[L], out2: Outlet[R])
extends Shape {
override val inlets: Seq[Inlet[_]] = Seq.empty
override val outlets: Seq[Outlet[_]] = Seq(out1, out2)
override def deepCopy() = copy(out1.carbonCopy(), out2.carbonCopy())
}

View File

@ -1,205 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf.engine.trigger
import com.daml.lf.data.{BackStack, NoCopy}
import com.daml.lf.engine.trigger.Runner.Implicits._
import com.daml.lf.engine.trigger.TriggerLogContext._
import com.daml.logging.{ContextualizedLogger, LoggingContextOf}
import com.daml.logging.LoggingContextOf.label
import com.daml.logging.entries.{LoggingEntries, LoggingValue}
import java.util.UUID
object ToLoggingContext {
implicit def `TriggerLogContext to LoggingContextOf[Trigger]`(implicit
triggerContext: TriggerLogContext
): LoggingContextOf[Trigger] = {
val parentEntries = if (triggerContext.span.parent.isEmpty) {
LoggingEntries.empty
} else if (triggerContext.span.parent.size == 1) {
LoggingEntries("parent" -> triggerContext.span.parent.head)
} else {
LoggingEntries("parent" -> triggerContext.span.parent)
}
val spanEntries = LoggingEntries(
"span" -> LoggingValue.Nested(
LoggingEntries(
"name" -> triggerContext.span.path.toImmArray.foldLeft("trigger")((path, name) =>
s"$path.$name"
),
"id" -> triggerContext.span.id,
) ++ parentEntries
)
)
LoggingContextOf
.withEnrichedLoggingContext(
label[Trigger],
"trigger" -> LoggingValue.Nested(LoggingEntries(triggerContext.entries: _*) ++ spanEntries),
)(triggerContext.loggingContext)
.run(identity)
}
}
final class TriggerLogContext private (
private[trigger] val loggingContext: LoggingContextOf[Trigger],
private[trigger] val entries: Seq[(String, LoggingValue)],
private[trigger] val span: TriggerLogSpan,
private[trigger] val callback: (String, TriggerLogContext) => Unit,
) extends NoCopy {
import ToLoggingContext._
def enrichTriggerContext[A](
additionalEntries: (String, LoggingValue)*
)(f: TriggerLogContext => A): A = {
f(new TriggerLogContext(loggingContext, entries ++ additionalEntries, span, callback))
}
def nextSpan[A](
name: String,
additionalEntries: (String, LoggingValue)*
)(f: TriggerLogContext => A)(implicit
logger: ContextualizedLogger
): A = {
val context = new TriggerLogContext(
loggingContext,
entries ++ additionalEntries,
span.nextSpan(name),
callback,
)
try {
context.logInfo("span entry")
f(context)
} finally {
context.logInfo("span exit")
}
}
def childSpan[A](
name: String,
additionalEntries: (String, LoggingValue)*
)(f: TriggerLogContext => A)(implicit
logger: ContextualizedLogger
): A = {
val context = new TriggerLogContext(
loggingContext,
entries ++ additionalEntries,
span.childSpan(name),
callback,
)
try {
context.logInfo("span entry")
f(context)
} finally {
context.logInfo("span exit")
}
}
def groupWith(contexts: TriggerLogContext*): TriggerLogContext = {
val groupEntries = contexts.foldLeft(entries.toSet) { case (entries, context) =>
entries ++ context.entries.toSet
}
val groupSpans = contexts.foldLeft(span) { case (span, context) =>
span.groupWith(context.span)
}
new TriggerLogContext(loggingContext, groupEntries.toSeq, groupSpans, callback)
}
def logError(message: String, additionalEntries: (String, LoggingValue)*)(implicit
logger: ContextualizedLogger
): Unit = {
enrichTriggerContext(additionalEntries: _*) { implicit triggerContext: TriggerLogContext =>
callback(message, triggerContext)
logger.error(message)
}
}
def logWarning(message: String, additionalEntries: (String, LoggingValue)*)(implicit
logger: ContextualizedLogger
): Unit = {
enrichTriggerContext(additionalEntries: _*) { implicit triggerContext: TriggerLogContext =>
callback(message, triggerContext)
logger.warn(message)
}
}
def logInfo(message: String, additionalEntries: (String, LoggingValue)*)(implicit
logger: ContextualizedLogger
): Unit = {
enrichTriggerContext(additionalEntries: _*) { implicit triggerContext: TriggerLogContext =>
callback(message, triggerContext)
logger.info(message)
}
}
def logDebug(message: String, additionalEntries: (String, LoggingValue)*)(implicit
logger: ContextualizedLogger
): Unit = {
enrichTriggerContext(additionalEntries: _*) { implicit triggerContext: TriggerLogContext =>
callback(message, triggerContext)
logger.debug(message)
}
}
def logTrace(message: String, additionalEntries: (String, LoggingValue)*)(implicit
logger: ContextualizedLogger
): Unit = {
enrichTriggerContext(additionalEntries: _*) { implicit triggerContext: TriggerLogContext =>
callback(message, triggerContext)
logger.trace(message)
}
}
}
object TriggerLogContext {
def newRootSpan[A](
span: String,
entries: (String, LoggingValue)*
)(f: TriggerLogContext => A)(implicit loggingContext: LoggingContextOf[Trigger]): A = {
new TriggerLogContext(
loggingContext,
entries,
TriggerLogSpan(BackStack(span)),
(_, _) => (),
).enrichTriggerContext()(f)
}
private[trigger] def newRootSpanWithCallback[A](
span: String,
callback: (String, TriggerLogContext) => Unit,
entries: (String, LoggingValue)*
)(f: TriggerLogContext => A)(implicit loggingContext: LoggingContextOf[Trigger]): A = {
new TriggerLogContext(
loggingContext,
entries,
TriggerLogSpan(BackStack(span)),
callback,
).enrichTriggerContext()(f)
}
private[trigger] final case class TriggerLogSpan(
path: BackStack[String],
id: UUID = UUID.randomUUID(),
parent: Set[UUID] = Set.empty,
) {
def nextSpan(name: String): TriggerLogSpan = {
val basePath = path.pop.fold(BackStack.empty[String])(_._1)
TriggerLogSpan(basePath :+ name, parent = parent)
}
def childSpan(name: String): TriggerLogSpan = {
TriggerLogSpan(path :+ name, parent = Set(id))
}
def groupWith(span: TriggerLogSpan): TriggerLogSpan = {
copy(parent = parent + span.id)
}
}
}

View File

@ -1,83 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf.engine.trigger
import scala.concurrent.duration._
/** Trigger hard limits. If any of these values are exceeded, then the current trigger instance will throw a
* `TriggerHardLimitException` and stop running.
*
* @param maximumActiveContracts Maximum number of active contracts that we will store at any point in time.
* @param inFlightCommandOverflowCount When the number of in-flight command submissions exceeds this value, then we
* kill the trigger instance by throwing an InFlightCommandOverflowException.
* @param allowInFlightCommandOverflows flag to control whether we allow in-flight command overflows or not.
* @param ruleEvaluationTimeout If the trigger rule evaluator takes longer than this timeout value, then we throw a
* TriggerRuleEvaluationTimeout.
* @param stepInterpreterTimeout If the trigger rule step evaluator (during rule evaluation) takes longer than this
* timeout value, then we throw a TriggerRuleStepInterpretationTimeout.
* @param allowTriggerTimeouts flag to control whether we allow rule evaluation and step interpreter timeouts or not.
*/
final case class TriggerRunnerHardLimits(
maximumActiveContracts: Long,
inFlightCommandOverflowCount: Int,
allowInFlightCommandOverflows: Boolean,
ruleEvaluationTimeout: FiniteDuration,
stepInterpreterTimeout: FiniteDuration,
allowTriggerTimeouts: Boolean,
)
/** @param parallelism The number of submitSingleCommand invocations each trigger will attempt to execute in parallel.
* Note that this does not in any way bound the number of already-submitted, but not completed,
* commands that may be pending.
* @param maxRetries Maximum number of retries when the ledger client fails an API command submission.
* @param maxSubmissionRequests Used to control rate at which we throttle ledger client submission requests.
* @param maxSubmissionDuration Used to control rate at which we throttle ledger client submission requests.
* @param inFlightCommandBackPressureCount When the number of in-flight command submissions exceeds this value, then we
* enable Daml rule evaluation to apply backpressure (by failing emitCommands
* evaluations).
* @param submissionFailureQueueSize Size of the queue holding ledger API command submission failures.
* @param maximumBatchSize Maximum number of messages triggers will batch (for rule evaluation/processing).
* @param batchingDuration Period of time we will wait before emitting a message batch (for rule evaluation/processing).
*/
final case class TriggerRunnerConfig(
parallelism: Int,
maxRetries: Int,
maxSubmissionRequests: Int,
maxSubmissionDuration: FiniteDuration,
inFlightCommandBackPressureCount: Long,
submissionFailureQueueSize: Int,
maximumBatchSize: Long,
batchingDuration: FiniteDuration,
hardLimit: TriggerRunnerHardLimits,
)
object TriggerRunnerConfig {
val DefaultTriggerRunnerConfig: TriggerRunnerConfig = {
val parallelism = 8
val maxSubmissionRequests = 100
val maxSubmissionDuration = 5.seconds
TriggerRunnerConfig(
parallelism = parallelism,
maxRetries = 6,
maxSubmissionRequests = maxSubmissionRequests,
maxSubmissionDuration = maxSubmissionDuration,
inFlightCommandBackPressureCount = 1000,
// 256 here comes from the default ExecutionContext.
submissionFailureQueueSize = 256 + parallelism,
maximumBatchSize = 1000,
batchingDuration = 250.milliseconds,
hardLimit = TriggerRunnerHardLimits(
maximumActiveContracts = 20000,
inFlightCommandOverflowCount = 10000,
allowInFlightCommandOverflows = true,
// 50% extra on the maxSubmissionDuration value
ruleEvaluationTimeout = maxSubmissionDuration * 3 / 2,
// 50% extra on mean time between submission requests
stepInterpreterTimeout = (maxSubmissionDuration / maxSubmissionRequests.toLong) * 3 / 2,
allowTriggerTimeouts = false,
),
)
}
}

View File

@ -1,232 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf.engine.trigger
import scalaz.{-\/, Bifunctor, \/, \/-}
import scalaz.syntax.bifunctor._
import scalaz.std.option.some
import scalaz.std.tuple._
import org.apache.pekko.NotUsed
import org.apache.pekko.stream.{BidiShape, FanOutShape2, Graph, Inlet, Outlet}
import org.apache.pekko.stream.scaladsl.{Concat, Flow, GraphDSL, Partition, Source}
import scala.annotation.tailrec
import com.daml.scalautil.Statement.discard
import scala.collection.immutable.{IndexedSeq, Iterable, LinearSeq}
/** A variant of [[scalaz.CorecursiveList]] that emits a final state
* at the end of the list.
*/
private[trigger] sealed abstract class UnfoldState[T, A] {
type S
val init: S
val step: S => T \/ (A, S)
def withInit(init: S): UnfoldState.Aux[S, T, A]
final def foreach(f: A => Unit): T = {
@tailrec def go(s: S): T = step(s) match {
case -\/(t) => t
case \/-((a, s2)) =>
f(a)
go(s2)
}
go(init)
}
private[trigger] final def iterator(): Iterator[T \/ A] =
new Iterator[T \/ A] {
var last = some(step(init))
override def hasNext = last.isDefined
override def next() = last match {
case Some(\/-((a, s))) =>
last = Some(step(s))
\/-(a)
case Some(et @ -\/(_)) =>
last = None
et
case None =>
throw new IllegalStateException("iterator read past end")
}
}
final def runTo[FA](implicit factory: collection.Factory[A, FA]): (FA, T) = {
val b = factory.newBuilder
val t = foreach(a => discard(b += a))
(b.result(), t)
}
}
private[trigger] object UnfoldState {
type Aux[S0, T, A] = UnfoldState[T, A] { type S = S0 }
def apply[S, T, A](init: S)(step: S => T \/ (A, S)): UnfoldState[T, A] = {
type S0 = S
final case class UnfoldStateImpl(init: S, step: S => T \/ (A, S)) extends UnfoldState[T, A] {
type S = S0
override def withInit(init: S) = copy(init = init)
}
UnfoldStateImpl(init, step)
}
implicit def `US bifunctor instance`: Bifunctor[UnfoldState] = new Bifunctor[UnfoldState] {
override def bimap[A, B, C, D](fab: UnfoldState[A, B])(f: A => C, g: B => D) =
UnfoldState(fab.init)(fab.step andThen (_.bimap(f, (_ leftMap g))))
}
def fromLinearSeq[A](list: LinearSeq[A]): UnfoldState[Unit, A] = {
type Sr = Unit \/ (A, LinearSeq[A])
apply(list) {
case hd +: tl => \/-((hd, tl)): Sr
case _ => -\/(()): Sr
}
}
def fromIndexedSeq[A](vector: IndexedSeq[A]): UnfoldState[Unit, A] = {
type Sr = Unit \/ (A, Int)
apply(0) { n =>
if (vector.sizeIs > n) \/-((vector(n), n + 1)): Sr
else -\/(()): Sr
}
}
implicit final class toSourceOps[T, A](private val self: SourceShape2[T, A]) {
def elemsOut: Outlet[A] = self.out2
def finalState: Outlet[T] = self.out1
}
def toSource[T, A](us: UnfoldState[T, A]): Graph[SourceShape2[T, A], NotUsed] =
GraphDSL.create() { implicit gb =>
import GraphDSL.Implicits._
val split = gb add partition[T, A]
Source.fromIterator(() => us.iterator()) ~> split.in
SourceShape2(split.out0, split.out1)
}
/** A stateful but pure version of built-in flatMapConcat.
* (flatMapMerge does not make sense, because parallelism
* with linear state does not make sense.)
*/
def flatMapConcat[T, A, B](zero: T)(f: (T, A) => UnfoldState[T, B]): Flow[A, B, NotUsed] =
flatMapConcatStates(zero)(f) collect { case \/-(b) => b }
/** Like `flatMapConcat` but emit the new state after each unfolded list.
* The pattern you will see is a bunch of right Bs, followed by a single
* left T, then repeat until close, with a final T unless aborted.
*/
def flatMapConcatStates[T, A, B](zero: T)(
f: (T, A) => UnfoldState[T, B]
): Flow[A, T \/ B, NotUsed] =
Flow[A].statefulMapConcat(() => mkMapConcatFun(zero, f))
type UnfoldStateShape[T, -A, +B] = BidiShape[T, B, A, T]
implicit final class flatMapConcatNodeOps[IT, B, A, OT](
private val self: BidiShape[IT, B, A, OT]
) {
def initState: Inlet[IT] = self.in1
def elemsIn: Inlet[A] = self.in2
def elemsOut: Outlet[B] = self.out1
def finalStates: Outlet[OT] = self.out2
}
/** Accept 1 initial state on in1, fold over in2 elements, emitting the output
* elements on out1, and each result state after each result of `f` is unfolded
* on out2.
*/
def flatMapConcatNode[T, A, B](
f: (T, A) => UnfoldState[T, B]
): Graph[UnfoldStateShape[T, A, B], NotUsed] =
GraphDSL.create() { implicit gb =>
import GraphDSL.Implicits._
val initialT = gb add (Flow fromFunction \/.left[T, A] take 1)
val as = gb add (Flow fromFunction \/.right[T, A])
val tas = gb add Concat[T \/ A](2) // ensure that T arrives *before* A
val splat = gb add (Flow[T \/ A] statefulMapConcat (() => statefulMapConcatFun(f)))
val split = gb add partition[T, B]
// format: off
discard { initialT ~> tas }
discard { as ~> tas ~> splat ~> split.in }
// format: on
new BidiShape(initialT.in, split.out1, as.in, split.out0)
}
// TODO factor with ContractsFetch
private[this] def partition[A, B]: Graph[FanOutShape2[A \/ B, A, B], NotUsed] =
GraphDSL.create() { implicit b =>
import GraphDSL.Implicits._
val split = b.add(
Partition[A \/ B](
2,
{
case -\/(_) => 0
case \/-(_) => 1
},
)
)
val as = b.add(Flow[A \/ B].collect { case -\/(a) => a })
val bs = b.add(Flow[A \/ B].collect { case \/-(b) => b })
discard { split ~> as }
discard { split ~> bs }
new FanOutShape2(split.in, as.out, bs.out)
}
private[this] type FoldL[T, -A, B] = (T, A) => UnfoldState[T, B]
private[this] def statefulMapConcatFun[T, A, B](f: FoldL[T, A, B]): T \/ A => Iterable[T \/ B] = {
var mcFun: A => Iterable[T \/ B] = null
_.fold(
zeroT => {
mcFun = mkMapConcatFun(zeroT, f)
Iterable.empty
},
{ a =>
mcFun(a)
},
)
}
private[this] def mkMapConcatFun[T, A, B](zero: T, f: FoldL[T, A, B]): A => Iterable[T \/ B] = {
var t = zero
// statefulMapConcat only uses 'iterator'. We preserve the Iterable's
// immutability by making one strict reference to the 't' var at 'Iterable' creation
// time, meaning any later 'iterator' call uses the same start state, no matter
// whether the 't' has been updated
a =>
new Iterable[T \/ B] {
private[this] val bs = f(t, a)
import bs.step
override def iterator = new Iterator[T \/ B] {
private[this] var last: Option[T \/ (B, bs.S)] = {
val fst = step(bs.init)
fst.fold(newT => t = newT, _ => ())
Some(fst)
}
// this stream is "odd", i.e. we are always evaluating 1 step ahead
// of what the client sees. We could improve laziness by making it
// "even", but it would be a little trickier, as `hasNext` would have
// a forcing side-effect
override def hasNext = last.isDefined
override def next() =
last match {
case Some(\/-((b, s))) =>
val next = step(s)
// The assumption here is that statefulMapConcat's implementation
// will always read iterator to end before invoking on the next A
next.fold(newT => t = newT, _ => ())
last = Some(next)
\/-(b)
case Some(et @ -\/(_)) =>
last = None
et
case None =>
throw new IllegalStateException("iterator read past end")
}
}
}
}
}

View File

@ -1,416 +0,0 @@
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
load("@oracle//:index.bzl", "oracle_tags")
load("@build_environment//:configuration.bzl", "sdk_version")
load(
"//daml-lf/language:daml-lf.bzl",
"LF_DEV_VERSIONS",
"LF_MAJOR_VERSIONS",
"lf_version_default_or_latest",
"lf_versions_aggregate",
)
load("@os_info//:os_info.bzl", "is_windows")
load(
"//bazel_tools:scala.bzl",
"da_scala_binary",
"da_scala_library",
"da_scala_test_suite",
"lf_scalacopts_stricter",
)
TRIGGER_MAIN = "src/main/scala/com/digitalasset/daml/lf/engine/trigger/ServiceMain.scala"
target_lf_versions = lf_versions_aggregate(
[lf_version_default_or_latest(major) for major in LF_MAJOR_VERSIONS] +
LF_DEV_VERSIONS,
)
da_scala_library(
name = "trigger-service",
srcs = glob(
["src/main/scala/**/*.scala"],
exclude = [TRIGGER_MAIN],
),
resources = glob(["src/main/resources/**/*"]),
scala_deps = [
"@maven//:com_chuusai_shapeless",
"@maven//:com_github_pureconfig_pureconfig_core",
"@maven//:com_github_pureconfig_pureconfig_generic",
"@maven//:com_github_scopt_scopt",
"@maven//:com_typesafe_scala_logging_scala_logging",
"@maven//:io_spray_spray_json",
"@maven//:org_apache_pekko_pekko_actor",
"@maven//:org_apache_pekko_pekko_actor_typed",
"@maven//:org_apache_pekko_pekko_http",
"@maven//:org_apache_pekko_pekko_http_core",
"@maven//:org_apache_pekko_pekko_http_spray_json",
"@maven//:org_apache_pekko_pekko_stream",
"@maven//:org_parboiled_parboiled",
"@maven//:org_scalaz_scalaz_core",
"@maven//:org_tpolecat_doobie_core",
"@maven//:org_tpolecat_doobie_free",
"@maven//:org_tpolecat_doobie_postgres",
"@maven//:org_typelevel_cats_core",
"@maven//:org_typelevel_cats_effect",
"@maven//:org_typelevel_cats_free",
"@maven//:org_typelevel_cats_kernel",
],
scala_runtime_deps = [
"@maven//:org_apache_pekko_pekko_slf4j",
"@maven//:org_tpolecat_doobie_postgres",
],
scalacopts = lf_scalacopts_stricter,
# Uncomment this if/when the target is published to maven.
# tags = ["maven_coordinates=com.daml:trigger-service:__VERSION__"],
visibility = ["//visibility:public"],
runtime_deps = [
"@maven//:ch_qos_logback_logback_classic",
"@maven//:ch_qos_logback_logback_core",
"@maven//:org_postgresql_postgresql",
],
deps = [
"//canton:ledger_api_proto_scala",
"//daml-lf/archive:daml_lf_archive_reader",
"//daml-lf/archive:daml_lf_dev_archive_proto_java",
"//daml-lf/data",
"//daml-lf/engine",
"//daml-lf/interpreter",
"//daml-lf/language",
"//ledger-service/cli-opts",
"//ledger-service/metrics",
"//ledger-service/pureconfig-utils",
"//ledger/ledger-api-client",
"//ledger/ledger-api-common",
"//libs-scala/contextualized-logging",
"//libs-scala/db-utils",
"//libs-scala/doobie-slf4j",
"//libs-scala/ledger-resources",
"//libs-scala/resources",
"//libs-scala/rs-grpc-bridge",
"//libs-scala/rs-grpc-pekko",
"//libs-scala/scala-utils",
"//observability/metrics",
"//observability/pekko-http-metrics",
"//triggers/metrics",
"//triggers/runner:trigger-runner-lib",
"//triggers/service/auth:middleware-api",
"@maven//:ch_qos_logback_logback_classic",
"@maven//:com_typesafe_config",
"@maven//:io_dropwizard_metrics_metrics_core",
"@maven//:io_netty_netty_handler",
"@maven//:io_opentelemetry_opentelemetry_api",
"@maven//:org_flywaydb_flyway_core",
"@maven//:org_slf4j_slf4j_api",
],
)
scala_binary_deps = [
"@maven//:org_apache_pekko_pekko_actor",
"@maven//:org_apache_pekko_pekko_actor_typed",
"@maven//:org_apache_pekko_pekko_http_core",
"@maven//:com_typesafe_scala_logging_scala_logging",
"@maven//:org_scalaz_scalaz_core",
]
binary_deps = [
"//daml-lf/archive:daml_lf_archive_reader",
"//daml-lf/archive:daml_lf_dev_archive_proto_java",
"//daml-lf/data",
"//daml-lf/interpreter",
"//ledger/ledger-api-common",
"//libs-scala/contextualized-logging",
"//libs-scala/db-utils",
"//libs-scala/ports",
"//libs-scala/scala-utils",
"//observability/metrics",
"//triggers/runner:trigger-runner-lib",
"//triggers/service/auth:middleware-api",
":trigger-service",
"@maven//:ch_qos_logback_logback_classic",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:org_slf4j_slf4j_api",
]
trigger_service_runtime_deps = {
"ce": [],
"ee": ["@maven//:com_oracle_database_jdbc_ojdbc8"],
}
[
da_scala_binary(
name = "trigger-service-binary-{}".format(edition),
srcs = [TRIGGER_MAIN],
main_class = "com.daml.lf.engine.trigger.ServiceMain",
resource_strip_prefix = "triggers/service/release/trigger-service-",
resources = ["release/trigger-service-logback.xml"],
scala_deps = scala_binary_deps,
scalacopts = lf_scalacopts_stricter,
tags = ["ee-jar-license"] if edition == "ee" else [],
visibility = ["//visibility:public"],
runtime_deps = trigger_service_runtime_deps.get(edition),
deps = binary_deps + [
"//runtime-components/jdbc-drivers:jdbc-drivers-{}".format(edition),
],
)
for edition in [
"ce",
"ee",
]
]
da_scala_library(
name = "trigger-service-tests",
srcs = glob(["src/test/scala/com/digitalasset/daml/lf/engine/trigger/*.scala"]),
data = [
":test-model-{}.dar".format(lf_version)
for lf_version in target_lf_versions
] + [
":test-model-v{}.dar".format(major)
for major in LF_MAJOR_VERSIONS
] + (
[
"@toxiproxy_dev_env//:bin/toxiproxy-server",
] if not is_windows else [
"@toxiproxy_dev_env//:toxiproxy-server-windows-amd64.exe",
]
),
resources = glob(["src/test/resources/**/*"]),
scala_deps = [
"@maven//:com_lihaoyi_sourcecode",
"@maven//:com_typesafe_scala_logging_scala_logging",
"@maven//:io_spray_spray_json",
"@maven//:org_apache_pekko_pekko_actor",
"@maven//:org_apache_pekko_pekko_actor_typed",
"@maven//:org_apache_pekko_pekko_http_core",
"@maven//:org_apache_pekko_pekko_stream",
"@maven//:org_parboiled_parboiled",
"@maven//:org_scalactic_scalactic",
"@maven//:org_scalatest_scalatest_core",
"@maven//:org_scalatest_scalatest_flatspec",
"@maven//:org_scalatest_scalatest_matchers_core",
"@maven//:org_scalatest_scalatest_shouldmatchers",
"@maven//:org_scalaz_scalaz_core",
],
visibility = ["//test-evidence:__pkg__"],
deps = [
":trigger-service",
":trigger-service-binary-ce",
"//bazel_tools/runfiles:scala_runfiles",
"//canton:ledger_api_proto_scala",
"//daml-lf/archive:daml_lf_archive_reader",
"//daml-lf/archive:daml_lf_dev_archive_proto_java",
"//daml-lf/data",
"//daml-lf/interpreter",
"//daml-lf/language",
"//ledger/ledger-api-auth",
"//ledger/ledger-api-client",
"//ledger/ledger-api-common",
"//ledger/ledger-api-domain",
"//libs-scala/adjustable-clock",
"//libs-scala/db-utils",
"//libs-scala/jwt",
"//libs-scala/ledger-resources",
"//libs-scala/oracle-testing",
"//libs-scala/ports",
"//libs-scala/ports:ports-testing",
"//libs-scala/postgresql-testing",
"//libs-scala/resources",
"//libs-scala/rs-grpc-bridge",
"//libs-scala/scala-utils",
"//libs-scala/test-evidence/scalatest:test-evidence-scalatest",
"//libs-scala/test-evidence/tag:test-evidence-tag",
"//libs-scala/testing-utils",
"//libs-scala/timer-utils",
"//observability/metrics",
"//test-common/canton/it-lib",
"//triggers/runner:trigger-runner-lib",
"//triggers/service/auth:middleware-api",
"//triggers/service/auth:oauth2-middleware",
"//triggers/service/auth:oauth2-test-server",
"@maven//:ch_qos_logback_logback_classic",
"@maven//:ch_qos_logback_logback_core",
"@maven//:com_auth0_java_jwt",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:eu_rekawek_toxiproxy_toxiproxy_java_2_1_7",
"@maven//:org_scalatest_scalatest_compatible",
"@maven//:org_slf4j_slf4j_api",
],
)
da_scala_test_suite(
name = "test",
timeout = "long",
srcs = glob(
["src/test-suite/scala/**/*.scala"],
exclude = ["**/*Oracle*"],
),
data = [
":src/test-suite/resources/trigger-service.conf",
":src/test-suite/resources/trigger-service-minimal.conf",
],
scala_deps = [
"@maven//:com_github_pureconfig_pureconfig_core",
"@maven//:com_github_scopt_scopt",
"@maven//:com_typesafe_scala_logging_scala_logging",
"@maven//:io_spray_spray_json",
"@maven//:org_apache_pekko_pekko_actor",
"@maven//:org_apache_pekko_pekko_http_core",
"@maven//:org_apache_pekko_pekko_stream",
"@maven//:org_parboiled_parboiled",
"@maven//:org_scalactic_scalactic",
"@maven//:org_scalatest_scalatest_core",
"@maven//:org_scalatest_scalatest_flatspec",
"@maven//:org_scalatest_scalatest_matchers_core",
"@maven//:org_scalatest_scalatest_shouldmatchers",
],
visibility = ["//test-evidence:__pkg__"],
deps = [
":trigger-service",
":trigger-service-tests",
"//bazel_tools/runfiles:scala_runfiles",
"//canton:ledger_api_proto_scala",
"//daml-lf/archive:daml_lf_archive_reader",
"//daml-lf/archive:daml_lf_dev_archive_proto_java",
"//daml-lf/data",
"//daml-lf/interpreter",
"//daml-lf/language",
"//ledger-service/cli-opts",
"//ledger-service/pureconfig-utils",
"//ledger/ledger-api-auth",
"//ledger/ledger-api-client",
"//ledger/ledger-api-common",
"//libs-scala/adjustable-clock",
"//libs-scala/db-utils",
"//libs-scala/flyway-testing",
"//libs-scala/jwt",
"//libs-scala/ledger-resources",
"//libs-scala/ports",
"//libs-scala/postgresql-testing",
"//libs-scala/resources",
"//libs-scala/rs-grpc-bridge",
"//libs-scala/test-evidence/tag:test-evidence-tag",
"//libs-scala/testing-utils",
"//libs-scala/timer-utils",
"//test-common/canton/it-lib",
"//triggers/runner:trigger-runner-lib",
"//triggers/service/auth:middleware-api",
"//triggers/service/auth:oauth2-test-server",
"@maven//:ch_qos_logback_logback_classic",
"@maven//:ch_qos_logback_logback_core",
"@maven//:eu_rekawek_toxiproxy_toxiproxy_java_2_1_7",
"@maven//:org_flywaydb_flyway_core",
"@maven//:org_scalatest_scalatest_compatible",
],
)
da_scala_test_suite(
name = "test-oracle",
timeout = "long",
srcs = glob(["src/test-suite/scala/**/*Oracle*.scala"]),
scala_deps = [
"@maven//:io_spray_spray_json",
"@maven//:org_apache_pekko_pekko_actor",
"@maven//:org_apache_pekko_pekko_http_core",
"@maven//:org_apache_pekko_pekko_stream",
"@maven//:org_scalatest_scalatest_core",
"@maven//:org_scalatest_scalatest_matchers_core",
"@maven//:org_scalatest_scalatest_shouldmatchers",
],
tags = oracle_tags,
runtime_deps = [
"@maven//:com_oracle_database_jdbc_ojdbc8",
],
deps = [
":trigger-service",
":trigger-service-tests",
"//canton:ledger_api_proto_scala",
"//daml-lf/archive:daml_lf_archive_reader",
"//daml-lf/archive:daml_lf_dev_archive_proto_java",
"//daml-lf/data",
"//daml-lf/language",
"//ledger/ledger-api-auth",
"//ledger/ledger-api-client",
"//ledger/ledger-api-common",
"//libs-scala/adjustable-clock",
"//libs-scala/db-utils",
"//libs-scala/jwt",
"//libs-scala/ledger-resources",
"//libs-scala/oracle-testing",
"//libs-scala/ports",
"//libs-scala/resources",
"//libs-scala/rs-grpc-bridge",
"//libs-scala/test-evidence/tag:test-evidence-tag",
"//libs-scala/testing-utils",
"//test-common/canton/it-lib",
"//triggers/runner:trigger-runner-lib",
"//triggers/service/auth:oauth2-test-server",
"@maven//:ch_qos_logback_logback_classic",
"@maven//:ch_qos_logback_logback_core",
"@maven//:eu_rekawek_toxiproxy_toxiproxy_java_2_1_7",
"@maven//:org_scalatest_scalatest_compatible",
],
)
# Build one DAR per LF version to bundle with the SDK.
[
genrule(
name = "test-model-{}".format(lf_version),
srcs =
glob(["test-model/*.daml"]) + [
"//triggers/daml:daml-trigger-{}".format(lf_version),
"//daml-script/daml:daml-script-{}".format(lf_version),
],
outs = ["test-model-{}.dar".format(lf_version)],
cmd = """
set -eou pipefail
TMP_DIR=$$(mktemp -d)
mkdir -p $$TMP_DIR/daml
cp -L $(location :test-model/TestTrigger.daml) $$TMP_DIR/daml
cp -L $(location :test-model/ErrorTrigger.daml) $$TMP_DIR/daml
cp -L $(location :test-model/LowLevelErrorTrigger.daml) $$TMP_DIR/daml
cp -L $(location :test-model/ReadAs.daml) $$TMP_DIR/daml
cp -L $(location :test-model/Cats.daml) $$TMP_DIR/daml
cp -L $(location {daml_trigger}) $$TMP_DIR/daml-trigger.dar
cp -L $(location {daml_script}) $$TMP_DIR/daml-script.dar
cat << EOF > $$TMP_DIR/daml.yaml
sdk-version: {sdk}
name: test-model
source: daml
version: 0.0.1
dependencies:
- daml-stdlib
- daml-prim
- daml-trigger.dar
- daml-script.dar
build-options: {build_options}
EOF
$(location //compiler/damlc) build --project-root=$$TMP_DIR --ghc-option=-Werror -o $$PWD/$@
rm -rf $$TMP_DIR
""".format(
build_options = str([
"--target",
lf_version,
]),
daml_script = "//daml-script/daml:daml-script-{}".format(lf_version),
daml_trigger = "//triggers/daml:daml-trigger-{}".format(lf_version),
sdk = sdk_version,
),
tools = ["//compiler/damlc"],
visibility = ["//visibility:public"],
)
for lf_version in target_lf_versions
]
[
genrule(
name = "test-model-{}".format(major),
srcs = [":test-model-{}".format(lf_version_default_or_latest(major))],
outs = ["test-model-v{}.dar".format(major)],
cmd = "cp -L $(location :test-model-{}) $$PWD/$@".format(lf_version_default_or_latest(major)),
visibility = ["//visibility:public"],
)
for major in LF_MAJOR_VERSIONS
]
exports_files(["release/trigger-service-logback.xml"])

View File

@ -1,271 +0,0 @@
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
load(
"//bazel_tools:scala.bzl",
"da_scala_binary",
"da_scala_library",
"da_scala_test",
"lf_scalacopts_stricter",
)
exports_files(["release/oauth2-middleware-logback.xml"])
test_scalacopts = ["-P:wartremover:traverser:org.wartremover.warts.OptionPartial"]
da_scala_library(
name = "jwt-cli-opts",
srcs = ["src/main/scala/com/daml/JwtVerifierConfigurationCli.scala"],
scala_deps = [
"@maven//:com_github_scopt_scopt",
"@maven//:com_typesafe_scala_logging_scala_logging",
"@maven//:org_scalaz_scalaz_core",
],
scalacopts = lf_scalacopts_stricter,
deps = [
"//libs-scala/jwt",
"@maven//:com_auth0_java_jwt",
],
)
da_scala_test(
name = "jwt-cli-opts-tests",
srcs = ["src/test/scala/com/daml/JwtVerifierConfigurationCliSpec.scala"],
scala_deps = [
"@maven//:com_github_scopt_scopt",
"@maven//:org_scalactic_scalactic",
"@maven//:org_scalatest_scalatest_core",
"@maven//:org_scalatest_scalatest_matchers_core",
"@maven//:org_scalatest_scalatest_shouldmatchers",
"@maven//:org_scalatest_scalatest_wordspec",
],
scalacopts = lf_scalacopts_stricter,
deps = [
":jwt-cli-opts",
"//ledger/ledger-api-auth",
"//libs-scala/fs-utils",
"//libs-scala/http-test-utils",
"//libs-scala/jwt",
"//libs-scala/resources",
"//libs-scala/scala-utils",
"@maven//:com_auth0_java_jwt",
"@maven//:io_grpc_grpc_api",
"@maven//:org_bouncycastle_bcpkix_jdk15on",
"@maven//:org_bouncycastle_bcprov_jdk15on",
],
)
da_scala_library(
name = "oauth2-api",
srcs = glob(["src/main/scala/com/daml/auth/oauth2/api/**/*.scala"]),
scala_deps = [
"@maven//:io_spray_spray_json",
"@maven//:org_apache_pekko_pekko_actor",
"@maven//:org_apache_pekko_pekko_http",
"@maven//:org_apache_pekko_pekko_http_core",
"@maven//:org_apache_pekko_pekko_stream",
"@maven//:org_parboiled_parboiled",
],
scalacopts = lf_scalacopts_stricter,
visibility = ["//visibility:public"],
)
da_scala_library(
name = "middleware-api",
srcs = glob(["src/main/scala/com/daml/auth/middleware/api/**/*.scala"]),
scala_deps = [
"@maven//:io_spray_spray_json",
"@maven//:org_apache_pekko_pekko_actor",
"@maven//:org_apache_pekko_pekko_http",
"@maven//:org_apache_pekko_pekko_http_core",
"@maven//:org_apache_pekko_pekko_http_spray_json",
"@maven//:org_apache_pekko_pekko_stream",
"@maven//:org_parboiled_parboiled",
"@maven//:org_scalaz_scalaz_core",
],
scalacopts = lf_scalacopts_stricter,
visibility = ["//visibility:public"],
deps = [
"//daml-lf/data",
],
)
da_scala_library(
name = "oauth2-middleware",
srcs = glob(["src/main/scala/com/daml/auth/middleware/oauth2/**/*.scala"]),
resources = glob(["src/main/resources/com/daml/auth/middleware/oauth2/**"]),
scala_deps = [
"@maven//:com_chuusai_shapeless",
"@maven//:com_github_pureconfig_pureconfig_core",
"@maven//:com_github_pureconfig_pureconfig_generic",
"@maven//:com_github_scopt_scopt",
"@maven//:com_typesafe_scala_logging_scala_logging",
"@maven//:io_spray_spray_json",
"@maven//:org_apache_pekko_pekko_actor",
"@maven//:org_apache_pekko_pekko_http",
"@maven//:org_apache_pekko_pekko_http_core",
"@maven//:org_apache_pekko_pekko_http_spray_json",
"@maven//:org_apache_pekko_pekko_stream",
"@maven//:org_parboiled_parboiled",
"@maven//:org_scalaz_scalaz_core",
"@triggers_maven//:com_lihaoyi_fastparse",
"@triggers_maven//:com_lihaoyi_geny",
"@triggers_maven//:com_lihaoyi_os_lib",
"@triggers_maven//:com_lihaoyi_sjsonnet",
"@triggers_maven//:com_lihaoyi_ujson",
"@triggers_maven//:com_lihaoyi_upickle_core",
],
scalacopts = lf_scalacopts_stricter,
visibility = ["//visibility:public"],
deps = [
":jwt-cli-opts",
":middleware-api",
":oauth2-api",
"//daml-lf/data",
"//ledger-service/cli-opts",
"//ledger-service/metrics",
"//ledger-service/pureconfig-utils",
"//ledger/ledger-api-auth",
"//ledger/ledger-api-common",
"//libs-scala/jwt",
"//libs-scala/ledger-resources",
"//libs-scala/ports",
"//libs-scala/resources",
"//libs-scala/scala-utils",
"//observability/metrics",
"//observability/pekko-http-metrics",
"@maven//:com_typesafe_config",
"@maven//:io_dropwizard_metrics_metrics_core",
"@maven//:io_opentelemetry_opentelemetry_api",
"@maven//:org_slf4j_slf4j_api",
],
)
da_scala_binary(
name = "oauth2-middleware-binary",
main_class = "com.daml.auth.middleware.oauth2.Main",
scalacopts = lf_scalacopts_stricter,
visibility = ["//visibility:public"],
runtime_deps = [
"@maven//:ch_qos_logback_logback_classic",
],
deps = [
":oauth2-middleware",
],
)
da_scala_library(
name = "oauth2-test-server",
srcs = glob(["src/main/scala/com/daml/auth/oauth2/test/server/**/*.scala"]),
scala_deps = [
"@maven//:com_github_scopt_scopt",
"@maven//:com_typesafe_scala_logging_scala_logging",
"@maven//:io_spray_spray_json",
"@maven//:org_apache_pekko_pekko_actor",
"@maven//:org_apache_pekko_pekko_http",
"@maven//:org_apache_pekko_pekko_http_core",
"@maven//:org_apache_pekko_pekko_http_spray_json",
"@maven//:org_apache_pekko_pekko_stream",
"@maven//:org_parboiled_parboiled",
"@maven//:org_scalaz_scalaz_core",
],
visibility = ["//triggers/service:__subpackages__"],
deps = [
":oauth2-api",
"//daml-lf/data",
"//ledger/ledger-api-auth",
"//libs-scala/jwt",
"//libs-scala/ports",
"@maven//:org_slf4j_slf4j_api",
],
)
da_scala_binary(
name = "oauth2-test-server-binary",
main_class = "com.daml.auth.oauth2.test.server.Main",
scalacopts = lf_scalacopts_stricter,
runtime_deps = [
"@maven//:ch_qos_logback_logback_classic",
],
deps = [
":oauth2-test-server",
],
)
da_scala_test(
name = "oauth2-test-server-tests",
srcs = glob(["src/test/scala/com/daml/auth/oauth2/test/server/**/*.scala"]),
scala_deps = [
"@maven//:io_spray_spray_json",
"@maven//:org_apache_pekko_pekko_actor",
"@maven//:org_apache_pekko_pekko_http",
"@maven//:org_apache_pekko_pekko_http_core",
"@maven//:org_apache_pekko_pekko_http_spray_json",
"@maven//:org_apache_pekko_pekko_stream",
"@maven//:org_parboiled_parboiled",
"@maven//:org_scalaz_scalaz_core",
],
scalacopts = test_scalacopts,
deps = [
":oauth2-api",
":oauth2-test-server",
"//daml-lf/data",
"//ledger/ledger-api-auth",
"//libs-scala/adjustable-clock",
"//libs-scala/jwt",
"//libs-scala/ledger-resources",
"//libs-scala/ports",
"//libs-scala/resources",
"//libs-scala/rs-grpc-bridge",
"//libs-scala/testing-utils",
],
)
da_scala_test(
name = "oauth2-middleware-tests",
srcs = glob(["src/test/scala/com/daml/auth/middleware/oauth2/**/*.scala"]),
data = [
":src/test/resources/oauth2-middleware.conf",
":src/test/resources/oauth2-middleware-minimal.conf",
],
scala_deps = [
"@maven//:com_github_pureconfig_pureconfig_core",
"@maven//:com_lihaoyi_sourcecode",
"@maven//:com_typesafe_scala_logging_scala_logging",
"@maven//:io_spray_spray_json",
"@maven//:org_apache_pekko_pekko_actor",
"@maven//:org_apache_pekko_pekko_http",
"@maven//:org_apache_pekko_pekko_http_core",
"@maven//:org_apache_pekko_pekko_http_spray_json",
"@maven//:org_apache_pekko_pekko_http_testkit",
"@maven//:org_apache_pekko_pekko_stream",
"@maven//:org_parboiled_parboiled",
"@maven//:org_scalaz_scalaz_core",
],
scala_runtime_deps = [
"@maven//:org_apache_pekko_pekko_stream_testkit",
],
visibility = ["//test-evidence:__pkg__"],
deps = [
":middleware-api",
":oauth2-api",
":oauth2-middleware",
":oauth2-test-server",
"//bazel_tools/runfiles:scala_runfiles",
"//daml-lf/data",
"//ledger/ledger-api-auth",
"//libs-scala/adjustable-clock",
"//libs-scala/jwt",
"//libs-scala/ledger-resources",
"//libs-scala/ports",
"//libs-scala/resources",
"//libs-scala/rs-grpc-bridge",
"//libs-scala/scala-utils",
"//libs-scala/test-evidence/scalatest:test-evidence-scalatest",
"//libs-scala/test-evidence/tag:test-evidence-tag",
"//libs-scala/testing-utils",
"//observability/metrics",
"@maven//:com_auth0_java_jwt",
"@maven//:com_typesafe_config",
],
)

View File

@ -1,25 +0,0 @@
<configuration>
<appender name="console" class="ch.qos.logback.core.ConsoleAppender">
<if condition='isDefined("LOG_FORMAT_JSON")'>
<then>
<encoder class="net.logstash.logback.encoder.LogstashEncoder"/>
</then>
<else>
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</else>
</if>
</appender>
<appender name="STDOUT" class="net.logstash.logback.appender.LoggingEventAsyncDisruptorAppender">
<appender-ref ref="console"/>
</appender>
<logger name="io.netty" level="WARN"/>
<logger name="io.grpc.netty" level="WARN"/>
<root level="${LOG_LEVEL_ROOT:-INFO}">
<appender-ref ref="STDOUT" />
</root>
</configuration>

View File

@ -1,15 +0,0 @@
local scope(claims) =
local admin = if claims.admin then "admin";
local applicationId = if claims.applicationId != null then "applicationId:" + claims.applicationId;
local actAs = std.map(function(p) "actAs:" + p, claims.actAs);
local readAs = std.map(function(p) "readAs:" + p, claims.readAs);
[admin, applicationId] + actAs + readAs;
function(config, request) {
"audience": "https://daml.com/ledger-api",
"client_id": config.clientId,
"redirect_uri": request.redirectUri,
"response_type": "code",
"scope": std.join(" ", ["offline_access"] + scope(request.claims)),
"state": request.state,
}

View File

@ -1,6 +0,0 @@
function(config, request) {
"client_id": config.clientId,
"client_secret": config.clientSecret,
"grant_type": "refresh_code",
"refresh_token": request.refreshToken,
}

View File

@ -1,7 +0,0 @@
function(config, request) {
"client_id": config.clientId,
"client_secret": config.clientSecret,
"code": request.code,
"grant_type": "authorization_code",
"redirect_uri": request.redirectUri,
}

View File

@ -1,95 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.jwt
import java.nio.file.Paths
import com.auth0.jwt.algorithms.Algorithm
import scala.util.Try
@SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements"))
object JwtVerifierConfigurationCli {
def parse[C](parser: scopt.OptionParser[C])(setter: (JwtVerifierBase, C) => C): Unit = {
def setJwtVerifier(jwtVerifier: JwtVerifierBase, c: C): C = setter(jwtVerifier, c)
import parser.opt
opt[String]("auth-jwt-hs256-unsafe")
.optional()
.hidden()
.validate(v => Either.cond(v.nonEmpty, (), "HMAC secret must be a non-empty string"))
.text(
"[UNSAFE] Enables JWT-based authorization with shared secret HMAC256 signing: USE THIS EXCLUSIVELY FOR TESTING"
)
.action { (secret, config) =>
val verifier = HMAC256Verifier(secret)
.valueOr(err => sys.error(s"Failed to create HMAC256 verifier: $err"))
setJwtVerifier(verifier, config)
}
opt[String]("auth-jwt-rs256-crt")
.optional()
.validate(
validatePath(_, "The certificate file specified via --auth-jwt-rs256-crt does not exist")
)
.text(
"Enables JWT-based authorization, where the JWT is signed by RSA256 with a public key loaded from the given X509 certificate file (.crt)"
)
.action { (path, config) =>
val verifier = RSA256Verifier
.fromCrtFile(path)
.valueOr(err => sys.error(s"Failed to create RSA256 verifier: $err"))
setJwtVerifier(verifier, config)
}
opt[String]("auth-jwt-es256-crt")
.optional()
.validate(
validatePath(_, "The certificate file specified via --auth-jwt-es256-crt does not exist")
)
.text(
"Enables JWT-based authorization, where the JWT is signed by ECDSA256 with a public key loaded from the given X509 certificate file (.crt)"
)
.action { (path, config) =>
val verifier = ECDSAVerifier
.fromCrtFile(path, Algorithm.ECDSA256(_, null))
.valueOr(err => sys.error(s"Failed to create ECDSA256 verifier: $err"))
setJwtVerifier(verifier, config)
}
opt[String]("auth-jwt-es512-crt")
.optional()
.validate(
validatePath(_, "The certificate file specified via --auth-jwt-es512-crt does not exist")
)
.text(
"Enables JWT-based authorization, where the JWT is signed by ECDSA512 with a public key loaded from the given X509 certificate file (.crt)"
)
.action { (path, config) =>
val verifier = ECDSAVerifier
.fromCrtFile(path, Algorithm.ECDSA512(_, null))
.valueOr(err => sys.error(s"Failed to create ECDSA512 verifier: $err"))
setJwtVerifier(verifier, config)
}
opt[String]("auth-jwt-rs256-jwks")
.optional()
.validate(v => Either.cond(v.length > 0, (), "JWK server URL must be a non-empty string"))
.text(
"Enables JWT-based authorization, where the JWT is signed by RSA256 with a public key loaded from the given JWKS URL"
)
.action { (url, config) =>
val verifier = JwksVerifier(url)
setJwtVerifier(verifier, config)
}
()
}
private def validatePath(path: String, message: String): Either[String, Unit] = {
val valid = Try(Paths.get(path).toFile.canRead).getOrElse(false)
if (valid) Right(()) else Left(message)
}
}

View File

@ -1,376 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.api
import java.util.UUID
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.http.scaladsl.Http
import org.apache.pekko.http.scaladsl.marshalling.Marshal
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import org.apache.pekko.http.scaladsl.model.Uri.Path
import org.apache.pekko.http.scaladsl.model.{
HttpMethods,
HttpRequest,
MediaTypes,
RequestEntity,
StatusCode,
StatusCodes,
Uri,
headers,
}
import org.apache.pekko.http.scaladsl.server.{
ContentNegotiator,
Directive,
Directive0,
Directive1,
Route,
StandardRoute,
}
import org.apache.pekko.http.scaladsl.server.Directives._
import org.apache.pekko.http.scaladsl.unmarshalling.Unmarshal
import com.daml.auth.middleware.api.Client.{AuthException, RedirectToLogin, RefreshException}
import com.daml.auth.middleware.api.Tagged.RefreshToken
import scala.collection.immutable
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.FiniteDuration
/** Client component for interaction with the auth middleware
*
* A client of the auth middleware is typically itself a web-application that serves HTTP requests.
* Note, a [[Client]] maintains state that needs to persist across such requests.
* In particular, you should not create the [[Client]] instance within a [[Route]].
*
* This may pose a challenge when the client application uses dynamic port binding,
* e.g. for testing purposes, as the login URI's redirect parameter may depend on the port.
* To that end the login and request handler components that may depend on the port
* are provided in a separate [[com.daml.auth.middleware.api.Client.Routes]] class
* that does not need to maintain state across requests and can safely be constructed
* within a [[Route]] using the [[routes]] family of methods.
*/
class Client(config: Client.Config) {
private val callbacks: RequestStore[UUID, Response.Login => Route] = new RequestStore(
config.maxAuthCallbacks,
config.authCallbackTimeout,
)
/** Create a [[Client.Routes]] based on an absolute login redirect URI.
*/
def routes(callbackUri: Uri): Client.Routes = {
assert(
callbackUri.isAbsolute,
"The authorization middleware client callback URI must be absolute.",
)
RoutesImpl(callbackUri)
}
/** Create a [[Client.Routes]] based on a path to be appended to the requests URI scheme and authority.
*
* E.g. given `callbackPath = "cb"` and a request to `http://my.client/foo/bar`
* the redirect URI would take the form `http://my.client/cb`.
*/
def routesFromRequestAuthority(callbackPath: Uri.Path): Directive1[Client.Routes] =
extractUri.map { reqUri =>
RoutesImpl(
Uri()
.withScheme(reqUri.scheme)
.withAuthority(reqUri.authority)
.withPath(callbackPath)
): Client.Routes
}
/** Equivalent to [[routes]] if [[callbackUri]] is an absolute URI,
* otherwise equivalent to [[routesFromRequestAuthority]].
*/
def routesAuto(callbackUri: Uri): Directive1[Client.Routes] = {
if (callbackUri.isAbsolute) {
provide(routes(callbackUri))
} else {
routesFromRequestAuthority(callbackUri.path)
}
}
private case class RoutesImpl(callbackUri: Uri) extends Client.Routes {
val callbackHandler: Route =
parameters(Symbol("state").as[UUID]) { requestId =>
callbacks.pop(requestId) match {
case None =>
complete(StatusCodes.NotFound)
case Some(callback) =>
Response.Login.callbackParameters { callback }
}
}
private val isHtmlRequest: Directive1[Boolean] = extractRequest.map { req =>
val negotiator = ContentNegotiator(req.headers)
val contentTypes = List(
ContentNegotiator.Alternative(MediaTypes.`application/json`),
ContentNegotiator.Alternative(MediaTypes.`text/html`),
)
val preferred = negotiator.pickContentType(contentTypes)
preferred.map(_.mediaType) == Some(MediaTypes.`text/html`)
}
/** Pass control to the inner directive if we should redirect to login on auth failure, reject otherwise.
*/
private val onRedirectToLogin: Directive0 =
config.redirectToLogin match {
case RedirectToLogin.No => reject
case RedirectToLogin.Yes => pass
case RedirectToLogin.Auto =>
isHtmlRequest.flatMap {
case false => reject
case true => pass
}
}
def authorize(claims: Request.Claims): Directive1[Client.AuthorizeResult] = {
auth(claims).flatMap {
// Authorization successful - pass token to continuation
case Some(authorization) => provide(Client.Authorized(authorization))
// Authorization failed - login and retry on callback request.
case None =>
onRedirectToLogin
.tflatMap { _ =>
// Ensure that the request is fully uploaded.
val timeout = config.httpEntityUploadTimeout
val maxBytes = config.maxHttpEntityUploadSize
toStrictEntity(timeout, maxBytes).tflatMap { _ =>
extractRequestContext.flatMap { ctx =>
Directive { (inner: Tuple1[Client.AuthorizeResult] => Route) =>
def continue(result: Client.AuthorizeResult): Route =
mapRequestContext(_ => ctx) {
inner(Tuple1(result))
}
val callback: Response.Login => Route = {
case Response.LoginSuccess =>
auth(claims) {
case None => continue(Client.Unauthorized)
case Some(authorization) => continue(Client.Authorized(authorization))
}
case loginError: Response.LoginError =>
continue(Client.LoginFailed(loginError))
}
login(claims, callback)
}
}
}
}
.or(unauthorized(claims))
}
}
/** This directive attempts to obtain an access token from the middleware's auth endpoint for the given claims.
*
* Forwards the current request's cookies. Completes with 500 on an unexpected response from the auth middleware.
*
* @return `None` if the request was denied otherwise `Some` access and optionally refresh token.
*/
private def auth(claims: Request.Claims): Directive1[Option[Response.Authorize]] =
extractExecutionContext.flatMap { implicit ec =>
extractActorSystem.flatMap { implicit system =>
extract(_.request.headers[headers.Cookie]).flatMap { cookies =>
onSuccess(requestAuth(claims, cookies))
}
}
}
/** Return a 401 Unauthorized response.
*
* Includes a `WWW-Authenticate` header with a custom challenge to login at the auth middleware.
* Lists the required claims in the `realm` and the login URI in the `login` parameter
* and the auth URI in the `auth` parameter.
*
* The challenge is also included in the response body
* as some browsers make it difficult to access the `WWW-Authenticate` header.
*/
private def unauthorized(claims: Request.Claims): StandardRoute = {
import com.daml.auth.middleware.api.JsonProtocol.responseAuthenticateChallengeFormat
val challenge = Response.AuthenticateChallenge(
claims,
loginUri(claims, None, false),
authUri(claims),
)
complete(
status = StatusCodes.Unauthorized,
headers = immutable.Seq(challenge.toHeader),
challenge,
)
}
def login(claims: Request.Claims, callback: Response.Login => Route): Route = {
val requestId = UUID.randomUUID()
if (callbacks.put(requestId, callback)) {
redirect(loginUri(claims, Some(requestId)), StatusCodes.Found)
} else {
complete(StatusCodes.ServiceUnavailable)
}
}
def loginUri(
claims: Request.Claims,
requestId: Option[UUID] = None,
redirect: Boolean = true,
): Uri = {
val redirectUri =
if (redirect) { Some(callbackUri) }
else { None }
appendToUri(
config.authMiddlewareExternalUri,
Path./("login"),
Request.Login(redirectUri, claims, requestId.map(_.toString)).toQuery,
)
}
}
/** Request authentication/authorization on the auth middleware's auth endpoint.
*
* @return `None` if the request was denied otherwise `Some` access and optionally refresh token.
*/
def requestAuth(claims: Request.Claims, cookies: immutable.Seq[headers.Cookie])(implicit
ec: ExecutionContext,
system: ActorSystem,
): Future[Option[Response.Authorize]] =
for {
response <- Http().singleRequest(
HttpRequest(
method = HttpMethods.GET,
uri = authUri(claims),
headers = cookies,
)
)
authorize <- response.status match {
case StatusCodes.OK =>
import JsonProtocol.responseAuthorizeFormat
Unmarshal(response.entity).to[Response.Authorize].map(Some(_))
case StatusCodes.Unauthorized =>
Future.successful(None)
case status =>
Unmarshal(response).to[String].flatMap { msg =>
Future.failed(AuthException(status, msg))
}
}
} yield authorize
/** Request a token refresh on the auth middleware's refresh endpoint.
*/
def requestRefresh(
refreshToken: RefreshToken
)(implicit ec: ExecutionContext, system: ActorSystem): Future[Response.Authorize] =
for {
requestEntity <- {
import JsonProtocol.requestRefreshFormat
Marshal(Request.Refresh(refreshToken))
.to[RequestEntity]
}
response <- Http().singleRequest(
HttpRequest(
method = HttpMethods.POST,
uri = refreshUri,
entity = requestEntity,
)
)
authorize <- response.status match {
case StatusCodes.OK =>
import JsonProtocol._
Unmarshal(response.entity).to[Response.Authorize]
case status =>
Unmarshal(response).to[String].flatMap { msg =>
Future.failed(RefreshException(status, msg))
}
}
} yield authorize
private def appendToUri(uri: Uri, path: Uri.Path, query: Uri.Query = Uri.Query.Empty): Uri = {
val newPath: Uri.Path = uri.path ++ path
val newQueryParams: Seq[(String, String)] = uri.query().toSeq ++ query.toSeq
val newQuery = Uri.Query(newQueryParams: _*)
uri.withPath(newPath).withQuery(newQuery)
}
def authUri(claims: Request.Claims): Uri =
appendToUri(
config.authMiddlewareInternalUri,
Path./("auth"),
Request.Auth(claims).toQuery,
)
val refreshUri: Uri =
appendToUri(config.authMiddlewareInternalUri, Path./("refresh"))
}
object Client {
sealed trait AuthorizeResult
final case class Authorized(authorization: Response.Authorize) extends AuthorizeResult
object Unauthorized extends AuthorizeResult
final case class LoginFailed(loginError: Response.LoginError) extends AuthorizeResult
abstract class ClientException(message: String) extends RuntimeException(message)
case class AuthException(status: StatusCode, message: String)
extends ClientException(s"Failed to authorize with middleware ($status): $message")
case class RefreshException(status: StatusCode, message: String)
extends ClientException(s"Failed to refresh token on middleware ($status): $message")
/** Whether to automatically redirect to the login endpoint when authorization fails.
*
* [[RedirectToLogin.Auto]] redirects for HTML requests (`text/html`)
* and returns 401 Unauthorized for JSON requests (`application/json`).
*/
sealed trait RedirectToLogin
object RedirectToLogin {
object No extends RedirectToLogin
object Yes extends RedirectToLogin
object Auto extends RedirectToLogin
}
case class Config(
authMiddlewareInternalUri: Uri,
authMiddlewareExternalUri: Uri,
redirectToLogin: RedirectToLogin,
maxAuthCallbacks: Int,
authCallbackTimeout: FiniteDuration,
maxHttpEntityUploadSize: Long,
httpEntityUploadTimeout: FiniteDuration,
)
trait Routes {
/** Handler for the callback in a login flow.
*
* Note, a GET request on the `callbackUri` must map to this route.
*/
def callbackHandler: Route
/** This directive requires authorization for the given claims via the auth middleware.
*
* Authorization follows the steps defined in `triggers/service/authentication.md`.
* 1. Ask for a token on the `/auth` endpoint and return it if granted.
* 2a. Return 401 Unauthorized if denied and [[Client.Config.redirectToLogin]]
* indicates not to redirect to the login endpoint.
* 2b. Redirect to the login endpoint if denied and [[Client.Config.redirectToLogin]]
* indicates to redirect to the login endpoint.
* In this case this will store the current continuation to proceed
* once the login flow completed and authentication succeeded.
* A route for the [[callbackHandler]] must be configured.
*/
def authorize(claims: Request.Claims): Directive1[Client.AuthorizeResult]
/** Redirect the client to login with the auth middleware.
*
* Will respond with 503 if the callback store is full ([[Client.Config.maxAuthCallbacks]]).
*
* @param callback Will be stored and executed once the login flow completed.
*/
def login(claims: Request.Claims, callback: Response.Login => Route): Route
def loginUri(
claims: Request.Claims,
requestId: Option[UUID] = None,
redirect: Boolean = true,
): Uri
}
def apply(config: Config): Client = new Client(config)
}

View File

@ -1,226 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.api
import org.apache.pekko.http.scaladsl.marshalling.Marshaller
import org.apache.pekko.http.scaladsl.model.headers.HttpChallenge
import org.apache.pekko.http.scaladsl.model.{HttpHeader, Uri, headers}
import org.apache.pekko.http.scaladsl.server.Directive1
import org.apache.pekko.http.scaladsl.server.Directives._
import org.apache.pekko.http.scaladsl.unmarshalling.Unmarshaller
import com.daml.lf.data.Ref
import scalaz.{@@, Tag}
import spray.json._
import scala.collection.mutable.ArrayBuffer
import scala.concurrent._
import scala.util.Try
import scala.language.postfixOps
object Tagged {
sealed trait AccessTokenTag
type AccessToken = String @@ AccessTokenTag
val AccessToken = Tag.of[AccessTokenTag]
sealed trait RefreshTokenTag
type RefreshToken = String @@ RefreshTokenTag
val RefreshToken = Tag.of[RefreshTokenTag]
}
object Request {
import Tagged._
// applicationId = None makes no guarantees about the application ID. You can use this
// if you dont use the token for requests that use the application ID.
// applicationId = Some(appId) will return a token that is valid for
// appId, i.e., either a wildcard token or a token with applicationId set to appId.
case class Claims(
admin: Boolean,
actAs: List[Ref.Party],
readAs: List[Ref.Party],
applicationId: Option[Ref.ApplicationId],
) {
def toQueryString() = {
val adminS = if (admin) LazyList("admin") else LazyList()
val actAsS = actAs.to(LazyList).map(party => s"actAs:$party")
val readAsS = readAs.to(LazyList).map(party => s"readAs:$party")
val applicationIdS = applicationId.toList.to(LazyList).map(appId => s"applicationId:$appId")
(adminS ++ actAsS ++ readAsS ++ applicationIdS).mkString(" ")
}
}
object Claims {
def apply(
admin: Boolean = false,
actAs: List[Ref.Party] = List(),
readAs: List[Ref.Party] = List(),
applicationId: Option[Ref.ApplicationId] = None,
): Claims =
new Claims(admin, actAs, readAs, applicationId)
def apply(s: String): Claims = {
var admin = false
val actAs = ArrayBuffer[Ref.Party]()
val readAs = ArrayBuffer[Ref.Party]()
var applicationId: Option[Ref.ApplicationId] = None
s.split(' ').foreach { w =>
if (w == "admin") {
admin = true
} else if (w.startsWith("actAs:")) {
actAs.append(Ref.Party.assertFromString(w.stripPrefix("actAs:")))
} else if (w.startsWith("readAs:")) {
readAs.append(Ref.Party.assertFromString(w.stripPrefix("readAs:")))
} else if (w.startsWith("applicationId:")) {
applicationId match {
case None =>
applicationId =
Some(Ref.ApplicationId.assertFromString(w.stripPrefix("applicationId:")))
case Some(_) =>
throw new IllegalArgumentException(
"applicationId claim can only be specified once"
)
}
} else {
throw new IllegalArgumentException(s"Expected claim but got $w")
}
}
Claims(admin, actAs.toList, readAs.toList, applicationId)
}
implicit val marshalRequestEntity: Marshaller[Claims, String] =
Marshaller.opaque(_.toQueryString())
implicit val unmarshalHttpEntity: Unmarshaller[String, Claims] =
Unmarshaller { _ => s => Future.fromTry(Try(apply(s))) }
}
/** Auth endpoint query parameters
*/
case class Auth(claims: Claims) {
def toQuery: Uri.Query = Uri.Query("claims" -> claims.toQueryString())
}
/** Login endpoint query parameters
*
* @param redirectUri Redirect target after the login flow completed. I.e. the original request URI on the trigger service.
* @param claims Required ledger claims.
* @param state State that will be forwarded to the callback URI after authentication and authorization.
*/
case class Login(redirectUri: Option[Uri], claims: Claims, state: Option[String]) {
def toQuery: Uri.Query = {
var params = Seq("claims" -> claims.toQueryString())
redirectUri.foreach(x => params ++= Seq("redirect_uri" -> x.toString()))
state.foreach(x => params ++= Seq("state" -> x))
Uri.Query(params: _*)
}
}
/** Refresh endpoint request entity
*/
case class Refresh(refreshToken: RefreshToken)
}
object Response {
import Tagged._
case class Authorize(accessToken: AccessToken, refreshToken: Option[RefreshToken])
sealed abstract class Login
final case class LoginError(error: String, errorDescription: Option[String]) extends Login
object LoginSuccess extends Login
object Login {
val callbackParameters: Directive1[Login] =
parameters(Symbol("error"), Symbol("error_description") ?)
.as[LoginError](LoginError)
.or(provide(LoginSuccess))
}
val authenticateChallengeName: String = "DamlAuthMiddleware"
case class AuthenticateChallenge(
realm: Request.Claims,
login: Uri,
auth: Uri,
) {
def toHeader: HttpHeader = headers.`WWW-Authenticate`(
HttpChallenge(
authenticateChallengeName,
realm.toQueryString(),
Map(
"login" -> login.toString(),
"auth" -> auth.toString(),
),
)
)
}
}
object JsonProtocol extends DefaultJsonProtocol {
import Tagged._
implicit object UriFormat extends JsonFormat[Uri] {
def read(value: JsValue) = value match {
case JsString(s) => Uri(s)
case _ => deserializationError(s"Expected Uri string but got $value")
}
def write(uri: Uri) = JsString(uri.toString)
}
implicit object AccessTokenJsonFormat extends JsonFormat[AccessToken] {
def write(x: AccessToken) = {
JsString(AccessToken.unwrap(x))
}
def read(value: JsValue) = value match {
case JsString(x) => AccessToken(x)
case x => deserializationError(s"Expected AccessToken as JsString, but got $x")
}
}
implicit object RefreshTokenJsonFormat extends JsonFormat[RefreshToken] {
def write(x: RefreshToken) = {
JsString(RefreshToken.unwrap(x))
}
def read(value: JsValue) = value match {
case JsString(x) => RefreshToken(x)
case x => deserializationError(s"Expected RefreshToken as JsString, but got $x")
}
}
implicit object RequestClaimsFormat extends JsonFormat[Request.Claims] {
def write(claims: Request.Claims) = {
JsString(claims.toQueryString())
}
def read(value: JsValue) = value match {
case JsString(s) =>
try {
Request.Claims(s)
} catch {
case ex: IllegalArgumentException =>
deserializationError(ex.getMessage, ex)
}
case x => deserializationError(s"Expected Claims as JsString, but got $x")
}
}
implicit val requestRefreshFormat: RootJsonFormat[Request.Refresh] =
jsonFormat(Request.Refresh, "refresh_token")
implicit val responseAuthorizeFormat: RootJsonFormat[Response.Authorize] =
jsonFormat(Response.Authorize, "access_token", "refresh_token")
implicit object ResponseLoginFormat extends RootJsonFormat[Response.Login] {
implicit private val errorFormat: RootJsonFormat[Response.LoginError] = jsonFormat2(
Response.LoginError
)
def write(login: Response.Login) = login match {
case error: Response.LoginError => error.toJson
case Response.LoginSuccess => JsNull
}
def read(value: JsValue) = value.convertTo(safeReader[Response.LoginError]) match {
case Right(error) => error
case Left(_) =>
value match {
case JsNull => Response.LoginSuccess
case _ => deserializationError(s"Expected null or error object but got $value")
}
}
}
implicit val responseAuthenticateChallengeFormat: RootJsonFormat[Response.AuthenticateChallenge] =
jsonFormat(Response.AuthenticateChallenge, "realm", "login", "auth")
}

View File

@ -1,73 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.api
import scala.collection.mutable.LinkedHashMap
import scala.concurrent.duration.FiniteDuration
/** A key-value store with a maximum capacity and maximum storage duration.
* @param maxCapacity Maximum number of requests that can be stored.
* @param timeout Duration after which requests will be evicted.
* @param monotonicClock Determines the current timestamp. The underlying clock must be monotonic.
* The JVM will use a monotonic clock for [[System.nanoTime]], if available, according to
* [[https://bugs.openjdk.java.net/browse/JDK-6458294?focusedCommentId=13823604&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-13823604 JDK bug 6458294]]
* @tparam K The key type
* @tparam V The value type
*/
private[middleware] class RequestStore[K, V](
maxCapacity: Int,
timeout: FiniteDuration,
monotonicClock: () => Long = () => System.nanoTime,
) {
/** Mapping from key to insertion timestamp and value.
* The timestamp of later inserted elements must be greater or equal to the timestamp of earlier inserted elements.
*/
private val store: LinkedHashMap[K, (Long, V)] = LinkedHashMap.empty
/** Check whether the given [[timestamp]] timed out relative to the current time [[now]].
*/
private def timedOut(now: Long, timestamp: Long): Boolean = {
now - timestamp >= timeout.toNanos
}
private def evictTimedOut(now: Long): Unit = {
// Remove items until their timestamp is more recent than the configured timeout.
store.iterator
.takeWhile { case (_, (t, _)) =>
timedOut(now, t)
}
.foreach { case (k, _) =>
store.remove(k)
}
}
/** Insert a new key-value pair unless the maximum capacity is reached.
* Evicts timed out elements before attempting insertion.
* @return whether the key-value pair was inserted.
*/
def put(key: K, value: => V): Boolean = {
synchronized {
val now = monotonicClock()
evictTimedOut(now)
if (store.size >= maxCapacity) {
false
} else {
store.update(key, (now, value))
true
}
}
}
/** Remove and return the value under the given key, if present and not timed out.
*/
def pop(key: K): Option[V] = {
synchronized {
store.remove(key).flatMap {
case (t, _) if timedOut(monotonicClock(), t) => None
case (_, v) => Some(v)
}
}
}
}

View File

@ -1,241 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.oauth2
import org.apache.pekko.http.scaladsl.model.Uri
import com.daml.auth.middleware.oauth2.Config.{
DefaultCookieSecure,
DefaultHttpPort,
DefaultLoginTimeout,
DefaultMaxLoginRequests,
}
import com.daml.cliopts
import com.daml.jwt.{JwtVerifierBase, JwtVerifierConfigurationCli}
import com.daml.metrics.api.reporters.MetricsReporter
import com.typesafe.scalalogging.StrictLogging
import pureconfig.ConfigSource
import pureconfig.error.ConfigReaderFailures
import scopt.OptionParser
import java.io.File
import java.nio.file.{Path, Paths}
import scala.concurrent.duration
import scala.concurrent.duration.FiniteDuration
import scalaz.syntax.std.option._
private[oauth2] final case class Cli(
configFile: Option[File] = None,
// Host and port the middleware listens on
address: String = cliopts.Http.defaultAddress,
port: Int = DefaultHttpPort,
portFile: Option[Path] = None,
// The URI to which the OAuth2 server will redirect after a completed login flow.
// Must map to the `/cb` endpoint of the auth middleware.
callbackUri: Option[Uri] = None,
maxLoginRequests: Int = DefaultMaxLoginRequests,
loginTimeout: FiniteDuration = DefaultLoginTimeout,
cookieSecure: Boolean = DefaultCookieSecure,
// OAuth2 server endpoints
oauthAuth: Uri,
oauthToken: Uri,
// OAuth2 server request templates
oauthAuthTemplate: Option[Path],
oauthTokenTemplate: Option[Path],
oauthRefreshTemplate: Option[Path],
// OAuth2 client properties
clientId: String,
clientSecret: SecretString,
// Token verification
tokenVerifier: JwtVerifierBase,
metricsReporter: Option[MetricsReporter] = None,
metricsReportingInterval: FiniteDuration = FiniteDuration(10, duration.SECONDS),
) extends StrictLogging {
def loadFromConfigFile: Option[Either[ConfigReaderFailures, FileConfig]] = {
configFile.map(cf => ConfigSource.file(cf).load[FileConfig])
}
def loadFromCliArgs: Config = {
val cfg = Config(
address,
port,
portFile,
callbackUri,
maxLoginRequests,
loginTimeout,
cookieSecure,
oauthAuth,
oauthToken,
oauthAuthTemplate,
oauthTokenTemplate,
oauthRefreshTemplate,
clientId,
clientSecret,
tokenVerifier,
metricsReporter,
metricsReportingInterval,
Seq.empty,
)
cfg.validate()
cfg
}
def loadConfig: Option[Config] = {
loadFromConfigFile.cata(
{
case Right(cfg) => Some(cfg.toConfig())
case Left(ex) =>
logger.error(
s"Error loading oauth2-middleware config from file $configFile",
ex.prettyPrint(),
)
None
}, {
logger.warn("Using cli opts for running oauth2-middleware is deprecated")
Some(loadFromCliArgs)
},
)
}
}
@SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements"))
private[oauth2] object Cli {
private[oauth2] val Default =
Cli(
configFile = None,
address = cliopts.Http.defaultAddress,
port = DefaultHttpPort,
portFile = None,
callbackUri = None,
maxLoginRequests = DefaultMaxLoginRequests,
loginTimeout = DefaultLoginTimeout,
cookieSecure = DefaultCookieSecure,
oauthAuth = null,
oauthToken = null,
oauthAuthTemplate = None,
oauthTokenTemplate = None,
oauthRefreshTemplate = None,
clientId = null,
clientSecret = null,
tokenVerifier = null,
)
private val parser: OptionParser[Cli] = new scopt.OptionParser[Cli]("oauth-middleware") {
help('h', "help").text("Print usage")
opt[Option[File]]('c', "config")
.text(
"This is the recommended way to provide an app config file, the remaining cli-args are deprecated"
)
.valueName("<file>")
.action((file, cli) => cli.copy(configFile = file))
cliopts.Http.serverParse(this, serviceName = "OAuth2 Middleware")(
address = (f, c) => c.copy(address = f(c.address)),
httpPort = (f, c) => c.copy(port = f(c.port)),
defaultHttpPort = Some(DefaultHttpPort),
portFile = Some((f, c) => c.copy(portFile = f(c.portFile))),
)
opt[String]("callback")
.action((x, c) => c.copy(callbackUri = Some(Uri(x))))
.text(
"URI to the auth middleware's callback endpoint `/cb`. By default constructed from the incoming login request."
)
opt[Int]("max-pending-login-requests")
.action((x, c) => c.copy(maxLoginRequests = x))
.text(
"Maximum number of simultaneously pending login requests. Requests will be denied when exceeded until earlier requests have been completed or timed out."
)
opt[Boolean]("cookie-secure")
.action((x, c) => c.copy(cookieSecure = x))
.text(
"Enable the Secure attribute on the cookie that stores the token. Defaults to true. Only disable this for testing and development purposes."
)
opt[Long]("login-request-timeout")
.action((x, c) => c.copy(loginTimeout = FiniteDuration(x, duration.SECONDS)))
.text(
"Login request timeout. Requests will be evicted if the callback endpoint receives no corresponding request in time."
)
opt[String]("oauth-auth")
.action((x, c) => c.copy(oauthAuth = Uri(x)))
.text("URI of the OAuth2 authorization endpoint")
opt[String]("oauth-token")
.action((x, c) => c.copy(oauthToken = Uri(x)))
.text("URI of the OAuth2 token endpoint")
opt[String]("oauth-auth-template")
.action((x, c) => c.copy(oauthAuthTemplate = Some(Paths.get(x))))
.text("OAuth2 authorization request Jsonnet template")
opt[String]("oauth-token-template")
.action((x, c) => c.copy(oauthTokenTemplate = Some(Paths.get(x))))
.text("OAuth2 token request Jsonnet template")
opt[String]("oauth-refresh-template")
.action((x, c) => c.copy(oauthRefreshTemplate = Some(Paths.get(x))))
.text("OAuth2 refresh request Jsonnet template")
opt[String]("id")
.hidden()
.action((x, c) => c.copy(clientId = x))
.withFallback(() => sys.env.getOrElse("DAML_CLIENT_ID", ""))
opt[String]("secret")
.hidden()
.action((x, c) => c.copy(clientSecret = SecretString(x)))
.withFallback(() => sys.env.getOrElse("DAML_CLIENT_SECRET", ""))
cliopts.Metrics.metricsReporterParse(this)(
(f, c) => c.copy(metricsReporter = f(c.metricsReporter)),
(f, c) => c.copy(metricsReportingInterval = f(c.metricsReportingInterval)),
)
JwtVerifierConfigurationCli.parse(this)((v, c) => c.copy(tokenVerifier = v))
checkConfig { cfg =>
if (cfg.configFile.isEmpty && cfg.tokenVerifier == null)
Left("You must specify one of the --auth-jwt-* flags for token verification.")
else
Right(())
}
checkConfig { cfg =>
if (cfg.configFile.isEmpty && (cfg.clientId.isEmpty || cfg.clientSecret.value.isEmpty))
Left("Environment variable DAML_CLIENT_ID AND DAML_CLIENT_SECRET must not be empty")
else
Right(())
}
checkConfig { cfg =>
if (cfg.configFile.isEmpty && (cfg.oauthAuth == null || cfg.oauthToken == null))
Left("oauth-auth and oauth-token values must not be empty")
else
Right(())
}
checkConfig { cfg =>
val cliOptionsAreDefined =
cfg.oauthToken != null || cfg.oauthAuth != null || cfg.tokenVerifier != null
if (cfg.configFile.isDefined && cliOptionsAreDefined) {
Left("Found both config file and cli opts for the app, please provide only one of them")
} else Right(())
}
override def showUsageOnError: Option[Boolean] = Some(true)
}
def parse(args: Array[String]): Option[Cli] = parser.parse(args, Default)
def parseConfig(args: Array[String]): Option[Config] = {
val cli = parse(args)
cli.flatMap(_.loadConfig)
}
}

View File

@ -1,114 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.oauth2
import java.nio.file.Path
import org.apache.pekko.http.scaladsl.model.Uri
import com.daml.auth.middleware.oauth2.Config._
import com.daml.cliopts
import com.daml.jwt.JwtVerifierBase
import com.daml.metrics.{HistogramDefinition, MetricsConfig}
import com.daml.metrics.api.reporters.MetricsReporter
import com.daml.pureconfigutils.SharedConfigReaders._
import pureconfig.{ConfigReader, ConvertHelpers}
import pureconfig.generic.semiauto.deriveReader
import scala.concurrent.duration._
final case class Config(
// Host and port the middleware listens on
address: String = cliopts.Http.defaultAddress,
port: Int = DefaultHttpPort,
portFile: Option[Path] = None,
// The URI to which the OAuth2 server will redirect after a completed login flow.
// Must map to the `/cb` endpoint of the auth middleware.
callbackUri: Option[Uri] = None,
maxLoginRequests: Int = DefaultMaxLoginRequests,
loginTimeout: FiniteDuration = DefaultLoginTimeout,
cookieSecure: Boolean = DefaultCookieSecure,
// OAuth2 server endpoints
oauthAuth: Uri,
oauthToken: Uri,
// OAuth2 server request templates
oauthAuthTemplate: Option[Path] = None,
oauthTokenTemplate: Option[Path] = None,
oauthRefreshTemplate: Option[Path] = None,
// OAuth2 client properties
clientId: String,
clientSecret: SecretString,
// Token verification
tokenVerifier: JwtVerifierBase,
metricsReporter: Option[MetricsReporter] = None,
metricsReportingInterval: FiniteDuration = 10.seconds,
histograms: Seq[HistogramDefinition],
) {
def validate(): Unit = {
require(oauthToken != null, "Oauth token value on config cannot be null")
require(oauthAuth != null, "Oauth auth value on config cannot be null")
require(clientId.nonEmpty, "DAML_CLIENT_ID cannot be empty")
require(clientSecret.value.nonEmpty, "DAML_CLIENT_SECRET cannot be empty")
require(tokenVerifier != null, "token verifier must be defined")
}
}
@scala.annotation.nowarn("msg=Block result was adapted via implicit conversion")
object FileConfig {
implicit val clientSecretReader: ConfigReader[SecretString] =
ConfigReader.fromString[SecretString](ConvertHelpers.catchReadError(s => SecretString(s)))
implicit val cfgReader: ConfigReader[FileConfig] = deriveReader[FileConfig]
}
final case class FileConfig(
address: String = cliopts.Http.defaultAddress,
port: Int = DefaultHttpPort,
portFile: Option[Path] = None,
callbackUri: Option[Uri] = None,
maxLoginRequests: Int = DefaultMaxLoginRequests,
loginTimeout: FiniteDuration = DefaultLoginTimeout,
cookieSecure: Boolean = DefaultCookieSecure,
oauthAuth: Uri,
oauthToken: Uri,
oauthAuthTemplate: Option[Path] = None,
oauthTokenTemplate: Option[Path] = None,
oauthRefreshTemplate: Option[Path] = None,
clientId: String,
clientSecret: SecretString,
tokenVerifier: JwtVerifierBase,
metrics: Option[MetricsConfig] = None,
) {
def toConfig(): Config = Config(
address = address,
port = port,
portFile = portFile,
callbackUri = callbackUri,
maxLoginRequests = maxLoginRequests,
loginTimeout = loginTimeout,
cookieSecure = cookieSecure,
oauthAuth = oauthAuth,
oauthToken = oauthToken,
oauthAuthTemplate = oauthAuthTemplate,
oauthTokenTemplate = oauthTokenTemplate,
oauthRefreshTemplate = oauthRefreshTemplate,
clientId = clientId,
clientSecret = clientSecret,
tokenVerifier = tokenVerifier,
metricsReporter = metrics.map(_.reporter),
metricsReportingInterval =
metrics.map(_.reportingInterval).getOrElse(MetricsConfig.DefaultMetricsReportingInterval),
histograms = metrics.toList.flatMap(_.histograms),
)
}
final case class SecretString(value: String) {
override def toString: String = "###"
}
object Config {
val DefaultHttpPort: Int = 3000
val DefaultCookieSecure: Boolean = true
val DefaultMaxLoginRequests: Int = 100
val DefaultLoginTimeout: FiniteDuration = 5.minutes
}

View File

@ -1,46 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.oauth2
import org.apache.pekko.actor.ActorSystem
import com.daml.scalautil.Statement.discard
import com.typesafe.scalalogging.StrictLogging
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext}
import scala.util.{Failure, Success}
object Main extends StrictLogging {
def main(args: Array[String]): Unit = {
Cli.parseConfig(args) match {
case Some(config) => main(config)
case None => sys.exit(1)
}
}
private def main(config: Config): Unit = {
implicit val system: ActorSystem = ActorSystem("system")
implicit val executionContext: ExecutionContext = system.dispatcher
def terminate() = Await.result(system.terminate(), 10.seconds)
val bindingFuture = Server.start(config, registerGlobalOpenTelemetry = true)
discard(
sys.addShutdownHook(
Server.stop(bindingFuture).onComplete(_ => terminate())
)
)
logger.debug(s"Configuration $config")
bindingFuture.onComplete {
case Success(binding) =>
logger.info(s"Started server: $binding")
case Failure(e) =>
logger.error(s"Failed to start server: $e")
terminate()
}
}
}

View File

@ -1,16 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.oauth2
import com.daml.metrics.api.opentelemetry.OpenTelemetryMetricsFactory
import com.daml.metrics.http.DamlHttpMetrics
import io.opentelemetry.api.metrics.{Meter => OtelMeter}
case class Oauth2MiddlewareMetrics(otelMeter: OtelMeter) {
val openTelemetryFactory = new OpenTelemetryMetricsFactory(otelMeter)
val http = new DamlHttpMetrics(openTelemetryFactory, "oauth2-middleware")
}

View File

@ -1,111 +0,0 @@
# Trigger Service OAuth 2.0 Middleware
Implements an OAuth2 middleware according to the trigger service
authentication/authorization specification in
`triggers/service/authentication.md`.
## Manual Testing against Auth0
Apart from the automated tests defined in this repository, the middleware can
be tested manually against an auth0 OAuth2 setup. The necessary steps are
extracted and adapted from the [Secure Daml Infrastructure
repository](https://github.com/digital-asset/ex-secure-daml-infra).
### Setup
* Sign up for an account on [Auth0](https://auth0.com).
* Create a new API.
- Provide a name (`ex-daml-api`).
- Provide an Identifier (`https://daml.com/ledger-api`).
- Select Signing Algorithm of `RS256`.
- Allow offline access to enable refresh token generation.
This allows the OAuth2 client, i.e. the auth middleware, to request access through a refresh token when the resource owner, i.e. the user, is offline.
* Create a new native application.
- Provide a name (`ex-daml-auth-middleware`).
- Select the authorized API (`ex-daml-api`).
- Configure the allowed callback URLs in the settings (`http://localhost:3000/cb`).
- Note the "Client ID" and "Client Secret" displayed in the "Basic
Information" pane of the application settings.
- Note the "OAuth Authorization URL" and the "OAuth Token URL" in the
"Endpoints" tab of the advanced settings.
* Create a new empty rule.
- Provide a name (`ex-daml-claims`).
- Provide a script
``` javascript
function (user, context, callback) {
// Only handle ledger-api audience.
const audience = context.request.query && context.request.query.audience || "";
if (audience !== "https://daml.com/ledger-api") {
return callback(null, user, context);
}
// Grant all requested claims
const scope = (context.request.query && context.request.query.scope || "").split(" ");
var actAs = [];
var readAs = [];
var admin = false;
scope.forEach(s => {
if (s.startsWith("actAs:")) {
actAs.push(s.slice(6));
} else if (s.startsWith("readAs:")) {
readAs.push(s.slice(7));
} else if (s === "admin") {
admin = true;
}
});
// Construct access token.
const namespace = 'https://daml.com/ledger-api';
context.accessToken[namespace] = {
// NOTE change the ledger ID to match your deployment.
"ledgerId": "2D105384-CE61-4CCC-8E0E-37248BA935A3",
"actAs": actAs,
"readAs": readAs,
"admin": admin
};
return callback(null, user, context);
}
```
* Create a new user.
- Provide an email address (`alice@localhost`)
- Provide a secure password
- Mark the email address as verified on the user's "Details" page.
### Testing
* Start the middleware by executing the following command.
```
$ DAML_CLIENT_ID=CLIENTID \
DAML_CLIENT_SECRET=CLIENTSECRET \
bazel run //triggers/service/auth:oauth-middleware-binary -- \
--config oauth-middleware.conf
```
- Replace `CLIENTID` and `CLIENTSECRET` by the "Client ID" and "Client
Secret" from above.
The basic minimal config that needs to be supplied needs to have appropriate
`callback-uri`,`oauth-auth` and `oauth-token` urls defined,
along with the `token-verifier`,`client-id` and `client-secret` fields. e.g
```
{
callback-uri = "https://example.com/auth/cb"
oauth-auth = "https://XYZ.auth0.com/authorize"
oauth-token = "https://XYZ.auth0.com/oauth/token"
client-id = ${DAML_CLIENT_ID}
client-secret = ${DAML_CLIENT_SECRET}
// type can be one of rs256-crt, es256-crt, es512-crt, rs256-jwks
// uri is the uri to the cert file or the jwks url
token-verifier {
type = "rs256-jwks"
uri = "https://example.com/.well-known/jwks.json"
}
}
```
- Browse to the middleware's login endpoint.
- URL `http://localhost:3000/login?redirect_uri=callback&claims=actAs:Alice`
- Login as the new user created above.
- Authorize the middleware application to access the tenant.
- You should be redirected to `http://localhost:3000/callback`.

View File

@ -1,171 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.oauth2
import java.nio.file.Path
import java.util.UUID
import org.apache.pekko.http.scaladsl.model.Uri
import com.daml.auth.middleware.api.Request
import com.daml.auth.middleware.api.Tagged.RefreshToken
import scala.collection.concurrent.TrieMap
import scala.io.{BufferedSource, Source}
import scala.util.Try
private[oauth2] class RequestTemplates(
clientId: String,
clientSecret: SecretString,
authTemplate: Option[Path],
tokenTemplate: Option[Path],
refreshTemplate: Option[Path],
) {
private val authResourcePath: String = "auth0_request_authorization.jsonnet"
private val tokenResourcePath: String = "auth0_request_token.jsonnet"
private val refreshResourcePath: String = "auth0_request_refresh.jsonnet"
/** Load a Jsonnet source file.
* @param optFilePath Load from this file path, if provided.
* @param resourcePath Load from this JAR resource, if no file is provided.
* @return Content and file path (for error reporting) of the loaded Jsonnet file.
*/
private def jsonnetSource(
optFilePath: Option[Path],
resourcePath: String,
): (String, sjsonnet.Path) = {
def readSource(source: BufferedSource): String = {
try { source.mkString }
finally { source.close() }
}
optFilePath match {
case Some(filePath) =>
val content: String = readSource(Source.fromFile(filePath.toString))
val path: sjsonnet.Path = sjsonnet.OsPath(os.Path(filePath.toAbsolutePath))
(content, path)
case None =>
val resource = getClass.getResource(resourcePath)
val content: String = readSource(Source.fromInputStream(resource.openStream()))
// This path is only used for error reporting and a builtin template should not raise any errors.
// However, if it does it should be clear that the path refers to a builtin file.
// Paths are reported relative to `$PWD`, we prefix `$PWD` to avoid `../../` noise.
val path: sjsonnet.Path =
sjsonnet.OsPath(os.RelPath(s"BUILTIN/$resourcePath").resolveFrom(os.pwd))
(content, path)
}
}
private val jsonnetParseCache
: TrieMap[String, fastparse.Parsed[(sjsonnet.Expr, Map[String, Int])]] = TrieMap.empty
/** Interpret the given Jsonnet code.
* @param source The Jsonnet source code.
* @param sourcePath The Jsonnet source file path (for error reporting).
* @param arguments Top-level arguments to pass to the Jsonnet code.
* @return The resulting JSON value.
*/
private def interpretJsonnet(
source: String,
sourcePath: sjsonnet.Path,
arguments: Map[String, ujson.Value],
): Try[ujson.Value] = {
val interp = new sjsonnet.Interpreter(
jsonnetParseCache,
Map(),
arguments,
sjsonnet.OsPath(os.pwd),
importer = sjsonnet.SjsonnetMain.resolveImport(Nil, None),
)
interp
.interpret(source, sourcePath)
.left
.map(new RequestTemplates.InterpretTemplateException(_))
.toTry
}
/** Convert a JSON value to a string mapping representing request parameters.
*/
private def toRequestParams(value: ujson.Value): Try[Map[String, String]] =
Try(value.obj.view.mapValues(_.str).toMap)
private def createRequest(
template: (String, sjsonnet.Path),
args: Map[String, ujson.Value],
): Try[Map[String, String]] = {
val (jsonnet_src, jsonnet_path) = template
interpretJsonnet(jsonnet_src, jsonnet_path, args).flatMap(toRequestParams)
}
private lazy val config: ujson.Value = ujson.Obj(
"clientId" -> clientId,
"clientSecret" -> clientSecret.value,
)
private lazy val authJsonnetSource: (String, sjsonnet.Path) =
jsonnetSource(authTemplate, authResourcePath)
private def authArguments(
claims: Request.Claims,
requestId: UUID,
redirectUri: Uri,
): Map[String, ujson.Value] =
Map(
"config" -> config,
"request" -> ujson.Obj(
"claims" -> ujson.Obj(
"admin" -> claims.admin,
"applicationId" -> (claims.applicationId match {
case Some(appId) => appId
case None => ujson.Null
}),
"actAs" -> claims.actAs,
"readAs" -> claims.readAs,
),
"redirectUri" -> redirectUri.toString,
"state" -> requestId.toString,
),
)
def createAuthRequest(
claims: Request.Claims,
requestId: UUID,
redirectUri: Uri,
): Try[Map[String, String]] = {
createRequest(authJsonnetSource, authArguments(claims, requestId, redirectUri))
}
private lazy val tokenJsonnetSource: (String, sjsonnet.Path) =
jsonnetSource(tokenTemplate, tokenResourcePath)
private def tokenArguments(code: String, redirectUri: Uri): Map[String, ujson.Value] = Map(
"config" -> config,
"request" -> ujson.Obj(
"code" -> code,
"redirectUri" -> redirectUri.toString,
),
)
def createTokenRequest(code: String, redirectUri: Uri): Try[Map[String, String]] =
createRequest(tokenJsonnetSource, tokenArguments(code, redirectUri))
private lazy val refreshJsonnetSource: (String, sjsonnet.Path) =
jsonnetSource(refreshTemplate, refreshResourcePath)
private def refreshArguments(refreshToken: RefreshToken): Map[String, ujson.Value] = Map(
"config" -> config,
"request" -> ujson.Obj(
"refreshToken" -> RefreshToken.unwrap(refreshToken)
),
)
def createRefreshRequest(refreshToken: RefreshToken): Try[Map[String, String]] =
createRequest(refreshJsonnetSource, refreshArguments(refreshToken))
}
object RequestTemplates {
class InterpretTemplateException(msg: String) extends RuntimeException(msg)
def apply(
clientId: String,
clientSecret: SecretString,
authTemplate: Option[Path],
tokenTemplate: Option[Path],
refreshTemplate: Option[Path],
): RequestTemplates =
new RequestTemplates(clientId, clientSecret, authTemplate, tokenTemplate, refreshTemplate)
}

View File

@ -1,395 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.oauth2
import org.apache.pekko.Done
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.http.scaladsl.Http
import org.apache.pekko.http.scaladsl.Http.ServerBinding
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import org.apache.pekko.http.scaladsl.model._
import org.apache.pekko.http.scaladsl.model.headers.{HttpCookie, HttpCookiePair}
import org.apache.pekko.http.scaladsl.server.{Directive1, Route}
import org.apache.pekko.http.scaladsl.server.Directives._
import org.apache.pekko.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller}
import com.daml.auth.oauth2.api.{JsonProtocol => OAuthJsonProtocol, Response => OAuthResponse}
import com.daml.ledger.api.{auth => lapiauth}
import com.daml.ledger.resources.ResourceContext
import com.daml.metrics.api.reporters.MetricsReporting
import com.daml.metrics.pekkohttp.HttpMetricsInterceptor
import com.typesafe.scalalogging.StrictLogging
import java.util.UUID
import com.daml.auth.middleware.api.{Request, RequestStore, Response}
import com.daml.jwt.{JwtDecoder, JwtVerifierBase}
import com.daml.jwt.domain.Jwt
import com.daml.auth.middleware.api.Tagged.{AccessToken, RefreshToken}
import com.daml.ports.{Port, PortFiles}
import scalaz.{-\/, \/-}
import spray.json._
import scala.concurrent.{ExecutionContext, Future}
import scala.language.postfixOps
import scala.util.{Failure, Success, Try}
// This is an implementation of the trigger service auth middleware
// for OAuth2 as specified in `/triggers/service/authentication.md`
class Server(config: Config) extends StrictLogging {
import com.daml.auth.middleware.api.JsonProtocol._
import com.daml.auth.oauth2.api.JsonProtocol._
import Server.rightsProvideClaims
implicit private val unmarshal: Unmarshaller[String, Uri] = Unmarshaller.strict(Uri(_))
private def toRedirectUri(uri: Uri) =
config.callbackUri.getOrElse {
Uri()
.withScheme(uri.scheme)
.withAuthority(uri.authority)
.withPath(Uri.Path./("cb"))
}
private val cookieName = "daml-ledger-token"
private def optionalToken: Directive1[Option[OAuthResponse.Token]] = {
def f(x: HttpCookiePair) = OAuthResponse.Token.fromCookieValue(x.value)
optionalCookie(cookieName).map(_.flatMap(f))
}
// Check whether the provided token's signature is valid.
private def tokenIsValid(accessToken: String, verifier: JwtVerifierBase): Boolean = {
verifier.verify(Jwt(accessToken)).isRight
}
// Check whether the provided access token grants at least the requested claims.
private def tokenProvidesClaims(accessToken: String, claims: Request.Claims): Boolean = {
for {
decodedJwt <- JwtDecoder.decode(Jwt(accessToken)).toOption
tokenPayload <- lapiauth.AuthServiceJWTCodec
.readFromString(decodedJwt.payload)
.toOption
} yield rightsProvideClaims(tokenPayload, claims)
} getOrElse false
private val requestTemplates: RequestTemplates = RequestTemplates(
config.clientId,
config.clientSecret,
config.oauthAuthTemplate,
config.oauthTokenTemplate,
config.oauthRefreshTemplate,
)
private def onTemplateSuccess(
request: String,
tryParams: Try[Map[String, String]],
): Directive1[Map[String, String]] =
tryParams match {
case Failure(exception) =>
logger.error(s"Failed to interpret $request request template: ${exception.getMessage}")
complete(StatusCodes.InternalServerError, s"Failed to construct $request request")
case Success(params) => provide(params)
}
private val auth: Route =
parameters(Symbol("claims").as[Request.Claims])
.as[Request.Auth](Request.Auth) { auth =>
optionalToken {
case Some(token)
if tokenIsValid(token.accessToken, config.tokenVerifier) &&
tokenProvidesClaims(token.accessToken, auth.claims) =>
complete(
Response
.Authorize(
accessToken = AccessToken(token.accessToken),
refreshToken = RefreshToken.subst(token.refreshToken),
)
)
// TODO[AH] Include a `WWW-Authenticate` header.
case _ => complete(StatusCodes.Unauthorized)
}
}
private val requests: RequestStore[UUID, Option[Uri]] =
new RequestStore(config.maxLoginRequests, config.loginTimeout)
private val login: Route =
parameters(
Symbol("redirect_uri").as[Uri] ?,
Symbol("claims").as[Request.Claims],
Symbol("state") ?,
)
.as[Request.Login](Request.Login) { login =>
extractRequest { request =>
val requestId = UUID.randomUUID
val stored = requests.put(
requestId,
login.redirectUri.map { redirectUri =>
var query = redirectUri.query().to(Seq)
login.state.foreach(x => query ++= Seq("state" -> x))
redirectUri.withQuery(Uri.Query(query: _*))
},
)
if (stored) {
onTemplateSuccess(
"authorization",
requestTemplates.createAuthRequest(
login.claims,
requestId,
toRedirectUri(request.uri),
),
) { params =>
val query = Uri.Query(params)
val uri = config.oauthAuth.withQuery(query)
redirect(uri, StatusCodes.Found)
}
} else {
complete(StatusCodes.ServiceUnavailable)
}
}
}
private val loginCallback: Route = {
extractActorSystem { implicit sys =>
extractExecutionContext { implicit ec =>
def popRequest(optState: Option[String]): Directive1[Option[Uri]] = {
val redirectUri = for {
state <- optState
requestId <- Try(UUID.fromString(state)).toOption
redirectUri <- requests.pop(requestId)
} yield redirectUri
redirectUri match {
case Some(redirectUri) => provide(redirectUri)
case None => complete(StatusCodes.NotFound)
}
}
concat(
parameters(Symbol("code"), Symbol("state") ?)
.as[OAuthResponse.Authorize](OAuthResponse.Authorize) { authorize =>
popRequest(authorize.state) { redirectUri =>
extractRequest { request =>
onTemplateSuccess(
"token",
requestTemplates.createTokenRequest(authorize.code, toRedirectUri(request.uri)),
) { params =>
val entity = FormData(params).toEntity
val req = HttpRequest(
uri = config.oauthToken,
entity = entity,
method = HttpMethods.POST,
)
val tokenRequest =
for {
resp <- Http().singleRequest(req)
tokenResp <-
if (resp.status != StatusCodes.OK) {
Unmarshal(resp).to[String].flatMap { msg =>
Future.failed(
new RuntimeException(
s"Failed to retrieve token at ${req.uri} (${resp.status}): $msg"
)
)
}
} else {
Unmarshal(resp).to[OAuthResponse.Token]
}
} yield tokenResp
onSuccess(tokenRequest) { token =>
setCookie(
HttpCookie(
name = cookieName,
value = token.toCookieValue,
path = Some("/"),
maxAge = token.expiresIn.map(_.toLong),
secure = config.cookieSecure,
httpOnly = true,
)
) {
redirectUri match {
case Some(uri) =>
redirect(uri, StatusCodes.Found)
case None =>
complete(StatusCodes.OK)
}
}
}
}
}
}
},
parameters(
Symbol("error"),
Symbol("error_description") ?,
Symbol("error_uri").as[Uri] ?,
Symbol("state") ?,
)
.as[OAuthResponse.Error](OAuthResponse.Error) { error =>
popRequest(error.state) {
case Some(redirectUri) =>
val uri = redirectUri.withQuery {
var params = redirectUri.query().to(Seq)
params ++= Seq("error" -> error.error)
error.errorDescription.foreach(x => params ++= Seq("error_description" -> x))
Uri.Query(params: _*)
}
redirect(uri, StatusCodes.Found)
case None =>
import OAuthJsonProtocol.errorRespFormat
complete(StatusCodes.Forbidden, error)
}
},
)
}
}
}
private val refresh: Route = {
extractActorSystem { implicit sys =>
extractExecutionContext { implicit ec =>
entity(as[Request.Refresh]) { refresh =>
onTemplateSuccess(
"refresh",
requestTemplates.createRefreshRequest(refresh.refreshToken),
) { params =>
val entity = FormData(params).toEntity
val req =
HttpRequest(uri = config.oauthToken, entity = entity, method = HttpMethods.POST)
val tokenRequest = Http().singleRequest(req)
onSuccess(tokenRequest) { resp =>
resp.status match {
// Return access and refresh token on success.
case StatusCodes.OK =>
val authResponse = Unmarshal(resp).to[OAuthResponse.Token].map { token =>
Response.Authorize(
accessToken = AccessToken(token.accessToken),
refreshToken = RefreshToken.subst(token.refreshToken),
)
}
complete(authResponse)
// Forward client errors.
case status: StatusCodes.ClientError =>
complete(HttpResponse.apply(status = status, entity = resp.entity))
// Fail on unexpected responses.
case _ =>
onSuccess(Unmarshal(resp).to[String]) { msg =>
failWith(
new RuntimeException(
s"Failed to retrieve refresh token (${resp.status}): $msg"
)
)
}
}
}
}
}
}
}
}
def route: Route = concat(
path("auth") {
get {
auth
}
},
path("login") {
get {
login
}
},
path("cb") {
get {
loginCallback
}
},
path("refresh") {
post {
refresh
}
},
path("livez") {
complete(StatusCodes.OK, JsObject("status" -> JsString("pass")))
},
path("readyz") {
complete(StatusCodes.OK, JsObject("status" -> JsString("pass")))
},
)
}
object Server extends StrictLogging {
def start(config: Config, registerGlobalOpenTelemetry: Boolean)(implicit
sys: ActorSystem
): Future[ServerBinding] = {
implicit val ec: ExecutionContext = sys.getDispatcher
implicit val rc: ResourceContext = ResourceContext(ec)
val metricsReporting = new MetricsReporting(
getClass.getName,
config.metricsReporter,
config.metricsReportingInterval,
registerGlobalOpenTelemetry,
config.histograms,
)((_, otelMeter) => Oauth2MiddlewareMetrics(otelMeter))
val metricsResource = metricsReporting.acquire()
val rateDurationSizeMetrics = metricsResource.asFuture.map { implicit metrics =>
HttpMetricsInterceptor.rateDurationSizeMetrics(
metrics.http
)
}
val route = new Server(config).route
for {
metricsInterceptor <- rateDurationSizeMetrics
binding <- Http()
.newServerAt(config.address, config.port)
.bind(metricsInterceptor apply route)
_ <- config.portFile match {
case Some(portFile) =>
PortFiles.write(portFile, Port(binding.localAddress.getPort)) match {
case -\/(err) =>
Future.failed(new RuntimeException(s"Failed to create port file: ${err.toString}"))
case \/-(()) => Future.successful(())
}
case None => Future.successful(())
}
} yield binding
}
def stop(f: Future[ServerBinding])(implicit ec: ExecutionContext): Future[Done] =
f.flatMap(_.unbind())
private[oauth2] def rightsProvideClaims(
r: lapiauth.AuthServiceJWTPayload,
claims: Request.Claims,
): Boolean = {
val (precond, userId) = r match {
case tp: lapiauth.CustomDamlJWTPayload =>
(
(tp.admin || !claims.admin) &&
claims.actAs
.toSet[String]
.subsetOf(tp.actAs.toSet) &&
claims.readAs
.toSet[String]
.subsetOf(tp.readAs.toSet ++ tp.actAs),
tp.applicationId,
)
case tp: lapiauth.StandardJWTPayload =>
// NB: in this mode we check the applicationId claim (if supplied)
// and ignore everything else
(true, Some(tp.userId))
}
precond && ((claims.applicationId, userId) match {
// No requirement on app id
case (None, _) => true
// Token valid for all app ids.
case (_, None) => true
case (Some(expectedAppId), Some(actualAppId)) => expectedAppId == actualAppId
})
}
}

View File

@ -1,190 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.oauth2.api
import java.util.Base64
import org.apache.pekko.http.scaladsl.model.{FormData, HttpEntity, RequestEntity, Uri}
import org.apache.pekko.http.scaladsl.model.Uri.Query
import org.apache.pekko.http.scaladsl.marshalling._
import org.apache.pekko.http.scaladsl.unmarshalling._
import spray.json._
import scala.util.Try
object Request {
// https://tools.ietf.org/html/rfc6749#section-4.1.1
case class Authorize(
responseType: String,
clientId: String,
redirectUri: Uri, // optional in oauth but we require it
scope: Option[String],
state: Option[String],
audience: Option[Uri],
) { // required by auth0 to obtain an access_token
def toQuery: Query = {
var params: Seq[(String, String)] =
Seq(
("response_type", responseType),
("client_id", clientId),
("redirect_uri", redirectUri.toString),
)
scope.foreach { scope =>
params ++= Seq(("scope", scope))
}
state.foreach { state =>
params ++= Seq(("state", state))
}
audience.foreach { audience =>
params ++= Seq(("audience", audience.toString))
}
Query(params: _*)
}
}
// https://tools.ietf.org/html/rfc6749#section-4.1.3
case class Token(
grantType: String,
code: String,
redirectUri: Uri,
clientId: String,
clientSecret: String,
)
object Token {
implicit val marshalRequestEntity: Marshaller[Token, RequestEntity] =
Marshaller.combined { token =>
FormData(
"grant_type" -> token.grantType,
"code" -> token.code,
"redirect_uri" -> token.redirectUri.toString,
"client_id" -> token.clientId,
"client_secret" -> token.clientSecret,
)
}
implicit val unmarshalHttpEntity: Unmarshaller[HttpEntity, Token] =
Unmarshaller.defaultUrlEncodedFormDataUnmarshaller.map { form =>
Token(
grantType = form.fields.get("grant_type").get,
code = form.fields.get("code").get,
redirectUri = form.fields.get("redirect_uri").get,
clientId = form.fields.get("client_id").get,
clientSecret = form.fields.get("client_secret").get,
)
}
}
// https://tools.ietf.org/html/rfc6749#section-6
case class Refresh(
grantType: String,
refreshToken: String,
clientId: String,
clientSecret: String,
)
object Refresh {
implicit val marshalRequestEntity: Marshaller[Refresh, RequestEntity] =
Marshaller.combined { refresh =>
FormData(
"grant_type" -> refresh.grantType,
"refresh_token" -> refresh.refreshToken,
"client_id" -> refresh.clientId,
"client_secret" -> refresh.clientSecret,
)
}
implicit val unmarshalHttpEntity: Unmarshaller[HttpEntity, Refresh] =
Unmarshaller.defaultUrlEncodedFormDataUnmarshaller.map { form =>
Refresh(
grantType = form.fields.get("grant_type").get,
refreshToken = form.fields.get("refresh_token").get,
clientId = form.fields.get("client_id").get,
clientSecret = form.fields.get("client_secret").get,
)
}
}
}
object Response {
// https://tools.ietf.org/html/rfc6749#section-4.1.2
case class Authorize(code: String, state: Option[String]) {
def toQuery: Query = state match {
case None => Query(("code", code))
case Some(state) => Query(("code", code), ("state", state))
}
}
// https://tools.ietf.org/html/rfc6749#section-4.1.2.1
case class Error(
error: String,
errorDescription: Option[String],
errorUri: Option[Uri],
state: Option[String],
) {
def toQuery: Query = {
var params: Seq[(String, String)] = Seq("error" -> error)
errorDescription.foreach(x => params ++= Seq("error_description" -> x))
errorUri.foreach(x => params ++= Seq("error_uri" -> x.toString))
state.foreach(x => params ++= Seq("state" -> x))
Query(params: _*)
}
}
// https://tools.ietf.org/html/rfc6749#section-5.1
case class Token(
accessToken: String,
tokenType: String,
expiresIn: Option[Int],
refreshToken: Option[String],
scope: Option[String],
) {
def toCookieValue: String = {
import JsonProtocol._
Base64.getUrlEncoder().encodeToString(this.toJson.compactPrint.getBytes)
}
}
object Token {
def fromCookieValue(s: String): Option[Token] = {
import JsonProtocol._
for {
bytes <- Try(Base64.getUrlDecoder().decode(s))
json <- Try(new String(bytes).parseJson)
token <- Try(json.convertTo[Token])
} yield token
}.toOption
}
}
object JsonProtocol extends DefaultJsonProtocol {
implicit object UriFormat extends JsonFormat[Uri] {
def read(value: JsValue) = value match {
case JsString(s) => Uri(s)
case _ => deserializationError(s"Expected Uri string but got $value")
}
def write(uri: Uri) = JsString(uri.toString)
}
implicit val tokenRespFormat: RootJsonFormat[Response.Token] = {
jsonFormat(
Response.Token.apply,
"access_token",
"token_type",
"expires_in",
"refresh_token",
"scope",
)
}
implicit val errorRespFormat: RootJsonFormat[Response.Error] = {
jsonFormat(
Response.Error.apply,
"error",
"error_description",
"error_uri",
"state",
)
}
}

View File

@ -1,58 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.oauth2.test.server
import java.time.Clock
import com.daml.ports.Port
case class Config(
// Port the authorization server listens on
port: Port,
// Ledger ID of issued tokens
ledgerId: String,
// Secret used to sign JWTs
jwtSecret: String,
// Use the provided clock instead of system time for token generation.
clock: Option[Clock],
// produce user tokens instead of claim tokens
yieldUserTokens: Boolean,
)
object Config {
private val Empty =
Config(
port = Port.Dynamic,
ledgerId = null,
jwtSecret = null,
clock = None,
yieldUserTokens = false,
)
def parseConfig(args: collection.Seq[String]): Option[Config] =
configParser.parse(args, Empty)
@SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements"))
val configParser: scopt.OptionParser[Config] =
new scopt.OptionParser[Config]("oauth-test-server") {
head("OAuth2 TestServer")
opt[Int]("port")
.action((x, c) => c.copy(port = Port(x)))
.required()
.text("Port to listen on")
opt[String]("ledger-id")
.action((x, c) => c.copy(ledgerId = x))
opt[String]("secret")
.action((x, c) => c.copy(jwtSecret = x))
opt[Unit]("yield-user-tokens")
.optional()
.action((_, c) => c.copy(yieldUserTokens = true))
help("help").text("Print this usage text")
}
}

View File

@ -1,45 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.oauth2.test.server
import org.apache.pekko.actor.ActorSystem
import com.typesafe.scalalogging.StrictLogging
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext}
import scala.util.{Failure, Success}
object Main extends StrictLogging {
def main(args: Array[String]): Unit = {
Config.parseConfig(args) match {
case Some(config) => main(config)
case None => sys.exit(1)
}
}
private def main(config: Config): Unit = {
implicit val system: ActorSystem = ActorSystem("system")
implicit val executionContext: ExecutionContext = system.dispatcher
def terminate() = Await.result(system.terminate(), 10.seconds)
val bindingFuture = Server(config).start()
sys.addShutdownHook {
Server
.stop(bindingFuture)
.onComplete { _ =>
terminate()
}
}
bindingFuture.onComplete {
case Success(binding) =>
logger.info(s"Started server: $binding")
case Failure(e) =>
logger.error(s"Failed to start server: $e")
terminate()
}
}
}

View File

@ -1,249 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.oauth2.test.server
import java.time.Instant
import java.util.UUID
import org.apache.pekko.Done
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.http.scaladsl.Http
import org.apache.pekko.http.scaladsl.Http.ServerBinding
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import org.apache.pekko.http.scaladsl.model.{StatusCodes, Uri}
import org.apache.pekko.http.scaladsl.server.Directives._
import org.apache.pekko.http.scaladsl.unmarshalling.Unmarshaller
import com.daml.auth.oauth2.api.{Request, Response}
import com.daml.jwt.JwtSigner
import com.daml.jwt.domain.DecodedJwt
import com.daml.ledger.api.auth.{
AuthServiceJWTCodec,
AuthServiceJWTPayload,
CustomDamlJWTPayload,
StandardJWTPayload,
StandardJWTTokenFormat,
}
import com.daml.lf.data.Ref
import scala.collection.concurrent.TrieMap
import scala.concurrent.{ExecutionContext, Future}
import scala.language.postfixOps
import scala.util.{Failure, Success, Try}
// This is a test authorization server that implements the OAuth2 authorization code flow.
// This is primarily intended for use in the trigger service tests but could also serve
// as a useful ground for experimentation.
// Given scopes of the form `actAs:$party`, the authorization server will issue
// tokens with the respective claims. Requests for authorized parties will be accepted and
// request to /authorize are immediately redirected to the redirect_uri.
class Server(config: Config) {
import Server.withExp
private val jwtHeader = """{"alg": "HS256", "typ": "JWT"}"""
val tokenLifetimeSeconds = 24 * 60 * 60
private var unauthorizedParties: Set[Ref.Party] = Set()
// Remove the given party from the set of unauthorized parties.
def authorizeParty(party: Ref.Party): Unit = {
unauthorizedParties = unauthorizedParties - party
}
// Add the given party to the set of unauthorized parties.
def revokeParty(party: Ref.Party): Unit = {
unauthorizedParties = unauthorizedParties + party
}
// Clear the set of unauthorized parties.
def resetAuthorizedParties(): Unit = {
unauthorizedParties = Set()
}
private var allowAdmin = true
def authorizeAdmin(): Unit = {
allowAdmin = true
}
def revokeAdmin(): Unit = {
allowAdmin = false
}
def resetAdmin(): Unit = {
allowAdmin = true
}
// To keep things as simple as possible, we use a UUID as the authorization code and refresh token
// and in the /authorize request we already pre-compute the JWT payload based on the scope.
// The token request then only does a lookup and signs the token.
private val requests = TrieMap.empty[UUID, AuthServiceJWTPayload]
private def tokenExpiry(): Instant = {
val now = config.clock match {
case Some(clock) => Instant.now(clock)
case None => Instant.now()
}
now.plusSeconds(tokenLifetimeSeconds.asInstanceOf[Long])
}
private def toPayload(req: Request.Authorize): AuthServiceJWTPayload = {
var actAs: Seq[String] = Seq()
var readAs: Seq[String] = Seq()
var admin: Boolean = false
var applicationId: Option[String] = None
req.scope.foreach(_.split(" ").foreach {
case s if s.startsWith("actAs:") => actAs ++= Seq(s.stripPrefix("actAs:"))
case s if s.startsWith("readAs:") => readAs ++= Seq(s.stripPrefix("readAs:"))
case s if s == "admin" => admin = true
// Given that this is only for testing,
// we dont guard against multiple application id claims.
case s if s.startsWith("applicationId:") =>
applicationId = Some(s.stripPrefix("applicationId:"))
case _ => ()
})
if (config.yieldUserTokens) // ignore everything but the applicationId
StandardJWTPayload(
issuer = None,
userId = applicationId getOrElse "",
participantId = None,
exp = None,
format = StandardJWTTokenFormat.Scope,
audiences = List.empty,
scope = Some("daml_ledger_api"),
)
else
CustomDamlJWTPayload(
ledgerId = Some(config.ledgerId),
applicationId = applicationId,
// Not required by the default auth service
participantId = None,
// Expiry is set when the token is retrieved
exp = None,
// no admin claim for now.
admin = admin,
actAs = actAs.toList,
readAs = readAs.toList,
)
}
// Whether the current configuration of unauthorized parties and admin rights allows to grant the given token payload.
private def authorize(payload: AuthServiceJWTPayload): Either[String, Unit] = payload match {
case payload: CustomDamlJWTPayload =>
val parties = (payload.readAs ++ payload.actAs).toSet
val deniedParties = parties & unauthorizedParties.toSet[String]
val deniedAdmin: Boolean = payload.admin && !allowAdmin
if (deniedParties.nonEmpty) {
Left(s"Access to parties ${deniedParties.mkString(" ")} denied")
} else if (deniedAdmin) {
Left("Admin access denied")
} else {
Right(())
}
case _: StandardJWTPayload => Right(())
}
import Request.Refresh.unmarshalHttpEntity
implicit val unmarshal: Unmarshaller[String, Uri] = Unmarshaller.strict(Uri(_))
val route = concat(
path("authorize") {
get {
parameters(
Symbol("response_type"),
Symbol("client_id"),
Symbol("redirect_uri").as[Uri],
Symbol("scope") ?,
Symbol("state") ?,
Symbol("audience").as[Uri] ?,
)
.as[Request.Authorize](Request.Authorize) { request =>
val payload = toPayload(request)
authorize(payload) match {
case Left(msg) =>
val params =
Response
.Error(
error = "access_denied",
errorDescription = Some(msg),
errorUri = None,
state = request.state,
)
.toQuery
redirect(request.redirectUri.withQuery(params), StatusCodes.Found)
case Right(()) =>
val authorizationCode = UUID.randomUUID()
val params =
Response
.Authorize(code = authorizationCode.toString, state = request.state)
.toQuery
requests.update(authorizationCode, payload)
// We skip any actual consent screen since this is only intended for testing and
// this is outside of the scope of the trigger service anyway.
redirect(request.redirectUri.withQuery(params), StatusCodes.Found)
}
}
}
},
path("token") {
post {
def returnToken(uuid: String) =
Try(UUID.fromString(uuid)) match {
case Failure(_) =>
complete((StatusCodes.BadRequest, "Malformed code or refresh token"))
case Success(uuid) =>
requests.remove(uuid) match {
case Some(payload) =>
// Generate refresh token
val refreshCode = UUID.randomUUID()
requests.update(refreshCode, payload)
// Construct access token with expiry
val accessToken = JwtSigner.HMAC256
.sign(
DecodedJwt(
jwtHeader,
AuthServiceJWTCodec.compactPrint(withExp(payload, Some(tokenExpiry()))),
),
config.jwtSecret,
)
.getOrElse(throw new IllegalArgumentException("Failed to sign a token"))
.value
import com.daml.auth.oauth2.api.JsonProtocol._
complete(
Response.Token(
accessToken = accessToken,
refreshToken = Some(refreshCode.toString),
expiresIn = Some(tokenLifetimeSeconds),
scope = None,
tokenType = "bearer",
)
)
case None =>
complete(StatusCodes.NotFound)
}
}
concat(
entity(as[Request.Token]) { request =>
returnToken(request.code)
},
entity(as[Request.Refresh]) { request =>
returnToken(request.refreshToken)
},
)
}
},
)
def start()(implicit system: ActorSystem): Future[ServerBinding] = {
Http().newServerAt("localhost", config.port.value).bind(route)
}
}
object Server {
def apply(config: Config) = new Server(config)
def stop(f: Future[ServerBinding])(implicit ec: ExecutionContext): Future[Done] =
f.flatMap(_.unbind())
private def withExp(payload: AuthServiceJWTPayload, exp: Option[Instant]): AuthServiceJWTPayload =
payload match {
case payload: CustomDamlJWTPayload => payload.copy(exp = exp)
case payload: StandardJWTPayload => payload.copy(exp = exp)
}
}

View File

@ -1,18 +0,0 @@
{
callback-uri = "https://example.com/auth/cb"
oauth-auth = "https://oauth2/uri"
oauth-token = "https://oauth2/token"
// client-id = ${DAML_CLIENT_ID}
// client-secret = ${DAML_CLIENT_SECRET}
// can be set via env variables , dummy values for test purposes
client-id = foo
client-secret = bar
// type can be one of rs256-crt, es256-crt, es512-crt, rs256-jwks
// uri is the uri to the cert file or the jwks url
token-verifier {
type = "rs256-jwks"
uri = "https://example.com/.well-known/jwks.json"
}
}

View File

@ -1,31 +0,0 @@
{
address = "127.0.0.1"
port = 3000
callback-uri = "https://example.com/auth/cb"
max-login-requests = 10
login-timeout = 60s
cookie-secure = false
oauth-auth = "https://oauth2/uri"
oauth-token = "https://oauth2/token"
oauth-auth-template = "auth_template"
oauth-token-template = "token_template"
oauth-refresh-template = "refresh_template"
// client-id = ${DAML_CLIENT_ID}
// client-secret = ${DAML_CLIENT_SECRET}
// can be set via env variables , dummy values for test purposes
client-id = foo
client-secret = bar
// type can be one of rs256-crt, es256-crt, es512-crt, rs256-jwks
// uri is the uri to the cert file or the jwks url
token-verifier {
type = "rs256-jwks"
uri = "https://example.com/.well-known/jwks.json"
}
metrics {
reporter = "prometheus://0.0.0.0:5104"
reporting-interval = 30s
}
}

View File

@ -1,181 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.jwt
import java.math.BigInteger
import java.nio.file.{Files, Path}
import java.security.interfaces.{ECPrivateKey, ECPublicKey, RSAPrivateKey, RSAPublicKey}
import java.security.{KeyPair, KeyPairGenerator, PrivateKey, PublicKey, Security}
import java.time.Instant
import java.util.concurrent.atomic.AtomicReference
import com.auth0.jwt.JWT
import com.auth0.jwt.algorithms.Algorithm
import com.daml.fs.TemporaryDirectory
import com.daml.http.test.SimpleHttpServer
import com.daml.scalautil.Statement.discard
import com.daml.jwt.JwtVerifierConfigurationCliSpec._
import com.daml.ledger.api.auth.ClaimSet.Claims
import com.daml.ledger.api.auth.{AuthService, AuthServiceJWT, AuthServiceWildcard, ClaimPublic}
import io.grpc.Metadata
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.cert.jcajce.{JcaX509CertificateConverter, JcaX509v3CertificateBuilder}
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.scalatest.Assertion
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AsyncWordSpec
import scopt.OptionParser
import scala.concurrent.{ExecutionContext, Future}
import scala.jdk.FutureConverters._
class JwtVerifierConfigurationCliSpec extends AsyncWordSpec with Matchers {
discard(Security.addProvider(new BouncyCastleProvider))
"auth command-line parsers" should {
"parse and configure the authorisation mechanism correctly when `--auth-jwt-hs256-unsafe <secret>` is passed" in {
val secret = "someSecret"
val authService = parseConfig(Array("--auth-jwt-hs256-unsafe", secret))
val token = JWT.create().sign(Algorithm.HMAC256(secret))
val metadata = createAuthMetadata(token)
decodeAndCheckMetadata(authService, metadata)
}
"parse and configure the authorisation mechanism correctly when `--auth-jwt-rs256-crt <PK.crt>` is passed" in
new TemporaryDirectory(getClass.getSimpleName).use { directory =>
val (publicKey, privateKey) = newRsaKeyPair()
val certificatePath = newCertificate("SHA256WithRSA", directory, publicKey, privateKey)
val token = JWT.create().sign(Algorithm.RSA256(publicKey, privateKey))
val authService = parseConfig(Array("--auth-jwt-rs256-crt", certificatePath.toString))
val metadata = createAuthMetadata(token)
decodeAndCheckMetadata(authService, metadata)
}
"parse and configure the authorisation mechanism correctly when `--auth-jwt-es256-crt <PK.crt>` is passed" in
new TemporaryDirectory(getClass.getSimpleName).use { directory =>
val (publicKey, privateKey) = newEcdsaKeyPair()
val certificatePath = newCertificate("SHA256WithECDSA", directory, publicKey, privateKey)
val token = JWT.create().sign(Algorithm.ECDSA256(publicKey, privateKey))
val authService = parseConfig(Array("--auth-jwt-es256-crt", certificatePath.toString))
val metadata = createAuthMetadata(token)
decodeAndCheckMetadata(authService, metadata)
}
"parse and configure the authorisation mechanism correctly when `--auth-jwt-es512-crt <PK.crt>` is passed" in
new TemporaryDirectory(getClass.getSimpleName).use { directory =>
val (publicKey, privateKey) = newEcdsaKeyPair()
val certificatePath = newCertificate("SHA512WithECDSA", directory, publicKey, privateKey)
val token = JWT.create().sign(Algorithm.ECDSA512(publicKey, privateKey))
val authService = parseConfig(Array("--auth-jwt-es512-crt", certificatePath.toString))
val metadata = createAuthMetadata(token)
decodeAndCheckMetadata(authService, metadata)
}
"parse and configure the authorisation mechanism correctly when `--auth-jwt-rs256-jwks <URL>` is passed" in {
val (publicKey, privateKey) = newRsaKeyPair()
val keyId = "test-key-1"
val token = JWT.create().withKeyId(keyId).sign(Algorithm.RSA256(publicKey, privateKey))
// Start a JWKS server and create a verifier using the JWKS server
val jwks = KeyUtils.generateJwks(
Map(
keyId -> publicKey
)
)
val server = SimpleHttpServer.start(jwks)
Future {
val url = SimpleHttpServer.responseUrl(server)
val authService = parseConfig(Array("--auth-jwt-rs256-jwks", url))
val metadata = createAuthMetadata(token)
(authService, metadata)
}.flatMap { case (authService, metadata) =>
decodeAndCheckMetadata(authService, metadata)
}.andThen { case _ =>
SimpleHttpServer.stop(server)
}
}
}
}
object JwtVerifierConfigurationCliSpec {
private def parseConfig(args: Array[String]): AuthService = {
val parser = new OptionParser[AtomicReference[AuthService]]("test") {}
JwtVerifierConfigurationCli.parse(parser) { (verifier, config) =>
config.set(AuthServiceJWT(verifier, targetAudience = None, targetScope = None))
config
}
parser.parse(args, new AtomicReference[AuthService](AuthServiceWildcard)).get.get()
}
private def createAuthMetadata(token: String) = {
val metadata = new Metadata()
metadata.put(AuthService.AUTHORIZATION_KEY, s"Bearer $token")
metadata
}
private def decodeAndCheckMetadata(
authService: AuthService,
metadata: Metadata,
)(implicit executionContext: ExecutionContext): Future[Assertion] = {
import org.scalatest.Inside._
import org.scalatest.matchers.should.Matchers._
authService.decodeMetadata(metadata).asScala.map { auth =>
inside(auth) { case claims: Claims =>
claims.claims should be(List(ClaimPublic))
}
}
}
private def newRsaKeyPair(): (RSAPublicKey, RSAPrivateKey) = {
val keyPair = newKeyPair("RSA", 2048)
val publicKey = keyPair.getPublic.asInstanceOf[RSAPublicKey]
val privateKey = keyPair.getPrivate.asInstanceOf[RSAPrivateKey]
(publicKey, privateKey)
}
private def newEcdsaKeyPair(): (ECPublicKey, ECPrivateKey) = {
val keyPair = newKeyPair("ECDSA", 256)
val publicKey = keyPair.getPublic.asInstanceOf[ECPublicKey]
val privateKey = keyPair.getPrivate.asInstanceOf[ECPrivateKey]
(publicKey, privateKey)
}
private def newKeyPair(algorithm: String, keySize: Int): KeyPair = {
val generator = KeyPairGenerator.getInstance(algorithm, BouncyCastleProvider.PROVIDER_NAME)
generator.initialize(keySize)
generator.generateKeyPair()
}
private def newCertificate(
signatureAlgorithm: String,
directory: Path,
publicKey: PublicKey,
privateKey: PrivateKey,
): Path = {
val now = Instant.now()
val dnName = new X500Name(s"CN=${getClass.getSimpleName}")
val contentSigner = new JcaContentSignerBuilder(signatureAlgorithm).build(privateKey)
val certBuilder = new JcaX509v3CertificateBuilder(
dnName,
BigInteger.valueOf(now.toEpochMilli),
java.util.Date.from(now),
java.util.Date.from(now.plusSeconds(60)),
dnName,
publicKey,
)
val certificate =
new JcaX509CertificateConverter()
.setProvider(BouncyCastleProvider.PROVIDER_NAME)
.getCertificate(certBuilder.build(contentSigner))
val certificatePath = directory.resolve("certificate")
discard(Files.write(certificatePath, certificate.getEncoded))
certificatePath
}
}

View File

@ -1,130 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.oauth2
import org.apache.pekko.http.scaladsl.model.Uri
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AsyncWordSpec
import com.daml.bazeltools.BazelRunfiles.requiredResource
import com.daml.jwt.JwksVerifier
import com.daml.metrics.MetricsConfig
import com.daml.metrics.api.reporters.MetricsReporter
import org.scalatest.Inside.inside
import pureconfig.error.{CannotReadFile, ConfigReaderFailures}
import java.nio.file.Paths
import java.net.InetSocketAddress
import scala.concurrent.duration._
class CliSpec extends AsyncWordSpec with Matchers {
val minimalCfg = FileConfig(
oauthAuth = Uri("https://oauth2/uri"),
oauthToken = Uri("https://oauth2/token"),
callbackUri = Some(Uri("https://example.com/auth/cb")),
clientId = sys.env.getOrElse("DAML_CLIENT_ID", "foo"),
clientSecret = SecretString(sys.env.getOrElse("DAML_CLIENT_SECRET", "bar")),
tokenVerifier = null,
)
val confFile = "triggers/service/auth/src/test/resources/oauth2-middleware.conf"
def loadCli(file: String): Cli = {
Cli.parse(Array("--config", file)).getOrElse(fail("Could not load Cli on parse"))
}
"should pickup the config file provided" in {
val file = requiredResource(confFile)
val cli = loadCli(file.getAbsolutePath)
cli.configFile should not be empty
}
"should take default values on loading minimal config" in {
val file =
requiredResource("triggers/service/auth/src/test/resources/oauth2-middleware-minimal.conf")
val cli = loadCli(file.getAbsolutePath)
cli.configFile should not be empty
val cfg = cli.loadFromConfigFile
inside(cfg) { case Some(Right(c)) =>
c.copy(tokenVerifier = null) shouldBe minimalCfg
// token verifier needs to be set.
c.tokenVerifier shouldBe a[JwksVerifier]
}
}
"should be able to successfully load the config based on the file provided" in {
val file = requiredResource(confFile)
val cli = loadCli(file.getAbsolutePath)
cli.configFile should not be empty
val cfg = cli.loadFromConfigFile
inside(cfg) { case Some(Right(c)) =>
c.copy(tokenVerifier = null) shouldBe minimalCfg.copy(
port = 3000,
maxLoginRequests = 10,
loginTimeout = 60.seconds,
cookieSecure = false,
oauthAuthTemplate = Some(Paths.get("auth_template")),
oauthTokenTemplate = Some(Paths.get("token_template")),
oauthRefreshTemplate = Some(Paths.get("refresh_template")),
metrics = Some(
MetricsConfig(
MetricsReporter.Prometheus(new InetSocketAddress("0.0.0.0", 5104)),
30.seconds,
Seq.empty,
)
),
)
// token verifier needs to be set.
c.tokenVerifier shouldBe a[JwksVerifier]
}
}
"parse should raise error on non-existent config file" in {
val cli = loadCli("missingFile.conf")
cli.configFile should not be empty
val cfg = cli.loadFromConfigFile
inside(cfg) { case Some(Left(ConfigReaderFailures(head))) =>
head shouldBe a[CannotReadFile]
}
// parseConfig for non-existent file should return a None
Cli.parseConfig(
Array(
"--config",
"missingFile.conf",
)
) shouldBe None
}
"should load config from cli args on missing conf file " in {
Cli
.parseConfig(
Array(
"--oauth-auth",
"file://foo",
"--oauth-token",
"file://bar",
"--id",
"foo",
"--secret",
"bar",
"--auth-jwt-hs256-unsafe",
"unsafe",
)
) should not be empty
}
"should fail to load config from cli args on incomplete cli args" in {
Cli
.parseConfig(
Array(
"--oauth-auth",
"file://foo",
"--id",
"foo",
"--secret",
"bar",
"--auth-jwt-hs256-unsafe",
"unsafe",
)
) shouldBe None
}
}

View File

@ -1,108 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.oauth2
import java.io.File
import java.nio.file.Files
import java.time.{Clock, Duration, Instant, ZoneId}
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.http.scaladsl.Http
import org.apache.pekko.http.scaladsl.Http.ServerBinding
import org.apache.pekko.http.scaladsl.model.{StatusCodes, Uri}
import org.apache.pekko.http.scaladsl.server.Directives._
import com.daml.auth.middleware.api.Client
import com.daml.auth.middleware.api.Request.Claims
import com.daml.clock.AdjustableClock
import com.daml.ledger.resources.{Resource, ResourceContext, ResourceOwner}
import com.daml.auth.oauth2.test.server.{Server => OAuthServer}
import com.daml.scalautil.Statement.discard
import scala.concurrent.Future
object Resources {
def clock(start: Instant, zoneId: ZoneId): ResourceOwner[AdjustableClock] =
new ResourceOwner[AdjustableClock] {
override def acquire()(implicit context: ResourceContext): Resource[AdjustableClock] = {
Resource(Future(AdjustableClock(Clock.fixed(start, zoneId), Duration.ZERO)))(_ =>
Future(())
)
}
}
def temporaryDirectory(): ResourceOwner[File] =
new ResourceOwner[File] {
override def acquire()(implicit context: ResourceContext): Resource[File] =
Resource(Future(Files.createTempDirectory("daml-oauth2-middleware").toFile))(dir =>
Future(discard { dir.delete() })
)
}
def authServerBinding(
server: OAuthServer
)(implicit sys: ActorSystem): ResourceOwner[ServerBinding] =
new ResourceOwner[ServerBinding] {
override def acquire()(implicit context: ResourceContext): Resource[ServerBinding] =
Resource(server.start())(_.unbind().map(_ => ()))
}
def authMiddlewareBinding(
config: Config
)(implicit sys: ActorSystem): ResourceOwner[ServerBinding] =
new ResourceOwner[ServerBinding] {
override def acquire()(implicit context: ResourceContext): Resource[ServerBinding] =
Resource(Server.start(config, registerGlobalOpenTelemetry = false))(_.unbind().map(_ => ()))
}
def authMiddlewareClientBinding(client: Client, callbackPath: Uri.Path)(implicit
sys: ActorSystem
): ResourceOwner[ServerBinding] =
new ResourceOwner[ServerBinding] {
override def acquire()(implicit context: ResourceContext): Resource[ServerBinding] =
Resource {
Http()
.newServerAt("localhost", 0)
.bind {
extractUri { reqUri =>
val callbackUri = Uri()
.withScheme(reqUri.scheme)
.withAuthority(reqUri.authority)
.withPath(callbackPath)
val clientRoutes = client.routes(callbackUri)
concat(
path("authorize") {
get {
parameters(Symbol("claims").as[Claims]) { claims =>
clientRoutes.authorize(claims) {
case Client.Authorized(authorization) =>
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import com.daml.auth.middleware.api.JsonProtocol.responseAuthorizeFormat
complete(StatusCodes.OK, authorization)
case Client.Unauthorized =>
complete(StatusCodes.Unauthorized)
case Client.LoginFailed(loginError) =>
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import com.daml.auth.middleware.api.JsonProtocol.ResponseLoginFormat
complete(
StatusCodes.Forbidden,
loginError: com.daml.auth.middleware.api.Response.Login,
)
}
}
}
},
path("login") {
get {
parameters(Symbol("claims").as[Claims]) { claims =>
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import com.daml.auth.middleware.api.JsonProtocol.ResponseLoginFormat
clientRoutes.login(claims, login => complete(StatusCodes.OK, login))
}
}
},
path("cb") { get { clientRoutes.callbackHandler } },
)
}
}
} {
_.unbind().map(_ => ())
}
}
}

View File

@ -1,156 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.oauth2
import java.io.File
import java.time.{Instant, ZoneId}
import org.apache.pekko.http.scaladsl.Http.ServerBinding
import org.apache.pekko.http.scaladsl.model.Uri
import com.auth0.jwt.JWTVerifier.BaseVerification
import com.auth0.jwt.JWT
import com.auth0.jwt.algorithms.Algorithm
import com.daml.auth.middleware.api.Client
import com.daml.clock.AdjustableClock
import com.daml.jwt.JwtVerifier
import com.daml.ledger.api.testing.utils.{
PekkoBeforeAndAfterAll,
OwnedResource,
Resource,
SuiteResource,
}
import com.daml.ledger.resources.ResourceContext
import com.daml.auth.oauth2.test.server.{Config => OAuthConfig, Server => OAuthServer}
import com.daml.ports.Port
import org.scalatest.{BeforeAndAfterEach, Suite}
import scala.concurrent.duration
import scala.concurrent.duration.FiniteDuration
case class TestResources(
clock: AdjustableClock,
authServer: OAuthServer,
authServerBinding: ServerBinding,
authMiddlewarePortFile: File,
authMiddlewareBinding: ServerBinding,
authMiddlewareClient: Client,
authMiddlewareClientBinding: ServerBinding,
)
trait TestFixture
extends PekkoBeforeAndAfterAll
with BeforeAndAfterEach
with SuiteResource[TestResources] {
self: Suite =>
protected val ledgerId: String = "test-ledger"
protected val jwtSecret: String = "secret"
protected val maxMiddlewareLogins: Int = Config.DefaultMaxLoginRequests
protected val maxClientAuthCallbacks: Int = 1000
protected val middlewareCallbackUri: Option[Uri] = None
protected val middlewareClientCallbackPath: Uri.Path = Uri.Path./("cb")
protected val redirectToLogin: Client.RedirectToLogin = Client.RedirectToLogin.Yes
lazy protected val clock: AdjustableClock = suiteResource.value.clock
lazy protected val server: OAuthServer = suiteResource.value.authServer
lazy protected val serverBinding: ServerBinding = suiteResource.value.authServerBinding
lazy protected val middlewarePortFile: File = suiteResource.value.authMiddlewarePortFile
lazy protected val middlewareBinding: ServerBinding = suiteResource.value.authMiddlewareBinding
lazy protected val middlewareClient: Client = suiteResource.value.authMiddlewareClient
lazy protected val middlewareClientBinding: ServerBinding = {
suiteResource.value.authMiddlewareClientBinding
}
lazy protected val middlewareClientCallbackUri: Uri = {
val host = middlewareClientBinding.localAddress
Uri()
.withScheme("http")
.withAuthority("localhost", host.getPort)
.withPath(middlewareClientCallbackPath)
}
lazy protected val middlewareClientRoutes: Client.Routes =
middlewareClient.routes(middlewareClientCallbackUri)
protected def oauthYieldsUserTokens: Boolean = true
override protected lazy val suiteResource: Resource[TestResources] = {
implicit val resourceContext: ResourceContext = ResourceContext(system.dispatcher)
new OwnedResource[ResourceContext, TestResources](
for {
clock <- Resources.clock(Instant.now(), ZoneId.systemDefault())
server = OAuthServer(
OAuthConfig(
port = Port.Dynamic,
ledgerId = ledgerId,
jwtSecret = jwtSecret,
clock = Some(clock),
yieldUserTokens = oauthYieldsUserTokens,
)
)
serverBinding <- Resources.authServerBinding(server)
serverUri = Uri()
.withScheme("http")
.withAuthority(
serverBinding.localAddress.getHostString,
serverBinding.localAddress.getPort,
)
tempDir <- Resources.temporaryDirectory()
middlewarePortFile = new File(tempDir, "port")
middlewareBinding <- Resources.authMiddlewareBinding(
Config(
address = "localhost",
port = 0,
portFile = Some(middlewarePortFile.toPath),
callbackUri = middlewareCallbackUri,
maxLoginRequests = maxMiddlewareLogins,
loginTimeout = Config.DefaultLoginTimeout,
cookieSecure = Config.DefaultCookieSecure,
oauthAuth = serverUri.withPath(Uri.Path./("authorize")),
oauthToken = serverUri.withPath(Uri.Path./("token")),
oauthAuthTemplate = None,
oauthTokenTemplate = None,
oauthRefreshTemplate = None,
clientId = "middleware",
clientSecret = SecretString("middleware-secret"),
tokenVerifier = new JwtVerifier(
JWT
.require(Algorithm.HMAC256(jwtSecret))
.asInstanceOf[BaseVerification]
.build(clock)
),
histograms = Seq.empty,
)
)
authUri = Uri()
.withScheme("http")
.withAuthority(
middlewareBinding.localAddress.getHostName,
middlewareBinding.localAddress.getPort,
)
middlewareClientConfig = Client.Config(
authMiddlewareInternalUri = authUri,
authMiddlewareExternalUri = authUri,
redirectToLogin = redirectToLogin,
maxAuthCallbacks = maxClientAuthCallbacks,
authCallbackTimeout = FiniteDuration(1, duration.MINUTES),
maxHttpEntityUploadSize = 4194304,
httpEntityUploadTimeout = FiniteDuration(1, duration.MINUTES),
)
middlewareClient = Client(middlewareClientConfig)
middlewareClientBinding <- Resources
.authMiddlewareClientBinding(middlewareClient, middlewareClientCallbackPath)
} yield TestResources(
clock = clock,
authServer = server,
authServerBinding = serverBinding,
authMiddlewarePortFile = middlewarePortFile,
authMiddlewareBinding = middlewareBinding,
authMiddlewareClient = middlewareClient,
authMiddlewareClientBinding = middlewareClientBinding,
)
)
}
override protected def afterEach(): Unit = {
server.resetAuthorizedParties()
server.resetAdmin()
super.afterEach()
}
}

View File

@ -1,779 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.oauth2
import java.time.Duration
import org.apache.pekko.http.scaladsl.Http
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import org.apache.pekko.http.scaladsl.model._
import org.apache.pekko.http.scaladsl.model.headers.{Cookie, Location, `Set-Cookie`}
import org.apache.pekko.http.scaladsl.testkit.ScalatestRouteTest
import org.apache.pekko.http.scaladsl.unmarshalling.Unmarshal
import com.daml.auth.middleware.api.{Client, Request, Response}
import com.daml.auth.middleware.api.Request.Claims
import com.daml.auth.middleware.api.Tagged.{AccessToken, RefreshToken}
import com.daml.jwt.JwtSigner
import com.daml.jwt.domain.DecodedJwt
import com.daml.ledger.api.auth.{
AuthServiceJWTCodec,
AuthServiceJWTPayload,
CustomDamlJWTPayload,
StandardJWTPayload,
StandardJWTTokenFormat,
}
import com.daml.ledger.api.testing.utils.SuiteResourceManagementAroundAll
import com.daml.lf.data.Ref
import com.daml.auth.oauth2.api.{Response => OAuthResponse}
import com.daml.test.evidence.tag.Security.SecurityTest.Property.{Authenticity, Authorization}
import com.daml.test.evidence.tag.Security.{Attack, SecurityTest}
import com.daml.test.evidence.scalatest.ScalaTestSupport.Implicits._
import org.scalatest.{OptionValues, TryValues}
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.{AnyWordSpec, AsyncWordSpec}
import scala.collection.immutable
import scala.collection.immutable.Seq
import scala.concurrent.duration
import scala.concurrent.duration.FiniteDuration
import scala.io.Source
import scala.util.{Failure, Success}
abstract class TestMiddleware
extends AsyncWordSpec
with TestFixture
with SuiteResourceManagementAroundAll
with Matchers
with OptionValues {
val authenticationSecurity: SecurityTest =
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
val authorizationSecurity: SecurityTest =
SecurityTest(property = Authorization, asset = "OAuth2 Middleware")
protected[this] def makeJwt(
claims: Request.Claims,
expiresIn: Option[Duration],
): AuthServiceJWTPayload
protected[this] def makeToken(
claims: Request.Claims,
secret: String = "secret",
expiresIn: Option[Duration] = None,
): OAuthResponse.Token = {
val jwtHeader = """{"alg": "HS256", "typ": "JWT"}"""
val jwtPayload = makeJwt(claims, expiresIn)
OAuthResponse.Token(
accessToken = JwtSigner.HMAC256
.sign(DecodedJwt(jwtHeader, AuthServiceJWTCodec.compactPrint(jwtPayload)), secret)
.getOrElse(
throw new IllegalArgumentException("Failed to sign a token")
)
.value,
tokenType = "bearer",
expiresIn = expiresIn.map(in => in.getSeconds.toInt),
refreshToken = None,
scope = Some(claims.toQueryString()),
)
}
"the port file" should {
"list the HTTP port" in {
val bindingPort = middlewareBinding.localAddress.getPort.toString
val filePort = {
val source = Source.fromFile(middlewarePortFile)
try {
source.mkString.stripLineEnd
} finally {
source.close()
}
}
bindingPort should ===(filePort)
}
}
"the /auth endpoint" should {
"return unauthorized without cookie" taggedAs authorizationSecurity in {
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
for {
result <- middlewareClient.requestAuth(claims, Nil)
} yield {
result should ===(None)
}
}
"return the token from a cookie" taggedAs authorizationSecurity in {
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
val token = makeToken(claims)
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
for {
result <- middlewareClient.requestAuth(claims, List(cookieHeader))
auth = result.value
} yield {
auth.accessToken should ===(token.accessToken)
auth.refreshToken should ===(token.refreshToken)
}
}
"return unauthorized on insufficient app id claims" taggedAs authorizationSecurity in {
val claims = Request.Claims(
actAs = List(Ref.Party.assertFromString("Alice")),
applicationId = Some(Ref.ApplicationId.assertFromString("other-id")),
)
val token = makeToken(
Request.Claims(
actAs = List(Ref.Party.assertFromString("Alice")),
applicationId = Some(Ref.ApplicationId.assertFromString("my-app-id")),
)
)
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
for {
result <- middlewareClient.requestAuth(claims, List(cookieHeader))
} yield {
result should ===(None)
}
}
"return unauthorized on an invalid token" taggedAs authorizationSecurity in {
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
val token = makeToken(claims, "wrong-secret")
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
for {
result <- middlewareClient.requestAuth(claims, List(cookieHeader))
} yield {
result should ===(None)
}
}
"return unauthorized on an expired token" taggedAs authorizationSecurity in {
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
val token = makeToken(claims, expiresIn = Some(Duration.ZERO))
val _ = clock.fastForward(Duration.ofSeconds(1))
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
for {
result <- middlewareClient.requestAuth(claims, List(cookieHeader))
} yield {
result should ===(None)
}
}
"accept user tokens" taggedAs authorizationSecurity in {
import com.daml.auth.middleware.oauth2.Server.rightsProvideClaims
rightsProvideClaims(
StandardJWTPayload(
None,
"foo",
None,
None,
StandardJWTTokenFormat.Scope,
List.empty,
Some("daml_ledger_api"),
),
Claims(
admin = true,
actAs = List(Ref.Party.assertFromString("Alice")),
readAs = List(Ref.Party.assertFromString("Bob")),
applicationId = Some(Ref.ApplicationId.assertFromString("foo")),
),
) should ===(true)
}
}
"the /login endpoint" should {
"redirect and set cookie" taggedAs authenticationSecurity.setHappyCase(
"A valid request to /login redirects to client callback and sets cookie"
) in {
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
val req = HttpRequest(uri = middlewareClientRoutes.loginUri(claims, None))
for {
resp <- Http().singleRequest(req)
// Redirect to /authorize on authorization server
resp <- {
resp.status should ===(StatusCodes.Found)
val req = HttpRequest(uri = resp.header[Location].value.uri)
Http().singleRequest(req)
}
// Redirect to /cb on middleware
resp <- {
resp.status should ===(StatusCodes.Found)
val req = HttpRequest(uri = resp.header[Location].value.uri)
Http().singleRequest(req)
}
} yield {
// Redirect to client callback
resp.status should ===(StatusCodes.Found)
resp.header[Location].value.uri should ===(middlewareClientCallbackUri)
// Store token in cookie
val cookie = resp.header[`Set-Cookie`].value.cookie
cookie.name should ===("daml-ledger-token")
val token = OAuthResponse.Token.fromCookieValue(cookie.value).value
token.tokenType should ===("bearer")
}
}
"return OK and set cookie without redirectUri" taggedAs authenticationSecurity.setHappyCase(
"A valid request to /login returns OK and sets cookie when redirect is off"
) in {
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
val req = HttpRequest(uri = middlewareClientRoutes.loginUri(claims, None, redirect = false))
for {
resp <- Http().singleRequest(req)
// Redirect to /authorize on authorization server
resp <- {
resp.status should ===(StatusCodes.Found)
val req = HttpRequest(uri = resp.header[Location].value.uri)
Http().singleRequest(req)
}
// Redirect to /cb on middleware
resp <- {
resp.status should ===(StatusCodes.Found)
val req = HttpRequest(uri = resp.header[Location].value.uri)
Http().singleRequest(req)
}
} yield {
// Return OK
resp.status should ===(StatusCodes.OK)
// Store token in cookie
val cookie = resp.header[`Set-Cookie`].value.cookie
cookie.name should ===("daml-ledger-token")
val token = OAuthResponse.Token.fromCookieValue(cookie.value).value
token.tokenType should ===("bearer")
}
}
}
"the /refresh endpoint" should {
"return a new access token" taggedAs authorizationSecurity.setHappyCase(
"A valid request to /refresh returns a new access token"
) in {
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
val loginReq = HttpRequest(uri = middlewareClientRoutes.loginUri(claims, None))
for {
resp <- Http().singleRequest(loginReq)
// Redirect to /authorize on authorization server
resp <- {
resp.status should ===(StatusCodes.Found)
val req = HttpRequest(uri = resp.header[Location].value.uri)
Http().singleRequest(req)
}
// Redirect to /cb on middleware
resp <- {
resp.status should ===(StatusCodes.Found)
val req = HttpRequest(uri = resp.header[Location].value.uri)
Http().singleRequest(req)
}
// Extract token from cookie
(token1, refreshToken) = {
val cookie = resp.header[`Set-Cookie`].value.cookie
val token = OAuthResponse.Token.fromCookieValue(cookie.value).value
(AccessToken(token.accessToken), RefreshToken(token.refreshToken.value))
}
// Advance time
_ = clock.fastForward(Duration.ofSeconds(1))
// Request /refresh
authorize <- middlewareClient.requestRefresh(refreshToken)
} yield {
// Test that we got a new access token
authorize.accessToken should !==(token1)
// Test that we got a new refresh token
authorize.refreshToken.value should !==(refreshToken)
}
}
"fail on an invalid refresh token" taggedAs authorizationSecurity.setAttack(
Attack("HTTP Client", "Presents an invalid refresh token", "refuse request with CLIENT_ERROR")
) in {
for {
exception <- middlewareClient.requestRefresh(RefreshToken("made-up-token")).transform {
case Failure(exception: Client.RefreshException) => Success(exception)
case value => fail(s"Expected failure with RefreshException but got $value")
}
} yield {
exception.status shouldBe a[StatusCodes.ClientError]
}
}
}
}
class TestMiddlewareClaimsToken extends TestMiddleware {
override protected[this] def oauthYieldsUserTokens = false
override protected[this] def makeJwt(
claims: Request.Claims,
expiresIn: Option[Duration],
): AuthServiceJWTPayload =
CustomDamlJWTPayload(
ledgerId = Some("test-ledger"),
applicationId = Some("test-application"),
participantId = None,
exp = expiresIn.map(in => clock.instant.plus(in)),
admin = claims.admin,
actAs = claims.actAs,
readAs = claims.readAs,
)
"the /auth endpoint given claim token" should {
"return unauthorized on insufficient party claims" taggedAs authorizationSecurity in {
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Bob")))
def r(actAs: String*)(readAs: String*) =
middlewareClient
.requestAuth(
claims,
Seq(
Cookie(
"daml-ledger-token",
makeToken(
Request.Claims(
actAs = actAs.map(Ref.Party.assertFromString).toList,
readAs = readAs.map(Ref.Party.assertFromString).toList,
)
).toCookieValue,
)
),
)
for {
aliceA <- r("Alice")()
nothing <- r()()
aliceA_bobA <- r("Alice", "Bob")()
aliceA_bobR <- r("Alice")("Bob")
aliceR_bobA <- r("Bob")("Alice")
aliceR_bobR <- r()("Alice", "Bob")
bobA <- r("Bob")()
bobR <- r()("Bob")
bobAR <- r("Bob")("Bob")
} yield {
aliceA shouldBe empty
nothing shouldBe empty
aliceA_bobA should not be empty
aliceA_bobR shouldBe empty
aliceR_bobA should not be empty
aliceR_bobR shouldBe empty
bobA should not be empty
bobR shouldBe empty
bobAR should not be empty
}
}
}
"the /login endpoint with an oauth server checking claims" should {
"not authorize unauthorized parties" taggedAs authorizationSecurity in {
server.revokeParty(Ref.Party.assertFromString("Eve"))
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Eve")))
ensureDisallowed(claims)
}
"not authorize disallowed admin claims" taggedAs authorizationSecurity in {
server.revokeAdmin()
val claims = Request.Claims(admin = true)
ensureDisallowed(claims)
}
def ensureDisallowed(claims: Request.Claims) = {
val req = HttpRequest(uri = middlewareClientRoutes.loginUri(claims, None))
for {
resp <- Http().singleRequest(req)
// Redirect to /authorize on authorization server
resp <- {
resp.status should ===(StatusCodes.Found)
val req = HttpRequest(uri = resp.header[Location].value.uri)
Http().singleRequest(req)
}
// Redirect to /cb on middleware
resp <- {
resp.status should ===(StatusCodes.Found)
val req = HttpRequest(uri = resp.header[Location].value.uri)
Http().singleRequest(req)
}
} yield {
// Redirect to client callback
resp.status should ===(StatusCodes.Found)
resp.header[Location].value.uri.withQuery(Uri.Query()) should ===(
middlewareClientCallbackUri
)
// with error parameter set
resp.header[Location].value.uri.query().toMap.get("error") should ===(Some("access_denied"))
// Without token in cookie
val cookie = resp.header[`Set-Cookie`]
cookie should ===(None)
}
}
}
}
class TestMiddlewareUserToken extends TestMiddleware {
override protected[this] def makeJwt(
claims: Request.Claims,
expiresIn: Option[Duration],
): AuthServiceJWTPayload =
StandardJWTPayload(
issuer = None,
userId = "test-application",
participantId = None,
exp = expiresIn.map(in => clock.instant.plus(in)),
format = StandardJWTTokenFormat.Scope,
audiences = List.empty,
scope = Some("daml_ledger_api"),
)
}
class TestMiddlewareCallbackUriOverride
extends AsyncWordSpec
with Matchers
with OptionValues
with TestFixture
with SuiteResourceManagementAroundAll {
val authenticationSecurity: SecurityTest =
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
override protected val middlewareCallbackUri = Some(Uri("http://localhost/MIDDLEWARE_CALLBACK"))
"the /login endpoint with an oauth server checking claims" should {
"redirect to the configured middleware callback URI" taggedAs authenticationSecurity
.setHappyCase(
"A valid request to /login redirects to middleware callback"
) in {
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
val req = HttpRequest(uri = middlewareClientRoutes.loginUri(claims, None))
for {
resp <- Http().singleRequest(req)
// Redirect to /authorize on authorization server
resp <- {
resp.status should ===(StatusCodes.Found)
val req = HttpRequest(uri = resp.header[Location].value.uri)
Http().singleRequest(req)
}
} yield {
// Redirect to configured callback URI on middleware
resp.status should ===(StatusCodes.Found)
resp.header[Location].value.uri.withQuery(Uri.Query()) should ===(
middlewareCallbackUri.value
)
}
}
}
}
class TestMiddlewareLimitedCallbackStore
extends AsyncWordSpec
with Matchers
with OptionValues
with TestFixture
with SuiteResourceManagementAroundAll {
val authenticationSecurity: SecurityTest =
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
override protected val maxMiddlewareLogins = 2
"the /login endpoint with an oauth server checking claims" should {
"refuse requests when max capacity is reached" taggedAs authenticationSecurity.setAttack(
Attack(
"HTTP client",
"Issues too many requests to the /login endpoint",
"Refuse request with SERVICE_UNAVAILABLE",
)
) in {
def login(actAs: Ref.Party) = {
val claims = Request.Claims(actAs = List(actAs))
val uri = middlewareClientRoutes.loginUri(claims, None)
val req = HttpRequest(uri = uri)
Http().singleRequest(req)
}
def followRedirect(resp: HttpResponse) = {
resp.status should ===(StatusCodes.Found)
val uri = resp.header[Location].value.uri
val req = HttpRequest(uri = uri)
Http().singleRequest(req)
}
for {
// Follow login flows up to redirect to middleware callback.
redirectAlice <- login(Ref.Party.assertFromString("Alice"))
.flatMap(followRedirect)
redirectBob <- login(Ref.Party.assertFromString("Bob"))
.flatMap(followRedirect)
// The store should be full
refusedCarol <- login(Ref.Party.assertFromString("Carol"))
_ = refusedCarol.status should ===(StatusCodes.ServiceUnavailable)
// Follow first redirect to middleware callback.
resultAlice <- followRedirect(redirectAlice)
_ = resultAlice.status should ===(StatusCodes.Found)
// The store should have space again
redirectCarol <- login(Ref.Party.assertFromString("Carol"))
.flatMap(followRedirect)
// Follow redirects to middleware callback.
resultBob <- followRedirect(redirectBob)
resultCarol <- followRedirect(redirectCarol)
} yield {
resultBob.status should ===(StatusCodes.Found)
resultCarol.status should ===(StatusCodes.Found)
}
}
}
}
class TestMiddlewareClientLimitedCallbackStore
extends AsyncWordSpec
with Matchers
with OptionValues
with TestFixture
with SuiteResourceManagementAroundAll {
val authenticationSecurity: SecurityTest =
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
override protected val maxClientAuthCallbacks = 2
"the /login client with an oauth server checking claims" should {
"refuse requests when max capacity is reached" taggedAs authenticationSecurity.setAttack(
Attack(
"HTTP client",
"Issues too many requests to the /login endpoint",
"Refuse request with SERVICE_UNAVAILABLE",
)
) in {
def login(actAs: Ref.Party) = {
val claims = Request.Claims(actAs = List(actAs))
val host = middlewareClientBinding.localAddress
val uri = Uri()
.withScheme("http")
.withAuthority(host.getHostName, host.getPort)
.withPath(Uri.Path./("login"))
.withQuery(Uri.Query("claims" -> claims.toQueryString()))
val req = HttpRequest(uri = uri)
Http().singleRequest(req)
}
def followRedirect(resp: HttpResponse) = {
resp.status should ===(StatusCodes.Found)
val uri = resp.header[Location].value.uri
val req = HttpRequest(uri = uri)
Http().singleRequest(req)
}
for {
// Follow login flows up to last redirect to middleware client.
redirectAlice <- login(Ref.Party.assertFromString("Alice"))
.flatMap(followRedirect)
.flatMap(followRedirect)
.flatMap(followRedirect)
redirectBob <- login(Ref.Party.assertFromString("Bob"))
.flatMap(followRedirect)
.flatMap(followRedirect)
.flatMap(followRedirect)
// The store should be full
refusedCarol <- login(Ref.Party.assertFromString("Carol"))
_ = refusedCarol.status should ===(StatusCodes.ServiceUnavailable)
// Follow first redirect to middleware client.
resultAlice <- followRedirect(redirectAlice)
_ = resultAlice.status should ===(StatusCodes.OK)
// The store should have space again
redirectCarol <- login(Ref.Party.assertFromString("Carol"))
.flatMap(followRedirect)
.flatMap(followRedirect)
.flatMap(followRedirect)
resultBob <- followRedirect(redirectBob)
resultCarol <- followRedirect(redirectCarol)
} yield {
resultBob.status should ===(StatusCodes.OK)
resultCarol.status should ===(StatusCodes.OK)
}
}
}
}
class TestMiddlewareClientNoRedirectToLogin
extends AsyncWordSpec
with Matchers
with OptionValues
with TryValues
with TestFixture
with SuiteResourceManagementAroundAll {
val authenticationSecurity: SecurityTest =
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
override protected val redirectToLogin: Client.RedirectToLogin = Client.RedirectToLogin.No
"the TestMiddlewareClientNoRedirectToLogin client" should {
"not redirect to /login" taggedAs authenticationSecurity.setHappyCase(
"An unauthorized request should not redirect to /login using the JSON protocol"
) in {
import com.daml.auth.middleware.api.JsonProtocol.responseAuthenticateChallengeFormat
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
val host = middlewareClientBinding.localAddress
val uri = Uri()
.withScheme("http")
.withAuthority(host.getHostName, host.getPort)
.withPath(Uri.Path./("authorize"))
.withQuery(Uri.Query("claims" -> claims.toQueryString()))
val req = HttpRequest(uri = uri)
for {
resp <- Http().singleRequest(req)
// Unauthorized with WWW-Authenticate header
_ = resp.status should ===(StatusCodes.Unauthorized)
wwwAuthenticate = resp.header[headers.`WWW-Authenticate`].value
challenge = wwwAuthenticate.challenges
.find(_.scheme == Response.authenticateChallengeName)
.value
_ = challenge.params.keys should contain.allOf("auth", "login")
authUri = challenge.params.get("auth").value
loginUri = challenge.params.get("login").value
headerChallenge = Response.AuthenticateChallenge(
Request.Claims(challenge.realm),
loginUri,
authUri,
)
// The body should include the same challenge
bodyChallenge <- Unmarshal(resp).to[Response.AuthenticateChallenge]
} yield {
headerChallenge.auth should ===(middlewareClient.authUri(claims))
headerChallenge.login.withQuery(Uri.Query.Empty) should ===(
middlewareClientRoutes.loginUri(claims).withQuery(Uri.Query.Empty)
)
headerChallenge.login.query().get("claims").value should ===(claims.toQueryString())
bodyChallenge should ===(headerChallenge)
}
}
}
}
class TestMiddlewareClientYesRedirectToLogin
extends AsyncWordSpec
with Matchers
with OptionValues
with TestFixture
with SuiteResourceManagementAroundAll {
val authenticationSecurity: SecurityTest =
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
override protected val redirectToLogin: Client.RedirectToLogin = Client.RedirectToLogin.Yes
"the TestMiddlewareClientYesRedirectToLogin client" should {
"redirect to /login" taggedAs authenticationSecurity.setHappyCase(
"A valid HTTP request should redirect to /login"
) in {
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
val host = middlewareClientBinding.localAddress
val uri = Uri()
.withScheme("http")
.withAuthority(host.getHostName, host.getPort)
.withPath(Uri.Path./("authorize"))
.withQuery(Uri.Query("claims" -> claims.toQueryString()))
val req = HttpRequest(uri = uri)
for {
resp <- Http().singleRequest(req)
} yield {
// Redirect to /login on middleware
resp.status should ===(StatusCodes.Found)
val loginUri = resp.header[headers.Location].value.uri
loginUri.withQuery(Uri.Query.Empty) should ===(
middlewareClientRoutes
.loginUri(claims)
.withQuery(Uri.Query.Empty)
)
loginUri.query().get("claims").value should ===(claims.toQueryString())
}
}
}
}
class TestMiddlewareClientAutoRedirectToLogin
extends AsyncWordSpec
with Matchers
with TestFixture
with SuiteResourceManagementAroundAll {
val authenticationSecurity: SecurityTest =
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
override protected val redirectToLogin: Client.RedirectToLogin = Client.RedirectToLogin.Auto
"the TestMiddlewareClientAutoRedirectToLogin client" should {
"redirect to /login for HTML request" taggedAs authenticationSecurity.setHappyCase(
"A HTML request is redirected to /login"
) in {
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
val host = middlewareClientBinding.localAddress
val uri = Uri()
.withScheme("http")
.withAuthority(host.getHostName, host.getPort)
.withPath(Uri.Path./("authorize"))
.withQuery(Uri.Query("claims" -> claims.toQueryString()))
val acceptHtml: HttpHeader = headers.Accept(MediaTypes.`text/html`)
val req = HttpRequest(uri = uri, headers = immutable.Seq(acceptHtml))
for {
resp <- Http().singleRequest(req)
} yield {
// Redirect to /login on middleware
resp.status should ===(StatusCodes.Found)
}
}
"not redirect to /login for JSON request" taggedAs authenticationSecurity.setHappyCase(
"A JSON request is not redirected to /login"
) in {
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
val host = middlewareClientBinding.localAddress
val uri = Uri()
.withScheme("http")
.withAuthority(host.getHostName, host.getPort)
.withPath(Uri.Path./("authorize"))
.withQuery(Uri.Query("claims" -> claims.toQueryString()))
val acceptHtml: HttpHeader = headers.Accept(MediaTypes.`application/json`)
val req = HttpRequest(uri = uri, headers = immutable.Seq(acceptHtml))
for {
resp <- Http().singleRequest(req)
} yield {
// Unauthorized with WWW-Authenticate header
resp.status should ===(StatusCodes.Unauthorized)
}
}
}
}
class TestMiddlewareClientLoginCallbackUri
extends AnyWordSpec
with Matchers
with ScalatestRouteTest {
private val client = Client(
Client.Config(
authMiddlewareInternalUri = Uri("http://auth.internal"),
authMiddlewareExternalUri = Uri("http://auth.external"),
redirectToLogin = Client.RedirectToLogin.Yes,
maxAuthCallbacks = 1000,
authCallbackTimeout = FiniteDuration(1, duration.MINUTES),
maxHttpEntityUploadSize = 4194304,
httpEntityUploadTimeout = FiniteDuration(1, duration.MINUTES),
)
)
"fixed callback URI" should {
"be absolute" in {
an[AssertionError] should be thrownBy client.routes(Uri().withPath(Uri.Path./("cb")))
}
"be used in login URI" in {
val routes = client.routes(Uri("http://client.domain/cb"))
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
routes.loginUri(claims = claims) shouldBe
Uri(
s"http://auth.external/login?claims=${claims.toQueryString()}&redirect_uri=http://client.domain/cb"
)
}
}
"callback URI from request" should {
"be used in login URI" in {
val routes = client.routesFromRequestAuthority(Uri.Path./("cb"))
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
import org.apache.pekko.http.scaladsl.server.directives.RouteDirectives._
Get("http://client.domain") ~> routes { routes =>
complete(routes.loginUri(claims).toString)
} ~> check {
responseAs[String] shouldEqual
s"http://auth.external/login?claims=${claims.toQueryString()}&redirect_uri=http://client.domain/cb"
}
}
}
"automatic callback URI" should {
"be fixed when absolute" in {
val routes = client.routesAuto(Uri("http://client.domain/cb"))
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
import org.apache.pekko.http.scaladsl.server.directives.RouteDirectives._
Get() ~> routes { routes => complete(routes.loginUri(claims).toString) } ~> check {
responseAs[String] shouldEqual
s"http://auth.external/login?claims=${claims.toQueryString()}&redirect_uri=http://client.domain/cb"
}
}
"be from request when relative" in {
val routes = client.routesAuto(Uri().withPath(Uri.Path./("cb")))
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
import org.apache.pekko.http.scaladsl.server.directives.RouteDirectives._
Get("http://client.domain") ~> routes { routes =>
complete(routes.loginUri(claims).toString)
} ~> check {
responseAs[String] shouldEqual
s"http://auth.external/login?claims=${claims.toQueryString()}&redirect_uri=http://client.domain/cb"
}
}
}
}

View File

@ -1,68 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.oauth2
import com.daml.auth.middleware.api.RequestStore
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AsyncWordSpec
import scala.concurrent.duration._
class TestRequestStore extends AsyncWordSpec with Matchers {
"return None on missing element" in {
val store = new RequestStore[Int, String](1, 1.day)
store.pop(0) should ===(None)
}
"return previously put element" in {
val store = new RequestStore[Int, String](1, 1.day)
store.put(0, "zero")
store.pop(0) should ===(Some("zero"))
}
"return None on previously popped element" in {
val store = new RequestStore[Int, String](1, 1.day)
store.put(0, "zero")
store.pop(0)
store.pop(0) should ===(None)
}
"store multiple elements" in {
val store = new RequestStore[Int, String](3, 1.day)
store.put(0, "zero")
store.put(1, "one")
store.put(2, "two")
store.pop(0) should ===(Some("zero"))
store.pop(1) should ===(Some("one"))
store.pop(2) should ===(Some("two"))
}
"store no more than max capacity" in {
val store = new RequestStore[Int, String](2, 1.day)
assert(store.put(0, "zero"))
assert(store.put(1, "one"))
assert(!store.put(2, "two"))
store.pop(0) should ===(Some("zero"))
store.pop(1) should ===(Some("one"))
store.pop(2) should ===(None)
}
"return None on timed out element" in {
var time: Long = 0
val store = new RequestStore[Int, String](1, 1.day, () => time)
store.put(0, "zero")
time += 1.day.toNanos
store.pop(0) should ===(None)
}
"free capacity for timed out elements" in {
var time: Long = 0
val store = new RequestStore[Int, String](1, 1.day, () => time)
assert(store.put(0, "zero"))
assert(!store.put(1, "one"))
time += 1.day.toNanos
assert(store.put(2, "two"))
store.pop(2) should ===(Some("two"))
}
}

View File

@ -1,236 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.middleware.oauth2
import java.io._
import java.nio.file.Path
import java.util.UUID
import org.apache.pekko.http.scaladsl.model.Uri
import com.daml.auth.middleware.api.Request.Claims
import com.daml.auth.middleware.api.Tagged.RefreshToken
import com.daml.lf.data.Ref
import com.daml.scalautil.Statement.discard
import org.scalatest.{PartialFunctionValues, TryValues}
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
class TestRequestTemplates
extends AnyWordSpec
with Matchers
with TryValues
with PartialFunctionValues {
private val clientId = "client-id"
private val clientSecret = SecretString("client-secret")
private def getTemplates(
authTemplate: Option[Path] = None,
tokenTemplate: Option[Path] = None,
refreshTemplate: Option[Path] = None,
): RequestTemplates =
RequestTemplates(
clientId = clientId,
clientSecret = clientSecret,
authTemplate = authTemplate,
tokenTemplate = tokenTemplate,
refreshTemplate = refreshTemplate,
)
private def withJsonnetFile(content: String)(testCode: Path => Any): Unit = {
val file = File.createTempFile("test-request-template", ".jsonnet")
val writer = new FileWriter(file)
try {
writer.write(content)
writer.close()
discard(testCode(file.toPath))
} finally discard(file.delete())
}
"the builtin auth template" should {
"handle empty claims" in {
val templates = getTemplates()
val claims = Claims(admin = false, actAs = Nil, readAs = Nil, applicationId = None)
val requestId = UUID.randomUUID()
val redirectUri = Uri("https://localhost/cb")
val params = templates.createAuthRequest(claims, requestId, redirectUri).success.value
params.keys should contain.only(
"audience",
"client_id",
"redirect_uri",
"response_type",
"scope",
"state",
)
params should contain.allOf(
"audience" -> "https://daml.com/ledger-api",
"client_id" -> clientId,
"redirect_uri" -> redirectUri.toString,
"response_type" -> "code",
"scope" -> "offline_access",
"state" -> requestId.toString,
)
}
"handle an admin claim" in {
val templates = getTemplates()
val claims = Claims(admin = true, actAs = Nil, readAs = Nil, applicationId = None)
val requestId = UUID.randomUUID()
val redirectUri = Uri("https://localhost/cb")
val params = templates.createAuthRequest(claims, requestId, redirectUri).success.value
params.keys should contain.only(
"audience",
"client_id",
"redirect_uri",
"response_type",
"scope",
"state",
)
params should contain.allOf(
"audience" -> "https://daml.com/ledger-api",
"client_id" -> clientId,
"redirect_uri" -> redirectUri.toString,
"response_type" -> "code",
"state" -> requestId.toString,
)
val scope = params.valueAt("scope").split(" ")
scope should contain.allOf("admin", "offline_access")
}
"handle actAs claims" in {
val templates = getTemplates()
val claims = Claims(
admin = false,
actAs = List("Alice", "Bob").map(Ref.Party.assertFromString),
readAs = Nil,
applicationId = None,
)
val requestId = UUID.randomUUID()
val redirectUri = Uri("https://localhost/cb")
val params = templates.createAuthRequest(claims, requestId, redirectUri).success.value
params.keys should contain.only(
"audience",
"client_id",
"redirect_uri",
"response_type",
"scope",
"state",
)
params should contain.allOf(
"audience" -> "https://daml.com/ledger-api",
"client_id" -> clientId,
"redirect_uri" -> redirectUri.toString,
"response_type" -> "code",
"state" -> requestId.toString,
)
val scope = params.valueAt("scope").split(" ")
scope should contain.allOf("actAs:Alice", "actAs:Bob", "offline_access")
}
"handle readAs claims" in {
val templates = getTemplates()
val claims = Claims(
admin = false,
actAs = Nil,
readAs = List("Alice", "Bob").map(Ref.Party.assertFromString),
applicationId = None,
)
val requestId = UUID.randomUUID()
val redirectUri = Uri("https://localhost/cb")
val params = templates.createAuthRequest(claims, requestId, redirectUri).success.value
params.keys should contain.only(
"audience",
"client_id",
"redirect_uri",
"response_type",
"scope",
"state",
)
params should contain.allOf(
"audience" -> "https://daml.com/ledger-api",
"client_id" -> clientId,
"redirect_uri" -> redirectUri.toString,
"response_type" -> "code",
"state" -> requestId.toString,
)
val scope = params.valueAt("scope").split(" ")
scope should contain.allOf("offline_access", "readAs:Alice", "readAs:Bob")
}
"handle an applicationId claim" in {
val templates = getTemplates()
val claims = Claims(
admin = false,
actAs = Nil,
readAs = Nil,
applicationId = Some(Ref.ApplicationId.assertFromString("application-id")),
)
val requestId = UUID.randomUUID()
val redirectUri = Uri("https://localhost/cb")
val params = templates.createAuthRequest(claims, requestId, redirectUri).success.value
params.keys should contain.only(
"audience",
"client_id",
"redirect_uri",
"response_type",
"scope",
"state",
)
params should contain.allOf(
"audience" -> "https://daml.com/ledger-api",
"client_id" -> clientId,
"redirect_uri" -> redirectUri.toString,
"response_type" -> "code",
"state" -> requestId.toString,
)
val scope = params.valueAt("scope").split(" ")
scope should contain.allOf("applicationId:application-id", "offline_access")
}
}
"the builtin token template" should {
"be complete" in {
val templates = getTemplates()
val code = "request-code"
val redirectUri = Uri("https://localhost/cb")
val params = templates.createTokenRequest(code, redirectUri).success.value
params shouldBe Map(
"client_id" -> clientId,
"client_secret" -> clientSecret.value,
"code" -> code,
"grant_type" -> "authorization_code",
"redirect_uri" -> redirectUri.toString,
)
}
}
"the builtin refresh template" should {
"be complete" in {
val templates = getTemplates()
val refreshToken = RefreshToken("refresh-token")
val params = templates.createRefreshRequest(refreshToken).success.value
params shouldBe Map(
"client_id" -> clientId,
"client_secret" -> clientSecret.value,
"grant_type" -> "refresh_code",
"refresh_token" -> refreshToken,
)
}
}
"user defined templates" should {
"override the auth template" in withJsonnetFile("""{"key": "value"}""") { templatePath =>
val templates = getTemplates(authTemplate = Some(templatePath))
val claims = Claims(admin = false, actAs = Nil, readAs = Nil, applicationId = None)
val requestId = UUID.randomUUID()
val redirectUri = Uri("https://localhost/cb")
val params = templates.createAuthRequest(claims, requestId, redirectUri).success.value
params shouldBe Map("key" -> "value")
}
"override the token template" in withJsonnetFile("""{"key": "value"}""") { templatePath =>
val templates = getTemplates(tokenTemplate = Some(templatePath))
val code = "request-code"
val redirectUri = Uri("https://localhost/cb")
val params = templates.createTokenRequest(code, redirectUri).success.value
params shouldBe Map("key" -> "value")
}
"override the refresh template" in withJsonnetFile("""{"key": "value"}""") { templatePath =>
val templates = getTemplates(refreshTemplate = Some(templatePath))
val refreshToken = RefreshToken("refresh-token")
val params = templates.createRefreshRequest(refreshToken).success.value
params shouldBe Map("key" -> "value")
}
}
}

View File

@ -1,198 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.oauth2.test.server
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.http.scaladsl.Http
import org.apache.pekko.http.scaladsl.Http.ServerBinding
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import org.apache.pekko.http.scaladsl.marshalling.Marshal
import org.apache.pekko.http.scaladsl.model.Uri.Path
import org.apache.pekko.http.scaladsl.model._
import org.apache.pekko.http.scaladsl.server.Directives._
import org.apache.pekko.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller}
import com.daml.auth.oauth2.api.{Request, Response}
import com.daml.ports.Port
import spray.json._
import scala.concurrent.{ExecutionContext, Future}
import scala.language.postfixOps
// This is a test client (using terminology from oauth).
// The trigger service would also take the role of a client.
object Client {
import com.daml.auth.oauth2.api.JsonProtocol._
case class Config(
port: Port,
authServerUrl: Uri,
clientId: String,
clientSecret: String,
)
object JsonProtocol extends DefaultJsonProtocol {
implicit val accessParamsFormat: RootJsonFormat[AccessParams] = jsonFormat3(AccessParams)
implicit val refreshParamsFormat: RootJsonFormat[RefreshParams] = jsonFormat1(RefreshParams)
implicit object ResponseJsonFormat extends RootJsonFormat[Response] {
implicit private val accessFormat: RootJsonFormat[AccessResponse] = jsonFormat2(
AccessResponse
)
implicit private val errorFormat: RootJsonFormat[ErrorResponse] = jsonFormat1(ErrorResponse)
def write(resp: Response) = resp match {
case resp @ AccessResponse(_, _) => resp.toJson
case resp @ ErrorResponse(_) => resp.toJson
}
def read(value: JsValue) =
(
value.convertTo(safeReader[AccessResponse]),
value.convertTo(safeReader[ErrorResponse]),
) match {
case (Right(a), _) => a
case (_, Right(b)) => b
case (Left(ea), Left(eb)) =>
deserializationError(s"Could not read Response value:\n$ea\n$eb")
}
}
}
case class AccessParams(parties: Seq[String], admin: Boolean, applicationId: Option[String])
case class RefreshParams(refreshToken: String)
sealed trait Response
final case class AccessResponse(token: String, refresh: String) extends Response
final case class ErrorResponse(error: String) extends Response
def toRedirectUri(uri: Uri): Uri = uri.withPath(Path./("cb"))
def start(
config: Config
)(implicit asys: ActorSystem, ec: ExecutionContext): Future[ServerBinding] = {
import JsonProtocol._
implicit val unmarshal: Unmarshaller[String, Uri] = Unmarshaller.strict(Uri(_))
val route = concat(
// Some parameter that requires authorization for some parties. This will in the end return the token
// produced by the authorization server.
path("access") {
post {
entity(as[AccessParams]) { params =>
extractRequest { request =>
val redirectUri = toRedirectUri(request.uri)
val scope =
(params.parties.map(p => "actAs:" + p) ++
(if (params.admin) List("admin") else Nil) ++
params.applicationId.toList.map(id => "applicationId:" + id)).mkString(" ")
val authParams = Request.Authorize(
responseType = "code",
clientId = config.clientId,
redirectUri = redirectUri,
scope = Some(scope),
state = None,
audience = Some("https://daml.com/ledger-api"),
)
redirect(
config.authServerUrl
.withQuery(authParams.toQuery)
.withPath(Path./("authorize")),
StatusCodes.Found,
)
}
}
}
},
path("cb") {
get {
parameters(Symbol("code"), Symbol("state") ?).as[Response.Authorize](Response.Authorize) {
resp =>
extractRequest { request =>
// We got the code, now request a token
val body = Request.Token(
grantType = "authorization_code",
code = resp.code,
redirectUri = toRedirectUri(request.uri),
clientId = config.clientId,
clientSecret = config.clientSecret,
)
val f = for {
entity <- Marshal(body).to[RequestEntity]
req = HttpRequest(
uri = config.authServerUrl.withPath(Path./("token")),
entity = entity,
method = HttpMethods.POST,
)
resp <- Http().singleRequest(req)
tokenResp <- Unmarshal(resp).to[Response.Token]
} yield tokenResp
onSuccess(f) { tokenResp =>
// Now we have the access_token and potentially the refresh token. At this point,
// we would start the trigger.
complete(
AccessResponse(
tokenResp.accessToken,
tokenResp.refreshToken.getOrElse(
sys.error("/token endpoint failed to return a refresh token")
),
): Response
)
}
}
} ~
parameters(
Symbol("error"),
Symbol("error_description") ?,
Symbol("error_uri").as[Uri] ?,
Symbol("state") ?,
)
.as[Response.Error](Response.Error) { resp =>
complete(ErrorResponse(resp.error): Response)
}
}
},
path("refresh") {
post {
entity(as[RefreshParams]) { params =>
val body = Request.Refresh(
grantType = "refresh_token",
refreshToken = params.refreshToken,
clientId = config.clientId,
clientSecret = config.clientSecret,
)
val f =
for {
entity <- Marshal(body).to[RequestEntity]
req = HttpRequest(
uri = config.authServerUrl.withPath(Path./("token")),
entity = entity,
method = HttpMethods.POST,
)
resp <- Http().singleRequest(req)
tokenResp <-
if (resp.status != StatusCodes.OK) {
Unmarshal(resp).to[String].flatMap { msg =>
throw new RuntimeException(
s"Failed to fetch refresh token (${resp.status}): $msg."
)
}
} else {
Unmarshal(resp).to[Response.Token]
}
} yield tokenResp
onSuccess(f) { tokenResp =>
// Now we have the access_token and potentially the refresh token. At this point,
// we would start the trigger.
complete(
AccessResponse(
tokenResp.accessToken,
tokenResp.refreshToken.getOrElse(
sys.error("/token endpoint failed to return a refresh token")
),
): Response
)
}
}
}
},
)
Http().newServerAt("localhost", config.port.value).bind(route)
}
}

View File

@ -1,36 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.oauth2.test.server
import java.time.{Clock, Duration, Instant, ZoneId}
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.http.scaladsl.Http.ServerBinding
import com.daml.clock.AdjustableClock
import com.daml.ledger.resources.{Resource, ResourceContext, ResourceOwner}
import scala.concurrent.Future
object Resources {
def clock(start: Instant, zoneId: ZoneId): ResourceOwner[AdjustableClock] =
new ResourceOwner[AdjustableClock] {
override def acquire()(implicit context: ResourceContext): Resource[AdjustableClock] = {
Resource(Future(AdjustableClock(Clock.fixed(start, zoneId), Duration.ZERO)))(_ =>
Future(())
)
}
}
def authServerBinding(server: Server)(implicit sys: ActorSystem): ResourceOwner[ServerBinding] =
new ResourceOwner[ServerBinding] {
override def acquire()(implicit context: ResourceContext): Resource[ServerBinding] =
Resource(server.start())(_.unbind().map(_ => ()))
}
def authClientBinding(
config: Client.Config
)(implicit sys: ActorSystem): ResourceOwner[ServerBinding] =
new ResourceOwner[ServerBinding] {
override def acquire()(implicit context: ResourceContext): Resource[ServerBinding] =
Resource(Client.start(config))(_.unbind().map(_ => ()))
}
}

View File

@ -1,285 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.oauth2.test.server
import org.apache.pekko.http.scaladsl.Http
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import org.apache.pekko.http.scaladsl.model.Uri.Path
import org.apache.pekko.http.scaladsl.model._
import org.apache.pekko.http.scaladsl.model.headers.Location
import org.apache.pekko.http.scaladsl.unmarshalling.Unmarshal
import com.daml.jwt.JwtDecoder
import com.daml.jwt.domain.Jwt
import com.daml.ledger.api.auth.{
AuthServiceJWTCodec,
CustomDamlJWTPayload,
AuthServiceJWTPayload,
StandardJWTPayload,
}
import com.daml.lf.data.Ref
import com.daml.ledger.api.testing.utils.SuiteResourceManagementAroundAll
import org.scalatest.OptionValues
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AsyncWordSpec
import spray.json._
import java.time.Instant
import scala.concurrent.Future
import scala.util.Try
abstract class Test
extends AsyncWordSpec
with Matchers
with OptionValues
with TestFixture
with SuiteResourceManagementAroundAll {
import Client.JsonProtocol._
import Test._
type Tok <: AuthServiceJWTPayload
protected[this] val Tok: TokenCompat[Tok]
implicit def `default Token`: Token[Tok]
private def readJWTTokenFromString[A](
serializedPayload: String
)(implicit A: Token[A]): Try[A] =
AuthServiceJWTCodec.readFromString(serializedPayload).flatMap { t => Try(A.run(t)) }
private def requestToken[A: Token](
parties: Seq[String],
admin: Boolean,
applicationId: Option[String],
): Future[Either[String, (A, String)]] = {
lazy val clientUri = Uri()
.withAuthority(clientBinding.localAddress.getHostString, clientBinding.localAddress.getPort)
val req = HttpRequest(
uri = clientUri.withPath(Path./("access")).withScheme("http"),
method = HttpMethods.POST,
entity = HttpEntity(
MediaTypes.`application/json`,
Client.AccessParams(parties, admin, applicationId).toJson.compactPrint,
),
)
for {
resp <- Http().singleRequest(req)
// Redirect to /authorize on authorization server (No automatic redirect handling in pekko-http)
resp <- {
resp.status should ===(StatusCodes.Found)
val req = HttpRequest(uri = resp.header[Location].value.uri)
Http().singleRequest(req)
}
// Redirect to /cb on client.
resp <- {
resp.status should ===(StatusCodes.Found)
val req = HttpRequest(uri = resp.header[Location].value.uri)
Http().singleRequest(req)
}
// Actual token response (proxied from auth server to us via the client)
body <- Unmarshal(resp).to[Client.Response]
result <- body match {
case Client.AccessResponse(token, refreshToken) =>
for {
decodedJwt <- JwtDecoder
.decode(Jwt(token))
.fold(
e => Future.failed(new IllegalArgumentException(e.toString)),
Future.successful(_),
)
payload <- Future.fromTry(readJWTTokenFromString[A](decodedJwt.payload))
} yield Right((payload, refreshToken))
case Client.ErrorResponse(error) => Future(Left(error))
}
} yield result
}
private def requestRefresh[A: Token](
refreshToken: String
): Future[Either[String, (A, String)]] = {
lazy val clientUri = Uri()
.withAuthority(clientBinding.localAddress.getHostString, clientBinding.localAddress.getPort)
val req = HttpRequest(
uri = clientUri.withPath(Path./("refresh")).withScheme("http"),
method = HttpMethods.POST,
entity = HttpEntity(
MediaTypes.`application/json`,
Client.RefreshParams(refreshToken).toJson.compactPrint,
),
)
for {
resp <- Http().singleRequest(req)
// Token response (proxied from auth server to us via the client)
body <- Unmarshal(resp).to[Client.Response]
result <- body match {
case Client.AccessResponse(token, refreshToken) =>
for {
decodedJwt <- JwtDecoder
.decode(Jwt(token))
.fold(
e => Future.failed(new IllegalArgumentException(e.toString)),
Future.successful(_),
)
payload <- Future.fromTry(readJWTTokenFromString[A](decodedJwt.payload))
} yield Right((payload, refreshToken))
case Client.ErrorResponse(error) => Future(Left(error))
}
} yield result
}
protected[this] def expectToken(
parties: Seq[String],
admin: Boolean = false,
applicationId: Option[String] = None,
): Future[(Tok, String)] =
requestToken(parties, admin, applicationId).flatMap {
case Left(error) => fail(s"Expected token but got error-code $error")
case Right(token) => Future(token)
}
protected[this] def expectError(
parties: Seq[String],
admin: Boolean = false,
applicationId: Option[String] = None,
): Future[String] =
requestToken[AuthServiceJWTPayload](parties, admin, applicationId).flatMap {
case Left(error) => Future(error)
case Right(_) => fail("Expected an error but got a token")
}
private def expectRefresh(refreshToken: String): Future[(Tok, String)] =
requestRefresh(refreshToken).flatMap {
case Left(error) => fail(s"Expected token but got error-code $error")
case Right(token) => Future(token)
}
"the auth server" should {
"refresh a token" in {
for {
(token1, refresh1) <- expectToken(Seq())
_ <- Future(clock.set((Tok exp token1) plusSeconds 1))
(token2, _) <- expectRefresh(refresh1)
} yield {
(Tok exp token2) should be > (Tok exp token1)
(Tok withoutExp token1) should ===((Tok withoutExp token2))
}
}
"return a token with the requested app id" in {
for {
(token, __) <- expectToken(Seq(), applicationId = Some("my-app-id"))
} yield {
Tok.userId(token) should ===(Some("my-app-id"))
}
}
"return a token with no app id if non is requested" in {
for {
(token, __) <- expectToken(Seq(), applicationId = None)
} yield {
Tok.userId(token) should ===(None)
}
}
}
}
class ClaimTokenTest extends Test {
import Test._
override def yieldUserTokens = false
type Tok = CustomDamlJWTPayload
override object Tok extends TokenCompat[Tok] {
override def userId(t: Tok) = t.applicationId
override def exp(t: Tok) = t.exp.value
override def withoutExp(t: Tok) = t copy (exp = None)
}
implicit override def `default Token`: Token[Tok] = new Token({
case _: StandardJWTPayload =>
throw new IllegalStateException(
"auth-middleware: user access tokens are not expected here"
)
case payload: CustomDamlJWTPayload => payload
})
"the auth server with claim tokens" should {
"issue a token with no parties" in {
for {
(token, _) <- expectToken(Seq())
} yield {
token.actAs should ===(Seq())
}
}
"issue a token with 1 party" in {
for {
(token, _) <- expectToken(Seq("Alice"))
} yield {
token.actAs should ===(Seq("Alice"))
}
}
"issue a token with multiple parties" in {
for {
(token, _) <- expectToken(Seq("Alice", "Bob"))
} yield {
token.actAs should ===(Seq("Alice", "Bob"))
}
}
"deny access to unauthorized parties" in {
server.revokeParty(Ref.Party.assertFromString("Eve"))
for {
error <- expectError(Seq("Alice", "Eve"))
} yield {
error should ===("access_denied")
}
}
"issue a token with admin access" in {
for {
(token, _) <- expectToken(Seq(), admin = true)
} yield {
assert(token.admin)
}
}
"deny admin access if unauthorized" in {
server.revokeAdmin()
for {
error <- expectError(Seq(), admin = true)
} yield {
error should ===("access_denied")
}
}
}
}
class UserTokenTest extends Test {
import Test._
override def yieldUserTokens = true
type Tok = StandardJWTPayload
override object Tok extends TokenCompat[Tok] {
override def userId(t: Tok) = Some(t.userId).filter(_.nonEmpty)
override def exp(t: Tok) = t.exp.value
override def withoutExp(t: Tok) = t copy (exp = None)
}
implicit override def `default Token`: Token[Tok] = new Token({
case payload: StandardJWTPayload => payload
case _: CustomDamlJWTPayload =>
throw new IllegalStateException(
"auth-middleware: custom tokens are not expected here"
)
})
}
object Test {
final class Token[A](val run: AuthServiceJWTPayload => A) extends AnyVal
object Token {
implicit val any: Token[AuthServiceJWTPayload] = new Token(identity)
}
private[server] abstract class TokenCompat[Tok] {
def userId(t: Tok): Option[String]
def exp(t: Tok): Instant
def withoutExp(t: Tok): Tok
}
}

View File

@ -1,73 +0,0 @@
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.auth.oauth2.test.server
import java.time.{Instant, ZoneId}
import org.apache.pekko.http.scaladsl.Http.ServerBinding
import org.apache.pekko.http.scaladsl.model.Uri
import com.daml.clock.AdjustableClock
import com.daml.ledger.api.testing.utils.{
PekkoBeforeAndAfterAll,
OwnedResource,
Resource,
SuiteResource,
}
import com.daml.ledger.resources.ResourceContext
import com.daml.ports.Port
import org.scalatest.{BeforeAndAfterEach, Suite}
trait TestFixture
extends PekkoBeforeAndAfterAll
with BeforeAndAfterEach
with SuiteResource[(AdjustableClock, Server, ServerBinding, ServerBinding)] {
self: Suite =>
protected val ledgerId: String = "test-ledger"
protected val applicationId: String = "test-application"
protected val jwtSecret: String = "secret"
lazy protected val clock: AdjustableClock = suiteResource.value._1
lazy protected val server: Server = suiteResource.value._2
lazy protected val serverBinding: ServerBinding = suiteResource.value._3
lazy protected val clientBinding: ServerBinding = suiteResource.value._4
protected[this] def yieldUserTokens: Boolean
override protected lazy val suiteResource
: Resource[(AdjustableClock, Server, ServerBinding, ServerBinding)] = {
implicit val resourceContext: ResourceContext = ResourceContext(system.dispatcher)
new OwnedResource[ResourceContext, (AdjustableClock, Server, ServerBinding, ServerBinding)](
for {
clock <- Resources.clock(Instant.now(), ZoneId.systemDefault())
server = Server(
Config(
port = Port.Dynamic,
ledgerId = ledgerId,
jwtSecret = jwtSecret,
clock = Some(clock),
yieldUserTokens = yieldUserTokens,
)
)
serverBinding <- Resources.authServerBinding(server)
clientBinding <- Resources.authClientBinding(
Client.Config(
port = Port.Dynamic,
authServerUrl = Uri()
.withScheme("http")
.withAuthority(
serverBinding.localAddress.getHostString,
serverBinding.localAddress.getPort,
),
clientId = "test-client",
clientSecret = "test-client-secret",
)
)
} yield { (clock, server, serverBinding, clientBinding) }
)
}
override protected def afterEach(): Unit = {
server.resetAuthorizedParties()
server.resetAdmin()
super.afterEach()
}
}

View File

@ -1,90 +0,0 @@
# Design for Trigger Service Authentication/Authorization
## Goals
- Be compatible with an OAuth2 authorization code grant
https://tools.ietf.org/html/rfc6749#section-4.1
- Do not require OAuth2 or any other specific
authentication/authorization protocol from the IAM. In other words,
the communication with the IAM must be pluggable.
- Do not rely on wildcard access for the trigger service, it should
only be able to start triggers on behalf of a party if a user that
controls that party has given consent.
- Support long-running triggers without constant user
interaction. Since auth tokens are often short-lived (e.g., expire
after 1h), this implies some mechanism for token refresh.
## Design
This involves 3 components:
1. The trigger service provided by DA.
2. An auth middleware. DA provides an implementation of this for at
least the OAuth2 authorization code grant but this is completely
pluggable so if the DA-provided middleware does not cover the IAM
infrastructure of a client, they can implement their own.
3. The IAM. This is the entity that signs Ledger API tokens. This is
not provided by DA. The Ledger is configured to trust this entity.
### Auth Middleware API
The auth middleware provides a few endpoints (the names dont matter
all that much, they just need to be fixed once).
1. /auth The trigger service, will contact this endpoint with a set of
claims passing along all cookies in the original request. If
the user has already authenticated and is authorized for those
claims, it will return an access token (an opaque blob to the
trigger service) for at least those claims and a refresh token
(another opaque blob). If not, it will return an unauthorized
status code.
2. /login If /auth returned unauthorized, the trigger service will
redirect users to this.
For HTML requests via HTTP redirect, otherwise via a custom WWW-Authenticate challenge in a 401 resonse.
The parameters will include the requested claims as well as an optional callback URL (note that this is not the OAuth2 callback url but a callback URL on the trigger service). This will start an auth flow,
e.g., an OAuth2 authorization code grant. If the flow succeeds the
auth service will set a cookie with the access and refresh token
and redirect to the callback URL if present or return status code 200.
At this point, a request to
/auth will succeed (based on the cookie). If the flow failed the
auth service will not set a cookie and redirect to the callback URL
with an additional error and optional error_description parameter
or return 403 with error and optional error_description in the response body.
3. /refresh This accepts a refresh token and returns a new access
token and optionally a new refresh token (or fails).
### Auth Middleware Implementation based on OAuth2 Authorization Code Grant
1. /auth checks for the presence of a cookie with the tokens in it.
2. /login starts an OAuth2 authorization code grant flow. After the
redirect URI is called by the authorization server, the middleware
makes a request to get the tokens, sets them in cookies and
redirects back to the callback URI. Upon failure the middleware
forwards the error and error_description to the callback URI.
3. /refresh simply proxies to the refresh endpoint on the
authorization server adding the client id and secret.
Note that the auth middleware does not need to persist any state in
this model. The trigger service does need to persist at least the
refresh token and potentially the access token.
## Related Projects
The design here is very close to existing OAuth2 middlewares/proxies
such as [Vouch](https://github.com/vouch/vouch-proxy) or [OAuth2
Proxy](https://github.com/oauth2-proxy/oauth2-proxy). There are two
main differences:
1. The trigger service extracts the required claims from the original
request. This means that the first request goes to the trigger
service and not to the proxy/middleware. The middleware shouldnt
have to know how to map a request to the set of required claims so
this seems the only workable option.
2. The /auth request takes a custom set of claims. The existing
proxies focus on OIDC and dont support any custom claims.
Nevertheless, the fact that the design is very close to existing
proxies seems like a good thing.

View File

@ -1,12 +0,0 @@
#!/usr/bin/env bash
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -eou pipefail
shopt -s globstar
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"
"$DIR"/../../libs-scala/flyway-testing/hash-migrations.sh \
"$DIR"/src/main/resources/com/daml/lf/engine/trigger/db/migration/**/*.sql

View File

@ -1,26 +0,0 @@
<configuration>
<appender name="console" class="ch.qos.logback.core.ConsoleAppender">
<if condition='isDefined("LOG_FORMAT_JSON")'>
<then>
<encoder class="net.logstash.logback.encoder.LogstashEncoder"/>
</then>
<else>
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg %replace(, context: %marker){', context: $', ''} %n</pattern>
</encoder>
</else>
</if>
</appender>
<appender name="STDOUT" class="net.logstash.logback.appender.LoggingEventAsyncDisruptorAppender">
<appender-ref ref="console"/>
</appender>
<logger name="io.netty" level="WARN"/>
<logger name="io.grpc.netty" level="WARN"/>
<logger name="com.daml.lf.engine.trigger" level="INFO"/>
<root level="${LOG_LEVEL_ROOT:-INFO}">
<appender-ref ref="STDOUT" />
</root>
</configuration>

Some files were not shown because too many files have changed in this diff Show More