ledger-on-sql: Split read queries out from write queries. (#4620)

* ledger-on-sql: Provide queries in the transaction lambda.

* ledger-on-sql: Split read queries out from write queries.

Can't run write queries from a read transaction.

CHANGELOG_BEGIN
CHANGELOG_END

* ledger-on-sql: Pass the connection into the `Queries` constructors.

Way less typing this way round.
This commit is contained in:
Samir Talwar 2020-02-20 14:48:37 +01:00 committed by GitHub
parent 87a8a2548e
commit 46e046a68b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 103 additions and 77 deletions

View File

@ -5,7 +5,13 @@ package com.daml.ledger.on.sql
import java.sql.Connection
import com.daml.ledger.on.sql.queries.{H2Queries, PostgresqlQueries, Queries, SqliteQueries}
import com.daml.ledger.on.sql.queries.{
H2Queries,
PostgresqlQueries,
Queries,
ReadQueries,
SqliteQueries
}
import com.digitalasset.logging.{ContextualizedLogger, LoggingContext}
import com.digitalasset.resources.ProgramResource.StartupException
import com.digitalasset.resources.ResourceOwner
@ -19,22 +25,22 @@ import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}
final class Database(
val queries: Queries,
queries: Connection => Queries,
readerConnectionPool: DataSource,
writerConnectionPool: DataSource,
) {
private val logger = ContextualizedLogger.get(this.getClass)
def inReadTransaction[T](message: String)(
body: Connection => Future[T],
body: ReadQueries => Future[T],
)(implicit executionContext: ExecutionContext, logCtx: LoggingContext): Future[T] = {
inTransaction(message, readerConnectionPool)(body)
inTransaction(message, readerConnectionPool)(connection => body(queries(connection)))
}
def inWriteTransaction[T](message: String)(
body: Connection => Future[T],
body: Queries => Future[T],
)(implicit executionContext: ExecutionContext, logCtx: LoggingContext): Future[T] = {
inTransaction(message, writerConnectionPool)(body)
inTransaction(message, writerConnectionPool)(connection => body(queries(connection)))
}
private def inTransaction[T](message: String, connectionPool: DataSource)(
@ -176,26 +182,26 @@ object Database {
sealed trait RDBMS {
val name: String
val queries: Queries
val queries: Connection => Queries
}
object RDBMS {
object H2 extends RDBMS {
override val name: String = "h2"
override val queries: Queries = new H2Queries
override val queries: Connection => Queries = H2Queries.apply
}
object PostgreSQL extends RDBMS {
override val name: String = "postgresql"
override val queries: Queries = new PostgresqlQueries
override val queries: Connection => Queries = PostgresqlQueries.apply
}
object SQLite extends RDBMS {
override val name: String = "sqlite"
override val queries: Queries = new SqliteQueries
override val queries: Connection => Queries = SqliteQueries.apply
}
}

View File

@ -3,7 +3,6 @@
package com.daml.ledger.on.sql
import java.sql.Connection
import java.time.{Clock, Instant}
import java.util.UUID
@ -11,6 +10,7 @@ import akka.NotUsed
import akka.stream.Materializer
import akka.stream.scaladsl.Source
import com.daml.ledger.on.sql.SqlLedgerReaderWriter._
import com.daml.ledger.on.sql.queries.Queries
import com.daml.ledger.participant.state.kvutils.api.{LedgerReader, LedgerRecord, LedgerWriter}
import com.daml.ledger.participant.state.v1._
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}
@ -46,8 +46,6 @@ final class SqlLedgerReaderWriter(
) extends LedgerWriter
with LedgerReader {
private val queries = database.queries
private val committer = new ValidatingCommitter[Index](
participantId,
now,
@ -65,9 +63,8 @@ final class SqlLedgerReaderWriter(
RangeSource((start, end) => {
Source
.futureSource(database
.inReadTransaction(s"Querying events [$start, $end[ from log") {
implicit connection =>
Future.successful(queries.selectFromLog(start, end))
.inReadTransaction(s"Querying events [$start, $end[ from log") { queries =>
Future.successful(queries.selectFromLog(start, end))
}
.map { result =>
if (result.length < end - start) {
@ -88,13 +85,12 @@ final class SqlLedgerReaderWriter(
object SqlLedgerStateAccess extends LedgerStateAccess[Index] {
override def inTransaction[T](body: LedgerStateOperations[Index] => Future[T]): Future[T] =
database.inWriteTransaction("Committing a submission") { implicit connection =>
body(new SqlLedgerStateOperations)
database.inWriteTransaction("Committing a submission") { queries =>
body(new SqlLedgerStateOperations(queries))
}
}
class SqlLedgerStateOperations(implicit connection: Connection)
extends BatchingLedgerStateOperations[Index] {
class SqlLedgerStateOperations(queries: Queries) extends BatchingLedgerStateOperations[Index] {
override def readState(keys: Seq[Key]): Future[Seq[Option[Value]]] =
Future.successful(queries.selectStateValuesByKeys(keys))
@ -132,10 +128,10 @@ object SqlLedgerReaderWriter {
implicit executionContext: ExecutionContext,
logCtx: LoggingContext,
): Future[LedgerId] =
database.inWriteTransaction("Checking ledger ID at startup") { implicit connection =>
database.inWriteTransaction("Checking ledger ID at startup") { queries =>
val providedLedgerId =
initialLedgerId.getOrElse(Ref.LedgerString.assertFromString(UUID.randomUUID.toString))
val ledgerId = database.queries.updateOrRetrieveLedgerId(providedLedgerId)
val ledgerId = queries.updateOrRetrieveLedgerId(providedLedgerId)
if (initialLedgerId.exists(_ != ledgerId)) {
Future.failed(
new LedgerIdMismatchException(
@ -152,9 +148,8 @@ object SqlLedgerReaderWriter {
logCtx: LoggingContext,
): Future[Dispatcher[Index]] =
database
.inReadTransaction("Reading head at startup") { implicit connection =>
Future.successful(
database.queries.selectLatestLogEntryId().map(_ + 1).getOrElse(StartIndex))
.inReadTransaction("Reading head at startup") { queries =>
Future.successful(queries.selectLatestLogEntryId().map(_ + 1).getOrElse(StartIndex))
}
.map(head => Dispatcher("sql-participant-state", StartIndex, head))
}

View File

@ -19,15 +19,14 @@ import com.google.protobuf.ByteString
import scala.collection.{breakOut, immutable}
trait CommonQueries extends Queries {
override final def selectLatestLogEntryId()(implicit connection: Connection): Option[Index] =
protected implicit val connection: Connection
override final def selectLatestLogEntryId(): Option[Index] =
SQL"SELECT MAX(sequence_no) max_sequence_no FROM #$LogTable"
.as(get[Option[Long]]("max_sequence_no").singleOpt)
.flatten
override final def selectFromLog(
start: Index,
end: Index,
)(implicit connection: Connection): immutable.Seq[(Index, LedgerRecord)] =
override final def selectFromLog(start: Index, end: Index): immutable.Seq[(Index, LedgerRecord)] =
SQL"SELECT sequence_no, entry_id, envelope FROM #$LogTable WHERE sequence_no >= $start AND sequence_no < $end"
.as(
(long("sequence_no") ~ binaryStream("entry_id") ~ byteArray("envelope")).map {
@ -36,9 +35,7 @@ trait CommonQueries extends Queries {
}.*,
)
override final def selectStateValuesByKeys(
keys: Seq[Key],
)(implicit connection: Connection): immutable.Seq[Option[Value]] = {
override final def selectStateValuesByKeys(keys: Seq[Key]): immutable.Seq[Option[Value]] = {
val results = SQL"SELECT key, value FROM #$StateTable WHERE key IN ($keys)"
.fold(Map.newBuilder[ByteString, Array[Byte]], ColumnAliaser.empty)((builder, row) =>
builder += ByteString.readFrom(row[InputStream]("key")) -> row[Value]("value"))
@ -48,9 +45,7 @@ trait CommonQueries extends Queries {
keys.map(key => results.get(ByteString.copyFrom(key)))(breakOut)
}
override final def updateState(
stateUpdates: Seq[(Key, Value)],
)(implicit connection: Connection): Unit =
override final def updateState(stateUpdates: Seq[(Key, Value)]): Unit =
executeBatchSql(updateStateQuery, stateUpdates.map {
case (key, value) => Seq[NamedParameter]("key" -> key, "value" -> value)
})

View File

@ -12,17 +12,17 @@ import com.daml.ledger.on.sql.queries.Queries._
import com.daml.ledger.participant.state.v1.LedgerId
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}
final class H2Queries extends Queries with CommonQueries {
override def updateOrRetrieveLedgerId(
providedLedgerId: LedgerId,
)(implicit connection: Connection): LedgerId = {
final class H2Queries(override protected implicit val connection: Connection)
extends Queries
with CommonQueries {
override def updateOrRetrieveLedgerId(providedLedgerId: LedgerId): LedgerId = {
SQL"MERGE INTO #$MetaTable USING DUAL ON table_key = $MetaTableKey WHEN NOT MATCHED THEN INSERT (table_key, ledger_id) VALUES ($MetaTableKey, $providedLedgerId)"
.executeInsert()
SQL"SELECT ledger_id FROM #$MetaTable WHERE table_key = $MetaTableKey"
.as(str("ledger_id").single)
}
override def insertIntoLog(key: Key, value: Value)(implicit connection: Connection): Index = {
override def insertIntoLog(key: Key, value: Value): Index = {
SQL"INSERT INTO #$LogTable (entry_id, envelope) VALUES ($key, $value)"
.executeInsert()
SQL"CALL IDENTITY()"
@ -32,3 +32,10 @@ final class H2Queries extends Queries with CommonQueries {
override protected val updateStateQuery: String =
s"MERGE INTO $StateTable VALUES ({key}, {value})"
}
object H2Queries {
def apply(connection: Connection): Queries = {
implicit val conn: Connection = connection
new H2Queries
}
}

View File

@ -12,17 +12,17 @@ import com.daml.ledger.on.sql.queries.Queries._
import com.daml.ledger.participant.state.v1.LedgerId
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}
final class PostgresqlQueries extends Queries with CommonQueries {
override def updateOrRetrieveLedgerId(
providedLedgerId: LedgerId,
)(implicit connection: Connection): LedgerId = {
final class PostgresqlQueries(override protected implicit val connection: Connection)
extends Queries
with CommonQueries {
override def updateOrRetrieveLedgerId(providedLedgerId: LedgerId): LedgerId = {
SQL"INSERT INTO #$MetaTable (table_key, ledger_id) VALUES ($MetaTableKey, $providedLedgerId) ON CONFLICT DO NOTHING"
.executeInsert()
SQL"SELECT ledger_id FROM #$MetaTable WHERE table_key = $MetaTableKey"
.as(str("ledger_id").single)
}
override def insertIntoLog(key: Key, value: Value)(implicit connection: Connection): Index = {
override def insertIntoLog(key: Key, value: Value): Index = {
SQL"INSERT INTO #$LogTable (entry_id, envelope) VALUES ($key, $value) RETURNING sequence_no"
.as(long("sequence_no").single)
}
@ -30,3 +30,10 @@ final class PostgresqlQueries extends Queries with CommonQueries {
override protected val updateStateQuery: String =
s"INSERT INTO $StateTable VALUES ({key}, {value}) ON CONFLICT(key) DO UPDATE SET value = {value}"
}
object PostgresqlQueries {
def apply(connection: Connection): Queries = {
implicit val conn: Connection = connection
new PostgresqlQueries
}
}

View File

@ -6,33 +6,8 @@ package com.daml.ledger.on.sql.queries
import java.sql.Connection
import anorm.{BatchSql, NamedParameter}
import com.daml.ledger.on.sql.Index
import com.daml.ledger.participant.state.kvutils.api.LedgerRecord
import com.daml.ledger.participant.state.v1.LedgerId
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}
import scala.collection.immutable
trait Queries {
def updateOrRetrieveLedgerId(
providedLedgerId: LedgerId,
)(implicit connection: Connection): LedgerId
def selectLatestLogEntryId()(implicit connection: Connection): Option[Index]
def selectFromLog(
start: Index,
end: Index,
)(implicit connection: Connection): immutable.Seq[(Index, LedgerRecord)]
def insertIntoLog(key: Key, value: Value)(implicit connection: Connection): Index
def selectStateValuesByKeys(
keys: Seq[Key],
)(implicit connection: Connection): immutable.Seq[Option[Value]]
def updateState(stateUpdates: Seq[(Key, Value)])(implicit connection: Connection): Unit
}
trait Queries extends ReadQueries with WriteQueries
object Queries {
val TablePrefix = "ledger"

View File

@ -0,0 +1,18 @@
// Copyright (c) 2020 The DAML Authors. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.ledger.on.sql.queries
import com.daml.ledger.on.sql.Index
import com.daml.ledger.participant.state.kvutils.api.LedgerRecord
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}
import scala.collection.immutable
trait ReadQueries {
def selectLatestLogEntryId(): Option[Index]
def selectFromLog(start: Index, end: Index): immutable.Seq[(Index, LedgerRecord)]
def selectStateValuesByKeys(keys: Seq[Key]): immutable.Seq[Option[Value]]
}

View File

@ -12,17 +12,17 @@ import com.daml.ledger.on.sql.queries.Queries._
import com.daml.ledger.participant.state.v1.LedgerId
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}
final class SqliteQueries extends Queries with CommonQueries {
override def updateOrRetrieveLedgerId(
providedLedgerId: LedgerId,
)(implicit connection: Connection): LedgerId = {
final class SqliteQueries(override protected implicit val connection: Connection)
extends Queries
with CommonQueries {
override def updateOrRetrieveLedgerId(providedLedgerId: LedgerId): LedgerId = {
SQL"INSERT INTO #$MetaTable (table_key, ledger_id) VALUES ($MetaTableKey, $providedLedgerId) ON CONFLICT DO NOTHING"
.executeInsert()
SQL"SELECT ledger_id FROM #$MetaTable WHERE table_key = $MetaTableKey"
.as(str("ledger_id").single)
}
override def insertIntoLog(key: Key, value: Value)(implicit connection: Connection): Index = {
override def insertIntoLog(key: Key, value: Value): Index = {
SQL"INSERT INTO #$LogTable (entry_id, envelope) VALUES ($key, $value)"
.executeInsert()
SQL"SELECT LAST_INSERT_ROWID()"
@ -32,3 +32,10 @@ final class SqliteQueries extends Queries with CommonQueries {
override protected val updateStateQuery: String =
s"INSERT INTO $StateTable VALUES ({key}, {value}) ON CONFLICT(key) DO UPDATE SET value = {value}"
}
object SqliteQueries {
def apply(connection: Connection): Queries = {
implicit val conn: Connection = connection
new SqliteQueries
}
}

View File

@ -0,0 +1,16 @@
// Copyright (c) 2020 The DAML Authors. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.ledger.on.sql.queries
import com.daml.ledger.on.sql.Index
import com.daml.ledger.participant.state.v1.LedgerId
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}
trait WriteQueries {
def updateOrRetrieveLedgerId(providedLedgerId: LedgerId): LedgerId
def insertIntoLog(key: Key, value: Value): Index
def updateState(stateUpdates: Seq[(Key, Value)]): Unit
}