mirror of
https://github.com/digital-asset/daml.git
synced 2024-11-08 21:34:22 +03:00
Remove triggers (#18142)
* remove triggers * adopt the codebase to missing triggers run-all-tests: true * fix builid * Remove oracle tests
This commit is contained in:
parent
043dfc55c3
commit
7108f2c76a
@ -25,7 +25,6 @@ NOTICES @garyverhaegen-da @dasormeter
|
||||
# Language
|
||||
/daml-assistant/ @remyhaemmerle-da @filmackay
|
||||
/daml-script/ @remyhaemmerle-da
|
||||
/triggers/ @remyhaemmerle-da
|
||||
/compiler/ @basvangijzel-DA @remyhaemmerle-da @akrmn @dylant-da @samuel-williams-da @paulbrauner-da
|
||||
/libs-haskell/ @basvangijzel-DA @remyhaemmerle-da @akrmn @dylant-da @samuel-williams-da @paulbrauner-da
|
||||
/ghc-lib/ @basvangijzel-DA @remyhaemmerle-da @akrmn @dylant-da @samuel-williams-da @paulbrauner-da
|
||||
@ -56,7 +55,6 @@ NOTICES @garyverhaegen-da @dasormeter
|
||||
# Application Runtime
|
||||
/ledger-service/ @filmackay
|
||||
/runtime-components/ @filmackay
|
||||
/triggers/service/ @filmackay
|
||||
|
||||
# Canton code drop
|
||||
/canton/ @remyhaemmerle-da
|
||||
|
@ -635,10 +635,6 @@ load("@canton_maven//:defs.bzl", pinned_canton_maven_install = "pinned_maven_ins
|
||||
|
||||
pinned_canton_maven_install()
|
||||
|
||||
load("@triggers_maven//:defs.bzl", pinned_triggers_maven_install = "pinned_maven_install")
|
||||
|
||||
pinned_triggers_maven_install()
|
||||
|
||||
load("@deprecated_maven//:defs.bzl", pinned_deprecated_maven_install = "pinned_maven_install")
|
||||
|
||||
pinned_deprecated_maven_install()
|
||||
|
@ -313,22 +313,6 @@ def install_java_deps():
|
||||
version_conflict_policy = "pinned",
|
||||
)
|
||||
|
||||
# Triggers depend on sjsonnet whose latest version still depends transitively on upickle 1.x,
|
||||
# while ammonite cannot work with upickle < 2.x. So we define the sjsonnet dependency in a
|
||||
# different maven_install.
|
||||
maven_install(
|
||||
name = "triggers_maven",
|
||||
maven_install_json = "@//:triggers_maven_install.json",
|
||||
artifacts = [
|
||||
"com.lihaoyi:sjsonnet_{}:0.3.0".format(scala_major_version),
|
||||
],
|
||||
repositories = [
|
||||
"https://repo1.maven.org/maven2",
|
||||
],
|
||||
fetch_sources = True,
|
||||
version_conflict_policy = "pinned",
|
||||
)
|
||||
|
||||
# Do not use those dependencies in anything new !
|
||||
maven_install(
|
||||
name = "deprecated_maven",
|
||||
|
96
ci/build.yml
96
ci/build.yml
@ -147,101 +147,6 @@ jobs:
|
||||
trigger_sha: '$(trigger_sha)'
|
||||
- template: report-end.yml
|
||||
|
||||
- job: Linux_oracle
|
||||
timeoutInMinutes: 240
|
||||
pool:
|
||||
name: 'ubuntu_20_04'
|
||||
demands: assignment -equals default
|
||||
steps:
|
||||
- template: report-start.yml
|
||||
- checkout: self
|
||||
- bash: ci/dev-env-install.sh
|
||||
displayName: 'Build/Install the Developer Environment'
|
||||
- template: clean-up.yml
|
||||
- bash: |
|
||||
source dev-env/lib/ensure-nix
|
||||
ci/dev-env-push.py
|
||||
displayName: 'Push Developer Environment build results'
|
||||
condition: and(succeeded(), eq(variables['System.PullRequest.IsFork'], 'False'))
|
||||
env:
|
||||
# to upload to the Nix cache
|
||||
GOOGLE_APPLICATION_CREDENTIALS_CONTENT: $(GOOGLE_APPLICATION_CREDENTIALS_CONTENT)
|
||||
NIX_SECRET_KEY_CONTENT: $(NIX_SECRET_KEY_CONTENT)
|
||||
- bash: ci/configure-bazel.sh
|
||||
displayName: 'Configure Bazel'
|
||||
env:
|
||||
IS_FORK: $(System.PullRequest.IsFork)
|
||||
# to upload to the bazel cache
|
||||
GOOGLE_APPLICATION_CREDENTIALS_CONTENT: $(GOOGLE_APPLICATION_CREDENTIALS_CONTENT)
|
||||
- bash: |
|
||||
set -euo pipefail
|
||||
eval "$(./dev-env/bin/dade-assist)"
|
||||
docker login --username "$DOCKER_LOGIN" --password "$DOCKER_PASSWORD"
|
||||
IMAGE=$(cat ci/oracle_image)
|
||||
docker pull $IMAGE
|
||||
# Cleanup stray containers that might still be running from
|
||||
# another build that didn’t get shut down cleanly.
|
||||
docker rm -f oracle || true
|
||||
# Oracle does not like if you connect to it via localhost if it’s running in the container.
|
||||
# Interestingly it works if you use the external IP of the host so the issue is
|
||||
# not the host it is listening on (it claims for that to be 0.0.0.0).
|
||||
# --network host is a cheap escape hatch for this.
|
||||
docker run -d --rm --name oracle --network host -e ORACLE_PWD=$ORACLE_PWD $IMAGE
|
||||
function cleanup() {
|
||||
docker rm -f oracle
|
||||
}
|
||||
trap cleanup EXIT
|
||||
testConnection() {
|
||||
docker exec oracle bash -c 'sqlplus -L '"$ORACLE_USERNAME"'/'"$ORACLE_PWD"'@//localhost:'"$ORACLE_PORT"'/ORCLPDB1 <<< "select * from dba_users;"; exit $?' >/dev/null
|
||||
}
|
||||
until testConnection
|
||||
do
|
||||
echo "Could not connect to Oracle, trying again..."
|
||||
sleep 1
|
||||
done
|
||||
# Actually run some tests
|
||||
# Note: Oracle tests all run sequentially because they all access the same Oracle instance,
|
||||
# and we sometimes observe transient connection issues when running tests in parallel.
|
||||
bazel test \
|
||||
--config=oracle \
|
||||
--test_strategy=exclusive \
|
||||
--test_tag_filters=+oracle \
|
||||
//...
|
||||
|
||||
oracle_logs=$(Build.StagingDirectory)/oracle-logs
|
||||
mkdir $oracle_logs
|
||||
for path in $(docker exec oracle bash -c 'find /opt/oracle/diag/rdbms/ -type f'); do
|
||||
# $path starts with a slash
|
||||
mkdir -p $(dirname ${oracle_logs}${path})
|
||||
docker exec oracle bash -c "cat $path" > ${oracle_logs}${path}
|
||||
done
|
||||
env:
|
||||
DOCKER_LOGIN: $(DOCKER_LOGIN)
|
||||
DOCKER_PASSWORD: $(DOCKER_PASSWORD)
|
||||
ARTIFACTORY_USERNAME: $(ARTIFACTORY_USERNAME)
|
||||
ARTIFACTORY_PASSWORD: $(ARTIFACTORY_PASSWORD)
|
||||
displayName: 'Build'
|
||||
condition: and(succeeded(), eq(variables['System.PullRequest.IsFork'], 'False'))
|
||||
|
||||
- task: PublishBuildArtifacts@1
|
||||
condition: failed()
|
||||
displayName: 'Publish the bazel test logs'
|
||||
inputs:
|
||||
pathtoPublish: 'bazel-testlogs/'
|
||||
artifactName: 'Test logs Oracle'
|
||||
|
||||
- task: PublishBuildArtifacts@1
|
||||
condition: failed()
|
||||
displayName: 'Publish Oracle image logs'
|
||||
inputs:
|
||||
pathtoPublish: '$(Build.StagingDirectory)/oracle-logs'
|
||||
artifactName: 'Oracle image logs'
|
||||
|
||||
- template: tell-slack-failed.yml
|
||||
parameters:
|
||||
trigger_sha: '$(trigger_sha)'
|
||||
- template: report-end.yml
|
||||
|
||||
- job: platform_independence_test
|
||||
condition: and(succeeded(),
|
||||
eq(dependencies.check_for_release.outputs['out.is_release'], 'false'))
|
||||
@ -336,7 +241,6 @@ jobs:
|
||||
condition: failed()
|
||||
dependsOn:
|
||||
- Linux
|
||||
- Linux_oracle
|
||||
- macOS
|
||||
- Windows
|
||||
- release
|
||||
|
@ -32,22 +32,6 @@ if [[ "$NAME" == "linux" ]]; then
|
||||
PROTOS_ZIP=protobufs-$RELEASE_TAG.zip
|
||||
cp bazel-bin/release/protobufs.zip $OUTPUT_DIR/github/$PROTOS_ZIP
|
||||
|
||||
TRIGGER_SERVICE=trigger-service-$RELEASE_TAG.jar
|
||||
TRIGGER_SERVICE_EE=trigger-service-$RELEASE_TAG-ee.jar
|
||||
bazel build //triggers/service:trigger-service-binary-ce_distribute.jar
|
||||
cp bazel-bin/triggers/service/trigger-service-binary-ce_distribute.jar $OUTPUT_DIR/github/$TRIGGER_SERVICE
|
||||
bazel build //triggers/service:trigger-service-binary-ee_distribute.jar
|
||||
cp bazel-bin/triggers/service/trigger-service-binary-ee_distribute.jar $OUTPUT_DIR/artifactory/$TRIGGER_SERVICE_EE
|
||||
|
||||
OAUTH2_MIDDLEWARE=oauth2-middleware-$RELEASE_TAG.jar
|
||||
bazel build //triggers/service/auth:oauth2-middleware-binary_distribute.jar
|
||||
cp bazel-bin/triggers/service/auth/oauth2-middleware-binary_distribute.jar $OUTPUT_DIR/github/$OAUTH2_MIDDLEWARE
|
||||
|
||||
|
||||
TRIGGER=daml-trigger-runner-$RELEASE_TAG.jar
|
||||
bazel build //triggers/runner:trigger-runner_distribute.jar
|
||||
cp bazel-bin/triggers/runner/trigger-runner_distribute.jar $OUTPUT_DIR/artifactory/$TRIGGER
|
||||
|
||||
SCRIPT=daml-script-$RELEASE_TAG.jar
|
||||
bazel build //daml-script/runner:daml-script-binary_distribute.jar
|
||||
cp bazel-bin/daml-script/runner/daml-script-binary_distribute.jar $OUTPUT_DIR/artifactory/$SCRIPT
|
||||
@ -58,10 +42,6 @@ if [[ "$NAME" == "linux" ]]; then
|
||||
bazel build //daml-script/daml3:daml3-script-dars
|
||||
cp bazel-bin/daml-script/daml3/*.dar $OUTPUT_DIR/split-release/daml-libs/daml-script/
|
||||
|
||||
mkdir -p $OUTPUT_DIR/split-release/daml-libs/daml-trigger
|
||||
bazel build //triggers/daml:daml-trigger-dars
|
||||
cp bazel-bin/triggers/daml/*.dar $OUTPUT_DIR/split-release/daml-libs/daml-trigger/
|
||||
|
||||
mkdir -p $OUTPUT_DIR/split-release/docs
|
||||
|
||||
bazel build //docs:sphinx-source-tree //docs:pdf-fonts-tar //docs:non-sphinx-html-docs //docs:sphinx-source-tree-deps
|
||||
|
@ -25,17 +25,11 @@ push() {
|
||||
https://digitalasset.jfrog.io/artifactory/${repository}/$RELEASE_TAG/${file}
|
||||
}
|
||||
|
||||
TRIGGER_RUNNER=daml-trigger-runner-$RELEASE_TAG.jar
|
||||
TRIGGER_SERVICE=trigger-service-$RELEASE_TAG-ee.jar
|
||||
SCRIPT_RUNNER=daml-script-$RELEASE_TAG.jar
|
||||
|
||||
cd $INPUTS
|
||||
push daml-trigger-runner $TRIGGER_RUNNER
|
||||
push daml-trigger-runner $TRIGGER_RUNNER.asc
|
||||
push daml-script-runner $SCRIPT_RUNNER
|
||||
push daml-script-runner $SCRIPT_RUNNER.asc
|
||||
push trigger-service $TRIGGER_SERVICE
|
||||
push trigger-service $TRIGGER_SERVICE.asc
|
||||
|
||||
# For the split release process these are not published to artifactory.
|
||||
if [[ "$#" -lt 3 || $3 != "split" ]]; then
|
||||
|
@ -8,11 +8,6 @@ load(
|
||||
"daml_script_example_dar",
|
||||
"daml_script_example_test",
|
||||
)
|
||||
load(
|
||||
"//bazel_tools/daml_trigger:daml_trigger.bzl",
|
||||
"daml_trigger_dar",
|
||||
"daml_trigger_test",
|
||||
)
|
||||
load(
|
||||
"//bazel_tools/data_dependencies:data_dependencies.bzl",
|
||||
"data_dependencies_coins",
|
||||
@ -103,29 +98,6 @@ head = "0.0.0"
|
||||
if versions.is_at_least(sdk_version, platform_version)
|
||||
] if not is_windows else None
|
||||
|
||||
# Change to `CommandId` generation
|
||||
first_post_7587_trigger_version = "1.7.0-snapshot.20201012.5405.0.af92198d"
|
||||
|
||||
[
|
||||
daml_trigger_dar(sdk_version)
|
||||
for sdk_version in sdk_versions
|
||||
if versions.is_at_least(first_post_7587_trigger_version, sdk_version)
|
||||
]
|
||||
|
||||
[
|
||||
daml_trigger_test(
|
||||
compiler_version = sdk_version,
|
||||
runner_version = platform_version,
|
||||
)
|
||||
for sdk_version in sdk_versions
|
||||
for platform_version in platform_versions
|
||||
# Test that the Daml trigger runner can run DARs built with an older SDK
|
||||
# version. I.e. where the runner version is at least the SDK version or
|
||||
# more recent.
|
||||
if versions.is_at_least(first_post_7587_trigger_version, sdk_version) and
|
||||
versions.is_at_least(sdk_version, platform_version)
|
||||
]
|
||||
|
||||
[
|
||||
data_dependencies_coins(
|
||||
sdk_version = sdk_version,
|
||||
|
@ -54,13 +54,6 @@ We test that the Daml Script runner from a given SDK version can load
|
||||
DARs built against the Daml Script library from an older SDK. We only
|
||||
guarantee backwards compatibility here.
|
||||
|
||||
|
||||
#### Backwards-compatibility for Daml Triggers
|
||||
|
||||
We test that the Daml Trigger runner from a given SDK version can load
|
||||
DARs built against the Daml Script library from an older SDK. We only
|
||||
guarantee backwards compatibility here.
|
||||
|
||||
#### Backwards-compatibility for data-dependencies
|
||||
|
||||
We test that we can import DARs built in older SDK versions via
|
||||
|
@ -1,4 +0,0 @@
|
||||
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
exports_files(glob(["example/**"]))
|
@ -1,239 +0,0 @@
|
||||
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
load(
|
||||
"@daml//bazel_tools/client_server:client_server_test.bzl",
|
||||
"client_server_test",
|
||||
)
|
||||
load("//bazel_tools:versions.bzl", "version_to_name", "versions")
|
||||
load("//bazel_tools:testing.bzl", "extra_tags")
|
||||
|
||||
def copy_trigger_src(sdk_version):
|
||||
# To avoid having to mess with Bazel’s escaping, avoid `$` and backticks.
|
||||
# We can’t use CPP to make this nicer unfortunately since this doesn’t work
|
||||
# with an installed SDK.
|
||||
return """
|
||||
module CopyTrigger where
|
||||
|
||||
import DA.List hiding (dedup)
|
||||
import Templates
|
||||
import Daml.Trigger
|
||||
|
||||
copyTrigger : Trigger ()
|
||||
copyTrigger = Trigger
|
||||
{{ initialize = pure ()
|
||||
, updateState = \\_message {acsArg} -> pure ()
|
||||
, rule = copyRule
|
||||
, registeredTemplates = AllInDar
|
||||
, heartbeat = None
|
||||
}}
|
||||
|
||||
copyRule party {acsArg} {stateArg} = do
|
||||
subscribers : [(ContractId Subscriber, Subscriber)] <- {query} {acsArg}
|
||||
originals : [(ContractId Original, Original)] <- {query} {acsArg}
|
||||
copies : [(ContractId Copy, Copy)] <- {query} {acsArg}
|
||||
|
||||
let ownedSubscribers = filter (\\(_, s) -> s.subscribedTo == party) subscribers
|
||||
let ownedOriginals = filter (\\(_, o) -> o.owner == party) originals
|
||||
let ownedCopies = filter (\\(_, c) -> c.original.owner == party) copies
|
||||
|
||||
let subscribingParties = map (\\(_, s) -> s.subscriber) ownedSubscribers
|
||||
|
||||
let groupedCopies : [[(ContractId Copy, Copy)]]
|
||||
groupedCopies = groupOn snd (sortOn snd ownedCopies)
|
||||
let copiesToKeep = map head groupedCopies
|
||||
let archiveDuplicateCopies = concatMap tail groupedCopies
|
||||
|
||||
let archiveMissingOriginal = filter (\\(_, c) -> notElem c.original (map snd ownedOriginals)) copiesToKeep
|
||||
let archiveMissingSubscriber = filter (\\(_, c) -> notElem c.subscriber subscribingParties) copiesToKeep
|
||||
let archiveCopies = dedup (map fst (archiveDuplicateCopies <> archiveMissingOriginal <> archiveMissingSubscriber))
|
||||
|
||||
forA archiveCopies (\\cid -> dedupExercise cid Archive)
|
||||
|
||||
let neededCopies = [Copy m o | (_, m) <- ownedOriginals, o <- subscribingParties]
|
||||
let createCopies = filter (\\c -> notElem c (map snd copiesToKeep)) neededCopies
|
||||
mapA dedupCreate createCopies
|
||||
pure ()
|
||||
|
||||
dedup : Eq k => [k] -> [k]
|
||||
dedup [] = []
|
||||
dedup (x :: xs) = x :: dedup (filter (/= x) xs)
|
||||
""".format(
|
||||
stateArg = "_" if versions.is_at_most(last_pre_7674_version, sdk_version) else "",
|
||||
acsArg = "acs" if versions.is_at_most(last_pre_7632_version, sdk_version) else "",
|
||||
query = "(pure . getContracts)" if versions.is_at_most(last_pre_7632_version, sdk_version) else "query",
|
||||
)
|
||||
|
||||
# Removal of state argument.
|
||||
last_pre_7674_version = "1.7.0-snapshot.20201013.5418.0.bda13392"
|
||||
|
||||
# Removal of ACS argument
|
||||
last_pre_7632_version = "1.7.0-snapshot.20201012.5405.0.af92198d"
|
||||
|
||||
def daml_trigger_dar(sdk_version):
|
||||
daml = "@daml-sdk-{sdk_version}//:daml".format(
|
||||
sdk_version = sdk_version,
|
||||
)
|
||||
native.genrule(
|
||||
name = "trigger-example-dar-{sdk_version}".format(
|
||||
sdk_version = version_to_name(sdk_version),
|
||||
),
|
||||
srcs = [
|
||||
"//bazel_tools/daml_trigger:example/src/TestScript.daml",
|
||||
"//bazel_tools/daml_trigger:example/src/Templates.daml",
|
||||
],
|
||||
outs = ["trigger-example-{sdk_version}.dar".format(
|
||||
sdk_version = version_to_name(sdk_version),
|
||||
)],
|
||||
tools = [daml],
|
||||
cmd = """\
|
||||
set -euo pipefail
|
||||
TMP_DIR=$$(mktemp -d)
|
||||
cleanup() {{ rm -rf $$TMP_DIR; }}
|
||||
trap cleanup EXIT
|
||||
mkdir -p $$TMP_DIR/src
|
||||
echo "{copy_trigger}" > $$TMP_DIR/src/CopyTrigger.daml
|
||||
cp -L $(location //bazel_tools/daml_trigger:example/src/TestScript.daml) $$TMP_DIR/src/
|
||||
cp -L $(location //bazel_tools/daml_trigger:example/src/Templates.daml) $$TMP_DIR/src/
|
||||
cat <<EOF >$$TMP_DIR/daml.yaml
|
||||
sdk-version: {sdk_version}
|
||||
name: trigger-example
|
||||
source: src
|
||||
version: 0.0.1
|
||||
dependencies:
|
||||
- daml-prim
|
||||
- daml-script
|
||||
- daml-stdlib
|
||||
- daml-trigger
|
||||
EOF
|
||||
$(location {daml}) build --project-root=$$TMP_DIR -o $$PWD/$(OUTS)
|
||||
""".format(
|
||||
daml = daml,
|
||||
sdk_version = sdk_version,
|
||||
copy_trigger = copy_trigger_src(sdk_version),
|
||||
),
|
||||
)
|
||||
|
||||
def daml_trigger_test(compiler_version, runner_version):
|
||||
compiled_dar = "//:trigger-example-dar-{version}".format(
|
||||
version = version_to_name(compiler_version),
|
||||
)
|
||||
daml_runner = "@daml-sdk-{version}//:daml".format(
|
||||
version = runner_version,
|
||||
)
|
||||
name = "daml-trigger-test-compiler-{compiler_version}-runner-{runner_version}".format(
|
||||
compiler_version = version_to_name(compiler_version),
|
||||
runner_version = version_to_name(runner_version),
|
||||
)
|
||||
|
||||
# 1.16.0 is the first SDK version that uses LF 1.14, which is the earliest version that canton supports
|
||||
use_canton = versions.is_at_least("2.0.0", runner_version) and versions.is_at_least("1.16.0", compiler_version)
|
||||
|
||||
server = daml_runner
|
||||
server_args = ["sandbox"] + (["--canton-port-file", "_port_file"] if (use_canton) else [])
|
||||
server_files = ["$(rootpath {})".format(compiled_dar)]
|
||||
server_files_prefix = "--dar=" if use_canton else ""
|
||||
|
||||
native.genrule(
|
||||
name = "{}-client-sh".format(name),
|
||||
outs = ["{}-client.sh".format(name)],
|
||||
cmd = """\
|
||||
cat >$(OUTS) <<'EOF'
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
canonicalize_rlocation() {{
|
||||
# Note (MK): This is a fun one: Let's say $$TEST_WORKSPACE is "compatibility"
|
||||
# and the argument points to a target from an external workspace, e.g.,
|
||||
# @daml-sdk-0.0.0//:daml. Then the short path will point to
|
||||
# ../daml-sdk-0.0.0/daml. Putting things together we end up with
|
||||
# compatibility/../daml-sdk-0.0.0/daml. On Linux and MacOS this works
|
||||
# just fine. However, on windows we need to normalize the path
|
||||
# or rlocation will fail to find the path in the manifest file.
|
||||
rlocation $$(realpath -L -s -m --relative-to=$$PWD $$TEST_WORKSPACE/$$1)
|
||||
}}
|
||||
runner=$$(canonicalize_rlocation $(rootpath {runner}))
|
||||
# Cleanup the trigger runner process but maintain the script runner exit code.
|
||||
trap 'status=$$?; kill -TERM $$PID; wait $$PID; exit $$status' INT TERM
|
||||
|
||||
SCRIPTOUTPUT=$$(mktemp -d)
|
||||
if [ {wait_for_port_file} -eq 1 ]; then
|
||||
timeout=60
|
||||
while [ ! -e _port_file ]; do
|
||||
if [ "$$timeout" = 0 ]; then
|
||||
echo "Timed out waiting for Canton startup" >&2
|
||||
exit 1
|
||||
fi
|
||||
sleep 1
|
||||
timeout=$$((timeout - 1))
|
||||
done
|
||||
fi
|
||||
if [ {upload_dar} -eq 1 ] ; then
|
||||
$$runner ledger upload-dar \\
|
||||
--host localhost \\
|
||||
--port 6865 \\
|
||||
$$(canonicalize_rlocation $(rootpath {dar}))
|
||||
fi
|
||||
$$runner script \\
|
||||
--ledger-host localhost \\
|
||||
--ledger-port 6865 \\
|
||||
--wall-clock-time \\
|
||||
--dar $$(canonicalize_rlocation $(rootpath {dar})) \\
|
||||
--script-name TestScript:allocateAlice \\
|
||||
--output-file $$SCRIPTOUTPUT/alice.json
|
||||
ALICE=$$(cat $$SCRIPTOUTPUT/alice.json | sed 's/"//g')
|
||||
rm -rf $$SCRIPTOUTPUT
|
||||
$$runner trigger \\
|
||||
--ledger-host localhost \\
|
||||
--ledger-port 6865 \\
|
||||
--ledger-party $$ALICE \\
|
||||
--wall-clock-time \\
|
||||
--dar $$(canonicalize_rlocation $(rootpath {dar})) \\
|
||||
--trigger-name CopyTrigger:copyTrigger &
|
||||
PID=$$!
|
||||
$$runner script \\
|
||||
--ledger-host localhost \\
|
||||
--ledger-port 6865 \\
|
||||
--wall-clock-time \\
|
||||
--dar $$(canonicalize_rlocation $(rootpath {dar})) \\
|
||||
--script-name TestScript:test
|
||||
EOF
|
||||
chmod +x $(OUTS)
|
||||
""".format(
|
||||
dar = compiled_dar,
|
||||
runner = daml_runner,
|
||||
upload_dar = "0",
|
||||
wait_for_port_file = "1" if use_canton else "0",
|
||||
),
|
||||
exec_tools = [
|
||||
compiled_dar,
|
||||
daml_runner,
|
||||
],
|
||||
)
|
||||
native.sh_binary(
|
||||
name = "{}-client".format(name),
|
||||
srcs = ["{}-client.sh".format(name)],
|
||||
data = [
|
||||
compiled_dar,
|
||||
daml_runner,
|
||||
],
|
||||
)
|
||||
|
||||
client_server_test(
|
||||
name = "daml-trigger-test-compiler-{compiler_version}-runner-{runner_version}".format(
|
||||
compiler_version = version_to_name(compiler_version),
|
||||
runner_version = version_to_name(runner_version),
|
||||
),
|
||||
client = "{}-client".format(name),
|
||||
client_args = [],
|
||||
client_files = [],
|
||||
data = [
|
||||
compiled_dar,
|
||||
],
|
||||
runner = "//bazel_tools/client_server:runner",
|
||||
runner_args = ["6865"],
|
||||
server = server,
|
||||
server_args = server_args,
|
||||
server_files = server_files,
|
||||
server_files_prefix = server_files_prefix,
|
||||
tags = extra_tags(compiler_version, runner_version) + ["exclusive"],
|
||||
)
|
@ -1,44 +0,0 @@
|
||||
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
-- SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
module Templates where
|
||||
|
||||
-- ORIGINAL_TEMPLATE_BEGIN
|
||||
template Original
|
||||
with
|
||||
owner : Party
|
||||
name : Text
|
||||
textdata : Text
|
||||
where
|
||||
signatory owner
|
||||
|
||||
key (owner, name) : (Party, Text)
|
||||
maintainer key._1
|
||||
-- ORIGINAL_TEMPLATE_END
|
||||
|
||||
deriving instance Ord Original
|
||||
|
||||
-- SUBSCRIBER_TEMPLATE_BEGIN
|
||||
template Subscriber
|
||||
with
|
||||
subscriber : Party
|
||||
subscribedTo : Party
|
||||
where
|
||||
signatory subscriber
|
||||
observer subscribedTo
|
||||
key (subscriber, subscribedTo) : (Party, Party)
|
||||
maintainer key._1
|
||||
-- SUBSCRIBER_TEMPLATE_END
|
||||
|
||||
-- COPY_TEMPLATE_BEGIN
|
||||
template Copy
|
||||
with
|
||||
original : Original
|
||||
subscriber : Party
|
||||
where
|
||||
signatory (signatory original)
|
||||
observer subscriber
|
||||
-- COPY_TEMPLATE_END
|
||||
|
||||
deriving instance Ord Copy
|
||||
|
@ -1,66 +0,0 @@
|
||||
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
-- SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
module TestScript where
|
||||
|
||||
import Templates
|
||||
import DA.Assert
|
||||
import DA.Time
|
||||
import Daml.Script
|
||||
|
||||
allocateAlice : Script Party
|
||||
allocateAlice = do
|
||||
debug "Creating Alice ..."
|
||||
alice <- allocatePartyWithHint "Alice" (PartyIdHint "Alice")
|
||||
debug alice
|
||||
debug "... done"
|
||||
pure alice
|
||||
|
||||
test : Script ()
|
||||
test = do
|
||||
debug "Searching for Alice ..."
|
||||
let isAlice x = displayName x == Some "Alice"
|
||||
Some aliceDetails <- find isAlice <$> listKnownParties
|
||||
let alice = party aliceDetails
|
||||
debug alice
|
||||
debug "... done"
|
||||
|
||||
debug "Creating Bob ..."
|
||||
bob <- allocatePartyWithHint "Bob" (PartyIdHint "Bob")
|
||||
debug alice
|
||||
debug bob
|
||||
debug "... done"
|
||||
|
||||
debug "Creating Subscriber ..."
|
||||
submit bob $ do
|
||||
createCmd (Subscriber bob alice)
|
||||
debug "... done"
|
||||
|
||||
debug "Creating Original ..."
|
||||
let original = Original alice "original" "data"
|
||||
submit alice $ do
|
||||
createCmd original
|
||||
debug "... done"
|
||||
|
||||
debug "Waiting for copy ..."
|
||||
copy <- until $ do
|
||||
copies <- query @Copy bob
|
||||
case copies of
|
||||
[(_, copy)] -> pure (Some copy)
|
||||
xs -> do
|
||||
debug xs
|
||||
pure None
|
||||
debug "... done"
|
||||
|
||||
debug "Asserting equality ..."
|
||||
assertEq (Copy original bob) copy
|
||||
debug "... done"
|
||||
|
||||
until : Script (Optional a) -> Script a
|
||||
until action = do
|
||||
result <- action
|
||||
case result of
|
||||
Some a -> pure a
|
||||
None -> do
|
||||
sleep (convertMicrosecondsToRelTime 10000)
|
||||
until action
|
File diff suppressed because it is too large
Load Diff
@ -492,11 +492,10 @@ fileTest externalAnchors scriptPackageData damlFile = do
|
||||
other -> error $ "Unsupported file extension " <> other
|
||||
where
|
||||
diff ref new = [POSIX_DIFF, "--strip-trailing-cr", ref, new]
|
||||
-- In cases where daml-script/daml-trigger is used, the version of the package is embedded in the json.
|
||||
-- In cases where daml-script is used, the version of the package is embedded in the json.
|
||||
-- When we release, this version changes, which would break the golden file test.
|
||||
-- Instead, we omit daml-script/daml-trigger versions from .EXPECTED.json files in golden tests.
|
||||
-- Instead, we omit daml-script versions from .EXPECTED.json files in golden tests.
|
||||
replaceSdkPackages =
|
||||
TL.encodeUtf8
|
||||
. TL.replace (TL.pack $ "daml-script-" <> sdkPackageVersion) "daml-script-UNVERSIONED"
|
||||
. TL.replace (TL.pack $ "daml-trigger-" <> sdkPackageVersion) "daml-trigger-UNVERSIONED"
|
||||
. TL.decodeUtf8
|
||||
|
@ -1,7 +1,7 @@
|
||||
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
-- SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
-- | For compiler level warnings on packages that aren't included as standard. This includes, at the very least, Daml.Script and Triggers
|
||||
-- | For compiler level warnings on packages that aren't included as standard. This includes, at the very least Daml.Script
|
||||
module DA.Daml.LFConversion.ExternalWarnings (topLevelWarnings) where
|
||||
|
||||
import DA.Daml.LFConversion.ConvertM
|
||||
|
@ -289,9 +289,9 @@ dataDependableExtensions :: ES.EnumSet Extension
|
||||
dataDependableExtensions = ES.fromList $ xExtensionsSet ++
|
||||
[ -- useful for beginners to learn about type inference
|
||||
PartialTypeSignatures
|
||||
-- needed for script and triggers
|
||||
-- needed for script
|
||||
, ApplicativeDo
|
||||
-- used in daml-stdlib and triggers and a very reasonable
|
||||
-- used in daml-stdlib and very reasonable
|
||||
-- extension in general in the presence of TypeApplications
|
||||
, AllowAmbiguousTypes
|
||||
-- helpful for documentation purposes
|
||||
@ -306,7 +306,7 @@ dataDependableExtensions = ES.fromList $ xExtensionsSet ++
|
||||
-- would be silly
|
||||
, ExplicitNamespaces
|
||||
-- there's no way for our users to actually use this and listing it here
|
||||
-- removes a lot of warning from out stdlib, script and trigger builds
|
||||
-- removes a lot of warning from out stdlib and script builds
|
||||
-- NOTE: This should not appear on any list of extensions that are
|
||||
-- compatible with data-dependencies since this would spur wrong hopes.
|
||||
, Cpp
|
||||
@ -544,14 +544,14 @@ checkDFlags Options {..} dflags@DynFlags {..}
|
||||
\ or the dependency."
|
||||
|
||||
-- Expand SDK package dependencies using the SDK root path.
|
||||
-- E.g. `daml-trigger` --> `$DAML_SDK/daml-libs/daml-trigger.dar`
|
||||
-- E.g. `daml-script` --> `$DAML_SDK/daml-libs/daml-script.dar`
|
||||
-- When invoked outside of the SDK, we will only error out
|
||||
-- if there is actually an SDK package so that
|
||||
-- When there is no SDK
|
||||
expandSdkPackages :: Logger.Handle IO -> LF.Version -> [FilePath] -> IO [FilePath]
|
||||
expandSdkPackages logger lfVersion dars = do
|
||||
mbSdkPath <- handleIO (\_ -> pure Nothing) $ Just <$> getSdkPath
|
||||
mapM (expand mbSdkPath) (nubOrd $ concatMap addDep dars)
|
||||
mapM (expand mbSdkPath) (nubOrd dars)
|
||||
where
|
||||
isSdkPackage fp = takeExtension fp `notElem` [".dar", ".dalf"]
|
||||
isInvalidDaml3Script = \case
|
||||
@ -569,14 +569,6 @@ expandSdkPackages logger lfVersion dars = do
|
||||
pure $ sdkPath </> "daml-libs" </> fp <> sdkSuffix <.> "dar"
|
||||
Nothing -> fail $ "Cannot resolve SDK dependency '" ++ fp ++ "'. Use daml assistant."
|
||||
| otherwise = pure fp
|
||||
-- For `dependencies` you need to specify all transitive dependencies.
|
||||
-- However, for the packages in the SDK that is an implementation detail
|
||||
-- so we automagically insert `daml-script` if you’ve specified `daml-trigger`.
|
||||
addDep fp
|
||||
| isSdkPackage fp = fp : Map.findWithDefault [] fp sdkDependencies
|
||||
| otherwise = [fp]
|
||||
sdkDependencies = Map.fromList
|
||||
[ ("daml-trigger", ["daml-script"]) ]
|
||||
|
||||
|
||||
mkPackageFlag :: UnitId -> PackageFlag
|
||||
|
@ -28,8 +28,8 @@ error msg = primitive @"EThrow" (GeneralError msg)
|
||||
-- | `error` stops execution and displays the given error message.
|
||||
--
|
||||
-- If called within a transaction, it will abort the current transaction.
|
||||
-- Outside of a transaction (scenarios, Daml Script or Daml Triggers)
|
||||
-- it will stop the whole scenario/script/trigger.
|
||||
-- Outside of a transaction (scenarios and Daml Script)
|
||||
-- it will stop the whole scenario/script.
|
||||
error : Text -> a
|
||||
error = primitive @"BEError"
|
||||
|
||||
|
@ -23,7 +23,7 @@ import DA.Internal.Exception
|
||||
-- in this context.
|
||||
class Action m => CanAssert m where
|
||||
-- | Abort since an assertion has failed. In an Update, Scenario,
|
||||
-- Script, or Trigger context this will throw an AssertionFailed
|
||||
-- or Script context this will throw an AssertionFailed
|
||||
-- exception. In an `Either Text` context, this will return the
|
||||
-- message as an error.
|
||||
assertFail : Text -> m t
|
||||
|
@ -161,7 +161,7 @@ withVersionedDamlScriptDep packageFlagName darPath mLfVer extraPackages cont = d
|
||||
withCurrentDirectory dir $ do
|
||||
let projDir = toNormalizedFilePath' dir
|
||||
-- Bring in daml-script as previously installed by withDamlScriptDep, must include package db
|
||||
-- daml-script and daml-triggers use the sdkPackageVersion for their versioning
|
||||
-- daml-script use the sdkPackageVersion for their versioning
|
||||
mkPackageFlag flagName = ExposePackage ("--package " <> flagName) (UnitIdArg $ stringToUnitId flagName) (ModRenaming True [])
|
||||
toPackageName (name, version) = name <> "-" <> version
|
||||
packageFlags = mkPackageFlag <$> packageFlagName : (toPackageName <$> extraPackages)
|
||||
|
@ -9,10 +9,6 @@ load(
|
||||
load("@os_info//:os_info.bzl", "is_windows")
|
||||
load(":util.bzl", "deps")
|
||||
|
||||
scala_deps = [
|
||||
"@maven//:com_typesafe_scala_logging_scala_logging",
|
||||
]
|
||||
|
||||
scala_runtime_deps = [
|
||||
"@maven//:org_apache_pekko_pekko_slf4j",
|
||||
"@maven//:org_tpolecat_doobie_postgres",
|
||||
@ -30,7 +26,7 @@ da_scala_binary(
|
||||
srcs = glob(["src/main/scala/**/*.scala"]),
|
||||
main_class = "com.daml.sdk.SdkMain",
|
||||
resources = glob(["src/main/resources/**/*"]),
|
||||
scala_deps = scala_deps,
|
||||
scala_deps = [],
|
||||
scala_runtime_deps = scala_runtime_deps,
|
||||
visibility = ["//visibility:public"],
|
||||
runtime_deps = runtime_deps,
|
||||
@ -42,7 +38,7 @@ da_scala_binary(
|
||||
srcs = glob(["src/main/scala/**/*.scala"]),
|
||||
main_class = "com.daml.sdk.SdkMain",
|
||||
resources = glob(["src/main/resources/**/*"]),
|
||||
scala_deps = scala_deps,
|
||||
scala_deps = [],
|
||||
scala_runtime_deps = scala_runtime_deps,
|
||||
tags = ["ee-jar-license"],
|
||||
visibility = ["//visibility:public"],
|
||||
|
@ -5,9 +5,6 @@ package com.daml.sdk
|
||||
|
||||
import com.daml.codegen.{CodegenMain => Codegen}
|
||||
import com.daml.lf.engine.script.{ScriptMain => Script}
|
||||
import com.daml.lf.engine.trigger.{RunnerMain => Trigger}
|
||||
import com.daml.lf.engine.trigger.{ServiceMain => TriggerService}
|
||||
import com.daml.auth.middleware.oauth2.{Main => Oauth2Middleware}
|
||||
import com.daml.script.export.{Main => Export}
|
||||
|
||||
object SdkMain {
|
||||
@ -15,12 +12,9 @@ object SdkMain {
|
||||
val command = args(0)
|
||||
val rest = args.drop(1)
|
||||
command match {
|
||||
case "trigger" => Trigger.main(rest)
|
||||
case "script" => Script.main(rest)
|
||||
case "export" => Export.main(rest)
|
||||
case "codegen" => Codegen.main(rest)
|
||||
case "trigger-service" => TriggerService.main(rest)
|
||||
case "oauth2-middleware" => Oauth2Middleware.main(rest)
|
||||
case _ => sys.exit(1)
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,4 @@ def deps(edition):
|
||||
"//daml-script/runner:script-runner-lib",
|
||||
"//language-support/codegen-main:codegen-main-lib",
|
||||
"//daml-script/export",
|
||||
"//triggers/runner:trigger-runner-lib",
|
||||
"//triggers/service:trigger-service-binary-{}".format(edition),
|
||||
"//triggers/service/auth:oauth2-middleware",
|
||||
]
|
||||
|
@ -18,10 +18,3 @@ set -eou pipefail
|
||||
JAVA=$(rlocation "$TEST_WORKSPACE/$1")
|
||||
SDK_CE=$(rlocation "$TEST_WORKSPACE/$2")
|
||||
SDK_EE=$(rlocation "$TEST_WORKSPACE/$3")
|
||||
|
||||
if ! ($JAVA -jar $SDK_EE trigger-service --help | grep -q oracle); then
|
||||
exit 1
|
||||
fi
|
||||
if $JAVA -jar $SDK_CE trigger-service --help | grep -q oracle; then
|
||||
exit 1
|
||||
fi
|
||||
|
@ -370,7 +370,7 @@ argWhitelist = S.fromList
|
||||
, "install", "latest", "project"
|
||||
, "uninstall"
|
||||
, "studio", "never", "always", "published"
|
||||
, "new", "skeleton", "empty-skeleton", "quickstart-java", "copy-trigger", "gsg-trigger"
|
||||
, "new", "skeleton", "empty-skeleton", "quickstart-java"
|
||||
, "daml-intro-1", "daml-intro-2", "daml-intro-3", "daml-intro-4"
|
||||
, "daml-intro-5", "daml-intro-6", "daml-intro-7", "script-example"
|
||||
, "daml-intro-13"
|
||||
@ -386,8 +386,6 @@ argWhitelist = S.fromList
|
||||
, "codegen", "java", "js"
|
||||
, "deploy"
|
||||
, "json-api"
|
||||
, "trigger", "trigger-service", "list"
|
||||
, "oauth2-middleware"
|
||||
, "script"
|
||||
]
|
||||
|
||||
|
@ -31,7 +31,7 @@ import Test.Tasty.HUnit
|
||||
|
||||
import DA.Bazel.Runfiles
|
||||
import DA.Daml.Assistant.IntegrationTestUtils
|
||||
import DA.Daml.Helper.Util (waitForHttpServer, tokenFor, decodeCantonSandboxPort)
|
||||
import DA.Daml.Helper.Util (tokenFor, decodeCantonSandboxPort)
|
||||
import DA.Test.Daml2jsUtils
|
||||
import DA.Test.Process (callCommandSilent, callCommandSilentIn, subprocessEnv)
|
||||
import DA.Test.Util
|
||||
@ -182,7 +182,6 @@ tests tmpDir =
|
||||
, testCase "daml new --list" $
|
||||
callCommandSilentIn tmpDir "daml new --list"
|
||||
, packagingTests tmpDir
|
||||
, damlToolTests
|
||||
, withResource (damlStart (tmpDir </> "sandbox-canton")) stop damlStartTests
|
||||
, cleanTests cleanDir
|
||||
, templateTests
|
||||
@ -200,53 +199,7 @@ packagingTests :: SdkVersioned => FilePath -> TestTree
|
||||
packagingTests tmpDir =
|
||||
testGroup
|
||||
"packaging"
|
||||
[ testCase "Build copy trigger" $ do
|
||||
let projDir = tmpDir </> "copy-trigger1"
|
||||
callCommandSilent $ unwords ["daml", "new", projDir, "--template=copy-trigger"]
|
||||
callCommandSilentIn projDir "daml build"
|
||||
let dar = projDir </> ".daml" </> "dist" </> "copy-trigger1-0.0.1.dar"
|
||||
assertFileExists dar
|
||||
, testCase "Build copy trigger with LF version 1.dev" $ do
|
||||
let projDir = tmpDir </> "copy-trigger2"
|
||||
callCommandSilent $ unwords ["daml", "new", projDir, "--template=copy-trigger"]
|
||||
callCommandSilentIn projDir "daml build --target 1.dev"
|
||||
let dar = projDir </> ".daml" </> "dist" </> "copy-trigger2-0.0.1.dar"
|
||||
assertFileExists dar
|
||||
, testCase "Build trigger with extra dependency" $ do
|
||||
let myDepDir = tmpDir </> "mydep"
|
||||
createDirectoryIfMissing True (myDepDir </> "daml")
|
||||
writeFileUTF8 (myDepDir </> "daml.yaml") $
|
||||
unlines
|
||||
[ "sdk-version: " <> sdkVersion
|
||||
, "name: mydep"
|
||||
, "version: \"1.0\""
|
||||
, "source: daml"
|
||||
, "dependencies:"
|
||||
, " - daml-prim"
|
||||
, " - daml-stdlib"
|
||||
]
|
||||
writeFileUTF8 (myDepDir </> "daml" </> "MyDep.daml") $ unlines ["module MyDep where"]
|
||||
callCommandSilentIn myDepDir "daml build -o mydep.dar"
|
||||
let myTriggerDir = tmpDir </> "mytrigger"
|
||||
createDirectoryIfMissing True (myTriggerDir </> "daml")
|
||||
writeFileUTF8 (myTriggerDir </> "daml.yaml") $
|
||||
unlines
|
||||
[ "sdk-version: " <> sdkVersion
|
||||
, "name: mytrigger"
|
||||
, "version: \"1.0\""
|
||||
, "source: daml"
|
||||
, "dependencies:"
|
||||
, " - daml-prim"
|
||||
, " - daml-stdlib"
|
||||
, " - daml-trigger"
|
||||
, " - " <> myDepDir </> "mydep.dar"
|
||||
]
|
||||
writeFileUTF8 (myTriggerDir </> "daml/Main.daml") $
|
||||
unlines ["module Main where", "import MyDep ()", "import Daml.Trigger ()"]
|
||||
callCommandSilentIn myTriggerDir "daml build -o mytrigger.dar"
|
||||
let dar = myTriggerDir </> "mytrigger.dar"
|
||||
assertFileExists dar
|
||||
, testCase "Build Daml script example" $ do
|
||||
[ testCase "Build Daml script example" $ do
|
||||
let projDir = tmpDir </> "script-example"
|
||||
callCommandSilent $ unwords ["daml", "new", projDir, "--template=script-example"]
|
||||
callCommandSilentIn projDir "daml build"
|
||||
@ -259,7 +212,7 @@ packagingTests tmpDir =
|
||||
callCommandSilentIn projDir "daml build --target 1.dev"
|
||||
let dar = projDir </> ".daml/dist/script-example-0.0.1.dar"
|
||||
assertFileExists dar -}
|
||||
, testCase "Package depending on daml-script and daml-trigger can use data-dependencies" $ do
|
||||
, testCase "Package depending on daml-script can use data-dependencies" $ do
|
||||
callCommandSilent $ unwords ["daml", "new", tmpDir </> "data-dependency"]
|
||||
callCommandSilentIn (tmpDir </> "data-dependency") "daml build -o data-dependency.dar"
|
||||
createDirectoryIfMissing True (tmpDir </> "proj")
|
||||
@ -269,7 +222,7 @@ packagingTests tmpDir =
|
||||
, "name: proj"
|
||||
, "version: 0.0.1"
|
||||
, "source: ."
|
||||
, "dependencies: [daml-prim, daml-stdlib, daml-script, daml-trigger]"
|
||||
, "dependencies: [daml-prim, daml-stdlib, daml-script]"
|
||||
, "data-dependencies: [" <>
|
||||
show (tmpDir </> "data-dependency" </> "data-dependency.dar") <>
|
||||
"]"
|
||||
@ -285,39 +238,6 @@ packagingTests tmpDir =
|
||||
callCommandSilentIn (tmpDir </> "proj") "daml build"
|
||||
]
|
||||
|
||||
-- Test tools that can run outside a daml project
|
||||
damlToolTests :: TestTree
|
||||
damlToolTests =
|
||||
testGroup
|
||||
"daml tools"
|
||||
[ testCase "OAuth 2.0 middleware startup" $ do
|
||||
withTempDir $ \tmpDir -> do
|
||||
middlewarePort <- getFreePort
|
||||
withDamlServiceIn tmpDir "oauth2-middleware"
|
||||
[ "--address"
|
||||
, "localhost"
|
||||
, "--http-port"
|
||||
, show middlewarePort
|
||||
, "--oauth-auth"
|
||||
, "http://localhost:0/authorize"
|
||||
, "--oauth-token"
|
||||
, "http://localhost:0/token"
|
||||
, "--auth-jwt-hs256-unsafe"
|
||||
, "jwt-secret"
|
||||
, "--id"
|
||||
, "client-id"
|
||||
, "--secret"
|
||||
, "client-secret"
|
||||
] $ \ ph -> do
|
||||
let endpoint =
|
||||
"http://localhost:" <> show middlewarePort <> "/livez"
|
||||
waitForHttpServer 240 ph (threadDelay 500000) endpoint []
|
||||
req <- parseRequest endpoint
|
||||
manager <- newManager defaultManagerSettings
|
||||
resp <- httpLbs req manager
|
||||
responseBody resp @?= "{\"status\":\"pass\"}"
|
||||
]
|
||||
|
||||
-- We are trying to run as many tests with the same `daml start` process as possible to safe time.
|
||||
damlStartTests :: SdkVersioned => IO DamlStartResource -> TestTree
|
||||
damlStartTests getDamlStart =
|
||||
@ -394,24 +314,6 @@ damlStartTests getDamlStart =
|
||||
didGenerateDamlYaml <- doesFileExist (exportDir </> "daml.yaml")
|
||||
didGenerateExportDaml @?= True
|
||||
didGenerateDamlYaml @?= True
|
||||
subtest "trigger service startup" $ do
|
||||
DamlStartResource {projDir, sandboxPort} <- getDamlStart
|
||||
triggerServicePort <- getFreePort
|
||||
withDamlServiceIn projDir "trigger-service"
|
||||
[ "--ledger-host"
|
||||
, "localhost"
|
||||
, "--ledger-port"
|
||||
, show sandboxPort
|
||||
, "--http-port"
|
||||
, show triggerServicePort
|
||||
, "--wall-clock-time"
|
||||
] $ \ ph -> do
|
||||
let endpoint = "http://localhost:" <> show triggerServicePort <> "/livez"
|
||||
waitForHttpServer 240 ph (threadDelay 500000) endpoint []
|
||||
req <- parseRequest endpoint
|
||||
manager <- newManager defaultManagerSettings
|
||||
resp <- httpLbs req manager
|
||||
responseBody resp @?= "{\"status\":\"pass\"}"
|
||||
|
||||
subtest "hot reload" $ do
|
||||
DamlStartResource {projDir, jsonApiPort, startStdin, stdoutChan, alice, aliceHeaders} <- getDamlStart
|
||||
@ -529,10 +431,8 @@ templateTests = testGroup "templates" $
|
||||
-- NOTE (MK) We might want to autogenerate this list at some point but for now
|
||||
-- this should be good enough.
|
||||
where templateNames =
|
||||
[ "copy-trigger"
|
||||
, "gsg-trigger"
|
||||
-- daml-intro-1 - daml-intro-6 are not full projects.
|
||||
, "daml-intro-7"
|
||||
[ -- daml-intro-1 - daml-intro-6 are not full projects.
|
||||
"daml-intro-7"
|
||||
, "daml-patterns"
|
||||
, "quickstart-java"
|
||||
, "script-example"
|
||||
|
@ -343,7 +343,7 @@ private[lf] final class ValueTranslator(
|
||||
|
||||
// This does not try to pull missing packages, return an error instead.
|
||||
// TODO: https://github.com/digital-asset/daml/issues/17082
|
||||
// This is used by script and trigger, this should problaby use ValueTranslator.Config.Strict
|
||||
// This is used by script, this should problaby use ValueTranslator.Config.Strict
|
||||
def strictTranslateValue(
|
||||
ty: Type,
|
||||
value: Value,
|
||||
|
@ -35,7 +35,6 @@ da_scala_library(
|
||||
"//daml-lf:__subpackages__",
|
||||
"//daml-script:__subpackages__",
|
||||
"//ledger:__subpackages__",
|
||||
"//triggers:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"//daml-lf/data",
|
||||
|
@ -10,7 +10,7 @@ private[lf] sealed abstract class InitialSeeding extends Product with Serializab
|
||||
|
||||
private[lf] object InitialSeeding {
|
||||
// NoSeed may be used to initialize machines that are not intended to create transactions
|
||||
// e.g. trigger and script runners, tests
|
||||
// e.g. script runners, tests
|
||||
final case object NoSeed extends InitialSeeding
|
||||
final case class TransactionSeed(seed: crypto.Hash) extends InitialSeeding
|
||||
final case class RootNodeSeeds(seeds: ImmArray[Option[crypto.Hash]]) extends InitialSeeding
|
||||
|
@ -21,7 +21,6 @@ da_scala_library(
|
||||
visibility = [
|
||||
"//daml-lf:__subpackages__",
|
||||
"//ledger:__subpackages__",
|
||||
"//triggers:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"//daml-lf/data",
|
||||
|
@ -28,14 +28,12 @@ genrule(
|
||||
srcs = [
|
||||
"//compiler/damlc:daml-base-hoogle.txt",
|
||||
"//daml-script/daml:daml-script-hoogle.txt",
|
||||
"//triggers/daml:daml-trigger-hoogle.txt",
|
||||
],
|
||||
outs = ["hoogle_db.tar.gz"],
|
||||
cmd = """
|
||||
mkdir hoogle
|
||||
cp -L $(location //compiler/damlc:daml-base-hoogle.txt) hoogle/
|
||||
cp -L $(location //daml-script/daml:daml-script-hoogle.txt) hoogle/
|
||||
cp -L $(location //triggers/daml:daml-trigger-hoogle.txt) hoogle/
|
||||
$(execpath //bazel_tools/sh:mktgz) $@ hoogle
|
||||
""",
|
||||
tools = ["//bazel_tools/sh:mktgz"],
|
||||
@ -76,7 +74,6 @@ genrule(
|
||||
name = "sources",
|
||||
srcs = glob(["source/**"]) + [
|
||||
"//compiler/damlc:daml-base-rst.tar.gz",
|
||||
"//triggers/daml:daml-trigger-rst.tar.gz",
|
||||
"//daml-script/daml:daml-script-rst.tar.gz",
|
||||
"//canton:ledger-api-docs",
|
||||
"//:LICENSE",
|
||||
@ -102,12 +99,6 @@ genrule(
|
||||
--strip-components 1 \\
|
||||
-C source/daml/stdlib
|
||||
|
||||
# Copy in daml-trigger documentation
|
||||
mkdir -p source/triggers/api/
|
||||
tar xf $(location //triggers/daml:daml-trigger-rst.tar.gz) \\
|
||||
--strip-components 1 \\
|
||||
-C source/triggers/api/
|
||||
|
||||
# Copy in daml-script documentation
|
||||
mkdir -p source/daml-script/api/
|
||||
tar xf $(location //daml-script/daml:daml-script-rst.tar.gz) \\
|
||||
@ -140,7 +131,6 @@ genrule(
|
||||
":generate-docs-error-codes-inventory-into-rst-file",
|
||||
":generate-docs-error-categories-inventory-into-rst-file",
|
||||
"//compiler/damlc:daml-base-rst.tar.gz",
|
||||
"//triggers/daml:daml-trigger-rst.tar.gz",
|
||||
"//daml-script/daml:daml-script-rst.tar.gz",
|
||||
"//canton:ledger-api-docs",
|
||||
"//:LICENSE",
|
||||
@ -166,7 +156,6 @@ genrule(
|
||||
cp -L -- $(location //docs:generate-docs-error-codes-inventory-into-rst-file) $$DIR/deps/
|
||||
cp -L -- $(location //docs:generate-docs-error-categories-inventory-into-rst-file) $$DIR/deps/
|
||||
cp $(location //compiler/damlc:daml-base-rst.tar.gz) $$DIR/deps/
|
||||
cp $(location //triggers/daml:daml-trigger-rst.tar.gz) $$DIR/deps/
|
||||
cp $(location //daml-script/daml:daml-script-rst.tar.gz) $$DIR/deps/
|
||||
cp -L $(location //canton:ledger-api-docs) $$DIR/deps/
|
||||
cp -L $(location //:LICENSE) $$DIR/deps/
|
||||
@ -622,18 +611,6 @@ daml_build_test(
|
||||
project_dir = "source/upgrade/example/carbon-initiate-upgrade",
|
||||
)
|
||||
|
||||
daml_build_test(
|
||||
name = "daml-upgrade-example-upgrade-trigger",
|
||||
dar_dict = {
|
||||
":daml-upgrade-example-v1": "path/to/carbon-1.0.0.dar",
|
||||
":daml-upgrade-example-v2": "path/to/carbon-2.0.0.dar",
|
||||
":daml-upgrade-example-upgrade": "path/to/carbon-upgrade-1.0.0.dar",
|
||||
"//triggers/daml:daml-trigger.dar": "daml-trigger.dar",
|
||||
"//daml-script/daml:daml-script.dar": "daml-script.dar",
|
||||
},
|
||||
project_dir = "source/upgrade/example/carbon-upgrade-trigger",
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "daml-intro-1",
|
||||
srcs = glob(
|
||||
|
@ -791,8 +791,6 @@ daml/intro/9_Functional101.html -> /daml/intro/10_Functional101.html
|
||||
daml/intro/10_StdLib.html -> /daml/intro/11_StdLib.html
|
||||
daml/intro/11_Testing.html -> /daml/intro/12_Testing.html
|
||||
daml-script/daml-script-docs.html -> /daml-script/api/index.html
|
||||
triggers/trigger-docs.html -> /triggers/api/index.html
|
||||
tools/trigger-service.html -> trigger-service/index.html
|
||||
support.html -> /support/support.html
|
||||
release-notes.html -> /support/releases.html#release-notes
|
||||
support/release-notes.html -> /support/releases.html#release-notes
|
||||
|
@ -1,18 +0,0 @@
|
||||
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
sdk-version: 0.0.0
|
||||
# BEGIN
|
||||
name: carbon-upgrade-trigger
|
||||
version: 1.0.0
|
||||
dependencies:
|
||||
- daml-prim
|
||||
- daml-stdlib
|
||||
- daml-trigger
|
||||
- daml-script
|
||||
data-dependencies:
|
||||
- path/to/carbon-upgrade-1.0.0.dar
|
||||
- path/to/carbon-1.0.0.dar
|
||||
- path/to/carbon-2.0.0.dar
|
||||
# END
|
||||
source: .
|
@ -1,65 +0,0 @@
|
||||
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
-- SPDX-License-Identifier: Apache-2.0
|
||||
{-# LANGUAGE CPP #-}
|
||||
|
||||
module UpgradeTrigger where
|
||||
|
||||
import DA.Assert
|
||||
import DA.Foldable
|
||||
import qualified DA.Map as Map
|
||||
import Daml.Trigger
|
||||
import Daml.Trigger.Assert
|
||||
import qualified Daml.Script as Script
|
||||
import Daml.Script (script)
|
||||
|
||||
import CarbonV1
|
||||
import UpgradeFromCarbonCertV1
|
||||
|
||||
-- TRIGGER_BOILERPLATE_BEGIN
|
||||
upgradeTrigger : Trigger ()
|
||||
upgradeTrigger = Trigger with
|
||||
initialize = pure ()
|
||||
updateState = \_msg -> pure ()
|
||||
registeredTemplates = AllInDar
|
||||
heartbeat = None
|
||||
rule = triggerRule
|
||||
-- TRIGGER_BOILERPLATE_END
|
||||
|
||||
-- TRIGGER_RULE_BEGIN
|
||||
triggerRule : Party -> TriggerA () ()
|
||||
triggerRule issuer = do
|
||||
agreements <-
|
||||
filter (\(_cid, agreement) -> agreement.issuer == issuer) <$>
|
||||
query @UpgradeCarbonCertAgreement
|
||||
allCerts <-
|
||||
filter (\(_cid, cert) -> cert.issuer == issuer) <$>
|
||||
query @CarbonCert
|
||||
forA_ agreements $ \(agreementCid, agreement) -> do
|
||||
let certsForOwner = filter (\(_cid, cert) -> cert.owner == agreement.owner) allCerts
|
||||
forA_ certsForOwner $ \(certCid, _) ->
|
||||
emitCommands
|
||||
[exerciseCmd agreementCid (Upgrade certCid)]
|
||||
[toAnyContractId certCid]
|
||||
-- TRIGGER_RULE_END
|
||||
|
||||
-- TODO (MK) The Bazel rule atm doesn’t run this script, we should fix that.
|
||||
test = script do
|
||||
alice <- Script.allocateParty "Alice"
|
||||
bob <- Script.allocateParty "Bob"
|
||||
|
||||
certProposal <- submit alice $ Script.createCmd (CarbonCertProposal alice bob 10)
|
||||
cert <- submit bob $ Script.exerciseCmd certProposal CarbonCertProposal_Accept
|
||||
|
||||
upgradeProposal <- submit alice $ Script.createCmd (UpgradeCarbonCertProposal alice bob)
|
||||
upgradeAgreement <- submit bob $ Script.exerciseCmd upgradeProposal Accept
|
||||
|
||||
let acs = toACS cert <> toACS upgradeAgreement
|
||||
|
||||
(_, commands) <- testRule upgradeTrigger alice [] acs Map.empty ()
|
||||
let flatCommands = flattenCommands commands
|
||||
assertExerciseCmd flatCommands $ \(cid, choiceArg) -> do
|
||||
cid === upgradeAgreement
|
||||
choiceArg === Upgrade cert
|
||||
-- TODO (MK) It would be nice to test for the absence of certain commands as well
|
||||
-- or ideally just assert that the list of emitted commands matches an expected
|
||||
-- list of commands.
|
@ -209,8 +209,6 @@
|
||||
type: jar-scala
|
||||
- target: //libs-scala/timer-utils:timer-utils
|
||||
type: jar-scala
|
||||
- target: //triggers/runner:trigger-runner-lib
|
||||
type: jar-scala
|
||||
- target: //libs-scala/nameof:nameof
|
||||
type: jar-scala
|
||||
- target: //libs-scala/struct-json/struct-spray-json:struct-spray-json
|
||||
|
@ -71,18 +71,6 @@ commands:
|
||||
- name: ide
|
||||
path: damlc/damlc
|
||||
args: ["ide"]
|
||||
- name: trigger-service
|
||||
path: daml-helper/daml-helper
|
||||
desc: "Launch the trigger service"
|
||||
args: ["run-jar", "--logback-config=daml-sdk/trigger-service-logback.xml", "daml-sdk/daml-sdk.jar", "trigger-service"]
|
||||
- name: oauth2-middleware
|
||||
path: daml-helper/daml-helper
|
||||
desc: "Launch the OAuth 2.0 middleware"
|
||||
args: ["run-jar", "--logback-config=daml-sdk/oauth2-middleware-logback.xml", "daml-sdk/daml-sdk.jar", "oauth2-middleware"]
|
||||
- name: trigger
|
||||
path: daml-helper/daml-helper
|
||||
args: ["run-jar", "--logback-config=daml-sdk/trigger-logback.xml", "daml-sdk/daml-sdk.jar", "trigger"]
|
||||
desc: "Run a Daml trigger"
|
||||
- name: script
|
||||
path: daml-helper/daml-helper
|
||||
args: ["run-jar", "--logback-config=daml-sdk/script-logback.xml", "daml-sdk/daml-sdk.jar", "script"]
|
||||
|
@ -8,9 +8,6 @@ inputs = {
|
||||
"sdk_config": ":sdk-config.yaml.tmpl",
|
||||
"install_sh": ":install.sh",
|
||||
"install_bat": ":install.bat",
|
||||
"oauth2_middleware_logback": "//triggers/service/auth:release/oauth2-middleware-logback.xml",
|
||||
"trigger_service_logback": "//triggers/service:release/trigger-service-logback.xml",
|
||||
"trigger_logback": "//triggers/runner:src/main/resources/logback.xml",
|
||||
"java_codegen_logback": "//language-support/java/codegen:src/main/resources/logback.xml",
|
||||
"daml_script_logback": "//daml-script/runner:src/main/resources/logback.xml",
|
||||
"export_logback": "//daml-script/export:src/main/resources/logback.xml",
|
||||
@ -22,7 +19,6 @@ inputs = {
|
||||
"daml_extension_stylesheet": "//compiler/daml-extension:webview-stylesheet.css",
|
||||
"daml2js_dist": "//language-support/ts/codegen:daml2js-dist",
|
||||
"templates": "//templates:templates-tarball.tar.gz",
|
||||
"trigger_dars": "//triggers/daml:daml-trigger-dars",
|
||||
"script_dars": "//daml-script/daml:daml-script-dars",
|
||||
"script3_dars": "//daml-script/daml3:daml3-script-dars",
|
||||
"canton": "//canton:community_app_deploy.jar",
|
||||
@ -74,7 +70,6 @@ def sdk_tarball(name, version, config):
|
||||
tar xf $(location {damlc_dist}) --strip-components=1 -C $$OUT/damlc
|
||||
|
||||
mkdir -p $$OUT/daml-libs
|
||||
cp -t $$OUT/daml-libs $(locations {trigger_dars})
|
||||
cp -t $$OUT/daml-libs $(locations {script_dars})
|
||||
cp -t $$OUT/daml-libs $(locations {script3_dars})
|
||||
|
||||
@ -96,10 +91,7 @@ def sdk_tarball(name, version, config):
|
||||
|
||||
mkdir -p $$OUT/daml-sdk
|
||||
cp $(location {sdk_deploy_jar}) $$OUT/daml-sdk/daml-sdk.jar
|
||||
cp -L $(location {trigger_service_logback}) $$OUT/daml-sdk/
|
||||
cp -L $(location {oauth2_middleware_logback}) $$OUT/daml-sdk/
|
||||
cp -L $(location {java_codegen_logback}) $$OUT/daml-sdk/codegen-logback.xml
|
||||
cp -L $(location {trigger_logback}) $$OUT/daml-sdk/trigger-logback.xml
|
||||
cp -L $(location {daml_script_logback}) $$OUT/daml-sdk/script-logback.xml
|
||||
cp -L $(location {export_logback}) $$OUT/daml-sdk/export-logback.xml
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Security tests, by category
|
||||
|
||||
## Authorization:
|
||||
- auth and auth-* should not be set together for the trigger service: [CliConfigTest.scala](triggers/service/src/test-suite/scala/com/daml/lf/engine/trigger/CliConfigTest.scala#L40)
|
||||
- badly-authorized create is rejected: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L76)
|
||||
- badly-authorized exercise is rejected: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L177)
|
||||
- badly-authorized exercise/create (create is unauthorized) is rejected: [AuthPropagationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthPropagationSpec.scala#L267)
|
||||
@ -11,8 +10,6 @@
|
||||
- badly-authorized lookup is rejected: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L134)
|
||||
- create with no signatories is rejected: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L66)
|
||||
- create with non-signatory maintainers is rejected: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L88)
|
||||
- error on specifying both authCommonUri and authInternalUri/authExternalUri for the trigger service: [AuthorizationConfigTest.scala](triggers/service/src/test-suite/scala/com/daml/lf/engine/trigger/AuthorizationConfigTest.scala#L24)
|
||||
- error on specifying only authInternalUri and no authExternalUri for the trigger service: [AuthorizationConfigTest.scala](triggers/service/src/test-suite/scala/com/daml/lf/engine/trigger/AuthorizationConfigTest.scala#L52)
|
||||
- exercise with no controllers is rejected: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L167)
|
||||
- well-authorized create is accepted: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L59)
|
||||
- well-authorized exercise is accepted: [AuthorizationSpec.scala](daml-lf/engine/src/test/scala/com/digitalasset/daml/lf/engine/AuthorizationSpec.scala#L160)
|
||||
|
@ -4,7 +4,6 @@ load("@os_info//:os_info.bzl", "is_windows")
|
||||
load("@build_environment//:configuration.bzl", "sdk_version")
|
||||
|
||||
exports_files(glob(["create-daml-app-test-resources/*"]) + [
|
||||
"copy-trigger/src/CopyTrigger.daml",
|
||||
"create-daml-app/ui/package.json.template",
|
||||
])
|
||||
|
||||
@ -52,8 +51,6 @@ genrule(
|
||||
"empty-skeleton/**",
|
||||
"create-daml-app/**",
|
||||
"quickstart-java/**",
|
||||
"copy-trigger/**",
|
||||
"gsg-trigger.patch",
|
||||
],
|
||||
exclude = ["**/NO_AUTO_COPYRIGHT"],
|
||||
) + [
|
||||
@ -76,17 +73,11 @@ genrule(
|
||||
|
||||
cp -rL templates/* $$SRC/
|
||||
|
||||
PATCH_TOOL=$$PWD/$(location @patch_dev_env//:patch)
|
||||
cp -rL $$SRC/create-daml-app $$SRC/gsg-trigger
|
||||
"$$PATCH_TOOL" -d $$SRC/gsg-trigger -p1 < $$SRC/gsg-trigger.patch
|
||||
|
||||
# templates in templates dir
|
||||
for d in skeleton \
|
||||
empty-skeleton \
|
||||
create-daml-app \
|
||||
quickstart-java \
|
||||
copy-trigger \
|
||||
gsg-trigger; do
|
||||
quickstart-java; do
|
||||
mkdir -p $$OUT/$$d
|
||||
cp -rL $$SRC/$$d/* $$OUT/$$d/
|
||||
for f in gitattributes gitignore dlint.yaml; do
|
||||
|
@ -1,12 +0,0 @@
|
||||
sdk-version: __VERSION__
|
||||
name: __PROJECT_NAME__
|
||||
source: src
|
||||
version: 0.0.1
|
||||
# trigger-dependencies-begin
|
||||
dependencies:
|
||||
- daml-prim
|
||||
- daml-stdlib
|
||||
- daml-trigger
|
||||
# trigger-dependencies-end
|
||||
sandbox-options:
|
||||
- --wall-clock-time
|
@ -1,109 +0,0 @@
|
||||
-- Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
-- SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
module CopyTrigger where
|
||||
|
||||
import DA.List hiding (dedup)
|
||||
|
||||
import Daml.Trigger
|
||||
|
||||
-- ORIGINAL_TEMPLATE_BEGIN
|
||||
template Original
|
||||
with
|
||||
owner : Party
|
||||
name : Text
|
||||
textdata : Text
|
||||
where
|
||||
signatory owner
|
||||
|
||||
key (owner, name) : (Party, Text)
|
||||
maintainer key._1
|
||||
-- ORIGINAL_TEMPLATE_END
|
||||
|
||||
deriving instance Ord Original
|
||||
|
||||
-- SUBSCRIBER_TEMPLATE_BEGIN
|
||||
template Subscriber
|
||||
with
|
||||
subscriber : Party
|
||||
subscribedTo : Party
|
||||
where
|
||||
signatory subscriber
|
||||
observer subscribedTo
|
||||
key (subscriber, subscribedTo) : (Party, Party)
|
||||
maintainer key._1
|
||||
-- SUBSCRIBER_TEMPLATE_END
|
||||
|
||||
-- COPY_TEMPLATE_BEGIN
|
||||
template Copy
|
||||
with
|
||||
original : Original
|
||||
subscriber : Party
|
||||
where
|
||||
signatory (signatory original)
|
||||
observer subscriber
|
||||
-- COPY_TEMPLATE_END
|
||||
|
||||
deriving instance Ord Copy
|
||||
|
||||
-- TRIGGER_BEGIN
|
||||
copyTrigger : Trigger ()
|
||||
copyTrigger = Trigger
|
||||
{ initialize = pure ()
|
||||
, updateState = \_message -> pure ()
|
||||
, rule = copyRule
|
||||
, registeredTemplates = AllInDar
|
||||
, heartbeat = None
|
||||
}
|
||||
-- TRIGGER_END
|
||||
|
||||
-- RULE_SIGNATURE_BEGIN
|
||||
copyRule : Party -> TriggerA () ()
|
||||
copyRule party = do
|
||||
-- RULE_SIGNATURE_END
|
||||
-- ACS_QUERY_BEGIN
|
||||
subscribers : [(ContractId Subscriber, Subscriber)] <- query @Subscriber
|
||||
originals : [(ContractId Original, Original)] <- query @Original
|
||||
copies : [(ContractId Copy, Copy)] <- query @Copy
|
||||
-- ACS_QUERY_END
|
||||
|
||||
-- ACS_FILTER_BEGIN
|
||||
let ownedSubscribers = filter (\(_, s) -> s.subscribedTo == party) subscribers
|
||||
let ownedOriginals = filter (\(_, o) -> o.owner == party) originals
|
||||
let ownedCopies = filter (\(_, c) -> c.original.owner == party) copies
|
||||
-- ACS_FILTER_END
|
||||
|
||||
-- SUBSCRIBING_PARTIES_BEGIN
|
||||
let subscribingParties = map (\(_, s) -> s.subscriber) ownedSubscribers
|
||||
-- SUBSCRIBING_PARTIES_END
|
||||
|
||||
-- GROUP_COPIES_BEGIN
|
||||
let groupedCopies : [[(ContractId Copy, Copy)]]
|
||||
groupedCopies = groupOn snd $ sortOn snd $ ownedCopies
|
||||
let copiesToKeep = map head groupedCopies
|
||||
let archiveDuplicateCopies = concatMap tail groupedCopies
|
||||
-- GROUP_COPIES_END
|
||||
|
||||
-- ARCHIVE_COPIES_BEGIN
|
||||
let archiveMissingOriginal = filter (\(_, c) -> c.original `notElem` map snd ownedOriginals) copiesToKeep
|
||||
let archiveMissingSubscriber = filter (\(_, c) -> c.subscriber `notElem` subscribingParties) copiesToKeep
|
||||
let archiveCopies = dedup $ map fst $ archiveDuplicateCopies <> archiveMissingOriginal <> archiveMissingSubscriber
|
||||
-- ARCHIVE_COPIES_END
|
||||
|
||||
-- ARCHIVE_COMMAND_BEGIN
|
||||
forA archiveCopies $ \cid -> emitCommands [exerciseCmd cid Archive] [toAnyContractId cid]
|
||||
-- ARCHIVE_COMMAND_END
|
||||
|
||||
-- CREATE_COPIES_BEGIN
|
||||
let neededCopies = [Copy m o | (_, m) <- ownedOriginals, o <- subscribingParties]
|
||||
let createCopies = filter (\c -> c `notElem` map snd copiesToKeep) neededCopies
|
||||
mapA dedupCreate createCopies
|
||||
-- CREATE_COPIES_END
|
||||
pure ()
|
||||
|
||||
-- | The dedup function from DA.List requires an Ord constraint which we do not have for `ContractId k`. Therefore,
|
||||
-- we resort to the n^2 version for now. Once we have Maps we can use those to implement a more efficient dedup.
|
||||
dedup : Eq k => [k] -> [k]
|
||||
dedup [] = []
|
||||
dedup (x :: xs) = x :: dedup (filter (/= x) xs)
|
@ -39,17 +39,12 @@ write_scalatest_runpath(
|
||||
tests = [
|
||||
"//ledger-service/utils",
|
||||
"//libs-scala/jwt:tests-lib",
|
||||
"//triggers/service:test",
|
||||
"//triggers/service/auth:oauth2-middleware-tests",
|
||||
],
|
||||
runtime_deps = [
|
||||
"//ledger/error:error-test-lib",
|
||||
"//libs-scala/flyway-testing",
|
||||
"//libs-scala/jwt",
|
||||
"//libs-scala/scalatest-utils",
|
||||
"//triggers/service:trigger-service",
|
||||
"//triggers/service:trigger-service-tests",
|
||||
"//triggers/service/auth:oauth2-middleware-tests",
|
||||
],
|
||||
deps = [
|
||||
":ledger-generator-lib",
|
||||
|
@ -1,146 +0,0 @@
|
||||
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# TODO Once daml_compile uses build instead of package we should use
|
||||
# daml_compile instead of a genrule.
|
||||
|
||||
load("@build_environment//:configuration.bzl", "ghc_version", "sdk_version")
|
||||
load("//daml-lf/language:daml-lf.bzl", "COMPILER_LF_VERSIONS")
|
||||
|
||||
# Build one DAR per LF version to bundle with the SDK.
|
||||
# Also build one DAR with the default LF version for test-cases.
|
||||
[
|
||||
genrule(
|
||||
name = "daml-trigger{}".format(suffix),
|
||||
srcs = glob(["**/*.daml"]) + ["//daml-script/daml:daml-script{}".format(suffix)],
|
||||
outs = ["daml-trigger{}.dar".format(suffix)],
|
||||
cmd = """
|
||||
set -eou pipefail
|
||||
TMP_DIR=$$(mktemp -d)
|
||||
mkdir -p $$TMP_DIR/daml/Daml/Trigger
|
||||
cp -L $(location Daml/Trigger.daml) $$TMP_DIR/daml/Daml
|
||||
cp -L $(location Daml/Trigger/Assert.daml) $$TMP_DIR/daml/Daml/Trigger
|
||||
cp -L $(location Daml/Trigger/Internal.daml) $$TMP_DIR/daml/Daml/Trigger
|
||||
cp -L $(location Daml/Trigger/LowLevel.daml) $$TMP_DIR/daml/Daml/Trigger
|
||||
cp -L $(location {daml_script}) $$TMP_DIR/daml-script.dar
|
||||
cat << EOF > $$TMP_DIR/daml.yaml
|
||||
sdk-version: {sdk}
|
||||
name: daml-trigger
|
||||
source: daml
|
||||
version: {ghc}
|
||||
dependencies:
|
||||
- daml-stdlib
|
||||
- daml-prim
|
||||
- daml-script.dar
|
||||
build-options: {build_options}
|
||||
EOF
|
||||
$(location //compiler/damlc) build --project-root $$TMP_DIR --ghc-option=-Werror --log-level=WARNING \
|
||||
-o $$PWD/$@
|
||||
rm -rf $$TMP_DIR
|
||||
""".format(
|
||||
build_options = str([
|
||||
"--target",
|
||||
lf_version,
|
||||
] if lf_version else []),
|
||||
daml_script = "//daml-script/daml:daml-script{}".format(suffix),
|
||||
ghc = ghc_version,
|
||||
sdk = sdk_version,
|
||||
),
|
||||
tools = [
|
||||
"//compiler/damlc",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
for lf_version in COMPILER_LF_VERSIONS + [""]
|
||||
for suffix in [("-" + lf_version) if lf_version else ""]
|
||||
]
|
||||
|
||||
filegroup(
|
||||
name = "daml-trigger-dars",
|
||||
srcs = [
|
||||
"daml-trigger-{}.dar".format(lf_version)
|
||||
for lf_version in COMPILER_LF_VERSIONS
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "daml-trigger-json-docs",
|
||||
srcs = glob(["**/*.daml"]) + [
|
||||
"//daml-script/daml:daml-script",
|
||||
],
|
||||
outs = ["daml-trigger.json"],
|
||||
cmd = """
|
||||
TMP_DIR=$$(mktemp -d)
|
||||
mkdir -p $$TMP_DIR/daml/Daml/Trigger
|
||||
cp -L $(location Daml/Trigger.daml) $$TMP_DIR/daml/Daml
|
||||
cp -L $(location Daml/Trigger/Assert.daml) $$TMP_DIR/daml/Daml/Trigger
|
||||
cp -L $(location Daml/Trigger/Internal.daml) $$TMP_DIR/daml/Daml/Trigger
|
||||
cp -L $(location Daml/Trigger/LowLevel.daml) $$TMP_DIR/daml/Daml/Trigger
|
||||
cp -L $$PWD/$(location {daml_script}) $$TMP_DIR/daml-script.dar
|
||||
cat << EOF > $$TMP_DIR/daml.yaml
|
||||
sdk-version: {sdk}
|
||||
name: daml-trigger
|
||||
source: daml
|
||||
version: {ghc}
|
||||
dependencies:
|
||||
- daml-stdlib
|
||||
- daml-prim
|
||||
- daml-script.dar
|
||||
EOF
|
||||
DAMLC=$$PWD/$(location //compiler/damlc)
|
||||
JSON=$$PWD/$(location :daml-trigger.json)
|
||||
cd $$TMP_DIR
|
||||
$$DAMLC init
|
||||
$$DAMLC -- docs \
|
||||
--combine \
|
||||
--output=$$JSON \
|
||||
--format=Json \
|
||||
--package-name=daml-trigger \
|
||||
$$TMP_DIR/daml/Daml/Trigger.daml \
|
||||
$$TMP_DIR/daml/Daml/Trigger/Assert.daml \
|
||||
$$TMP_DIR/daml/Daml/Trigger/LowLevel.daml
|
||||
""".format(
|
||||
daml_script = "//daml-script/daml:daml-script",
|
||||
ghc = ghc_version,
|
||||
sdk = sdk_version,
|
||||
),
|
||||
tools = [
|
||||
"//compiler/damlc",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "daml-trigger-docs",
|
||||
srcs = [
|
||||
":daml-trigger.json",
|
||||
":daml-trigger-rst-template.rst",
|
||||
":daml-trigger-index-template.rst",
|
||||
":daml-trigger-hoogle-template.txt",
|
||||
],
|
||||
outs = [
|
||||
"daml-trigger-rst.tar.gz",
|
||||
"daml-trigger-hoogle.txt",
|
||||
"daml-trigger-anchors.json",
|
||||
],
|
||||
cmd = """
|
||||
$(location //compiler/damlc) -- docs \
|
||||
--output=daml-trigger-rst \
|
||||
--input-format=json \\
|
||||
--format=Rst \
|
||||
--template=$(location :daml-trigger-rst-template.rst) \
|
||||
--index-template=$(location :daml-trigger-index-template.rst) \\
|
||||
--hoogle-template=$(location :daml-trigger-hoogle-template.txt) \\
|
||||
--base-url=https://docs.daml.com/triggers/api/ \\
|
||||
--output-hoogle=$(location :daml-trigger-hoogle.txt) \\
|
||||
--output-anchor=$(location :daml-trigger-anchors.json) \\
|
||||
$(location :daml-trigger.json)
|
||||
$(execpath //bazel_tools/sh:mktgz) $(location :daml-trigger-rst.tar.gz) daml-trigger-rst
|
||||
""",
|
||||
tools = [
|
||||
"//bazel_tools/sh:mktgz",
|
||||
"//compiler/damlc",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
@ -1,358 +0,0 @@
|
||||
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
-- SPDX-License-Identifier: Apache-2.0
|
||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE CPP #-}
|
||||
|
||||
module Daml.Trigger
|
||||
( query
|
||||
, queryFilter
|
||||
, queryContractId
|
||||
, queryContractKey
|
||||
, ActionTriggerAny
|
||||
, getCommandsInFlight
|
||||
, ActionTriggerUpdate
|
||||
, Trigger(..)
|
||||
, TriggerA
|
||||
, TriggerUpdateA
|
||||
, TriggerInitializeA
|
||||
, get
|
||||
, put
|
||||
, modify
|
||||
, emitCommands
|
||||
, emitCommandsV2
|
||||
, runTrigger
|
||||
, CommandId
|
||||
, Command(..)
|
||||
, AnyContractId
|
||||
, toAnyContractId
|
||||
, fromAnyContractId
|
||||
, exerciseCmd
|
||||
, createCmd
|
||||
, exerciseByKeyCmd
|
||||
, createAndExerciseCmd
|
||||
, dedupExercise
|
||||
, dedupCreate
|
||||
, dedupExerciseByKey
|
||||
, dedupCreateAndExercise
|
||||
, Message(..)
|
||||
, Completion(..)
|
||||
, Transaction(..)
|
||||
, Event(..)
|
||||
, Created
|
||||
, Archived
|
||||
, fromCreated
|
||||
, fromArchived
|
||||
, CompletionStatus(..)
|
||||
, RegisteredTemplates(..)
|
||||
, registeredTemplate
|
||||
, RelTime(..)
|
||||
, getReadAs
|
||||
, getActAs
|
||||
) where
|
||||
|
||||
import Prelude hiding (any)
|
||||
import DA.Action
|
||||
import DA.Action.State (execState)
|
||||
import DA.Foldable (any)
|
||||
import DA.Functor ((<&>))
|
||||
import DA.Map (Map, size)
|
||||
import qualified DA.Map as Map
|
||||
import DA.Optional
|
||||
|
||||
import Daml.Trigger.Internal
|
||||
import Daml.Trigger.LowLevel hiding (BatchTrigger, Trigger)
|
||||
import qualified Daml.Trigger.LowLevel as LowLevel
|
||||
|
||||
-- public API
|
||||
|
||||
-- | Extract the contracts of a given template from the ACS.
|
||||
getContracts : forall a. Template a => ACS -> [(ContractId a, a)]
|
||||
getContracts acs@(ACS tpls _) = mapOptional fromAny
|
||||
$ filter (\(cid, _) -> not $ cid `elem` allPending)
|
||||
$ optional [] Map.toList
|
||||
$ Map.lookup (templateTypeRep @a) tpls
|
||||
where
|
||||
fromAny (cid, tpl) = (,) <$> fromAnyContractId cid <*> fromAnyTemplate tpl
|
||||
allPending = getPendingContracts acs
|
||||
|
||||
getPendingContracts : ACS -> [AnyContractId]
|
||||
getPendingContracts (ACS _ pending) = concatMap snd $ Map.toList pending
|
||||
|
||||
getContractById : forall a. Template a => ContractId a -> ACS -> Optional a
|
||||
getContractById id (ACS tpls pending) = do
|
||||
let aid = toAnyContractId id
|
||||
implSpecific = Map.lookup aid <=< Map.lookup (templateTypeRep @a)
|
||||
aa <- implSpecific tpls
|
||||
a <- fromAnyTemplate aa
|
||||
if any (elem aid) pending then None else Some a
|
||||
|
||||
-- | Extract the contracts of a given template from the ACS.
|
||||
query : forall a m. (Template a, ActionTriggerAny m) => m [(ContractId a, a)]
|
||||
query = implQuery
|
||||
|
||||
-- | Extract the contracts of a given template from the ACS and filter
|
||||
-- to those that match the predicate.
|
||||
queryFilter : forall a m. (Functor m, Template a, ActionTriggerAny m) => (a -> Bool) -> m [(ContractId a, a)]
|
||||
queryFilter pred = filter (\(_, c) -> pred c) <$> implQuery
|
||||
|
||||
-- | Find the contract with the given `key` in the ACS, if present.
|
||||
queryContractKey : forall a k m. (Template a, HasKey a k, Eq k, ActionTriggerAny m, Functor m)
|
||||
=> k -> m (Optional (ContractId a, a))
|
||||
queryContractKey k = find (\(_, a) -> k == key a) <$> query
|
||||
|
||||
-- | Features possible in `initialize`, `updateState`, and `rule`.
|
||||
class ActionTriggerAny m where
|
||||
-- | Extract the contracts of a given template from the ACS. (However, the
|
||||
-- type parameters are in the 'm a' order, so it is not exported.)
|
||||
implQuery : forall a. Template a => m [(ContractId a, a)]
|
||||
|
||||
-- | Find the contract with the given `id` in the ACS, if present.
|
||||
queryContractId : Template a => ContractId a -> m (Optional a)
|
||||
|
||||
-- | Query the list of currently pending contracts as set by
|
||||
-- `emitCommands`.
|
||||
queryPendingContracts : m [AnyContractId]
|
||||
|
||||
getReadAs : m [Party]
|
||||
|
||||
getActAs : m Party
|
||||
|
||||
instance ActionTriggerAny (TriggerA s) where
|
||||
implQuery = TriggerA $ pure . getContracts
|
||||
queryContractId id = TriggerA $ pure . getContractById id
|
||||
queryPendingContracts = TriggerA $ \acs -> pure (getPendingContracts acs)
|
||||
getReadAs = TriggerA $ \_ -> do
|
||||
s <- get
|
||||
pure s.readAs
|
||||
|
||||
getActAs = TriggerA $ \_ -> do
|
||||
s <- get
|
||||
pure s.actAs
|
||||
|
||||
instance ActionTriggerAny (TriggerUpdateA s) where
|
||||
implQuery = TriggerUpdateA $ \s -> pure (getContracts s.acs)
|
||||
queryContractId id = TriggerUpdateA $ \s -> pure (getContractById id s.acs)
|
||||
queryPendingContracts = TriggerUpdateA $ \s -> pure (getPendingContracts s.acs)
|
||||
getReadAs = TriggerUpdateA $ \s -> pure s.readAs
|
||||
getActAs = TriggerUpdateA $ \s -> pure s.actAs
|
||||
|
||||
instance ActionTriggerAny TriggerInitializeA where
|
||||
implQuery = TriggerInitializeA (\s -> getContracts s.acs)
|
||||
queryContractId id = TriggerInitializeA (\s -> getContractById id s.acs)
|
||||
queryPendingContracts = TriggerInitializeA (\s -> getPendingContracts s.acs)
|
||||
getReadAs = TriggerInitializeA (\s -> s.readAs)
|
||||
getActAs = TriggerInitializeA (\s -> s.actAs)
|
||||
|
||||
-- | Features possible in `updateState` and `rule`.
|
||||
class ActionTriggerAny m => ActionTriggerUpdate m where
|
||||
-- | Retrieve command submissions made by this trigger that have not yet
|
||||
-- completed. If the trigger has restarted, it will not contain commands from
|
||||
-- before the restart; therefore, this should be treated as an optimization
|
||||
-- rather than an absolute authority on ledger state.
|
||||
getCommandsInFlight : m (Map CommandId [Command])
|
||||
|
||||
instance ActionTriggerUpdate (TriggerUpdateA s) where
|
||||
getCommandsInFlight = TriggerUpdateA $ \s -> pure s.commandsInFlight
|
||||
|
||||
instance ActionTriggerUpdate (TriggerA s) where
|
||||
getCommandsInFlight = liftTriggerRule $ get <&> \s -> s.commandsInFlight
|
||||
|
||||
-- | This is the type of your trigger. `s` is the user-defined state type which
|
||||
-- you can often leave at `()`.
|
||||
data Trigger s = Trigger
|
||||
{ initialize : TriggerInitializeA s
|
||||
-- ^ Initialize the user-defined state based on the ACS.
|
||||
, updateState : Message -> TriggerUpdateA s ()
|
||||
-- ^ Update the user-defined state based on a transaction or
|
||||
-- completion message. It can manipulate the state with `get`, `put`,
|
||||
-- and `modify`, or query the ACS with `query`.
|
||||
, rule : Party -> TriggerA s ()
|
||||
-- ^ The rule defines the main logic of your trigger. It can send commands
|
||||
-- to the ledger using `emitCommands` to change the ACS.
|
||||
-- The rule depends on the following arguments:
|
||||
--
|
||||
-- * The party your trigger is running as.
|
||||
-- * The user-defined state.
|
||||
--
|
||||
-- and can retrieve other data with functions in `TriggerA`:
|
||||
--
|
||||
-- * The current state of the ACS.
|
||||
-- * The current time (UTC in wallclock mode, Unix epoch in static mode)
|
||||
-- * The commands in flight.
|
||||
, registeredTemplates : RegisteredTemplates
|
||||
-- ^ The templates the trigger will receive events for.
|
||||
, heartbeat : Optional RelTime
|
||||
-- ^ Send a heartbeat message at the given interval.
|
||||
}
|
||||
|
||||
-- | Send a transaction consisting of the given commands to the ledger.
|
||||
-- The second argument can be used to mark a list of contract ids as pending.
|
||||
-- These contracts will automatically be filtered from getContracts until we
|
||||
-- either get the corresponding transaction event for this command or
|
||||
-- a failing completion.
|
||||
emitCommands : [Command] -> [AnyContractId] -> TriggerA s CommandId
|
||||
emitCommands cmds pending = do
|
||||
id <- liftTriggerRule $ submitCommands cmds
|
||||
let commands = Commands id cmds
|
||||
liftTriggerRule $ modify $ \s -> s
|
||||
{ commandsInFlight = addCommands s.commandsInFlight commands
|
||||
, pendingContracts = Map.insert id pending s.pendingContracts
|
||||
}
|
||||
pure id
|
||||
|
||||
-- Version of emitCommands that will not perform command submissions once there are too
|
||||
-- many commands in-flight. When this function does not to return a command ID (i.e. it
|
||||
-- returns None), then client code should manage this scenario. Failing to manage this
|
||||
-- warning scenario will eventually cause the trigger runner to stop with an
|
||||
-- InFlightCommandOverflowException exception.
|
||||
emitCommandsV2 : [Command] -> [AnyContractId] -> TriggerA s (Optional CommandId)
|
||||
emitCommandsV2 cmds pending = do
|
||||
mbId <- liftTriggerRule $ mbSubmitCommands cmds
|
||||
case mbId of
|
||||
None ->
|
||||
pure None
|
||||
Some id -> do
|
||||
let commands = Commands id cmds
|
||||
liftTriggerRule $ modify $ \s -> s
|
||||
{ commandsInFlight = addCommands s.commandsInFlight commands
|
||||
, pendingContracts = Map.insert id pending s.pendingContracts
|
||||
}
|
||||
pure (Some id)
|
||||
where
|
||||
mbSubmitCommands cmds = do
|
||||
state <- get
|
||||
if size state.commandsInFlight < state.config.maxInFlightCommands then do
|
||||
id <- submitCommands cmds
|
||||
pure (Some id)
|
||||
else do
|
||||
_ <- debug "WARN: too many commands currently in-flight, so command submission will be dropped"
|
||||
pure None
|
||||
|
||||
-- | Create the template if it’s not already in the list of commands
|
||||
-- in flight (it will still be created if it is in the ACS).
|
||||
--
|
||||
-- Note that this will send the create as a single-command transaction.
|
||||
-- If you need to send multiple commands in one transaction, use
|
||||
-- `emitCommands` with `createCmd` and handle filtering yourself.
|
||||
dedupCreate : (Eq t, Template t) => t -> TriggerA s ()
|
||||
dedupCreate t = do
|
||||
aState <- liftTriggerRule get
|
||||
-- This is a very naive approach that is linear in the number of commands in flight.
|
||||
-- We probably want to change this to express the commands in flight as some kind of
|
||||
-- map to make these lookups cheaper.
|
||||
let cmds = concat $ map snd (Map.toList aState.commandsInFlight)
|
||||
unless (any ((Some t ==) . fromCreate) cmds) $
|
||||
void $ emitCommands [createCmd t] []
|
||||
|
||||
-- | Create the template and exercise a choice on it if it’s not already in the list of commands
|
||||
-- in flight (it will still be created if it is in the ACS).
|
||||
--
|
||||
-- Note that this will send the create and exercise as a
|
||||
-- single-command transaction. If you need to send multiple commands
|
||||
-- in one transaction, use `emitCommands` with `createAndExerciseCmd`
|
||||
-- and handle filtering yourself.
|
||||
dedupCreateAndExercise : (Eq t, Eq c, Template t, Choice t c r) => t -> c -> TriggerA s ()
|
||||
dedupCreateAndExercise t c = do
|
||||
aState <- liftTriggerRule get
|
||||
-- This is a very naive approach that is linear in the number of
|
||||
-- commands in flight. We probably want to change this to express
|
||||
-- the commands in flight as some kind of map to make these lookups
|
||||
-- cheaper.
|
||||
let cmds = concat $ map snd (Map.toList aState.commandsInFlight)
|
||||
unless (any ((Some (t, c) ==) . fromCreateAndExercise) cmds) $
|
||||
void $ emitCommands [createAndExerciseCmd t c] []
|
||||
|
||||
-- | Exercise the choice on the given contract if it is not already
|
||||
-- in flight.
|
||||
--
|
||||
-- Note that this will send the exercise as a single-command transaction.
|
||||
-- If you need to send multiple commands in one transaction, use
|
||||
-- `emitCommands` with `exerciseCmd` and handle filtering yourself.
|
||||
--
|
||||
-- If you are calling a consuming choice, you might be better off by using
|
||||
-- `emitCommands` and adding the contract id to the pending set.
|
||||
dedupExercise : (Eq c, Choice t c r) => ContractId t -> c -> TriggerA s ()
|
||||
dedupExercise cid c = do
|
||||
aState <- liftTriggerRule get
|
||||
-- This is a very naive approach that is linear in the number of commands in flight.
|
||||
-- We probably want to change this to express the commands in flight as some kind of
|
||||
-- map to make these lookups cheaper.
|
||||
let cmds = concat $ map snd (Map.toList aState.commandsInFlight)
|
||||
unless (any ((Some (cid, c) ==) . fromExercise) cmds) $
|
||||
void $ emitCommands [exerciseCmd cid c] []
|
||||
|
||||
-- | Exercise the choice on the given contract if it is not already
|
||||
-- in flight.
|
||||
--
|
||||
-- Note that this will send the exercise as a single-command transaction.
|
||||
-- If you need to send multiple commands in one transaction, use
|
||||
-- `emitCommands` with `exerciseCmd` and handle filtering yourself.
|
||||
dedupExerciseByKey : forall t c r k s. (Eq c, Eq k, Choice t c r, TemplateKey t k) => k -> c -> TriggerA s ()
|
||||
dedupExerciseByKey k c = do
|
||||
aState <- liftTriggerRule get
|
||||
-- This is a very naive approach that is linear in the number of commands in flight.
|
||||
-- We probably want to change this to express the commands in flight as some kind of
|
||||
-- map to make these lookups cheaper.
|
||||
let cmds = concat $ map snd (Map.toList aState.commandsInFlight)
|
||||
unless (any ((Some (k, c) ==) . fromExerciseByKey @t) cmds) $
|
||||
void $ emitCommands [exerciseByKeyCmd @t k c] []
|
||||
|
||||
-- | Transform the high-level trigger type into the batching trigger from `Daml.Trigger.LowLevel`.
|
||||
runTrigger : forall s. Trigger s -> LowLevel.BatchTrigger (TriggerState s)
|
||||
runTrigger userTrigger = LowLevel.BatchTrigger
|
||||
{ initialState
|
||||
, update = update
|
||||
, registeredTemplates = userTrigger.registeredTemplates
|
||||
, heartbeat = userTrigger.heartbeat
|
||||
}
|
||||
where
|
||||
initialState args =
|
||||
let acs = foldl (\acs created -> applyEvent (CreatedEvent created) acs) (ACS mempty Map.empty) args.acs.activeContracts
|
||||
userState = runTriggerInitializeA userTrigger.initialize (TriggerInitState acs args.actAs args.readAs)
|
||||
state = TriggerState acs args.actAs args.readAs userState Map.empty args.config
|
||||
in TriggerSetup $ execStateT (runTriggerRule $ runRule userTrigger.rule) state
|
||||
|
||||
mkUserState state acs msg =
|
||||
let state' = TriggerUpdateState state.commandsInFlight acs state.actAs state.readAs
|
||||
in execState (flip runTriggerUpdateA state' $ userTrigger.updateState msg) state.userState
|
||||
|
||||
update msgs = do
|
||||
runUserTrigger <- or <$> mapA processMessage msgs
|
||||
when runUserTrigger $
|
||||
runRule userTrigger.rule
|
||||
|
||||
-- Returns 'True' if the processed message means we need to run the user's trigger.
|
||||
processMessage : Message -> TriggerRule (TriggerState s) Bool
|
||||
processMessage msg = do
|
||||
state <- get
|
||||
case msg of
|
||||
MCompletion completion ->
|
||||
-- NB: the commands-in-flight and ACS updateState sees are those
|
||||
-- prior to updates incurred by the msg
|
||||
let userState = mkUserState state state.acs msg
|
||||
in case completion.status of
|
||||
Succeeded {} -> do
|
||||
-- We delete successful completions when we receive the corresponding transaction
|
||||
-- to avoid removing a command from commandsInFlight before we have modified the ACS.
|
||||
put $ state { userState }
|
||||
pure False
|
||||
Failed {} -> do
|
||||
let commandsInFlight = Map.delete completion.commandId state.commandsInFlight
|
||||
acs = state.acs { pendingContracts = Map.delete completion.commandId state.acs.pendingContracts }
|
||||
put $ state { commandsInFlight, userState, acs }
|
||||
pure True
|
||||
MTransaction transaction -> do
|
||||
let acs = applyTransaction transaction state.acs
|
||||
-- again, we use the commands-in-flight and ACS before the update below
|
||||
userState = mkUserState state acs msg
|
||||
-- See the comment above for why we delete this here instead of when we receive the completion.
|
||||
(acs', commandsInFlight) = case transaction.commandId of
|
||||
None -> (acs, state.commandsInFlight)
|
||||
Some commandId -> (acs { pendingContracts = Map.delete commandId acs.pendingContracts }, Map.delete commandId state.commandsInFlight)
|
||||
put $ state { acs = acs', userState, commandsInFlight }
|
||||
pure True
|
||||
MHeartbeat -> do
|
||||
let userState = mkUserState state state.acs msg
|
||||
put $ state { userState }
|
||||
pure True
|
||||
|
@ -1,147 +0,0 @@
|
||||
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
-- SPDX-License-Identifier: Apache-2.0
|
||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE CPP #-}
|
||||
|
||||
module Daml.Trigger.Assert
|
||||
( ACSBuilder
|
||||
, toACS
|
||||
, testRule
|
||||
, flattenCommands
|
||||
, assertCreateCmd
|
||||
, assertExerciseCmd
|
||||
, assertExerciseByKeyCmd
|
||||
) where
|
||||
|
||||
import qualified DA.List as List
|
||||
import DA.Map (Map)
|
||||
import qualified DA.Map as Map
|
||||
import qualified DA.Text as Text
|
||||
|
||||
import Daml.Trigger hiding (queryContractId)
|
||||
import Daml.Trigger.Internal
|
||||
import Daml.Trigger.LowLevel hiding (Trigger)
|
||||
|
||||
import Daml.Script (Script, queryContractId)
|
||||
|
||||
-- | Used to construct an 'ACS' for 'testRule'.
|
||||
newtype ACSBuilder = ACSBuilder (Party -> [Script (AnyContractId, AnyTemplate)])
|
||||
|
||||
instance Semigroup ACSBuilder where
|
||||
ACSBuilder l <> ACSBuilder r =
|
||||
ACSBuilder (\p -> l p <> r p)
|
||||
|
||||
instance Monoid ACSBuilder where
|
||||
mempty = ACSBuilder (const mempty)
|
||||
|
||||
buildACS : Party -> ACSBuilder -> Script ACS
|
||||
buildACS party (ACSBuilder fetches) = do
|
||||
activeContracts <- sequence (fetches party)
|
||||
pure ACS
|
||||
{ activeContracts = groupActiveContracts activeContracts
|
||||
, pendingContracts = Map.empty
|
||||
}
|
||||
|
||||
-- | Include the given contract in the 'ACS'. Note that the `ContractId`
|
||||
-- must point to an active contract.
|
||||
toACS : (Template t, HasAgreement t) => ContractId t -> ACSBuilder
|
||||
toACS cid = ACSBuilder $ \p ->
|
||||
[ do t <-
|
||||
queryContractId p cid >>= \case
|
||||
None -> abort ("Failed to fetch contract passed to toACS: " <> show cid)
|
||||
Some c -> pure c
|
||||
pure (toAnyContractId cid, toAnyTemplate t)
|
||||
]
|
||||
|
||||
-- | Execute a trigger's rule once in a scenario.
|
||||
testRule
|
||||
: Trigger s -- ^ Test this trigger's 'Trigger.rule'.
|
||||
-> Party -- ^ Execute the rule as this 'Party'.
|
||||
-> [Party] -- ^ Execute the rule with these parties as `readAs`
|
||||
-> ACSBuilder -- ^ List these contracts in the 'ACS'.
|
||||
-> Map CommandId [Command] -- ^ The commands in flight.
|
||||
-> s -- ^ The trigger state.
|
||||
-> Script (s, [Commands]) -- ^ The 'Commands' and new state emitted by the rule. The 'CommandId's will start from @"0"@.
|
||||
testRule trigger party readAs acsBuilder commandsInFlight s = do
|
||||
time <- getTime
|
||||
acs <- buildACS party acsBuilder
|
||||
let config = TriggerConfig
|
||||
{ maxInFlightCommands = 1, maxActiveContracts = 1 }
|
||||
let state = TriggerState
|
||||
{ acs = acs
|
||||
, actAs = party
|
||||
, readAs = readAs
|
||||
, userState = s
|
||||
, commandsInFlight = commandsInFlight
|
||||
, config = config
|
||||
}
|
||||
-- A sufficiently powerful `TriggerF` command will entail redoing this as a
|
||||
-- natural transformation from `Free TriggerF` to `Free ScriptF`, which in
|
||||
-- turn will likely require exposing `ScriptF` or both in some "internal"
|
||||
-- fashion. Meanwhile, it is worth pushing `simulateRule` as far as possible
|
||||
-- to forestall that outcome. -SC
|
||||
let (state', commands, _) = simulateRule (runRule trigger.rule) time state
|
||||
pure (state'.userState, commands)
|
||||
|
||||
-- | Drop 'CommandId's and extract all 'Command's.
|
||||
flattenCommands : [Commands] -> [Command]
|
||||
flattenCommands = concatMap commands
|
||||
|
||||
expectCommand
|
||||
: [Command]
|
||||
-> (Command -> Optional a)
|
||||
-> (a -> Either Text ())
|
||||
-> Either [Text] ()
|
||||
expectCommand commands fromCommand assertion = foldl step (Left []) commands
|
||||
where
|
||||
step : Either [Text] () -> Command -> Either [Text] ()
|
||||
step (Right ()) _ = Right ()
|
||||
step (Left msgs) command =
|
||||
case assertion <$> fromCommand command of
|
||||
None -> Left msgs
|
||||
Some (Left msg) -> Left (msg :: msgs)
|
||||
Some (Right ()) -> Right ()
|
||||
|
||||
-- | Check that at least one command is a create command whose payload fulfills the given assertions.
|
||||
assertCreateCmd
|
||||
: (Template t, HasAgreement t, CanAbort m)
|
||||
=> [Command] -- ^ Check these commands.
|
||||
-> (t -> Either Text ()) -- ^ Perform these assertions.
|
||||
-> m ()
|
||||
assertCreateCmd commands assertion =
|
||||
case expectCommand commands fromCreate assertion of
|
||||
Right () -> pure ()
|
||||
Left msgs ->
|
||||
abort $ "Failure, found no matching create command." <> collectMessages msgs
|
||||
|
||||
-- | Check that at least one command is an exercise command whose contract id and choice argument fulfill the given assertions.
|
||||
assertExerciseCmd
|
||||
: (Template t, HasAgreement t, Choice t c r, CanAbort m)
|
||||
=> [Command] -- ^ Check these commands.
|
||||
-> ((ContractId t, c) -> Either Text ()) -- ^ Perform these assertions.
|
||||
-> m ()
|
||||
assertExerciseCmd commands assertion =
|
||||
case expectCommand commands fromExercise assertion of
|
||||
Right () -> pure ()
|
||||
Left msgs ->
|
||||
abort $ "Failure, found no matching exercise command." <> collectMessages msgs
|
||||
|
||||
-- | Check that at least one command is an exercise by key command whose key and choice argument fulfill the given assertions.
|
||||
assertExerciseByKeyCmd
|
||||
: forall t c r k m
|
||||
. (TemplateKey t k, Choice t c r, CanAbort m)
|
||||
=> [Command] -- ^ Check these commands.
|
||||
-> ((k, c) -> Either Text ()) -- ^ Perform these assertions.
|
||||
-> m ()
|
||||
assertExerciseByKeyCmd commands assertion =
|
||||
case expectCommand commands (fromExerciseByKey @t) assertion of
|
||||
Right () -> pure ()
|
||||
Left msgs ->
|
||||
abort $ "Failure, found no matching exerciseByKey command." <> collectMessages msgs
|
||||
|
||||
collectMessages : [Text] -> Text
|
||||
collectMessages [] = ""
|
||||
collectMessages msgs = "\n" <> Text.unlines (map (bullet . nest) msgs)
|
||||
where
|
||||
bullet txt = " * " <> txt
|
||||
nest = Text.intercalate "\n" . List.intersperse " " . Text.lines
|
@ -1,243 +0,0 @@
|
||||
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
-- SPDX-License-Identifier: Apache-2.0
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
|
||||
-- | MOVE Daml.Trigger
|
||||
module Daml.Trigger.Internal
|
||||
( ACS (..)
|
||||
, TriggerA (..)
|
||||
, TriggerUpdateA (..)
|
||||
, TriggerInitializeA (..)
|
||||
, addCommands
|
||||
, insertTpl
|
||||
, groupActiveContracts
|
||||
, deleteTpl
|
||||
, lookupTpl
|
||||
, applyEvent
|
||||
, applyTransaction
|
||||
, runRule
|
||||
, liftTriggerRule
|
||||
, TriggerAState (..)
|
||||
, TriggerState (..)
|
||||
, TriggerInitState(..)
|
||||
, TriggerUpdateState(..)
|
||||
) where
|
||||
|
||||
import DA.Action.State
|
||||
import DA.Foldable (forA_)
|
||||
import DA.Functor ((<&>))
|
||||
import DA.Map (Map)
|
||||
import qualified DA.Map as Map
|
||||
import DA.Optional (fromOptional)
|
||||
|
||||
import Daml.Trigger.LowLevel hiding (Trigger)
|
||||
import qualified Daml.Trigger.LowLevel as LowLevel
|
||||
|
||||
-- We use this singleton enumeration to track the version of the API between Daml code and Scala code of the trigger
|
||||
data Version = Version_2_6
|
||||
|
||||
-- public API
|
||||
|
||||
-- | HIDE Active contract set, you can use `getContracts` to access the templates of
|
||||
-- a given type.
|
||||
|
||||
-- This will change to a Map once we have proper maps in Daml-LF
|
||||
-- The following variant type should be kept in sync with methods numberOfActiveContracts and numberOfPendingContracts
|
||||
-- in Runner.scala
|
||||
data ACS = ACS
|
||||
{ activeContracts : Map.Map TemplateTypeRep (Map.Map AnyContractId AnyTemplate)
|
||||
, pendingContracts : Map CommandId [AnyContractId]
|
||||
}
|
||||
|
||||
-- | TriggerA is the type used in the `rule` of a Daml trigger.
|
||||
-- Its main feature is that you can call `emitCommands` to
|
||||
-- send commands to the ledger.
|
||||
newtype TriggerA s a =
|
||||
-- | HIDE
|
||||
TriggerA { runTriggerA : ACS -> TriggerRule (TriggerAState s) a }
|
||||
|
||||
instance Functor (TriggerA s) where
|
||||
fmap f (TriggerA r) = TriggerA $ rliftFmap fmap f r
|
||||
|
||||
instance Applicative (TriggerA s) where
|
||||
pure = TriggerA . rliftPure pure
|
||||
TriggerA ff <*> TriggerA fa = TriggerA $ rliftAp (<*>) ff fa
|
||||
|
||||
instance Action (TriggerA s) where
|
||||
TriggerA fa >>= f = TriggerA $ rliftBind (>>=) fa (runTriggerA . f)
|
||||
|
||||
instance ActionState s (TriggerA s) where
|
||||
get = TriggerA $ const (get <&> \tas -> tas.userState)
|
||||
modify f = TriggerA . const . modify $ \tas -> tas { userState = f tas.userState }
|
||||
|
||||
instance HasTime (TriggerA s) where
|
||||
getTime = TriggerA $ const getTime
|
||||
|
||||
-- | HIDE
|
||||
data TriggerUpdateState = TriggerUpdateState
|
||||
with
|
||||
commandsInFlight : Map CommandId [Command]
|
||||
acs : ACS
|
||||
actAs : Party
|
||||
readAs : [Party]
|
||||
|
||||
-- | TriggerUpdateA is the type used in the `updateState` of a Daml
|
||||
-- trigger. It has similar actions in common with `TriggerA`, but
|
||||
-- cannot use `emitCommands` or `getTime`.
|
||||
newtype TriggerUpdateA s a =
|
||||
-- | HIDE
|
||||
TriggerUpdateA { runTriggerUpdateA : TriggerUpdateState -> State s a }
|
||||
|
||||
instance Functor (TriggerUpdateA s) where
|
||||
fmap f (TriggerUpdateA r) = TriggerUpdateA $ rliftFmap fmap f r
|
||||
|
||||
instance Applicative (TriggerUpdateA s) where
|
||||
pure = TriggerUpdateA . rliftPure pure
|
||||
TriggerUpdateA ff <*> TriggerUpdateA fa = TriggerUpdateA $ rliftAp (<*>) ff fa
|
||||
|
||||
instance Action (TriggerUpdateA s) where
|
||||
TriggerUpdateA fa >>= f = TriggerUpdateA $ rliftBind (>>=) fa (runTriggerUpdateA . f)
|
||||
|
||||
instance ActionState s (TriggerUpdateA s) where
|
||||
get = TriggerUpdateA $ const get
|
||||
put = TriggerUpdateA . const . put
|
||||
modify = TriggerUpdateA . const . modify
|
||||
|
||||
-- | HIDE
|
||||
data TriggerInitState = TriggerInitState
|
||||
with
|
||||
acs : ACS
|
||||
actAs : Party
|
||||
readAs : [Party]
|
||||
|
||||
-- | TriggerInitializeA is the type used in the `initialize` of a Daml
|
||||
-- trigger. It can query, but not emit commands or update the state.
|
||||
newtype TriggerInitializeA a =
|
||||
-- | HIDE
|
||||
TriggerInitializeA { runTriggerInitializeA : TriggerInitState -> a }
|
||||
deriving (Functor, Applicative, Action)
|
||||
|
||||
-- Internal API
|
||||
|
||||
-- | HIDE
|
||||
addCommands : Map CommandId [Command] -> Commands -> Map CommandId [Command]
|
||||
addCommands m (Commands cid cmds) = Map.insert cid cmds m
|
||||
|
||||
-- | HIDE
|
||||
insertTpl : AnyContractId -> AnyTemplate -> ACS -> ACS
|
||||
insertTpl cid tpl acs = acs { activeContracts = Map.alter addct cid.templateId acs.activeContracts }
|
||||
where addct = Some . Map.insert cid tpl . fromOptional mempty
|
||||
|
||||
-- | HIDE
|
||||
groupActiveContracts :
|
||||
[(AnyContractId, AnyTemplate)] -> Map.Map TemplateTypeRep (Map.Map AnyContractId AnyTemplate)
|
||||
groupActiveContracts = foldr (\v@(cid, _) -> Map.alter (addct v) cid.templateId) Map.empty
|
||||
where addct (cid, tpl) = Some . Map.insert cid tpl . fromOptional mempty
|
||||
|
||||
-- | HIDE
|
||||
deleteTpl : AnyContractId -> ACS -> ACS
|
||||
deleteTpl cid acs = acs { activeContracts = Map.alter rmct cid.templateId acs.activeContracts }
|
||||
where rmct om = do
|
||||
m <- om
|
||||
let m' = Map.delete cid m
|
||||
if Map.null m' then None else Some m'
|
||||
|
||||
-- | HIDE
|
||||
lookupTpl : Template a => AnyContractId -> ACS -> Optional a
|
||||
lookupTpl cid acs = do
|
||||
tpl <- Map.lookup cid =<< Map.lookup cid.templateId acs.activeContracts
|
||||
fromAnyTemplate tpl
|
||||
|
||||
-- | HIDE
|
||||
applyEvent : Event -> ACS -> ACS
|
||||
applyEvent ev acs = case ev of
|
||||
CreatedEvent (Created _ cid (Some tpl) _) -> insertTpl cid tpl acs
|
||||
CreatedEvent _ -> acs
|
||||
ArchivedEvent (Archived _ cid) -> deleteTpl cid acs
|
||||
|
||||
-- | HIDE
|
||||
applyTransaction : Transaction -> ACS -> ACS
|
||||
applyTransaction (Transaction _ _ evs) acs = foldl (flip applyEvent) acs evs
|
||||
|
||||
-- | HIDE
|
||||
runRule
|
||||
: (Party -> TriggerA s a)
|
||||
-> TriggerRule (TriggerState s) a
|
||||
runRule rule = do
|
||||
state <- get
|
||||
TriggerRule . zoom zoomIn zoomOut . runTriggerRule . flip runTriggerA state.acs
|
||||
$ rule state.actAs
|
||||
where zoomIn state = TriggerAState state.commandsInFlight state.acs.pendingContracts state.userState state.readAs state.actAs state.config
|
||||
zoomOut state aState =
|
||||
let commandsInFlight = aState.commandsInFlight
|
||||
acs = state.acs { pendingContracts = aState.pendingContracts }
|
||||
userState = aState.userState
|
||||
readAs = aState.readAs
|
||||
actAs = aState.actAs
|
||||
config = aState.config
|
||||
in state { commandsInFlight, acs, userState, readAs, actAs, config }
|
||||
|
||||
-- | HIDE
|
||||
-- | Transform the (legacy) low-level trigger type into a batching trigger.
|
||||
runLegacyTrigger : LowLevel.Trigger s -> BatchTrigger s
|
||||
runLegacyTrigger userTrigger = BatchTrigger
|
||||
{ initialState = \args -> userTrigger.initialState args.actAs args.readAs args.acs
|
||||
, update = \msgs -> forA_ msgs userTrigger.update
|
||||
, registeredTemplates = userTrigger.registeredTemplates
|
||||
, heartbeat = userTrigger.heartbeat
|
||||
}
|
||||
|
||||
-- | HIDE
|
||||
liftTriggerRule : TriggerRule (TriggerAState s) a -> TriggerA s a
|
||||
liftTriggerRule = TriggerA . const
|
||||
|
||||
-- | HIDE
|
||||
data TriggerAState s = TriggerAState
|
||||
{ commandsInFlight : Map CommandId [Command]
|
||||
-- ^ Zoomed from TriggerState; used for dedupCreateCmd/dedupExerciseCmd
|
||||
-- helpers and extended by emitCommands.
|
||||
, pendingContracts : Map CommandId [AnyContractId]
|
||||
-- ^ Map from command ids to the contract ids marked pending by that command;
|
||||
-- zoomed from TriggerState's acs.
|
||||
, userState : s
|
||||
-- ^ zoomed from TriggerState
|
||||
, readAs : [Party]
|
||||
-- ^ zoomed from TriggerState
|
||||
, actAs : Party
|
||||
-- ^ zoomed from TriggerState
|
||||
, config : TriggerConfig
|
||||
-- ^ zoomed from TriggerState
|
||||
}
|
||||
|
||||
-- | HIDE
|
||||
-- The following record type should be kept in sync with method numberOfInFlightCommands in Runner.scala
|
||||
data TriggerState s = TriggerState
|
||||
{ acs : ACS
|
||||
, actAs : Party
|
||||
, readAs : [Party]
|
||||
, userState : s
|
||||
, commandsInFlight : Map CommandId [Command]
|
||||
, config : TriggerConfig
|
||||
}
|
||||
|
||||
-- | HIDE
|
||||
--
|
||||
-- unboxed newtype for common Trigger*A additions
|
||||
type TriggerAT r f a = r -> f a
|
||||
|
||||
-- | HIDE
|
||||
rliftFmap : ((a -> b) -> f a -> f b) -> (a -> b) -> TriggerAT r f a -> TriggerAT r f b
|
||||
rliftFmap ub f r = ub f . r
|
||||
|
||||
-- | HIDE
|
||||
rliftPure : (a -> f a) -> a -> TriggerAT r f a
|
||||
rliftPure ub = const . ub
|
||||
|
||||
-- | HIDE
|
||||
rliftAp : (f (a -> b) -> f a -> f b) -> TriggerAT r f (a -> b) -> TriggerAT r f a -> TriggerAT r f b
|
||||
rliftAp ub ff fa r = ff r `ub` fa r
|
||||
|
||||
-- | HIDE
|
||||
rliftBind : (f a -> (a -> f b) -> f b) -> TriggerAT r f a -> (a -> TriggerAT r f b) -> TriggerAT r f b
|
||||
rliftBind ub fa f r = fa r `ub` \a -> f a r
|
@ -1,364 +0,0 @@
|
||||
-- Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
-- SPDX-License-Identifier: Apache-2.0
|
||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
|
||||
module Daml.Trigger.LowLevel
|
||||
( Message(..)
|
||||
, Completion(..)
|
||||
, CompletionStatus(..)
|
||||
, TriggerConfig(..)
|
||||
, TriggerSetupArguments(..)
|
||||
, Transaction(..)
|
||||
, AnyContractId
|
||||
, toAnyContractId
|
||||
, fromAnyContractId
|
||||
, TransactionId(..)
|
||||
, EventId(..)
|
||||
, CommandId(..)
|
||||
, Event(..)
|
||||
, Created(..)
|
||||
, fromCreated
|
||||
, Archived(..)
|
||||
, fromArchived
|
||||
, Trigger(..)
|
||||
, BatchTrigger(..)
|
||||
, ActiveContracts(..)
|
||||
, Commands(..)
|
||||
, Command(..)
|
||||
, createCmd
|
||||
, exerciseCmd
|
||||
, exerciseByKeyCmd
|
||||
, createAndExerciseCmd
|
||||
, fromCreate
|
||||
, fromExercise
|
||||
, fromExerciseByKey
|
||||
, fromCreateAndExercise
|
||||
, RegisteredTemplates(..)
|
||||
, registeredTemplate
|
||||
, RelTime(..)
|
||||
, execStateT
|
||||
, ActionTrigger(..)
|
||||
, TriggerSetup(..)
|
||||
, TriggerRule(..)
|
||||
, ActionState(..)
|
||||
, zoom
|
||||
, submitCommands
|
||||
, simulateRule
|
||||
) where
|
||||
|
||||
import DA.Action.State
|
||||
import DA.Action.State.Class
|
||||
import DA.Functor ((<&>))
|
||||
import DA.Internal.Interface.AnyView.Types
|
||||
import DA.Time (RelTime(..))
|
||||
import Daml.Script.Free (Free(..), lift, foldFree)
|
||||
|
||||
-- | This type represents the contract id of an unknown template.
|
||||
-- You can use `fromAnyContractId` to check which template it corresponds to.
|
||||
data AnyContractId = AnyContractId
|
||||
{ templateId : TemplateTypeRep
|
||||
, contractId : ContractId ()
|
||||
} deriving Eq
|
||||
|
||||
deriving instance Ord AnyContractId
|
||||
|
||||
-- We can’t derive the Show instance since TemplateTypeRep does not have a Show instance
|
||||
-- but it is useful for debugging so we add one that omits the type.
|
||||
instance Show AnyContractId where
|
||||
showsPrec d (AnyContractId _ cid) = showParen (d > app_prec) $
|
||||
showString "AnyContractId " . showsPrec (app_prec +1) cid
|
||||
where app_prec = 10
|
||||
|
||||
|
||||
-- | Wrap a `ContractId t` in `AnyContractId`.
|
||||
toAnyContractId : forall t. Template t => ContractId t -> AnyContractId
|
||||
toAnyContractId cid = AnyContractId
|
||||
{ templateId = templateTypeRep @t
|
||||
, contractId = coerceContractId cid
|
||||
}
|
||||
|
||||
-- | Check if a `AnyContractId` corresponds to the given template or return
|
||||
-- `None` otherwise.
|
||||
fromAnyContractId : forall t. Template t => AnyContractId -> Optional (ContractId t)
|
||||
fromAnyContractId cid
|
||||
| cid.templateId == templateTypeRep @t = Some (coerceContractId cid.contractId)
|
||||
| otherwise = None
|
||||
|
||||
newtype TransactionId = TransactionId Text
|
||||
deriving (Show, Eq)
|
||||
|
||||
newtype EventId = EventId Text
|
||||
deriving (Show, Eq)
|
||||
|
||||
newtype CommandId = CommandId Text
|
||||
deriving (Show, Eq, Ord)
|
||||
|
||||
data Transaction = Transaction
|
||||
{ transactionId : TransactionId
|
||||
, commandId : Optional CommandId
|
||||
, events : [Event]
|
||||
}
|
||||
|
||||
data InterfaceView = InterfaceView {
|
||||
interfaceTypeRep : TemplateTypeRep,
|
||||
anyView: Optional AnyView
|
||||
}
|
||||
|
||||
-- | An event in a transaction.
|
||||
-- This definition should be kept consistent with the object `EventVariant` defined in
|
||||
-- triggers/runner/src/main/scala/com/digitalasset/daml/lf/engine/trigger/Converter.scala
|
||||
data Event
|
||||
= CreatedEvent Created
|
||||
| ArchivedEvent Archived
|
||||
|
||||
-- | The data in a `Created` event.
|
||||
data Created = Created
|
||||
{ eventId : EventId
|
||||
, contractId : AnyContractId
|
||||
, argument : Optional AnyTemplate
|
||||
, views : [InterfaceView]
|
||||
}
|
||||
|
||||
-- | Check if a `Created` event corresponds to the given template.
|
||||
fromCreated : Template t => Created -> Optional (EventId, ContractId t, t)
|
||||
fromCreated Created {eventId, contractId, argument}
|
||||
| Some contractId' <- fromAnyContractId contractId
|
||||
, Some argument' <- argument
|
||||
, Some argument'' <- fromAnyTemplate argument'
|
||||
= Some (eventId, contractId', argument'')
|
||||
| otherwise
|
||||
= None
|
||||
|
||||
-- | The data in an `Archived` event.
|
||||
data Archived = Archived
|
||||
{ eventId : EventId
|
||||
, contractId : AnyContractId
|
||||
} deriving (Show, Eq)
|
||||
|
||||
-- | Check if an `Archived` event corresponds to the given template.
|
||||
fromArchived : Template t => Archived -> Optional (EventId, ContractId t)
|
||||
fromArchived Archived {eventId, contractId}
|
||||
| Some contractId' <- fromAnyContractId contractId
|
||||
= Some (eventId, contractId')
|
||||
| otherwise
|
||||
= None
|
||||
|
||||
-- | Either a transaction or a completion.
|
||||
-- This definition should be kept consistent with the object `MessageVariant` defined in
|
||||
-- triggers/runner/src/main/scala/com/digitalasset/daml/lf/engine/trigger/Converter.scala
|
||||
data Message
|
||||
= MTransaction Transaction
|
||||
| MCompletion Completion
|
||||
| MHeartbeat
|
||||
|
||||
-- | A completion message.
|
||||
-- Note that you will only get completions for commands emitted from the trigger.
|
||||
-- Contrary to the ledger API completion stream, this also includes
|
||||
-- synchronous failures.
|
||||
|
||||
data Completion = Completion
|
||||
{ commandId : CommandId
|
||||
, status : CompletionStatus
|
||||
} deriving Show
|
||||
|
||||
|
||||
-- This definition should be kept consistent with the object `CompletionStatusVariant` defined in
|
||||
-- triggers/runner/src/main/scala/com/digitalasset/daml/lf/engine/trigger/Converter.scala
|
||||
data CompletionStatus
|
||||
= Failed { status : Int, message : Text }
|
||||
| Succeeded { transactionId : TransactionId }
|
||||
deriving Show
|
||||
|
||||
-- Introduced in version 2.6.0
|
||||
data TriggerConfig = TriggerConfig
|
||||
{ maxInFlightCommands : Int
|
||||
-- ^ maximum number of commands that should be allowed to be in-flight at any point in time.
|
||||
-- Exceeding this value will eventually lead to the trigger run raising an InFlightCommandOverflowException exception.
|
||||
, maxActiveContracts : Int
|
||||
-- ^ maximum number of active contracts that we will allow to be stored
|
||||
-- Exceeding this value will lead to the trigger runner raising an ACSOverflowException exception.
|
||||
}
|
||||
|
||||
-- Introduced in version 2.5.1: this definition is used to simplify future extensions of trigger initialState arguments
|
||||
data TriggerSetupArguments = TriggerSetupArguments
|
||||
{ actAs : Party
|
||||
, readAs : [Party]
|
||||
, acs : ActiveContracts
|
||||
, config : TriggerConfig -- added in version 2.6.0
|
||||
}
|
||||
|
||||
data ActiveContracts = ActiveContracts { activeContracts : [Created] }
|
||||
|
||||
-- @WARN use 'BatchTrigger s' instead of 'Trigger s'
|
||||
data Trigger s = Trigger
|
||||
{ initialState : Party -> [Party] -> ActiveContracts -> TriggerSetup s
|
||||
, update : Message -> TriggerRule s ()
|
||||
, registeredTemplates : RegisteredTemplates
|
||||
, heartbeat : Optional RelTime
|
||||
}
|
||||
|
||||
-- | Batching trigger is (approximately) a left-fold over `Message` with
|
||||
-- an accumulator of type `s`.
|
||||
data BatchTrigger s = BatchTrigger
|
||||
{ initialState : TriggerSetupArguments -> TriggerSetup s
|
||||
, update : [Message] -> TriggerRule s ()
|
||||
, registeredTemplates : RegisteredTemplates
|
||||
, heartbeat : Optional RelTime
|
||||
}
|
||||
|
||||
-- | A template that the trigger will receive events for.
|
||||
newtype RegisteredTemplate = RegisteredTemplate TemplateTypeRep
|
||||
|
||||
-- This controls which templates the trigger will receive events for.
|
||||
-- `AllInDar` is a safe default but for performance reasons you might
|
||||
-- want to limit it to limit the templates that the trigger will receive
|
||||
-- events for.
|
||||
data RegisteredTemplates
|
||||
= AllInDar -- ^ Listen to events for all templates in the given DAR.
|
||||
| RegisteredTemplates [RegisteredTemplate]
|
||||
|
||||
registeredTemplate : forall t. Template t => RegisteredTemplate
|
||||
registeredTemplate = RegisteredTemplate (templateTypeRep @t)
|
||||
|
||||
-- | A ledger API command. To construct a command use `createCmd` and `exerciseCmd`.
|
||||
data Command
|
||||
= CreateCommand
|
||||
{ templateArg : AnyTemplate
|
||||
}
|
||||
| ExerciseCommand
|
||||
{ contractId : AnyContractId
|
||||
, choiceArg : AnyChoice
|
||||
}
|
||||
| CreateAndExerciseCommand
|
||||
{ templateArg : AnyTemplate
|
||||
, choiceArg : AnyChoice
|
||||
}
|
||||
| ExerciseByKeyCommand
|
||||
{ tplTypeRep : TemplateTypeRep
|
||||
, contractKey : AnyContractKey
|
||||
, choiceArg : AnyChoice
|
||||
}
|
||||
|
||||
-- | Create a contract of the given template.
|
||||
createCmd : Template t => t -> Command
|
||||
createCmd templateArg =
|
||||
CreateCommand (toAnyTemplate templateArg)
|
||||
|
||||
-- | Exercise the given choice.
|
||||
exerciseCmd : forall t c r. Choice t c r => ContractId t -> c -> Command
|
||||
exerciseCmd contractId choiceArg =
|
||||
ExerciseCommand (toAnyContractId contractId) (toAnyChoice @t choiceArg)
|
||||
|
||||
-- | Create a contract of the given template and immediately exercise
|
||||
-- the given choice on it.
|
||||
createAndExerciseCmd : forall t c r. (Template t, Choice t c r) => t -> c -> Command
|
||||
createAndExerciseCmd templateArg choiceArg =
|
||||
CreateAndExerciseCommand (toAnyTemplate templateArg) (toAnyChoice @t choiceArg)
|
||||
|
||||
exerciseByKeyCmd : forall t c r k. (Choice t c r, TemplateKey t k) => k -> c -> Command
|
||||
exerciseByKeyCmd contractKey choiceArg =
|
||||
ExerciseByKeyCommand (templateTypeRep @t) (toAnyContractKey @t contractKey) (toAnyChoice @t choiceArg)
|
||||
|
||||
-- | Check if the command corresponds to a create command
|
||||
-- for the given template.
|
||||
fromCreate : Template t => Command -> Optional t
|
||||
fromCreate (CreateCommand t) = fromAnyTemplate t
|
||||
fromCreate _ = None
|
||||
|
||||
-- | Check if the command corresponds to a create and exercise command
|
||||
-- for the given template.
|
||||
fromCreateAndExercise : forall t c r. (Template t, Choice t c r) => Command -> Optional (t, c)
|
||||
fromCreateAndExercise (CreateAndExerciseCommand t c) = (,) <$> fromAnyTemplate t <*> fromAnyChoice @t c
|
||||
fromCreateAndExercise _ = None
|
||||
|
||||
-- | Check if the command corresponds to an exercise command
|
||||
-- for the given template.
|
||||
fromExercise : forall t c r. Choice t c r => Command -> Optional (ContractId t, c)
|
||||
fromExercise (ExerciseCommand cid c) = (,) <$> fromAnyContractId cid <*> fromAnyChoice @t c
|
||||
fromExercise _ = None
|
||||
|
||||
-- | Check if the command corresponds to an exercise by key command
|
||||
-- for the given template.
|
||||
fromExerciseByKey : forall t c r k. (Choice t c r, TemplateKey t k) => Command -> Optional (k, c)
|
||||
fromExerciseByKey (ExerciseByKeyCommand tyRep k c)
|
||||
| tyRep == templateTypeRep @t = (,) <$> fromAnyContractKey @t k <*> fromAnyChoice @t c
|
||||
fromExerciseByKey _ = None
|
||||
|
||||
-- | A set of commands that are submitted as a single transaction.
|
||||
data Commands = Commands
|
||||
{ commandId : CommandId
|
||||
, commands : [Command]
|
||||
}
|
||||
|
||||
newtype StateT s m a = StateT { runStateT : s -> m (a, s) }
|
||||
deriving Functor
|
||||
|
||||
liftStateT : Functor m => m a -> StateT s m a
|
||||
liftStateT ma = StateT $ \s -> (,s) <$> ma
|
||||
|
||||
instance Action m => Applicative (StateT s m) where
|
||||
pure a = StateT (\s -> pure (a, s))
|
||||
f <*> x = f >>= (<$> x)
|
||||
|
||||
instance Action m => Action (StateT s m) where
|
||||
StateT x >>= f = StateT $ \s -> do
|
||||
(x', s') <- x s
|
||||
runStateT (f x') s'
|
||||
|
||||
instance Applicative m => ActionState s (StateT s m) where
|
||||
get = StateT $ \s -> pure (s, s)
|
||||
put s = StateT $ const $ pure ((), s)
|
||||
modify f = StateT $ \s -> pure ((), f s)
|
||||
|
||||
execStateT : Functor m => StateT s m a -> s -> m s
|
||||
execStateT (StateT fa) = fmap snd . fa
|
||||
|
||||
zoom : Functor m => (t -> s) -> (t -> s -> t) -> StateT s m a -> StateT t m a
|
||||
zoom r w (StateT smas) = StateT $ \t ->
|
||||
smas (r t) <&> \(a, s) -> (a, w t s)
|
||||
|
||||
-- Must be kept in sync with Runner#freeTriggerSubmits
|
||||
data TriggerF a =
|
||||
GetTime (Time -> a)
|
||||
| Submit ([Command], Text -> a)
|
||||
deriving Functor
|
||||
|
||||
newtype TriggerSetup a = TriggerSetup { runTriggerSetup : Free TriggerF a }
|
||||
deriving (Functor, Applicative, Action)
|
||||
|
||||
newtype TriggerRule s a = TriggerRule { runTriggerRule : StateT s (Free TriggerF) a }
|
||||
deriving (Functor, Applicative, Action)
|
||||
|
||||
-- | Run a rule without running it. May lose information from the rule;
|
||||
-- meant for testing purposes only.
|
||||
simulateRule : TriggerRule s a -> Time -> s -> (s, [Commands], a)
|
||||
simulateRule rule time s = (s', reverse cmds, a)
|
||||
where ((a, s'), (cmds, _)) = runState (foldFree sim (runStateT (runTriggerRule rule) s)) ([], 0)
|
||||
sim : TriggerF x -> State ([Commands], Int) x
|
||||
sim (GetTime f) = pure (f time)
|
||||
sim (Submit (cmds, f)) = do
|
||||
(pastCmds, nextId) <- get
|
||||
let nextIdShown = show nextId
|
||||
put (Commands (CommandId nextIdShown) cmds :: pastCmds, nextId + 1)
|
||||
pure $ f nextIdShown
|
||||
|
||||
deriving instance ActionState s (TriggerRule s)
|
||||
|
||||
-- | Low-level trigger actions.
|
||||
class HasTime m => ActionTrigger m where
|
||||
liftTF : TriggerF a -> m a
|
||||
|
||||
instance ActionTrigger TriggerSetup where
|
||||
liftTF = TriggerSetup . lift
|
||||
|
||||
instance ActionTrigger (TriggerRule s) where
|
||||
liftTF = TriggerRule . liftStateT . lift
|
||||
|
||||
instance HasTime TriggerSetup where
|
||||
getTime = liftTF (GetTime identity)
|
||||
|
||||
instance HasTime (TriggerRule s) where
|
||||
getTime = liftTF (GetTime identity)
|
||||
|
||||
submitCommands : ActionTrigger m => [Command] -> m CommandId
|
||||
submitCommands cmds = liftTF (Submit (cmds, CommandId))
|
@ -1,12 +0,0 @@
|
||||
-- Hoogle documentation for Daml, generated by damlc
|
||||
-- See Hoogle, http://www.haskell.org/hoogle/
|
||||
|
||||
-- Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates.
|
||||
-- All rights reserved. Any unauthorized use, duplication or distribution is strictly prohibited.
|
||||
|
||||
-- | Daml Trigger library.
|
||||
@url {{base-url}}
|
||||
@package daml-trigger
|
||||
@version 1.2.0
|
||||
|
||||
{{{body}}}
|
@ -1,11 +0,0 @@
|
||||
.. Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
.. SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
.. _daml-trigger-api-docs:
|
||||
|
||||
Daml Trigger Library
|
||||
====================
|
||||
|
||||
The Daml Trigger library defines the API used to declare a Daml trigger. See :doc:`/triggers/index`:: for more information on Daml triggers.
|
||||
|
||||
{{{body}}}
|
@ -1,4 +0,0 @@
|
||||
.. Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
.. SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
{{{body}}}
|
@ -1,23 +0,0 @@
|
||||
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
load(
|
||||
"//bazel_tools:scala.bzl",
|
||||
"da_scala_library",
|
||||
)
|
||||
|
||||
da_scala_library(
|
||||
name = "metrics",
|
||||
srcs = glob(["src/main/scala/**/*.scala"]),
|
||||
scala_deps = [],
|
||||
tags = ["maven_coordinates=com.daml:trigger-metrics:__VERSION__"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
runtime_deps = [],
|
||||
deps = [
|
||||
"//observability/metrics",
|
||||
"//observability/pekko-http-metrics",
|
||||
"@maven//:io_opentelemetry_opentelemetry_api",
|
||||
],
|
||||
)
|
@ -1,16 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.lf.engine.trigger.metrics
|
||||
|
||||
import com.daml.metrics.api.opentelemetry.OpenTelemetryMetricsFactory
|
||||
import com.daml.metrics.http.DamlHttpMetrics
|
||||
import io.opentelemetry.api.metrics.{Meter => OtelMeter}
|
||||
|
||||
case class TriggerServiceMetrics(otelMeter: OtelMeter) {
|
||||
|
||||
val openTelemetryFactory = new OpenTelemetryMetricsFactory(otelMeter)
|
||||
|
||||
val http = new DamlHttpMetrics(openTelemetryFactory, "trigger_service")
|
||||
|
||||
}
|
@ -1,63 +0,0 @@
|
||||
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
load(
|
||||
"//bazel_tools:scala.bzl",
|
||||
"da_scala_binary",
|
||||
"da_scala_library",
|
||||
"lf_scalacopts_stricter",
|
||||
)
|
||||
|
||||
da_scala_library(
|
||||
name = "trigger-runner-lib",
|
||||
srcs = glob(["src/main/scala/**/*.scala"]),
|
||||
scala_deps = [
|
||||
"@maven//:com_github_scopt_scopt",
|
||||
"@maven//:org_apache_pekko_pekko_actor",
|
||||
"@maven//:org_apache_pekko_pekko_stream",
|
||||
"@maven//:org_scalaz_scalaz_core",
|
||||
"@maven//:org_typelevel_paiges_core",
|
||||
],
|
||||
scalacopts = lf_scalacopts_stricter,
|
||||
tags = ["maven_coordinates=com.daml:trigger-runner:__VERSION__"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//canton:ledger_api_proto_scala",
|
||||
"//daml-lf/archive:daml_lf_archive_reader",
|
||||
"//daml-lf/data",
|
||||
"//daml-lf/engine",
|
||||
"//daml-lf/interpreter",
|
||||
"//daml-lf/language",
|
||||
"//daml-lf/transaction",
|
||||
"//daml-script/converter",
|
||||
"//ledger-service/cli-opts",
|
||||
"//ledger/ledger-api-client",
|
||||
"//ledger/ledger-api-common",
|
||||
"//ledger/ledger-api-domain",
|
||||
"//libs-scala/auth-utils",
|
||||
"//libs-scala/contextualized-logging",
|
||||
"//libs-scala/logging-entries",
|
||||
"//libs-scala/rs-grpc-bridge",
|
||||
"//libs-scala/rs-grpc-pekko",
|
||||
"//libs-scala/scala-utils",
|
||||
"//observability/tracing",
|
||||
"@maven//:ch_qos_logback_logback_classic",
|
||||
"@maven//:io_grpc_grpc_api",
|
||||
"@maven//:io_netty_netty_handler",
|
||||
"@maven//:org_slf4j_slf4j_api",
|
||||
],
|
||||
)
|
||||
|
||||
da_scala_binary(
|
||||
name = "trigger-runner",
|
||||
main_class = "com.daml.lf.engine.trigger.RunnerMain",
|
||||
resources = ["src/main/resources/logback.xml"],
|
||||
scalacopts = lf_scalacopts_stricter,
|
||||
tags = ["ee-jar-license"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":trigger-runner-lib",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(["src/main/resources/logback.xml"])
|
@ -1,30 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<configuration>
|
||||
<appender name="console" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<if condition='isDefined("LOG_FORMAT_JSON")'>
|
||||
<then>
|
||||
<encoder class="net.logstash.logback.encoder.LogstashEncoder"/>
|
||||
</then>
|
||||
<else>
|
||||
<encoder>
|
||||
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg %replace(, context: %marker){', context: $', ''} %n</pattern>
|
||||
</encoder>
|
||||
</else>
|
||||
</if>
|
||||
</appender>
|
||||
|
||||
<appender name="STDOUT" class="net.logstash.logback.appender.LoggingEventAsyncDisruptorAppender">
|
||||
<appender-ref ref="console" />
|
||||
</appender>
|
||||
|
||||
<logger name="io.netty" level="WARN" />
|
||||
<logger name="io.grpc.netty" level="WARN" />
|
||||
<logger name="pekko.event.slf4j" level="WARN" />
|
||||
<logger name="daml.tracelog" level="DEBUG" />
|
||||
<logger name="daml.warnings" level="WARN" />
|
||||
<logger name="com.daml.lf.engine.trigger" level="DEBUG" />
|
||||
|
||||
<root level="${LOG_LEVEL_ROOT:-INFO}">
|
||||
<appender-ref ref="STDOUT" />
|
||||
</root>
|
||||
</configuration>
|
@ -1,692 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.lf
|
||||
package engine
|
||||
package trigger
|
||||
|
||||
import scalaz.std.either._
|
||||
import scalaz.std.list._
|
||||
import scalaz.std.option._
|
||||
import scalaz.syntax.traverse._
|
||||
import com.daml.lf.data.{BackStack, FrontStack, ImmArray}
|
||||
import com.daml.lf.data.Ref._
|
||||
import com.daml.lf.language.Ast._
|
||||
import com.daml.lf.speedy.{ArrayList, SValue}
|
||||
import com.daml.lf.speedy.SValue._
|
||||
import com.daml.lf.value.Value.ContractId
|
||||
import com.daml.ledger.api.v1.commands.{
|
||||
CreateAndExerciseCommand,
|
||||
CreateCommand,
|
||||
ExerciseByKeyCommand,
|
||||
ExerciseCommand,
|
||||
Command => ApiCommand,
|
||||
}
|
||||
import com.daml.ledger.api.v1.completion.Completion
|
||||
import com.daml.ledger.api.v1.event.{ArchivedEvent, CreatedEvent, Event, InterfaceView}
|
||||
import com.daml.ledger.api.v1.transaction.Transaction
|
||||
import com.daml.ledger.api.v1.value
|
||||
import com.daml.ledger.api.validation.NoLoggingValueValidator
|
||||
import com.daml.lf.language.LanguageVersionRangeOps.LanguageVersionRange
|
||||
import com.daml.lf.language.StablePackages
|
||||
import com.daml.lf.speedy.Command
|
||||
import com.daml.lf.value.Value
|
||||
import com.daml.platform.participant.util.LfEngineToApi.{
|
||||
lfValueToApiRecord,
|
||||
lfValueToApiValue,
|
||||
toApiIdentifier,
|
||||
}
|
||||
import com.daml.script.converter.ConverterException
|
||||
import com.daml.script.converter.Converter._
|
||||
import com.daml.script.converter.Converter.Implicits._
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.concurrent.duration.{FiniteDuration, MICROSECONDS}
|
||||
|
||||
// Convert from a Ledger API transaction to an SValue corresponding to a Message from the Daml.Trigger module
|
||||
final class Converter(
|
||||
compiledPackages: CompiledPackages,
|
||||
triggerDef: TriggerDefinition,
|
||||
) {
|
||||
|
||||
import Converter._
|
||||
|
||||
private[this] val valueTranslator = new preprocessing.ValueTranslator(
|
||||
compiledPackages.pkgInterface,
|
||||
requireV1ContractIdSuffix = false,
|
||||
)
|
||||
|
||||
private[this] def validateRecord(r: value.Record): Either[String, Value.ValueRecord] =
|
||||
NoLoggingValueValidator.validateRecord(r).left.map(_.getMessage)
|
||||
|
||||
private[this] def translateValue(ty: Type, value: Value): Either[String, SValue] =
|
||||
valueTranslator
|
||||
.strictTranslateValue(ty, value)
|
||||
.left
|
||||
.map(res => s"Failure to translate value: $res")
|
||||
|
||||
private[this] val triggerIds: TriggerIds = triggerDef.triggerIds
|
||||
|
||||
private[this] val stablePackages = StablePackages(
|
||||
compiledPackages.compilerConfig.allowedLanguageVersions.majorVersion
|
||||
)
|
||||
private[this] val templateTypeRepTyCon = stablePackages.TemplateTypeRep
|
||||
private[this] val anyTemplateTyCon = stablePackages.AnyTemplate
|
||||
private[this] val anyViewTyCon = stablePackages.AnyView
|
||||
private[this] val activeContractsTy = triggerIds.damlTriggerLowLevel("ActiveContracts")
|
||||
private[this] val triggerConfigTy = triggerIds.damlTriggerLowLevel("TriggerConfig")
|
||||
private[this] val triggerSetupArgumentsTy =
|
||||
triggerIds.damlTriggerLowLevel("TriggerSetupArguments")
|
||||
private[this] val triggerStateTy = triggerIds.damlTriggerInternal("TriggerState")
|
||||
private[this] val anyContractIdTy = triggerIds.damlTriggerLowLevel("AnyContractId")
|
||||
private[this] val archivedTy = triggerIds.damlTriggerLowLevel("Archived")
|
||||
private[this] val commandIdTy = triggerIds.damlTriggerLowLevel("CommandId")
|
||||
private[this] val completionTy = triggerIds.damlTriggerLowLevel("Completion")
|
||||
private[this] val createdTy = triggerIds.damlTriggerLowLevel("Created")
|
||||
private[this] val eventIdTy = triggerIds.damlTriggerLowLevel("EventId")
|
||||
private[this] val eventTy = triggerIds.damlTriggerLowLevel("Event")
|
||||
private[this] val failedTy = triggerIds.damlTriggerLowLevel("CompletionStatus.Failed")
|
||||
private[this] val messageTy = triggerIds.damlTriggerLowLevel("Message")
|
||||
private[this] val succeedTy = triggerIds.damlTriggerLowLevel("CompletionStatus.Succeeded")
|
||||
private[this] val transactionIdTy = triggerIds.damlTriggerLowLevel("TransactionId")
|
||||
private[this] val transactionTy = triggerIds.damlTriggerLowLevel("Transaction")
|
||||
private[this] val createCommandTy = triggerIds.damlTriggerLowLevel("CreateCommand")
|
||||
private[this] val exerciseCommandTy = triggerIds.damlTriggerLowLevel("ExerciseCommand")
|
||||
private[this] val createAndExerciseCommandTy =
|
||||
triggerIds.damlTriggerLowLevel("CreateAndExerciseCommand")
|
||||
private[this] val exerciseByKeyCommandTy = triggerIds.damlTriggerLowLevel("ExerciseByKeyCommand")
|
||||
private[this] val acsTy = triggerIds.damlTriggerInternal("ACS")
|
||||
|
||||
private[this] def fromTemplateTypeRep(tyCon: value.Identifier): SValue =
|
||||
record(
|
||||
templateTypeRepTyCon,
|
||||
"getTemplateTypeRep" -> STypeRep(
|
||||
TTyCon(
|
||||
TypeConName(
|
||||
PackageId.assertFromString(tyCon.packageId),
|
||||
QualifiedName(
|
||||
DottedName.assertFromString(tyCon.moduleName),
|
||||
DottedName.assertFromString(tyCon.entityName),
|
||||
),
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
private[this] def fromTransactionId(transactionId: String): SValue =
|
||||
record(transactionIdTy, "unpack" -> SText(transactionId))
|
||||
|
||||
private[this] def fromEventId(eventId: String): SValue =
|
||||
record(eventIdTy, "unpack" -> SText(eventId))
|
||||
|
||||
private[trigger] def fromCommandId(commandId: String): SValue =
|
||||
record(commandIdTy, "unpack" -> SText(commandId))
|
||||
|
||||
private[this] def fromOptionalCommandId(commandId: String): SValue =
|
||||
if (commandId.isEmpty)
|
||||
SOptional(None)
|
||||
else
|
||||
SOptional(Some(fromCommandId(commandId)))
|
||||
|
||||
private[this] def fromAnyContractId(
|
||||
templateId: value.Identifier,
|
||||
contractId: String,
|
||||
): SValue =
|
||||
record(
|
||||
anyContractIdTy,
|
||||
"templateId" -> fromTemplateTypeRep(templateId),
|
||||
"contractId" -> SContractId(ContractId.assertFromString(contractId)),
|
||||
)
|
||||
|
||||
private def fromArchivedEvent(archived: ArchivedEvent): SValue =
|
||||
record(
|
||||
archivedTy,
|
||||
"eventId" -> fromEventId(archived.eventId),
|
||||
"contractId" -> fromAnyContractId(archived.getTemplateId, archived.contractId),
|
||||
)
|
||||
|
||||
private[this] def fromRecord(typ: Type, record: value.Record): Either[String, SValue] =
|
||||
for {
|
||||
record <- validateRecord(record)
|
||||
tmplPayload <- translateValue(typ, record)
|
||||
} yield tmplPayload
|
||||
|
||||
private[this] def fromAnyTemplate(typ: Type, value: SValue) =
|
||||
record(anyTemplateTyCon, "getAnyTemplate" -> SAny(typ, value))
|
||||
|
||||
private[this] def fromV20CreatedEvent(
|
||||
created: CreatedEvent
|
||||
): Either[String, SValue] =
|
||||
for {
|
||||
tmplId <- fromIdentifier(created.getTemplateId)
|
||||
tmplType = TTyCon(tmplId)
|
||||
tmplPayload <- fromRecord(tmplType, created.getCreateArguments)
|
||||
} yield {
|
||||
record(
|
||||
createdTy,
|
||||
"eventId" -> fromEventId(created.eventId),
|
||||
"contractId" -> fromAnyContractId(created.getTemplateId, created.contractId),
|
||||
"argument" -> fromAnyTemplate(tmplType, tmplPayload),
|
||||
)
|
||||
}
|
||||
|
||||
private[this] def fromAnyView(typ: Type, value: SValue) =
|
||||
record(anyViewTyCon, "getAnyView" -> SAny(typ, value))
|
||||
|
||||
private[this] def fromInterfaceView(view: InterfaceView): Either[String, SOptional] =
|
||||
for {
|
||||
ifaceId <- fromIdentifier(view.getInterfaceId)
|
||||
iface <- compiledPackages.pkgInterface.lookupInterface(ifaceId).left.map(_.pretty)
|
||||
viewType = iface.view
|
||||
viewValue <- view.viewValue.traverseU(fromRecord(viewType, _))
|
||||
} yield SOptional(viewValue.map(fromAnyView(viewType, _)))
|
||||
|
||||
private[this] def fromV250CreatedEvent(
|
||||
created: CreatedEvent
|
||||
): Either[String, SValue] =
|
||||
for {
|
||||
tmplId <- fromIdentifier(created.getTemplateId)
|
||||
tmplType = TTyCon(tmplId)
|
||||
tmplPayload <- created.createArguments.traverseU(fromRecord(tmplType, _))
|
||||
views <- created.interfaceViews.toList.traverseU(fromInterfaceView)
|
||||
} yield {
|
||||
record(
|
||||
createdTy,
|
||||
"eventId" -> fromEventId(created.eventId),
|
||||
"contractId" -> fromAnyContractId(created.getTemplateId, created.contractId),
|
||||
"argument" -> SOptional(tmplPayload.map(fromAnyTemplate(tmplType, _))),
|
||||
"views" -> SList(views.to(FrontStack)),
|
||||
)
|
||||
}
|
||||
|
||||
private[this] val fromCreatedEvent: CreatedEvent => Either[String, SValue] =
|
||||
if (triggerDef.version < Trigger.Version.`2.5.0`) {
|
||||
fromV20CreatedEvent
|
||||
} else {
|
||||
fromV250CreatedEvent
|
||||
}
|
||||
|
||||
private def fromEvent(ev: Event): Either[String, SValue] =
|
||||
ev.event match {
|
||||
case Event.Event.Archived(archivedEvent) =>
|
||||
Right(
|
||||
SVariant(
|
||||
id = eventTy,
|
||||
variant = EventVariant.ArchiveEventConstructor,
|
||||
constructorRank = EventVariant.ArchiveEventConstructorRank,
|
||||
value = fromArchivedEvent(archivedEvent),
|
||||
)
|
||||
)
|
||||
case Event.Event.Created(createdEvent) =>
|
||||
for {
|
||||
event <- fromCreatedEvent(createdEvent)
|
||||
} yield SVariant(
|
||||
id = eventTy,
|
||||
variant = EventVariant.CreatedEventConstructor,
|
||||
constructorRank = EventVariant.CreatedEventConstructorRank,
|
||||
value = event,
|
||||
)
|
||||
case _ => Left(s"Expected Archived or Created but got ${ev.event}")
|
||||
}
|
||||
|
||||
def fromTransaction(t: Transaction): Either[String, SValue] =
|
||||
for {
|
||||
events <- t.events.to(ImmArray).traverse(fromEvent).map(xs => SList(FrontStack.from(xs)))
|
||||
transactionId = fromTransactionId(t.transactionId)
|
||||
commandId = fromOptionalCommandId(t.commandId)
|
||||
} yield SVariant(
|
||||
id = messageTy,
|
||||
variant = MessageVariant.MTransactionVariant,
|
||||
constructorRank = MessageVariant.MTransactionVariantRank,
|
||||
value = record(
|
||||
transactionTy,
|
||||
"transactionId" -> transactionId,
|
||||
"commandId" -> commandId,
|
||||
"events" -> events,
|
||||
),
|
||||
)
|
||||
|
||||
def fromCompletion(c: Completion): Either[String, SValue] = {
|
||||
val status: SValue =
|
||||
if (c.getStatus.code == 0)
|
||||
SVariant(
|
||||
triggerIds.damlTriggerLowLevel("CompletionStatus"),
|
||||
CompletionStatusVariant.SucceedVariantConstructor,
|
||||
CompletionStatusVariant.SucceedVariantConstructorRank,
|
||||
record(
|
||||
succeedTy,
|
||||
"transactionId" -> fromTransactionId(c.transactionId),
|
||||
),
|
||||
)
|
||||
else
|
||||
SVariant(
|
||||
triggerIds.damlTriggerLowLevel("CompletionStatus"),
|
||||
CompletionStatusVariant.FailVariantConstructor,
|
||||
CompletionStatusVariant.FailVariantConstructorRank,
|
||||
record(
|
||||
failedTy,
|
||||
"status" -> SInt64(c.getStatus.code.asInstanceOf[Long]),
|
||||
"message" -> SText(c.getStatus.message),
|
||||
),
|
||||
)
|
||||
Right(
|
||||
SVariant(
|
||||
messageTy,
|
||||
MessageVariant.MCompletionConstructor,
|
||||
MessageVariant.MCompletionConstructorRank,
|
||||
record(
|
||||
completionTy,
|
||||
"commandId" -> fromCommandId(c.commandId),
|
||||
"status" -> status,
|
||||
),
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def fromHeartbeat: SValue =
|
||||
SVariant(
|
||||
messageTy,
|
||||
MessageVariant.MHeartbeatConstructor,
|
||||
MessageVariant.MHeartbeatConstructorRank,
|
||||
SUnit,
|
||||
)
|
||||
|
||||
def fromActiveContracts(createdEvents: Seq[CreatedEvent]): Either[String, SValue] =
|
||||
for {
|
||||
events <- createdEvents
|
||||
.to(ImmArray)
|
||||
.traverse(fromCreatedEvent)
|
||||
.map(xs => SList(FrontStack.from(xs)))
|
||||
} yield record(activeContractsTy, "activeContracts" -> events)
|
||||
|
||||
private[this] def fromTriggerConfig(triggerConfig: TriggerRunnerConfig): SValue =
|
||||
record(
|
||||
triggerConfigTy,
|
||||
"maxInFlightCommands" -> SValue.SInt64(triggerConfig.inFlightCommandBackPressureCount),
|
||||
"maxActiveContracts" -> SValue.SInt64(triggerConfig.hardLimit.maximumActiveContracts),
|
||||
)
|
||||
|
||||
def fromTriggerSetupArguments(
|
||||
parties: TriggerParties,
|
||||
createdEvents: Seq[CreatedEvent],
|
||||
triggerConfig: TriggerRunnerConfig,
|
||||
): Either[String, SValue] =
|
||||
for {
|
||||
acs <- fromActiveContracts(createdEvents)
|
||||
actAs = SParty(parties.actAs)
|
||||
readAs = SList(parties.readAs.view.map(SParty).to(FrontStack))
|
||||
config = fromTriggerConfig(triggerConfig)
|
||||
} yield record(
|
||||
triggerSetupArgumentsTy,
|
||||
"actAs" -> actAs,
|
||||
"readAs" -> readAs,
|
||||
"acs" -> acs,
|
||||
"config" -> config,
|
||||
)
|
||||
|
||||
def fromCommand(command: Command): SValue = command match {
|
||||
case Command.Create(templateId, argument) =>
|
||||
record(
|
||||
createCommandTy,
|
||||
"templateArg" -> fromAnyTemplate(
|
||||
TTyCon(templateId),
|
||||
argument,
|
||||
),
|
||||
)
|
||||
|
||||
case Command.ExerciseTemplate(_, contractId, _, choiceArg) =>
|
||||
record(
|
||||
exerciseCommandTy,
|
||||
"contractId" -> contractId,
|
||||
"choiceArg" -> choiceArg,
|
||||
)
|
||||
|
||||
case Command.ExerciseInterface(_, contractId, _, choiceArg) =>
|
||||
record(
|
||||
exerciseCommandTy,
|
||||
"contractId" -> contractId,
|
||||
"choiceArg" -> choiceArg,
|
||||
)
|
||||
|
||||
case Command.CreateAndExercise(templateId, createArg, _, choiceArg) =>
|
||||
record(
|
||||
createAndExerciseCommandTy,
|
||||
"templateArg" -> fromAnyTemplate(TTyCon(templateId), createArg),
|
||||
"choiceArg" -> choiceArg,
|
||||
)
|
||||
|
||||
case Command.ExerciseByKey(templateId, contractKey, _, choiceArg) =>
|
||||
record(
|
||||
exerciseByKeyCommandTy,
|
||||
"tplTypeRep" -> SValue.STypeRep(TTyCon(templateId)),
|
||||
"contractKey" -> contractKey,
|
||||
"choiceArg" -> choiceArg,
|
||||
)
|
||||
|
||||
case _ =>
|
||||
throw new ConverterException(
|
||||
s"${command.getClass.getSimpleName} is an unexpected command type"
|
||||
)
|
||||
}
|
||||
|
||||
def fromCommands(commands: Seq[Command]): SValue = {
|
||||
SList(commands.map(fromCommand).to(FrontStack))
|
||||
}
|
||||
|
||||
def fromACS(activeContracts: Seq[CreatedEvent]): SValue = {
|
||||
val createMapByTemplateId = mutable.HashMap.empty[value.Identifier, BackStack[CreatedEvent]]
|
||||
for (create <- activeContracts) {
|
||||
createMapByTemplateId += (create.getTemplateId -> (createMapByTemplateId.getOrElse(
|
||||
create.getTemplateId,
|
||||
BackStack.empty,
|
||||
) :+ create))
|
||||
}
|
||||
|
||||
record(
|
||||
acsTy,
|
||||
"activeContracts" -> SMap(
|
||||
isTextMap = false,
|
||||
createMapByTemplateId.iterator.map { case (templateId, creates) =>
|
||||
fromTemplateTypeRep(templateId) -> SMap(
|
||||
isTextMap = false,
|
||||
creates.reverseIterator.map { event =>
|
||||
val templateType = TTyCon(fromIdentifier(templateId).orConverterException)
|
||||
val template = fromAnyTemplate(
|
||||
templateType,
|
||||
fromRecord(templateType, event.getCreateArguments).orConverterException,
|
||||
)
|
||||
|
||||
fromAnyContractId(templateId, event.contractId) -> template
|
||||
},
|
||||
)
|
||||
},
|
||||
),
|
||||
"pendingContracts" -> SMap(isTextMap = false),
|
||||
)
|
||||
}
|
||||
|
||||
def fromTriggerUpdateState(
|
||||
createdEvents: Seq[CreatedEvent],
|
||||
userState: SValue,
|
||||
commandsInFlight: Map[String, Seq[Command]] = Map.empty,
|
||||
parties: TriggerParties,
|
||||
triggerConfig: TriggerRunnerConfig,
|
||||
): SValue = {
|
||||
val acs = fromACS(createdEvents)
|
||||
val actAs = SParty(parties.actAs)
|
||||
val readAs = SList(parties.readAs.map(SParty).to(FrontStack))
|
||||
val config = fromTriggerConfig(triggerConfig)
|
||||
|
||||
record(
|
||||
triggerStateTy,
|
||||
"acs" -> acs,
|
||||
"actAs" -> actAs,
|
||||
"readAs" -> readAs,
|
||||
"userState" -> userState,
|
||||
"commandsInFlight" -> SMap(
|
||||
isTextMap = false,
|
||||
commandsInFlight.iterator.map { case (cmdId, cmds) =>
|
||||
(fromCommandId(cmdId), fromCommands(cmds))
|
||||
},
|
||||
),
|
||||
"config" -> config,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
object Converter {
|
||||
|
||||
final case class AnyContractId(templateId: Identifier, contractId: ContractId)
|
||||
final case class AnyTemplate(ty: Identifier, arg: SValue)
|
||||
|
||||
private final case class AnyChoice(name: ChoiceName, arg: SValue)
|
||||
private final case class AnyContractKey(key: SValue)
|
||||
|
||||
object EventVariant {
|
||||
// Those values should be kept consistent with type `Event` defined in
|
||||
// triggers/daml/Daml/Trigger/LowLevel.daml
|
||||
val CreatedEventConstructor = Name.assertFromString("CreatedEvent")
|
||||
val CreatedEventConstructorRank = 0
|
||||
val ArchiveEventConstructor = Name.assertFromString("ArchivedEvent")
|
||||
val ArchiveEventConstructorRank = 1
|
||||
}
|
||||
|
||||
object MessageVariant {
|
||||
// Those values should be kept consistent with type `Message` defined in
|
||||
// triggers/daml/Daml/Trigger/LowLevel.daml
|
||||
val MTransactionVariant = Name.assertFromString("MTransaction")
|
||||
val MTransactionVariantRank = 0
|
||||
val MCompletionConstructor = Name.assertFromString("MCompletion")
|
||||
val MCompletionConstructorRank = 1
|
||||
val MHeartbeatConstructor = Name.assertFromString("MHeartbeat")
|
||||
val MHeartbeatConstructorRank = 2
|
||||
}
|
||||
|
||||
object CompletionStatusVariant {
|
||||
// Those values should be kept consistent `CompletionStatus` defined in
|
||||
// triggers/daml/Daml/Trigger/LowLevel.daml
|
||||
val FailVariantConstructor = Name.assertFromString("Failed")
|
||||
val FailVariantConstructorRank = 0
|
||||
val SucceedVariantConstructor = Name.assertFromString("Succeeded")
|
||||
val SucceedVariantConstructorRank = 1
|
||||
}
|
||||
|
||||
def fromIdentifier(identifier: value.Identifier): Either[String, Identifier] =
|
||||
for {
|
||||
pkgId <- PackageId.fromString(identifier.packageId)
|
||||
mod <- DottedName.fromString(identifier.moduleName)
|
||||
name <- DottedName.fromString(identifier.entityName)
|
||||
} yield Identifier(pkgId, QualifiedName(mod, name))
|
||||
|
||||
private def toLedgerRecord(v: SValue): Either[String, value.Record] =
|
||||
lfValueToApiRecord(verbose = true, v.toUnnormalizedValue)
|
||||
|
||||
private def toLedgerValue(v: SValue): Either[String, value.Value] =
|
||||
lfValueToApiValue(verbose = true, v.toUnnormalizedValue)
|
||||
|
||||
private def toIdentifier(v: SValue): Either[String, Identifier] =
|
||||
v.expect(
|
||||
"STypeRep",
|
||||
{ case STypeRep(TTyCon(id)) =>
|
||||
id
|
||||
},
|
||||
)
|
||||
|
||||
def toAnyContractId(v: SValue): Either[String, AnyContractId] =
|
||||
v.expectE(
|
||||
"AnyContractId",
|
||||
{ case SRecord(_, _, ArrayList(stid, scid)) =>
|
||||
for {
|
||||
templateId <- toTemplateTypeRep(stid)
|
||||
contractId <- toContractId(scid)
|
||||
} yield AnyContractId(templateId, contractId)
|
||||
},
|
||||
)
|
||||
|
||||
private def toTemplateTypeRep(v: SValue): Either[String, Identifier] =
|
||||
v.expectE(
|
||||
"TemplateTypeRep",
|
||||
{ case SRecord(_, _, ArrayList(id)) =>
|
||||
toIdentifier(id)
|
||||
},
|
||||
)
|
||||
|
||||
def toFiniteDuration(value: SValue): Either[String, FiniteDuration] =
|
||||
value.expect(
|
||||
"RelTime",
|
||||
{ case SRecord(_, _, ArrayList(SInt64(microseconds))) =>
|
||||
FiniteDuration(microseconds, MICROSECONDS)
|
||||
},
|
||||
)
|
||||
|
||||
private def toRegisteredTemplate(v: SValue): Either[String, Identifier] =
|
||||
v.expectE(
|
||||
"RegisteredTemplate",
|
||||
{ case SRecord(_, _, ArrayList(sttr)) =>
|
||||
toTemplateTypeRep(sttr)
|
||||
},
|
||||
)
|
||||
|
||||
def toRegisteredTemplates(v: SValue): Either[String, Seq[Identifier]] =
|
||||
v.expectE(
|
||||
"list of RegisteredTemplate",
|
||||
{ case SList(tpls) =>
|
||||
tpls.traverse(toRegisteredTemplate).map(_.toImmArray.toSeq)
|
||||
},
|
||||
)
|
||||
|
||||
def toAnyTemplate(v: SValue): Either[String, AnyTemplate] =
|
||||
v match {
|
||||
case SRecord(_, _, ArrayList(SAny(TTyCon(tmplId), value))) =>
|
||||
Right(AnyTemplate(tmplId, value))
|
||||
case _ => Left(s"Expected AnyTemplate but got $v")
|
||||
}
|
||||
|
||||
private def choiceArgTypeToChoiceName(choiceCons: TypeConName) = {
|
||||
// This exploits the fact that in Daml, choice argument type names
|
||||
// and choice names match up.
|
||||
assert(choiceCons.qualifiedName.name.segments.length == 1)
|
||||
choiceCons.qualifiedName.name.segments.head
|
||||
}
|
||||
|
||||
private def toAnyChoice(v: SValue): Either[String, AnyChoice] =
|
||||
v match {
|
||||
case SRecord(_, _, ArrayList(SAny(TTyCon(choiceCons), choiceVal), _)) =>
|
||||
Right(AnyChoice(choiceArgTypeToChoiceName(choiceCons), choiceVal))
|
||||
case _ =>
|
||||
Left(s"Expected AnyChoice but got $v")
|
||||
}
|
||||
|
||||
private def toAnyContractKey(v: SValue): Either[String, AnyContractKey] =
|
||||
v.expect(
|
||||
"AnyContractKey",
|
||||
{ case SRecord(_, _, ArrayList(SAny(_, v), _)) =>
|
||||
AnyContractKey(v)
|
||||
},
|
||||
)
|
||||
|
||||
private def toCreate(v: SValue): Either[String, CreateCommand] =
|
||||
v.expectE(
|
||||
"CreateCommand",
|
||||
{ case SRecord(_, _, ArrayList(sTpl)) =>
|
||||
for {
|
||||
anyTmpl <- toAnyTemplate(sTpl)
|
||||
templateArg <- toLedgerRecord(anyTmpl.arg)
|
||||
} yield CreateCommand(Some(toApiIdentifier(anyTmpl.ty)), Some(templateArg))
|
||||
},
|
||||
)
|
||||
|
||||
private def toExercise(v: SValue): Either[String, ExerciseCommand] =
|
||||
v.expectE(
|
||||
"ExerciseCommand",
|
||||
{ case SRecord(_, _, ArrayList(sAnyContractId, sChoiceVal)) =>
|
||||
for {
|
||||
anyContractId <- toAnyContractId(sAnyContractId)
|
||||
anyChoice <- toAnyChoice(sChoiceVal)
|
||||
choiceArg <- toLedgerValue(anyChoice.arg)
|
||||
} yield ExerciseCommand(
|
||||
Some(toApiIdentifier(anyContractId.templateId)),
|
||||
anyContractId.contractId.coid,
|
||||
anyChoice.name,
|
||||
Some(choiceArg),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
private def toExerciseByKey(v: SValue): Either[String, ExerciseByKeyCommand] =
|
||||
v.expectE(
|
||||
"ExerciseByKeyCommand",
|
||||
{ case SRecord(_, _, ArrayList(stplId, skeyVal, sChoiceVal)) =>
|
||||
for {
|
||||
tplId <- toTemplateTypeRep(stplId)
|
||||
keyVal <- toAnyContractKey(skeyVal)
|
||||
keyArg <- toLedgerValue(keyVal.key)
|
||||
anyChoice <- toAnyChoice(sChoiceVal)
|
||||
choiceArg <- toLedgerValue(anyChoice.arg)
|
||||
} yield ExerciseByKeyCommand(
|
||||
Some(toApiIdentifier(tplId)),
|
||||
Some(keyArg),
|
||||
anyChoice.name,
|
||||
Some(choiceArg),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
private def toCreateAndExercise(v: SValue): Either[String, CreateAndExerciseCommand] =
|
||||
v.expectE(
|
||||
"CreateAndExerciseCommand",
|
||||
{ case SRecord(_, _, ArrayList(sTpl, sChoiceVal)) =>
|
||||
for {
|
||||
anyTmpl <- toAnyTemplate(sTpl)
|
||||
templateArg <- toLedgerRecord(anyTmpl.arg)
|
||||
anyChoice <- toAnyChoice(sChoiceVal)
|
||||
choiceArg <- toLedgerValue(anyChoice.arg)
|
||||
} yield CreateAndExerciseCommand(
|
||||
Some(toApiIdentifier(anyTmpl.ty)),
|
||||
Some(templateArg),
|
||||
anyChoice.name,
|
||||
Some(choiceArg),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
private def toCommand(v: SValue): Either[String, ApiCommand] = {
|
||||
v match {
|
||||
case SVariant(_, "CreateCommand", _, createVal) =>
|
||||
for {
|
||||
create <- toCreate(createVal)
|
||||
} yield ApiCommand().withCreate(create)
|
||||
case SVariant(_, "ExerciseCommand", _, exerciseVal) =>
|
||||
for {
|
||||
exercise <- toExercise(exerciseVal)
|
||||
} yield ApiCommand().withExercise(exercise)
|
||||
case SVariant(_, "ExerciseByKeyCommand", _, exerciseByKeyVal) =>
|
||||
for {
|
||||
exerciseByKey <- toExerciseByKey(exerciseByKeyVal)
|
||||
} yield ApiCommand().withExerciseByKey(exerciseByKey)
|
||||
case SVariant(_, "CreateAndExerciseCommand", _, createAndExerciseVal) =>
|
||||
for {
|
||||
createAndExercise <- toCreateAndExercise(createAndExerciseVal)
|
||||
} yield ApiCommand().withCreateAndExercise(createAndExercise)
|
||||
case _ => Left(s"Expected a Command but got $v")
|
||||
}
|
||||
}
|
||||
|
||||
def toCommands(v: SValue): Either[String, Seq[ApiCommand]] =
|
||||
for {
|
||||
cmdValues <- v.expect(
|
||||
"[Command]",
|
||||
{ case SList(cmdValues) =>
|
||||
cmdValues
|
||||
},
|
||||
)
|
||||
commands <- cmdValues.traverse(toCommand)
|
||||
} yield commands.toImmArray.toSeq
|
||||
}
|
||||
|
||||
// Helper to create identifiers pointing to the Daml.Trigger module
|
||||
final case class TriggerIds(triggerPackageId: PackageId) {
|
||||
def damlTrigger(s: String): Identifier =
|
||||
Identifier(
|
||||
triggerPackageId,
|
||||
QualifiedName(ModuleName.assertFromString("Daml.Trigger"), DottedName.assertFromString(s)),
|
||||
)
|
||||
|
||||
def damlTriggerLowLevel(s: String): Identifier =
|
||||
Identifier(
|
||||
triggerPackageId,
|
||||
QualifiedName(
|
||||
ModuleName.assertFromString("Daml.Trigger.LowLevel"),
|
||||
DottedName.assertFromString(s),
|
||||
),
|
||||
)
|
||||
|
||||
def damlTriggerInternal(s: String): Identifier =
|
||||
Identifier(
|
||||
triggerPackageId,
|
||||
QualifiedName(
|
||||
ModuleName.assertFromString("Daml.Trigger.Internal"),
|
||||
DottedName.assertFromString(s),
|
||||
),
|
||||
)
|
||||
}
|
@ -1,193 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.lf.engine.trigger
|
||||
|
||||
import com.daml.ledger.api.v1.{value => api}
|
||||
import com.daml.lf.speedy.{Pretty, SValue}
|
||||
import org.typelevel.paiges.Doc
|
||||
import org.typelevel.paiges.Doc.{char, fill, intercalate, str, text}
|
||||
|
||||
import scala.jdk.CollectionConverters._
|
||||
|
||||
object PrettyPrint {
|
||||
|
||||
def prettyApiIdentifier(id: api.Identifier): Doc =
|
||||
text(id.moduleName) + char(':') + text(id.entityName) + char('@') + text(id.packageId)
|
||||
|
||||
def prettyApiValue(verbose: Boolean, maxListWidth: Option[Int] = None)(v: api.Value): Doc =
|
||||
v.sum match {
|
||||
case api.Value.Sum.Empty => Doc.empty
|
||||
|
||||
case api.Value.Sum.Int64(i) => str(i)
|
||||
|
||||
case api.Value.Sum.Numeric(d) => text(d)
|
||||
|
||||
case api.Value.Sum.Record(api.Record(mbId, fs)) =>
|
||||
(mbId match {
|
||||
case Some(id) if verbose => prettyApiIdentifier(id)
|
||||
case _ => Doc.empty
|
||||
}) +
|
||||
char('{') &
|
||||
fill(
|
||||
text(", "),
|
||||
fs.toList.map {
|
||||
case api.RecordField(k, Some(v)) =>
|
||||
text(k) & char('=') & prettyApiValue(verbose = true, maxListWidth)(v)
|
||||
case _ => Doc.empty
|
||||
},
|
||||
) &
|
||||
char('}')
|
||||
|
||||
case api.Value.Sum.Variant(api.Variant(mbId, variant, value)) =>
|
||||
(mbId match {
|
||||
case Some(id) if verbose => prettyApiIdentifier(id) + char(':')
|
||||
case _ => Doc.empty
|
||||
}) +
|
||||
text(variant) + char('(') + value.fold(Doc.empty)(v =>
|
||||
prettyApiValue(verbose = true, maxListWidth)(v)
|
||||
) + char(')')
|
||||
|
||||
case api.Value.Sum.Enum(api.Enum(mbId, constructor)) =>
|
||||
(mbId match {
|
||||
case Some(id) if verbose => prettyApiIdentifier(id) + char(':')
|
||||
case _ => Doc.empty
|
||||
}) + text(constructor)
|
||||
|
||||
case api.Value.Sum.Text(t) => char('"') + text(t) + char('"')
|
||||
|
||||
case api.Value.Sum.ContractId(acoid) => text(acoid)
|
||||
|
||||
case api.Value.Sum.Unit(_) => text("<unit>")
|
||||
|
||||
case api.Value.Sum.Bool(b) => str(b)
|
||||
|
||||
case api.Value.Sum.List(api.List(lst)) =>
|
||||
maxListWidth match {
|
||||
case Some(maxWidth) if lst.size > maxWidth =>
|
||||
char('[') + intercalate(
|
||||
text(", "),
|
||||
lst.take(maxWidth).map(prettyApiValue(verbose = true, maxListWidth)(_)),
|
||||
) + text(s", ...${lst.size - maxWidth} elements truncated...") + char(']')
|
||||
|
||||
case _ =>
|
||||
char('[') + intercalate(
|
||||
text(", "),
|
||||
lst.map(prettyApiValue(verbose = true, maxListWidth)(_)),
|
||||
) + char(']')
|
||||
}
|
||||
|
||||
case api.Value.Sum.Timestamp(t) => str(t)
|
||||
|
||||
case api.Value.Sum.Date(days) => str(days)
|
||||
|
||||
case api.Value.Sum.Party(p) => char('\'') + str(p) + char('\'')
|
||||
|
||||
case api.Value.Sum.Optional(api.Optional(Some(v1))) =>
|
||||
text("Option(") + prettyApiValue(verbose, maxListWidth)(v1) + char(')')
|
||||
|
||||
case api.Value.Sum.Optional(api.Optional(None)) => text("None")
|
||||
|
||||
case api.Value.Sum.Map(api.Map(entries)) =>
|
||||
val list = entries.map {
|
||||
case api.Map.Entry(k, Some(v)) =>
|
||||
text(k) + text(" -> ") + prettyApiValue(verbose, maxListWidth)(v)
|
||||
case _ => Doc.empty
|
||||
}
|
||||
text("TextMap(") + intercalate(text(", "), list) + text(")")
|
||||
|
||||
case api.Value.Sum.GenMap(api.GenMap(entries)) =>
|
||||
val list = entries.map {
|
||||
case api.GenMap.Entry(Some(k), Some(v)) =>
|
||||
prettyApiValue(verbose, maxListWidth)(k) + text(" -> ") + prettyApiValue(
|
||||
verbose,
|
||||
maxListWidth,
|
||||
)(v)
|
||||
case _ => Doc.empty
|
||||
}
|
||||
text("GenMap(") + intercalate(text(", "), list) + text(")")
|
||||
}
|
||||
|
||||
def prettySValue(v: SValue): Doc = v match {
|
||||
case SValue.SPAP(_, _, _) =>
|
||||
text("...")
|
||||
|
||||
case r: SValue.SRecord =>
|
||||
Pretty.prettyIdentifier(r.id) + char('{') & fill(
|
||||
text(", "),
|
||||
r.fields.toSeq.zip(r.values.asScala).map { case (k, v) =>
|
||||
text(k) & char('=') & prettySValue(v)
|
||||
},
|
||||
) & char('}')
|
||||
|
||||
case SValue.SStruct(fieldNames, values) =>
|
||||
char('<') + fill(
|
||||
text(", "),
|
||||
fieldNames.names.zip(values.asScala).toSeq.map { case (k, v) =>
|
||||
text(k) + char('=') + prettySValue(v)
|
||||
},
|
||||
) + char('>')
|
||||
|
||||
case SValue.SVariant(id, variant, _, value) =>
|
||||
Pretty.prettyIdentifier(id) + char(':') + text(variant) + char('(') + prettySValue(
|
||||
value
|
||||
) + char(')')
|
||||
|
||||
case SValue.SEnum(id, constructor, _) =>
|
||||
Pretty.prettyIdentifier(id) + char(':') + text(constructor)
|
||||
|
||||
case SValue.SOptional(Some(value)) =>
|
||||
text("Option(") + prettySValue(value) + char(')')
|
||||
|
||||
case SValue.SOptional(None) =>
|
||||
text("None")
|
||||
|
||||
case SValue.SList(list) =>
|
||||
char('[') + intercalate(text(", "), list.map(prettySValue).toImmArray.toSeq) + char(']')
|
||||
|
||||
case SValue.SMap(isTextMap, entries) =>
|
||||
val list = entries.map { case (k, v) =>
|
||||
prettySValue(k) + text(" -> ") + prettySValue(v)
|
||||
}
|
||||
text(if (isTextMap) "TextMap(" else "GenMap(") + intercalate(text(", "), list) + text(")")
|
||||
|
||||
case SValue.SAny(ty, value) =>
|
||||
text("to_any") + char('@') + text(ty.pretty) + prettySValue(value)
|
||||
|
||||
case SValue.SInt64(value) =>
|
||||
str(value)
|
||||
|
||||
case SValue.SNumeric(value) =>
|
||||
str(value)
|
||||
|
||||
case SValue.SBigNumeric(value) =>
|
||||
str(value)
|
||||
|
||||
case SValue.SText(value) =>
|
||||
text(s"$value")
|
||||
|
||||
case SValue.STimestamp(value) =>
|
||||
str(value)
|
||||
|
||||
case SValue.SParty(value) =>
|
||||
char('\'') + str(value) + char('\'')
|
||||
|
||||
case SValue.SBool(value) =>
|
||||
str(value)
|
||||
|
||||
case SValue.SUnit =>
|
||||
text("<unit>")
|
||||
|
||||
case SValue.SDate(value) =>
|
||||
str(value)
|
||||
|
||||
case SValue.SContractId(value) =>
|
||||
text(value.coid)
|
||||
|
||||
case SValue.STypeRep(ty) =>
|
||||
text(ty.pretty)
|
||||
|
||||
case SValue.SToken =>
|
||||
text("Token")
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,391 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.lf.engine.trigger
|
||||
|
||||
import ch.qos.logback.classic.Level
|
||||
|
||||
import java.nio.file.{Path, Paths}
|
||||
import java.time.Duration
|
||||
import com.daml.lf.data.Ref
|
||||
import com.daml.ledger.api.domain
|
||||
import com.daml.ledger.api.tls.TlsConfiguration
|
||||
import com.daml.ledger.api.tls.TlsConfigurationCli
|
||||
import com.daml.ledger.client.LedgerClient
|
||||
import com.daml.lf.engine.trigger.TriggerRunnerConfig.DefaultTriggerRunnerConfig
|
||||
import com.daml.lf.language.LanguageMajorVersion
|
||||
import com.daml.platform.services.time.TimeProviderType
|
||||
import com.daml.lf.speedy.Compiler
|
||||
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
|
||||
sealed trait LogEncoder
|
||||
object LogEncoder {
|
||||
object Plain extends LogEncoder
|
||||
object Json extends LogEncoder
|
||||
}
|
||||
|
||||
sealed trait CompilerConfigBuilder extends Product with Serializable {
|
||||
def build(majorLanguageVersion: LanguageMajorVersion): Compiler.Config
|
||||
}
|
||||
object CompilerConfigBuilder {
|
||||
final case object Default extends CompilerConfigBuilder {
|
||||
override def build(majorLanguageVersion: LanguageMajorVersion): Compiler.Config =
|
||||
Compiler.Config.Default(majorLanguageVersion)
|
||||
}
|
||||
final case object Dev extends CompilerConfigBuilder {
|
||||
override def build(majorLanguageVersion: LanguageMajorVersion): Compiler.Config =
|
||||
Compiler.Config.Dev(majorLanguageVersion)
|
||||
}
|
||||
}
|
||||
|
||||
case class RunnerConfig(
|
||||
darPath: Path,
|
||||
// If defined, we will only list the triggers in the DAR and exit.
|
||||
listTriggers: Option[Boolean],
|
||||
triggerIdentifier: String,
|
||||
ledgerHost: String,
|
||||
ledgerPort: Int,
|
||||
ledgerClaims: ClaimsSpecification,
|
||||
maxInboundMessageSize: Int,
|
||||
// optional so we can detect if both --static-time and --wall-clock-time are passed.
|
||||
timeProviderType: Option[TimeProviderType],
|
||||
commandTtl: Duration,
|
||||
accessTokenFile: Option[Path],
|
||||
applicationId: Option[Ref.ApplicationId],
|
||||
tlsConfig: TlsConfiguration,
|
||||
compilerConfigBuilder: CompilerConfigBuilder,
|
||||
majorLanguageVersion: LanguageMajorVersion,
|
||||
triggerConfig: TriggerRunnerConfig,
|
||||
rootLoggingLevel: Option[Level],
|
||||
logEncoder: LogEncoder,
|
||||
) {
|
||||
private def updatePartySpec(f: TriggerParties => TriggerParties): RunnerConfig =
|
||||
if (ledgerClaims == null) {
|
||||
copy(ledgerClaims = PartySpecification(f(TriggerParties.Empty)))
|
||||
} else
|
||||
ledgerClaims match {
|
||||
case PartySpecification(claims) =>
|
||||
copy(ledgerClaims = PartySpecification(f(claims)))
|
||||
case _: UserSpecification =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Must specify either --ledger-party and --ledger-readas or --ledger-userid but not both"
|
||||
)
|
||||
}
|
||||
private def updateActAs(party: Ref.Party): RunnerConfig =
|
||||
updatePartySpec(spec => spec.copy(actAsOpt = Some(party)))
|
||||
|
||||
private def updateReadAs(parties: Seq[Ref.Party]): RunnerConfig =
|
||||
updatePartySpec(spec => spec.copy(readAs = spec.readAs ++ parties))
|
||||
|
||||
private def updateUser(userId: Ref.UserId): RunnerConfig =
|
||||
if (ledgerClaims == null) {
|
||||
copy(ledgerClaims = UserSpecification(userId))
|
||||
} else
|
||||
ledgerClaims match {
|
||||
case UserSpecification(_) => copy(ledgerClaims = UserSpecification(userId))
|
||||
case _: PartySpecification =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Must specify either --ledger-party and --ledger-readas or --ledger-userid but not both"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
sealed abstract class ClaimsSpecification {
|
||||
def resolveClaims(client: LedgerClient)(implicit ec: ExecutionContext): Future[TriggerParties]
|
||||
}
|
||||
|
||||
final case class PartySpecification(claims: TriggerParties) extends ClaimsSpecification {
|
||||
override def resolveClaims(client: LedgerClient)(implicit
|
||||
ec: ExecutionContext
|
||||
): Future[TriggerParties] =
|
||||
Future.successful(claims)
|
||||
}
|
||||
|
||||
final case class UserSpecification(userId: Ref.UserId) extends ClaimsSpecification {
|
||||
override def resolveClaims(
|
||||
client: LedgerClient
|
||||
)(implicit ec: ExecutionContext): Future[TriggerParties] = for {
|
||||
user <- client.userManagementClient.getUser(userId)
|
||||
primaryParty <- user.primaryParty.fold[Future[Ref.Party]](
|
||||
Future.failed(
|
||||
new IllegalArgumentException(
|
||||
s"User $user has no primary party. Specify a party explicitly via --ledger-party"
|
||||
)
|
||||
)
|
||||
)(Future.successful)
|
||||
rights <- client.userManagementClient.listUserRights(userId)
|
||||
readAs = rights.collect { case domain.UserRight.CanReadAs(party) =>
|
||||
party
|
||||
}.toSet
|
||||
actAs = rights.collect { case domain.UserRight.CanActAs(party) =>
|
||||
party
|
||||
}.toSet
|
||||
_ <-
|
||||
if (actAs.contains(primaryParty)) {
|
||||
Future.unit
|
||||
} else {
|
||||
Future.failed(
|
||||
new IllegalArgumentException(
|
||||
s"User $user has primary party $primaryParty but no actAs claims for that party. Either change the user rights or specify a different party via --ledger-party"
|
||||
)
|
||||
)
|
||||
}
|
||||
readers = (readAs ++ actAs) - primaryParty
|
||||
} yield TriggerParties(primaryParty, readers)
|
||||
}
|
||||
|
||||
final case class TriggerParties(
|
||||
actAsOpt: Option[Ref.Party],
|
||||
readAs: Set[Ref.Party],
|
||||
) {
|
||||
lazy val actAs = actAsOpt.get
|
||||
lazy val readers: Set[Ref.Party] = readAs + actAs
|
||||
}
|
||||
|
||||
object TriggerParties {
|
||||
def apply(actAs: Ref.Party, readAs: Set[Ref.Party]): TriggerParties =
|
||||
new TriggerParties(Some(actAs), readAs)
|
||||
|
||||
def Empty = new TriggerParties(None, Set.empty)
|
||||
}
|
||||
|
||||
object RunnerConfig {
|
||||
|
||||
implicit val userRead: scopt.Read[Ref.UserId] = scopt.Read.reads { s =>
|
||||
Ref.UserId.fromString(s).fold(e => throw new IllegalArgumentException(e), identity)
|
||||
}
|
||||
|
||||
private[trigger] val DefaultMaxInboundMessageSize: Int = 4194304
|
||||
private[trigger] val DefaultTimeProviderType: TimeProviderType = TimeProviderType.WallClock
|
||||
private[trigger] val DefaultApplicationId: Some[Ref.ApplicationId] =
|
||||
Some(Ref.ApplicationId.assertFromString("daml-trigger"))
|
||||
private[trigger] val DefaultCompilerConfigBuilder: CompilerConfigBuilder =
|
||||
CompilerConfigBuilder.Default
|
||||
private[trigger] val DefaultMajorLanguageVersion: LanguageMajorVersion = LanguageMajorVersion.V1
|
||||
|
||||
@SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements")) // scopt builders
|
||||
private val parser = new scopt.OptionParser[RunnerConfig]("trigger-runner") {
|
||||
head("trigger-runner")
|
||||
|
||||
opt[String]("dar")
|
||||
.required()
|
||||
.action((f, c) => c.copy(darPath = Paths.get(f)))
|
||||
.text("Path to the dar file containing the trigger")
|
||||
|
||||
opt[String]("trigger-name")
|
||||
.action((t, c) => c.copy(triggerIdentifier = t))
|
||||
.text("Identifier of the trigger that should be run in the format Module.Name:Entity.Name")
|
||||
|
||||
opt[String]("ledger-host")
|
||||
.action((t, c) => c.copy(ledgerHost = t))
|
||||
.text("Ledger hostname")
|
||||
|
||||
opt[Int]("ledger-port")
|
||||
.action((t, c) => c.copy(ledgerPort = t))
|
||||
.text("Ledger port")
|
||||
|
||||
opt[String]("ledger-party")
|
||||
.action((t, c) => c.updateActAs(Ref.Party.assertFromString(t)))
|
||||
.text("""The party the trigger can act as.
|
||||
|Mutually exclusive with --ledger-user.""".stripMargin)
|
||||
|
||||
opt[Seq[String]]("ledger-readas")
|
||||
.action((t, c) => c.updateReadAs(t.map(Ref.Party.assertFromString)))
|
||||
.unbounded()
|
||||
.text(
|
||||
"""A comma-separated list of parties the trigger can read as.
|
||||
|Can be specified multiple-times.
|
||||
|Mutually exclusive with --ledger-user.""".stripMargin
|
||||
)
|
||||
|
||||
opt[Ref.UserId]("ledger-user")
|
||||
.action((u, c) => c.updateUser(u))
|
||||
.unbounded()
|
||||
.text(
|
||||
"""The id of the user the trigger should run as.
|
||||
|This is equivalent to specifying the primary party
|
||||
|of the user as --ledger-party and
|
||||
|all actAs and readAs claims of the user other than the
|
||||
|primary party as --ledger-readas.
|
||||
|The user must have a primary party.
|
||||
|Mutually exclusive with --ledger-party and --ledger-readas.""".stripMargin
|
||||
)
|
||||
|
||||
opt[Int]("max-inbound-message-size")
|
||||
.action((x, c) => c.copy(maxInboundMessageSize = x))
|
||||
.optional()
|
||||
.text(
|
||||
s"Optional max inbound message size in bytes. Defaults to ${DefaultMaxInboundMessageSize}"
|
||||
)
|
||||
|
||||
opt[Unit]('w', "wall-clock-time")
|
||||
.action { (_, c) =>
|
||||
setTimeProviderType(c, TimeProviderType.WallClock)
|
||||
}
|
||||
.text("Use wall clock time (UTC).")
|
||||
|
||||
opt[Unit]('s', "static-time")
|
||||
.action { (_, c) =>
|
||||
setTimeProviderType(c, TimeProviderType.Static)
|
||||
}
|
||||
.text("Use static time.")
|
||||
|
||||
opt[Long]("ttl")
|
||||
.action { (t, c) =>
|
||||
c.copy(commandTtl = Duration.ofSeconds(t))
|
||||
}
|
||||
.text("TTL in seconds used for commands emitted by the trigger. Defaults to 30s.")
|
||||
|
||||
opt[String]("access-token-file")
|
||||
.action { (f, c) =>
|
||||
c.copy(accessTokenFile = Some(Paths.get(f)))
|
||||
}
|
||||
.text(
|
||||
"File from which the access token will be read, required to interact with an authenticated ledger"
|
||||
)
|
||||
|
||||
opt[String]("application-id")
|
||||
.action { (appId, c) =>
|
||||
c.copy(applicationId =
|
||||
Some(appId).filterNot(_.isEmpty).map(Ref.ApplicationId.assertFromString)
|
||||
)
|
||||
}
|
||||
.text(s"Application ID used to submit commands. Defaults to ${DefaultApplicationId}")
|
||||
|
||||
opt[Unit]('v', "verbose")
|
||||
.text("Root logging level -> DEBUG")
|
||||
.action((_, cli) => cli.copy(rootLoggingLevel = Some(Level.DEBUG)))
|
||||
|
||||
opt[Unit]("debug")
|
||||
.text("Root logging level -> DEBUG")
|
||||
.action((_, cli) => cli.copy(rootLoggingLevel = Some(Level.DEBUG)))
|
||||
|
||||
implicit val levelRead: scopt.Read[Level] = scopt.Read.reads(Level.valueOf)
|
||||
opt[Level]("log-level-root")
|
||||
.text("Log-level of the root logger")
|
||||
.valueName("<LEVEL>")
|
||||
.action((level, cli) => cli.copy(rootLoggingLevel = Some(level)))
|
||||
|
||||
opt[String]("log-encoder")
|
||||
.text("Log encoder: plain|json")
|
||||
.action {
|
||||
case ("json", cli) => cli.copy(logEncoder = LogEncoder.Json)
|
||||
case ("plain", cli) => cli.copy(logEncoder = LogEncoder.Plain)
|
||||
case (other, _) =>
|
||||
throw new IllegalArgumentException(s"Unsupported logging encoder $other")
|
||||
}
|
||||
|
||||
opt[Long]("max-batch-size")
|
||||
.optional()
|
||||
.text(
|
||||
s"maximum number of messages processed between two high-level rule triggers. Defaults to ${DefaultTriggerRunnerConfig.maximumBatchSize}"
|
||||
)
|
||||
.action((size, cli) =>
|
||||
if (size > 0) cli.copy(triggerConfig = cli.triggerConfig.copy(maximumBatchSize = size))
|
||||
else throw new IllegalArgumentException(s"batch size must be strictly positive")
|
||||
)
|
||||
|
||||
opt[Unit]("dev-mode-unsafe")
|
||||
.action((_, c) => c.copy(compilerConfigBuilder = CompilerConfigBuilder.Dev))
|
||||
.optional()
|
||||
.text(
|
||||
"Turns on development mode. Development mode allows development versions of Daml-LF language."
|
||||
)
|
||||
.hidden()
|
||||
|
||||
implicit val majorLanguageVersionRead: scopt.Read[LanguageMajorVersion] =
|
||||
scopt.Read.reads(s =>
|
||||
LanguageMajorVersion.fromString(s) match {
|
||||
case Some(v) => v
|
||||
case None => throw new IllegalArgumentException(s"$s is not a valid major LF version")
|
||||
}
|
||||
)
|
||||
opt[LanguageMajorVersion]("lf-major-version")
|
||||
.action((v, c) => c.copy(majorLanguageVersion = v))
|
||||
.optional()
|
||||
.text(
|
||||
"The major version of LF to use."
|
||||
)
|
||||
// TODO(#17366): unhide once LF v2 has a stable version
|
||||
.hidden()
|
||||
|
||||
TlsConfigurationCli.parse(this, colSpacer = " ")((f, c) =>
|
||||
c.copy(tlsConfig = f(c.tlsConfig))
|
||||
)
|
||||
|
||||
help("help").text("Print this usage text")
|
||||
|
||||
cmd("list")
|
||||
.action((_, c) => c.copy(listTriggers = Some(false)))
|
||||
.text("List the triggers in the DAR.")
|
||||
|
||||
cmd("verbose-list")
|
||||
.hidden()
|
||||
.action((_, c) => c.copy(listTriggers = Some(true)))
|
||||
|
||||
checkConfig(c =>
|
||||
c.listTriggers match {
|
||||
case Some(_) =>
|
||||
// I do not want to break the trigger CLI and require a
|
||||
// "run" command so I can’t make these options required
|
||||
// in general. Therefore, we do this check in checkConfig.
|
||||
success
|
||||
case None =>
|
||||
if (c.triggerIdentifier == null) {
|
||||
failure("Missing option --trigger-name")
|
||||
} else if (c.ledgerHost == null) {
|
||||
failure("Missing option --ledger-host")
|
||||
} else if (c.ledgerPort == 0) {
|
||||
failure("Missing option --ledger-port")
|
||||
} else if (c.ledgerClaims == null) {
|
||||
failure("Missing option --ledger-party or --ledger-user")
|
||||
} else {
|
||||
c.ledgerClaims match {
|
||||
case PartySpecification(TriggerParties(actAs, _)) if actAs.isEmpty =>
|
||||
failure("Missing option --ledger-party")
|
||||
case _ => success
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
private def setTimeProviderType(
|
||||
config: RunnerConfig,
|
||||
timeProviderType: TimeProviderType,
|
||||
): RunnerConfig = {
|
||||
if (config.timeProviderType.exists(_ != timeProviderType)) {
|
||||
throw new IllegalStateException(
|
||||
"Static time mode (`-s`/`--static-time`) and wall-clock time mode (`-w`/`--wall-clock-time`) are mutually exclusive. The time mode must be unambiguous."
|
||||
)
|
||||
}
|
||||
config.copy(timeProviderType = Some(timeProviderType))
|
||||
}
|
||||
|
||||
def parse(args: Array[String]): Option[RunnerConfig] =
|
||||
parser.parse(
|
||||
args,
|
||||
Empty,
|
||||
)
|
||||
|
||||
val Empty: RunnerConfig = RunnerConfig(
|
||||
darPath = null,
|
||||
listTriggers = None,
|
||||
triggerIdentifier = null,
|
||||
ledgerHost = null,
|
||||
ledgerPort = 0,
|
||||
ledgerClaims = null,
|
||||
maxInboundMessageSize = DefaultMaxInboundMessageSize,
|
||||
timeProviderType = None,
|
||||
commandTtl = Duration.ofSeconds(30L),
|
||||
accessTokenFile = None,
|
||||
tlsConfig = TlsConfiguration(enabled = false, None, None, None),
|
||||
applicationId = DefaultApplicationId,
|
||||
compilerConfigBuilder = DefaultCompilerConfigBuilder,
|
||||
majorLanguageVersion = DefaultMajorLanguageVersion,
|
||||
triggerConfig = DefaultTriggerRunnerConfig,
|
||||
rootLoggingLevel = None,
|
||||
logEncoder = LogEncoder.Plain,
|
||||
)
|
||||
}
|
@ -1,135 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.lf.engine.trigger
|
||||
|
||||
import java.io.File
|
||||
import org.apache.pekko.actor.ActorSystem
|
||||
import org.apache.pekko.stream._
|
||||
import ch.qos.logback.classic.Level
|
||||
import com.daml.auth.TokenHolder
|
||||
import com.daml.grpc.adapter.PekkoExecutionSequencerPool
|
||||
import com.daml.ledger.client.LedgerClient
|
||||
import com.daml.ledger.client.configuration.{
|
||||
CommandClientConfiguration,
|
||||
LedgerClientChannelConfiguration,
|
||||
LedgerClientConfiguration,
|
||||
LedgerIdRequirement,
|
||||
}
|
||||
import com.daml.lf.archive.{Dar, DarDecoder}
|
||||
import com.daml.lf.data.Ref.{Identifier, PackageId, QualifiedName}
|
||||
import com.daml.lf.language.Ast._
|
||||
import com.daml.lf.language.PackageInterface
|
||||
import com.daml.scalautil.Statement.discard
|
||||
|
||||
import scala.concurrent.duration.Duration
|
||||
import scala.concurrent.{Await, ExecutionContext, Future}
|
||||
|
||||
object RunnerMain {
|
||||
|
||||
private def listTriggers(
|
||||
darPath: File,
|
||||
dar: Dar[(PackageId, Package)],
|
||||
verbose: Boolean,
|
||||
): Unit = {
|
||||
println(s"Listing triggers in $darPath:")
|
||||
val pkgInterface = PackageInterface(dar.all.toMap)
|
||||
val (mainPkgId, mainPkg) = dar.main
|
||||
for {
|
||||
mod <- mainPkg.modules.values
|
||||
defName <- mod.definitions.keys
|
||||
qualifiedName = QualifiedName(mod.name, defName)
|
||||
triggerId = Identifier(mainPkgId, qualifiedName)
|
||||
} {
|
||||
Trigger.detectTriggerDefinition(pkgInterface, triggerId).foreach {
|
||||
case TriggerDefinition(_, ty, version, level, _) =>
|
||||
if (verbose)
|
||||
println(
|
||||
s" $qualifiedName\t(type = ${ty.pretty}, level = $level, version = $version)"
|
||||
)
|
||||
else
|
||||
println(s" $qualifiedName")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
|
||||
RunnerConfig.parse(args) match {
|
||||
case None => sys.exit(1)
|
||||
case Some(config) => {
|
||||
config.rootLoggingLevel.foreach(setLoggingLevel)
|
||||
config.logEncoder match {
|
||||
case LogEncoder.Plain =>
|
||||
case LogEncoder.Json =>
|
||||
discard(System.setProperty("LOG_FORMAT_JSON", "true"))
|
||||
}
|
||||
|
||||
val dar: Dar[(PackageId, Package)] =
|
||||
DarDecoder.assertReadArchiveFromFile(config.darPath.toFile)
|
||||
|
||||
config.listTriggers.foreach { verbose =>
|
||||
listTriggers(config.darPath.toFile, dar, verbose)
|
||||
sys.exit(0)
|
||||
}
|
||||
|
||||
val triggerId: Identifier =
|
||||
Identifier(dar.main._1, QualifiedName.assertFromString(config.triggerIdentifier))
|
||||
|
||||
val system: ActorSystem = ActorSystem("TriggerRunner")
|
||||
implicit val materializer: Materializer = Materializer(system)
|
||||
val sequencer = new PekkoExecutionSequencerPool("TriggerRunnerPool")(system)
|
||||
implicit val ec: ExecutionContext = system.dispatcher
|
||||
|
||||
val tokenHolder = config.accessTokenFile.map(new TokenHolder(_))
|
||||
// We probably want to refresh the token at some point but given that triggers
|
||||
// are expected to be written such that they can be killed and restarted at
|
||||
// any time it would in principle also be fine to just have the auth failure due
|
||||
// to an expired token tear the trigger down and have some external monitoring process (e.g. systemd)
|
||||
// restart it.
|
||||
val clientConfig = LedgerClientConfiguration(
|
||||
applicationId = config.applicationId.getOrElse(""),
|
||||
ledgerIdRequirement = LedgerIdRequirement.none,
|
||||
commandClient =
|
||||
CommandClientConfiguration.default.copy(defaultDeduplicationTime = config.commandTtl),
|
||||
token = tokenHolder.flatMap(_.token),
|
||||
)
|
||||
|
||||
val channelConfig = LedgerClientChannelConfiguration(
|
||||
sslContext = config.tlsConfig.client(),
|
||||
maxInboundMessageSize = config.maxInboundMessageSize,
|
||||
)
|
||||
|
||||
val flow: Future[Unit] = for {
|
||||
client <- LedgerClient.singleHost(
|
||||
config.ledgerHost,
|
||||
config.ledgerPort,
|
||||
clientConfig,
|
||||
channelConfig,
|
||||
)(ec, sequencer)
|
||||
|
||||
parties <- config.ledgerClaims.resolveClaims(client)
|
||||
|
||||
_ <- Runner.run(
|
||||
dar,
|
||||
triggerId,
|
||||
client,
|
||||
config.timeProviderType.getOrElse(RunnerConfig.DefaultTimeProviderType),
|
||||
config.applicationId,
|
||||
parties,
|
||||
config.compilerConfigBuilder.build(config.majorLanguageVersion),
|
||||
config.triggerConfig,
|
||||
)
|
||||
} yield ()
|
||||
|
||||
flow.onComplete(_ => system.terminate())
|
||||
|
||||
Await.result(flow, Duration.Inf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def setLoggingLevel(level: Level): Unit = {
|
||||
discard(System.setProperty("LOG_LEVEL_ROOT", level.levelStr))
|
||||
}
|
||||
}
|
@ -1,15 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.lf.engine.trigger
|
||||
|
||||
import org.apache.pekko.stream.{Inlet, Outlet, Shape}
|
||||
|
||||
import scala.collection.immutable.Seq
|
||||
|
||||
private[trigger] final case class SourceShape2[L, R](out1: Outlet[L], out2: Outlet[R])
|
||||
extends Shape {
|
||||
override val inlets: Seq[Inlet[_]] = Seq.empty
|
||||
override val outlets: Seq[Outlet[_]] = Seq(out1, out2)
|
||||
override def deepCopy() = copy(out1.carbonCopy(), out2.carbonCopy())
|
||||
}
|
@ -1,205 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.lf.engine.trigger
|
||||
|
||||
import com.daml.lf.data.{BackStack, NoCopy}
|
||||
import com.daml.lf.engine.trigger.Runner.Implicits._
|
||||
import com.daml.lf.engine.trigger.TriggerLogContext._
|
||||
import com.daml.logging.{ContextualizedLogger, LoggingContextOf}
|
||||
import com.daml.logging.LoggingContextOf.label
|
||||
import com.daml.logging.entries.{LoggingEntries, LoggingValue}
|
||||
|
||||
import java.util.UUID
|
||||
|
||||
object ToLoggingContext {
|
||||
implicit def `TriggerLogContext to LoggingContextOf[Trigger]`(implicit
|
||||
triggerContext: TriggerLogContext
|
||||
): LoggingContextOf[Trigger] = {
|
||||
val parentEntries = if (triggerContext.span.parent.isEmpty) {
|
||||
LoggingEntries.empty
|
||||
} else if (triggerContext.span.parent.size == 1) {
|
||||
LoggingEntries("parent" -> triggerContext.span.parent.head)
|
||||
} else {
|
||||
LoggingEntries("parent" -> triggerContext.span.parent)
|
||||
}
|
||||
val spanEntries = LoggingEntries(
|
||||
"span" -> LoggingValue.Nested(
|
||||
LoggingEntries(
|
||||
"name" -> triggerContext.span.path.toImmArray.foldLeft("trigger")((path, name) =>
|
||||
s"$path.$name"
|
||||
),
|
||||
"id" -> triggerContext.span.id,
|
||||
) ++ parentEntries
|
||||
)
|
||||
)
|
||||
|
||||
LoggingContextOf
|
||||
.withEnrichedLoggingContext(
|
||||
label[Trigger],
|
||||
"trigger" -> LoggingValue.Nested(LoggingEntries(triggerContext.entries: _*) ++ spanEntries),
|
||||
)(triggerContext.loggingContext)
|
||||
.run(identity)
|
||||
}
|
||||
}
|
||||
|
||||
final class TriggerLogContext private (
|
||||
private[trigger] val loggingContext: LoggingContextOf[Trigger],
|
||||
private[trigger] val entries: Seq[(String, LoggingValue)],
|
||||
private[trigger] val span: TriggerLogSpan,
|
||||
private[trigger] val callback: (String, TriggerLogContext) => Unit,
|
||||
) extends NoCopy {
|
||||
|
||||
import ToLoggingContext._
|
||||
|
||||
def enrichTriggerContext[A](
|
||||
additionalEntries: (String, LoggingValue)*
|
||||
)(f: TriggerLogContext => A): A = {
|
||||
f(new TriggerLogContext(loggingContext, entries ++ additionalEntries, span, callback))
|
||||
}
|
||||
|
||||
def nextSpan[A](
|
||||
name: String,
|
||||
additionalEntries: (String, LoggingValue)*
|
||||
)(f: TriggerLogContext => A)(implicit
|
||||
logger: ContextualizedLogger
|
||||
): A = {
|
||||
val context = new TriggerLogContext(
|
||||
loggingContext,
|
||||
entries ++ additionalEntries,
|
||||
span.nextSpan(name),
|
||||
callback,
|
||||
)
|
||||
|
||||
try {
|
||||
context.logInfo("span entry")
|
||||
f(context)
|
||||
} finally {
|
||||
context.logInfo("span exit")
|
||||
}
|
||||
}
|
||||
|
||||
def childSpan[A](
|
||||
name: String,
|
||||
additionalEntries: (String, LoggingValue)*
|
||||
)(f: TriggerLogContext => A)(implicit
|
||||
logger: ContextualizedLogger
|
||||
): A = {
|
||||
val context = new TriggerLogContext(
|
||||
loggingContext,
|
||||
entries ++ additionalEntries,
|
||||
span.childSpan(name),
|
||||
callback,
|
||||
)
|
||||
|
||||
try {
|
||||
context.logInfo("span entry")
|
||||
f(context)
|
||||
} finally {
|
||||
context.logInfo("span exit")
|
||||
}
|
||||
}
|
||||
|
||||
def groupWith(contexts: TriggerLogContext*): TriggerLogContext = {
|
||||
val groupEntries = contexts.foldLeft(entries.toSet) { case (entries, context) =>
|
||||
entries ++ context.entries.toSet
|
||||
}
|
||||
val groupSpans = contexts.foldLeft(span) { case (span, context) =>
|
||||
span.groupWith(context.span)
|
||||
}
|
||||
|
||||
new TriggerLogContext(loggingContext, groupEntries.toSeq, groupSpans, callback)
|
||||
}
|
||||
|
||||
def logError(message: String, additionalEntries: (String, LoggingValue)*)(implicit
|
||||
logger: ContextualizedLogger
|
||||
): Unit = {
|
||||
enrichTriggerContext(additionalEntries: _*) { implicit triggerContext: TriggerLogContext =>
|
||||
callback(message, triggerContext)
|
||||
logger.error(message)
|
||||
}
|
||||
}
|
||||
|
||||
def logWarning(message: String, additionalEntries: (String, LoggingValue)*)(implicit
|
||||
logger: ContextualizedLogger
|
||||
): Unit = {
|
||||
enrichTriggerContext(additionalEntries: _*) { implicit triggerContext: TriggerLogContext =>
|
||||
callback(message, triggerContext)
|
||||
logger.warn(message)
|
||||
}
|
||||
}
|
||||
|
||||
def logInfo(message: String, additionalEntries: (String, LoggingValue)*)(implicit
|
||||
logger: ContextualizedLogger
|
||||
): Unit = {
|
||||
enrichTriggerContext(additionalEntries: _*) { implicit triggerContext: TriggerLogContext =>
|
||||
callback(message, triggerContext)
|
||||
logger.info(message)
|
||||
}
|
||||
}
|
||||
|
||||
def logDebug(message: String, additionalEntries: (String, LoggingValue)*)(implicit
|
||||
logger: ContextualizedLogger
|
||||
): Unit = {
|
||||
enrichTriggerContext(additionalEntries: _*) { implicit triggerContext: TriggerLogContext =>
|
||||
callback(message, triggerContext)
|
||||
logger.debug(message)
|
||||
}
|
||||
}
|
||||
|
||||
def logTrace(message: String, additionalEntries: (String, LoggingValue)*)(implicit
|
||||
logger: ContextualizedLogger
|
||||
): Unit = {
|
||||
enrichTriggerContext(additionalEntries: _*) { implicit triggerContext: TriggerLogContext =>
|
||||
callback(message, triggerContext)
|
||||
logger.trace(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object TriggerLogContext {
|
||||
def newRootSpan[A](
|
||||
span: String,
|
||||
entries: (String, LoggingValue)*
|
||||
)(f: TriggerLogContext => A)(implicit loggingContext: LoggingContextOf[Trigger]): A = {
|
||||
new TriggerLogContext(
|
||||
loggingContext,
|
||||
entries,
|
||||
TriggerLogSpan(BackStack(span)),
|
||||
(_, _) => (),
|
||||
).enrichTriggerContext()(f)
|
||||
}
|
||||
|
||||
private[trigger] def newRootSpanWithCallback[A](
|
||||
span: String,
|
||||
callback: (String, TriggerLogContext) => Unit,
|
||||
entries: (String, LoggingValue)*
|
||||
)(f: TriggerLogContext => A)(implicit loggingContext: LoggingContextOf[Trigger]): A = {
|
||||
new TriggerLogContext(
|
||||
loggingContext,
|
||||
entries,
|
||||
TriggerLogSpan(BackStack(span)),
|
||||
callback,
|
||||
).enrichTriggerContext()(f)
|
||||
}
|
||||
|
||||
private[trigger] final case class TriggerLogSpan(
|
||||
path: BackStack[String],
|
||||
id: UUID = UUID.randomUUID(),
|
||||
parent: Set[UUID] = Set.empty,
|
||||
) {
|
||||
def nextSpan(name: String): TriggerLogSpan = {
|
||||
val basePath = path.pop.fold(BackStack.empty[String])(_._1)
|
||||
|
||||
TriggerLogSpan(basePath :+ name, parent = parent)
|
||||
}
|
||||
|
||||
def childSpan(name: String): TriggerLogSpan = {
|
||||
TriggerLogSpan(path :+ name, parent = Set(id))
|
||||
}
|
||||
|
||||
def groupWith(span: TriggerLogSpan): TriggerLogSpan = {
|
||||
copy(parent = parent + span.id)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,83 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.lf.engine.trigger
|
||||
|
||||
import scala.concurrent.duration._
|
||||
|
||||
/** Trigger hard limits. If any of these values are exceeded, then the current trigger instance will throw a
|
||||
* `TriggerHardLimitException` and stop running.
|
||||
*
|
||||
* @param maximumActiveContracts Maximum number of active contracts that we will store at any point in time.
|
||||
* @param inFlightCommandOverflowCount When the number of in-flight command submissions exceeds this value, then we
|
||||
* kill the trigger instance by throwing an InFlightCommandOverflowException.
|
||||
* @param allowInFlightCommandOverflows flag to control whether we allow in-flight command overflows or not.
|
||||
* @param ruleEvaluationTimeout If the trigger rule evaluator takes longer than this timeout value, then we throw a
|
||||
* TriggerRuleEvaluationTimeout.
|
||||
* @param stepInterpreterTimeout If the trigger rule step evaluator (during rule evaluation) takes longer than this
|
||||
* timeout value, then we throw a TriggerRuleStepInterpretationTimeout.
|
||||
* @param allowTriggerTimeouts flag to control whether we allow rule evaluation and step interpreter timeouts or not.
|
||||
*/
|
||||
final case class TriggerRunnerHardLimits(
|
||||
maximumActiveContracts: Long,
|
||||
inFlightCommandOverflowCount: Int,
|
||||
allowInFlightCommandOverflows: Boolean,
|
||||
ruleEvaluationTimeout: FiniteDuration,
|
||||
stepInterpreterTimeout: FiniteDuration,
|
||||
allowTriggerTimeouts: Boolean,
|
||||
)
|
||||
|
||||
/** @param parallelism The number of submitSingleCommand invocations each trigger will attempt to execute in parallel.
|
||||
* Note that this does not in any way bound the number of already-submitted, but not completed,
|
||||
* commands that may be pending.
|
||||
* @param maxRetries Maximum number of retries when the ledger client fails an API command submission.
|
||||
* @param maxSubmissionRequests Used to control rate at which we throttle ledger client submission requests.
|
||||
* @param maxSubmissionDuration Used to control rate at which we throttle ledger client submission requests.
|
||||
* @param inFlightCommandBackPressureCount When the number of in-flight command submissions exceeds this value, then we
|
||||
* enable Daml rule evaluation to apply backpressure (by failing emitCommands
|
||||
* evaluations).
|
||||
* @param submissionFailureQueueSize Size of the queue holding ledger API command submission failures.
|
||||
* @param maximumBatchSize Maximum number of messages triggers will batch (for rule evaluation/processing).
|
||||
* @param batchingDuration Period of time we will wait before emitting a message batch (for rule evaluation/processing).
|
||||
*/
|
||||
final case class TriggerRunnerConfig(
|
||||
parallelism: Int,
|
||||
maxRetries: Int,
|
||||
maxSubmissionRequests: Int,
|
||||
maxSubmissionDuration: FiniteDuration,
|
||||
inFlightCommandBackPressureCount: Long,
|
||||
submissionFailureQueueSize: Int,
|
||||
maximumBatchSize: Long,
|
||||
batchingDuration: FiniteDuration,
|
||||
hardLimit: TriggerRunnerHardLimits,
|
||||
)
|
||||
|
||||
object TriggerRunnerConfig {
|
||||
val DefaultTriggerRunnerConfig: TriggerRunnerConfig = {
|
||||
val parallelism = 8
|
||||
val maxSubmissionRequests = 100
|
||||
val maxSubmissionDuration = 5.seconds
|
||||
|
||||
TriggerRunnerConfig(
|
||||
parallelism = parallelism,
|
||||
maxRetries = 6,
|
||||
maxSubmissionRequests = maxSubmissionRequests,
|
||||
maxSubmissionDuration = maxSubmissionDuration,
|
||||
inFlightCommandBackPressureCount = 1000,
|
||||
// 256 here comes from the default ExecutionContext.
|
||||
submissionFailureQueueSize = 256 + parallelism,
|
||||
maximumBatchSize = 1000,
|
||||
batchingDuration = 250.milliseconds,
|
||||
hardLimit = TriggerRunnerHardLimits(
|
||||
maximumActiveContracts = 20000,
|
||||
inFlightCommandOverflowCount = 10000,
|
||||
allowInFlightCommandOverflows = true,
|
||||
// 50% extra on the maxSubmissionDuration value
|
||||
ruleEvaluationTimeout = maxSubmissionDuration * 3 / 2,
|
||||
// 50% extra on mean time between submission requests
|
||||
stepInterpreterTimeout = (maxSubmissionDuration / maxSubmissionRequests.toLong) * 3 / 2,
|
||||
allowTriggerTimeouts = false,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
@ -1,232 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.lf.engine.trigger
|
||||
|
||||
import scalaz.{-\/, Bifunctor, \/, \/-}
|
||||
import scalaz.syntax.bifunctor._
|
||||
import scalaz.std.option.some
|
||||
import scalaz.std.tuple._
|
||||
import org.apache.pekko.NotUsed
|
||||
import org.apache.pekko.stream.{BidiShape, FanOutShape2, Graph, Inlet, Outlet}
|
||||
import org.apache.pekko.stream.scaladsl.{Concat, Flow, GraphDSL, Partition, Source}
|
||||
|
||||
import scala.annotation.tailrec
|
||||
|
||||
import com.daml.scalautil.Statement.discard
|
||||
|
||||
import scala.collection.immutable.{IndexedSeq, Iterable, LinearSeq}
|
||||
|
||||
/** A variant of [[scalaz.CorecursiveList]] that emits a final state
|
||||
* at the end of the list.
|
||||
*/
|
||||
private[trigger] sealed abstract class UnfoldState[T, A] {
|
||||
type S
|
||||
val init: S
|
||||
val step: S => T \/ (A, S)
|
||||
|
||||
def withInit(init: S): UnfoldState.Aux[S, T, A]
|
||||
|
||||
final def foreach(f: A => Unit): T = {
|
||||
@tailrec def go(s: S): T = step(s) match {
|
||||
case -\/(t) => t
|
||||
case \/-((a, s2)) =>
|
||||
f(a)
|
||||
go(s2)
|
||||
}
|
||||
go(init)
|
||||
}
|
||||
|
||||
private[trigger] final def iterator(): Iterator[T \/ A] =
|
||||
new Iterator[T \/ A] {
|
||||
var last = some(step(init))
|
||||
override def hasNext = last.isDefined
|
||||
override def next() = last match {
|
||||
case Some(\/-((a, s))) =>
|
||||
last = Some(step(s))
|
||||
\/-(a)
|
||||
case Some(et @ -\/(_)) =>
|
||||
last = None
|
||||
et
|
||||
case None =>
|
||||
throw new IllegalStateException("iterator read past end")
|
||||
}
|
||||
}
|
||||
|
||||
final def runTo[FA](implicit factory: collection.Factory[A, FA]): (FA, T) = {
|
||||
val b = factory.newBuilder
|
||||
val t = foreach(a => discard(b += a))
|
||||
(b.result(), t)
|
||||
}
|
||||
}
|
||||
|
||||
private[trigger] object UnfoldState {
|
||||
type Aux[S0, T, A] = UnfoldState[T, A] { type S = S0 }
|
||||
|
||||
def apply[S, T, A](init: S)(step: S => T \/ (A, S)): UnfoldState[T, A] = {
|
||||
type S0 = S
|
||||
final case class UnfoldStateImpl(init: S, step: S => T \/ (A, S)) extends UnfoldState[T, A] {
|
||||
type S = S0
|
||||
override def withInit(init: S) = copy(init = init)
|
||||
}
|
||||
UnfoldStateImpl(init, step)
|
||||
}
|
||||
|
||||
implicit def `US bifunctor instance`: Bifunctor[UnfoldState] = new Bifunctor[UnfoldState] {
|
||||
override def bimap[A, B, C, D](fab: UnfoldState[A, B])(f: A => C, g: B => D) =
|
||||
UnfoldState(fab.init)(fab.step andThen (_.bimap(f, (_ leftMap g))))
|
||||
}
|
||||
|
||||
def fromLinearSeq[A](list: LinearSeq[A]): UnfoldState[Unit, A] = {
|
||||
type Sr = Unit \/ (A, LinearSeq[A])
|
||||
apply(list) {
|
||||
case hd +: tl => \/-((hd, tl)): Sr
|
||||
case _ => -\/(()): Sr
|
||||
}
|
||||
}
|
||||
|
||||
def fromIndexedSeq[A](vector: IndexedSeq[A]): UnfoldState[Unit, A] = {
|
||||
type Sr = Unit \/ (A, Int)
|
||||
apply(0) { n =>
|
||||
if (vector.sizeIs > n) \/-((vector(n), n + 1)): Sr
|
||||
else -\/(()): Sr
|
||||
}
|
||||
}
|
||||
|
||||
implicit final class toSourceOps[T, A](private val self: SourceShape2[T, A]) {
|
||||
def elemsOut: Outlet[A] = self.out2
|
||||
def finalState: Outlet[T] = self.out1
|
||||
}
|
||||
|
||||
def toSource[T, A](us: UnfoldState[T, A]): Graph[SourceShape2[T, A], NotUsed] =
|
||||
GraphDSL.create() { implicit gb =>
|
||||
import GraphDSL.Implicits._
|
||||
val split = gb add partition[T, A]
|
||||
Source.fromIterator(() => us.iterator()) ~> split.in
|
||||
SourceShape2(split.out0, split.out1)
|
||||
}
|
||||
|
||||
/** A stateful but pure version of built-in flatMapConcat.
|
||||
* (flatMapMerge does not make sense, because parallelism
|
||||
* with linear state does not make sense.)
|
||||
*/
|
||||
def flatMapConcat[T, A, B](zero: T)(f: (T, A) => UnfoldState[T, B]): Flow[A, B, NotUsed] =
|
||||
flatMapConcatStates(zero)(f) collect { case \/-(b) => b }
|
||||
|
||||
/** Like `flatMapConcat` but emit the new state after each unfolded list.
|
||||
* The pattern you will see is a bunch of right Bs, followed by a single
|
||||
* left T, then repeat until close, with a final T unless aborted.
|
||||
*/
|
||||
def flatMapConcatStates[T, A, B](zero: T)(
|
||||
f: (T, A) => UnfoldState[T, B]
|
||||
): Flow[A, T \/ B, NotUsed] =
|
||||
Flow[A].statefulMapConcat(() => mkMapConcatFun(zero, f))
|
||||
|
||||
type UnfoldStateShape[T, -A, +B] = BidiShape[T, B, A, T]
|
||||
implicit final class flatMapConcatNodeOps[IT, B, A, OT](
|
||||
private val self: BidiShape[IT, B, A, OT]
|
||||
) {
|
||||
def initState: Inlet[IT] = self.in1
|
||||
def elemsIn: Inlet[A] = self.in2
|
||||
def elemsOut: Outlet[B] = self.out1
|
||||
def finalStates: Outlet[OT] = self.out2
|
||||
}
|
||||
|
||||
/** Accept 1 initial state on in1, fold over in2 elements, emitting the output
|
||||
* elements on out1, and each result state after each result of `f` is unfolded
|
||||
* on out2.
|
||||
*/
|
||||
def flatMapConcatNode[T, A, B](
|
||||
f: (T, A) => UnfoldState[T, B]
|
||||
): Graph[UnfoldStateShape[T, A, B], NotUsed] =
|
||||
GraphDSL.create() { implicit gb =>
|
||||
import GraphDSL.Implicits._
|
||||
val initialT = gb add (Flow fromFunction \/.left[T, A] take 1)
|
||||
val as = gb add (Flow fromFunction \/.right[T, A])
|
||||
val tas = gb add Concat[T \/ A](2) // ensure that T arrives *before* A
|
||||
val splat = gb add (Flow[T \/ A] statefulMapConcat (() => statefulMapConcatFun(f)))
|
||||
val split = gb add partition[T, B]
|
||||
// format: off
|
||||
discard { initialT ~> tas }
|
||||
discard { as ~> tas ~> splat ~> split.in }
|
||||
// format: on
|
||||
new BidiShape(initialT.in, split.out1, as.in, split.out0)
|
||||
}
|
||||
|
||||
// TODO factor with ContractsFetch
|
||||
private[this] def partition[A, B]: Graph[FanOutShape2[A \/ B, A, B], NotUsed] =
|
||||
GraphDSL.create() { implicit b =>
|
||||
import GraphDSL.Implicits._
|
||||
val split = b.add(
|
||||
Partition[A \/ B](
|
||||
2,
|
||||
{
|
||||
case -\/(_) => 0
|
||||
case \/-(_) => 1
|
||||
},
|
||||
)
|
||||
)
|
||||
val as = b.add(Flow[A \/ B].collect { case -\/(a) => a })
|
||||
val bs = b.add(Flow[A \/ B].collect { case \/-(b) => b })
|
||||
discard { split ~> as }
|
||||
discard { split ~> bs }
|
||||
new FanOutShape2(split.in, as.out, bs.out)
|
||||
}
|
||||
|
||||
private[this] type FoldL[T, -A, B] = (T, A) => UnfoldState[T, B]
|
||||
|
||||
private[this] def statefulMapConcatFun[T, A, B](f: FoldL[T, A, B]): T \/ A => Iterable[T \/ B] = {
|
||||
var mcFun: A => Iterable[T \/ B] = null
|
||||
_.fold(
|
||||
zeroT => {
|
||||
mcFun = mkMapConcatFun(zeroT, f)
|
||||
Iterable.empty
|
||||
},
|
||||
{ a =>
|
||||
mcFun(a)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
private[this] def mkMapConcatFun[T, A, B](zero: T, f: FoldL[T, A, B]): A => Iterable[T \/ B] = {
|
||||
var t = zero
|
||||
// statefulMapConcat only uses 'iterator'. We preserve the Iterable's
|
||||
// immutability by making one strict reference to the 't' var at 'Iterable' creation
|
||||
// time, meaning any later 'iterator' call uses the same start state, no matter
|
||||
// whether the 't' has been updated
|
||||
a =>
|
||||
new Iterable[T \/ B] {
|
||||
private[this] val bs = f(t, a)
|
||||
import bs.step
|
||||
override def iterator = new Iterator[T \/ B] {
|
||||
private[this] var last: Option[T \/ (B, bs.S)] = {
|
||||
val fst = step(bs.init)
|
||||
fst.fold(newT => t = newT, _ => ())
|
||||
Some(fst)
|
||||
}
|
||||
|
||||
// this stream is "odd", i.e. we are always evaluating 1 step ahead
|
||||
// of what the client sees. We could improve laziness by making it
|
||||
// "even", but it would be a little trickier, as `hasNext` would have
|
||||
// a forcing side-effect
|
||||
override def hasNext = last.isDefined
|
||||
|
||||
override def next() =
|
||||
last match {
|
||||
case Some(\/-((b, s))) =>
|
||||
val next = step(s)
|
||||
// The assumption here is that statefulMapConcat's implementation
|
||||
// will always read iterator to end before invoking on the next A
|
||||
next.fold(newT => t = newT, _ => ())
|
||||
last = Some(next)
|
||||
\/-(b)
|
||||
case Some(et @ -\/(_)) =>
|
||||
last = None
|
||||
et
|
||||
case None =>
|
||||
throw new IllegalStateException("iterator read past end")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,416 +0,0 @@
|
||||
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
load("@oracle//:index.bzl", "oracle_tags")
|
||||
load("@build_environment//:configuration.bzl", "sdk_version")
|
||||
load(
|
||||
"//daml-lf/language:daml-lf.bzl",
|
||||
"LF_DEV_VERSIONS",
|
||||
"LF_MAJOR_VERSIONS",
|
||||
"lf_version_default_or_latest",
|
||||
"lf_versions_aggregate",
|
||||
)
|
||||
load("@os_info//:os_info.bzl", "is_windows")
|
||||
load(
|
||||
"//bazel_tools:scala.bzl",
|
||||
"da_scala_binary",
|
||||
"da_scala_library",
|
||||
"da_scala_test_suite",
|
||||
"lf_scalacopts_stricter",
|
||||
)
|
||||
|
||||
TRIGGER_MAIN = "src/main/scala/com/digitalasset/daml/lf/engine/trigger/ServiceMain.scala"
|
||||
|
||||
target_lf_versions = lf_versions_aggregate(
|
||||
[lf_version_default_or_latest(major) for major in LF_MAJOR_VERSIONS] +
|
||||
LF_DEV_VERSIONS,
|
||||
)
|
||||
|
||||
da_scala_library(
|
||||
name = "trigger-service",
|
||||
srcs = glob(
|
||||
["src/main/scala/**/*.scala"],
|
||||
exclude = [TRIGGER_MAIN],
|
||||
),
|
||||
resources = glob(["src/main/resources/**/*"]),
|
||||
scala_deps = [
|
||||
"@maven//:com_chuusai_shapeless",
|
||||
"@maven//:com_github_pureconfig_pureconfig_core",
|
||||
"@maven//:com_github_pureconfig_pureconfig_generic",
|
||||
"@maven//:com_github_scopt_scopt",
|
||||
"@maven//:com_typesafe_scala_logging_scala_logging",
|
||||
"@maven//:io_spray_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_actor",
|
||||
"@maven//:org_apache_pekko_pekko_actor_typed",
|
||||
"@maven//:org_apache_pekko_pekko_http",
|
||||
"@maven//:org_apache_pekko_pekko_http_core",
|
||||
"@maven//:org_apache_pekko_pekko_http_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_stream",
|
||||
"@maven//:org_parboiled_parboiled",
|
||||
"@maven//:org_scalaz_scalaz_core",
|
||||
"@maven//:org_tpolecat_doobie_core",
|
||||
"@maven//:org_tpolecat_doobie_free",
|
||||
"@maven//:org_tpolecat_doobie_postgres",
|
||||
"@maven//:org_typelevel_cats_core",
|
||||
"@maven//:org_typelevel_cats_effect",
|
||||
"@maven//:org_typelevel_cats_free",
|
||||
"@maven//:org_typelevel_cats_kernel",
|
||||
],
|
||||
scala_runtime_deps = [
|
||||
"@maven//:org_apache_pekko_pekko_slf4j",
|
||||
"@maven//:org_tpolecat_doobie_postgres",
|
||||
],
|
||||
scalacopts = lf_scalacopts_stricter,
|
||||
# Uncomment this if/when the target is published to maven.
|
||||
# tags = ["maven_coordinates=com.daml:trigger-service:__VERSION__"],
|
||||
visibility = ["//visibility:public"],
|
||||
runtime_deps = [
|
||||
"@maven//:ch_qos_logback_logback_classic",
|
||||
"@maven//:ch_qos_logback_logback_core",
|
||||
"@maven//:org_postgresql_postgresql",
|
||||
],
|
||||
deps = [
|
||||
"//canton:ledger_api_proto_scala",
|
||||
"//daml-lf/archive:daml_lf_archive_reader",
|
||||
"//daml-lf/archive:daml_lf_dev_archive_proto_java",
|
||||
"//daml-lf/data",
|
||||
"//daml-lf/engine",
|
||||
"//daml-lf/interpreter",
|
||||
"//daml-lf/language",
|
||||
"//ledger-service/cli-opts",
|
||||
"//ledger-service/metrics",
|
||||
"//ledger-service/pureconfig-utils",
|
||||
"//ledger/ledger-api-client",
|
||||
"//ledger/ledger-api-common",
|
||||
"//libs-scala/contextualized-logging",
|
||||
"//libs-scala/db-utils",
|
||||
"//libs-scala/doobie-slf4j",
|
||||
"//libs-scala/ledger-resources",
|
||||
"//libs-scala/resources",
|
||||
"//libs-scala/rs-grpc-bridge",
|
||||
"//libs-scala/rs-grpc-pekko",
|
||||
"//libs-scala/scala-utils",
|
||||
"//observability/metrics",
|
||||
"//observability/pekko-http-metrics",
|
||||
"//triggers/metrics",
|
||||
"//triggers/runner:trigger-runner-lib",
|
||||
"//triggers/service/auth:middleware-api",
|
||||
"@maven//:ch_qos_logback_logback_classic",
|
||||
"@maven//:com_typesafe_config",
|
||||
"@maven//:io_dropwizard_metrics_metrics_core",
|
||||
"@maven//:io_netty_netty_handler",
|
||||
"@maven//:io_opentelemetry_opentelemetry_api",
|
||||
"@maven//:org_flywaydb_flyway_core",
|
||||
"@maven//:org_slf4j_slf4j_api",
|
||||
],
|
||||
)
|
||||
|
||||
scala_binary_deps = [
|
||||
"@maven//:org_apache_pekko_pekko_actor",
|
||||
"@maven//:org_apache_pekko_pekko_actor_typed",
|
||||
"@maven//:org_apache_pekko_pekko_http_core",
|
||||
"@maven//:com_typesafe_scala_logging_scala_logging",
|
||||
"@maven//:org_scalaz_scalaz_core",
|
||||
]
|
||||
|
||||
binary_deps = [
|
||||
"//daml-lf/archive:daml_lf_archive_reader",
|
||||
"//daml-lf/archive:daml_lf_dev_archive_proto_java",
|
||||
"//daml-lf/data",
|
||||
"//daml-lf/interpreter",
|
||||
"//ledger/ledger-api-common",
|
||||
"//libs-scala/contextualized-logging",
|
||||
"//libs-scala/db-utils",
|
||||
"//libs-scala/ports",
|
||||
"//libs-scala/scala-utils",
|
||||
"//observability/metrics",
|
||||
"//triggers/runner:trigger-runner-lib",
|
||||
"//triggers/service/auth:middleware-api",
|
||||
":trigger-service",
|
||||
"@maven//:ch_qos_logback_logback_classic",
|
||||
"@maven//:com_google_protobuf_protobuf_java",
|
||||
"@maven//:org_slf4j_slf4j_api",
|
||||
]
|
||||
|
||||
trigger_service_runtime_deps = {
|
||||
"ce": [],
|
||||
"ee": ["@maven//:com_oracle_database_jdbc_ojdbc8"],
|
||||
}
|
||||
|
||||
[
|
||||
da_scala_binary(
|
||||
name = "trigger-service-binary-{}".format(edition),
|
||||
srcs = [TRIGGER_MAIN],
|
||||
main_class = "com.daml.lf.engine.trigger.ServiceMain",
|
||||
resource_strip_prefix = "triggers/service/release/trigger-service-",
|
||||
resources = ["release/trigger-service-logback.xml"],
|
||||
scala_deps = scala_binary_deps,
|
||||
scalacopts = lf_scalacopts_stricter,
|
||||
tags = ["ee-jar-license"] if edition == "ee" else [],
|
||||
visibility = ["//visibility:public"],
|
||||
runtime_deps = trigger_service_runtime_deps.get(edition),
|
||||
deps = binary_deps + [
|
||||
"//runtime-components/jdbc-drivers:jdbc-drivers-{}".format(edition),
|
||||
],
|
||||
)
|
||||
for edition in [
|
||||
"ce",
|
||||
"ee",
|
||||
]
|
||||
]
|
||||
|
||||
da_scala_library(
|
||||
name = "trigger-service-tests",
|
||||
srcs = glob(["src/test/scala/com/digitalasset/daml/lf/engine/trigger/*.scala"]),
|
||||
data = [
|
||||
":test-model-{}.dar".format(lf_version)
|
||||
for lf_version in target_lf_versions
|
||||
] + [
|
||||
":test-model-v{}.dar".format(major)
|
||||
for major in LF_MAJOR_VERSIONS
|
||||
] + (
|
||||
[
|
||||
"@toxiproxy_dev_env//:bin/toxiproxy-server",
|
||||
] if not is_windows else [
|
||||
"@toxiproxy_dev_env//:toxiproxy-server-windows-amd64.exe",
|
||||
]
|
||||
),
|
||||
resources = glob(["src/test/resources/**/*"]),
|
||||
scala_deps = [
|
||||
"@maven//:com_lihaoyi_sourcecode",
|
||||
"@maven//:com_typesafe_scala_logging_scala_logging",
|
||||
"@maven//:io_spray_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_actor",
|
||||
"@maven//:org_apache_pekko_pekko_actor_typed",
|
||||
"@maven//:org_apache_pekko_pekko_http_core",
|
||||
"@maven//:org_apache_pekko_pekko_stream",
|
||||
"@maven//:org_parboiled_parboiled",
|
||||
"@maven//:org_scalactic_scalactic",
|
||||
"@maven//:org_scalatest_scalatest_core",
|
||||
"@maven//:org_scalatest_scalatest_flatspec",
|
||||
"@maven//:org_scalatest_scalatest_matchers_core",
|
||||
"@maven//:org_scalatest_scalatest_shouldmatchers",
|
||||
"@maven//:org_scalaz_scalaz_core",
|
||||
],
|
||||
visibility = ["//test-evidence:__pkg__"],
|
||||
deps = [
|
||||
":trigger-service",
|
||||
":trigger-service-binary-ce",
|
||||
"//bazel_tools/runfiles:scala_runfiles",
|
||||
"//canton:ledger_api_proto_scala",
|
||||
"//daml-lf/archive:daml_lf_archive_reader",
|
||||
"//daml-lf/archive:daml_lf_dev_archive_proto_java",
|
||||
"//daml-lf/data",
|
||||
"//daml-lf/interpreter",
|
||||
"//daml-lf/language",
|
||||
"//ledger/ledger-api-auth",
|
||||
"//ledger/ledger-api-client",
|
||||
"//ledger/ledger-api-common",
|
||||
"//ledger/ledger-api-domain",
|
||||
"//libs-scala/adjustable-clock",
|
||||
"//libs-scala/db-utils",
|
||||
"//libs-scala/jwt",
|
||||
"//libs-scala/ledger-resources",
|
||||
"//libs-scala/oracle-testing",
|
||||
"//libs-scala/ports",
|
||||
"//libs-scala/ports:ports-testing",
|
||||
"//libs-scala/postgresql-testing",
|
||||
"//libs-scala/resources",
|
||||
"//libs-scala/rs-grpc-bridge",
|
||||
"//libs-scala/scala-utils",
|
||||
"//libs-scala/test-evidence/scalatest:test-evidence-scalatest",
|
||||
"//libs-scala/test-evidence/tag:test-evidence-tag",
|
||||
"//libs-scala/testing-utils",
|
||||
"//libs-scala/timer-utils",
|
||||
"//observability/metrics",
|
||||
"//test-common/canton/it-lib",
|
||||
"//triggers/runner:trigger-runner-lib",
|
||||
"//triggers/service/auth:middleware-api",
|
||||
"//triggers/service/auth:oauth2-middleware",
|
||||
"//triggers/service/auth:oauth2-test-server",
|
||||
"@maven//:ch_qos_logback_logback_classic",
|
||||
"@maven//:ch_qos_logback_logback_core",
|
||||
"@maven//:com_auth0_java_jwt",
|
||||
"@maven//:com_google_protobuf_protobuf_java",
|
||||
"@maven//:eu_rekawek_toxiproxy_toxiproxy_java_2_1_7",
|
||||
"@maven//:org_scalatest_scalatest_compatible",
|
||||
"@maven//:org_slf4j_slf4j_api",
|
||||
],
|
||||
)
|
||||
|
||||
da_scala_test_suite(
|
||||
name = "test",
|
||||
timeout = "long",
|
||||
srcs = glob(
|
||||
["src/test-suite/scala/**/*.scala"],
|
||||
exclude = ["**/*Oracle*"],
|
||||
),
|
||||
data = [
|
||||
":src/test-suite/resources/trigger-service.conf",
|
||||
":src/test-suite/resources/trigger-service-minimal.conf",
|
||||
],
|
||||
scala_deps = [
|
||||
"@maven//:com_github_pureconfig_pureconfig_core",
|
||||
"@maven//:com_github_scopt_scopt",
|
||||
"@maven//:com_typesafe_scala_logging_scala_logging",
|
||||
"@maven//:io_spray_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_actor",
|
||||
"@maven//:org_apache_pekko_pekko_http_core",
|
||||
"@maven//:org_apache_pekko_pekko_stream",
|
||||
"@maven//:org_parboiled_parboiled",
|
||||
"@maven//:org_scalactic_scalactic",
|
||||
"@maven//:org_scalatest_scalatest_core",
|
||||
"@maven//:org_scalatest_scalatest_flatspec",
|
||||
"@maven//:org_scalatest_scalatest_matchers_core",
|
||||
"@maven//:org_scalatest_scalatest_shouldmatchers",
|
||||
],
|
||||
visibility = ["//test-evidence:__pkg__"],
|
||||
deps = [
|
||||
":trigger-service",
|
||||
":trigger-service-tests",
|
||||
"//bazel_tools/runfiles:scala_runfiles",
|
||||
"//canton:ledger_api_proto_scala",
|
||||
"//daml-lf/archive:daml_lf_archive_reader",
|
||||
"//daml-lf/archive:daml_lf_dev_archive_proto_java",
|
||||
"//daml-lf/data",
|
||||
"//daml-lf/interpreter",
|
||||
"//daml-lf/language",
|
||||
"//ledger-service/cli-opts",
|
||||
"//ledger-service/pureconfig-utils",
|
||||
"//ledger/ledger-api-auth",
|
||||
"//ledger/ledger-api-client",
|
||||
"//ledger/ledger-api-common",
|
||||
"//libs-scala/adjustable-clock",
|
||||
"//libs-scala/db-utils",
|
||||
"//libs-scala/flyway-testing",
|
||||
"//libs-scala/jwt",
|
||||
"//libs-scala/ledger-resources",
|
||||
"//libs-scala/ports",
|
||||
"//libs-scala/postgresql-testing",
|
||||
"//libs-scala/resources",
|
||||
"//libs-scala/rs-grpc-bridge",
|
||||
"//libs-scala/test-evidence/tag:test-evidence-tag",
|
||||
"//libs-scala/testing-utils",
|
||||
"//libs-scala/timer-utils",
|
||||
"//test-common/canton/it-lib",
|
||||
"//triggers/runner:trigger-runner-lib",
|
||||
"//triggers/service/auth:middleware-api",
|
||||
"//triggers/service/auth:oauth2-test-server",
|
||||
"@maven//:ch_qos_logback_logback_classic",
|
||||
"@maven//:ch_qos_logback_logback_core",
|
||||
"@maven//:eu_rekawek_toxiproxy_toxiproxy_java_2_1_7",
|
||||
"@maven//:org_flywaydb_flyway_core",
|
||||
"@maven//:org_scalatest_scalatest_compatible",
|
||||
],
|
||||
)
|
||||
|
||||
da_scala_test_suite(
|
||||
name = "test-oracle",
|
||||
timeout = "long",
|
||||
srcs = glob(["src/test-suite/scala/**/*Oracle*.scala"]),
|
||||
scala_deps = [
|
||||
"@maven//:io_spray_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_actor",
|
||||
"@maven//:org_apache_pekko_pekko_http_core",
|
||||
"@maven//:org_apache_pekko_pekko_stream",
|
||||
"@maven//:org_scalatest_scalatest_core",
|
||||
"@maven//:org_scalatest_scalatest_matchers_core",
|
||||
"@maven//:org_scalatest_scalatest_shouldmatchers",
|
||||
],
|
||||
tags = oracle_tags,
|
||||
runtime_deps = [
|
||||
"@maven//:com_oracle_database_jdbc_ojdbc8",
|
||||
],
|
||||
deps = [
|
||||
":trigger-service",
|
||||
":trigger-service-tests",
|
||||
"//canton:ledger_api_proto_scala",
|
||||
"//daml-lf/archive:daml_lf_archive_reader",
|
||||
"//daml-lf/archive:daml_lf_dev_archive_proto_java",
|
||||
"//daml-lf/data",
|
||||
"//daml-lf/language",
|
||||
"//ledger/ledger-api-auth",
|
||||
"//ledger/ledger-api-client",
|
||||
"//ledger/ledger-api-common",
|
||||
"//libs-scala/adjustable-clock",
|
||||
"//libs-scala/db-utils",
|
||||
"//libs-scala/jwt",
|
||||
"//libs-scala/ledger-resources",
|
||||
"//libs-scala/oracle-testing",
|
||||
"//libs-scala/ports",
|
||||
"//libs-scala/resources",
|
||||
"//libs-scala/rs-grpc-bridge",
|
||||
"//libs-scala/test-evidence/tag:test-evidence-tag",
|
||||
"//libs-scala/testing-utils",
|
||||
"//test-common/canton/it-lib",
|
||||
"//triggers/runner:trigger-runner-lib",
|
||||
"//triggers/service/auth:oauth2-test-server",
|
||||
"@maven//:ch_qos_logback_logback_classic",
|
||||
"@maven//:ch_qos_logback_logback_core",
|
||||
"@maven//:eu_rekawek_toxiproxy_toxiproxy_java_2_1_7",
|
||||
"@maven//:org_scalatest_scalatest_compatible",
|
||||
],
|
||||
)
|
||||
|
||||
# Build one DAR per LF version to bundle with the SDK.
|
||||
[
|
||||
genrule(
|
||||
name = "test-model-{}".format(lf_version),
|
||||
srcs =
|
||||
glob(["test-model/*.daml"]) + [
|
||||
"//triggers/daml:daml-trigger-{}".format(lf_version),
|
||||
"//daml-script/daml:daml-script-{}".format(lf_version),
|
||||
],
|
||||
outs = ["test-model-{}.dar".format(lf_version)],
|
||||
cmd = """
|
||||
set -eou pipefail
|
||||
TMP_DIR=$$(mktemp -d)
|
||||
mkdir -p $$TMP_DIR/daml
|
||||
cp -L $(location :test-model/TestTrigger.daml) $$TMP_DIR/daml
|
||||
cp -L $(location :test-model/ErrorTrigger.daml) $$TMP_DIR/daml
|
||||
cp -L $(location :test-model/LowLevelErrorTrigger.daml) $$TMP_DIR/daml
|
||||
cp -L $(location :test-model/ReadAs.daml) $$TMP_DIR/daml
|
||||
cp -L $(location :test-model/Cats.daml) $$TMP_DIR/daml
|
||||
cp -L $(location {daml_trigger}) $$TMP_DIR/daml-trigger.dar
|
||||
cp -L $(location {daml_script}) $$TMP_DIR/daml-script.dar
|
||||
cat << EOF > $$TMP_DIR/daml.yaml
|
||||
sdk-version: {sdk}
|
||||
name: test-model
|
||||
source: daml
|
||||
version: 0.0.1
|
||||
dependencies:
|
||||
- daml-stdlib
|
||||
- daml-prim
|
||||
- daml-trigger.dar
|
||||
- daml-script.dar
|
||||
build-options: {build_options}
|
||||
EOF
|
||||
$(location //compiler/damlc) build --project-root=$$TMP_DIR --ghc-option=-Werror -o $$PWD/$@
|
||||
rm -rf $$TMP_DIR
|
||||
""".format(
|
||||
build_options = str([
|
||||
"--target",
|
||||
lf_version,
|
||||
]),
|
||||
daml_script = "//daml-script/daml:daml-script-{}".format(lf_version),
|
||||
daml_trigger = "//triggers/daml:daml-trigger-{}".format(lf_version),
|
||||
sdk = sdk_version,
|
||||
),
|
||||
tools = ["//compiler/damlc"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
for lf_version in target_lf_versions
|
||||
]
|
||||
|
||||
[
|
||||
genrule(
|
||||
name = "test-model-{}".format(major),
|
||||
srcs = [":test-model-{}".format(lf_version_default_or_latest(major))],
|
||||
outs = ["test-model-v{}.dar".format(major)],
|
||||
cmd = "cp -L $(location :test-model-{}) $$PWD/$@".format(lf_version_default_or_latest(major)),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
for major in LF_MAJOR_VERSIONS
|
||||
]
|
||||
|
||||
exports_files(["release/trigger-service-logback.xml"])
|
@ -1,271 +0,0 @@
|
||||
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
load(
|
||||
"//bazel_tools:scala.bzl",
|
||||
"da_scala_binary",
|
||||
"da_scala_library",
|
||||
"da_scala_test",
|
||||
"lf_scalacopts_stricter",
|
||||
)
|
||||
|
||||
exports_files(["release/oauth2-middleware-logback.xml"])
|
||||
|
||||
test_scalacopts = ["-P:wartremover:traverser:org.wartremover.warts.OptionPartial"]
|
||||
|
||||
da_scala_library(
|
||||
name = "jwt-cli-opts",
|
||||
srcs = ["src/main/scala/com/daml/JwtVerifierConfigurationCli.scala"],
|
||||
scala_deps = [
|
||||
"@maven//:com_github_scopt_scopt",
|
||||
"@maven//:com_typesafe_scala_logging_scala_logging",
|
||||
"@maven//:org_scalaz_scalaz_core",
|
||||
],
|
||||
scalacopts = lf_scalacopts_stricter,
|
||||
deps = [
|
||||
"//libs-scala/jwt",
|
||||
"@maven//:com_auth0_java_jwt",
|
||||
],
|
||||
)
|
||||
|
||||
da_scala_test(
|
||||
name = "jwt-cli-opts-tests",
|
||||
srcs = ["src/test/scala/com/daml/JwtVerifierConfigurationCliSpec.scala"],
|
||||
scala_deps = [
|
||||
"@maven//:com_github_scopt_scopt",
|
||||
"@maven//:org_scalactic_scalactic",
|
||||
"@maven//:org_scalatest_scalatest_core",
|
||||
"@maven//:org_scalatest_scalatest_matchers_core",
|
||||
"@maven//:org_scalatest_scalatest_shouldmatchers",
|
||||
"@maven//:org_scalatest_scalatest_wordspec",
|
||||
],
|
||||
scalacopts = lf_scalacopts_stricter,
|
||||
deps = [
|
||||
":jwt-cli-opts",
|
||||
"//ledger/ledger-api-auth",
|
||||
"//libs-scala/fs-utils",
|
||||
"//libs-scala/http-test-utils",
|
||||
"//libs-scala/jwt",
|
||||
"//libs-scala/resources",
|
||||
"//libs-scala/scala-utils",
|
||||
"@maven//:com_auth0_java_jwt",
|
||||
"@maven//:io_grpc_grpc_api",
|
||||
"@maven//:org_bouncycastle_bcpkix_jdk15on",
|
||||
"@maven//:org_bouncycastle_bcprov_jdk15on",
|
||||
],
|
||||
)
|
||||
|
||||
da_scala_library(
|
||||
name = "oauth2-api",
|
||||
srcs = glob(["src/main/scala/com/daml/auth/oauth2/api/**/*.scala"]),
|
||||
scala_deps = [
|
||||
"@maven//:io_spray_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_actor",
|
||||
"@maven//:org_apache_pekko_pekko_http",
|
||||
"@maven//:org_apache_pekko_pekko_http_core",
|
||||
"@maven//:org_apache_pekko_pekko_stream",
|
||||
"@maven//:org_parboiled_parboiled",
|
||||
],
|
||||
scalacopts = lf_scalacopts_stricter,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
da_scala_library(
|
||||
name = "middleware-api",
|
||||
srcs = glob(["src/main/scala/com/daml/auth/middleware/api/**/*.scala"]),
|
||||
scala_deps = [
|
||||
"@maven//:io_spray_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_actor",
|
||||
"@maven//:org_apache_pekko_pekko_http",
|
||||
"@maven//:org_apache_pekko_pekko_http_core",
|
||||
"@maven//:org_apache_pekko_pekko_http_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_stream",
|
||||
"@maven//:org_parboiled_parboiled",
|
||||
"@maven//:org_scalaz_scalaz_core",
|
||||
],
|
||||
scalacopts = lf_scalacopts_stricter,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//daml-lf/data",
|
||||
],
|
||||
)
|
||||
|
||||
da_scala_library(
|
||||
name = "oauth2-middleware",
|
||||
srcs = glob(["src/main/scala/com/daml/auth/middleware/oauth2/**/*.scala"]),
|
||||
resources = glob(["src/main/resources/com/daml/auth/middleware/oauth2/**"]),
|
||||
scala_deps = [
|
||||
"@maven//:com_chuusai_shapeless",
|
||||
"@maven//:com_github_pureconfig_pureconfig_core",
|
||||
"@maven//:com_github_pureconfig_pureconfig_generic",
|
||||
"@maven//:com_github_scopt_scopt",
|
||||
"@maven//:com_typesafe_scala_logging_scala_logging",
|
||||
"@maven//:io_spray_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_actor",
|
||||
"@maven//:org_apache_pekko_pekko_http",
|
||||
"@maven//:org_apache_pekko_pekko_http_core",
|
||||
"@maven//:org_apache_pekko_pekko_http_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_stream",
|
||||
"@maven//:org_parboiled_parboiled",
|
||||
"@maven//:org_scalaz_scalaz_core",
|
||||
"@triggers_maven//:com_lihaoyi_fastparse",
|
||||
"@triggers_maven//:com_lihaoyi_geny",
|
||||
"@triggers_maven//:com_lihaoyi_os_lib",
|
||||
"@triggers_maven//:com_lihaoyi_sjsonnet",
|
||||
"@triggers_maven//:com_lihaoyi_ujson",
|
||||
"@triggers_maven//:com_lihaoyi_upickle_core",
|
||||
],
|
||||
scalacopts = lf_scalacopts_stricter,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":jwt-cli-opts",
|
||||
":middleware-api",
|
||||
":oauth2-api",
|
||||
"//daml-lf/data",
|
||||
"//ledger-service/cli-opts",
|
||||
"//ledger-service/metrics",
|
||||
"//ledger-service/pureconfig-utils",
|
||||
"//ledger/ledger-api-auth",
|
||||
"//ledger/ledger-api-common",
|
||||
"//libs-scala/jwt",
|
||||
"//libs-scala/ledger-resources",
|
||||
"//libs-scala/ports",
|
||||
"//libs-scala/resources",
|
||||
"//libs-scala/scala-utils",
|
||||
"//observability/metrics",
|
||||
"//observability/pekko-http-metrics",
|
||||
"@maven//:com_typesafe_config",
|
||||
"@maven//:io_dropwizard_metrics_metrics_core",
|
||||
"@maven//:io_opentelemetry_opentelemetry_api",
|
||||
"@maven//:org_slf4j_slf4j_api",
|
||||
],
|
||||
)
|
||||
|
||||
da_scala_binary(
|
||||
name = "oauth2-middleware-binary",
|
||||
main_class = "com.daml.auth.middleware.oauth2.Main",
|
||||
scalacopts = lf_scalacopts_stricter,
|
||||
visibility = ["//visibility:public"],
|
||||
runtime_deps = [
|
||||
"@maven//:ch_qos_logback_logback_classic",
|
||||
],
|
||||
deps = [
|
||||
":oauth2-middleware",
|
||||
],
|
||||
)
|
||||
|
||||
da_scala_library(
|
||||
name = "oauth2-test-server",
|
||||
srcs = glob(["src/main/scala/com/daml/auth/oauth2/test/server/**/*.scala"]),
|
||||
scala_deps = [
|
||||
"@maven//:com_github_scopt_scopt",
|
||||
"@maven//:com_typesafe_scala_logging_scala_logging",
|
||||
"@maven//:io_spray_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_actor",
|
||||
"@maven//:org_apache_pekko_pekko_http",
|
||||
"@maven//:org_apache_pekko_pekko_http_core",
|
||||
"@maven//:org_apache_pekko_pekko_http_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_stream",
|
||||
"@maven//:org_parboiled_parboiled",
|
||||
"@maven//:org_scalaz_scalaz_core",
|
||||
],
|
||||
visibility = ["//triggers/service:__subpackages__"],
|
||||
deps = [
|
||||
":oauth2-api",
|
||||
"//daml-lf/data",
|
||||
"//ledger/ledger-api-auth",
|
||||
"//libs-scala/jwt",
|
||||
"//libs-scala/ports",
|
||||
"@maven//:org_slf4j_slf4j_api",
|
||||
],
|
||||
)
|
||||
|
||||
da_scala_binary(
|
||||
name = "oauth2-test-server-binary",
|
||||
main_class = "com.daml.auth.oauth2.test.server.Main",
|
||||
scalacopts = lf_scalacopts_stricter,
|
||||
runtime_deps = [
|
||||
"@maven//:ch_qos_logback_logback_classic",
|
||||
],
|
||||
deps = [
|
||||
":oauth2-test-server",
|
||||
],
|
||||
)
|
||||
|
||||
da_scala_test(
|
||||
name = "oauth2-test-server-tests",
|
||||
srcs = glob(["src/test/scala/com/daml/auth/oauth2/test/server/**/*.scala"]),
|
||||
scala_deps = [
|
||||
"@maven//:io_spray_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_actor",
|
||||
"@maven//:org_apache_pekko_pekko_http",
|
||||
"@maven//:org_apache_pekko_pekko_http_core",
|
||||
"@maven//:org_apache_pekko_pekko_http_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_stream",
|
||||
"@maven//:org_parboiled_parboiled",
|
||||
"@maven//:org_scalaz_scalaz_core",
|
||||
],
|
||||
scalacopts = test_scalacopts,
|
||||
deps = [
|
||||
":oauth2-api",
|
||||
":oauth2-test-server",
|
||||
"//daml-lf/data",
|
||||
"//ledger/ledger-api-auth",
|
||||
"//libs-scala/adjustable-clock",
|
||||
"//libs-scala/jwt",
|
||||
"//libs-scala/ledger-resources",
|
||||
"//libs-scala/ports",
|
||||
"//libs-scala/resources",
|
||||
"//libs-scala/rs-grpc-bridge",
|
||||
"//libs-scala/testing-utils",
|
||||
],
|
||||
)
|
||||
|
||||
da_scala_test(
|
||||
name = "oauth2-middleware-tests",
|
||||
srcs = glob(["src/test/scala/com/daml/auth/middleware/oauth2/**/*.scala"]),
|
||||
data = [
|
||||
":src/test/resources/oauth2-middleware.conf",
|
||||
":src/test/resources/oauth2-middleware-minimal.conf",
|
||||
],
|
||||
scala_deps = [
|
||||
"@maven//:com_github_pureconfig_pureconfig_core",
|
||||
"@maven//:com_lihaoyi_sourcecode",
|
||||
"@maven//:com_typesafe_scala_logging_scala_logging",
|
||||
"@maven//:io_spray_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_actor",
|
||||
"@maven//:org_apache_pekko_pekko_http",
|
||||
"@maven//:org_apache_pekko_pekko_http_core",
|
||||
"@maven//:org_apache_pekko_pekko_http_spray_json",
|
||||
"@maven//:org_apache_pekko_pekko_http_testkit",
|
||||
"@maven//:org_apache_pekko_pekko_stream",
|
||||
"@maven//:org_parboiled_parboiled",
|
||||
"@maven//:org_scalaz_scalaz_core",
|
||||
],
|
||||
scala_runtime_deps = [
|
||||
"@maven//:org_apache_pekko_pekko_stream_testkit",
|
||||
],
|
||||
visibility = ["//test-evidence:__pkg__"],
|
||||
deps = [
|
||||
":middleware-api",
|
||||
":oauth2-api",
|
||||
":oauth2-middleware",
|
||||
":oauth2-test-server",
|
||||
"//bazel_tools/runfiles:scala_runfiles",
|
||||
"//daml-lf/data",
|
||||
"//ledger/ledger-api-auth",
|
||||
"//libs-scala/adjustable-clock",
|
||||
"//libs-scala/jwt",
|
||||
"//libs-scala/ledger-resources",
|
||||
"//libs-scala/ports",
|
||||
"//libs-scala/resources",
|
||||
"//libs-scala/rs-grpc-bridge",
|
||||
"//libs-scala/scala-utils",
|
||||
"//libs-scala/test-evidence/scalatest:test-evidence-scalatest",
|
||||
"//libs-scala/test-evidence/tag:test-evidence-tag",
|
||||
"//libs-scala/testing-utils",
|
||||
"//observability/metrics",
|
||||
"@maven//:com_auth0_java_jwt",
|
||||
"@maven//:com_typesafe_config",
|
||||
],
|
||||
)
|
@ -1,25 +0,0 @@
|
||||
<configuration>
|
||||
<appender name="console" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<if condition='isDefined("LOG_FORMAT_JSON")'>
|
||||
<then>
|
||||
<encoder class="net.logstash.logback.encoder.LogstashEncoder"/>
|
||||
</then>
|
||||
<else>
|
||||
<encoder>
|
||||
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
|
||||
</encoder>
|
||||
</else>
|
||||
</if>
|
||||
</appender>
|
||||
|
||||
<appender name="STDOUT" class="net.logstash.logback.appender.LoggingEventAsyncDisruptorAppender">
|
||||
<appender-ref ref="console"/>
|
||||
</appender>
|
||||
|
||||
<logger name="io.netty" level="WARN"/>
|
||||
<logger name="io.grpc.netty" level="WARN"/>
|
||||
|
||||
<root level="${LOG_LEVEL_ROOT:-INFO}">
|
||||
<appender-ref ref="STDOUT" />
|
||||
</root>
|
||||
</configuration>
|
@ -1,15 +0,0 @@
|
||||
local scope(claims) =
|
||||
local admin = if claims.admin then "admin";
|
||||
local applicationId = if claims.applicationId != null then "applicationId:" + claims.applicationId;
|
||||
local actAs = std.map(function(p) "actAs:" + p, claims.actAs);
|
||||
local readAs = std.map(function(p) "readAs:" + p, claims.readAs);
|
||||
[admin, applicationId] + actAs + readAs;
|
||||
|
||||
function(config, request) {
|
||||
"audience": "https://daml.com/ledger-api",
|
||||
"client_id": config.clientId,
|
||||
"redirect_uri": request.redirectUri,
|
||||
"response_type": "code",
|
||||
"scope": std.join(" ", ["offline_access"] + scope(request.claims)),
|
||||
"state": request.state,
|
||||
}
|
@ -1,6 +0,0 @@
|
||||
function(config, request) {
|
||||
"client_id": config.clientId,
|
||||
"client_secret": config.clientSecret,
|
||||
"grant_type": "refresh_code",
|
||||
"refresh_token": request.refreshToken,
|
||||
}
|
@ -1,7 +0,0 @@
|
||||
function(config, request) {
|
||||
"client_id": config.clientId,
|
||||
"client_secret": config.clientSecret,
|
||||
"code": request.code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": request.redirectUri,
|
||||
}
|
@ -1,95 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.jwt
|
||||
|
||||
import java.nio.file.Paths
|
||||
|
||||
import com.auth0.jwt.algorithms.Algorithm
|
||||
|
||||
import scala.util.Try
|
||||
|
||||
@SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements"))
|
||||
object JwtVerifierConfigurationCli {
|
||||
def parse[C](parser: scopt.OptionParser[C])(setter: (JwtVerifierBase, C) => C): Unit = {
|
||||
def setJwtVerifier(jwtVerifier: JwtVerifierBase, c: C): C = setter(jwtVerifier, c)
|
||||
|
||||
import parser.opt
|
||||
|
||||
opt[String]("auth-jwt-hs256-unsafe")
|
||||
.optional()
|
||||
.hidden()
|
||||
.validate(v => Either.cond(v.nonEmpty, (), "HMAC secret must be a non-empty string"))
|
||||
.text(
|
||||
"[UNSAFE] Enables JWT-based authorization with shared secret HMAC256 signing: USE THIS EXCLUSIVELY FOR TESTING"
|
||||
)
|
||||
.action { (secret, config) =>
|
||||
val verifier = HMAC256Verifier(secret)
|
||||
.valueOr(err => sys.error(s"Failed to create HMAC256 verifier: $err"))
|
||||
setJwtVerifier(verifier, config)
|
||||
}
|
||||
|
||||
opt[String]("auth-jwt-rs256-crt")
|
||||
.optional()
|
||||
.validate(
|
||||
validatePath(_, "The certificate file specified via --auth-jwt-rs256-crt does not exist")
|
||||
)
|
||||
.text(
|
||||
"Enables JWT-based authorization, where the JWT is signed by RSA256 with a public key loaded from the given X509 certificate file (.crt)"
|
||||
)
|
||||
.action { (path, config) =>
|
||||
val verifier = RSA256Verifier
|
||||
.fromCrtFile(path)
|
||||
.valueOr(err => sys.error(s"Failed to create RSA256 verifier: $err"))
|
||||
setJwtVerifier(verifier, config)
|
||||
}
|
||||
|
||||
opt[String]("auth-jwt-es256-crt")
|
||||
.optional()
|
||||
.validate(
|
||||
validatePath(_, "The certificate file specified via --auth-jwt-es256-crt does not exist")
|
||||
)
|
||||
.text(
|
||||
"Enables JWT-based authorization, where the JWT is signed by ECDSA256 with a public key loaded from the given X509 certificate file (.crt)"
|
||||
)
|
||||
.action { (path, config) =>
|
||||
val verifier = ECDSAVerifier
|
||||
.fromCrtFile(path, Algorithm.ECDSA256(_, null))
|
||||
.valueOr(err => sys.error(s"Failed to create ECDSA256 verifier: $err"))
|
||||
setJwtVerifier(verifier, config)
|
||||
}
|
||||
|
||||
opt[String]("auth-jwt-es512-crt")
|
||||
.optional()
|
||||
.validate(
|
||||
validatePath(_, "The certificate file specified via --auth-jwt-es512-crt does not exist")
|
||||
)
|
||||
.text(
|
||||
"Enables JWT-based authorization, where the JWT is signed by ECDSA512 with a public key loaded from the given X509 certificate file (.crt)"
|
||||
)
|
||||
.action { (path, config) =>
|
||||
val verifier = ECDSAVerifier
|
||||
.fromCrtFile(path, Algorithm.ECDSA512(_, null))
|
||||
.valueOr(err => sys.error(s"Failed to create ECDSA512 verifier: $err"))
|
||||
setJwtVerifier(verifier, config)
|
||||
}
|
||||
|
||||
opt[String]("auth-jwt-rs256-jwks")
|
||||
.optional()
|
||||
.validate(v => Either.cond(v.length > 0, (), "JWK server URL must be a non-empty string"))
|
||||
.text(
|
||||
"Enables JWT-based authorization, where the JWT is signed by RSA256 with a public key loaded from the given JWKS URL"
|
||||
)
|
||||
.action { (url, config) =>
|
||||
val verifier = JwksVerifier(url)
|
||||
setJwtVerifier(verifier, config)
|
||||
}
|
||||
|
||||
()
|
||||
}
|
||||
|
||||
private def validatePath(path: String, message: String): Either[String, Unit] = {
|
||||
val valid = Try(Paths.get(path).toFile.canRead).getOrElse(false)
|
||||
if (valid) Right(()) else Left(message)
|
||||
}
|
||||
}
|
@ -1,376 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.api
|
||||
|
||||
import java.util.UUID
|
||||
|
||||
import org.apache.pekko.actor.ActorSystem
|
||||
import org.apache.pekko.http.scaladsl.Http
|
||||
import org.apache.pekko.http.scaladsl.marshalling.Marshal
|
||||
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
|
||||
import org.apache.pekko.http.scaladsl.model.Uri.Path
|
||||
import org.apache.pekko.http.scaladsl.model.{
|
||||
HttpMethods,
|
||||
HttpRequest,
|
||||
MediaTypes,
|
||||
RequestEntity,
|
||||
StatusCode,
|
||||
StatusCodes,
|
||||
Uri,
|
||||
headers,
|
||||
}
|
||||
import org.apache.pekko.http.scaladsl.server.{
|
||||
ContentNegotiator,
|
||||
Directive,
|
||||
Directive0,
|
||||
Directive1,
|
||||
Route,
|
||||
StandardRoute,
|
||||
}
|
||||
import org.apache.pekko.http.scaladsl.server.Directives._
|
||||
import org.apache.pekko.http.scaladsl.unmarshalling.Unmarshal
|
||||
import com.daml.auth.middleware.api.Client.{AuthException, RedirectToLogin, RefreshException}
|
||||
import com.daml.auth.middleware.api.Tagged.RefreshToken
|
||||
|
||||
import scala.collection.immutable
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
import scala.concurrent.duration.FiniteDuration
|
||||
|
||||
/** Client component for interaction with the auth middleware
|
||||
*
|
||||
* A client of the auth middleware is typically itself a web-application that serves HTTP requests.
|
||||
* Note, a [[Client]] maintains state that needs to persist across such requests.
|
||||
* In particular, you should not create the [[Client]] instance within a [[Route]].
|
||||
*
|
||||
* This may pose a challenge when the client application uses dynamic port binding,
|
||||
* e.g. for testing purposes, as the login URI's redirect parameter may depend on the port.
|
||||
* To that end the login and request handler components that may depend on the port
|
||||
* are provided in a separate [[com.daml.auth.middleware.api.Client.Routes]] class
|
||||
* that does not need to maintain state across requests and can safely be constructed
|
||||
* within a [[Route]] using the [[routes]] family of methods.
|
||||
*/
|
||||
class Client(config: Client.Config) {
|
||||
private val callbacks: RequestStore[UUID, Response.Login => Route] = new RequestStore(
|
||||
config.maxAuthCallbacks,
|
||||
config.authCallbackTimeout,
|
||||
)
|
||||
|
||||
/** Create a [[Client.Routes]] based on an absolute login redirect URI.
|
||||
*/
|
||||
def routes(callbackUri: Uri): Client.Routes = {
|
||||
assert(
|
||||
callbackUri.isAbsolute,
|
||||
"The authorization middleware client callback URI must be absolute.",
|
||||
)
|
||||
RoutesImpl(callbackUri)
|
||||
}
|
||||
|
||||
/** Create a [[Client.Routes]] based on a path to be appended to the requests URI scheme and authority.
|
||||
*
|
||||
* E.g. given `callbackPath = "cb"` and a request to `http://my.client/foo/bar`
|
||||
* the redirect URI would take the form `http://my.client/cb`.
|
||||
*/
|
||||
def routesFromRequestAuthority(callbackPath: Uri.Path): Directive1[Client.Routes] =
|
||||
extractUri.map { reqUri =>
|
||||
RoutesImpl(
|
||||
Uri()
|
||||
.withScheme(reqUri.scheme)
|
||||
.withAuthority(reqUri.authority)
|
||||
.withPath(callbackPath)
|
||||
): Client.Routes
|
||||
}
|
||||
|
||||
/** Equivalent to [[routes]] if [[callbackUri]] is an absolute URI,
|
||||
* otherwise equivalent to [[routesFromRequestAuthority]].
|
||||
*/
|
||||
def routesAuto(callbackUri: Uri): Directive1[Client.Routes] = {
|
||||
if (callbackUri.isAbsolute) {
|
||||
provide(routes(callbackUri))
|
||||
} else {
|
||||
routesFromRequestAuthority(callbackUri.path)
|
||||
}
|
||||
}
|
||||
|
||||
private case class RoutesImpl(callbackUri: Uri) extends Client.Routes {
|
||||
val callbackHandler: Route =
|
||||
parameters(Symbol("state").as[UUID]) { requestId =>
|
||||
callbacks.pop(requestId) match {
|
||||
case None =>
|
||||
complete(StatusCodes.NotFound)
|
||||
case Some(callback) =>
|
||||
Response.Login.callbackParameters { callback }
|
||||
}
|
||||
}
|
||||
|
||||
private val isHtmlRequest: Directive1[Boolean] = extractRequest.map { req =>
|
||||
val negotiator = ContentNegotiator(req.headers)
|
||||
val contentTypes = List(
|
||||
ContentNegotiator.Alternative(MediaTypes.`application/json`),
|
||||
ContentNegotiator.Alternative(MediaTypes.`text/html`),
|
||||
)
|
||||
val preferred = negotiator.pickContentType(contentTypes)
|
||||
preferred.map(_.mediaType) == Some(MediaTypes.`text/html`)
|
||||
}
|
||||
|
||||
/** Pass control to the inner directive if we should redirect to login on auth failure, reject otherwise.
|
||||
*/
|
||||
private val onRedirectToLogin: Directive0 =
|
||||
config.redirectToLogin match {
|
||||
case RedirectToLogin.No => reject
|
||||
case RedirectToLogin.Yes => pass
|
||||
case RedirectToLogin.Auto =>
|
||||
isHtmlRequest.flatMap {
|
||||
case false => reject
|
||||
case true => pass
|
||||
}
|
||||
}
|
||||
|
||||
def authorize(claims: Request.Claims): Directive1[Client.AuthorizeResult] = {
|
||||
auth(claims).flatMap {
|
||||
// Authorization successful - pass token to continuation
|
||||
case Some(authorization) => provide(Client.Authorized(authorization))
|
||||
// Authorization failed - login and retry on callback request.
|
||||
case None =>
|
||||
onRedirectToLogin
|
||||
.tflatMap { _ =>
|
||||
// Ensure that the request is fully uploaded.
|
||||
val timeout = config.httpEntityUploadTimeout
|
||||
val maxBytes = config.maxHttpEntityUploadSize
|
||||
toStrictEntity(timeout, maxBytes).tflatMap { _ =>
|
||||
extractRequestContext.flatMap { ctx =>
|
||||
Directive { (inner: Tuple1[Client.AuthorizeResult] => Route) =>
|
||||
def continue(result: Client.AuthorizeResult): Route =
|
||||
mapRequestContext(_ => ctx) {
|
||||
inner(Tuple1(result))
|
||||
}
|
||||
val callback: Response.Login => Route = {
|
||||
case Response.LoginSuccess =>
|
||||
auth(claims) {
|
||||
case None => continue(Client.Unauthorized)
|
||||
case Some(authorization) => continue(Client.Authorized(authorization))
|
||||
}
|
||||
case loginError: Response.LoginError =>
|
||||
continue(Client.LoginFailed(loginError))
|
||||
}
|
||||
login(claims, callback)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.or(unauthorized(claims))
|
||||
}
|
||||
}
|
||||
|
||||
/** This directive attempts to obtain an access token from the middleware's auth endpoint for the given claims.
|
||||
*
|
||||
* Forwards the current request's cookies. Completes with 500 on an unexpected response from the auth middleware.
|
||||
*
|
||||
* @return `None` if the request was denied otherwise `Some` access and optionally refresh token.
|
||||
*/
|
||||
private def auth(claims: Request.Claims): Directive1[Option[Response.Authorize]] =
|
||||
extractExecutionContext.flatMap { implicit ec =>
|
||||
extractActorSystem.flatMap { implicit system =>
|
||||
extract(_.request.headers[headers.Cookie]).flatMap { cookies =>
|
||||
onSuccess(requestAuth(claims, cookies))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Return a 401 Unauthorized response.
|
||||
*
|
||||
* Includes a `WWW-Authenticate` header with a custom challenge to login at the auth middleware.
|
||||
* Lists the required claims in the `realm` and the login URI in the `login` parameter
|
||||
* and the auth URI in the `auth` parameter.
|
||||
*
|
||||
* The challenge is also included in the response body
|
||||
* as some browsers make it difficult to access the `WWW-Authenticate` header.
|
||||
*/
|
||||
private def unauthorized(claims: Request.Claims): StandardRoute = {
|
||||
import com.daml.auth.middleware.api.JsonProtocol.responseAuthenticateChallengeFormat
|
||||
val challenge = Response.AuthenticateChallenge(
|
||||
claims,
|
||||
loginUri(claims, None, false),
|
||||
authUri(claims),
|
||||
)
|
||||
complete(
|
||||
status = StatusCodes.Unauthorized,
|
||||
headers = immutable.Seq(challenge.toHeader),
|
||||
challenge,
|
||||
)
|
||||
}
|
||||
|
||||
def login(claims: Request.Claims, callback: Response.Login => Route): Route = {
|
||||
val requestId = UUID.randomUUID()
|
||||
if (callbacks.put(requestId, callback)) {
|
||||
redirect(loginUri(claims, Some(requestId)), StatusCodes.Found)
|
||||
} else {
|
||||
complete(StatusCodes.ServiceUnavailable)
|
||||
}
|
||||
}
|
||||
|
||||
def loginUri(
|
||||
claims: Request.Claims,
|
||||
requestId: Option[UUID] = None,
|
||||
redirect: Boolean = true,
|
||||
): Uri = {
|
||||
val redirectUri =
|
||||
if (redirect) { Some(callbackUri) }
|
||||
else { None }
|
||||
appendToUri(
|
||||
config.authMiddlewareExternalUri,
|
||||
Path./("login"),
|
||||
Request.Login(redirectUri, claims, requestId.map(_.toString)).toQuery,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/** Request authentication/authorization on the auth middleware's auth endpoint.
|
||||
*
|
||||
* @return `None` if the request was denied otherwise `Some` access and optionally refresh token.
|
||||
*/
|
||||
def requestAuth(claims: Request.Claims, cookies: immutable.Seq[headers.Cookie])(implicit
|
||||
ec: ExecutionContext,
|
||||
system: ActorSystem,
|
||||
): Future[Option[Response.Authorize]] =
|
||||
for {
|
||||
response <- Http().singleRequest(
|
||||
HttpRequest(
|
||||
method = HttpMethods.GET,
|
||||
uri = authUri(claims),
|
||||
headers = cookies,
|
||||
)
|
||||
)
|
||||
authorize <- response.status match {
|
||||
case StatusCodes.OK =>
|
||||
import JsonProtocol.responseAuthorizeFormat
|
||||
Unmarshal(response.entity).to[Response.Authorize].map(Some(_))
|
||||
case StatusCodes.Unauthorized =>
|
||||
Future.successful(None)
|
||||
case status =>
|
||||
Unmarshal(response).to[String].flatMap { msg =>
|
||||
Future.failed(AuthException(status, msg))
|
||||
}
|
||||
}
|
||||
} yield authorize
|
||||
|
||||
/** Request a token refresh on the auth middleware's refresh endpoint.
|
||||
*/
|
||||
def requestRefresh(
|
||||
refreshToken: RefreshToken
|
||||
)(implicit ec: ExecutionContext, system: ActorSystem): Future[Response.Authorize] =
|
||||
for {
|
||||
requestEntity <- {
|
||||
import JsonProtocol.requestRefreshFormat
|
||||
Marshal(Request.Refresh(refreshToken))
|
||||
.to[RequestEntity]
|
||||
}
|
||||
response <- Http().singleRequest(
|
||||
HttpRequest(
|
||||
method = HttpMethods.POST,
|
||||
uri = refreshUri,
|
||||
entity = requestEntity,
|
||||
)
|
||||
)
|
||||
authorize <- response.status match {
|
||||
case StatusCodes.OK =>
|
||||
import JsonProtocol._
|
||||
Unmarshal(response.entity).to[Response.Authorize]
|
||||
case status =>
|
||||
Unmarshal(response).to[String].flatMap { msg =>
|
||||
Future.failed(RefreshException(status, msg))
|
||||
}
|
||||
}
|
||||
} yield authorize
|
||||
|
||||
private def appendToUri(uri: Uri, path: Uri.Path, query: Uri.Query = Uri.Query.Empty): Uri = {
|
||||
val newPath: Uri.Path = uri.path ++ path
|
||||
val newQueryParams: Seq[(String, String)] = uri.query().toSeq ++ query.toSeq
|
||||
val newQuery = Uri.Query(newQueryParams: _*)
|
||||
uri.withPath(newPath).withQuery(newQuery)
|
||||
}
|
||||
|
||||
def authUri(claims: Request.Claims): Uri =
|
||||
appendToUri(
|
||||
config.authMiddlewareInternalUri,
|
||||
Path./("auth"),
|
||||
Request.Auth(claims).toQuery,
|
||||
)
|
||||
|
||||
val refreshUri: Uri =
|
||||
appendToUri(config.authMiddlewareInternalUri, Path./("refresh"))
|
||||
}
|
||||
|
||||
object Client {
|
||||
sealed trait AuthorizeResult
|
||||
final case class Authorized(authorization: Response.Authorize) extends AuthorizeResult
|
||||
object Unauthorized extends AuthorizeResult
|
||||
final case class LoginFailed(loginError: Response.LoginError) extends AuthorizeResult
|
||||
|
||||
abstract class ClientException(message: String) extends RuntimeException(message)
|
||||
case class AuthException(status: StatusCode, message: String)
|
||||
extends ClientException(s"Failed to authorize with middleware ($status): $message")
|
||||
case class RefreshException(status: StatusCode, message: String)
|
||||
extends ClientException(s"Failed to refresh token on middleware ($status): $message")
|
||||
|
||||
/** Whether to automatically redirect to the login endpoint when authorization fails.
|
||||
*
|
||||
* [[RedirectToLogin.Auto]] redirects for HTML requests (`text/html`)
|
||||
* and returns 401 Unauthorized for JSON requests (`application/json`).
|
||||
*/
|
||||
sealed trait RedirectToLogin
|
||||
object RedirectToLogin {
|
||||
object No extends RedirectToLogin
|
||||
object Yes extends RedirectToLogin
|
||||
object Auto extends RedirectToLogin
|
||||
}
|
||||
|
||||
case class Config(
|
||||
authMiddlewareInternalUri: Uri,
|
||||
authMiddlewareExternalUri: Uri,
|
||||
redirectToLogin: RedirectToLogin,
|
||||
maxAuthCallbacks: Int,
|
||||
authCallbackTimeout: FiniteDuration,
|
||||
maxHttpEntityUploadSize: Long,
|
||||
httpEntityUploadTimeout: FiniteDuration,
|
||||
)
|
||||
|
||||
trait Routes {
|
||||
|
||||
/** Handler for the callback in a login flow.
|
||||
*
|
||||
* Note, a GET request on the `callbackUri` must map to this route.
|
||||
*/
|
||||
def callbackHandler: Route
|
||||
|
||||
/** This directive requires authorization for the given claims via the auth middleware.
|
||||
*
|
||||
* Authorization follows the steps defined in `triggers/service/authentication.md`.
|
||||
* 1. Ask for a token on the `/auth` endpoint and return it if granted.
|
||||
* 2a. Return 401 Unauthorized if denied and [[Client.Config.redirectToLogin]]
|
||||
* indicates not to redirect to the login endpoint.
|
||||
* 2b. Redirect to the login endpoint if denied and [[Client.Config.redirectToLogin]]
|
||||
* indicates to redirect to the login endpoint.
|
||||
* In this case this will store the current continuation to proceed
|
||||
* once the login flow completed and authentication succeeded.
|
||||
* A route for the [[callbackHandler]] must be configured.
|
||||
*/
|
||||
def authorize(claims: Request.Claims): Directive1[Client.AuthorizeResult]
|
||||
|
||||
/** Redirect the client to login with the auth middleware.
|
||||
*
|
||||
* Will respond with 503 if the callback store is full ([[Client.Config.maxAuthCallbacks]]).
|
||||
*
|
||||
* @param callback Will be stored and executed once the login flow completed.
|
||||
*/
|
||||
def login(claims: Request.Claims, callback: Response.Login => Route): Route
|
||||
|
||||
def loginUri(
|
||||
claims: Request.Claims,
|
||||
requestId: Option[UUID] = None,
|
||||
redirect: Boolean = true,
|
||||
): Uri
|
||||
}
|
||||
|
||||
def apply(config: Config): Client = new Client(config)
|
||||
}
|
@ -1,226 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.api
|
||||
|
||||
import org.apache.pekko.http.scaladsl.marshalling.Marshaller
|
||||
import org.apache.pekko.http.scaladsl.model.headers.HttpChallenge
|
||||
import org.apache.pekko.http.scaladsl.model.{HttpHeader, Uri, headers}
|
||||
import org.apache.pekko.http.scaladsl.server.Directive1
|
||||
import org.apache.pekko.http.scaladsl.server.Directives._
|
||||
import org.apache.pekko.http.scaladsl.unmarshalling.Unmarshaller
|
||||
import com.daml.lf.data.Ref
|
||||
import scalaz.{@@, Tag}
|
||||
import spray.json._
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.concurrent._
|
||||
import scala.util.Try
|
||||
import scala.language.postfixOps
|
||||
|
||||
object Tagged {
|
||||
sealed trait AccessTokenTag
|
||||
type AccessToken = String @@ AccessTokenTag
|
||||
val AccessToken = Tag.of[AccessTokenTag]
|
||||
|
||||
sealed trait RefreshTokenTag
|
||||
type RefreshToken = String @@ RefreshTokenTag
|
||||
val RefreshToken = Tag.of[RefreshTokenTag]
|
||||
}
|
||||
|
||||
object Request {
|
||||
import Tagged._
|
||||
|
||||
// applicationId = None makes no guarantees about the application ID. You can use this
|
||||
// if you don’t use the token for requests that use the application ID.
|
||||
// applicationId = Some(appId) will return a token that is valid for
|
||||
// appId, i.e., either a wildcard token or a token with applicationId set to appId.
|
||||
case class Claims(
|
||||
admin: Boolean,
|
||||
actAs: List[Ref.Party],
|
||||
readAs: List[Ref.Party],
|
||||
applicationId: Option[Ref.ApplicationId],
|
||||
) {
|
||||
def toQueryString() = {
|
||||
val adminS = if (admin) LazyList("admin") else LazyList()
|
||||
val actAsS = actAs.to(LazyList).map(party => s"actAs:$party")
|
||||
val readAsS = readAs.to(LazyList).map(party => s"readAs:$party")
|
||||
val applicationIdS = applicationId.toList.to(LazyList).map(appId => s"applicationId:$appId")
|
||||
(adminS ++ actAsS ++ readAsS ++ applicationIdS).mkString(" ")
|
||||
}
|
||||
}
|
||||
object Claims {
|
||||
def apply(
|
||||
admin: Boolean = false,
|
||||
actAs: List[Ref.Party] = List(),
|
||||
readAs: List[Ref.Party] = List(),
|
||||
applicationId: Option[Ref.ApplicationId] = None,
|
||||
): Claims =
|
||||
new Claims(admin, actAs, readAs, applicationId)
|
||||
def apply(s: String): Claims = {
|
||||
var admin = false
|
||||
val actAs = ArrayBuffer[Ref.Party]()
|
||||
val readAs = ArrayBuffer[Ref.Party]()
|
||||
var applicationId: Option[Ref.ApplicationId] = None
|
||||
s.split(' ').foreach { w =>
|
||||
if (w == "admin") {
|
||||
admin = true
|
||||
} else if (w.startsWith("actAs:")) {
|
||||
actAs.append(Ref.Party.assertFromString(w.stripPrefix("actAs:")))
|
||||
} else if (w.startsWith("readAs:")) {
|
||||
readAs.append(Ref.Party.assertFromString(w.stripPrefix("readAs:")))
|
||||
} else if (w.startsWith("applicationId:")) {
|
||||
applicationId match {
|
||||
case None =>
|
||||
applicationId =
|
||||
Some(Ref.ApplicationId.assertFromString(w.stripPrefix("applicationId:")))
|
||||
case Some(_) =>
|
||||
throw new IllegalArgumentException(
|
||||
"applicationId claim can only be specified once"
|
||||
)
|
||||
}
|
||||
} else {
|
||||
throw new IllegalArgumentException(s"Expected claim but got $w")
|
||||
}
|
||||
}
|
||||
Claims(admin, actAs.toList, readAs.toList, applicationId)
|
||||
}
|
||||
implicit val marshalRequestEntity: Marshaller[Claims, String] =
|
||||
Marshaller.opaque(_.toQueryString())
|
||||
implicit val unmarshalHttpEntity: Unmarshaller[String, Claims] =
|
||||
Unmarshaller { _ => s => Future.fromTry(Try(apply(s))) }
|
||||
}
|
||||
|
||||
/** Auth endpoint query parameters
|
||||
*/
|
||||
case class Auth(claims: Claims) {
|
||||
def toQuery: Uri.Query = Uri.Query("claims" -> claims.toQueryString())
|
||||
}
|
||||
|
||||
/** Login endpoint query parameters
|
||||
*
|
||||
* @param redirectUri Redirect target after the login flow completed. I.e. the original request URI on the trigger service.
|
||||
* @param claims Required ledger claims.
|
||||
* @param state State that will be forwarded to the callback URI after authentication and authorization.
|
||||
*/
|
||||
case class Login(redirectUri: Option[Uri], claims: Claims, state: Option[String]) {
|
||||
def toQuery: Uri.Query = {
|
||||
var params = Seq("claims" -> claims.toQueryString())
|
||||
redirectUri.foreach(x => params ++= Seq("redirect_uri" -> x.toString()))
|
||||
state.foreach(x => params ++= Seq("state" -> x))
|
||||
Uri.Query(params: _*)
|
||||
}
|
||||
}
|
||||
|
||||
/** Refresh endpoint request entity
|
||||
*/
|
||||
case class Refresh(refreshToken: RefreshToken)
|
||||
|
||||
}
|
||||
|
||||
object Response {
|
||||
import Tagged._
|
||||
|
||||
case class Authorize(accessToken: AccessToken, refreshToken: Option[RefreshToken])
|
||||
|
||||
sealed abstract class Login
|
||||
final case class LoginError(error: String, errorDescription: Option[String]) extends Login
|
||||
object LoginSuccess extends Login
|
||||
|
||||
object Login {
|
||||
val callbackParameters: Directive1[Login] =
|
||||
parameters(Symbol("error"), Symbol("error_description") ?)
|
||||
.as[LoginError](LoginError)
|
||||
.or(provide(LoginSuccess))
|
||||
}
|
||||
|
||||
val authenticateChallengeName: String = "DamlAuthMiddleware"
|
||||
|
||||
case class AuthenticateChallenge(
|
||||
realm: Request.Claims,
|
||||
login: Uri,
|
||||
auth: Uri,
|
||||
) {
|
||||
def toHeader: HttpHeader = headers.`WWW-Authenticate`(
|
||||
HttpChallenge(
|
||||
authenticateChallengeName,
|
||||
realm.toQueryString(),
|
||||
Map(
|
||||
"login" -> login.toString(),
|
||||
"auth" -> auth.toString(),
|
||||
),
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object JsonProtocol extends DefaultJsonProtocol {
|
||||
import Tagged._
|
||||
|
||||
implicit object UriFormat extends JsonFormat[Uri] {
|
||||
def read(value: JsValue) = value match {
|
||||
case JsString(s) => Uri(s)
|
||||
case _ => deserializationError(s"Expected Uri string but got $value")
|
||||
}
|
||||
def write(uri: Uri) = JsString(uri.toString)
|
||||
}
|
||||
implicit object AccessTokenJsonFormat extends JsonFormat[AccessToken] {
|
||||
def write(x: AccessToken) = {
|
||||
JsString(AccessToken.unwrap(x))
|
||||
}
|
||||
def read(value: JsValue) = value match {
|
||||
case JsString(x) => AccessToken(x)
|
||||
case x => deserializationError(s"Expected AccessToken as JsString, but got $x")
|
||||
}
|
||||
}
|
||||
implicit object RefreshTokenJsonFormat extends JsonFormat[RefreshToken] {
|
||||
def write(x: RefreshToken) = {
|
||||
JsString(RefreshToken.unwrap(x))
|
||||
}
|
||||
def read(value: JsValue) = value match {
|
||||
case JsString(x) => RefreshToken(x)
|
||||
case x => deserializationError(s"Expected RefreshToken as JsString, but got $x")
|
||||
}
|
||||
}
|
||||
implicit object RequestClaimsFormat extends JsonFormat[Request.Claims] {
|
||||
def write(claims: Request.Claims) = {
|
||||
JsString(claims.toQueryString())
|
||||
}
|
||||
def read(value: JsValue) = value match {
|
||||
case JsString(s) =>
|
||||
try {
|
||||
Request.Claims(s)
|
||||
} catch {
|
||||
case ex: IllegalArgumentException =>
|
||||
deserializationError(ex.getMessage, ex)
|
||||
}
|
||||
case x => deserializationError(s"Expected Claims as JsString, but got $x")
|
||||
}
|
||||
}
|
||||
implicit val requestRefreshFormat: RootJsonFormat[Request.Refresh] =
|
||||
jsonFormat(Request.Refresh, "refresh_token")
|
||||
implicit val responseAuthorizeFormat: RootJsonFormat[Response.Authorize] =
|
||||
jsonFormat(Response.Authorize, "access_token", "refresh_token")
|
||||
|
||||
implicit object ResponseLoginFormat extends RootJsonFormat[Response.Login] {
|
||||
implicit private val errorFormat: RootJsonFormat[Response.LoginError] = jsonFormat2(
|
||||
Response.LoginError
|
||||
)
|
||||
def write(login: Response.Login) = login match {
|
||||
case error: Response.LoginError => error.toJson
|
||||
case Response.LoginSuccess => JsNull
|
||||
}
|
||||
def read(value: JsValue) = value.convertTo(safeReader[Response.LoginError]) match {
|
||||
case Right(error) => error
|
||||
case Left(_) =>
|
||||
value match {
|
||||
case JsNull => Response.LoginSuccess
|
||||
case _ => deserializationError(s"Expected null or error object but got $value")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
implicit val responseAuthenticateChallengeFormat: RootJsonFormat[Response.AuthenticateChallenge] =
|
||||
jsonFormat(Response.AuthenticateChallenge, "realm", "login", "auth")
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.api
|
||||
|
||||
import scala.collection.mutable.LinkedHashMap
|
||||
import scala.concurrent.duration.FiniteDuration
|
||||
|
||||
/** A key-value store with a maximum capacity and maximum storage duration.
|
||||
* @param maxCapacity Maximum number of requests that can be stored.
|
||||
* @param timeout Duration after which requests will be evicted.
|
||||
* @param monotonicClock Determines the current timestamp. The underlying clock must be monotonic.
|
||||
* The JVM will use a monotonic clock for [[System.nanoTime]], if available, according to
|
||||
* [[https://bugs.openjdk.java.net/browse/JDK-6458294?focusedCommentId=13823604&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-13823604 JDK bug 6458294]]
|
||||
* @tparam K The key type
|
||||
* @tparam V The value type
|
||||
*/
|
||||
private[middleware] class RequestStore[K, V](
|
||||
maxCapacity: Int,
|
||||
timeout: FiniteDuration,
|
||||
monotonicClock: () => Long = () => System.nanoTime,
|
||||
) {
|
||||
|
||||
/** Mapping from key to insertion timestamp and value.
|
||||
* The timestamp of later inserted elements must be greater or equal to the timestamp of earlier inserted elements.
|
||||
*/
|
||||
private val store: LinkedHashMap[K, (Long, V)] = LinkedHashMap.empty
|
||||
|
||||
/** Check whether the given [[timestamp]] timed out relative to the current time [[now]].
|
||||
*/
|
||||
private def timedOut(now: Long, timestamp: Long): Boolean = {
|
||||
now - timestamp >= timeout.toNanos
|
||||
}
|
||||
|
||||
private def evictTimedOut(now: Long): Unit = {
|
||||
// Remove items until their timestamp is more recent than the configured timeout.
|
||||
store.iterator
|
||||
.takeWhile { case (_, (t, _)) =>
|
||||
timedOut(now, t)
|
||||
}
|
||||
.foreach { case (k, _) =>
|
||||
store.remove(k)
|
||||
}
|
||||
}
|
||||
|
||||
/** Insert a new key-value pair unless the maximum capacity is reached.
|
||||
* Evicts timed out elements before attempting insertion.
|
||||
* @return whether the key-value pair was inserted.
|
||||
*/
|
||||
def put(key: K, value: => V): Boolean = {
|
||||
synchronized {
|
||||
val now = monotonicClock()
|
||||
evictTimedOut(now)
|
||||
if (store.size >= maxCapacity) {
|
||||
false
|
||||
} else {
|
||||
store.update(key, (now, value))
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Remove and return the value under the given key, if present and not timed out.
|
||||
*/
|
||||
def pop(key: K): Option[V] = {
|
||||
synchronized {
|
||||
store.remove(key).flatMap {
|
||||
case (t, _) if timedOut(monotonicClock(), t) => None
|
||||
case (_, v) => Some(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,241 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.oauth2
|
||||
|
||||
import org.apache.pekko.http.scaladsl.model.Uri
|
||||
import com.daml.auth.middleware.oauth2.Config.{
|
||||
DefaultCookieSecure,
|
||||
DefaultHttpPort,
|
||||
DefaultLoginTimeout,
|
||||
DefaultMaxLoginRequests,
|
||||
}
|
||||
import com.daml.cliopts
|
||||
import com.daml.jwt.{JwtVerifierBase, JwtVerifierConfigurationCli}
|
||||
import com.daml.metrics.api.reporters.MetricsReporter
|
||||
import com.typesafe.scalalogging.StrictLogging
|
||||
import pureconfig.ConfigSource
|
||||
import pureconfig.error.ConfigReaderFailures
|
||||
import scopt.OptionParser
|
||||
|
||||
import java.io.File
|
||||
import java.nio.file.{Path, Paths}
|
||||
|
||||
import scala.concurrent.duration
|
||||
import scala.concurrent.duration.FiniteDuration
|
||||
import scalaz.syntax.std.option._
|
||||
|
||||
private[oauth2] final case class Cli(
|
||||
configFile: Option[File] = None,
|
||||
// Host and port the middleware listens on
|
||||
address: String = cliopts.Http.defaultAddress,
|
||||
port: Int = DefaultHttpPort,
|
||||
portFile: Option[Path] = None,
|
||||
// The URI to which the OAuth2 server will redirect after a completed login flow.
|
||||
// Must map to the `/cb` endpoint of the auth middleware.
|
||||
callbackUri: Option[Uri] = None,
|
||||
maxLoginRequests: Int = DefaultMaxLoginRequests,
|
||||
loginTimeout: FiniteDuration = DefaultLoginTimeout,
|
||||
cookieSecure: Boolean = DefaultCookieSecure,
|
||||
// OAuth2 server endpoints
|
||||
oauthAuth: Uri,
|
||||
oauthToken: Uri,
|
||||
// OAuth2 server request templates
|
||||
oauthAuthTemplate: Option[Path],
|
||||
oauthTokenTemplate: Option[Path],
|
||||
oauthRefreshTemplate: Option[Path],
|
||||
// OAuth2 client properties
|
||||
clientId: String,
|
||||
clientSecret: SecretString,
|
||||
// Token verification
|
||||
tokenVerifier: JwtVerifierBase,
|
||||
metricsReporter: Option[MetricsReporter] = None,
|
||||
metricsReportingInterval: FiniteDuration = FiniteDuration(10, duration.SECONDS),
|
||||
) extends StrictLogging {
|
||||
|
||||
def loadFromConfigFile: Option[Either[ConfigReaderFailures, FileConfig]] = {
|
||||
configFile.map(cf => ConfigSource.file(cf).load[FileConfig])
|
||||
}
|
||||
|
||||
def loadFromCliArgs: Config = {
|
||||
val cfg = Config(
|
||||
address,
|
||||
port,
|
||||
portFile,
|
||||
callbackUri,
|
||||
maxLoginRequests,
|
||||
loginTimeout,
|
||||
cookieSecure,
|
||||
oauthAuth,
|
||||
oauthToken,
|
||||
oauthAuthTemplate,
|
||||
oauthTokenTemplate,
|
||||
oauthRefreshTemplate,
|
||||
clientId,
|
||||
clientSecret,
|
||||
tokenVerifier,
|
||||
metricsReporter,
|
||||
metricsReportingInterval,
|
||||
Seq.empty,
|
||||
)
|
||||
cfg.validate()
|
||||
cfg
|
||||
}
|
||||
|
||||
def loadConfig: Option[Config] = {
|
||||
loadFromConfigFile.cata(
|
||||
{
|
||||
case Right(cfg) => Some(cfg.toConfig())
|
||||
case Left(ex) =>
|
||||
logger.error(
|
||||
s"Error loading oauth2-middleware config from file $configFile",
|
||||
ex.prettyPrint(),
|
||||
)
|
||||
None
|
||||
}, {
|
||||
logger.warn("Using cli opts for running oauth2-middleware is deprecated")
|
||||
Some(loadFromCliArgs)
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements"))
|
||||
private[oauth2] object Cli {
|
||||
|
||||
private[oauth2] val Default =
|
||||
Cli(
|
||||
configFile = None,
|
||||
address = cliopts.Http.defaultAddress,
|
||||
port = DefaultHttpPort,
|
||||
portFile = None,
|
||||
callbackUri = None,
|
||||
maxLoginRequests = DefaultMaxLoginRequests,
|
||||
loginTimeout = DefaultLoginTimeout,
|
||||
cookieSecure = DefaultCookieSecure,
|
||||
oauthAuth = null,
|
||||
oauthToken = null,
|
||||
oauthAuthTemplate = None,
|
||||
oauthTokenTemplate = None,
|
||||
oauthRefreshTemplate = None,
|
||||
clientId = null,
|
||||
clientSecret = null,
|
||||
tokenVerifier = null,
|
||||
)
|
||||
|
||||
private val parser: OptionParser[Cli] = new scopt.OptionParser[Cli]("oauth-middleware") {
|
||||
help('h', "help").text("Print usage")
|
||||
opt[Option[File]]('c', "config")
|
||||
.text(
|
||||
"This is the recommended way to provide an app config file, the remaining cli-args are deprecated"
|
||||
)
|
||||
.valueName("<file>")
|
||||
.action((file, cli) => cli.copy(configFile = file))
|
||||
|
||||
cliopts.Http.serverParse(this, serviceName = "OAuth2 Middleware")(
|
||||
address = (f, c) => c.copy(address = f(c.address)),
|
||||
httpPort = (f, c) => c.copy(port = f(c.port)),
|
||||
defaultHttpPort = Some(DefaultHttpPort),
|
||||
portFile = Some((f, c) => c.copy(portFile = f(c.portFile))),
|
||||
)
|
||||
|
||||
opt[String]("callback")
|
||||
.action((x, c) => c.copy(callbackUri = Some(Uri(x))))
|
||||
.text(
|
||||
"URI to the auth middleware's callback endpoint `/cb`. By default constructed from the incoming login request."
|
||||
)
|
||||
|
||||
opt[Int]("max-pending-login-requests")
|
||||
.action((x, c) => c.copy(maxLoginRequests = x))
|
||||
.text(
|
||||
"Maximum number of simultaneously pending login requests. Requests will be denied when exceeded until earlier requests have been completed or timed out."
|
||||
)
|
||||
|
||||
opt[Boolean]("cookie-secure")
|
||||
.action((x, c) => c.copy(cookieSecure = x))
|
||||
.text(
|
||||
"Enable the Secure attribute on the cookie that stores the token. Defaults to true. Only disable this for testing and development purposes."
|
||||
)
|
||||
|
||||
opt[Long]("login-request-timeout")
|
||||
.action((x, c) => c.copy(loginTimeout = FiniteDuration(x, duration.SECONDS)))
|
||||
.text(
|
||||
"Login request timeout. Requests will be evicted if the callback endpoint receives no corresponding request in time."
|
||||
)
|
||||
|
||||
opt[String]("oauth-auth")
|
||||
.action((x, c) => c.copy(oauthAuth = Uri(x)))
|
||||
.text("URI of the OAuth2 authorization endpoint")
|
||||
|
||||
opt[String]("oauth-token")
|
||||
.action((x, c) => c.copy(oauthToken = Uri(x)))
|
||||
.text("URI of the OAuth2 token endpoint")
|
||||
|
||||
opt[String]("oauth-auth-template")
|
||||
.action((x, c) => c.copy(oauthAuthTemplate = Some(Paths.get(x))))
|
||||
.text("OAuth2 authorization request Jsonnet template")
|
||||
|
||||
opt[String]("oauth-token-template")
|
||||
.action((x, c) => c.copy(oauthTokenTemplate = Some(Paths.get(x))))
|
||||
.text("OAuth2 token request Jsonnet template")
|
||||
|
||||
opt[String]("oauth-refresh-template")
|
||||
.action((x, c) => c.copy(oauthRefreshTemplate = Some(Paths.get(x))))
|
||||
.text("OAuth2 refresh request Jsonnet template")
|
||||
|
||||
opt[String]("id")
|
||||
.hidden()
|
||||
.action((x, c) => c.copy(clientId = x))
|
||||
.withFallback(() => sys.env.getOrElse("DAML_CLIENT_ID", ""))
|
||||
|
||||
opt[String]("secret")
|
||||
.hidden()
|
||||
.action((x, c) => c.copy(clientSecret = SecretString(x)))
|
||||
.withFallback(() => sys.env.getOrElse("DAML_CLIENT_SECRET", ""))
|
||||
|
||||
cliopts.Metrics.metricsReporterParse(this)(
|
||||
(f, c) => c.copy(metricsReporter = f(c.metricsReporter)),
|
||||
(f, c) => c.copy(metricsReportingInterval = f(c.metricsReportingInterval)),
|
||||
)
|
||||
|
||||
JwtVerifierConfigurationCli.parse(this)((v, c) => c.copy(tokenVerifier = v))
|
||||
|
||||
checkConfig { cfg =>
|
||||
if (cfg.configFile.isEmpty && cfg.tokenVerifier == null)
|
||||
Left("You must specify one of the --auth-jwt-* flags for token verification.")
|
||||
else
|
||||
Right(())
|
||||
}
|
||||
|
||||
checkConfig { cfg =>
|
||||
if (cfg.configFile.isEmpty && (cfg.clientId.isEmpty || cfg.clientSecret.value.isEmpty))
|
||||
Left("Environment variable DAML_CLIENT_ID AND DAML_CLIENT_SECRET must not be empty")
|
||||
else
|
||||
Right(())
|
||||
}
|
||||
|
||||
checkConfig { cfg =>
|
||||
if (cfg.configFile.isEmpty && (cfg.oauthAuth == null || cfg.oauthToken == null))
|
||||
Left("oauth-auth and oauth-token values must not be empty")
|
||||
else
|
||||
Right(())
|
||||
}
|
||||
|
||||
checkConfig { cfg =>
|
||||
val cliOptionsAreDefined =
|
||||
cfg.oauthToken != null || cfg.oauthAuth != null || cfg.tokenVerifier != null
|
||||
if (cfg.configFile.isDefined && cliOptionsAreDefined) {
|
||||
Left("Found both config file and cli opts for the app, please provide only one of them")
|
||||
} else Right(())
|
||||
}
|
||||
|
||||
override def showUsageOnError: Option[Boolean] = Some(true)
|
||||
}
|
||||
|
||||
def parse(args: Array[String]): Option[Cli] = parser.parse(args, Default)
|
||||
|
||||
def parseConfig(args: Array[String]): Option[Config] = {
|
||||
val cli = parse(args)
|
||||
cli.flatMap(_.loadConfig)
|
||||
}
|
||||
}
|
@ -1,114 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.oauth2
|
||||
|
||||
import java.nio.file.Path
|
||||
|
||||
import org.apache.pekko.http.scaladsl.model.Uri
|
||||
import com.daml.auth.middleware.oauth2.Config._
|
||||
import com.daml.cliopts
|
||||
import com.daml.jwt.JwtVerifierBase
|
||||
import com.daml.metrics.{HistogramDefinition, MetricsConfig}
|
||||
import com.daml.metrics.api.reporters.MetricsReporter
|
||||
import com.daml.pureconfigutils.SharedConfigReaders._
|
||||
import pureconfig.{ConfigReader, ConvertHelpers}
|
||||
import pureconfig.generic.semiauto.deriveReader
|
||||
|
||||
import scala.concurrent.duration._
|
||||
|
||||
final case class Config(
|
||||
// Host and port the middleware listens on
|
||||
address: String = cliopts.Http.defaultAddress,
|
||||
port: Int = DefaultHttpPort,
|
||||
portFile: Option[Path] = None,
|
||||
// The URI to which the OAuth2 server will redirect after a completed login flow.
|
||||
// Must map to the `/cb` endpoint of the auth middleware.
|
||||
callbackUri: Option[Uri] = None,
|
||||
maxLoginRequests: Int = DefaultMaxLoginRequests,
|
||||
loginTimeout: FiniteDuration = DefaultLoginTimeout,
|
||||
cookieSecure: Boolean = DefaultCookieSecure,
|
||||
// OAuth2 server endpoints
|
||||
oauthAuth: Uri,
|
||||
oauthToken: Uri,
|
||||
// OAuth2 server request templates
|
||||
oauthAuthTemplate: Option[Path] = None,
|
||||
oauthTokenTemplate: Option[Path] = None,
|
||||
oauthRefreshTemplate: Option[Path] = None,
|
||||
// OAuth2 client properties
|
||||
clientId: String,
|
||||
clientSecret: SecretString,
|
||||
// Token verification
|
||||
tokenVerifier: JwtVerifierBase,
|
||||
metricsReporter: Option[MetricsReporter] = None,
|
||||
metricsReportingInterval: FiniteDuration = 10.seconds,
|
||||
histograms: Seq[HistogramDefinition],
|
||||
) {
|
||||
def validate(): Unit = {
|
||||
require(oauthToken != null, "Oauth token value on config cannot be null")
|
||||
require(oauthAuth != null, "Oauth auth value on config cannot be null")
|
||||
require(clientId.nonEmpty, "DAML_CLIENT_ID cannot be empty")
|
||||
require(clientSecret.value.nonEmpty, "DAML_CLIENT_SECRET cannot be empty")
|
||||
require(tokenVerifier != null, "token verifier must be defined")
|
||||
}
|
||||
}
|
||||
|
||||
@scala.annotation.nowarn("msg=Block result was adapted via implicit conversion")
|
||||
object FileConfig {
|
||||
implicit val clientSecretReader: ConfigReader[SecretString] =
|
||||
ConfigReader.fromString[SecretString](ConvertHelpers.catchReadError(s => SecretString(s)))
|
||||
|
||||
implicit val cfgReader: ConfigReader[FileConfig] = deriveReader[FileConfig]
|
||||
}
|
||||
|
||||
final case class FileConfig(
|
||||
address: String = cliopts.Http.defaultAddress,
|
||||
port: Int = DefaultHttpPort,
|
||||
portFile: Option[Path] = None,
|
||||
callbackUri: Option[Uri] = None,
|
||||
maxLoginRequests: Int = DefaultMaxLoginRequests,
|
||||
loginTimeout: FiniteDuration = DefaultLoginTimeout,
|
||||
cookieSecure: Boolean = DefaultCookieSecure,
|
||||
oauthAuth: Uri,
|
||||
oauthToken: Uri,
|
||||
oauthAuthTemplate: Option[Path] = None,
|
||||
oauthTokenTemplate: Option[Path] = None,
|
||||
oauthRefreshTemplate: Option[Path] = None,
|
||||
clientId: String,
|
||||
clientSecret: SecretString,
|
||||
tokenVerifier: JwtVerifierBase,
|
||||
metrics: Option[MetricsConfig] = None,
|
||||
) {
|
||||
def toConfig(): Config = Config(
|
||||
address = address,
|
||||
port = port,
|
||||
portFile = portFile,
|
||||
callbackUri = callbackUri,
|
||||
maxLoginRequests = maxLoginRequests,
|
||||
loginTimeout = loginTimeout,
|
||||
cookieSecure = cookieSecure,
|
||||
oauthAuth = oauthAuth,
|
||||
oauthToken = oauthToken,
|
||||
oauthAuthTemplate = oauthAuthTemplate,
|
||||
oauthTokenTemplate = oauthTokenTemplate,
|
||||
oauthRefreshTemplate = oauthRefreshTemplate,
|
||||
clientId = clientId,
|
||||
clientSecret = clientSecret,
|
||||
tokenVerifier = tokenVerifier,
|
||||
metricsReporter = metrics.map(_.reporter),
|
||||
metricsReportingInterval =
|
||||
metrics.map(_.reportingInterval).getOrElse(MetricsConfig.DefaultMetricsReportingInterval),
|
||||
histograms = metrics.toList.flatMap(_.histograms),
|
||||
)
|
||||
}
|
||||
|
||||
final case class SecretString(value: String) {
|
||||
override def toString: String = "###"
|
||||
}
|
||||
|
||||
object Config {
|
||||
val DefaultHttpPort: Int = 3000
|
||||
val DefaultCookieSecure: Boolean = true
|
||||
val DefaultMaxLoginRequests: Int = 100
|
||||
val DefaultLoginTimeout: FiniteDuration = 5.minutes
|
||||
}
|
@ -1,46 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.oauth2
|
||||
|
||||
import org.apache.pekko.actor.ActorSystem
|
||||
import com.daml.scalautil.Statement.discard
|
||||
import com.typesafe.scalalogging.StrictLogging
|
||||
|
||||
import scala.concurrent.duration._
|
||||
import scala.concurrent.{Await, ExecutionContext}
|
||||
import scala.util.{Failure, Success}
|
||||
|
||||
object Main extends StrictLogging {
|
||||
def main(args: Array[String]): Unit = {
|
||||
Cli.parseConfig(args) match {
|
||||
case Some(config) => main(config)
|
||||
case None => sys.exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
private def main(config: Config): Unit = {
|
||||
implicit val system: ActorSystem = ActorSystem("system")
|
||||
implicit val executionContext: ExecutionContext = system.dispatcher
|
||||
|
||||
def terminate() = Await.result(system.terminate(), 10.seconds)
|
||||
|
||||
val bindingFuture = Server.start(config, registerGlobalOpenTelemetry = true)
|
||||
|
||||
discard(
|
||||
sys.addShutdownHook(
|
||||
Server.stop(bindingFuture).onComplete(_ => terminate())
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(s"Configuration $config")
|
||||
|
||||
bindingFuture.onComplete {
|
||||
case Success(binding) =>
|
||||
logger.info(s"Started server: $binding")
|
||||
case Failure(e) =>
|
||||
logger.error(s"Failed to start server: $e")
|
||||
terminate()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,16 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.oauth2
|
||||
|
||||
import com.daml.metrics.api.opentelemetry.OpenTelemetryMetricsFactory
|
||||
import com.daml.metrics.http.DamlHttpMetrics
|
||||
import io.opentelemetry.api.metrics.{Meter => OtelMeter}
|
||||
|
||||
case class Oauth2MiddlewareMetrics(otelMeter: OtelMeter) {
|
||||
|
||||
val openTelemetryFactory = new OpenTelemetryMetricsFactory(otelMeter)
|
||||
|
||||
val http = new DamlHttpMetrics(openTelemetryFactory, "oauth2-middleware")
|
||||
|
||||
}
|
@ -1,111 +0,0 @@
|
||||
# Trigger Service OAuth 2.0 Middleware
|
||||
|
||||
Implements an OAuth2 middleware according to the trigger service
|
||||
authentication/authorization specification in
|
||||
`triggers/service/authentication.md`.
|
||||
|
||||
## Manual Testing against Auth0
|
||||
|
||||
Apart from the automated tests defined in this repository, the middleware can
|
||||
be tested manually against an auth0 OAuth2 setup. The necessary steps are
|
||||
extracted and adapted from the [Secure Daml Infrastructure
|
||||
repository](https://github.com/digital-asset/ex-secure-daml-infra).
|
||||
|
||||
### Setup
|
||||
|
||||
* Sign up for an account on [Auth0](https://auth0.com).
|
||||
* Create a new API.
|
||||
- Provide a name (`ex-daml-api`).
|
||||
- Provide an Identifier (`https://daml.com/ledger-api`).
|
||||
- Select Signing Algorithm of `RS256`.
|
||||
- Allow offline access to enable refresh token generation.
|
||||
This allows the OAuth2 client, i.e. the auth middleware, to request access through a refresh token when the resource owner, i.e. the user, is offline.
|
||||
* Create a new native application.
|
||||
- Provide a name (`ex-daml-auth-middleware`).
|
||||
- Select the authorized API (`ex-daml-api`).
|
||||
- Configure the allowed callback URLs in the settings (`http://localhost:3000/cb`).
|
||||
- Note the "Client ID" and "Client Secret" displayed in the "Basic
|
||||
Information" pane of the application settings.
|
||||
- Note the "OAuth Authorization URL" and the "OAuth Token URL" in the
|
||||
"Endpoints" tab of the advanced settings.
|
||||
* Create a new empty rule.
|
||||
- Provide a name (`ex-daml-claims`).
|
||||
- Provide a script
|
||||
``` javascript
|
||||
function (user, context, callback) {
|
||||
// Only handle ledger-api audience.
|
||||
const audience = context.request.query && context.request.query.audience || "";
|
||||
if (audience !== "https://daml.com/ledger-api") {
|
||||
return callback(null, user, context);
|
||||
}
|
||||
|
||||
// Grant all requested claims
|
||||
const scope = (context.request.query && context.request.query.scope || "").split(" ");
|
||||
var actAs = [];
|
||||
var readAs = [];
|
||||
var admin = false;
|
||||
scope.forEach(s => {
|
||||
if (s.startsWith("actAs:")) {
|
||||
actAs.push(s.slice(6));
|
||||
} else if (s.startsWith("readAs:")) {
|
||||
readAs.push(s.slice(7));
|
||||
} else if (s === "admin") {
|
||||
admin = true;
|
||||
}
|
||||
});
|
||||
|
||||
// Construct access token.
|
||||
const namespace = 'https://daml.com/ledger-api';
|
||||
context.accessToken[namespace] = {
|
||||
// NOTE change the ledger ID to match your deployment.
|
||||
"ledgerId": "2D105384-CE61-4CCC-8E0E-37248BA935A3",
|
||||
"actAs": actAs,
|
||||
"readAs": readAs,
|
||||
"admin": admin
|
||||
};
|
||||
|
||||
return callback(null, user, context);
|
||||
}
|
||||
```
|
||||
* Create a new user.
|
||||
- Provide an email address (`alice@localhost`)
|
||||
- Provide a secure password
|
||||
- Mark the email address as verified on the user's "Details" page.
|
||||
|
||||
### Testing
|
||||
|
||||
* Start the middleware by executing the following command.
|
||||
```
|
||||
$ DAML_CLIENT_ID=CLIENTID \
|
||||
DAML_CLIENT_SECRET=CLIENTSECRET \
|
||||
bazel run //triggers/service/auth:oauth-middleware-binary -- \
|
||||
--config oauth-middleware.conf
|
||||
```
|
||||
- Replace `CLIENTID` and `CLIENTSECRET` by the "Client ID" and "Client
|
||||
Secret" from above.
|
||||
|
||||
The basic minimal config that needs to be supplied needs to have appropriate
|
||||
`callback-uri`,`oauth-auth` and `oauth-token` urls defined,
|
||||
along with the `token-verifier`,`client-id` and `client-secret` fields. e.g
|
||||
```
|
||||
{
|
||||
callback-uri = "https://example.com/auth/cb"
|
||||
oauth-auth = "https://XYZ.auth0.com/authorize"
|
||||
oauth-token = "https://XYZ.auth0.com/oauth/token"
|
||||
|
||||
client-id = ${DAML_CLIENT_ID}
|
||||
client-secret = ${DAML_CLIENT_SECRET}
|
||||
|
||||
// type can be one of rs256-crt, es256-crt, es512-crt, rs256-jwks
|
||||
// uri is the uri to the cert file or the jwks url
|
||||
token-verifier {
|
||||
type = "rs256-jwks"
|
||||
uri = "https://example.com/.well-known/jwks.json"
|
||||
}
|
||||
}
|
||||
```
|
||||
- Browse to the middleware's login endpoint.
|
||||
- URL `http://localhost:3000/login?redirect_uri=callback&claims=actAs:Alice`
|
||||
- Login as the new user created above.
|
||||
- Authorize the middleware application to access the tenant.
|
||||
- You should be redirected to `http://localhost:3000/callback`.
|
@ -1,171 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.oauth2
|
||||
|
||||
import java.nio.file.Path
|
||||
import java.util.UUID
|
||||
|
||||
import org.apache.pekko.http.scaladsl.model.Uri
|
||||
import com.daml.auth.middleware.api.Request
|
||||
import com.daml.auth.middleware.api.Tagged.RefreshToken
|
||||
|
||||
import scala.collection.concurrent.TrieMap
|
||||
import scala.io.{BufferedSource, Source}
|
||||
import scala.util.Try
|
||||
|
||||
private[oauth2] class RequestTemplates(
|
||||
clientId: String,
|
||||
clientSecret: SecretString,
|
||||
authTemplate: Option[Path],
|
||||
tokenTemplate: Option[Path],
|
||||
refreshTemplate: Option[Path],
|
||||
) {
|
||||
|
||||
private val authResourcePath: String = "auth0_request_authorization.jsonnet"
|
||||
private val tokenResourcePath: String = "auth0_request_token.jsonnet"
|
||||
private val refreshResourcePath: String = "auth0_request_refresh.jsonnet"
|
||||
|
||||
/** Load a Jsonnet source file.
|
||||
* @param optFilePath Load from this file path, if provided.
|
||||
* @param resourcePath Load from this JAR resource, if no file is provided.
|
||||
* @return Content and file path (for error reporting) of the loaded Jsonnet file.
|
||||
*/
|
||||
private def jsonnetSource(
|
||||
optFilePath: Option[Path],
|
||||
resourcePath: String,
|
||||
): (String, sjsonnet.Path) = {
|
||||
def readSource(source: BufferedSource): String = {
|
||||
try { source.mkString }
|
||||
finally { source.close() }
|
||||
}
|
||||
optFilePath match {
|
||||
case Some(filePath) =>
|
||||
val content: String = readSource(Source.fromFile(filePath.toString))
|
||||
val path: sjsonnet.Path = sjsonnet.OsPath(os.Path(filePath.toAbsolutePath))
|
||||
(content, path)
|
||||
case None =>
|
||||
val resource = getClass.getResource(resourcePath)
|
||||
val content: String = readSource(Source.fromInputStream(resource.openStream()))
|
||||
// This path is only used for error reporting and a builtin template should not raise any errors.
|
||||
// However, if it does it should be clear that the path refers to a builtin file.
|
||||
// Paths are reported relative to `$PWD`, we prefix `$PWD` to avoid `../../` noise.
|
||||
val path: sjsonnet.Path =
|
||||
sjsonnet.OsPath(os.RelPath(s"BUILTIN/$resourcePath").resolveFrom(os.pwd))
|
||||
(content, path)
|
||||
}
|
||||
}
|
||||
|
||||
private val jsonnetParseCache
|
||||
: TrieMap[String, fastparse.Parsed[(sjsonnet.Expr, Map[String, Int])]] = TrieMap.empty
|
||||
|
||||
/** Interpret the given Jsonnet code.
|
||||
* @param source The Jsonnet source code.
|
||||
* @param sourcePath The Jsonnet source file path (for error reporting).
|
||||
* @param arguments Top-level arguments to pass to the Jsonnet code.
|
||||
* @return The resulting JSON value.
|
||||
*/
|
||||
private def interpretJsonnet(
|
||||
source: String,
|
||||
sourcePath: sjsonnet.Path,
|
||||
arguments: Map[String, ujson.Value],
|
||||
): Try[ujson.Value] = {
|
||||
val interp = new sjsonnet.Interpreter(
|
||||
jsonnetParseCache,
|
||||
Map(),
|
||||
arguments,
|
||||
sjsonnet.OsPath(os.pwd),
|
||||
importer = sjsonnet.SjsonnetMain.resolveImport(Nil, None),
|
||||
)
|
||||
interp
|
||||
.interpret(source, sourcePath)
|
||||
.left
|
||||
.map(new RequestTemplates.InterpretTemplateException(_))
|
||||
.toTry
|
||||
}
|
||||
|
||||
/** Convert a JSON value to a string mapping representing request parameters.
|
||||
*/
|
||||
private def toRequestParams(value: ujson.Value): Try[Map[String, String]] =
|
||||
Try(value.obj.view.mapValues(_.str).toMap)
|
||||
|
||||
private def createRequest(
|
||||
template: (String, sjsonnet.Path),
|
||||
args: Map[String, ujson.Value],
|
||||
): Try[Map[String, String]] = {
|
||||
val (jsonnet_src, jsonnet_path) = template
|
||||
interpretJsonnet(jsonnet_src, jsonnet_path, args).flatMap(toRequestParams)
|
||||
}
|
||||
|
||||
private lazy val config: ujson.Value = ujson.Obj(
|
||||
"clientId" -> clientId,
|
||||
"clientSecret" -> clientSecret.value,
|
||||
)
|
||||
|
||||
private lazy val authJsonnetSource: (String, sjsonnet.Path) =
|
||||
jsonnetSource(authTemplate, authResourcePath)
|
||||
private def authArguments(
|
||||
claims: Request.Claims,
|
||||
requestId: UUID,
|
||||
redirectUri: Uri,
|
||||
): Map[String, ujson.Value] =
|
||||
Map(
|
||||
"config" -> config,
|
||||
"request" -> ujson.Obj(
|
||||
"claims" -> ujson.Obj(
|
||||
"admin" -> claims.admin,
|
||||
"applicationId" -> (claims.applicationId match {
|
||||
case Some(appId) => appId
|
||||
case None => ujson.Null
|
||||
}),
|
||||
"actAs" -> claims.actAs,
|
||||
"readAs" -> claims.readAs,
|
||||
),
|
||||
"redirectUri" -> redirectUri.toString,
|
||||
"state" -> requestId.toString,
|
||||
),
|
||||
)
|
||||
def createAuthRequest(
|
||||
claims: Request.Claims,
|
||||
requestId: UUID,
|
||||
redirectUri: Uri,
|
||||
): Try[Map[String, String]] = {
|
||||
createRequest(authJsonnetSource, authArguments(claims, requestId, redirectUri))
|
||||
}
|
||||
|
||||
private lazy val tokenJsonnetSource: (String, sjsonnet.Path) =
|
||||
jsonnetSource(tokenTemplate, tokenResourcePath)
|
||||
private def tokenArguments(code: String, redirectUri: Uri): Map[String, ujson.Value] = Map(
|
||||
"config" -> config,
|
||||
"request" -> ujson.Obj(
|
||||
"code" -> code,
|
||||
"redirectUri" -> redirectUri.toString,
|
||||
),
|
||||
)
|
||||
def createTokenRequest(code: String, redirectUri: Uri): Try[Map[String, String]] =
|
||||
createRequest(tokenJsonnetSource, tokenArguments(code, redirectUri))
|
||||
|
||||
private lazy val refreshJsonnetSource: (String, sjsonnet.Path) =
|
||||
jsonnetSource(refreshTemplate, refreshResourcePath)
|
||||
private def refreshArguments(refreshToken: RefreshToken): Map[String, ujson.Value] = Map(
|
||||
"config" -> config,
|
||||
"request" -> ujson.Obj(
|
||||
"refreshToken" -> RefreshToken.unwrap(refreshToken)
|
||||
),
|
||||
)
|
||||
def createRefreshRequest(refreshToken: RefreshToken): Try[Map[String, String]] =
|
||||
createRequest(refreshJsonnetSource, refreshArguments(refreshToken))
|
||||
}
|
||||
|
||||
object RequestTemplates {
|
||||
class InterpretTemplateException(msg: String) extends RuntimeException(msg)
|
||||
|
||||
def apply(
|
||||
clientId: String,
|
||||
clientSecret: SecretString,
|
||||
authTemplate: Option[Path],
|
||||
tokenTemplate: Option[Path],
|
||||
refreshTemplate: Option[Path],
|
||||
): RequestTemplates =
|
||||
new RequestTemplates(clientId, clientSecret, authTemplate, tokenTemplate, refreshTemplate)
|
||||
}
|
@ -1,395 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.oauth2
|
||||
|
||||
import org.apache.pekko.Done
|
||||
import org.apache.pekko.actor.ActorSystem
|
||||
import org.apache.pekko.http.scaladsl.Http
|
||||
import org.apache.pekko.http.scaladsl.Http.ServerBinding
|
||||
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
|
||||
import org.apache.pekko.http.scaladsl.model._
|
||||
import org.apache.pekko.http.scaladsl.model.headers.{HttpCookie, HttpCookiePair}
|
||||
import org.apache.pekko.http.scaladsl.server.{Directive1, Route}
|
||||
import org.apache.pekko.http.scaladsl.server.Directives._
|
||||
import org.apache.pekko.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller}
|
||||
import com.daml.auth.oauth2.api.{JsonProtocol => OAuthJsonProtocol, Response => OAuthResponse}
|
||||
import com.daml.ledger.api.{auth => lapiauth}
|
||||
import com.daml.ledger.resources.ResourceContext
|
||||
import com.daml.metrics.api.reporters.MetricsReporting
|
||||
import com.daml.metrics.pekkohttp.HttpMetricsInterceptor
|
||||
import com.typesafe.scalalogging.StrictLogging
|
||||
|
||||
import java.util.UUID
|
||||
import com.daml.auth.middleware.api.{Request, RequestStore, Response}
|
||||
import com.daml.jwt.{JwtDecoder, JwtVerifierBase}
|
||||
import com.daml.jwt.domain.Jwt
|
||||
import com.daml.auth.middleware.api.Tagged.{AccessToken, RefreshToken}
|
||||
import com.daml.ports.{Port, PortFiles}
|
||||
import scalaz.{-\/, \/-}
|
||||
import spray.json._
|
||||
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
import scala.language.postfixOps
|
||||
import scala.util.{Failure, Success, Try}
|
||||
|
||||
// This is an implementation of the trigger service auth middleware
|
||||
// for OAuth2 as specified in `/triggers/service/authentication.md`
|
||||
class Server(config: Config) extends StrictLogging {
|
||||
import com.daml.auth.middleware.api.JsonProtocol._
|
||||
import com.daml.auth.oauth2.api.JsonProtocol._
|
||||
import Server.rightsProvideClaims
|
||||
|
||||
implicit private val unmarshal: Unmarshaller[String, Uri] = Unmarshaller.strict(Uri(_))
|
||||
|
||||
private def toRedirectUri(uri: Uri) =
|
||||
config.callbackUri.getOrElse {
|
||||
Uri()
|
||||
.withScheme(uri.scheme)
|
||||
.withAuthority(uri.authority)
|
||||
.withPath(Uri.Path./("cb"))
|
||||
}
|
||||
|
||||
private val cookieName = "daml-ledger-token"
|
||||
|
||||
private def optionalToken: Directive1[Option[OAuthResponse.Token]] = {
|
||||
def f(x: HttpCookiePair) = OAuthResponse.Token.fromCookieValue(x.value)
|
||||
optionalCookie(cookieName).map(_.flatMap(f))
|
||||
}
|
||||
|
||||
// Check whether the provided token's signature is valid.
|
||||
private def tokenIsValid(accessToken: String, verifier: JwtVerifierBase): Boolean = {
|
||||
verifier.verify(Jwt(accessToken)).isRight
|
||||
}
|
||||
|
||||
// Check whether the provided access token grants at least the requested claims.
|
||||
private def tokenProvidesClaims(accessToken: String, claims: Request.Claims): Boolean = {
|
||||
for {
|
||||
decodedJwt <- JwtDecoder.decode(Jwt(accessToken)).toOption
|
||||
tokenPayload <- lapiauth.AuthServiceJWTCodec
|
||||
.readFromString(decodedJwt.payload)
|
||||
.toOption
|
||||
} yield rightsProvideClaims(tokenPayload, claims)
|
||||
} getOrElse false
|
||||
|
||||
private val requestTemplates: RequestTemplates = RequestTemplates(
|
||||
config.clientId,
|
||||
config.clientSecret,
|
||||
config.oauthAuthTemplate,
|
||||
config.oauthTokenTemplate,
|
||||
config.oauthRefreshTemplate,
|
||||
)
|
||||
|
||||
private def onTemplateSuccess(
|
||||
request: String,
|
||||
tryParams: Try[Map[String, String]],
|
||||
): Directive1[Map[String, String]] =
|
||||
tryParams match {
|
||||
case Failure(exception) =>
|
||||
logger.error(s"Failed to interpret $request request template: ${exception.getMessage}")
|
||||
complete(StatusCodes.InternalServerError, s"Failed to construct $request request")
|
||||
case Success(params) => provide(params)
|
||||
}
|
||||
|
||||
private val auth: Route =
|
||||
parameters(Symbol("claims").as[Request.Claims])
|
||||
.as[Request.Auth](Request.Auth) { auth =>
|
||||
optionalToken {
|
||||
case Some(token)
|
||||
if tokenIsValid(token.accessToken, config.tokenVerifier) &&
|
||||
tokenProvidesClaims(token.accessToken, auth.claims) =>
|
||||
complete(
|
||||
Response
|
||||
.Authorize(
|
||||
accessToken = AccessToken(token.accessToken),
|
||||
refreshToken = RefreshToken.subst(token.refreshToken),
|
||||
)
|
||||
)
|
||||
// TODO[AH] Include a `WWW-Authenticate` header.
|
||||
case _ => complete(StatusCodes.Unauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
private val requests: RequestStore[UUID, Option[Uri]] =
|
||||
new RequestStore(config.maxLoginRequests, config.loginTimeout)
|
||||
|
||||
private val login: Route =
|
||||
parameters(
|
||||
Symbol("redirect_uri").as[Uri] ?,
|
||||
Symbol("claims").as[Request.Claims],
|
||||
Symbol("state") ?,
|
||||
)
|
||||
.as[Request.Login](Request.Login) { login =>
|
||||
extractRequest { request =>
|
||||
val requestId = UUID.randomUUID
|
||||
val stored = requests.put(
|
||||
requestId,
|
||||
login.redirectUri.map { redirectUri =>
|
||||
var query = redirectUri.query().to(Seq)
|
||||
login.state.foreach(x => query ++= Seq("state" -> x))
|
||||
redirectUri.withQuery(Uri.Query(query: _*))
|
||||
},
|
||||
)
|
||||
if (stored) {
|
||||
onTemplateSuccess(
|
||||
"authorization",
|
||||
requestTemplates.createAuthRequest(
|
||||
login.claims,
|
||||
requestId,
|
||||
toRedirectUri(request.uri),
|
||||
),
|
||||
) { params =>
|
||||
val query = Uri.Query(params)
|
||||
val uri = config.oauthAuth.withQuery(query)
|
||||
redirect(uri, StatusCodes.Found)
|
||||
}
|
||||
} else {
|
||||
complete(StatusCodes.ServiceUnavailable)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val loginCallback: Route = {
|
||||
extractActorSystem { implicit sys =>
|
||||
extractExecutionContext { implicit ec =>
|
||||
def popRequest(optState: Option[String]): Directive1[Option[Uri]] = {
|
||||
val redirectUri = for {
|
||||
state <- optState
|
||||
requestId <- Try(UUID.fromString(state)).toOption
|
||||
redirectUri <- requests.pop(requestId)
|
||||
} yield redirectUri
|
||||
redirectUri match {
|
||||
case Some(redirectUri) => provide(redirectUri)
|
||||
case None => complete(StatusCodes.NotFound)
|
||||
}
|
||||
}
|
||||
|
||||
concat(
|
||||
parameters(Symbol("code"), Symbol("state") ?)
|
||||
.as[OAuthResponse.Authorize](OAuthResponse.Authorize) { authorize =>
|
||||
popRequest(authorize.state) { redirectUri =>
|
||||
extractRequest { request =>
|
||||
onTemplateSuccess(
|
||||
"token",
|
||||
requestTemplates.createTokenRequest(authorize.code, toRedirectUri(request.uri)),
|
||||
) { params =>
|
||||
val entity = FormData(params).toEntity
|
||||
val req = HttpRequest(
|
||||
uri = config.oauthToken,
|
||||
entity = entity,
|
||||
method = HttpMethods.POST,
|
||||
)
|
||||
val tokenRequest =
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
tokenResp <-
|
||||
if (resp.status != StatusCodes.OK) {
|
||||
Unmarshal(resp).to[String].flatMap { msg =>
|
||||
Future.failed(
|
||||
new RuntimeException(
|
||||
s"Failed to retrieve token at ${req.uri} (${resp.status}): $msg"
|
||||
)
|
||||
)
|
||||
}
|
||||
} else {
|
||||
Unmarshal(resp).to[OAuthResponse.Token]
|
||||
}
|
||||
} yield tokenResp
|
||||
onSuccess(tokenRequest) { token =>
|
||||
setCookie(
|
||||
HttpCookie(
|
||||
name = cookieName,
|
||||
value = token.toCookieValue,
|
||||
path = Some("/"),
|
||||
maxAge = token.expiresIn.map(_.toLong),
|
||||
secure = config.cookieSecure,
|
||||
httpOnly = true,
|
||||
)
|
||||
) {
|
||||
redirectUri match {
|
||||
case Some(uri) =>
|
||||
redirect(uri, StatusCodes.Found)
|
||||
case None =>
|
||||
complete(StatusCodes.OK)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
parameters(
|
||||
Symbol("error"),
|
||||
Symbol("error_description") ?,
|
||||
Symbol("error_uri").as[Uri] ?,
|
||||
Symbol("state") ?,
|
||||
)
|
||||
.as[OAuthResponse.Error](OAuthResponse.Error) { error =>
|
||||
popRequest(error.state) {
|
||||
case Some(redirectUri) =>
|
||||
val uri = redirectUri.withQuery {
|
||||
var params = redirectUri.query().to(Seq)
|
||||
params ++= Seq("error" -> error.error)
|
||||
error.errorDescription.foreach(x => params ++= Seq("error_description" -> x))
|
||||
Uri.Query(params: _*)
|
||||
}
|
||||
redirect(uri, StatusCodes.Found)
|
||||
case None =>
|
||||
import OAuthJsonProtocol.errorRespFormat
|
||||
complete(StatusCodes.Forbidden, error)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val refresh: Route = {
|
||||
extractActorSystem { implicit sys =>
|
||||
extractExecutionContext { implicit ec =>
|
||||
entity(as[Request.Refresh]) { refresh =>
|
||||
onTemplateSuccess(
|
||||
"refresh",
|
||||
requestTemplates.createRefreshRequest(refresh.refreshToken),
|
||||
) { params =>
|
||||
val entity = FormData(params).toEntity
|
||||
val req =
|
||||
HttpRequest(uri = config.oauthToken, entity = entity, method = HttpMethods.POST)
|
||||
val tokenRequest = Http().singleRequest(req)
|
||||
onSuccess(tokenRequest) { resp =>
|
||||
resp.status match {
|
||||
// Return access and refresh token on success.
|
||||
case StatusCodes.OK =>
|
||||
val authResponse = Unmarshal(resp).to[OAuthResponse.Token].map { token =>
|
||||
Response.Authorize(
|
||||
accessToken = AccessToken(token.accessToken),
|
||||
refreshToken = RefreshToken.subst(token.refreshToken),
|
||||
)
|
||||
}
|
||||
complete(authResponse)
|
||||
// Forward client errors.
|
||||
case status: StatusCodes.ClientError =>
|
||||
complete(HttpResponse.apply(status = status, entity = resp.entity))
|
||||
// Fail on unexpected responses.
|
||||
case _ =>
|
||||
onSuccess(Unmarshal(resp).to[String]) { msg =>
|
||||
failWith(
|
||||
new RuntimeException(
|
||||
s"Failed to retrieve refresh token (${resp.status}): $msg"
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def route: Route = concat(
|
||||
path("auth") {
|
||||
get {
|
||||
auth
|
||||
}
|
||||
},
|
||||
path("login") {
|
||||
get {
|
||||
login
|
||||
}
|
||||
},
|
||||
path("cb") {
|
||||
get {
|
||||
loginCallback
|
||||
}
|
||||
},
|
||||
path("refresh") {
|
||||
post {
|
||||
refresh
|
||||
}
|
||||
},
|
||||
path("livez") {
|
||||
complete(StatusCodes.OK, JsObject("status" -> JsString("pass")))
|
||||
},
|
||||
path("readyz") {
|
||||
complete(StatusCodes.OK, JsObject("status" -> JsString("pass")))
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
object Server extends StrictLogging {
|
||||
def start(config: Config, registerGlobalOpenTelemetry: Boolean)(implicit
|
||||
sys: ActorSystem
|
||||
): Future[ServerBinding] = {
|
||||
implicit val ec: ExecutionContext = sys.getDispatcher
|
||||
|
||||
implicit val rc: ResourceContext = ResourceContext(ec)
|
||||
|
||||
val metricsReporting = new MetricsReporting(
|
||||
getClass.getName,
|
||||
config.metricsReporter,
|
||||
config.metricsReportingInterval,
|
||||
registerGlobalOpenTelemetry,
|
||||
config.histograms,
|
||||
)((_, otelMeter) => Oauth2MiddlewareMetrics(otelMeter))
|
||||
val metricsResource = metricsReporting.acquire()
|
||||
|
||||
val rateDurationSizeMetrics = metricsResource.asFuture.map { implicit metrics =>
|
||||
HttpMetricsInterceptor.rateDurationSizeMetrics(
|
||||
metrics.http
|
||||
)
|
||||
}
|
||||
|
||||
val route = new Server(config).route
|
||||
|
||||
for {
|
||||
metricsInterceptor <- rateDurationSizeMetrics
|
||||
binding <- Http()
|
||||
.newServerAt(config.address, config.port)
|
||||
.bind(metricsInterceptor apply route)
|
||||
_ <- config.portFile match {
|
||||
case Some(portFile) =>
|
||||
PortFiles.write(portFile, Port(binding.localAddress.getPort)) match {
|
||||
case -\/(err) =>
|
||||
Future.failed(new RuntimeException(s"Failed to create port file: ${err.toString}"))
|
||||
case \/-(()) => Future.successful(())
|
||||
}
|
||||
case None => Future.successful(())
|
||||
}
|
||||
} yield binding
|
||||
}
|
||||
|
||||
def stop(f: Future[ServerBinding])(implicit ec: ExecutionContext): Future[Done] =
|
||||
f.flatMap(_.unbind())
|
||||
|
||||
private[oauth2] def rightsProvideClaims(
|
||||
r: lapiauth.AuthServiceJWTPayload,
|
||||
claims: Request.Claims,
|
||||
): Boolean = {
|
||||
val (precond, userId) = r match {
|
||||
case tp: lapiauth.CustomDamlJWTPayload =>
|
||||
(
|
||||
(tp.admin || !claims.admin) &&
|
||||
claims.actAs
|
||||
.toSet[String]
|
||||
.subsetOf(tp.actAs.toSet) &&
|
||||
claims.readAs
|
||||
.toSet[String]
|
||||
.subsetOf(tp.readAs.toSet ++ tp.actAs),
|
||||
tp.applicationId,
|
||||
)
|
||||
case tp: lapiauth.StandardJWTPayload =>
|
||||
// NB: in this mode we check the applicationId claim (if supplied)
|
||||
// and ignore everything else
|
||||
(true, Some(tp.userId))
|
||||
}
|
||||
precond && ((claims.applicationId, userId) match {
|
||||
// No requirement on app id
|
||||
case (None, _) => true
|
||||
// Token valid for all app ids.
|
||||
case (_, None) => true
|
||||
case (Some(expectedAppId), Some(actualAppId)) => expectedAppId == actualAppId
|
||||
})
|
||||
}
|
||||
|
||||
}
|
@ -1,190 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.oauth2.api
|
||||
|
||||
import java.util.Base64
|
||||
|
||||
import org.apache.pekko.http.scaladsl.model.{FormData, HttpEntity, RequestEntity, Uri}
|
||||
import org.apache.pekko.http.scaladsl.model.Uri.Query
|
||||
import org.apache.pekko.http.scaladsl.marshalling._
|
||||
import org.apache.pekko.http.scaladsl.unmarshalling._
|
||||
import spray.json._
|
||||
|
||||
import scala.util.Try
|
||||
|
||||
object Request {
|
||||
|
||||
// https://tools.ietf.org/html/rfc6749#section-4.1.1
|
||||
case class Authorize(
|
||||
responseType: String,
|
||||
clientId: String,
|
||||
redirectUri: Uri, // optional in oauth but we require it
|
||||
scope: Option[String],
|
||||
state: Option[String],
|
||||
audience: Option[Uri],
|
||||
) { // required by auth0 to obtain an access_token
|
||||
def toQuery: Query = {
|
||||
var params: Seq[(String, String)] =
|
||||
Seq(
|
||||
("response_type", responseType),
|
||||
("client_id", clientId),
|
||||
("redirect_uri", redirectUri.toString),
|
||||
)
|
||||
scope.foreach { scope =>
|
||||
params ++= Seq(("scope", scope))
|
||||
}
|
||||
state.foreach { state =>
|
||||
params ++= Seq(("state", state))
|
||||
}
|
||||
audience.foreach { audience =>
|
||||
params ++= Seq(("audience", audience.toString))
|
||||
}
|
||||
Query(params: _*)
|
||||
}
|
||||
}
|
||||
|
||||
// https://tools.ietf.org/html/rfc6749#section-4.1.3
|
||||
case class Token(
|
||||
grantType: String,
|
||||
code: String,
|
||||
redirectUri: Uri,
|
||||
clientId: String,
|
||||
clientSecret: String,
|
||||
)
|
||||
|
||||
object Token {
|
||||
implicit val marshalRequestEntity: Marshaller[Token, RequestEntity] =
|
||||
Marshaller.combined { token =>
|
||||
FormData(
|
||||
"grant_type" -> token.grantType,
|
||||
"code" -> token.code,
|
||||
"redirect_uri" -> token.redirectUri.toString,
|
||||
"client_id" -> token.clientId,
|
||||
"client_secret" -> token.clientSecret,
|
||||
)
|
||||
}
|
||||
implicit val unmarshalHttpEntity: Unmarshaller[HttpEntity, Token] =
|
||||
Unmarshaller.defaultUrlEncodedFormDataUnmarshaller.map { form =>
|
||||
Token(
|
||||
grantType = form.fields.get("grant_type").get,
|
||||
code = form.fields.get("code").get,
|
||||
redirectUri = form.fields.get("redirect_uri").get,
|
||||
clientId = form.fields.get("client_id").get,
|
||||
clientSecret = form.fields.get("client_secret").get,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// https://tools.ietf.org/html/rfc6749#section-6
|
||||
case class Refresh(
|
||||
grantType: String,
|
||||
refreshToken: String,
|
||||
clientId: String,
|
||||
clientSecret: String,
|
||||
)
|
||||
|
||||
object Refresh {
|
||||
implicit val marshalRequestEntity: Marshaller[Refresh, RequestEntity] =
|
||||
Marshaller.combined { refresh =>
|
||||
FormData(
|
||||
"grant_type" -> refresh.grantType,
|
||||
"refresh_token" -> refresh.refreshToken,
|
||||
"client_id" -> refresh.clientId,
|
||||
"client_secret" -> refresh.clientSecret,
|
||||
)
|
||||
}
|
||||
implicit val unmarshalHttpEntity: Unmarshaller[HttpEntity, Refresh] =
|
||||
Unmarshaller.defaultUrlEncodedFormDataUnmarshaller.map { form =>
|
||||
Refresh(
|
||||
grantType = form.fields.get("grant_type").get,
|
||||
refreshToken = form.fields.get("refresh_token").get,
|
||||
clientId = form.fields.get("client_id").get,
|
||||
clientSecret = form.fields.get("client_secret").get,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object Response {
|
||||
|
||||
// https://tools.ietf.org/html/rfc6749#section-4.1.2
|
||||
case class Authorize(code: String, state: Option[String]) {
|
||||
def toQuery: Query = state match {
|
||||
case None => Query(("code", code))
|
||||
case Some(state) => Query(("code", code), ("state", state))
|
||||
}
|
||||
}
|
||||
|
||||
// https://tools.ietf.org/html/rfc6749#section-4.1.2.1
|
||||
case class Error(
|
||||
error: String,
|
||||
errorDescription: Option[String],
|
||||
errorUri: Option[Uri],
|
||||
state: Option[String],
|
||||
) {
|
||||
def toQuery: Query = {
|
||||
var params: Seq[(String, String)] = Seq("error" -> error)
|
||||
errorDescription.foreach(x => params ++= Seq("error_description" -> x))
|
||||
errorUri.foreach(x => params ++= Seq("error_uri" -> x.toString))
|
||||
state.foreach(x => params ++= Seq("state" -> x))
|
||||
Query(params: _*)
|
||||
}
|
||||
}
|
||||
|
||||
// https://tools.ietf.org/html/rfc6749#section-5.1
|
||||
case class Token(
|
||||
accessToken: String,
|
||||
tokenType: String,
|
||||
expiresIn: Option[Int],
|
||||
refreshToken: Option[String],
|
||||
scope: Option[String],
|
||||
) {
|
||||
def toCookieValue: String = {
|
||||
import JsonProtocol._
|
||||
Base64.getUrlEncoder().encodeToString(this.toJson.compactPrint.getBytes)
|
||||
}
|
||||
}
|
||||
|
||||
object Token {
|
||||
def fromCookieValue(s: String): Option[Token] = {
|
||||
import JsonProtocol._
|
||||
for {
|
||||
bytes <- Try(Base64.getUrlDecoder().decode(s))
|
||||
json <- Try(new String(bytes).parseJson)
|
||||
token <- Try(json.convertTo[Token])
|
||||
} yield token
|
||||
}.toOption
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object JsonProtocol extends DefaultJsonProtocol {
|
||||
implicit object UriFormat extends JsonFormat[Uri] {
|
||||
def read(value: JsValue) = value match {
|
||||
case JsString(s) => Uri(s)
|
||||
case _ => deserializationError(s"Expected Uri string but got $value")
|
||||
}
|
||||
def write(uri: Uri) = JsString(uri.toString)
|
||||
}
|
||||
implicit val tokenRespFormat: RootJsonFormat[Response.Token] = {
|
||||
jsonFormat(
|
||||
Response.Token.apply,
|
||||
"access_token",
|
||||
"token_type",
|
||||
"expires_in",
|
||||
"refresh_token",
|
||||
"scope",
|
||||
)
|
||||
}
|
||||
implicit val errorRespFormat: RootJsonFormat[Response.Error] = {
|
||||
jsonFormat(
|
||||
Response.Error.apply,
|
||||
"error",
|
||||
"error_description",
|
||||
"error_uri",
|
||||
"state",
|
||||
)
|
||||
}
|
||||
}
|
@ -1,58 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.oauth2.test.server
|
||||
|
||||
import java.time.Clock
|
||||
|
||||
import com.daml.ports.Port
|
||||
|
||||
case class Config(
|
||||
// Port the authorization server listens on
|
||||
port: Port,
|
||||
// Ledger ID of issued tokens
|
||||
ledgerId: String,
|
||||
// Secret used to sign JWTs
|
||||
jwtSecret: String,
|
||||
// Use the provided clock instead of system time for token generation.
|
||||
clock: Option[Clock],
|
||||
// produce user tokens instead of claim tokens
|
||||
yieldUserTokens: Boolean,
|
||||
)
|
||||
|
||||
object Config {
|
||||
private val Empty =
|
||||
Config(
|
||||
port = Port.Dynamic,
|
||||
ledgerId = null,
|
||||
jwtSecret = null,
|
||||
clock = None,
|
||||
yieldUserTokens = false,
|
||||
)
|
||||
|
||||
def parseConfig(args: collection.Seq[String]): Option[Config] =
|
||||
configParser.parse(args, Empty)
|
||||
|
||||
@SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements"))
|
||||
val configParser: scopt.OptionParser[Config] =
|
||||
new scopt.OptionParser[Config]("oauth-test-server") {
|
||||
head("OAuth2 TestServer")
|
||||
|
||||
opt[Int]("port")
|
||||
.action((x, c) => c.copy(port = Port(x)))
|
||||
.required()
|
||||
.text("Port to listen on")
|
||||
|
||||
opt[String]("ledger-id")
|
||||
.action((x, c) => c.copy(ledgerId = x))
|
||||
|
||||
opt[String]("secret")
|
||||
.action((x, c) => c.copy(jwtSecret = x))
|
||||
|
||||
opt[Unit]("yield-user-tokens")
|
||||
.optional()
|
||||
.action((_, c) => c.copy(yieldUserTokens = true))
|
||||
|
||||
help("help").text("Print this usage text")
|
||||
}
|
||||
}
|
@ -1,45 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.oauth2.test.server
|
||||
|
||||
import org.apache.pekko.actor.ActorSystem
|
||||
import com.typesafe.scalalogging.StrictLogging
|
||||
|
||||
import scala.concurrent.duration._
|
||||
import scala.concurrent.{Await, ExecutionContext}
|
||||
import scala.util.{Failure, Success}
|
||||
|
||||
object Main extends StrictLogging {
|
||||
def main(args: Array[String]): Unit = {
|
||||
Config.parseConfig(args) match {
|
||||
case Some(config) => main(config)
|
||||
case None => sys.exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
private def main(config: Config): Unit = {
|
||||
implicit val system: ActorSystem = ActorSystem("system")
|
||||
implicit val executionContext: ExecutionContext = system.dispatcher
|
||||
|
||||
def terminate() = Await.result(system.terminate(), 10.seconds)
|
||||
|
||||
val bindingFuture = Server(config).start()
|
||||
|
||||
sys.addShutdownHook {
|
||||
Server
|
||||
.stop(bindingFuture)
|
||||
.onComplete { _ =>
|
||||
terminate()
|
||||
}
|
||||
}
|
||||
|
||||
bindingFuture.onComplete {
|
||||
case Success(binding) =>
|
||||
logger.info(s"Started server: $binding")
|
||||
case Failure(e) =>
|
||||
logger.error(s"Failed to start server: $e")
|
||||
terminate()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,249 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.oauth2.test.server
|
||||
|
||||
import java.time.Instant
|
||||
import java.util.UUID
|
||||
import org.apache.pekko.Done
|
||||
import org.apache.pekko.actor.ActorSystem
|
||||
import org.apache.pekko.http.scaladsl.Http
|
||||
import org.apache.pekko.http.scaladsl.Http.ServerBinding
|
||||
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
|
||||
import org.apache.pekko.http.scaladsl.model.{StatusCodes, Uri}
|
||||
import org.apache.pekko.http.scaladsl.server.Directives._
|
||||
import org.apache.pekko.http.scaladsl.unmarshalling.Unmarshaller
|
||||
import com.daml.auth.oauth2.api.{Request, Response}
|
||||
import com.daml.jwt.JwtSigner
|
||||
import com.daml.jwt.domain.DecodedJwt
|
||||
import com.daml.ledger.api.auth.{
|
||||
AuthServiceJWTCodec,
|
||||
AuthServiceJWTPayload,
|
||||
CustomDamlJWTPayload,
|
||||
StandardJWTPayload,
|
||||
StandardJWTTokenFormat,
|
||||
}
|
||||
import com.daml.lf.data.Ref
|
||||
|
||||
import scala.collection.concurrent.TrieMap
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
import scala.language.postfixOps
|
||||
import scala.util.{Failure, Success, Try}
|
||||
|
||||
// This is a test authorization server that implements the OAuth2 authorization code flow.
|
||||
// This is primarily intended for use in the trigger service tests but could also serve
|
||||
// as a useful ground for experimentation.
|
||||
// Given scopes of the form `actAs:$party`, the authorization server will issue
|
||||
// tokens with the respective claims. Requests for authorized parties will be accepted and
|
||||
// request to /authorize are immediately redirected to the redirect_uri.
|
||||
class Server(config: Config) {
|
||||
import Server.withExp
|
||||
|
||||
private val jwtHeader = """{"alg": "HS256", "typ": "JWT"}"""
|
||||
val tokenLifetimeSeconds = 24 * 60 * 60
|
||||
|
||||
private var unauthorizedParties: Set[Ref.Party] = Set()
|
||||
|
||||
// Remove the given party from the set of unauthorized parties.
|
||||
def authorizeParty(party: Ref.Party): Unit = {
|
||||
unauthorizedParties = unauthorizedParties - party
|
||||
}
|
||||
|
||||
// Add the given party to the set of unauthorized parties.
|
||||
def revokeParty(party: Ref.Party): Unit = {
|
||||
unauthorizedParties = unauthorizedParties + party
|
||||
}
|
||||
|
||||
// Clear the set of unauthorized parties.
|
||||
def resetAuthorizedParties(): Unit = {
|
||||
unauthorizedParties = Set()
|
||||
}
|
||||
|
||||
private var allowAdmin = true
|
||||
|
||||
def authorizeAdmin(): Unit = {
|
||||
allowAdmin = true
|
||||
}
|
||||
|
||||
def revokeAdmin(): Unit = {
|
||||
allowAdmin = false
|
||||
}
|
||||
|
||||
def resetAdmin(): Unit = {
|
||||
allowAdmin = true
|
||||
}
|
||||
|
||||
// To keep things as simple as possible, we use a UUID as the authorization code and refresh token
|
||||
// and in the /authorize request we already pre-compute the JWT payload based on the scope.
|
||||
// The token request then only does a lookup and signs the token.
|
||||
private val requests = TrieMap.empty[UUID, AuthServiceJWTPayload]
|
||||
|
||||
private def tokenExpiry(): Instant = {
|
||||
val now = config.clock match {
|
||||
case Some(clock) => Instant.now(clock)
|
||||
case None => Instant.now()
|
||||
}
|
||||
now.plusSeconds(tokenLifetimeSeconds.asInstanceOf[Long])
|
||||
}
|
||||
private def toPayload(req: Request.Authorize): AuthServiceJWTPayload = {
|
||||
var actAs: Seq[String] = Seq()
|
||||
var readAs: Seq[String] = Seq()
|
||||
var admin: Boolean = false
|
||||
var applicationId: Option[String] = None
|
||||
req.scope.foreach(_.split(" ").foreach {
|
||||
case s if s.startsWith("actAs:") => actAs ++= Seq(s.stripPrefix("actAs:"))
|
||||
case s if s.startsWith("readAs:") => readAs ++= Seq(s.stripPrefix("readAs:"))
|
||||
case s if s == "admin" => admin = true
|
||||
// Given that this is only for testing,
|
||||
// we don’t guard against multiple application id claims.
|
||||
case s if s.startsWith("applicationId:") =>
|
||||
applicationId = Some(s.stripPrefix("applicationId:"))
|
||||
case _ => ()
|
||||
})
|
||||
if (config.yieldUserTokens) // ignore everything but the applicationId
|
||||
StandardJWTPayload(
|
||||
issuer = None,
|
||||
userId = applicationId getOrElse "",
|
||||
participantId = None,
|
||||
exp = None,
|
||||
format = StandardJWTTokenFormat.Scope,
|
||||
audiences = List.empty,
|
||||
scope = Some("daml_ledger_api"),
|
||||
)
|
||||
else
|
||||
CustomDamlJWTPayload(
|
||||
ledgerId = Some(config.ledgerId),
|
||||
applicationId = applicationId,
|
||||
// Not required by the default auth service
|
||||
participantId = None,
|
||||
// Expiry is set when the token is retrieved
|
||||
exp = None,
|
||||
// no admin claim for now.
|
||||
admin = admin,
|
||||
actAs = actAs.toList,
|
||||
readAs = readAs.toList,
|
||||
)
|
||||
}
|
||||
// Whether the current configuration of unauthorized parties and admin rights allows to grant the given token payload.
|
||||
private def authorize(payload: AuthServiceJWTPayload): Either[String, Unit] = payload match {
|
||||
case payload: CustomDamlJWTPayload =>
|
||||
val parties = (payload.readAs ++ payload.actAs).toSet
|
||||
val deniedParties = parties & unauthorizedParties.toSet[String]
|
||||
val deniedAdmin: Boolean = payload.admin && !allowAdmin
|
||||
if (deniedParties.nonEmpty) {
|
||||
Left(s"Access to parties ${deniedParties.mkString(" ")} denied")
|
||||
} else if (deniedAdmin) {
|
||||
Left("Admin access denied")
|
||||
} else {
|
||||
Right(())
|
||||
}
|
||||
case _: StandardJWTPayload => Right(())
|
||||
}
|
||||
|
||||
import Request.Refresh.unmarshalHttpEntity
|
||||
implicit val unmarshal: Unmarshaller[String, Uri] = Unmarshaller.strict(Uri(_))
|
||||
|
||||
val route = concat(
|
||||
path("authorize") {
|
||||
get {
|
||||
parameters(
|
||||
Symbol("response_type"),
|
||||
Symbol("client_id"),
|
||||
Symbol("redirect_uri").as[Uri],
|
||||
Symbol("scope") ?,
|
||||
Symbol("state") ?,
|
||||
Symbol("audience").as[Uri] ?,
|
||||
)
|
||||
.as[Request.Authorize](Request.Authorize) { request =>
|
||||
val payload = toPayload(request)
|
||||
authorize(payload) match {
|
||||
case Left(msg) =>
|
||||
val params =
|
||||
Response
|
||||
.Error(
|
||||
error = "access_denied",
|
||||
errorDescription = Some(msg),
|
||||
errorUri = None,
|
||||
state = request.state,
|
||||
)
|
||||
.toQuery
|
||||
redirect(request.redirectUri.withQuery(params), StatusCodes.Found)
|
||||
case Right(()) =>
|
||||
val authorizationCode = UUID.randomUUID()
|
||||
val params =
|
||||
Response
|
||||
.Authorize(code = authorizationCode.toString, state = request.state)
|
||||
.toQuery
|
||||
requests.update(authorizationCode, payload)
|
||||
// We skip any actual consent screen since this is only intended for testing and
|
||||
// this is outside of the scope of the trigger service anyway.
|
||||
redirect(request.redirectUri.withQuery(params), StatusCodes.Found)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
path("token") {
|
||||
post {
|
||||
def returnToken(uuid: String) =
|
||||
Try(UUID.fromString(uuid)) match {
|
||||
case Failure(_) =>
|
||||
complete((StatusCodes.BadRequest, "Malformed code or refresh token"))
|
||||
case Success(uuid) =>
|
||||
requests.remove(uuid) match {
|
||||
case Some(payload) =>
|
||||
// Generate refresh token
|
||||
val refreshCode = UUID.randomUUID()
|
||||
requests.update(refreshCode, payload)
|
||||
// Construct access token with expiry
|
||||
val accessToken = JwtSigner.HMAC256
|
||||
.sign(
|
||||
DecodedJwt(
|
||||
jwtHeader,
|
||||
AuthServiceJWTCodec.compactPrint(withExp(payload, Some(tokenExpiry()))),
|
||||
),
|
||||
config.jwtSecret,
|
||||
)
|
||||
.getOrElse(throw new IllegalArgumentException("Failed to sign a token"))
|
||||
.value
|
||||
import com.daml.auth.oauth2.api.JsonProtocol._
|
||||
complete(
|
||||
Response.Token(
|
||||
accessToken = accessToken,
|
||||
refreshToken = Some(refreshCode.toString),
|
||||
expiresIn = Some(tokenLifetimeSeconds),
|
||||
scope = None,
|
||||
tokenType = "bearer",
|
||||
)
|
||||
)
|
||||
case None =>
|
||||
complete(StatusCodes.NotFound)
|
||||
}
|
||||
}
|
||||
concat(
|
||||
entity(as[Request.Token]) { request =>
|
||||
returnToken(request.code)
|
||||
},
|
||||
entity(as[Request.Refresh]) { request =>
|
||||
returnToken(request.refreshToken)
|
||||
},
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def start()(implicit system: ActorSystem): Future[ServerBinding] = {
|
||||
Http().newServerAt("localhost", config.port.value).bind(route)
|
||||
}
|
||||
}
|
||||
|
||||
object Server {
|
||||
def apply(config: Config) = new Server(config)
|
||||
def stop(f: Future[ServerBinding])(implicit ec: ExecutionContext): Future[Done] =
|
||||
f.flatMap(_.unbind())
|
||||
|
||||
private def withExp(payload: AuthServiceJWTPayload, exp: Option[Instant]): AuthServiceJWTPayload =
|
||||
payload match {
|
||||
case payload: CustomDamlJWTPayload => payload.copy(exp = exp)
|
||||
case payload: StandardJWTPayload => payload.copy(exp = exp)
|
||||
}
|
||||
}
|
@ -1,18 +0,0 @@
|
||||
{
|
||||
callback-uri = "https://example.com/auth/cb"
|
||||
oauth-auth = "https://oauth2/uri"
|
||||
oauth-token = "https://oauth2/token"
|
||||
|
||||
// client-id = ${DAML_CLIENT_ID}
|
||||
// client-secret = ${DAML_CLIENT_SECRET}
|
||||
// can be set via env variables , dummy values for test purposes
|
||||
client-id = foo
|
||||
client-secret = bar
|
||||
|
||||
// type can be one of rs256-crt, es256-crt, es512-crt, rs256-jwks
|
||||
// uri is the uri to the cert file or the jwks url
|
||||
token-verifier {
|
||||
type = "rs256-jwks"
|
||||
uri = "https://example.com/.well-known/jwks.json"
|
||||
}
|
||||
}
|
@ -1,31 +0,0 @@
|
||||
{
|
||||
address = "127.0.0.1"
|
||||
port = 3000
|
||||
callback-uri = "https://example.com/auth/cb"
|
||||
max-login-requests = 10
|
||||
login-timeout = 60s
|
||||
cookie-secure = false
|
||||
oauth-auth = "https://oauth2/uri"
|
||||
oauth-token = "https://oauth2/token"
|
||||
|
||||
oauth-auth-template = "auth_template"
|
||||
oauth-token-template = "token_template"
|
||||
oauth-refresh-template = "refresh_template"
|
||||
|
||||
// client-id = ${DAML_CLIENT_ID}
|
||||
// client-secret = ${DAML_CLIENT_SECRET}
|
||||
// can be set via env variables , dummy values for test purposes
|
||||
client-id = foo
|
||||
client-secret = bar
|
||||
|
||||
// type can be one of rs256-crt, es256-crt, es512-crt, rs256-jwks
|
||||
// uri is the uri to the cert file or the jwks url
|
||||
token-verifier {
|
||||
type = "rs256-jwks"
|
||||
uri = "https://example.com/.well-known/jwks.json"
|
||||
}
|
||||
metrics {
|
||||
reporter = "prometheus://0.0.0.0:5104"
|
||||
reporting-interval = 30s
|
||||
}
|
||||
}
|
@ -1,181 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.jwt
|
||||
|
||||
import java.math.BigInteger
|
||||
import java.nio.file.{Files, Path}
|
||||
import java.security.interfaces.{ECPrivateKey, ECPublicKey, RSAPrivateKey, RSAPublicKey}
|
||||
import java.security.{KeyPair, KeyPairGenerator, PrivateKey, PublicKey, Security}
|
||||
import java.time.Instant
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
|
||||
import com.auth0.jwt.JWT
|
||||
import com.auth0.jwt.algorithms.Algorithm
|
||||
import com.daml.fs.TemporaryDirectory
|
||||
import com.daml.http.test.SimpleHttpServer
|
||||
import com.daml.scalautil.Statement.discard
|
||||
import com.daml.jwt.JwtVerifierConfigurationCliSpec._
|
||||
import com.daml.ledger.api.auth.ClaimSet.Claims
|
||||
import com.daml.ledger.api.auth.{AuthService, AuthServiceJWT, AuthServiceWildcard, ClaimPublic}
|
||||
import io.grpc.Metadata
|
||||
import org.bouncycastle.asn1.x500.X500Name
|
||||
import org.bouncycastle.cert.jcajce.{JcaX509CertificateConverter, JcaX509v3CertificateBuilder}
|
||||
import org.bouncycastle.jce.provider.BouncyCastleProvider
|
||||
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
|
||||
import org.scalatest.Assertion
|
||||
import org.scalatest.matchers.should.Matchers
|
||||
import org.scalatest.wordspec.AsyncWordSpec
|
||||
import scopt.OptionParser
|
||||
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
import scala.jdk.FutureConverters._
|
||||
|
||||
class JwtVerifierConfigurationCliSpec extends AsyncWordSpec with Matchers {
|
||||
discard(Security.addProvider(new BouncyCastleProvider))
|
||||
|
||||
"auth command-line parsers" should {
|
||||
"parse and configure the authorisation mechanism correctly when `--auth-jwt-hs256-unsafe <secret>` is passed" in {
|
||||
val secret = "someSecret"
|
||||
val authService = parseConfig(Array("--auth-jwt-hs256-unsafe", secret))
|
||||
val token = JWT.create().sign(Algorithm.HMAC256(secret))
|
||||
val metadata = createAuthMetadata(token)
|
||||
decodeAndCheckMetadata(authService, metadata)
|
||||
}
|
||||
|
||||
"parse and configure the authorisation mechanism correctly when `--auth-jwt-rs256-crt <PK.crt>` is passed" in
|
||||
new TemporaryDirectory(getClass.getSimpleName).use { directory =>
|
||||
val (publicKey, privateKey) = newRsaKeyPair()
|
||||
val certificatePath = newCertificate("SHA256WithRSA", directory, publicKey, privateKey)
|
||||
val token = JWT.create().sign(Algorithm.RSA256(publicKey, privateKey))
|
||||
|
||||
val authService = parseConfig(Array("--auth-jwt-rs256-crt", certificatePath.toString))
|
||||
val metadata = createAuthMetadata(token)
|
||||
decodeAndCheckMetadata(authService, metadata)
|
||||
}
|
||||
|
||||
"parse and configure the authorisation mechanism correctly when `--auth-jwt-es256-crt <PK.crt>` is passed" in
|
||||
new TemporaryDirectory(getClass.getSimpleName).use { directory =>
|
||||
val (publicKey, privateKey) = newEcdsaKeyPair()
|
||||
val certificatePath = newCertificate("SHA256WithECDSA", directory, publicKey, privateKey)
|
||||
val token = JWT.create().sign(Algorithm.ECDSA256(publicKey, privateKey))
|
||||
|
||||
val authService = parseConfig(Array("--auth-jwt-es256-crt", certificatePath.toString))
|
||||
val metadata = createAuthMetadata(token)
|
||||
decodeAndCheckMetadata(authService, metadata)
|
||||
}
|
||||
|
||||
"parse and configure the authorisation mechanism correctly when `--auth-jwt-es512-crt <PK.crt>` is passed" in
|
||||
new TemporaryDirectory(getClass.getSimpleName).use { directory =>
|
||||
val (publicKey, privateKey) = newEcdsaKeyPair()
|
||||
val certificatePath = newCertificate("SHA512WithECDSA", directory, publicKey, privateKey)
|
||||
val token = JWT.create().sign(Algorithm.ECDSA512(publicKey, privateKey))
|
||||
|
||||
val authService = parseConfig(Array("--auth-jwt-es512-crt", certificatePath.toString))
|
||||
val metadata = createAuthMetadata(token)
|
||||
decodeAndCheckMetadata(authService, metadata)
|
||||
}
|
||||
|
||||
"parse and configure the authorisation mechanism correctly when `--auth-jwt-rs256-jwks <URL>` is passed" in {
|
||||
val (publicKey, privateKey) = newRsaKeyPair()
|
||||
val keyId = "test-key-1"
|
||||
val token = JWT.create().withKeyId(keyId).sign(Algorithm.RSA256(publicKey, privateKey))
|
||||
|
||||
// Start a JWKS server and create a verifier using the JWKS server
|
||||
val jwks = KeyUtils.generateJwks(
|
||||
Map(
|
||||
keyId -> publicKey
|
||||
)
|
||||
)
|
||||
|
||||
val server = SimpleHttpServer.start(jwks)
|
||||
Future {
|
||||
val url = SimpleHttpServer.responseUrl(server)
|
||||
val authService = parseConfig(Array("--auth-jwt-rs256-jwks", url))
|
||||
val metadata = createAuthMetadata(token)
|
||||
(authService, metadata)
|
||||
}.flatMap { case (authService, metadata) =>
|
||||
decodeAndCheckMetadata(authService, metadata)
|
||||
}.andThen { case _ =>
|
||||
SimpleHttpServer.stop(server)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object JwtVerifierConfigurationCliSpec {
|
||||
private def parseConfig(args: Array[String]): AuthService = {
|
||||
val parser = new OptionParser[AtomicReference[AuthService]]("test") {}
|
||||
JwtVerifierConfigurationCli.parse(parser) { (verifier, config) =>
|
||||
config.set(AuthServiceJWT(verifier, targetAudience = None, targetScope = None))
|
||||
config
|
||||
}
|
||||
parser.parse(args, new AtomicReference[AuthService](AuthServiceWildcard)).get.get()
|
||||
}
|
||||
|
||||
private def createAuthMetadata(token: String) = {
|
||||
val metadata = new Metadata()
|
||||
metadata.put(AuthService.AUTHORIZATION_KEY, s"Bearer $token")
|
||||
metadata
|
||||
}
|
||||
|
||||
private def decodeAndCheckMetadata(
|
||||
authService: AuthService,
|
||||
metadata: Metadata,
|
||||
)(implicit executionContext: ExecutionContext): Future[Assertion] = {
|
||||
import org.scalatest.Inside._
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
|
||||
authService.decodeMetadata(metadata).asScala.map { auth =>
|
||||
inside(auth) { case claims: Claims =>
|
||||
claims.claims should be(List(ClaimPublic))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def newRsaKeyPair(): (RSAPublicKey, RSAPrivateKey) = {
|
||||
val keyPair = newKeyPair("RSA", 2048)
|
||||
val publicKey = keyPair.getPublic.asInstanceOf[RSAPublicKey]
|
||||
val privateKey = keyPair.getPrivate.asInstanceOf[RSAPrivateKey]
|
||||
(publicKey, privateKey)
|
||||
}
|
||||
|
||||
private def newEcdsaKeyPair(): (ECPublicKey, ECPrivateKey) = {
|
||||
val keyPair = newKeyPair("ECDSA", 256)
|
||||
val publicKey = keyPair.getPublic.asInstanceOf[ECPublicKey]
|
||||
val privateKey = keyPair.getPrivate.asInstanceOf[ECPrivateKey]
|
||||
(publicKey, privateKey)
|
||||
}
|
||||
|
||||
private def newKeyPair(algorithm: String, keySize: Int): KeyPair = {
|
||||
val generator = KeyPairGenerator.getInstance(algorithm, BouncyCastleProvider.PROVIDER_NAME)
|
||||
generator.initialize(keySize)
|
||||
generator.generateKeyPair()
|
||||
}
|
||||
|
||||
private def newCertificate(
|
||||
signatureAlgorithm: String,
|
||||
directory: Path,
|
||||
publicKey: PublicKey,
|
||||
privateKey: PrivateKey,
|
||||
): Path = {
|
||||
val now = Instant.now()
|
||||
val dnName = new X500Name(s"CN=${getClass.getSimpleName}")
|
||||
val contentSigner = new JcaContentSignerBuilder(signatureAlgorithm).build(privateKey)
|
||||
val certBuilder = new JcaX509v3CertificateBuilder(
|
||||
dnName,
|
||||
BigInteger.valueOf(now.toEpochMilli),
|
||||
java.util.Date.from(now),
|
||||
java.util.Date.from(now.plusSeconds(60)),
|
||||
dnName,
|
||||
publicKey,
|
||||
)
|
||||
val certificate =
|
||||
new JcaX509CertificateConverter()
|
||||
.setProvider(BouncyCastleProvider.PROVIDER_NAME)
|
||||
.getCertificate(certBuilder.build(contentSigner))
|
||||
val certificatePath = directory.resolve("certificate")
|
||||
discard(Files.write(certificatePath, certificate.getEncoded))
|
||||
certificatePath
|
||||
}
|
||||
}
|
@ -1,130 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.oauth2
|
||||
|
||||
import org.apache.pekko.http.scaladsl.model.Uri
|
||||
import org.scalatest.matchers.should.Matchers
|
||||
import org.scalatest.wordspec.AsyncWordSpec
|
||||
import com.daml.bazeltools.BazelRunfiles.requiredResource
|
||||
import com.daml.jwt.JwksVerifier
|
||||
import com.daml.metrics.MetricsConfig
|
||||
import com.daml.metrics.api.reporters.MetricsReporter
|
||||
import org.scalatest.Inside.inside
|
||||
import pureconfig.error.{CannotReadFile, ConfigReaderFailures}
|
||||
|
||||
import java.nio.file.Paths
|
||||
import java.net.InetSocketAddress
|
||||
import scala.concurrent.duration._
|
||||
|
||||
class CliSpec extends AsyncWordSpec with Matchers {
|
||||
val minimalCfg = FileConfig(
|
||||
oauthAuth = Uri("https://oauth2/uri"),
|
||||
oauthToken = Uri("https://oauth2/token"),
|
||||
callbackUri = Some(Uri("https://example.com/auth/cb")),
|
||||
clientId = sys.env.getOrElse("DAML_CLIENT_ID", "foo"),
|
||||
clientSecret = SecretString(sys.env.getOrElse("DAML_CLIENT_SECRET", "bar")),
|
||||
tokenVerifier = null,
|
||||
)
|
||||
val confFile = "triggers/service/auth/src/test/resources/oauth2-middleware.conf"
|
||||
def loadCli(file: String): Cli = {
|
||||
Cli.parse(Array("--config", file)).getOrElse(fail("Could not load Cli on parse"))
|
||||
}
|
||||
|
||||
"should pickup the config file provided" in {
|
||||
val file = requiredResource(confFile)
|
||||
val cli = loadCli(file.getAbsolutePath)
|
||||
cli.configFile should not be empty
|
||||
}
|
||||
|
||||
"should take default values on loading minimal config" in {
|
||||
val file =
|
||||
requiredResource("triggers/service/auth/src/test/resources/oauth2-middleware-minimal.conf")
|
||||
val cli = loadCli(file.getAbsolutePath)
|
||||
cli.configFile should not be empty
|
||||
val cfg = cli.loadFromConfigFile
|
||||
inside(cfg) { case Some(Right(c)) =>
|
||||
c.copy(tokenVerifier = null) shouldBe minimalCfg
|
||||
// token verifier needs to be set.
|
||||
c.tokenVerifier shouldBe a[JwksVerifier]
|
||||
}
|
||||
}
|
||||
|
||||
"should be able to successfully load the config based on the file provided" in {
|
||||
val file = requiredResource(confFile)
|
||||
val cli = loadCli(file.getAbsolutePath)
|
||||
cli.configFile should not be empty
|
||||
val cfg = cli.loadFromConfigFile
|
||||
inside(cfg) { case Some(Right(c)) =>
|
||||
c.copy(tokenVerifier = null) shouldBe minimalCfg.copy(
|
||||
port = 3000,
|
||||
maxLoginRequests = 10,
|
||||
loginTimeout = 60.seconds,
|
||||
cookieSecure = false,
|
||||
oauthAuthTemplate = Some(Paths.get("auth_template")),
|
||||
oauthTokenTemplate = Some(Paths.get("token_template")),
|
||||
oauthRefreshTemplate = Some(Paths.get("refresh_template")),
|
||||
metrics = Some(
|
||||
MetricsConfig(
|
||||
MetricsReporter.Prometheus(new InetSocketAddress("0.0.0.0", 5104)),
|
||||
30.seconds,
|
||||
Seq.empty,
|
||||
)
|
||||
),
|
||||
)
|
||||
// token verifier needs to be set.
|
||||
c.tokenVerifier shouldBe a[JwksVerifier]
|
||||
}
|
||||
}
|
||||
|
||||
"parse should raise error on non-existent config file" in {
|
||||
val cli = loadCli("missingFile.conf")
|
||||
cli.configFile should not be empty
|
||||
val cfg = cli.loadFromConfigFile
|
||||
inside(cfg) { case Some(Left(ConfigReaderFailures(head))) =>
|
||||
head shouldBe a[CannotReadFile]
|
||||
}
|
||||
|
||||
// parseConfig for non-existent file should return a None
|
||||
Cli.parseConfig(
|
||||
Array(
|
||||
"--config",
|
||||
"missingFile.conf",
|
||||
)
|
||||
) shouldBe None
|
||||
}
|
||||
|
||||
"should load config from cli args on missing conf file " in {
|
||||
Cli
|
||||
.parseConfig(
|
||||
Array(
|
||||
"--oauth-auth",
|
||||
"file://foo",
|
||||
"--oauth-token",
|
||||
"file://bar",
|
||||
"--id",
|
||||
"foo",
|
||||
"--secret",
|
||||
"bar",
|
||||
"--auth-jwt-hs256-unsafe",
|
||||
"unsafe",
|
||||
)
|
||||
) should not be empty
|
||||
}
|
||||
|
||||
"should fail to load config from cli args on incomplete cli args" in {
|
||||
Cli
|
||||
.parseConfig(
|
||||
Array(
|
||||
"--oauth-auth",
|
||||
"file://foo",
|
||||
"--id",
|
||||
"foo",
|
||||
"--secret",
|
||||
"bar",
|
||||
"--auth-jwt-hs256-unsafe",
|
||||
"unsafe",
|
||||
)
|
||||
) shouldBe None
|
||||
}
|
||||
}
|
@ -1,108 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.oauth2
|
||||
|
||||
import java.io.File
|
||||
import java.nio.file.Files
|
||||
import java.time.{Clock, Duration, Instant, ZoneId}
|
||||
|
||||
import org.apache.pekko.actor.ActorSystem
|
||||
import org.apache.pekko.http.scaladsl.Http
|
||||
import org.apache.pekko.http.scaladsl.Http.ServerBinding
|
||||
import org.apache.pekko.http.scaladsl.model.{StatusCodes, Uri}
|
||||
import org.apache.pekko.http.scaladsl.server.Directives._
|
||||
import com.daml.auth.middleware.api.Client
|
||||
import com.daml.auth.middleware.api.Request.Claims
|
||||
import com.daml.clock.AdjustableClock
|
||||
import com.daml.ledger.resources.{Resource, ResourceContext, ResourceOwner}
|
||||
import com.daml.auth.oauth2.test.server.{Server => OAuthServer}
|
||||
import com.daml.scalautil.Statement.discard
|
||||
|
||||
import scala.concurrent.Future
|
||||
|
||||
object Resources {
|
||||
def clock(start: Instant, zoneId: ZoneId): ResourceOwner[AdjustableClock] =
|
||||
new ResourceOwner[AdjustableClock] {
|
||||
override def acquire()(implicit context: ResourceContext): Resource[AdjustableClock] = {
|
||||
Resource(Future(AdjustableClock(Clock.fixed(start, zoneId), Duration.ZERO)))(_ =>
|
||||
Future(())
|
||||
)
|
||||
}
|
||||
}
|
||||
def temporaryDirectory(): ResourceOwner[File] =
|
||||
new ResourceOwner[File] {
|
||||
override def acquire()(implicit context: ResourceContext): Resource[File] =
|
||||
Resource(Future(Files.createTempDirectory("daml-oauth2-middleware").toFile))(dir =>
|
||||
Future(discard { dir.delete() })
|
||||
)
|
||||
}
|
||||
def authServerBinding(
|
||||
server: OAuthServer
|
||||
)(implicit sys: ActorSystem): ResourceOwner[ServerBinding] =
|
||||
new ResourceOwner[ServerBinding] {
|
||||
override def acquire()(implicit context: ResourceContext): Resource[ServerBinding] =
|
||||
Resource(server.start())(_.unbind().map(_ => ()))
|
||||
}
|
||||
def authMiddlewareBinding(
|
||||
config: Config
|
||||
)(implicit sys: ActorSystem): ResourceOwner[ServerBinding] =
|
||||
new ResourceOwner[ServerBinding] {
|
||||
override def acquire()(implicit context: ResourceContext): Resource[ServerBinding] =
|
||||
Resource(Server.start(config, registerGlobalOpenTelemetry = false))(_.unbind().map(_ => ()))
|
||||
}
|
||||
def authMiddlewareClientBinding(client: Client, callbackPath: Uri.Path)(implicit
|
||||
sys: ActorSystem
|
||||
): ResourceOwner[ServerBinding] =
|
||||
new ResourceOwner[ServerBinding] {
|
||||
override def acquire()(implicit context: ResourceContext): Resource[ServerBinding] =
|
||||
Resource {
|
||||
Http()
|
||||
.newServerAt("localhost", 0)
|
||||
.bind {
|
||||
extractUri { reqUri =>
|
||||
val callbackUri = Uri()
|
||||
.withScheme(reqUri.scheme)
|
||||
.withAuthority(reqUri.authority)
|
||||
.withPath(callbackPath)
|
||||
val clientRoutes = client.routes(callbackUri)
|
||||
concat(
|
||||
path("authorize") {
|
||||
get {
|
||||
parameters(Symbol("claims").as[Claims]) { claims =>
|
||||
clientRoutes.authorize(claims) {
|
||||
case Client.Authorized(authorization) =>
|
||||
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
|
||||
import com.daml.auth.middleware.api.JsonProtocol.responseAuthorizeFormat
|
||||
complete(StatusCodes.OK, authorization)
|
||||
case Client.Unauthorized =>
|
||||
complete(StatusCodes.Unauthorized)
|
||||
case Client.LoginFailed(loginError) =>
|
||||
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
|
||||
import com.daml.auth.middleware.api.JsonProtocol.ResponseLoginFormat
|
||||
complete(
|
||||
StatusCodes.Forbidden,
|
||||
loginError: com.daml.auth.middleware.api.Response.Login,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
path("login") {
|
||||
get {
|
||||
parameters(Symbol("claims").as[Claims]) { claims =>
|
||||
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
|
||||
import com.daml.auth.middleware.api.JsonProtocol.ResponseLoginFormat
|
||||
clientRoutes.login(claims, login => complete(StatusCodes.OK, login))
|
||||
}
|
||||
}
|
||||
},
|
||||
path("cb") { get { clientRoutes.callbackHandler } },
|
||||
)
|
||||
}
|
||||
}
|
||||
} {
|
||||
_.unbind().map(_ => ())
|
||||
}
|
||||
}
|
||||
}
|
@ -1,156 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.oauth2
|
||||
|
||||
import java.io.File
|
||||
import java.time.{Instant, ZoneId}
|
||||
|
||||
import org.apache.pekko.http.scaladsl.Http.ServerBinding
|
||||
import org.apache.pekko.http.scaladsl.model.Uri
|
||||
import com.auth0.jwt.JWTVerifier.BaseVerification
|
||||
import com.auth0.jwt.JWT
|
||||
import com.auth0.jwt.algorithms.Algorithm
|
||||
import com.daml.auth.middleware.api.Client
|
||||
import com.daml.clock.AdjustableClock
|
||||
import com.daml.jwt.JwtVerifier
|
||||
import com.daml.ledger.api.testing.utils.{
|
||||
PekkoBeforeAndAfterAll,
|
||||
OwnedResource,
|
||||
Resource,
|
||||
SuiteResource,
|
||||
}
|
||||
import com.daml.ledger.resources.ResourceContext
|
||||
import com.daml.auth.oauth2.test.server.{Config => OAuthConfig, Server => OAuthServer}
|
||||
import com.daml.ports.Port
|
||||
import org.scalatest.{BeforeAndAfterEach, Suite}
|
||||
|
||||
import scala.concurrent.duration
|
||||
import scala.concurrent.duration.FiniteDuration
|
||||
|
||||
case class TestResources(
|
||||
clock: AdjustableClock,
|
||||
authServer: OAuthServer,
|
||||
authServerBinding: ServerBinding,
|
||||
authMiddlewarePortFile: File,
|
||||
authMiddlewareBinding: ServerBinding,
|
||||
authMiddlewareClient: Client,
|
||||
authMiddlewareClientBinding: ServerBinding,
|
||||
)
|
||||
|
||||
trait TestFixture
|
||||
extends PekkoBeforeAndAfterAll
|
||||
with BeforeAndAfterEach
|
||||
with SuiteResource[TestResources] {
|
||||
self: Suite =>
|
||||
protected val ledgerId: String = "test-ledger"
|
||||
protected val jwtSecret: String = "secret"
|
||||
protected val maxMiddlewareLogins: Int = Config.DefaultMaxLoginRequests
|
||||
protected val maxClientAuthCallbacks: Int = 1000
|
||||
protected val middlewareCallbackUri: Option[Uri] = None
|
||||
protected val middlewareClientCallbackPath: Uri.Path = Uri.Path./("cb")
|
||||
protected val redirectToLogin: Client.RedirectToLogin = Client.RedirectToLogin.Yes
|
||||
lazy protected val clock: AdjustableClock = suiteResource.value.clock
|
||||
lazy protected val server: OAuthServer = suiteResource.value.authServer
|
||||
lazy protected val serverBinding: ServerBinding = suiteResource.value.authServerBinding
|
||||
lazy protected val middlewarePortFile: File = suiteResource.value.authMiddlewarePortFile
|
||||
lazy protected val middlewareBinding: ServerBinding = suiteResource.value.authMiddlewareBinding
|
||||
lazy protected val middlewareClient: Client = suiteResource.value.authMiddlewareClient
|
||||
lazy protected val middlewareClientBinding: ServerBinding = {
|
||||
suiteResource.value.authMiddlewareClientBinding
|
||||
}
|
||||
lazy protected val middlewareClientCallbackUri: Uri = {
|
||||
val host = middlewareClientBinding.localAddress
|
||||
Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority("localhost", host.getPort)
|
||||
.withPath(middlewareClientCallbackPath)
|
||||
}
|
||||
lazy protected val middlewareClientRoutes: Client.Routes =
|
||||
middlewareClient.routes(middlewareClientCallbackUri)
|
||||
protected def oauthYieldsUserTokens: Boolean = true
|
||||
override protected lazy val suiteResource: Resource[TestResources] = {
|
||||
implicit val resourceContext: ResourceContext = ResourceContext(system.dispatcher)
|
||||
new OwnedResource[ResourceContext, TestResources](
|
||||
for {
|
||||
clock <- Resources.clock(Instant.now(), ZoneId.systemDefault())
|
||||
server = OAuthServer(
|
||||
OAuthConfig(
|
||||
port = Port.Dynamic,
|
||||
ledgerId = ledgerId,
|
||||
jwtSecret = jwtSecret,
|
||||
clock = Some(clock),
|
||||
yieldUserTokens = oauthYieldsUserTokens,
|
||||
)
|
||||
)
|
||||
serverBinding <- Resources.authServerBinding(server)
|
||||
serverUri = Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority(
|
||||
serverBinding.localAddress.getHostString,
|
||||
serverBinding.localAddress.getPort,
|
||||
)
|
||||
tempDir <- Resources.temporaryDirectory()
|
||||
middlewarePortFile = new File(tempDir, "port")
|
||||
middlewareBinding <- Resources.authMiddlewareBinding(
|
||||
Config(
|
||||
address = "localhost",
|
||||
port = 0,
|
||||
portFile = Some(middlewarePortFile.toPath),
|
||||
callbackUri = middlewareCallbackUri,
|
||||
maxLoginRequests = maxMiddlewareLogins,
|
||||
loginTimeout = Config.DefaultLoginTimeout,
|
||||
cookieSecure = Config.DefaultCookieSecure,
|
||||
oauthAuth = serverUri.withPath(Uri.Path./("authorize")),
|
||||
oauthToken = serverUri.withPath(Uri.Path./("token")),
|
||||
oauthAuthTemplate = None,
|
||||
oauthTokenTemplate = None,
|
||||
oauthRefreshTemplate = None,
|
||||
clientId = "middleware",
|
||||
clientSecret = SecretString("middleware-secret"),
|
||||
tokenVerifier = new JwtVerifier(
|
||||
JWT
|
||||
.require(Algorithm.HMAC256(jwtSecret))
|
||||
.asInstanceOf[BaseVerification]
|
||||
.build(clock)
|
||||
),
|
||||
histograms = Seq.empty,
|
||||
)
|
||||
)
|
||||
authUri = Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority(
|
||||
middlewareBinding.localAddress.getHostName,
|
||||
middlewareBinding.localAddress.getPort,
|
||||
)
|
||||
middlewareClientConfig = Client.Config(
|
||||
authMiddlewareInternalUri = authUri,
|
||||
authMiddlewareExternalUri = authUri,
|
||||
redirectToLogin = redirectToLogin,
|
||||
maxAuthCallbacks = maxClientAuthCallbacks,
|
||||
authCallbackTimeout = FiniteDuration(1, duration.MINUTES),
|
||||
maxHttpEntityUploadSize = 4194304,
|
||||
httpEntityUploadTimeout = FiniteDuration(1, duration.MINUTES),
|
||||
)
|
||||
middlewareClient = Client(middlewareClientConfig)
|
||||
middlewareClientBinding <- Resources
|
||||
.authMiddlewareClientBinding(middlewareClient, middlewareClientCallbackPath)
|
||||
} yield TestResources(
|
||||
clock = clock,
|
||||
authServer = server,
|
||||
authServerBinding = serverBinding,
|
||||
authMiddlewarePortFile = middlewarePortFile,
|
||||
authMiddlewareBinding = middlewareBinding,
|
||||
authMiddlewareClient = middlewareClient,
|
||||
authMiddlewareClientBinding = middlewareClientBinding,
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
override protected def afterEach(): Unit = {
|
||||
server.resetAuthorizedParties()
|
||||
server.resetAdmin()
|
||||
|
||||
super.afterEach()
|
||||
}
|
||||
}
|
@ -1,779 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.oauth2
|
||||
|
||||
import java.time.Duration
|
||||
import org.apache.pekko.http.scaladsl.Http
|
||||
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
|
||||
import org.apache.pekko.http.scaladsl.model._
|
||||
import org.apache.pekko.http.scaladsl.model.headers.{Cookie, Location, `Set-Cookie`}
|
||||
import org.apache.pekko.http.scaladsl.testkit.ScalatestRouteTest
|
||||
import org.apache.pekko.http.scaladsl.unmarshalling.Unmarshal
|
||||
import com.daml.auth.middleware.api.{Client, Request, Response}
|
||||
import com.daml.auth.middleware.api.Request.Claims
|
||||
import com.daml.auth.middleware.api.Tagged.{AccessToken, RefreshToken}
|
||||
import com.daml.jwt.JwtSigner
|
||||
import com.daml.jwt.domain.DecodedJwt
|
||||
import com.daml.ledger.api.auth.{
|
||||
AuthServiceJWTCodec,
|
||||
AuthServiceJWTPayload,
|
||||
CustomDamlJWTPayload,
|
||||
StandardJWTPayload,
|
||||
StandardJWTTokenFormat,
|
||||
}
|
||||
import com.daml.ledger.api.testing.utils.SuiteResourceManagementAroundAll
|
||||
import com.daml.lf.data.Ref
|
||||
import com.daml.auth.oauth2.api.{Response => OAuthResponse}
|
||||
import com.daml.test.evidence.tag.Security.SecurityTest.Property.{Authenticity, Authorization}
|
||||
import com.daml.test.evidence.tag.Security.{Attack, SecurityTest}
|
||||
import com.daml.test.evidence.scalatest.ScalaTestSupport.Implicits._
|
||||
import org.scalatest.{OptionValues, TryValues}
|
||||
import org.scalatest.matchers.should.Matchers
|
||||
import org.scalatest.wordspec.{AnyWordSpec, AsyncWordSpec}
|
||||
|
||||
import scala.collection.immutable
|
||||
import scala.collection.immutable.Seq
|
||||
import scala.concurrent.duration
|
||||
import scala.concurrent.duration.FiniteDuration
|
||||
import scala.io.Source
|
||||
import scala.util.{Failure, Success}
|
||||
|
||||
abstract class TestMiddleware
|
||||
extends AsyncWordSpec
|
||||
with TestFixture
|
||||
with SuiteResourceManagementAroundAll
|
||||
with Matchers
|
||||
with OptionValues {
|
||||
|
||||
val authenticationSecurity: SecurityTest =
|
||||
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
|
||||
|
||||
val authorizationSecurity: SecurityTest =
|
||||
SecurityTest(property = Authorization, asset = "OAuth2 Middleware")
|
||||
|
||||
protected[this] def makeJwt(
|
||||
claims: Request.Claims,
|
||||
expiresIn: Option[Duration],
|
||||
): AuthServiceJWTPayload
|
||||
|
||||
protected[this] def makeToken(
|
||||
claims: Request.Claims,
|
||||
secret: String = "secret",
|
||||
expiresIn: Option[Duration] = None,
|
||||
): OAuthResponse.Token = {
|
||||
val jwtHeader = """{"alg": "HS256", "typ": "JWT"}"""
|
||||
val jwtPayload = makeJwt(claims, expiresIn)
|
||||
OAuthResponse.Token(
|
||||
accessToken = JwtSigner.HMAC256
|
||||
.sign(DecodedJwt(jwtHeader, AuthServiceJWTCodec.compactPrint(jwtPayload)), secret)
|
||||
.getOrElse(
|
||||
throw new IllegalArgumentException("Failed to sign a token")
|
||||
)
|
||||
.value,
|
||||
tokenType = "bearer",
|
||||
expiresIn = expiresIn.map(in => in.getSeconds.toInt),
|
||||
refreshToken = None,
|
||||
scope = Some(claims.toQueryString()),
|
||||
)
|
||||
}
|
||||
|
||||
"the port file" should {
|
||||
"list the HTTP port" in {
|
||||
val bindingPort = middlewareBinding.localAddress.getPort.toString
|
||||
val filePort = {
|
||||
val source = Source.fromFile(middlewarePortFile)
|
||||
try {
|
||||
source.mkString.stripLineEnd
|
||||
} finally {
|
||||
source.close()
|
||||
}
|
||||
}
|
||||
bindingPort should ===(filePort)
|
||||
}
|
||||
}
|
||||
"the /auth endpoint" should {
|
||||
"return unauthorized without cookie" taggedAs authorizationSecurity in {
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
for {
|
||||
result <- middlewareClient.requestAuth(claims, Nil)
|
||||
} yield {
|
||||
result should ===(None)
|
||||
}
|
||||
}
|
||||
"return the token from a cookie" taggedAs authorizationSecurity in {
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
val token = makeToken(claims)
|
||||
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
|
||||
for {
|
||||
result <- middlewareClient.requestAuth(claims, List(cookieHeader))
|
||||
auth = result.value
|
||||
} yield {
|
||||
auth.accessToken should ===(token.accessToken)
|
||||
auth.refreshToken should ===(token.refreshToken)
|
||||
}
|
||||
}
|
||||
"return unauthorized on insufficient app id claims" taggedAs authorizationSecurity in {
|
||||
val claims = Request.Claims(
|
||||
actAs = List(Ref.Party.assertFromString("Alice")),
|
||||
applicationId = Some(Ref.ApplicationId.assertFromString("other-id")),
|
||||
)
|
||||
val token = makeToken(
|
||||
Request.Claims(
|
||||
actAs = List(Ref.Party.assertFromString("Alice")),
|
||||
applicationId = Some(Ref.ApplicationId.assertFromString("my-app-id")),
|
||||
)
|
||||
)
|
||||
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
|
||||
for {
|
||||
result <- middlewareClient.requestAuth(claims, List(cookieHeader))
|
||||
} yield {
|
||||
result should ===(None)
|
||||
}
|
||||
}
|
||||
"return unauthorized on an invalid token" taggedAs authorizationSecurity in {
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
val token = makeToken(claims, "wrong-secret")
|
||||
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
|
||||
for {
|
||||
result <- middlewareClient.requestAuth(claims, List(cookieHeader))
|
||||
} yield {
|
||||
result should ===(None)
|
||||
}
|
||||
}
|
||||
"return unauthorized on an expired token" taggedAs authorizationSecurity in {
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
val token = makeToken(claims, expiresIn = Some(Duration.ZERO))
|
||||
val _ = clock.fastForward(Duration.ofSeconds(1))
|
||||
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
|
||||
for {
|
||||
result <- middlewareClient.requestAuth(claims, List(cookieHeader))
|
||||
} yield {
|
||||
result should ===(None)
|
||||
}
|
||||
}
|
||||
|
||||
"accept user tokens" taggedAs authorizationSecurity in {
|
||||
import com.daml.auth.middleware.oauth2.Server.rightsProvideClaims
|
||||
rightsProvideClaims(
|
||||
StandardJWTPayload(
|
||||
None,
|
||||
"foo",
|
||||
None,
|
||||
None,
|
||||
StandardJWTTokenFormat.Scope,
|
||||
List.empty,
|
||||
Some("daml_ledger_api"),
|
||||
),
|
||||
Claims(
|
||||
admin = true,
|
||||
actAs = List(Ref.Party.assertFromString("Alice")),
|
||||
readAs = List(Ref.Party.assertFromString("Bob")),
|
||||
applicationId = Some(Ref.ApplicationId.assertFromString("foo")),
|
||||
),
|
||||
) should ===(true)
|
||||
}
|
||||
}
|
||||
"the /login endpoint" should {
|
||||
"redirect and set cookie" taggedAs authenticationSecurity.setHappyCase(
|
||||
"A valid request to /login redirects to client callback and sets cookie"
|
||||
) in {
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
val req = HttpRequest(uri = middlewareClientRoutes.loginUri(claims, None))
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
// Redirect to /authorize on authorization server
|
||||
resp <- {
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val req = HttpRequest(uri = resp.header[Location].value.uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
// Redirect to /cb on middleware
|
||||
resp <- {
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val req = HttpRequest(uri = resp.header[Location].value.uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
} yield {
|
||||
// Redirect to client callback
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
resp.header[Location].value.uri should ===(middlewareClientCallbackUri)
|
||||
// Store token in cookie
|
||||
val cookie = resp.header[`Set-Cookie`].value.cookie
|
||||
cookie.name should ===("daml-ledger-token")
|
||||
val token = OAuthResponse.Token.fromCookieValue(cookie.value).value
|
||||
token.tokenType should ===("bearer")
|
||||
}
|
||||
}
|
||||
"return OK and set cookie without redirectUri" taggedAs authenticationSecurity.setHappyCase(
|
||||
"A valid request to /login returns OK and sets cookie when redirect is off"
|
||||
) in {
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
val req = HttpRequest(uri = middlewareClientRoutes.loginUri(claims, None, redirect = false))
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
// Redirect to /authorize on authorization server
|
||||
resp <- {
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val req = HttpRequest(uri = resp.header[Location].value.uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
// Redirect to /cb on middleware
|
||||
resp <- {
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val req = HttpRequest(uri = resp.header[Location].value.uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
} yield {
|
||||
// Return OK
|
||||
resp.status should ===(StatusCodes.OK)
|
||||
// Store token in cookie
|
||||
val cookie = resp.header[`Set-Cookie`].value.cookie
|
||||
cookie.name should ===("daml-ledger-token")
|
||||
val token = OAuthResponse.Token.fromCookieValue(cookie.value).value
|
||||
token.tokenType should ===("bearer")
|
||||
}
|
||||
}
|
||||
}
|
||||
"the /refresh endpoint" should {
|
||||
"return a new access token" taggedAs authorizationSecurity.setHappyCase(
|
||||
"A valid request to /refresh returns a new access token"
|
||||
) in {
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
val loginReq = HttpRequest(uri = middlewareClientRoutes.loginUri(claims, None))
|
||||
for {
|
||||
resp <- Http().singleRequest(loginReq)
|
||||
// Redirect to /authorize on authorization server
|
||||
resp <- {
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val req = HttpRequest(uri = resp.header[Location].value.uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
// Redirect to /cb on middleware
|
||||
resp <- {
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val req = HttpRequest(uri = resp.header[Location].value.uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
// Extract token from cookie
|
||||
(token1, refreshToken) = {
|
||||
val cookie = resp.header[`Set-Cookie`].value.cookie
|
||||
val token = OAuthResponse.Token.fromCookieValue(cookie.value).value
|
||||
(AccessToken(token.accessToken), RefreshToken(token.refreshToken.value))
|
||||
}
|
||||
// Advance time
|
||||
_ = clock.fastForward(Duration.ofSeconds(1))
|
||||
// Request /refresh
|
||||
authorize <- middlewareClient.requestRefresh(refreshToken)
|
||||
} yield {
|
||||
// Test that we got a new access token
|
||||
authorize.accessToken should !==(token1)
|
||||
// Test that we got a new refresh token
|
||||
authorize.refreshToken.value should !==(refreshToken)
|
||||
}
|
||||
}
|
||||
"fail on an invalid refresh token" taggedAs authorizationSecurity.setAttack(
|
||||
Attack("HTTP Client", "Presents an invalid refresh token", "refuse request with CLIENT_ERROR")
|
||||
) in {
|
||||
for {
|
||||
exception <- middlewareClient.requestRefresh(RefreshToken("made-up-token")).transform {
|
||||
case Failure(exception: Client.RefreshException) => Success(exception)
|
||||
case value => fail(s"Expected failure with RefreshException but got $value")
|
||||
}
|
||||
} yield {
|
||||
exception.status shouldBe a[StatusCodes.ClientError]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TestMiddlewareClaimsToken extends TestMiddleware {
|
||||
override protected[this] def oauthYieldsUserTokens = false
|
||||
override protected[this] def makeJwt(
|
||||
claims: Request.Claims,
|
||||
expiresIn: Option[Duration],
|
||||
): AuthServiceJWTPayload =
|
||||
CustomDamlJWTPayload(
|
||||
ledgerId = Some("test-ledger"),
|
||||
applicationId = Some("test-application"),
|
||||
participantId = None,
|
||||
exp = expiresIn.map(in => clock.instant.plus(in)),
|
||||
admin = claims.admin,
|
||||
actAs = claims.actAs,
|
||||
readAs = claims.readAs,
|
||||
)
|
||||
|
||||
"the /auth endpoint given claim token" should {
|
||||
"return unauthorized on insufficient party claims" taggedAs authorizationSecurity in {
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Bob")))
|
||||
def r(actAs: String*)(readAs: String*) =
|
||||
middlewareClient
|
||||
.requestAuth(
|
||||
claims,
|
||||
Seq(
|
||||
Cookie(
|
||||
"daml-ledger-token",
|
||||
makeToken(
|
||||
Request.Claims(
|
||||
actAs = actAs.map(Ref.Party.assertFromString).toList,
|
||||
readAs = readAs.map(Ref.Party.assertFromString).toList,
|
||||
)
|
||||
).toCookieValue,
|
||||
)
|
||||
),
|
||||
)
|
||||
for {
|
||||
aliceA <- r("Alice")()
|
||||
nothing <- r()()
|
||||
aliceA_bobA <- r("Alice", "Bob")()
|
||||
aliceA_bobR <- r("Alice")("Bob")
|
||||
aliceR_bobA <- r("Bob")("Alice")
|
||||
aliceR_bobR <- r()("Alice", "Bob")
|
||||
bobA <- r("Bob")()
|
||||
bobR <- r()("Bob")
|
||||
bobAR <- r("Bob")("Bob")
|
||||
} yield {
|
||||
aliceA shouldBe empty
|
||||
nothing shouldBe empty
|
||||
aliceA_bobA should not be empty
|
||||
aliceA_bobR shouldBe empty
|
||||
aliceR_bobA should not be empty
|
||||
aliceR_bobR shouldBe empty
|
||||
bobA should not be empty
|
||||
bobR shouldBe empty
|
||||
bobAR should not be empty
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
"the /login endpoint with an oauth server checking claims" should {
|
||||
"not authorize unauthorized parties" taggedAs authorizationSecurity in {
|
||||
server.revokeParty(Ref.Party.assertFromString("Eve"))
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Eve")))
|
||||
ensureDisallowed(claims)
|
||||
}
|
||||
|
||||
"not authorize disallowed admin claims" taggedAs authorizationSecurity in {
|
||||
server.revokeAdmin()
|
||||
val claims = Request.Claims(admin = true)
|
||||
ensureDisallowed(claims)
|
||||
}
|
||||
|
||||
def ensureDisallowed(claims: Request.Claims) = {
|
||||
val req = HttpRequest(uri = middlewareClientRoutes.loginUri(claims, None))
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
// Redirect to /authorize on authorization server
|
||||
resp <- {
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val req = HttpRequest(uri = resp.header[Location].value.uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
// Redirect to /cb on middleware
|
||||
resp <- {
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val req = HttpRequest(uri = resp.header[Location].value.uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
} yield {
|
||||
// Redirect to client callback
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
resp.header[Location].value.uri.withQuery(Uri.Query()) should ===(
|
||||
middlewareClientCallbackUri
|
||||
)
|
||||
// with error parameter set
|
||||
resp.header[Location].value.uri.query().toMap.get("error") should ===(Some("access_denied"))
|
||||
// Without token in cookie
|
||||
val cookie = resp.header[`Set-Cookie`]
|
||||
cookie should ===(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TestMiddlewareUserToken extends TestMiddleware {
|
||||
override protected[this] def makeJwt(
|
||||
claims: Request.Claims,
|
||||
expiresIn: Option[Duration],
|
||||
): AuthServiceJWTPayload =
|
||||
StandardJWTPayload(
|
||||
issuer = None,
|
||||
userId = "test-application",
|
||||
participantId = None,
|
||||
exp = expiresIn.map(in => clock.instant.plus(in)),
|
||||
format = StandardJWTTokenFormat.Scope,
|
||||
audiences = List.empty,
|
||||
scope = Some("daml_ledger_api"),
|
||||
)
|
||||
}
|
||||
|
||||
class TestMiddlewareCallbackUriOverride
|
||||
extends AsyncWordSpec
|
||||
with Matchers
|
||||
with OptionValues
|
||||
with TestFixture
|
||||
with SuiteResourceManagementAroundAll {
|
||||
val authenticationSecurity: SecurityTest =
|
||||
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
|
||||
|
||||
override protected val middlewareCallbackUri = Some(Uri("http://localhost/MIDDLEWARE_CALLBACK"))
|
||||
"the /login endpoint with an oauth server checking claims" should {
|
||||
"redirect to the configured middleware callback URI" taggedAs authenticationSecurity
|
||||
.setHappyCase(
|
||||
"A valid request to /login redirects to middleware callback"
|
||||
) in {
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
val req = HttpRequest(uri = middlewareClientRoutes.loginUri(claims, None))
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
// Redirect to /authorize on authorization server
|
||||
resp <- {
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val req = HttpRequest(uri = resp.header[Location].value.uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
} yield {
|
||||
// Redirect to configured callback URI on middleware
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
resp.header[Location].value.uri.withQuery(Uri.Query()) should ===(
|
||||
middlewareCallbackUri.value
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TestMiddlewareLimitedCallbackStore
|
||||
extends AsyncWordSpec
|
||||
with Matchers
|
||||
with OptionValues
|
||||
with TestFixture
|
||||
with SuiteResourceManagementAroundAll {
|
||||
val authenticationSecurity: SecurityTest =
|
||||
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
|
||||
|
||||
override protected val maxMiddlewareLogins = 2
|
||||
"the /login endpoint with an oauth server checking claims" should {
|
||||
"refuse requests when max capacity is reached" taggedAs authenticationSecurity.setAttack(
|
||||
Attack(
|
||||
"HTTP client",
|
||||
"Issues too many requests to the /login endpoint",
|
||||
"Refuse request with SERVICE_UNAVAILABLE",
|
||||
)
|
||||
) in {
|
||||
def login(actAs: Ref.Party) = {
|
||||
val claims = Request.Claims(actAs = List(actAs))
|
||||
val uri = middlewareClientRoutes.loginUri(claims, None)
|
||||
val req = HttpRequest(uri = uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
|
||||
def followRedirect(resp: HttpResponse) = {
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val uri = resp.header[Location].value.uri
|
||||
val req = HttpRequest(uri = uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
|
||||
for {
|
||||
// Follow login flows up to redirect to middleware callback.
|
||||
redirectAlice <- login(Ref.Party.assertFromString("Alice"))
|
||||
.flatMap(followRedirect)
|
||||
redirectBob <- login(Ref.Party.assertFromString("Bob"))
|
||||
.flatMap(followRedirect)
|
||||
// The store should be full
|
||||
refusedCarol <- login(Ref.Party.assertFromString("Carol"))
|
||||
_ = refusedCarol.status should ===(StatusCodes.ServiceUnavailable)
|
||||
// Follow first redirect to middleware callback.
|
||||
resultAlice <- followRedirect(redirectAlice)
|
||||
_ = resultAlice.status should ===(StatusCodes.Found)
|
||||
// The store should have space again
|
||||
redirectCarol <- login(Ref.Party.assertFromString("Carol"))
|
||||
.flatMap(followRedirect)
|
||||
// Follow redirects to middleware callback.
|
||||
resultBob <- followRedirect(redirectBob)
|
||||
resultCarol <- followRedirect(redirectCarol)
|
||||
} yield {
|
||||
resultBob.status should ===(StatusCodes.Found)
|
||||
resultCarol.status should ===(StatusCodes.Found)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TestMiddlewareClientLimitedCallbackStore
|
||||
extends AsyncWordSpec
|
||||
with Matchers
|
||||
with OptionValues
|
||||
with TestFixture
|
||||
with SuiteResourceManagementAroundAll {
|
||||
val authenticationSecurity: SecurityTest =
|
||||
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
|
||||
|
||||
override protected val maxClientAuthCallbacks = 2
|
||||
"the /login client with an oauth server checking claims" should {
|
||||
"refuse requests when max capacity is reached" taggedAs authenticationSecurity.setAttack(
|
||||
Attack(
|
||||
"HTTP client",
|
||||
"Issues too many requests to the /login endpoint",
|
||||
"Refuse request with SERVICE_UNAVAILABLE",
|
||||
)
|
||||
) in {
|
||||
def login(actAs: Ref.Party) = {
|
||||
val claims = Request.Claims(actAs = List(actAs))
|
||||
val host = middlewareClientBinding.localAddress
|
||||
val uri = Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority(host.getHostName, host.getPort)
|
||||
.withPath(Uri.Path./("login"))
|
||||
.withQuery(Uri.Query("claims" -> claims.toQueryString()))
|
||||
val req = HttpRequest(uri = uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
|
||||
def followRedirect(resp: HttpResponse) = {
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val uri = resp.header[Location].value.uri
|
||||
val req = HttpRequest(uri = uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
|
||||
for {
|
||||
// Follow login flows up to last redirect to middleware client.
|
||||
redirectAlice <- login(Ref.Party.assertFromString("Alice"))
|
||||
.flatMap(followRedirect)
|
||||
.flatMap(followRedirect)
|
||||
.flatMap(followRedirect)
|
||||
redirectBob <- login(Ref.Party.assertFromString("Bob"))
|
||||
.flatMap(followRedirect)
|
||||
.flatMap(followRedirect)
|
||||
.flatMap(followRedirect)
|
||||
// The store should be full
|
||||
refusedCarol <- login(Ref.Party.assertFromString("Carol"))
|
||||
_ = refusedCarol.status should ===(StatusCodes.ServiceUnavailable)
|
||||
// Follow first redirect to middleware client.
|
||||
resultAlice <- followRedirect(redirectAlice)
|
||||
_ = resultAlice.status should ===(StatusCodes.OK)
|
||||
// The store should have space again
|
||||
redirectCarol <- login(Ref.Party.assertFromString("Carol"))
|
||||
.flatMap(followRedirect)
|
||||
.flatMap(followRedirect)
|
||||
.flatMap(followRedirect)
|
||||
resultBob <- followRedirect(redirectBob)
|
||||
resultCarol <- followRedirect(redirectCarol)
|
||||
} yield {
|
||||
resultBob.status should ===(StatusCodes.OK)
|
||||
resultCarol.status should ===(StatusCodes.OK)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TestMiddlewareClientNoRedirectToLogin
|
||||
extends AsyncWordSpec
|
||||
with Matchers
|
||||
with OptionValues
|
||||
with TryValues
|
||||
with TestFixture
|
||||
with SuiteResourceManagementAroundAll {
|
||||
val authenticationSecurity: SecurityTest =
|
||||
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
|
||||
|
||||
override protected val redirectToLogin: Client.RedirectToLogin = Client.RedirectToLogin.No
|
||||
"the TestMiddlewareClientNoRedirectToLogin client" should {
|
||||
"not redirect to /login" taggedAs authenticationSecurity.setHappyCase(
|
||||
"An unauthorized request should not redirect to /login using the JSON protocol"
|
||||
) in {
|
||||
import com.daml.auth.middleware.api.JsonProtocol.responseAuthenticateChallengeFormat
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
val host = middlewareClientBinding.localAddress
|
||||
val uri = Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority(host.getHostName, host.getPort)
|
||||
.withPath(Uri.Path./("authorize"))
|
||||
.withQuery(Uri.Query("claims" -> claims.toQueryString()))
|
||||
val req = HttpRequest(uri = uri)
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
// Unauthorized with WWW-Authenticate header
|
||||
_ = resp.status should ===(StatusCodes.Unauthorized)
|
||||
wwwAuthenticate = resp.header[headers.`WWW-Authenticate`].value
|
||||
challenge = wwwAuthenticate.challenges
|
||||
.find(_.scheme == Response.authenticateChallengeName)
|
||||
.value
|
||||
_ = challenge.params.keys should contain.allOf("auth", "login")
|
||||
authUri = challenge.params.get("auth").value
|
||||
loginUri = challenge.params.get("login").value
|
||||
headerChallenge = Response.AuthenticateChallenge(
|
||||
Request.Claims(challenge.realm),
|
||||
loginUri,
|
||||
authUri,
|
||||
)
|
||||
// The body should include the same challenge
|
||||
bodyChallenge <- Unmarshal(resp).to[Response.AuthenticateChallenge]
|
||||
} yield {
|
||||
headerChallenge.auth should ===(middlewareClient.authUri(claims))
|
||||
headerChallenge.login.withQuery(Uri.Query.Empty) should ===(
|
||||
middlewareClientRoutes.loginUri(claims).withQuery(Uri.Query.Empty)
|
||||
)
|
||||
headerChallenge.login.query().get("claims").value should ===(claims.toQueryString())
|
||||
bodyChallenge should ===(headerChallenge)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TestMiddlewareClientYesRedirectToLogin
|
||||
extends AsyncWordSpec
|
||||
with Matchers
|
||||
with OptionValues
|
||||
with TestFixture
|
||||
with SuiteResourceManagementAroundAll {
|
||||
val authenticationSecurity: SecurityTest =
|
||||
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
|
||||
|
||||
override protected val redirectToLogin: Client.RedirectToLogin = Client.RedirectToLogin.Yes
|
||||
"the TestMiddlewareClientYesRedirectToLogin client" should {
|
||||
"redirect to /login" taggedAs authenticationSecurity.setHappyCase(
|
||||
"A valid HTTP request should redirect to /login"
|
||||
) in {
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
val host = middlewareClientBinding.localAddress
|
||||
val uri = Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority(host.getHostName, host.getPort)
|
||||
.withPath(Uri.Path./("authorize"))
|
||||
.withQuery(Uri.Query("claims" -> claims.toQueryString()))
|
||||
val req = HttpRequest(uri = uri)
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
} yield {
|
||||
// Redirect to /login on middleware
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val loginUri = resp.header[headers.Location].value.uri
|
||||
loginUri.withQuery(Uri.Query.Empty) should ===(
|
||||
middlewareClientRoutes
|
||||
.loginUri(claims)
|
||||
.withQuery(Uri.Query.Empty)
|
||||
)
|
||||
loginUri.query().get("claims").value should ===(claims.toQueryString())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TestMiddlewareClientAutoRedirectToLogin
|
||||
extends AsyncWordSpec
|
||||
with Matchers
|
||||
with TestFixture
|
||||
with SuiteResourceManagementAroundAll {
|
||||
val authenticationSecurity: SecurityTest =
|
||||
SecurityTest(property = Authenticity, asset = "OAuth2 Middleware")
|
||||
|
||||
override protected val redirectToLogin: Client.RedirectToLogin = Client.RedirectToLogin.Auto
|
||||
"the TestMiddlewareClientAutoRedirectToLogin client" should {
|
||||
"redirect to /login for HTML request" taggedAs authenticationSecurity.setHappyCase(
|
||||
"A HTML request is redirected to /login"
|
||||
) in {
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
val host = middlewareClientBinding.localAddress
|
||||
val uri = Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority(host.getHostName, host.getPort)
|
||||
.withPath(Uri.Path./("authorize"))
|
||||
.withQuery(Uri.Query("claims" -> claims.toQueryString()))
|
||||
val acceptHtml: HttpHeader = headers.Accept(MediaTypes.`text/html`)
|
||||
val req = HttpRequest(uri = uri, headers = immutable.Seq(acceptHtml))
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
} yield {
|
||||
// Redirect to /login on middleware
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
}
|
||||
}
|
||||
"not redirect to /login for JSON request" taggedAs authenticationSecurity.setHappyCase(
|
||||
"A JSON request is not redirected to /login"
|
||||
) in {
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
val host = middlewareClientBinding.localAddress
|
||||
val uri = Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority(host.getHostName, host.getPort)
|
||||
.withPath(Uri.Path./("authorize"))
|
||||
.withQuery(Uri.Query("claims" -> claims.toQueryString()))
|
||||
val acceptHtml: HttpHeader = headers.Accept(MediaTypes.`application/json`)
|
||||
val req = HttpRequest(uri = uri, headers = immutable.Seq(acceptHtml))
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
} yield {
|
||||
// Unauthorized with WWW-Authenticate header
|
||||
resp.status should ===(StatusCodes.Unauthorized)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TestMiddlewareClientLoginCallbackUri
|
||||
extends AnyWordSpec
|
||||
with Matchers
|
||||
with ScalatestRouteTest {
|
||||
private val client = Client(
|
||||
Client.Config(
|
||||
authMiddlewareInternalUri = Uri("http://auth.internal"),
|
||||
authMiddlewareExternalUri = Uri("http://auth.external"),
|
||||
redirectToLogin = Client.RedirectToLogin.Yes,
|
||||
maxAuthCallbacks = 1000,
|
||||
authCallbackTimeout = FiniteDuration(1, duration.MINUTES),
|
||||
maxHttpEntityUploadSize = 4194304,
|
||||
httpEntityUploadTimeout = FiniteDuration(1, duration.MINUTES),
|
||||
)
|
||||
)
|
||||
"fixed callback URI" should {
|
||||
"be absolute" in {
|
||||
an[AssertionError] should be thrownBy client.routes(Uri().withPath(Uri.Path./("cb")))
|
||||
}
|
||||
"be used in login URI" in {
|
||||
val routes = client.routes(Uri("http://client.domain/cb"))
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
routes.loginUri(claims = claims) shouldBe
|
||||
Uri(
|
||||
s"http://auth.external/login?claims=${claims.toQueryString()}&redirect_uri=http://client.domain/cb"
|
||||
)
|
||||
}
|
||||
}
|
||||
"callback URI from request" should {
|
||||
"be used in login URI" in {
|
||||
val routes = client.routesFromRequestAuthority(Uri.Path./("cb"))
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
import org.apache.pekko.http.scaladsl.server.directives.RouteDirectives._
|
||||
Get("http://client.domain") ~> routes { routes =>
|
||||
complete(routes.loginUri(claims).toString)
|
||||
} ~> check {
|
||||
responseAs[String] shouldEqual
|
||||
s"http://auth.external/login?claims=${claims.toQueryString()}&redirect_uri=http://client.domain/cb"
|
||||
}
|
||||
}
|
||||
}
|
||||
"automatic callback URI" should {
|
||||
"be fixed when absolute" in {
|
||||
val routes = client.routesAuto(Uri("http://client.domain/cb"))
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
import org.apache.pekko.http.scaladsl.server.directives.RouteDirectives._
|
||||
Get() ~> routes { routes => complete(routes.loginUri(claims).toString) } ~> check {
|
||||
responseAs[String] shouldEqual
|
||||
s"http://auth.external/login?claims=${claims.toQueryString()}&redirect_uri=http://client.domain/cb"
|
||||
}
|
||||
}
|
||||
"be from request when relative" in {
|
||||
val routes = client.routesAuto(Uri().withPath(Uri.Path./("cb")))
|
||||
val claims = Request.Claims(actAs = List(Ref.Party.assertFromString("Alice")))
|
||||
import org.apache.pekko.http.scaladsl.server.directives.RouteDirectives._
|
||||
Get("http://client.domain") ~> routes { routes =>
|
||||
complete(routes.loginUri(claims).toString)
|
||||
} ~> check {
|
||||
responseAs[String] shouldEqual
|
||||
s"http://auth.external/login?claims=${claims.toQueryString()}&redirect_uri=http://client.domain/cb"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,68 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.oauth2
|
||||
|
||||
import com.daml.auth.middleware.api.RequestStore
|
||||
import org.scalatest.matchers.should.Matchers
|
||||
import org.scalatest.wordspec.AsyncWordSpec
|
||||
|
||||
import scala.concurrent.duration._
|
||||
|
||||
class TestRequestStore extends AsyncWordSpec with Matchers {
|
||||
"return None on missing element" in {
|
||||
val store = new RequestStore[Int, String](1, 1.day)
|
||||
store.pop(0) should ===(None)
|
||||
}
|
||||
|
||||
"return previously put element" in {
|
||||
val store = new RequestStore[Int, String](1, 1.day)
|
||||
store.put(0, "zero")
|
||||
store.pop(0) should ===(Some("zero"))
|
||||
}
|
||||
|
||||
"return None on previously popped element" in {
|
||||
val store = new RequestStore[Int, String](1, 1.day)
|
||||
store.put(0, "zero")
|
||||
store.pop(0)
|
||||
store.pop(0) should ===(None)
|
||||
}
|
||||
|
||||
"store multiple elements" in {
|
||||
val store = new RequestStore[Int, String](3, 1.day)
|
||||
store.put(0, "zero")
|
||||
store.put(1, "one")
|
||||
store.put(2, "two")
|
||||
store.pop(0) should ===(Some("zero"))
|
||||
store.pop(1) should ===(Some("one"))
|
||||
store.pop(2) should ===(Some("two"))
|
||||
}
|
||||
|
||||
"store no more than max capacity" in {
|
||||
val store = new RequestStore[Int, String](2, 1.day)
|
||||
assert(store.put(0, "zero"))
|
||||
assert(store.put(1, "one"))
|
||||
assert(!store.put(2, "two"))
|
||||
store.pop(0) should ===(Some("zero"))
|
||||
store.pop(1) should ===(Some("one"))
|
||||
store.pop(2) should ===(None)
|
||||
}
|
||||
|
||||
"return None on timed out element" in {
|
||||
var time: Long = 0
|
||||
val store = new RequestStore[Int, String](1, 1.day, () => time)
|
||||
store.put(0, "zero")
|
||||
time += 1.day.toNanos
|
||||
store.pop(0) should ===(None)
|
||||
}
|
||||
|
||||
"free capacity for timed out elements" in {
|
||||
var time: Long = 0
|
||||
val store = new RequestStore[Int, String](1, 1.day, () => time)
|
||||
assert(store.put(0, "zero"))
|
||||
assert(!store.put(1, "one"))
|
||||
time += 1.day.toNanos
|
||||
assert(store.put(2, "two"))
|
||||
store.pop(2) should ===(Some("two"))
|
||||
}
|
||||
}
|
@ -1,236 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.middleware.oauth2
|
||||
|
||||
import java.io._
|
||||
import java.nio.file.Path
|
||||
import java.util.UUID
|
||||
import org.apache.pekko.http.scaladsl.model.Uri
|
||||
import com.daml.auth.middleware.api.Request.Claims
|
||||
import com.daml.auth.middleware.api.Tagged.RefreshToken
|
||||
import com.daml.lf.data.Ref
|
||||
import com.daml.scalautil.Statement.discard
|
||||
import org.scalatest.{PartialFunctionValues, TryValues}
|
||||
import org.scalatest.matchers.should.Matchers
|
||||
import org.scalatest.wordspec.AnyWordSpec
|
||||
|
||||
class TestRequestTemplates
|
||||
extends AnyWordSpec
|
||||
with Matchers
|
||||
with TryValues
|
||||
with PartialFunctionValues {
|
||||
private val clientId = "client-id"
|
||||
private val clientSecret = SecretString("client-secret")
|
||||
|
||||
private def getTemplates(
|
||||
authTemplate: Option[Path] = None,
|
||||
tokenTemplate: Option[Path] = None,
|
||||
refreshTemplate: Option[Path] = None,
|
||||
): RequestTemplates =
|
||||
RequestTemplates(
|
||||
clientId = clientId,
|
||||
clientSecret = clientSecret,
|
||||
authTemplate = authTemplate,
|
||||
tokenTemplate = tokenTemplate,
|
||||
refreshTemplate = refreshTemplate,
|
||||
)
|
||||
|
||||
private def withJsonnetFile(content: String)(testCode: Path => Any): Unit = {
|
||||
val file = File.createTempFile("test-request-template", ".jsonnet")
|
||||
val writer = new FileWriter(file)
|
||||
try {
|
||||
writer.write(content)
|
||||
writer.close()
|
||||
discard(testCode(file.toPath))
|
||||
} finally discard(file.delete())
|
||||
}
|
||||
|
||||
"the builtin auth template" should {
|
||||
"handle empty claims" in {
|
||||
val templates = getTemplates()
|
||||
val claims = Claims(admin = false, actAs = Nil, readAs = Nil, applicationId = None)
|
||||
val requestId = UUID.randomUUID()
|
||||
val redirectUri = Uri("https://localhost/cb")
|
||||
val params = templates.createAuthRequest(claims, requestId, redirectUri).success.value
|
||||
params.keys should contain.only(
|
||||
"audience",
|
||||
"client_id",
|
||||
"redirect_uri",
|
||||
"response_type",
|
||||
"scope",
|
||||
"state",
|
||||
)
|
||||
params should contain.allOf(
|
||||
"audience" -> "https://daml.com/ledger-api",
|
||||
"client_id" -> clientId,
|
||||
"redirect_uri" -> redirectUri.toString,
|
||||
"response_type" -> "code",
|
||||
"scope" -> "offline_access",
|
||||
"state" -> requestId.toString,
|
||||
)
|
||||
}
|
||||
"handle an admin claim" in {
|
||||
val templates = getTemplates()
|
||||
val claims = Claims(admin = true, actAs = Nil, readAs = Nil, applicationId = None)
|
||||
val requestId = UUID.randomUUID()
|
||||
val redirectUri = Uri("https://localhost/cb")
|
||||
val params = templates.createAuthRequest(claims, requestId, redirectUri).success.value
|
||||
params.keys should contain.only(
|
||||
"audience",
|
||||
"client_id",
|
||||
"redirect_uri",
|
||||
"response_type",
|
||||
"scope",
|
||||
"state",
|
||||
)
|
||||
params should contain.allOf(
|
||||
"audience" -> "https://daml.com/ledger-api",
|
||||
"client_id" -> clientId,
|
||||
"redirect_uri" -> redirectUri.toString,
|
||||
"response_type" -> "code",
|
||||
"state" -> requestId.toString,
|
||||
)
|
||||
val scope = params.valueAt("scope").split(" ")
|
||||
scope should contain.allOf("admin", "offline_access")
|
||||
}
|
||||
"handle actAs claims" in {
|
||||
val templates = getTemplates()
|
||||
val claims = Claims(
|
||||
admin = false,
|
||||
actAs = List("Alice", "Bob").map(Ref.Party.assertFromString),
|
||||
readAs = Nil,
|
||||
applicationId = None,
|
||||
)
|
||||
val requestId = UUID.randomUUID()
|
||||
val redirectUri = Uri("https://localhost/cb")
|
||||
val params = templates.createAuthRequest(claims, requestId, redirectUri).success.value
|
||||
params.keys should contain.only(
|
||||
"audience",
|
||||
"client_id",
|
||||
"redirect_uri",
|
||||
"response_type",
|
||||
"scope",
|
||||
"state",
|
||||
)
|
||||
params should contain.allOf(
|
||||
"audience" -> "https://daml.com/ledger-api",
|
||||
"client_id" -> clientId,
|
||||
"redirect_uri" -> redirectUri.toString,
|
||||
"response_type" -> "code",
|
||||
"state" -> requestId.toString,
|
||||
)
|
||||
val scope = params.valueAt("scope").split(" ")
|
||||
scope should contain.allOf("actAs:Alice", "actAs:Bob", "offline_access")
|
||||
}
|
||||
"handle readAs claims" in {
|
||||
val templates = getTemplates()
|
||||
val claims = Claims(
|
||||
admin = false,
|
||||
actAs = Nil,
|
||||
readAs = List("Alice", "Bob").map(Ref.Party.assertFromString),
|
||||
applicationId = None,
|
||||
)
|
||||
val requestId = UUID.randomUUID()
|
||||
val redirectUri = Uri("https://localhost/cb")
|
||||
val params = templates.createAuthRequest(claims, requestId, redirectUri).success.value
|
||||
params.keys should contain.only(
|
||||
"audience",
|
||||
"client_id",
|
||||
"redirect_uri",
|
||||
"response_type",
|
||||
"scope",
|
||||
"state",
|
||||
)
|
||||
params should contain.allOf(
|
||||
"audience" -> "https://daml.com/ledger-api",
|
||||
"client_id" -> clientId,
|
||||
"redirect_uri" -> redirectUri.toString,
|
||||
"response_type" -> "code",
|
||||
"state" -> requestId.toString,
|
||||
)
|
||||
val scope = params.valueAt("scope").split(" ")
|
||||
scope should contain.allOf("offline_access", "readAs:Alice", "readAs:Bob")
|
||||
}
|
||||
"handle an applicationId claim" in {
|
||||
val templates = getTemplates()
|
||||
val claims = Claims(
|
||||
admin = false,
|
||||
actAs = Nil,
|
||||
readAs = Nil,
|
||||
applicationId = Some(Ref.ApplicationId.assertFromString("application-id")),
|
||||
)
|
||||
val requestId = UUID.randomUUID()
|
||||
val redirectUri = Uri("https://localhost/cb")
|
||||
val params = templates.createAuthRequest(claims, requestId, redirectUri).success.value
|
||||
params.keys should contain.only(
|
||||
"audience",
|
||||
"client_id",
|
||||
"redirect_uri",
|
||||
"response_type",
|
||||
"scope",
|
||||
"state",
|
||||
)
|
||||
params should contain.allOf(
|
||||
"audience" -> "https://daml.com/ledger-api",
|
||||
"client_id" -> clientId,
|
||||
"redirect_uri" -> redirectUri.toString,
|
||||
"response_type" -> "code",
|
||||
"state" -> requestId.toString,
|
||||
)
|
||||
val scope = params.valueAt("scope").split(" ")
|
||||
scope should contain.allOf("applicationId:application-id", "offline_access")
|
||||
}
|
||||
}
|
||||
"the builtin token template" should {
|
||||
"be complete" in {
|
||||
val templates = getTemplates()
|
||||
val code = "request-code"
|
||||
val redirectUri = Uri("https://localhost/cb")
|
||||
val params = templates.createTokenRequest(code, redirectUri).success.value
|
||||
params shouldBe Map(
|
||||
"client_id" -> clientId,
|
||||
"client_secret" -> clientSecret.value,
|
||||
"code" -> code,
|
||||
"grant_type" -> "authorization_code",
|
||||
"redirect_uri" -> redirectUri.toString,
|
||||
)
|
||||
}
|
||||
}
|
||||
"the builtin refresh template" should {
|
||||
"be complete" in {
|
||||
val templates = getTemplates()
|
||||
val refreshToken = RefreshToken("refresh-token")
|
||||
val params = templates.createRefreshRequest(refreshToken).success.value
|
||||
params shouldBe Map(
|
||||
"client_id" -> clientId,
|
||||
"client_secret" -> clientSecret.value,
|
||||
"grant_type" -> "refresh_code",
|
||||
"refresh_token" -> refreshToken,
|
||||
)
|
||||
}
|
||||
}
|
||||
"user defined templates" should {
|
||||
"override the auth template" in withJsonnetFile("""{"key": "value"}""") { templatePath =>
|
||||
val templates = getTemplates(authTemplate = Some(templatePath))
|
||||
val claims = Claims(admin = false, actAs = Nil, readAs = Nil, applicationId = None)
|
||||
val requestId = UUID.randomUUID()
|
||||
val redirectUri = Uri("https://localhost/cb")
|
||||
val params = templates.createAuthRequest(claims, requestId, redirectUri).success.value
|
||||
params shouldBe Map("key" -> "value")
|
||||
}
|
||||
"override the token template" in withJsonnetFile("""{"key": "value"}""") { templatePath =>
|
||||
val templates = getTemplates(tokenTemplate = Some(templatePath))
|
||||
val code = "request-code"
|
||||
val redirectUri = Uri("https://localhost/cb")
|
||||
val params = templates.createTokenRequest(code, redirectUri).success.value
|
||||
params shouldBe Map("key" -> "value")
|
||||
}
|
||||
"override the refresh template" in withJsonnetFile("""{"key": "value"}""") { templatePath =>
|
||||
val templates = getTemplates(refreshTemplate = Some(templatePath))
|
||||
val refreshToken = RefreshToken("refresh-token")
|
||||
val params = templates.createRefreshRequest(refreshToken).success.value
|
||||
params shouldBe Map("key" -> "value")
|
||||
}
|
||||
}
|
||||
}
|
@ -1,198 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.oauth2.test.server
|
||||
|
||||
import org.apache.pekko.actor.ActorSystem
|
||||
import org.apache.pekko.http.scaladsl.Http
|
||||
import org.apache.pekko.http.scaladsl.Http.ServerBinding
|
||||
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
|
||||
import org.apache.pekko.http.scaladsl.marshalling.Marshal
|
||||
import org.apache.pekko.http.scaladsl.model.Uri.Path
|
||||
import org.apache.pekko.http.scaladsl.model._
|
||||
import org.apache.pekko.http.scaladsl.server.Directives._
|
||||
import org.apache.pekko.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller}
|
||||
import com.daml.auth.oauth2.api.{Request, Response}
|
||||
import com.daml.ports.Port
|
||||
import spray.json._
|
||||
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
import scala.language.postfixOps
|
||||
|
||||
// This is a test client (using terminology from oauth).
|
||||
// The trigger service would also take the role of a client.
|
||||
object Client {
|
||||
import com.daml.auth.oauth2.api.JsonProtocol._
|
||||
|
||||
case class Config(
|
||||
port: Port,
|
||||
authServerUrl: Uri,
|
||||
clientId: String,
|
||||
clientSecret: String,
|
||||
)
|
||||
|
||||
object JsonProtocol extends DefaultJsonProtocol {
|
||||
implicit val accessParamsFormat: RootJsonFormat[AccessParams] = jsonFormat3(AccessParams)
|
||||
implicit val refreshParamsFormat: RootJsonFormat[RefreshParams] = jsonFormat1(RefreshParams)
|
||||
implicit object ResponseJsonFormat extends RootJsonFormat[Response] {
|
||||
implicit private val accessFormat: RootJsonFormat[AccessResponse] = jsonFormat2(
|
||||
AccessResponse
|
||||
)
|
||||
implicit private val errorFormat: RootJsonFormat[ErrorResponse] = jsonFormat1(ErrorResponse)
|
||||
def write(resp: Response) = resp match {
|
||||
case resp @ AccessResponse(_, _) => resp.toJson
|
||||
case resp @ ErrorResponse(_) => resp.toJson
|
||||
}
|
||||
def read(value: JsValue) =
|
||||
(
|
||||
value.convertTo(safeReader[AccessResponse]),
|
||||
value.convertTo(safeReader[ErrorResponse]),
|
||||
) match {
|
||||
case (Right(a), _) => a
|
||||
case (_, Right(b)) => b
|
||||
case (Left(ea), Left(eb)) =>
|
||||
deserializationError(s"Could not read Response value:\n$ea\n$eb")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class AccessParams(parties: Seq[String], admin: Boolean, applicationId: Option[String])
|
||||
case class RefreshParams(refreshToken: String)
|
||||
sealed trait Response
|
||||
final case class AccessResponse(token: String, refresh: String) extends Response
|
||||
final case class ErrorResponse(error: String) extends Response
|
||||
|
||||
def toRedirectUri(uri: Uri): Uri = uri.withPath(Path./("cb"))
|
||||
|
||||
def start(
|
||||
config: Config
|
||||
)(implicit asys: ActorSystem, ec: ExecutionContext): Future[ServerBinding] = {
|
||||
import JsonProtocol._
|
||||
implicit val unmarshal: Unmarshaller[String, Uri] = Unmarshaller.strict(Uri(_))
|
||||
val route = concat(
|
||||
// Some parameter that requires authorization for some parties. This will in the end return the token
|
||||
// produced by the authorization server.
|
||||
path("access") {
|
||||
post {
|
||||
entity(as[AccessParams]) { params =>
|
||||
extractRequest { request =>
|
||||
val redirectUri = toRedirectUri(request.uri)
|
||||
val scope =
|
||||
(params.parties.map(p => "actAs:" + p) ++
|
||||
(if (params.admin) List("admin") else Nil) ++
|
||||
params.applicationId.toList.map(id => "applicationId:" + id)).mkString(" ")
|
||||
val authParams = Request.Authorize(
|
||||
responseType = "code",
|
||||
clientId = config.clientId,
|
||||
redirectUri = redirectUri,
|
||||
scope = Some(scope),
|
||||
state = None,
|
||||
audience = Some("https://daml.com/ledger-api"),
|
||||
)
|
||||
redirect(
|
||||
config.authServerUrl
|
||||
.withQuery(authParams.toQuery)
|
||||
.withPath(Path./("authorize")),
|
||||
StatusCodes.Found,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
path("cb") {
|
||||
get {
|
||||
parameters(Symbol("code"), Symbol("state") ?).as[Response.Authorize](Response.Authorize) {
|
||||
resp =>
|
||||
extractRequest { request =>
|
||||
// We got the code, now request a token
|
||||
val body = Request.Token(
|
||||
grantType = "authorization_code",
|
||||
code = resp.code,
|
||||
redirectUri = toRedirectUri(request.uri),
|
||||
clientId = config.clientId,
|
||||
clientSecret = config.clientSecret,
|
||||
)
|
||||
val f = for {
|
||||
entity <- Marshal(body).to[RequestEntity]
|
||||
req = HttpRequest(
|
||||
uri = config.authServerUrl.withPath(Path./("token")),
|
||||
entity = entity,
|
||||
method = HttpMethods.POST,
|
||||
)
|
||||
resp <- Http().singleRequest(req)
|
||||
tokenResp <- Unmarshal(resp).to[Response.Token]
|
||||
} yield tokenResp
|
||||
onSuccess(f) { tokenResp =>
|
||||
// Now we have the access_token and potentially the refresh token. At this point,
|
||||
// we would start the trigger.
|
||||
complete(
|
||||
AccessResponse(
|
||||
tokenResp.accessToken,
|
||||
tokenResp.refreshToken.getOrElse(
|
||||
sys.error("/token endpoint failed to return a refresh token")
|
||||
),
|
||||
): Response
|
||||
)
|
||||
}
|
||||
}
|
||||
} ~
|
||||
parameters(
|
||||
Symbol("error"),
|
||||
Symbol("error_description") ?,
|
||||
Symbol("error_uri").as[Uri] ?,
|
||||
Symbol("state") ?,
|
||||
)
|
||||
.as[Response.Error](Response.Error) { resp =>
|
||||
complete(ErrorResponse(resp.error): Response)
|
||||
}
|
||||
}
|
||||
},
|
||||
path("refresh") {
|
||||
post {
|
||||
entity(as[RefreshParams]) { params =>
|
||||
val body = Request.Refresh(
|
||||
grantType = "refresh_token",
|
||||
refreshToken = params.refreshToken,
|
||||
clientId = config.clientId,
|
||||
clientSecret = config.clientSecret,
|
||||
)
|
||||
val f =
|
||||
for {
|
||||
entity <- Marshal(body).to[RequestEntity]
|
||||
req = HttpRequest(
|
||||
uri = config.authServerUrl.withPath(Path./("token")),
|
||||
entity = entity,
|
||||
method = HttpMethods.POST,
|
||||
)
|
||||
resp <- Http().singleRequest(req)
|
||||
tokenResp <-
|
||||
if (resp.status != StatusCodes.OK) {
|
||||
Unmarshal(resp).to[String].flatMap { msg =>
|
||||
throw new RuntimeException(
|
||||
s"Failed to fetch refresh token (${resp.status}): $msg."
|
||||
)
|
||||
}
|
||||
} else {
|
||||
Unmarshal(resp).to[Response.Token]
|
||||
}
|
||||
} yield tokenResp
|
||||
onSuccess(f) { tokenResp =>
|
||||
// Now we have the access_token and potentially the refresh token. At this point,
|
||||
// we would start the trigger.
|
||||
complete(
|
||||
AccessResponse(
|
||||
tokenResp.accessToken,
|
||||
tokenResp.refreshToken.getOrElse(
|
||||
sys.error("/token endpoint failed to return a refresh token")
|
||||
),
|
||||
): Response
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
Http().newServerAt("localhost", config.port.value).bind(route)
|
||||
}
|
||||
|
||||
}
|
@ -1,36 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.oauth2.test.server
|
||||
|
||||
import java.time.{Clock, Duration, Instant, ZoneId}
|
||||
|
||||
import org.apache.pekko.actor.ActorSystem
|
||||
import org.apache.pekko.http.scaladsl.Http.ServerBinding
|
||||
import com.daml.clock.AdjustableClock
|
||||
import com.daml.ledger.resources.{Resource, ResourceContext, ResourceOwner}
|
||||
|
||||
import scala.concurrent.Future
|
||||
|
||||
object Resources {
|
||||
def clock(start: Instant, zoneId: ZoneId): ResourceOwner[AdjustableClock] =
|
||||
new ResourceOwner[AdjustableClock] {
|
||||
override def acquire()(implicit context: ResourceContext): Resource[AdjustableClock] = {
|
||||
Resource(Future(AdjustableClock(Clock.fixed(start, zoneId), Duration.ZERO)))(_ =>
|
||||
Future(())
|
||||
)
|
||||
}
|
||||
}
|
||||
def authServerBinding(server: Server)(implicit sys: ActorSystem): ResourceOwner[ServerBinding] =
|
||||
new ResourceOwner[ServerBinding] {
|
||||
override def acquire()(implicit context: ResourceContext): Resource[ServerBinding] =
|
||||
Resource(server.start())(_.unbind().map(_ => ()))
|
||||
}
|
||||
def authClientBinding(
|
||||
config: Client.Config
|
||||
)(implicit sys: ActorSystem): ResourceOwner[ServerBinding] =
|
||||
new ResourceOwner[ServerBinding] {
|
||||
override def acquire()(implicit context: ResourceContext): Resource[ServerBinding] =
|
||||
Resource(Client.start(config))(_.unbind().map(_ => ()))
|
||||
}
|
||||
}
|
@ -1,285 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.oauth2.test.server
|
||||
|
||||
import org.apache.pekko.http.scaladsl.Http
|
||||
import org.apache.pekko.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
|
||||
import org.apache.pekko.http.scaladsl.model.Uri.Path
|
||||
import org.apache.pekko.http.scaladsl.model._
|
||||
import org.apache.pekko.http.scaladsl.model.headers.Location
|
||||
import org.apache.pekko.http.scaladsl.unmarshalling.Unmarshal
|
||||
import com.daml.jwt.JwtDecoder
|
||||
import com.daml.jwt.domain.Jwt
|
||||
import com.daml.ledger.api.auth.{
|
||||
AuthServiceJWTCodec,
|
||||
CustomDamlJWTPayload,
|
||||
AuthServiceJWTPayload,
|
||||
StandardJWTPayload,
|
||||
}
|
||||
import com.daml.lf.data.Ref
|
||||
import com.daml.ledger.api.testing.utils.SuiteResourceManagementAroundAll
|
||||
import org.scalatest.OptionValues
|
||||
import org.scalatest.matchers.should.Matchers
|
||||
import org.scalatest.wordspec.AsyncWordSpec
|
||||
import spray.json._
|
||||
|
||||
import java.time.Instant
|
||||
import scala.concurrent.Future
|
||||
import scala.util.Try
|
||||
|
||||
abstract class Test
|
||||
extends AsyncWordSpec
|
||||
with Matchers
|
||||
with OptionValues
|
||||
with TestFixture
|
||||
with SuiteResourceManagementAroundAll {
|
||||
import Client.JsonProtocol._
|
||||
import Test._
|
||||
|
||||
type Tok <: AuthServiceJWTPayload
|
||||
protected[this] val Tok: TokenCompat[Tok]
|
||||
implicit def `default Token`: Token[Tok]
|
||||
|
||||
private def readJWTTokenFromString[A](
|
||||
serializedPayload: String
|
||||
)(implicit A: Token[A]): Try[A] =
|
||||
AuthServiceJWTCodec.readFromString(serializedPayload).flatMap { t => Try(A.run(t)) }
|
||||
|
||||
private def requestToken[A: Token](
|
||||
parties: Seq[String],
|
||||
admin: Boolean,
|
||||
applicationId: Option[String],
|
||||
): Future[Either[String, (A, String)]] = {
|
||||
lazy val clientUri = Uri()
|
||||
.withAuthority(clientBinding.localAddress.getHostString, clientBinding.localAddress.getPort)
|
||||
val req = HttpRequest(
|
||||
uri = clientUri.withPath(Path./("access")).withScheme("http"),
|
||||
method = HttpMethods.POST,
|
||||
entity = HttpEntity(
|
||||
MediaTypes.`application/json`,
|
||||
Client.AccessParams(parties, admin, applicationId).toJson.compactPrint,
|
||||
),
|
||||
)
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
// Redirect to /authorize on authorization server (No automatic redirect handling in pekko-http)
|
||||
resp <- {
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val req = HttpRequest(uri = resp.header[Location].value.uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
// Redirect to /cb on client.
|
||||
resp <- {
|
||||
resp.status should ===(StatusCodes.Found)
|
||||
val req = HttpRequest(uri = resp.header[Location].value.uri)
|
||||
Http().singleRequest(req)
|
||||
}
|
||||
// Actual token response (proxied from auth server to us via the client)
|
||||
body <- Unmarshal(resp).to[Client.Response]
|
||||
result <- body match {
|
||||
case Client.AccessResponse(token, refreshToken) =>
|
||||
for {
|
||||
decodedJwt <- JwtDecoder
|
||||
.decode(Jwt(token))
|
||||
.fold(
|
||||
e => Future.failed(new IllegalArgumentException(e.toString)),
|
||||
Future.successful(_),
|
||||
)
|
||||
payload <- Future.fromTry(readJWTTokenFromString[A](decodedJwt.payload))
|
||||
} yield Right((payload, refreshToken))
|
||||
case Client.ErrorResponse(error) => Future(Left(error))
|
||||
}
|
||||
} yield result
|
||||
}
|
||||
|
||||
private def requestRefresh[A: Token](
|
||||
refreshToken: String
|
||||
): Future[Either[String, (A, String)]] = {
|
||||
lazy val clientUri = Uri()
|
||||
.withAuthority(clientBinding.localAddress.getHostString, clientBinding.localAddress.getPort)
|
||||
val req = HttpRequest(
|
||||
uri = clientUri.withPath(Path./("refresh")).withScheme("http"),
|
||||
method = HttpMethods.POST,
|
||||
entity = HttpEntity(
|
||||
MediaTypes.`application/json`,
|
||||
Client.RefreshParams(refreshToken).toJson.compactPrint,
|
||||
),
|
||||
)
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
// Token response (proxied from auth server to us via the client)
|
||||
body <- Unmarshal(resp).to[Client.Response]
|
||||
result <- body match {
|
||||
case Client.AccessResponse(token, refreshToken) =>
|
||||
for {
|
||||
decodedJwt <- JwtDecoder
|
||||
.decode(Jwt(token))
|
||||
.fold(
|
||||
e => Future.failed(new IllegalArgumentException(e.toString)),
|
||||
Future.successful(_),
|
||||
)
|
||||
payload <- Future.fromTry(readJWTTokenFromString[A](decodedJwt.payload))
|
||||
} yield Right((payload, refreshToken))
|
||||
case Client.ErrorResponse(error) => Future(Left(error))
|
||||
}
|
||||
} yield result
|
||||
}
|
||||
|
||||
protected[this] def expectToken(
|
||||
parties: Seq[String],
|
||||
admin: Boolean = false,
|
||||
applicationId: Option[String] = None,
|
||||
): Future[(Tok, String)] =
|
||||
requestToken(parties, admin, applicationId).flatMap {
|
||||
case Left(error) => fail(s"Expected token but got error-code $error")
|
||||
case Right(token) => Future(token)
|
||||
}
|
||||
|
||||
protected[this] def expectError(
|
||||
parties: Seq[String],
|
||||
admin: Boolean = false,
|
||||
applicationId: Option[String] = None,
|
||||
): Future[String] =
|
||||
requestToken[AuthServiceJWTPayload](parties, admin, applicationId).flatMap {
|
||||
case Left(error) => Future(error)
|
||||
case Right(_) => fail("Expected an error but got a token")
|
||||
}
|
||||
|
||||
private def expectRefresh(refreshToken: String): Future[(Tok, String)] =
|
||||
requestRefresh(refreshToken).flatMap {
|
||||
case Left(error) => fail(s"Expected token but got error-code $error")
|
||||
case Right(token) => Future(token)
|
||||
}
|
||||
|
||||
"the auth server" should {
|
||||
"refresh a token" in {
|
||||
for {
|
||||
(token1, refresh1) <- expectToken(Seq())
|
||||
_ <- Future(clock.set((Tok exp token1) plusSeconds 1))
|
||||
(token2, _) <- expectRefresh(refresh1)
|
||||
} yield {
|
||||
(Tok exp token2) should be > (Tok exp token1)
|
||||
(Tok withoutExp token1) should ===((Tok withoutExp token2))
|
||||
}
|
||||
}
|
||||
"return a token with the requested app id" in {
|
||||
for {
|
||||
(token, __) <- expectToken(Seq(), applicationId = Some("my-app-id"))
|
||||
} yield {
|
||||
Tok.userId(token) should ===(Some("my-app-id"))
|
||||
}
|
||||
}
|
||||
"return a token with no app id if non is requested" in {
|
||||
for {
|
||||
(token, __) <- expectToken(Seq(), applicationId = None)
|
||||
} yield {
|
||||
Tok.userId(token) should ===(None)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
class ClaimTokenTest extends Test {
|
||||
import Test._
|
||||
|
||||
override def yieldUserTokens = false
|
||||
|
||||
type Tok = CustomDamlJWTPayload
|
||||
override object Tok extends TokenCompat[Tok] {
|
||||
override def userId(t: Tok) = t.applicationId
|
||||
override def exp(t: Tok) = t.exp.value
|
||||
override def withoutExp(t: Tok) = t copy (exp = None)
|
||||
}
|
||||
|
||||
implicit override def `default Token`: Token[Tok] = new Token({
|
||||
case _: StandardJWTPayload =>
|
||||
throw new IllegalStateException(
|
||||
"auth-middleware: user access tokens are not expected here"
|
||||
)
|
||||
case payload: CustomDamlJWTPayload => payload
|
||||
})
|
||||
|
||||
"the auth server with claim tokens" should {
|
||||
"issue a token with no parties" in {
|
||||
for {
|
||||
(token, _) <- expectToken(Seq())
|
||||
} yield {
|
||||
token.actAs should ===(Seq())
|
||||
}
|
||||
}
|
||||
"issue a token with 1 party" in {
|
||||
for {
|
||||
(token, _) <- expectToken(Seq("Alice"))
|
||||
} yield {
|
||||
token.actAs should ===(Seq("Alice"))
|
||||
}
|
||||
}
|
||||
"issue a token with multiple parties" in {
|
||||
for {
|
||||
(token, _) <- expectToken(Seq("Alice", "Bob"))
|
||||
} yield {
|
||||
token.actAs should ===(Seq("Alice", "Bob"))
|
||||
}
|
||||
}
|
||||
"deny access to unauthorized parties" in {
|
||||
server.revokeParty(Ref.Party.assertFromString("Eve"))
|
||||
for {
|
||||
error <- expectError(Seq("Alice", "Eve"))
|
||||
} yield {
|
||||
error should ===("access_denied")
|
||||
}
|
||||
}
|
||||
"issue a token with admin access" in {
|
||||
for {
|
||||
(token, _) <- expectToken(Seq(), admin = true)
|
||||
} yield {
|
||||
assert(token.admin)
|
||||
}
|
||||
}
|
||||
"deny admin access if unauthorized" in {
|
||||
server.revokeAdmin()
|
||||
for {
|
||||
error <- expectError(Seq(), admin = true)
|
||||
} yield {
|
||||
error should ===("access_denied")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class UserTokenTest extends Test {
|
||||
import Test._
|
||||
|
||||
override def yieldUserTokens = true
|
||||
|
||||
type Tok = StandardJWTPayload
|
||||
override object Tok extends TokenCompat[Tok] {
|
||||
override def userId(t: Tok) = Some(t.userId).filter(_.nonEmpty)
|
||||
override def exp(t: Tok) = t.exp.value
|
||||
override def withoutExp(t: Tok) = t copy (exp = None)
|
||||
}
|
||||
|
||||
implicit override def `default Token`: Token[Tok] = new Token({
|
||||
case payload: StandardJWTPayload => payload
|
||||
case _: CustomDamlJWTPayload =>
|
||||
throw new IllegalStateException(
|
||||
"auth-middleware: custom tokens are not expected here"
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
object Test {
|
||||
final class Token[A](val run: AuthServiceJWTPayload => A) extends AnyVal
|
||||
|
||||
object Token {
|
||||
implicit val any: Token[AuthServiceJWTPayload] = new Token(identity)
|
||||
}
|
||||
|
||||
private[server] abstract class TokenCompat[Tok] {
|
||||
def userId(t: Tok): Option[String]
|
||||
def exp(t: Tok): Instant
|
||||
def withoutExp(t: Tok): Tok
|
||||
}
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.auth.oauth2.test.server
|
||||
|
||||
import java.time.{Instant, ZoneId}
|
||||
|
||||
import org.apache.pekko.http.scaladsl.Http.ServerBinding
|
||||
import org.apache.pekko.http.scaladsl.model.Uri
|
||||
import com.daml.clock.AdjustableClock
|
||||
import com.daml.ledger.api.testing.utils.{
|
||||
PekkoBeforeAndAfterAll,
|
||||
OwnedResource,
|
||||
Resource,
|
||||
SuiteResource,
|
||||
}
|
||||
import com.daml.ledger.resources.ResourceContext
|
||||
import com.daml.ports.Port
|
||||
import org.scalatest.{BeforeAndAfterEach, Suite}
|
||||
|
||||
trait TestFixture
|
||||
extends PekkoBeforeAndAfterAll
|
||||
with BeforeAndAfterEach
|
||||
with SuiteResource[(AdjustableClock, Server, ServerBinding, ServerBinding)] {
|
||||
self: Suite =>
|
||||
protected val ledgerId: String = "test-ledger"
|
||||
protected val applicationId: String = "test-application"
|
||||
protected val jwtSecret: String = "secret"
|
||||
lazy protected val clock: AdjustableClock = suiteResource.value._1
|
||||
lazy protected val server: Server = suiteResource.value._2
|
||||
lazy protected val serverBinding: ServerBinding = suiteResource.value._3
|
||||
lazy protected val clientBinding: ServerBinding = suiteResource.value._4
|
||||
protected[this] def yieldUserTokens: Boolean
|
||||
override protected lazy val suiteResource
|
||||
: Resource[(AdjustableClock, Server, ServerBinding, ServerBinding)] = {
|
||||
implicit val resourceContext: ResourceContext = ResourceContext(system.dispatcher)
|
||||
new OwnedResource[ResourceContext, (AdjustableClock, Server, ServerBinding, ServerBinding)](
|
||||
for {
|
||||
clock <- Resources.clock(Instant.now(), ZoneId.systemDefault())
|
||||
server = Server(
|
||||
Config(
|
||||
port = Port.Dynamic,
|
||||
ledgerId = ledgerId,
|
||||
jwtSecret = jwtSecret,
|
||||
clock = Some(clock),
|
||||
yieldUserTokens = yieldUserTokens,
|
||||
)
|
||||
)
|
||||
serverBinding <- Resources.authServerBinding(server)
|
||||
clientBinding <- Resources.authClientBinding(
|
||||
Client.Config(
|
||||
port = Port.Dynamic,
|
||||
authServerUrl = Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority(
|
||||
serverBinding.localAddress.getHostString,
|
||||
serverBinding.localAddress.getPort,
|
||||
),
|
||||
clientId = "test-client",
|
||||
clientSecret = "test-client-secret",
|
||||
)
|
||||
)
|
||||
} yield { (clock, server, serverBinding, clientBinding) }
|
||||
)
|
||||
}
|
||||
|
||||
override protected def afterEach(): Unit = {
|
||||
server.resetAuthorizedParties()
|
||||
server.resetAdmin()
|
||||
|
||||
super.afterEach()
|
||||
}
|
||||
}
|
@ -1,90 +0,0 @@
|
||||
# Design for Trigger Service Authentication/Authorization
|
||||
|
||||
## Goals
|
||||
|
||||
- Be compatible with an OAuth2 authorization code grant
|
||||
https://tools.ietf.org/html/rfc6749#section-4.1
|
||||
- Do not require OAuth2 or any other specific
|
||||
authentication/authorization protocol from the IAM. In other words,
|
||||
the communication with the IAM must be pluggable.
|
||||
- Do not rely on wildcard access for the trigger service, it should
|
||||
only be able to start triggers on behalf of a party if a user that
|
||||
controls that party has given consent.
|
||||
- Support long-running triggers without constant user
|
||||
interaction. Since auth tokens are often short-lived (e.g., expire
|
||||
after 1h), this implies some mechanism for token refresh.
|
||||
|
||||
## Design
|
||||
|
||||
This involves 3 components:
|
||||
|
||||
1. The trigger service provided by DA.
|
||||
2. An auth middleware. DA provides an implementation of this for at
|
||||
least the OAuth2 authorization code grant but this is completely
|
||||
pluggable so if the DA-provided middleware does not cover the IAM
|
||||
infrastructure of a client, they can implement their own.
|
||||
3. The IAM. This is the entity that signs Ledger API tokens. This is
|
||||
not provided by DA. The Ledger is configured to trust this entity.
|
||||
|
||||
### Auth Middleware API
|
||||
|
||||
The auth middleware provides a few endpoints (the names don’t matter
|
||||
all that much, they just need to be fixed once).
|
||||
|
||||
1. /auth The trigger service, will contact this endpoint with a set of
|
||||
claims passing along all cookies in the original request. If
|
||||
the user has already authenticated and is authorized for those
|
||||
claims, it will return an access token (an opaque blob to the
|
||||
trigger service) for at least those claims and a refresh token
|
||||
(another opaque blob). If not, it will return an unauthorized
|
||||
status code.
|
||||
|
||||
2. /login If /auth returned unauthorized, the trigger service will
|
||||
redirect users to this.
|
||||
For HTML requests via HTTP redirect, otherwise via a custom WWW-Authenticate challenge in a 401 resonse.
|
||||
The parameters will include the requested claims as well as an optional callback URL (note that this is not the OAuth2 callback url but a callback URL on the trigger service). This will start an auth flow,
|
||||
e.g., an OAuth2 authorization code grant. If the flow succeeds the
|
||||
auth service will set a cookie with the access and refresh token
|
||||
and redirect to the callback URL if present or return status code 200.
|
||||
At this point, a request to
|
||||
/auth will succeed (based on the cookie). If the flow failed the
|
||||
auth service will not set a cookie and redirect to the callback URL
|
||||
with an additional error and optional error_description parameter
|
||||
or return 403 with error and optional error_description in the response body.
|
||||
|
||||
3. /refresh This accepts a refresh token and returns a new access
|
||||
token and optionally a new refresh token (or fails).
|
||||
|
||||
### Auth Middleware Implementation based on OAuth2 Authorization Code Grant
|
||||
|
||||
1. /auth checks for the presence of a cookie with the tokens in it.
|
||||
2. /login starts an OAuth2 authorization code grant flow. After the
|
||||
redirect URI is called by the authorization server, the middleware
|
||||
makes a request to get the tokens, sets them in cookies and
|
||||
redirects back to the callback URI. Upon failure the middleware
|
||||
forwards the error and error_description to the callback URI.
|
||||
3. /refresh simply proxies to the refresh endpoint on the
|
||||
authorization server adding the client id and secret.
|
||||
|
||||
Note that the auth middleware does not need to persist any state in
|
||||
this model. The trigger service does need to persist at least the
|
||||
refresh token and potentially the access token.
|
||||
|
||||
## Related Projects
|
||||
|
||||
The design here is very close to existing OAuth2 middlewares/proxies
|
||||
such as [Vouch](https://github.com/vouch/vouch-proxy) or [OAuth2
|
||||
Proxy](https://github.com/oauth2-proxy/oauth2-proxy). There are two
|
||||
main differences:
|
||||
|
||||
1. The trigger service extracts the required claims from the original
|
||||
request. This means that the first request goes to the trigger
|
||||
service and not to the proxy/middleware. The middleware shouldn’t
|
||||
have to know how to map a request to the set of required claims so
|
||||
this seems the only workable option.
|
||||
|
||||
2. The /auth request takes a custom set of claims. The existing
|
||||
proxies focus on OIDC and don’t support any custom claims.
|
||||
|
||||
Nevertheless, the fact that the design is very close to existing
|
||||
proxies seems like a good thing.
|
@ -1,12 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
set -eou pipefail
|
||||
shopt -s globstar
|
||||
|
||||
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"
|
||||
|
||||
"$DIR"/../../libs-scala/flyway-testing/hash-migrations.sh \
|
||||
"$DIR"/src/main/resources/com/daml/lf/engine/trigger/db/migration/**/*.sql
|
@ -1,26 +0,0 @@
|
||||
<configuration>
|
||||
<appender name="console" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<if condition='isDefined("LOG_FORMAT_JSON")'>
|
||||
<then>
|
||||
<encoder class="net.logstash.logback.encoder.LogstashEncoder"/>
|
||||
</then>
|
||||
<else>
|
||||
<encoder>
|
||||
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg %replace(, context: %marker){', context: $', ''} %n</pattern>
|
||||
</encoder>
|
||||
</else>
|
||||
</if>
|
||||
</appender>
|
||||
|
||||
<appender name="STDOUT" class="net.logstash.logback.appender.LoggingEventAsyncDisruptorAppender">
|
||||
<appender-ref ref="console"/>
|
||||
</appender>
|
||||
|
||||
<logger name="io.netty" level="WARN"/>
|
||||
<logger name="io.grpc.netty" level="WARN"/>
|
||||
<logger name="com.daml.lf.engine.trigger" level="INFO"/>
|
||||
|
||||
<root level="${LOG_LEVEL_ROOT:-INFO}">
|
||||
<appender-ref ref="STDOUT" />
|
||||
</root>
|
||||
</configuration>
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user