diff --git a/bin/server/cmd/root.go b/bin/server/cmd/root.go index f86306b4..5a96b5b6 100644 --- a/bin/server/cmd/root.go +++ b/bin/server/cmd/root.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "os" @@ -26,8 +27,10 @@ type Main struct { } func (m *Main) Run() error { + ctx := context.Background() + db := DB.NewDB(m.profile) - if err := db.Open(); err != nil { + if err := db.Open(ctx); err != nil { return fmt.Errorf("cannot open db: %w", err) } diff --git a/store/db/db.go b/store/db/db.go index e068fab4..5f7ce059 100644 --- a/store/db/db.go +++ b/store/db/db.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "embed" "errors" @@ -38,7 +39,7 @@ func NewDB(profile *profile.Profile) *DB { return db } -func (db *DB) Open() (err error) { +func (db *DB) Open(ctx context.Context) (err error) { // Ensure a DSN is set before attempting to open the database. if db.profile.DSN == "" { return fmt.Errorf("dsn required") @@ -53,32 +54,32 @@ func (db *DB) Open() (err error) { db.Db = sqlDB // If mode is dev, we should migrate and seed the database. if db.profile.Mode == "dev" { - if err := db.applyLatestSchema(); err != nil { + if err := db.applyLatestSchema(ctx); err != nil { return fmt.Errorf("failed to apply latest schema: %w", err) } - if err := db.seed(); err != nil { + if err := db.seed(ctx); err != nil { return fmt.Errorf("failed to seed: %w", err) } } else { // If db file not exists, we should migrate the database. if _, err := os.Stat(db.profile.DSN); errors.Is(err, os.ErrNotExist) { - err := db.applyLatestSchema() + err := db.applyLatestSchema(ctx) if err != nil { return fmt.Errorf("failed to apply latest schema: %w", err) } } else { - err := db.createMigrationHistoryTable() + err := db.createMigrationHistoryTable(ctx) if err != nil { return fmt.Errorf("failed to create migration_history table: %w", err) } currentVersion := common.GetCurrentVersion(db.profile.Mode) - migrationHistory, err := findMigrationHistory(db.Db, &MigrationHistoryFind{}) + migrationHistory, err := db.FindMigrationHistory(ctx, &MigrationHistoryFind{}) if err != nil { return err } if migrationHistory == nil { - migrationHistory, err = upsertMigrationHistory(db.Db, &MigrationHistoryUpsert{ + migrationHistory, err = db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{ Version: currentVersion, }) if err != nil { @@ -105,7 +106,7 @@ func (db *DB) Open() (err error) { normalizedVersion := minorVersion + ".0" if common.IsVersionGreaterThan(normalizedVersion, migrationHistory.Version) && common.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) { println("applying migration for", normalizedVersion) - if err := db.applyMigrationForMinorVersion(minorVersion); err != nil { + if err := db.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { return fmt.Errorf("failed to apply minor version migration: %w", err) } } @@ -127,20 +128,20 @@ const ( latestSchemaFileName = "LATEST__SCHEMA.sql" ) -func (db *DB) applyLatestSchema() error { +func (db *DB) applyLatestSchema(ctx context.Context) error { latestSchemaPath := fmt.Sprintf("%s/%s/%s", "migration", db.profile.Mode, latestSchemaFileName) buf, err := migrationFS.ReadFile(latestSchemaPath) if err != nil { return fmt.Errorf("failed to read latest schema %q, error %w", latestSchemaPath, err) } stmt := string(buf) - if err := db.execute(stmt); err != nil { + if err := db.execute(ctx, stmt); err != nil { return fmt.Errorf("migrate error: statement:%s err=%w", stmt, err) } return nil } -func (db *DB) applyMigrationForMinorVersion(minorVersion string) error { +func (db *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error { filenames, err := fs.Glob(migrationFS, fmt.Sprintf("%s/%s/*.sql", "migration/prod", minorVersion)) if err != nil { return err @@ -157,22 +158,32 @@ func (db *DB) applyMigrationForMinorVersion(minorVersion string) error { } stmt := string(buf) migrationStmt += stmt - if err := db.execute(stmt); err != nil { + if err := db.execute(ctx, stmt); err != nil { return fmt.Errorf("migrate error: statement:%s err=%w", stmt, err) } } + tx, err := db.Db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + // upsert the newest version to migration_history - if _, err = upsertMigrationHistory(db.Db, &MigrationHistoryUpsert{ + if _, err = upsertMigrationHistory(ctx, tx, &MigrationHistoryUpsert{ Version: minorVersion + ".0", }); err != nil { return err } + if err := tx.Commit(); err != nil { + return err + } + return nil } -func (db *DB) seed() error { +func (db *DB) seed(ctx context.Context) error { filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed")) if err != nil { return err @@ -187,7 +198,7 @@ func (db *DB) seed() error { return fmt.Errorf("failed to read seed file, filename=%s err=%w", filename, err) } stmt := string(buf) - if err := db.execute(stmt); err != nil { + if err := db.execute(ctx, stmt); err != nil { return fmt.Errorf("seed error: statement:%s err=%w", stmt, err) } } @@ -195,18 +206,22 @@ func (db *DB) seed() error { } // excecute runs a single SQL statement within a transaction. -func (db *DB) execute(stmt string) error { +func (db *DB) execute(ctx context.Context, stmt string) error { tx, err := db.Db.Begin() if err != nil { return err } defer tx.Rollback() - if _, err := tx.Exec(stmt); err != nil { + if _, err := tx.ExecContext(ctx, stmt); err != nil { return err } - return tx.Commit() + if err := tx.Commit(); err != nil { + return err + } + + return nil } // minorDirRegexp is a regular expression for minor version directory. @@ -234,8 +249,14 @@ func getMinorVersionList() []string { } // createMigrationHistoryTable creates the migration_history table if it doesn't exist. -func (db *DB) createMigrationHistoryTable() error { - if err := createTable(db.Db, ` +func (db *DB) createMigrationHistoryTable(ctx context.Context) error { + tx, err := db.Db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if err := createTable(ctx, tx, ` CREATE TABLE IF NOT EXISTS migration_history ( version TEXT NOT NULL PRIMARY KEY, created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')) @@ -244,5 +265,9 @@ func (db *DB) createMigrationHistoryTable() error { return err } + if err := tx.Commit(); err != nil { + return err + } + return nil } diff --git a/store/db/migration_history.go b/store/db/migration_history.go index 8e74c0c5..da6eac8c 100644 --- a/store/db/migration_history.go +++ b/store/db/migration_history.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "strings" ) @@ -18,23 +19,61 @@ type MigrationHistoryFind struct { Version *string } -func findMigrationHistoryList(db *sql.DB, find *MigrationHistoryFind) ([]*MigrationHistory, error) { +func (db *DB) FindMigrationHistory(ctx context.Context, find *MigrationHistoryFind) (*MigrationHistory, error) { + tx, err := db.Db.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() + + list, err := findMigrationHistoryList(ctx, tx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } else { + return list[0], nil + } +} + +func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { + tx, err := db.Db.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() + + migrationHistory, err := upsertMigrationHistory(ctx, tx, upsert) + if err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return migrationHistory, nil +} + +func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHistoryFind) ([]*MigrationHistory, error) { where, args := []string{"1 = 1"}, []interface{}{} if v := find.Version; v != nil { where, args = append(where, "version = ?"), append(args, *v) } - rows, err := db.Query(` + query := ` SELECT version, created_ts FROM migration_history - WHERE `+strings.Join(where, " AND ")+` - ORDER BY created_ts DESC`, - args..., - ) + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY created_ts DESC + ` + rows, err := tx.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -56,21 +95,8 @@ func findMigrationHistoryList(db *sql.DB, find *MigrationHistoryFind) ([]*Migrat return migrationHistoryList, nil } -func findMigrationHistory(db *sql.DB, find *MigrationHistoryFind) (*MigrationHistory, error) { - list, err := findMigrationHistoryList(db, find) - if err != nil { - return nil, err - } - - if len(list) == 0 { - return nil, nil - } else { - return list[0], nil - } -} - -func upsertMigrationHistory(db *sql.DB, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { - row, err := db.Query(` +func upsertMigrationHistory(ctx context.Context, tx *sql.Tx, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { + query := ` INSERT INTO migration_history ( version ) @@ -79,9 +105,8 @@ func upsertMigrationHistory(db *sql.DB, upsert *MigrationHistoryUpsert) (*Migrat SET version=EXCLUDED.version RETURNING version, created_ts - `, - upsert.Version, - ) + ` + row, err := tx.QueryContext(ctx, query, upsert.Version) if err != nil { return nil, err } diff --git a/store/db/table.go b/store/db/table.go index bd069ea2..11460f57 100644 --- a/store/db/table.go +++ b/store/db/table.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "strings" ) @@ -11,20 +12,19 @@ type Table struct { } //lint:ignore U1000 Ignore unused function temporarily for debugging -func findTable(db *sql.DB, tableName string) (*Table, error) { +func findTable(ctx context.Context, tx *sql.Tx, tableName string) (*Table, error) { where, args := []string{"1 = 1"}, []interface{}{} where, args = append(where, "type = ?"), append(args, "table") where, args = append(where, "name = ?"), append(args, tableName) - rows, err := db.Query(` + query := ` SELECT tbl_name, sql FROM sqlite_schema - WHERE `+strings.Join(where, " AND "), - args..., - ) + WHERE ` + strings.Join(where, " AND ") + rows, err := tx.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -54,13 +54,11 @@ func findTable(db *sql.DB, tableName string) (*Table, error) { } } -func createTable(db *sql.DB, sql string) error { - result, err := db.Exec(sql) +func createTable(ctx context.Context, tx *sql.Tx, stmt string) error { + _, err := tx.ExecContext(ctx, stmt) if err != nil { return err } - _, err = result.RowsAffected() - - return err + return nil }