fixed broken mysql tests (parseTime param); moved some test funcs to pkg tutil (#109)

This commit is contained in:
Neil O'Toole 2022-12-16 19:09:49 -07:00 committed by GitHub
parent e674cdc724
commit 6a0878bc6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 548 additions and 177 deletions

View File

@ -1,6 +1,7 @@
package cli_test
import (
"github.com/neilotoole/sq/testh/tutil"
"testing"
"github.com/neilotoole/sq/drivers/csv"
@ -58,7 +59,7 @@ func TestCmdAdd(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(testh.Name(tc.wantHandle, tc.loc, tc.driver), func(t *testing.T) {
t.Run(tutil.Name(tc.wantHandle, tc.loc, tc.driver), func(t *testing.T) {
args := []string{"add", tc.loc}
if tc.handle != "" {
args = append(args, "--handle="+tc.handle)

View File

@ -2,6 +2,7 @@ package cli_test
import (
"encoding/json"
"github.com/neilotoole/sq/testh/tutil"
"os"
"testing"
@ -125,7 +126,7 @@ func TestCmdInspect_Stdin(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(testh.Name(tc.fpath), func(t *testing.T) {
t.Run(tutil.Name(tc.fpath), func(t *testing.T) {
f, err := os.Open(tc.fpath) // No need to close f
require.NoError(t, err)

View File

@ -2,6 +2,7 @@ package cli_test
import (
"fmt"
"github.com/neilotoole/sq/testh/tutil"
"os"
"strings"
"testing"
@ -22,7 +23,7 @@ func TestCmdSQL_Insert(t *testing.T) {
origin := origin
t.Run("origin_"+origin, func(t *testing.T) {
testh.SkipShort(t, origin == sakila.XLSX)
tutil.SkipShort(t, origin == sakila.XLSX)
for _, dest := range sakila.SQLLatest() {
dest := dest
@ -126,7 +127,7 @@ func TestCmdSQL_StdinQuery(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(testh.Name(tc.fpath), func(t *testing.T) {
t.Run(tutil.Name(tc.fpath), func(t *testing.T) {
t.Parallel()
f, err := os.Open(tc.fpath)

View File

@ -1,6 +1,7 @@
package config_test
import (
"github.com/neilotoole/sq/testh/tutil"
"io/ioutil"
"path/filepath"
"testing"
@ -9,7 +10,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/neilotoole/sq/cli/config"
"github.com/neilotoole/sq/testh"
"github.com/neilotoole/sq/testh/proj"
)
@ -72,7 +72,7 @@ func TestFileStore_Load(t *testing.T) {
for _, match := range good {
match := match
t.Run(testh.Name(match), func(t *testing.T) {
t.Run(tutil.Name(match), func(t *testing.T) {
fs.Path = match
_, err = fs.Load()
require.NoError(t, err, match)
@ -82,7 +82,7 @@ func TestFileStore_Load(t *testing.T) {
for _, match := range bad {
match := match
t.Run(testh.Name(match), func(t *testing.T) {
t.Run(tutil.Name(match), func(t *testing.T) {
fs.Path = match
_, err = fs.Load()
require.Error(t, err, match)

View File

@ -3,6 +3,7 @@ package output_test
import (
"context"
"fmt"
"github.com/neilotoole/sq/testh/tutil"
"testing"
"time"
@ -122,7 +123,7 @@ func TestRecordWriterAdapter_FlushAfterDuration(t *testing.T) {
testCases := []struct {
flushAfter time.Duration
wantFlushed int
assertFn testh.AssertCompareFunc
assertFn tutil.AssertCompareFunc
}{
{flushAfter: -1, wantFlushed: 0, assertFn: require.Equal},
{flushAfter: 0, wantFlushed: 0, assertFn: require.Equal},

View File

@ -296,7 +296,12 @@ func makeMapStringInterface(n int) map[string]any {
}
func testName(v any) string {
return fmt.Sprintf("%T", v)
name := fmt.Sprintf("%T", v)
if len(name) > 80 {
name = name[:80]
}
return name
}
type codeResponse2 struct {

View File

@ -3,6 +3,7 @@ package json_test
import (
"bytes"
stdj "encoding/json"
"github.com/neilotoole/sq/testh/tutil"
"io"
"io/ioutil"
"os"
@ -73,7 +74,7 @@ func TestImportJSONL_Flat(t *testing.T) {
for i, tc := range testCases {
tc := tc
t.Run(testh.Name(i, tc.fpath, tc.input), func(t *testing.T) {
t.Run(tutil.Name(i, tc.fpath, tc.input), func(t *testing.T) {
openFn := func() (io.ReadCloser, error) {
return ioutil.NopCloser(strings.NewReader(tc.input)), nil
}
@ -170,7 +171,7 @@ func TestScanObjectsInArray(t *testing.T) {
for i, tc := range testCases {
tc := tc
t.Run(testh.Name(i, tc.in), func(t *testing.T) {
t.Run(tutil.Name(i, tc.in), func(t *testing.T) {
r := bytes.NewReader([]byte(tc.in))
gotObjs, gotChunks, err := json.ScanObjectsInArray(r)
if tc.wantErr {
@ -202,7 +203,7 @@ func TestScanObjectsInArray_Files(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(testh.Name(tc.fname), func(t *testing.T) {
t.Run(tutil.Name(tc.fname), func(t *testing.T) {
f, err := os.Open(tc.fname)
require.NoError(t, err)
defer f.Close()
@ -241,7 +242,7 @@ func TestColumnOrderFlat(t *testing.T) {
for i, tc := range testCases {
tc := tc
t.Run(testh.Name(i, tc.in), func(t *testing.T) {
t.Run(tutil.Name(i, tc.in), func(t *testing.T) {
require.True(t, stdj.Valid([]byte(tc.in)))
gotCols, err := json.ColumnOrderFlat([]byte(tc.in))

View File

@ -3,6 +3,7 @@ package json_test
import (
"context"
"fmt"
"github.com/neilotoole/sq/testh/tutil"
"io"
"os"
"path/filepath"
@ -13,7 +14,6 @@ import (
"github.com/neilotoole/sq/drivers/json"
"github.com/neilotoole/sq/libsq/source"
"github.com/neilotoole/sq/testh"
)
func TestTypeDetectorFuncs(t *testing.T) {
@ -87,7 +87,7 @@ func TestTypeDetectorFuncs(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(testh.Name(tc.fn, tc.fname), func(t *testing.T) {
t.Run(tutil.Name(tc.fn, tc.fname), func(t *testing.T) {
openFn := func() (io.ReadCloser, error) { return os.Open(filepath.Join("testdata", tc.fname)) }
detectFn := detectFns[tc.fn]

View File

@ -2,6 +2,7 @@ package mysql_test
import (
"fmt"
"github.com/neilotoole/sq/testh/tutil"
"io/ioutil"
"strings"
"testing"
@ -431,7 +432,7 @@ func TestDatabaseTypeJSON(t *testing.T) {
require.Equal(t, len(testVals), len(sink.Recs))
for i := range testVals {
for j := range testVals[i] {
require.Equal(t, testVals[i][j], testh.Val(sink.Recs[i][j]))
require.Equal(t, testVals[i][j], tutil.Val(sink.Recs[i][j]))
}
}
})

View File

@ -1,6 +1,7 @@
package mysql
import (
"github.com/neilotoole/sq/testh/tutil"
"testing"
"github.com/go-sql-driver/mysql"
@ -47,38 +48,60 @@ func TestHasErrCode(t *testing.T) {
func TestDSNFromLocation(t *testing.T) {
testCases := []struct {
loc string
wantDSN string
wantErr bool
loc string
parseTime bool
wantDSN string
wantErr bool
}{
{
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest",
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest",
wantErr: false,
},
{
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?allowOldPasswords=1",
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowOldPasswords=1",
wantErr: false,
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest",
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?parseTime=true",
parseTime: true,
},
{
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?allowCleartextPasswords=true&allowOldPasswords=1",
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowCleartextPasswords=true&allowOldPasswords=1",
wantErr: false,
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?parseTime=true",
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?parseTime=true",
parseTime: true,
},
{
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?allowOldPasswords=true",
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowOldPasswords=true",
},
{
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?allowOldPasswords=true",
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowOldPasswords=true&parseTime=true",
parseTime: true,
},
{
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?allowOldPasswords=true",
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowOldPasswords=true",
},
{
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?allowCleartextPasswords=true&allowOldPasswords=true",
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowCleartextPasswords=true&allowOldPasswords=true",
},
{
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?parseTime=true&allowCleartextPasswords=true&allowOldPasswords=true",
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowCleartextPasswords=true&allowOldPasswords=true&parseTime=true",
parseTime: true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.loc, func(t *testing.T) {
t.Run(tutil.Name(tc.loc, tc.parseTime), func(t *testing.T) {
src := &source.Source{
Handle: "@testhandle",
Type: Type,
Location: tc.loc,
}
gotDSN, gotErr := dsnFromLocation(src)
gotDSN, gotErr := dsnFromLocation(src, tc.parseTime)
if tc.wantErr {
require.Error(t, gotErr)
return

View File

@ -39,7 +39,9 @@ func kindFromDBTypeName(log lg.Log, colName, dbTypeName string) kind.Kind {
knd = kind.Unknown
case "":
knd = kind.Unknown
case "INTEGER", "INT", "TINYINT", "SMALLINT", "MEDIUMINT", "BIGINT", "YEAR", "BIT":
case "INTEGER", "INT", "TINYINT", "SMALLINT", "MEDIUMINT", "BIGINT", "YEAR", "BIT",
"UNSIGNED INTEGER", "UNSIGNED INT", "UNSIGNED TINYINT",
"UNSIGNED SMALLINT", "UNSIGNED MEDIUMINT", "UNSIGNED BIGINT":
knd = kind.Int
case "DECIMAL", "NUMERIC":
knd = kind.Decimal

View File

@ -259,7 +259,7 @@ func (d *driveri) getTableRecordMeta(ctx context.Context, db sqlz.DB, tblName st
// Open implements driver.Driver.
func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Database, error) {
dsn, err := dsnFromLocation(src)
dsn, err := dsnFromLocation(src, true)
if err != nil {
return nil, err
}
@ -296,7 +296,7 @@ func (d *driveri) Ping(ctx context.Context, src *source.Source) error {
// the TRUNCATE statement.
func (d *driveri) Truncate(ctx context.Context, src *source.Source, tbl string, reset bool) (affected int64, err error) {
// https://dev.mysql.com/doc/refman/8.0/en/truncate-table.html
dsn, err := dsnFromLocation(src)
dsn, err := dsnFromLocation(src, true)
if err != nil {
return 0, err
}
@ -390,8 +390,10 @@ func hasErrCode(err error, code uint16) bool {
const errNumTableNotExist = uint16(1146)
// dsnFromLocation extracts the mysql driver DSN from src.Location.
func dsnFromLocation(src *source.Source) (string, error) {
// dsnFromLocation builds the mysql driver DSN from src.Location.
// If parseTime is true, the param "parseTime=true" is added. This
// is because of: https://stackoverflow.com/questions/29341590/how-to-parse-time-from-database/29343013#29343013
func dsnFromLocation(src *source.Source, parseTime bool) (string, error) {
if !strings.HasPrefix(src.Location, "mysql://") || len(src.Location) < 10 {
return "", errz.Errorf("invalid source location %s", src.RedactedLocation())
}
@ -406,10 +408,13 @@ func dsnFromLocation(src *source.Source) (string, error) {
// Driver DSN: sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowOldPasswords=1
driverDSN := u.DSN
_, err = mysql.ParseDSN(driverDSN) // verify
myCfg, err := mysql.ParseDSN(driverDSN) // verify
if err != nil {
return "", errz.Wrapf(err, "invalid source location: %q", driverDSN)
}
myCfg.ParseTime = parseTime
driverDSN = myCfg.FormatDSN()
return driverDSN, nil
}

View File

@ -2,6 +2,7 @@ package sqlite3_test
import (
"fmt"
"github.com/neilotoole/sq/testh/tutil"
"io/ioutil"
"strings"
"testing"
@ -188,7 +189,7 @@ func TestDatabaseTypes(t *testing.T) {
continue
}
require.Equal(t, wantVal, testh.Val(gotVal),
require.Equal(t, wantVal, tutil.Val(gotVal),
"%s[%d][%d] (%s) expected %T(%v) but got %T(%v)",
actualTblName, i, j, typeTestColNames[j], wantVal, wantVal, gotVal, gotVal)
}

View File

@ -2,6 +2,7 @@ package xmlud_test
import (
"bytes"
"github.com/neilotoole/sq/testh/tutil"
"testing"
"github.com/stretchr/testify/assert"
@ -49,19 +50,19 @@ func TestImport_Ppl(t *testing.T) {
sink, err := th.QuerySQL(scratchDB.Source(), "SELECT * FROM person")
require.NoError(t, err)
require.Equal(t, 3, len(sink.Recs))
require.Equal(t, "Nikola", testh.Val(sink.Recs[0][1]))
require.Equal(t, "Nikola", tutil.Val(sink.Recs[0][1]))
for i, rec := range sink.Recs {
// Verify that the primary id cols are sequential
require.Equal(t, int64(i+1), testh.Val(rec[0]))
require.Equal(t, int64(i+1), tutil.Val(rec[0]))
}
sink, err = th.QuerySQL(scratchDB.Source(), "SELECT * FROM skill")
require.NoError(t, err)
require.Equal(t, 6, len(sink.Recs))
require.Equal(t, "Electrifying", testh.Val(sink.Recs[0][2]))
require.Equal(t, "Electrifying", tutil.Val(sink.Recs[0][2]))
for i, rec := range sink.Recs {
// Verify that the primary id cols are sequential
require.Equal(t, int64(i+1), testh.Val(rec[0]))
require.Equal(t, int64(i+1), tutil.Val(rec[0]))
}
}
@ -95,27 +96,27 @@ func TestImport_RSS(t *testing.T) {
sink, err := th.QuerySQL(scratchDB.Source(), "SELECT * FROM channel")
require.NoError(t, err)
require.Equal(t, 1, len(sink.Recs))
require.Equal(t, "NYT > World", testh.Val(sink.Recs[0][1]))
require.Equal(t, "NYT > World", tutil.Val(sink.Recs[0][1]))
for i, rec := range sink.Recs {
// Verify that the primary id cols are sequential
require.Equal(t, int64(i+1), testh.Val(rec[0]))
require.Equal(t, int64(i+1), tutil.Val(rec[0]))
}
sink, err = th.QuerySQL(scratchDB.Source(), "SELECT * FROM category")
require.NoError(t, err)
require.Equal(t, 251, len(sink.Recs))
require.EqualValues(t, "Extradition", testh.Val(sink.Recs[0][2]))
require.EqualValues(t, "Extradition", tutil.Val(sink.Recs[0][2]))
for i, rec := range sink.Recs {
// Verify that the primary id cols are sequential
require.Equal(t, int64(i+1), testh.Val(rec[0]))
require.Equal(t, int64(i+1), tutil.Val(rec[0]))
}
sink, err = th.QuerySQL(scratchDB.Source(), "SELECT * FROM item")
require.NoError(t, err)
require.Equal(t, 45, len(sink.Recs))
require.EqualValues(t, "Trilobites: Fishing for Clues to Solve Namibias Fairy Circle Mystery", testh.Val(sink.Recs[17][4]))
require.EqualValues(t, "Trilobites: Fishing for Clues to Solve Namibias Fairy Circle Mystery", tutil.Val(sink.Recs[17][4]))
for i, rec := range sink.Recs {
// Verify that the primary id cols are sequential
require.Equal(t, int64(i+1), testh.Val(rec[0]))
require.Equal(t, int64(i+1), tutil.Val(rec[0]))
}
}

View File

@ -1,6 +1,7 @@
package xlsx_test
import (
"github.com/neilotoole/sq/testh/tutil"
"testing"
"github.com/stretchr/testify/require"
@ -24,7 +25,7 @@ func Test_Smoke_Subset(t *testing.T) {
}
func Test_Smoke_Full(t *testing.T) {
testh.SkipShort(t, true)
tutil.SkipShort(t, true)
th := testh.New(t)
src := th.Source(sakila.XLSX)

2
go.mod
View File

@ -14,7 +14,7 @@ require (
github.com/djherbis/fscache v0.10.1
github.com/emirpasic/gods v1.9.0
github.com/fatih/color v1.13.0
github.com/go-sql-driver/mysql v1.6.0
github.com/go-sql-driver/mysql v1.7.0
github.com/google/uuid v1.3.0
github.com/h2non/filetype v1.1.0
github.com/jackc/pgconn v1.5.0

2
go.sum
View File

@ -352,6 +352,8 @@ github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/godbus/dbus v0.0.0-20151105175453-c7fdd8b5cd55/go.mod h1:/YcGZj5zSblfDWMMoOzV4fas9FZnQYTkDnsGvmh2Grw=
github.com/godbus/dbus v0.0.0-20180201030542-885f9cc04c9c/go.mod h1:/YcGZj5zSblfDWMMoOzV4fas9FZnQYTkDnsGvmh2Grw=

View File

@ -389,3 +389,13 @@ func LineCount(r io.Reader, skipEmpty bool) int {
return i
}
// TrimLen returns s but with a maximum length of maxLen.
func TrimLen(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}

View File

@ -1,6 +1,7 @@
package stringz_test
import (
"github.com/neilotoole/sq/testh/tutil"
"strings"
"testing"
@ -8,7 +9,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/neilotoole/sq/libsq/core/stringz"
"github.com/neilotoole/sq/testh"
)
func TestGenerateAlphaColName(t *testing.T) {
@ -74,6 +74,30 @@ func TestPluralize(t *testing.T) {
}
}
func TestTrimLen(t *testing.T) {
testCases := []struct {
s string
i int
want string
}{
{s: "", i: 0, want: ""},
{s: "", i: 1, want: ""},
{s: "abc", i: 0, want: ""},
{s: "abc", i: 1, want: "a"},
{s: "abc", i: 2, want: "ab"},
{s: "abc", i: 3, want: "abc"},
{s: "abc", i: 4, want: "abc"},
{s: "abc", i: 5, want: "abc"},
}
for _, tc := range testCases {
got := stringz.TrimLen(tc.s, tc.i)
require.Equal(t, tc.want, got)
}
}
func TestRepeatJoin(t *testing.T) {
testCases := []struct {
s string
@ -272,7 +296,7 @@ func TestLineCount(t *testing.T) {
for i, tc := range testCases {
tc := tc
t.Run(testh.Name(i, tc.in), func(t *testing.T) {
t.Run(tutil.Name(i, tc.in), func(t *testing.T) {
count := stringz.LineCount(strings.NewReader(tc.in), false)
require.Equal(t, tc.withEmpty, count)
count = stringz.LineCount(strings.NewReader(tc.in), true)

View File

@ -1,6 +1,7 @@
package driver_test
import (
"github.com/neilotoole/sq/testh/tutil"
"testing"
"github.com/stretchr/testify/assert"
@ -145,7 +146,7 @@ func TestDriver_TableColumnTypes(t *testing.T) {
handle := handle
t.Run(handle, func(t *testing.T) {
testh.SkipShort(t, handle == sakila.XLSX)
tutil.SkipShort(t, handle == sakila.XLSX)
t.Parallel()
th := testh.New(t)
@ -190,7 +191,7 @@ func TestSQLDriver_PrepareUpdateStmt(t *testing.T) {
handle := handle
t.Run(handle, func(t *testing.T) {
testh.SkipShort(t, handle == sakila.XLSX)
tutil.SkipShort(t, handle == sakila.XLSX)
t.Parallel()
th, src, dbase, drvr := testh.NewWith(t, handle)
@ -220,9 +221,9 @@ func TestSQLDriver_PrepareUpdateStmt(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 1, len(sink.Recs))
require.Equal(t, actorID, testh.Val(sink.Recs[0][0]))
require.Equal(t, wantVals[0], testh.Val(sink.Recs[0][1]))
require.Equal(t, wantVals[1], testh.Val(sink.Recs[0][2]))
require.Equal(t, actorID, tutil.Val(sink.Recs[0][0]))
require.Equal(t, wantVals[0], tutil.Val(sink.Recs[0][1]))
require.Equal(t, wantVals[1], tutil.Val(sink.Recs[0][2]))
})
}
}
@ -235,7 +236,7 @@ func TestDriver_Ping(t *testing.T) {
handle := handle
t.Run(handle, func(t *testing.T) {
testh.SkipShort(t, handle == sakila.XLSX)
tutil.SkipShort(t, handle == sakila.XLSX)
th := testh.New(t)
src := th.Source(handle)
@ -256,7 +257,7 @@ func TestDriver_Open(t *testing.T) {
handle := handle
t.Run(handle, func(t *testing.T) {
testh.SkipShort(t, handle == sakila.XLSX)
tutil.SkipShort(t, handle == sakila.XLSX)
t.Parallel()
th := testh.New(t)

View File

@ -1,6 +1,7 @@
package libsq_test
import (
"github.com/neilotoole/sq/testh/tutil"
"testing"
"github.com/stretchr/testify/require"
@ -28,7 +29,7 @@ func TestSLQ2SQL(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(testh.Name(tc.slq), func(t *testing.T) {
t.Run(tutil.Name(tc.slq), func(t *testing.T) {
th := testh.New(t)
srcs := th.NewSourceSet(tc.handles...)

View File

@ -1,6 +1,7 @@
package libsq_test
import (
"github.com/neilotoole/sq/testh/tutil"
"reflect"
"testing"
@ -62,7 +63,7 @@ func TestQuerySQL_Smoke(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(tc.handle, func(t *testing.T) {
testh.SkipShort(t, tc.handle == sakila.XLSX)
tutil.SkipShort(t, tc.handle == sakila.XLSX)
t.Parallel()
th := testh.New(t)

View File

@ -2,6 +2,7 @@ package source_test
import (
"context"
"github.com/neilotoole/sq/testh/tutil"
"io"
"io/ioutil"
"os"
@ -187,7 +188,7 @@ func TestFiles_Stdin(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(testh.Name(tc.fpath), func(t *testing.T) {
t.Run(tutil.Name(tc.fpath), func(t *testing.T) {
th := testh.New(t)
fs := th.Files()

View File

@ -4,17 +4,14 @@ package testh
import (
"context"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"testing"
"github.com/alexflint/go-filemutex"
"github.com/neilotoole/lg"
"github.com/neilotoole/lg/testlg"
"github.com/neilotoole/sq/cli/config"
@ -625,42 +622,6 @@ func DriverDefsFrom(t testing.TB, cfgFiles ...string) []*userdriver.DriverDef {
return userDriverDefs
}
// SkipShort invokes t.Skip if testing.Short and arg skip are both true.
func SkipShort(t *testing.T, skip bool) {
if skip && testing.Short() {
t.Skip("Skipping long-running test because -short is true.")
}
}
// Val returns the fully dereferenced value of i. If i
// is nil, nil is returned. If i has type *(*string),
// Val(i) returns string.
// Useful for testing.
func Val(i any) any {
if i == nil {
return nil
}
v := reflect.ValueOf(i)
for {
if !v.IsValid() {
return nil
}
switch v.Kind() {
default:
return v.Interface()
case reflect.Ptr, reflect.Interface:
if v.IsNil() {
return nil
}
v = v.Elem()
// Loop again
continue
}
}
}
// TypeDetectors returns the common set of TypeDetectorFuncs.
func TypeDetectors() []source.TypeDetectFunc {
return []source.TypeDetectFunc{
@ -670,68 +631,3 @@ func TypeDetectors() []source.TypeDetectFunc {
/*json.DetectJSON,*/ json.DetectJSONA, json.DetectJSONL, // FIXME: enable DetectJSON when it's ready
}
}
// AssertCompareFunc matches several of the the testify/require funcs.
// It can be used to choose assertion comparison funcs in test cases.
type AssertCompareFunc func(require.TestingT, any, any, ...any)
// Verify that a sample of the require funcs match AssertCompareFunc.
var (
_ AssertCompareFunc = require.Equal
_ AssertCompareFunc = require.GreaterOrEqual
_ AssertCompareFunc = require.Greater
)
// Name is a convenience function for building a test name to
// pass to t.Run.
//
// t.Run(testh.Name("my_test", 1), func(t *testing.T) {
//
// The most common usage is with test names that are file
// paths.
//
// testh.Name("path/to/file") --> "path_to_file"
//
// Any element of arg that prints to empty string is skipped.
func Name(args ...any) string {
var parts []string
var s string
for _, a := range args {
s = fmt.Sprintf("%v", a)
if s == "" {
continue
}
s = strings.Replace(s, "/", "_", -1)
parts = append(parts, s)
}
s = strings.Join(parts, "_")
if s == "" {
return "empty"
}
return s
}
// Lock obtains a universal (cross-process) mutex for all tests.
// This should be called by tests that cannot be executed in parallel
// with any other test (even those in another package).
//
// Why? The vast majority of tests can be run in parallel, both inside
// each test package and across test packages. The handful of tests
// that must not be run in parallel can use this function to guarantee
// sequential execution.
//
// This is implemented via a lock file /tmp/go_test.lock.
// The lock is released via t.Cleanup.
func Lock(t testing.TB) {
fp := filepath.Join(os.TempDir(), "go_test.lock")
mu, err := filemutex.New(fp)
require.NoError(t, err)
t.Cleanup(func() {
err := mu.Unlock()
assert.NoError(t, err)
})
}

View File

@ -1,6 +1,7 @@
package testh_test
import (
"github.com/neilotoole/sq/testh/tutil"
"io/ioutil"
"testing"
"time"
@ -23,12 +24,12 @@ func TestVal(t *testing.T) {
want := "hello"
var got any
if testh.Val(nil) != nil {
if tutil.Val(nil) != nil {
t.FailNow()
}
var v0 any
if testh.Val(v0) != nil {
if tutil.Val(v0) != nil {
t.FailNow()
}
@ -41,7 +42,7 @@ func TestVal(t *testing.T) {
vals := []any{v1, v1a, v2, v3, v4, v5}
for _, val := range vals {
got = testh.Val(val)
got = tutil.Val(val)
if got != want {
t.Errorf("expected %T(%v) but got %T(%v)", want, want, got, got)
@ -49,26 +50,26 @@ func TestVal(t *testing.T) {
}
slice := []string{"a", "b"}
require.Equal(t, slice, testh.Val(slice))
require.Equal(t, slice, testh.Val(&slice))
require.Equal(t, slice, tutil.Val(slice))
require.Equal(t, slice, tutil.Val(&slice))
b := true
require.Equal(t, b, testh.Val(b))
require.Equal(t, b, testh.Val(&b))
require.Equal(t, b, tutil.Val(b))
require.Equal(t, b, tutil.Val(&b))
type structT struct {
f string
}
st1 := structT{f: "hello"}
require.Equal(t, st1, testh.Val(st1))
require.Equal(t, st1, testh.Val(&st1))
require.Equal(t, st1, tutil.Val(st1))
require.Equal(t, st1, tutil.Val(&st1))
var c chan int
require.Nil(t, testh.Val(c))
require.Nil(t, tutil.Val(c))
c = make(chan int, 10)
require.Equal(t, c, testh.Val(c))
require.Equal(t, c, testh.Val(&c))
require.Equal(t, c, tutil.Val(c))
require.Equal(t, c, tutil.Val(&c))
}
func TestCopyRecords(t *testing.T) {
@ -106,7 +107,7 @@ func TestCopyRecords(t *testing.T) {
require.False(t, recs[i][j] == recs2[i][j],
"pointer values should not be equal: %#v --> %#v", recs[i][j], recs2[i][j])
val1, val2 := testh.Val(recs[i][j]), testh.Val(recs2[i][j])
val1, val2 := tutil.Val(recs[i][j]), tutil.Val(recs2[i][j])
require.Equal(t, val1, val2,
"dereferenced values should be equal: %#v --> %#v", val1, val2)
}
@ -176,7 +177,7 @@ func TestTName(t *testing.T) {
}
for _, tc := range testCases {
got := testh.Name(tc.a...)
got := tutil.Name(tc.a...)
require.Equal(t, tc.want, got)
}

295
testh/tutil/tutil.go Normal file
View File

@ -0,0 +1,295 @@
// Package tutil contains basic generic test utilities.
package tutil
import (
"fmt"
"github.com/alexflint/go-filemutex"
"github.com/neilotoole/sq/libsq/core/stringz"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
)
// SkipIff skips t if b is true. If msgAndArgs is non-empty, its first
// element must be a string, which can be a format string if there are
// additional elements.
//
// Examples:
//
// tutil.SkipIff(t, a == b)
// tutil.SkipIff(t, a == b, "skipping because a == b")
// tutil.SkipIff(t, a == b, "skipping because a is %v and b is %v", a, b)
func SkipIff(t testing.TB, b bool, format string, args ...any) {
if b {
if format == "" {
t.SkipNow()
} else {
t.Skipf(format, args...)
}
}
}
// StructFieldValue extracts the value of fieldName from arg strct.
// If strct is nil, nil is returned.
// The function will panic if strct is not a struct (or pointer to struct), or if
// the struct does not have fieldName. The returned value may be nil if the
// field is a pointer and is nil.
//
// Note that this function uses reflection, and may panic. It is only
// to be used by test code.
//
// See also: SliceFieldValues, SliceFieldKeyValues.
func StructFieldValue(fieldName string, strct any) any {
if strct == nil {
return nil
}
// zv is the zero value of reflect.Value, which can be returned by FieldByName
zv := reflect.Value{}
e := reflect.Indirect(reflect.ValueOf(strct))
if e.Kind() != reflect.Struct {
panic(fmt.Sprintf("strct expected to be struct but was %s", e.Kind()))
}
f := e.FieldByName(fieldName)
if f == zv {
panic(fmt.Sprintf("struct (%T) does not have field {%s}", strct, fieldName))
}
fieldValue := f.Interface()
return fieldValue
}
// SliceFieldValues takes a slice of structs, and returns a slice
// containing the value of fieldName for each element of slice.
//
// Note that slice can be []interface{}, or a typed slice (e.g. []*Person).
// If slice is nil, nil is returned. If slice has len zero, an empty slice
// is returned. The function panics if slice is not a slice, or if any element
// of slice is not a struct (excepting nil elements).
//
// Note that this function uses reflection, and may panic. It is only
// to be used by test code.
//
// See also: StructFieldValue, SliceFieldKeyValues.
func SliceFieldValues(fieldName string, slice any) []any {
if slice == nil {
return nil
}
s := reflect.ValueOf(slice)
if s.Kind() != reflect.Slice {
panic(fmt.Sprintf("arg slice expected to be a slice, but was {%T}", slice))
}
iSlice := InterfaceSlice(slice)
retVals := make([]any, len(iSlice))
for i := range iSlice {
retVals[i] = StructFieldValue(fieldName, iSlice[i])
}
return retVals
}
// SliceFieldKeyValues is similar to SliceFieldValues, but instead of
// returning a slice of field values, it returns a map containing two
// field values, a "key" and a "value". For example:
//
// persons := []*person{
// {Name: "Alice", Age: 42},
// {Name: "Bob", Age: 27},
// }
//
// m := SliceFieldKeyValues("Name", "Age", persons)
// // map[Alice:42 Bob:27]
//
// Note that this function uses reflection, and may panic. It is only
// to be used by test code.
//
// See also: StructFieldValue, SliceFieldValues.
func SliceFieldKeyValues(keyFieldName, valFieldName string, slice any) map[any]any {
if slice == nil {
return nil
}
s := reflect.ValueOf(slice)
if s.Kind() != reflect.Slice {
panic(fmt.Sprintf("arg slice expected to be a slice, but was {%T}", slice))
}
iSlice := InterfaceSlice(slice)
m := make(map[any]any, len(iSlice))
for i := range iSlice {
key := StructFieldValue(keyFieldName, iSlice[i])
val := StructFieldValue(valFieldName, iSlice[i])
m[key] = val
}
return m
}
// InterfaceSlice converts a typed slice (such as []string) to []interface{}.
// If slice is already of type []interface{}, it is returned unmodified.
// Otherwise a new []interface{} is constructed. If slice is nil, nil is
// returned. The function panics if slice is not a slice.
//
// Note that this function uses reflection, and may panic. It is only
// to be used by test code.
func InterfaceSlice(slice any) []any {
if slice == nil {
return nil
}
// If it's already an []interface{}, then just return
if iSlice, ok := slice.([]any); ok {
return iSlice
}
s := reflect.ValueOf(slice)
if s.Kind() != reflect.Slice {
panic(fmt.Sprintf("arg slice expected to be a slice, but was {%T}", slice))
}
// Keep the distinction between nil and empty slice input
if s.IsNil() {
return nil
}
ret := make([]any, s.Len())
for i := 0; i < s.Len(); i++ {
ret[i] = s.Index(i).Interface()
}
return ret
}
// StringSlice accepts a slice of arbitrary type (e.g. []int64 or []interface{})
// and returns a slice of string.
func StringSlice(slice any) []string {
if slice == nil {
return nil
}
// If it's already []string, return directly
if sSlice, ok := slice.([]string); ok {
return sSlice
}
iSlice := InterfaceSlice(slice)
sSlice := make([]string, len(iSlice))
for i := range iSlice {
sSlice[i] = fmt.Sprintf("%v", iSlice[i])
}
return sSlice
}
// Name is a convenience function for building a test name to
// pass to t.Run.
//
// t.Run(testh.Name("my_test", 1), func(t *testing.T) {
//
// The most common usage is with test names that are file
// paths.
//
// testh.Name("path/to/file") --> "path_to_file"
//
// Any element of arg that prints to empty string is skipped.
func Name(args ...any) string {
var parts []string
var s string
for _, a := range args {
s = fmt.Sprintf("%v", a)
if s == "" {
continue
}
s = strings.Replace(s, "/", "_", -1)
s = stringz.TrimLen(s, 40) // we don't want it to be too long
parts = append(parts, s)
}
s = strings.Join(parts, "_")
if s == "" {
return "empty"
}
return s
}
// SkipShort invokes t.Skip if testing.Short and arg skip are both true.
func SkipShort(t *testing.T, skip bool) {
if skip && testing.Short() {
t.Skip("Skipping long-running test because -short is true.")
}
}
// Val returns the fully dereferenced value of i. If i
// is nil, nil is returned. If i has type *(*string),
// Val(i) returns string.
// Useful for testing.
func Val(i any) any {
if i == nil {
return nil
}
v := reflect.ValueOf(i)
for {
if !v.IsValid() {
return nil
}
switch v.Kind() {
default:
return v.Interface()
case reflect.Ptr, reflect.Interface:
if v.IsNil() {
return nil
}
v = v.Elem()
// Loop again
continue
}
}
}
// AssertCompareFunc matches several of the testify/require funcs.
// It can be used to choose assertion comparison funcs in test cases.
type AssertCompareFunc func(require.TestingT, any, any, ...any)
// Verify that a sample of the require funcs match AssertCompareFunc.
var (
_ AssertCompareFunc = require.Equal
_ AssertCompareFunc = require.GreaterOrEqual
_ AssertCompareFunc = require.Greater
)
// Lock obtains a universal (cross-process) mutex for all tests.
// This should be called by tests that cannot be executed in parallel
// with any other test (even those in another package).
//
// Why? The vast majority of tests can be run in parallel, both inside
// each test package and across test packages. The handful of tests
// that must not be run in parallel can use this function to guarantee
// sequential execution.
//
// This is implemented via a lock file /tmp/go_test.lock.
// The lock is released via t.Cleanup.
func Lock(t testing.TB) {
fp := filepath.Join(os.TempDir(), "go_test.lock")
mu, err := filemutex.New(fp)
require.NoError(t, err)
t.Cleanup(func() {
err := mu.Unlock()
assert.NoError(t, err)
})
}

95
testh/tutil/tutil_test.go Normal file
View File

@ -0,0 +1,95 @@
package tutil
import (
"testing"
"github.com/stretchr/testify/require"
)
// TestFieldExtractionFunctions tests StructFieldValue, SliceFieldValues,
// SliceFieldKeyValues.
func TestFieldExtractionFunctions(t *testing.T) {
type person struct {
UUID string
Age int
Nickname *string
}
p1 := &person{
UUID: "235a50d7-3955-431c-8641-6ce171abf589",
Age: 42,
Nickname: nil,
}
nn := "The Great"
p2 := &person{
UUID: "81975a8f-6add-441a-8c81-3806a9f4c6f0",
Age: 27,
Nickname: &nn,
}
uu := StructFieldValue("UUID", p1)
require.Equal(t, uu, p1.UUID)
age := StructFieldValue("Age", p1)
require.Equal(t, age, 42)
require.Panics(t, func() {
_ = StructFieldValue("UUID", 123)
}, "non-struct arg should panic")
require.Nil(t, StructFieldValue("UUID", nil))
require.Panics(t, func() {
_ = StructFieldValue("", p1)
}, "invalid fieldName should panic")
require.Panics(t, func() {
_ = StructFieldValue("NotAField", p1)
}, "invalid fieldName should panic")
nickname := StructFieldValue("Nickname", p1)
require.Nil(t, nickname)
nickname = StructFieldValue("Nickname", p2)
require.NotNil(t, nickname)
require.EqualValues(t, nickname, p2.Nickname)
iSlice := []any{p1, p2}
iVals := SliceFieldValues("UUID", iSlice)
require.Len(t, iVals, 2)
require.Equal(t, p1.UUID, iVals[0])
personPtrSlice := []*person{p1, p2}
iVals2 := SliceFieldValues("UUID", personPtrSlice)
require.Len(t, iVals2, 2)
require.EqualValues(t, iVals, iVals2)
personSlice := []person{*p1, *p2}
iVals3 := SliceFieldValues("UUID", personSlice)
require.Len(t, iVals2, 2)
require.EqualValues(t, iVals, iVals3)
require.Panics(t, func() {
_ = SliceFieldValues("UUID", p1)
}, "non-slice arg should panic")
m1 := SliceFieldKeyValues("UUID", "Age", iSlice)
require.Len(t, m1, 2)
require.Equal(t, m1[p1.UUID], p1.Age)
require.Equal(t, m1[p2.UUID], p2.Age)
}
func TestInterfaceSlice(t *testing.T) {
stringSlice := []string{"hello", "world"}
iSlice := InterfaceSlice(stringSlice)
require.Equal(t, len(stringSlice), len(iSlice))
require.Equal(t, stringSlice[0], iSlice[0])
iSlice = InterfaceSlice(nil)
require.Nil(t, iSlice)
require.Panics(t, func() {
_ = InterfaceSlice(42)
}, "should panic for non-slice arg")
}