diff --git a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/Database.scala b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/Database.scala index 035f80e48d..7e7a8ac7e3 100644 --- a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/Database.scala +++ b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/Database.scala @@ -5,7 +5,13 @@ package com.daml.ledger.on.sql import java.sql.Connection -import com.daml.ledger.on.sql.queries.{H2Queries, PostgresqlQueries, Queries, SqliteQueries} +import com.daml.ledger.on.sql.queries.{ + H2Queries, + PostgresqlQueries, + Queries, + ReadQueries, + SqliteQueries +} import com.digitalasset.logging.{ContextualizedLogger, LoggingContext} import com.digitalasset.resources.ProgramResource.StartupException import com.digitalasset.resources.ResourceOwner @@ -19,22 +25,22 @@ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} final class Database( - val queries: Queries, + queries: Connection => Queries, readerConnectionPool: DataSource, writerConnectionPool: DataSource, ) { private val logger = ContextualizedLogger.get(this.getClass) def inReadTransaction[T](message: String)( - body: Connection => Future[T], + body: ReadQueries => Future[T], )(implicit executionContext: ExecutionContext, logCtx: LoggingContext): Future[T] = { - inTransaction(message, readerConnectionPool)(body) + inTransaction(message, readerConnectionPool)(connection => body(queries(connection))) } def inWriteTransaction[T](message: String)( - body: Connection => Future[T], + body: Queries => Future[T], )(implicit executionContext: ExecutionContext, logCtx: LoggingContext): Future[T] = { - inTransaction(message, writerConnectionPool)(body) + inTransaction(message, writerConnectionPool)(connection => body(queries(connection))) } private def inTransaction[T](message: String, connectionPool: DataSource)( @@ -176,26 +182,26 @@ object Database { sealed trait RDBMS { val name: String - val queries: Queries + val queries: Connection => Queries } object RDBMS { object H2 extends RDBMS { override val name: String = "h2" - override val queries: Queries = new H2Queries + override val queries: Connection => Queries = H2Queries.apply } object PostgreSQL extends RDBMS { override val name: String = "postgresql" - override val queries: Queries = new PostgresqlQueries + override val queries: Connection => Queries = PostgresqlQueries.apply } object SQLite extends RDBMS { override val name: String = "sqlite" - override val queries: Queries = new SqliteQueries + override val queries: Connection => Queries = SqliteQueries.apply } } diff --git a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/SqlLedgerReaderWriter.scala b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/SqlLedgerReaderWriter.scala index 674d782801..de95a7451d 100644 --- a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/SqlLedgerReaderWriter.scala +++ b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/SqlLedgerReaderWriter.scala @@ -3,7 +3,6 @@ package com.daml.ledger.on.sql -import java.sql.Connection import java.time.{Clock, Instant} import java.util.UUID @@ -11,6 +10,7 @@ import akka.NotUsed import akka.stream.Materializer import akka.stream.scaladsl.Source import com.daml.ledger.on.sql.SqlLedgerReaderWriter._ +import com.daml.ledger.on.sql.queries.Queries import com.daml.ledger.participant.state.kvutils.api.{LedgerReader, LedgerRecord, LedgerWriter} import com.daml.ledger.participant.state.v1._ import com.daml.ledger.validator.LedgerStateOperations.{Key, Value} @@ -46,8 +46,6 @@ final class SqlLedgerReaderWriter( ) extends LedgerWriter with LedgerReader { - private val queries = database.queries - private val committer = new ValidatingCommitter[Index]( participantId, now, @@ -65,9 +63,8 @@ final class SqlLedgerReaderWriter( RangeSource((start, end) => { Source .futureSource(database - .inReadTransaction(s"Querying events [$start, $end[ from log") { - implicit connection => - Future.successful(queries.selectFromLog(start, end)) + .inReadTransaction(s"Querying events [$start, $end[ from log") { queries => + Future.successful(queries.selectFromLog(start, end)) } .map { result => if (result.length < end - start) { @@ -88,13 +85,12 @@ final class SqlLedgerReaderWriter( object SqlLedgerStateAccess extends LedgerStateAccess[Index] { override def inTransaction[T](body: LedgerStateOperations[Index] => Future[T]): Future[T] = - database.inWriteTransaction("Committing a submission") { implicit connection => - body(new SqlLedgerStateOperations) + database.inWriteTransaction("Committing a submission") { queries => + body(new SqlLedgerStateOperations(queries)) } } - class SqlLedgerStateOperations(implicit connection: Connection) - extends BatchingLedgerStateOperations[Index] { + class SqlLedgerStateOperations(queries: Queries) extends BatchingLedgerStateOperations[Index] { override def readState(keys: Seq[Key]): Future[Seq[Option[Value]]] = Future.successful(queries.selectStateValuesByKeys(keys)) @@ -132,10 +128,10 @@ object SqlLedgerReaderWriter { implicit executionContext: ExecutionContext, logCtx: LoggingContext, ): Future[LedgerId] = - database.inWriteTransaction("Checking ledger ID at startup") { implicit connection => + database.inWriteTransaction("Checking ledger ID at startup") { queries => val providedLedgerId = initialLedgerId.getOrElse(Ref.LedgerString.assertFromString(UUID.randomUUID.toString)) - val ledgerId = database.queries.updateOrRetrieveLedgerId(providedLedgerId) + val ledgerId = queries.updateOrRetrieveLedgerId(providedLedgerId) if (initialLedgerId.exists(_ != ledgerId)) { Future.failed( new LedgerIdMismatchException( @@ -152,9 +148,8 @@ object SqlLedgerReaderWriter { logCtx: LoggingContext, ): Future[Dispatcher[Index]] = database - .inReadTransaction("Reading head at startup") { implicit connection => - Future.successful( - database.queries.selectLatestLogEntryId().map(_ + 1).getOrElse(StartIndex)) + .inReadTransaction("Reading head at startup") { queries => + Future.successful(queries.selectLatestLogEntryId().map(_ + 1).getOrElse(StartIndex)) } .map(head => Dispatcher("sql-participant-state", StartIndex, head)) } diff --git a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/CommonQueries.scala b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/CommonQueries.scala index 4bf30fb1e1..5f070195a0 100644 --- a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/CommonQueries.scala +++ b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/CommonQueries.scala @@ -19,15 +19,14 @@ import com.google.protobuf.ByteString import scala.collection.{breakOut, immutable} trait CommonQueries extends Queries { - override final def selectLatestLogEntryId()(implicit connection: Connection): Option[Index] = + protected implicit val connection: Connection + + override final def selectLatestLogEntryId(): Option[Index] = SQL"SELECT MAX(sequence_no) max_sequence_no FROM #$LogTable" .as(get[Option[Long]]("max_sequence_no").singleOpt) .flatten - override final def selectFromLog( - start: Index, - end: Index, - )(implicit connection: Connection): immutable.Seq[(Index, LedgerRecord)] = + override final def selectFromLog(start: Index, end: Index): immutable.Seq[(Index, LedgerRecord)] = SQL"SELECT sequence_no, entry_id, envelope FROM #$LogTable WHERE sequence_no >= $start AND sequence_no < $end" .as( (long("sequence_no") ~ binaryStream("entry_id") ~ byteArray("envelope")).map { @@ -36,9 +35,7 @@ trait CommonQueries extends Queries { }.*, ) - override final def selectStateValuesByKeys( - keys: Seq[Key], - )(implicit connection: Connection): immutable.Seq[Option[Value]] = { + override final def selectStateValuesByKeys(keys: Seq[Key]): immutable.Seq[Option[Value]] = { val results = SQL"SELECT key, value FROM #$StateTable WHERE key IN ($keys)" .fold(Map.newBuilder[ByteString, Array[Byte]], ColumnAliaser.empty)((builder, row) => builder += ByteString.readFrom(row[InputStream]("key")) -> row[Value]("value")) @@ -48,9 +45,7 @@ trait CommonQueries extends Queries { keys.map(key => results.get(ByteString.copyFrom(key)))(breakOut) } - override final def updateState( - stateUpdates: Seq[(Key, Value)], - )(implicit connection: Connection): Unit = + override final def updateState(stateUpdates: Seq[(Key, Value)]): Unit = executeBatchSql(updateStateQuery, stateUpdates.map { case (key, value) => Seq[NamedParameter]("key" -> key, "value" -> value) }) diff --git a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/H2Queries.scala b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/H2Queries.scala index f3a64e8c6a..69bb301e77 100644 --- a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/H2Queries.scala +++ b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/H2Queries.scala @@ -12,17 +12,17 @@ import com.daml.ledger.on.sql.queries.Queries._ import com.daml.ledger.participant.state.v1.LedgerId import com.daml.ledger.validator.LedgerStateOperations.{Key, Value} -final class H2Queries extends Queries with CommonQueries { - override def updateOrRetrieveLedgerId( - providedLedgerId: LedgerId, - )(implicit connection: Connection): LedgerId = { +final class H2Queries(override protected implicit val connection: Connection) + extends Queries + with CommonQueries { + override def updateOrRetrieveLedgerId(providedLedgerId: LedgerId): LedgerId = { SQL"MERGE INTO #$MetaTable USING DUAL ON table_key = $MetaTableKey WHEN NOT MATCHED THEN INSERT (table_key, ledger_id) VALUES ($MetaTableKey, $providedLedgerId)" .executeInsert() SQL"SELECT ledger_id FROM #$MetaTable WHERE table_key = $MetaTableKey" .as(str("ledger_id").single) } - override def insertIntoLog(key: Key, value: Value)(implicit connection: Connection): Index = { + override def insertIntoLog(key: Key, value: Value): Index = { SQL"INSERT INTO #$LogTable (entry_id, envelope) VALUES ($key, $value)" .executeInsert() SQL"CALL IDENTITY()" @@ -32,3 +32,10 @@ final class H2Queries extends Queries with CommonQueries { override protected val updateStateQuery: String = s"MERGE INTO $StateTable VALUES ({key}, {value})" } + +object H2Queries { + def apply(connection: Connection): Queries = { + implicit val conn: Connection = connection + new H2Queries + } +} diff --git a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/PostgresqlQueries.scala b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/PostgresqlQueries.scala index 90555dc8a8..cd2e40a53d 100644 --- a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/PostgresqlQueries.scala +++ b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/PostgresqlQueries.scala @@ -12,17 +12,17 @@ import com.daml.ledger.on.sql.queries.Queries._ import com.daml.ledger.participant.state.v1.LedgerId import com.daml.ledger.validator.LedgerStateOperations.{Key, Value} -final class PostgresqlQueries extends Queries with CommonQueries { - override def updateOrRetrieveLedgerId( - providedLedgerId: LedgerId, - )(implicit connection: Connection): LedgerId = { +final class PostgresqlQueries(override protected implicit val connection: Connection) + extends Queries + with CommonQueries { + override def updateOrRetrieveLedgerId(providedLedgerId: LedgerId): LedgerId = { SQL"INSERT INTO #$MetaTable (table_key, ledger_id) VALUES ($MetaTableKey, $providedLedgerId) ON CONFLICT DO NOTHING" .executeInsert() SQL"SELECT ledger_id FROM #$MetaTable WHERE table_key = $MetaTableKey" .as(str("ledger_id").single) } - override def insertIntoLog(key: Key, value: Value)(implicit connection: Connection): Index = { + override def insertIntoLog(key: Key, value: Value): Index = { SQL"INSERT INTO #$LogTable (entry_id, envelope) VALUES ($key, $value) RETURNING sequence_no" .as(long("sequence_no").single) } @@ -30,3 +30,10 @@ final class PostgresqlQueries extends Queries with CommonQueries { override protected val updateStateQuery: String = s"INSERT INTO $StateTable VALUES ({key}, {value}) ON CONFLICT(key) DO UPDATE SET value = {value}" } + +object PostgresqlQueries { + def apply(connection: Connection): Queries = { + implicit val conn: Connection = connection + new PostgresqlQueries + } +} diff --git a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/Queries.scala b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/Queries.scala index 6c476c1e92..b98e24c2a4 100644 --- a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/Queries.scala +++ b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/Queries.scala @@ -6,33 +6,8 @@ package com.daml.ledger.on.sql.queries import java.sql.Connection import anorm.{BatchSql, NamedParameter} -import com.daml.ledger.on.sql.Index -import com.daml.ledger.participant.state.kvutils.api.LedgerRecord -import com.daml.ledger.participant.state.v1.LedgerId -import com.daml.ledger.validator.LedgerStateOperations.{Key, Value} -import scala.collection.immutable - -trait Queries { - def updateOrRetrieveLedgerId( - providedLedgerId: LedgerId, - )(implicit connection: Connection): LedgerId - - def selectLatestLogEntryId()(implicit connection: Connection): Option[Index] - - def selectFromLog( - start: Index, - end: Index, - )(implicit connection: Connection): immutable.Seq[(Index, LedgerRecord)] - - def insertIntoLog(key: Key, value: Value)(implicit connection: Connection): Index - - def selectStateValuesByKeys( - keys: Seq[Key], - )(implicit connection: Connection): immutable.Seq[Option[Value]] - - def updateState(stateUpdates: Seq[(Key, Value)])(implicit connection: Connection): Unit -} +trait Queries extends ReadQueries with WriteQueries object Queries { val TablePrefix = "ledger" diff --git a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/ReadQueries.scala b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/ReadQueries.scala new file mode 100644 index 0000000000..c04cb723ea --- /dev/null +++ b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/ReadQueries.scala @@ -0,0 +1,18 @@ +// Copyright (c) 2020 The DAML Authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.daml.ledger.on.sql.queries + +import com.daml.ledger.on.sql.Index +import com.daml.ledger.participant.state.kvutils.api.LedgerRecord +import com.daml.ledger.validator.LedgerStateOperations.{Key, Value} + +import scala.collection.immutable + +trait ReadQueries { + def selectLatestLogEntryId(): Option[Index] + + def selectFromLog(start: Index, end: Index): immutable.Seq[(Index, LedgerRecord)] + + def selectStateValuesByKeys(keys: Seq[Key]): immutable.Seq[Option[Value]] +} diff --git a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/SqliteQueries.scala b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/SqliteQueries.scala index b86f08a04c..ede961c375 100644 --- a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/SqliteQueries.scala +++ b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/SqliteQueries.scala @@ -12,17 +12,17 @@ import com.daml.ledger.on.sql.queries.Queries._ import com.daml.ledger.participant.state.v1.LedgerId import com.daml.ledger.validator.LedgerStateOperations.{Key, Value} -final class SqliteQueries extends Queries with CommonQueries { - override def updateOrRetrieveLedgerId( - providedLedgerId: LedgerId, - )(implicit connection: Connection): LedgerId = { +final class SqliteQueries(override protected implicit val connection: Connection) + extends Queries + with CommonQueries { + override def updateOrRetrieveLedgerId(providedLedgerId: LedgerId): LedgerId = { SQL"INSERT INTO #$MetaTable (table_key, ledger_id) VALUES ($MetaTableKey, $providedLedgerId) ON CONFLICT DO NOTHING" .executeInsert() SQL"SELECT ledger_id FROM #$MetaTable WHERE table_key = $MetaTableKey" .as(str("ledger_id").single) } - override def insertIntoLog(key: Key, value: Value)(implicit connection: Connection): Index = { + override def insertIntoLog(key: Key, value: Value): Index = { SQL"INSERT INTO #$LogTable (entry_id, envelope) VALUES ($key, $value)" .executeInsert() SQL"SELECT LAST_INSERT_ROWID()" @@ -32,3 +32,10 @@ final class SqliteQueries extends Queries with CommonQueries { override protected val updateStateQuery: String = s"INSERT INTO $StateTable VALUES ({key}, {value}) ON CONFLICT(key) DO UPDATE SET value = {value}" } + +object SqliteQueries { + def apply(connection: Connection): Queries = { + implicit val conn: Connection = connection + new SqliteQueries + } +} diff --git a/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/WriteQueries.scala b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/WriteQueries.scala new file mode 100644 index 0000000000..fcda599c0e --- /dev/null +++ b/ledger/ledger-on-sql/src/main/scala/com/daml/ledger/on/sql/queries/WriteQueries.scala @@ -0,0 +1,16 @@ +// Copyright (c) 2020 The DAML Authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.daml.ledger.on.sql.queries + +import com.daml.ledger.on.sql.Index +import com.daml.ledger.participant.state.v1.LedgerId +import com.daml.ledger.validator.LedgerStateOperations.{Key, Value} + +trait WriteQueries { + def updateOrRetrieveLedgerId(providedLedgerId: LedgerId): LedgerId + + def insertIntoLog(key: Key, value: Value): Index + + def updateState(stateUpdates: Seq[(Key, Value)]): Unit +}