ledger-api-common: Do not mock final classes. (#10733)

* ledger-api-common: Replace `URL` with a trait so we can fake it.

This means we don't have to use Mockito to stub out `URL`, which is
final and therefore requires magic to stub.

CHANGELOG_BEGIN
CHANGELOG_END

* ledger-api-common: Move tests that open streams into SecretsUrlTest.

* ledger-api-common: Make SecretsUrl constructors simple methods.

* ledger-api-common: Reduce concrete implementations of SecretsUrl.
This commit is contained in:
Samir Talwar 2021-09-01 15:36:30 +02:00 committed by GitHub
parent e9c8af5024
commit f6a75b42f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 148 additions and 86 deletions

View File

@ -3,16 +3,16 @@
package com.daml.ledger.api.tls
import org.apache.commons.codec.binary.Hex
import org.apache.commons.io.IOUtils
import spray.json.{DefaultJsonProtocol, RootJsonFormat}
import java.io.File
import java.net.URL
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import javax.crypto.Cipher
import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}
import org.apache.commons.codec.binary.Hex
import org.apache.commons.io.IOUtils
import spray.json.{DefaultJsonProtocol, RootJsonFormat}
import scala.util.Using
final class PrivateKeyDecryptionException(cause: Throwable) extends Exception(cause)
@ -63,17 +63,15 @@ object DecryptionParameters {
/** Creates an instance of [[DecryptionParameters]] by fetching necessary information from an URL
*/
def fromSecretsServer(url: URL): DecryptionParameters = {
val text = fetchPayload(url)
parsePayload(text)
def fromSecretsServer(url: SecretsUrl): DecryptionParameters = {
val body = fetchPayload(url)
parsePayload(body)
}
private[tls] def fetchPayload(url: URL): String = {
val text = Using.resource(url.openStream()) { stream =>
private[tls] def fetchPayload(url: SecretsUrl): String =
Using.resource(url.openStream()) { stream =>
IOUtils.toString(stream, StandardCharsets.UTF_8.name())
}
text
}
private[tls] def parsePayload(payload: String): DecryptionParameters = {
import DecryptionParametersJsonProtocol._

View File

@ -0,0 +1,25 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.ledger.api.tls
import java.io.InputStream
import java.net.URL
import java.nio.file.Path
// This trait is not sealed so we can replace it with a fake in tests.
trait SecretsUrl {
def openStream(): InputStream
}
object SecretsUrl {
def fromString(string: String): SecretsUrl = new FromUrl(new URL(string))
def fromPath(path: Path): SecretsUrl = new FromUrl(path.toUri.toURL)
def fromUrl(url: URL): SecretsUrl = new FromUrl(url)
private final case class FromUrl(url: URL) extends SecretsUrl {
override def openStream(): InputStream = url.openStream()
}
}

View File

@ -3,12 +3,12 @@
package com.daml.ledger.api.tls
import java.io.{ByteArrayInputStream, File, FileInputStream, InputStream}
import java.nio.file.Files
import io.grpc.netty.GrpcSslContexts
import io.netty.handler.ssl.{ClientAuth, SslContext}
import java.io.{ByteArrayInputStream, File, FileInputStream, InputStream}
import java.net.URL
import java.nio.file.Files
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
@ -17,7 +17,7 @@ final case class TlsConfiguration(
keyCertChainFile: Option[File], // mutual auth is disabled if null
keyFile: Option[File],
trustCertCollectionFile: Option[File], // System default if null
secretsUrl: Option[URL] = None,
secretsUrl: Option[SecretsUrl] = None,
clientAuth: ClientAuth =
ClientAuth.REQUIRE, // Client auth setting used by the server. This is not used in the client configuration.
enableCertRevocationChecking: Boolean = false,
@ -90,7 +90,7 @@ final case class TlsConfiguration(
new ByteArrayInputStream(bytes)
}
def secretsUrlOrFail: URL = secretsUrl.getOrElse(
private def secretsUrlOrFail: SecretsUrl = secretsUrl.getOrElse(
throw new IllegalStateException(
s"Unable to convert ${this.toString} to SSL Context: cannot decrypt keyFile without secretsUrl."
)

View File

@ -3,18 +3,17 @@
package com.daml.ledger.api.tls
import com.daml.testing.SimpleHttpServer
import java.io.ByteArrayInputStream
import java.nio.file.Files
import javax.crypto.{Cipher, KeyGenerator, SecretKey}
import org.apache.commons.codec.binary.Hex
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import java.net.URL
import java.nio.file.Files
import javax.crypto.{Cipher, KeyGenerator, SecretKey}
class DecryptionParametersTest extends AnyWordSpec with Matchers {
DecryptionParameters.getClass.getSimpleName should {
"decryption parameters" should {
// given
val key: SecretKey = KeyGenerator.getInstance("AES").generateKey()
@ -106,39 +105,17 @@ class DecryptionParametersTest extends AnyWordSpec with Matchers {
actual shouldBe expected
}
"fetch JSON document from a file URL" in {
"fetch JSON document from a secrets URL" in {
// given
val tmpFilePath = Files.createTempFile("decryption-params", ".json")
val expected = "decryption-params123"
Files.write(tmpFilePath, expected.getBytes)
assume(new String(Files.readAllBytes(tmpFilePath)) == expected)
val url = tmpFilePath.toUri.toURL
assume(url.getProtocol == "file")
val secretsUrl: SecretsUrl = () => new ByteArrayInputStream(expected.getBytes)
// when
val actual = DecryptionParameters.fetchPayload(url)
val actual = DecryptionParameters.fetchPayload(secretsUrl)
// then
actual shouldBe expected
}
"fetch JSON document from a http URL" in {
// given
val expected = "payload123"
val server = SimpleHttpServer.start(expected)
try {
val url = new URL(SimpleHttpServer.responseUrl(server))
assume(url.getProtocol == "http")
// when
val actual = DecryptionParameters.fetchPayload(url)
// then
actual shouldBe expected
} finally {
SimpleHttpServer.stop(server)
}
}
}
}

View File

@ -0,0 +1,62 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.ledger.api.tls
import java.io.{BufferedReader, InputStream, InputStreamReader}
import java.net.URL
import java.nio.file.Files
import java.util.stream.Collectors
import com.daml.ledger.api.tls.SecretsUrlTest._
import com.daml.testing.SimpleHttpServer
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import scala.util.Using
class SecretsUrlTest extends AnyWordSpec with Matchers {
"a secrets URL based on a file" should {
"open a stream" in {
val contents = "Here is some text."
val filePath = Files.createTempFile(getClass.getSimpleName, ".txt")
try {
Files.write(filePath, contents.getBytes)
val secretsUrl = SecretsUrl.fromPath(filePath)
val actualContents = readStreamFully(secretsUrl.openStream())
actualContents should be(contents)
} finally {
Files.delete(filePath)
}
}
}
"a secrets URL based on a URL" should {
"open a stream" in {
val contents = "Here is a response body."
val server = SimpleHttpServer.start(contents)
try {
val url = new URL(SimpleHttpServer.responseUrl(server))
url.getProtocol should be("http")
val secretsUrl = SecretsUrl.fromUrl(url)
val actualContents = readStreamFully(secretsUrl.openStream())
actualContents should be(contents)
} finally {
SimpleHttpServer.stop(server)
}
}
}
}
object SecretsUrlTest {
private def readStreamFully(newStream: => InputStream): String =
Using.resource(newStream) { stream =>
new BufferedReader(new InputStreamReader(stream))
.lines()
.collect(Collectors.joining(System.lineSeparator()))
}
}

View File

@ -3,22 +3,24 @@
package com.daml.ledger.api.tls
import org.apache.commons.io.IOUtils
import org.mockito.Mockito
import org.scalatest.BeforeAndAfterEach
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import java.io.InputStream
import java.net.{ConnectException, URL}
import java.net.ConnectException
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.security.Security
import org.apache.commons.io.IOUtils
import org.scalatest.BeforeAndAfterEach
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
class TlsConfigurationTest extends AnyWordSpec with Matchers with BeforeAndAfterEach {
var systemProperties: Map[String, Option[String]] = Map.empty
var ocspSecurityProperty: Option[String] = None
private var systemProperties: Map[String, Option[String]] = Map.empty
private var ocspSecurityProperty: Option[String] = None
private val Enabled = "true"
private val Disabled = "false"
override def beforeEach(): Unit = {
super.beforeEach()
@ -98,10 +100,9 @@ class TlsConfigurationTest extends AnyWordSpec with Matchers with BeforeAndAfter
Files.write(keyFilePath, "private-key-123".getBytes())
assume(Files.readAllBytes(keyFilePath) sameElements "private-key-123".getBytes)
val keyFile = keyFilePath.toFile
val urlMock = Mockito.mock(classOf[URL])
Mockito.when(urlMock.openStream()).thenThrow(new ConnectException("Mocked url 123"))
val tested = TlsConfiguration.Empty
.copy(secretsUrl = Some(urlMock))
val tested = TlsConfiguration.Empty.copy(
secretsUrl = Some(() => throw new ConnectException("Mocked url 123"))
)
// when
val e = intercept[PrivateKeyDecryptionException] {
@ -127,7 +128,4 @@ class TlsConfigurationTest extends AnyWordSpec with Matchers with BeforeAndAfter
Security.getProperty(OcspProperties.EnableOcspProperty) shouldBe expectedValue
}
private val Enabled = "true"
private val Disabled = "false"
}

View File

@ -3,8 +3,14 @@
package com.daml.ledger.participant.state.kvutils.app
import java.io.File
import java.nio.file.Path
import java.time.Duration
import java.util.UUID
import java.util.concurrent.TimeUnit
import com.daml.caching
import com.daml.ledger.api.tls.TlsConfiguration
import com.daml.ledger.api.tls.{SecretsUrl, TlsConfiguration}
import com.daml.ledger.resources.ResourceOwner
import com.daml.lf.VersionRange
import com.daml.lf.data.Ref
@ -21,12 +27,6 @@ import com.daml.ports.Port
import io.netty.handler.ssl.ClientAuth
import scopt.OptionParser
import java.io.File
import java.net.URL
import java.nio.file.Path
import java.time.Duration
import java.util.UUID
import java.util.concurrent.TimeUnit
import scala.concurrent.duration.FiniteDuration
final case class Config[Extra](
@ -327,7 +327,7 @@ object Config {
"TLS: URL of a secrets service that provide parameters needed to decrypt the private key. Required when private key is encrypted (indicated by '.enc' filename suffix)."
)
.action((url, config) =>
config.withTlsConfig(c => c.copy(secretsUrl = Some(new URL(url))))
config.withTlsConfig(c => c.copy(secretsUrl = Some(SecretsUrl.fromString(url))))
)
checkConfig(c =>

View File

@ -4,10 +4,9 @@
package com.daml.ledger.participant.state.kvutils.app
import java.io.File
import java.net.URL
import java.time.Duration
import com.daml.ledger.api.tls.TlsConfiguration
import com.daml.ledger.api.tls.{SecretsUrl, TlsConfiguration}
import com.daml.lf.data.Ref
import io.netty.handler.ssl.ClientAuth
import org.scalatest.OptionValues
@ -76,7 +75,7 @@ final class ConfigSpec
actual.get.tlsConfig shouldBe Some(
TlsConfiguration(
enabled = true,
secretsUrl = Some(new URL("http://aaa")),
secretsUrl = Some(SecretsUrl.fromString("http://aaa")),
keyFile = Some(new File("key.enc")),
keyCertChainFile = None,
trustCertCollectionFile = None,

View File

@ -3,11 +3,14 @@
package com.daml.platform.sandbox.cli
import java.io.File
import java.time.Duration
import com.daml.buildinfo.BuildInfo
import com.daml.jwt.JwtVerifierConfigurationCli
import com.daml.ledger.api.auth.AuthServiceJWT
import com.daml.ledger.api.domain.LedgerId
import com.daml.ledger.api.tls.TlsConfiguration
import com.daml.ledger.api.tls.{SecretsUrl, TlsConfiguration}
import com.daml.ledger.configuration.LedgerTimeModel
import com.daml.lf.data.Ref
import com.daml.platform.apiserver.SeedService.Seeding
@ -21,9 +24,6 @@ import io.netty.handler.ssl.ClientAuth
import scalaz.syntax.tag._
import scopt.OptionParser
import java.io.File
import java.net.URL
import java.time.Duration
import scala.util.Try
// [[SandboxConfig]] should not expose Options for mandatory fields as such validations should not
@ -127,7 +127,9 @@ class CommonCliBase(name: LedgerName) {
.text(
"TLS: URL of a secrets service that provides parameters needed to decrypt the private key. Required when private key is encrypted (indicated by '.enc' filename suffix)."
)
.action((url, config) => config.withTlsConfig(c => c.copy(secretsUrl = Some(new URL(url)))))
.action((url, config) =>
config.withTlsConfig(c => c.copy(secretsUrl = Some(SecretsUrl.fromString(url))))
)
checkConfig(c =>
c.tlsConfig.fold(success) { tlsConfig =>

View File

@ -3,8 +3,12 @@
package com.daml.platform.sandbox.cli
import java.io.File
import java.net.InetSocketAddress
import java.nio.file.{Files, Paths}
import com.daml.bazeltools.BazelRunfiles.rlocation
import com.daml.ledger.api.tls.TlsConfiguration
import com.daml.ledger.api.tls.{SecretsUrl, TlsConfiguration}
import com.daml.ledger.test.ModelTestDar
import com.daml.lf.data.Ref
import com.daml.metrics.MetricsReporter
@ -17,9 +21,6 @@ import org.scalatest.Assertion
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import java.io.File
import java.net.{InetSocketAddress, URL}
import java.nio.file.{Files, Paths}
import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._
@ -130,7 +131,7 @@ abstract class CommonCliSpecBase(
Some(
TlsConfiguration(
enabled = true,
secretsUrl = Some(new URL("http://aaa")),
secretsUrl = Some(SecretsUrl.fromString("http://aaa")),
keyFile = Some(new File("key.enc")),
keyCertChainFile = None,
trustCertCollectionFile = None,