mirror of
https://github.com/neilotoole/sq.git
synced 2024-12-18 13:41:49 +03:00
3f6157c4c4
- Switch to slog logger.
304 lines
8.0 KiB
Go
304 lines
8.0 KiB
Go
package postgres_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/neilotoole/sq/libsq/core/lg"
|
|
|
|
"github.com/neilotoole/sq/libsq/core/errz"
|
|
|
|
"github.com/neilotoole/sq/libsq/source"
|
|
|
|
_ "github.com/jackc/pgx/v4/stdlib"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/neilotoole/sq/drivers/postgres"
|
|
"github.com/neilotoole/sq/libsq/core/sqlmodel"
|
|
"github.com/neilotoole/sq/libsq/core/stringz"
|
|
"github.com/neilotoole/sq/testh"
|
|
"github.com/neilotoole/sq/testh/fixt"
|
|
"github.com/neilotoole/sq/testh/sakila"
|
|
)
|
|
|
|
func TestSmoke(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
for _, handle := range sakila.PgAll() {
|
|
handle := handle
|
|
t.Run(handle, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
th, src, _, _ := testh.NewWith(t, handle)
|
|
sink, err := th.QuerySQL(src, "SELECT * FROM actor")
|
|
require.NoError(t, err)
|
|
require.Equal(t, len(sakila.TblActorCols()), len(sink.RecMeta))
|
|
require.Equal(t, sakila.TblActorCount, len(sink.Recs))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDriverBehavior(t *testing.T) {
|
|
// This test was created to help understand the behavior of the driver impl.
|
|
// It can be deleted eventually.
|
|
testCases := sakila.PgAll()
|
|
|
|
for _, handle := range testCases {
|
|
handle := handle
|
|
|
|
t.Run(handle, func(t *testing.T) {
|
|
th := testh.New(t)
|
|
src := th.Source(handle)
|
|
db := th.Open(src).DB()
|
|
|
|
query := `SELECT
|
|
(SELECT actor_id FROM actor limit 1) AS actor_id,
|
|
(SELECT first_name FROM actor LIMIT 1) AS first_name,
|
|
(SELECT last_name FROM actor LIMIT 1) AS last_name
|
|
LIMIT 1`
|
|
|
|
rows, err := db.QueryContext(th.Context, query)
|
|
require.NoError(t, err)
|
|
require.NoError(t, rows.Err())
|
|
t.Cleanup(func() { assert.NoError(t, rows.Close()) })
|
|
|
|
colTypes, err := rows.ColumnTypes()
|
|
require.NoError(t, err)
|
|
|
|
for i, colType := range colTypes {
|
|
nullable, ok := colType.Nullable()
|
|
t.Logf("%d: %s %s %s nullable,ok={%v,%v}", i, colType.Name(), colType.DatabaseTypeName(),
|
|
colType.ScanType().Name(), nullable, ok)
|
|
|
|
if !nullable {
|
|
scanType := colType.ScanType()
|
|
z := reflect.Zero(scanType)
|
|
t.Logf("zero: %T %v", z, z)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_VerifyDriverDoesNotReportNullability(t *testing.T) {
|
|
t.Parallel()
|
|
// This test demonstrates that the backing pgx driver
|
|
// does not report column nullability (as one might hope).
|
|
//
|
|
// When/if the driver is modified to behave as hoped (if
|
|
// at all possible) then we can simplify the
|
|
// postgres driver wrapper.
|
|
testCases := sakila.PgAll()
|
|
for _, handle := range testCases {
|
|
handle := handle
|
|
|
|
t.Run(handle, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
th := testh.New(t)
|
|
src := th.Source(handle)
|
|
db := th.Open(src).DB()
|
|
|
|
_, actualTblName := createTypeTestTable(th, src, true)
|
|
t.Cleanup(func() { th.DropTable(src, actualTblName) })
|
|
|
|
rows, err := db.Query("SELECT * FROM " + actualTblName)
|
|
require.NoError(t, err)
|
|
require.NoError(t, rows.Err())
|
|
t.Cleanup(func() { assert.NoError(t, rows.Close()) })
|
|
|
|
colTypes, err := rows.ColumnTypes()
|
|
require.NoError(t, err)
|
|
|
|
for _, colType := range colTypes {
|
|
colName := colType.Name()
|
|
|
|
// The _n suffix indicates a nullable col
|
|
if !strings.HasSuffix(colName, "_n") {
|
|
continue
|
|
}
|
|
|
|
// The col is indicated as nullable via its name/suffix
|
|
nullable, hasNullable := colType.Nullable()
|
|
require.False(t, hasNullable, "ColumnType.hasNullable is unfortunately expected to be false for {%s}",
|
|
colName)
|
|
require.False(t, nullable, "ColumnType.nullable is unfortunately expected to be false for {%s}", colName)
|
|
}
|
|
|
|
for rows.Next() {
|
|
require.NoError(t, rows.Err())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetTableColumnNames(t *testing.T) {
|
|
testCases := sakila.PgAll()
|
|
|
|
for _, handle := range testCases {
|
|
handle := handle
|
|
t.Run(handle, func(t *testing.T) {
|
|
th := testh.New(t)
|
|
src := th.Source(handle)
|
|
|
|
colNames, err := postgres.GetTableColumnNames(th.Context, th.Open(src).DB(), sakila.TblActor)
|
|
require.NoError(t, err)
|
|
require.Equal(t, sakila.TblActorCols(), colNames)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDriver_CreateTable_NotNullDefault(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := sakila.PgAll()
|
|
for _, handle := range testCases {
|
|
handle := handle
|
|
|
|
t.Run(handle, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
th, src, dbase, drvr := testh.NewWith(t, handle)
|
|
|
|
tblName := stringz.UniqTableName(t.Name())
|
|
colNames, colKinds := fixt.ColNamePerKind(drvr.Dialect().IntBool, false, false)
|
|
|
|
tblDef := sqlmodel.NewTableDef(tblName, colNames, colKinds)
|
|
for _, colDef := range tblDef.Cols {
|
|
colDef.NotNull = true
|
|
colDef.HasDefault = true
|
|
}
|
|
|
|
err := drvr.CreateTable(th.Context, dbase.DB(), tblDef)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() { th.DropTable(src, tblName) })
|
|
|
|
th.InsertDefaultRow(src, tblName)
|
|
|
|
sink, err := th.QuerySQL(src, "SELECT * FROM "+tblName)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 1, len(sink.Recs))
|
|
require.Equal(t, len(colNames), len(sink.RecMeta))
|
|
for i := range colNames {
|
|
require.NotNil(t, sink.Recs[0][i])
|
|
_, ok := sink.RecMeta[i].Nullable()
|
|
require.False(t, ok, "postgres driver doesn't report nullability")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// getAdminSource returns a new src that users the "postgres" user
|
|
// instead of "sakila" user. This is useful if we need to create
|
|
// databases, etc. A better solution is probably to change the
|
|
// sakila DBs to grant all privileges to "sakila".
|
|
func getAdminSakilaSource(src *source.Source) *source.Source {
|
|
s := *src
|
|
s.Location = strings.Replace(s.Location, "postgres://sakila:", "postgres://postgres:", 1)
|
|
return &s
|
|
}
|
|
|
|
// TestAlternateSchema verifies that we can access a schema
|
|
// other than the default ("public").
|
|
func TestAlternateSchema(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
th := testh.New(t)
|
|
ctx := th.Context
|
|
|
|
src := getAdminSakilaSource(th.Source(sakila.Pg))
|
|
t.Logf("Using src: {%s}", src)
|
|
|
|
dbase := th.Open(src)
|
|
db := dbase.DB()
|
|
require.NoError(t, db.Ping())
|
|
|
|
schemaName := stringz.UniqSuffix("test_schema")
|
|
err := createSchema(ctx, db, schemaName)
|
|
t.Logf("Created schema {%s} in {%s}", schemaName, src)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
assert.NoError(t, dropSchema(ctx, db, schemaName))
|
|
})
|
|
|
|
tblName := stringz.UniqSuffix("test_table")
|
|
const wantRowCount = 5
|
|
require.NoError(t, createSimpleTable(ctx, db, schemaName, tblName, wantRowCount))
|
|
|
|
// We create a new src to point to the new schema.
|
|
// We change the schema by setting the search_path param.
|
|
src2 := src.Clone()
|
|
src2.Handle += "2"
|
|
src2.Location += "?search_path=" + schemaName
|
|
dbase2 := th.Open(src2)
|
|
md2, err := dbase2.SourceMetadata(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, schemaName, md2.Schema)
|
|
|
|
tblMeta2, err := dbase2.TableMetadata(ctx, tblName)
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(wantRowCount), tblMeta2.RowCount)
|
|
}
|
|
|
|
func createSchema(ctx context.Context, db *sql.DB, name string) error {
|
|
const tpl = `DROP SCHEMA IF EXISTS %q CASCADE;
|
|
CREATE SCHEMA %q;`
|
|
stmt := fmt.Sprintf(tpl, name, name)
|
|
_, err := db.ExecContext(ctx, stmt)
|
|
return err
|
|
}
|
|
|
|
func dropSchema(ctx context.Context, db *sql.DB, name string) error {
|
|
const tpl = `DROP SCHEMA IF EXISTS %q CASCADE;`
|
|
stmt := fmt.Sprintf(tpl, name)
|
|
_, err := db.ExecContext(ctx, stmt)
|
|
return err
|
|
}
|
|
|
|
func createSimpleTable(ctx context.Context, db *sql.DB, schemaName, tblName string, insertRowCount int) error {
|
|
const tpl = `CREATE TABLE %q.%q
|
|
(
|
|
id serial PRIMARY KEY,
|
|
NAME VARCHAR(255)
|
|
);`
|
|
|
|
stmt := fmt.Sprintf(tpl, schemaName, tblName)
|
|
|
|
_, err := db.ExecContext(ctx, stmt)
|
|
if err != nil {
|
|
return errz.Err(err)
|
|
}
|
|
|
|
stmt = fmt.Sprintf("INSERT INTO %q.%q (NAME) VALUES ($1)", schemaName, tblName)
|
|
|
|
for i := 0; i < insertRowCount; i++ {
|
|
_, err = db.ExecContext(ctx, stmt, fmt.Sprintf("name-%d", i))
|
|
if err != nil {
|
|
return errz.Err(err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func BenchmarkDatabase_SourceMetadata(b *testing.B) {
|
|
for _, handle := range sakila.PgAll() {
|
|
handle := handle
|
|
b.Run(handle, func(b *testing.B) {
|
|
th := testh.New(b)
|
|
th.Log = lg.Discard()
|
|
dbase := th.Open(th.Source(handle))
|
|
b.ResetTimer()
|
|
|
|
md, err := dbase.SourceMetadata(th.Context)
|
|
require.NoError(b, err)
|
|
require.Equal(b, "sakila", md.Name)
|
|
})
|
|
}
|
|
}
|