mirror of
https://github.com/neilotoole/sq.git
synced 2024-11-24 03:45:56 +03:00
fixed broken mysql tests (parseTime param); moved some test funcs to pkg tutil (#109)
This commit is contained in:
parent
e674cdc724
commit
6a0878bc6b
@ -1,6 +1,7 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"testing"
|
||||
|
||||
"github.com/neilotoole/sq/drivers/csv"
|
||||
@ -58,7 +59,7 @@ func TestCmdAdd(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(testh.Name(tc.wantHandle, tc.loc, tc.driver), func(t *testing.T) {
|
||||
t.Run(tutil.Name(tc.wantHandle, tc.loc, tc.driver), func(t *testing.T) {
|
||||
args := []string{"add", tc.loc}
|
||||
if tc.handle != "" {
|
||||
args = append(args, "--handle="+tc.handle)
|
||||
|
@ -2,6 +2,7 @@ package cli_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
@ -125,7 +126,7 @@ func TestCmdInspect_Stdin(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(testh.Name(tc.fpath), func(t *testing.T) {
|
||||
t.Run(tutil.Name(tc.fpath), func(t *testing.T) {
|
||||
f, err := os.Open(tc.fpath) // No need to close f
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -2,6 +2,7 @@ package cli_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -22,7 +23,7 @@ func TestCmdSQL_Insert(t *testing.T) {
|
||||
origin := origin
|
||||
|
||||
t.Run("origin_"+origin, func(t *testing.T) {
|
||||
testh.SkipShort(t, origin == sakila.XLSX)
|
||||
tutil.SkipShort(t, origin == sakila.XLSX)
|
||||
|
||||
for _, dest := range sakila.SQLLatest() {
|
||||
dest := dest
|
||||
@ -126,7 +127,7 @@ func TestCmdSQL_StdinQuery(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(testh.Name(tc.fpath), func(t *testing.T) {
|
||||
t.Run(tutil.Name(tc.fpath), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f, err := os.Open(tc.fpath)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@ -9,7 +10,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/neilotoole/sq/cli/config"
|
||||
"github.com/neilotoole/sq/testh"
|
||||
"github.com/neilotoole/sq/testh/proj"
|
||||
)
|
||||
|
||||
@ -72,7 +72,7 @@ func TestFileStore_Load(t *testing.T) {
|
||||
|
||||
for _, match := range good {
|
||||
match := match
|
||||
t.Run(testh.Name(match), func(t *testing.T) {
|
||||
t.Run(tutil.Name(match), func(t *testing.T) {
|
||||
fs.Path = match
|
||||
_, err = fs.Load()
|
||||
require.NoError(t, err, match)
|
||||
@ -82,7 +82,7 @@ func TestFileStore_Load(t *testing.T) {
|
||||
|
||||
for _, match := range bad {
|
||||
match := match
|
||||
t.Run(testh.Name(match), func(t *testing.T) {
|
||||
t.Run(tutil.Name(match), func(t *testing.T) {
|
||||
fs.Path = match
|
||||
_, err = fs.Load()
|
||||
require.Error(t, err, match)
|
||||
|
@ -3,6 +3,7 @@ package output_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -122,7 +123,7 @@ func TestRecordWriterAdapter_FlushAfterDuration(t *testing.T) {
|
||||
testCases := []struct {
|
||||
flushAfter time.Duration
|
||||
wantFlushed int
|
||||
assertFn testh.AssertCompareFunc
|
||||
assertFn tutil.AssertCompareFunc
|
||||
}{
|
||||
{flushAfter: -1, wantFlushed: 0, assertFn: require.Equal},
|
||||
{flushAfter: 0, wantFlushed: 0, assertFn: require.Equal},
|
||||
|
@ -296,7 +296,12 @@ func makeMapStringInterface(n int) map[string]any {
|
||||
}
|
||||
|
||||
func testName(v any) string {
|
||||
return fmt.Sprintf("%T", v)
|
||||
name := fmt.Sprintf("%T", v)
|
||||
if len(name) > 80 {
|
||||
name = name[:80]
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
|
||||
type codeResponse2 struct {
|
||||
|
@ -3,6 +3,7 @@ package json_test
|
||||
import (
|
||||
"bytes"
|
||||
stdj "encoding/json"
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
@ -73,7 +74,7 @@ func TestImportJSONL_Flat(t *testing.T) {
|
||||
for i, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(testh.Name(i, tc.fpath, tc.input), func(t *testing.T) {
|
||||
t.Run(tutil.Name(i, tc.fpath, tc.input), func(t *testing.T) {
|
||||
openFn := func() (io.ReadCloser, error) {
|
||||
return ioutil.NopCloser(strings.NewReader(tc.input)), nil
|
||||
}
|
||||
@ -170,7 +171,7 @@ func TestScanObjectsInArray(t *testing.T) {
|
||||
for i, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(testh.Name(i, tc.in), func(t *testing.T) {
|
||||
t.Run(tutil.Name(i, tc.in), func(t *testing.T) {
|
||||
r := bytes.NewReader([]byte(tc.in))
|
||||
gotObjs, gotChunks, err := json.ScanObjectsInArray(r)
|
||||
if tc.wantErr {
|
||||
@ -202,7 +203,7 @@ func TestScanObjectsInArray_Files(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(testh.Name(tc.fname), func(t *testing.T) {
|
||||
t.Run(tutil.Name(tc.fname), func(t *testing.T) {
|
||||
f, err := os.Open(tc.fname)
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
@ -241,7 +242,7 @@ func TestColumnOrderFlat(t *testing.T) {
|
||||
for i, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(testh.Name(i, tc.in), func(t *testing.T) {
|
||||
t.Run(tutil.Name(i, tc.in), func(t *testing.T) {
|
||||
require.True(t, stdj.Valid([]byte(tc.in)))
|
||||
|
||||
gotCols, err := json.ColumnOrderFlat([]byte(tc.in))
|
||||
|
@ -3,6 +3,7 @@ package json_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -13,7 +14,6 @@ import (
|
||||
|
||||
"github.com/neilotoole/sq/drivers/json"
|
||||
"github.com/neilotoole/sq/libsq/source"
|
||||
"github.com/neilotoole/sq/testh"
|
||||
)
|
||||
|
||||
func TestTypeDetectorFuncs(t *testing.T) {
|
||||
@ -87,7 +87,7 @@ func TestTypeDetectorFuncs(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(testh.Name(tc.fn, tc.fname), func(t *testing.T) {
|
||||
t.Run(tutil.Name(tc.fn, tc.fname), func(t *testing.T) {
|
||||
openFn := func() (io.ReadCloser, error) { return os.Open(filepath.Join("testdata", tc.fname)) }
|
||||
detectFn := detectFns[tc.fn]
|
||||
|
||||
|
@ -2,6 +2,7 @@ package mysql_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -431,7 +432,7 @@ func TestDatabaseTypeJSON(t *testing.T) {
|
||||
require.Equal(t, len(testVals), len(sink.Recs))
|
||||
for i := range testVals {
|
||||
for j := range testVals[i] {
|
||||
require.Equal(t, testVals[i][j], testh.Val(sink.Recs[i][j]))
|
||||
require.Equal(t, testVals[i][j], tutil.Val(sink.Recs[i][j]))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -1,6 +1,7 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"testing"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
@ -47,38 +48,60 @@ func TestHasErrCode(t *testing.T) {
|
||||
|
||||
func TestDSNFromLocation(t *testing.T) {
|
||||
testCases := []struct {
|
||||
loc string
|
||||
wantDSN string
|
||||
wantErr bool
|
||||
loc string
|
||||
parseTime bool
|
||||
wantDSN string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest",
|
||||
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?allowOldPasswords=1",
|
||||
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowOldPasswords=1",
|
||||
wantErr: false,
|
||||
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest",
|
||||
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?parseTime=true",
|
||||
parseTime: true,
|
||||
},
|
||||
{
|
||||
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?allowCleartextPasswords=true&allowOldPasswords=1",
|
||||
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowCleartextPasswords=true&allowOldPasswords=1",
|
||||
wantErr: false,
|
||||
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?parseTime=true",
|
||||
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?parseTime=true",
|
||||
parseTime: true,
|
||||
},
|
||||
{
|
||||
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?allowOldPasswords=true",
|
||||
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowOldPasswords=true",
|
||||
},
|
||||
{
|
||||
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?allowOldPasswords=true",
|
||||
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowOldPasswords=true&parseTime=true",
|
||||
parseTime: true,
|
||||
},
|
||||
{
|
||||
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?allowOldPasswords=true",
|
||||
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowOldPasswords=true",
|
||||
},
|
||||
{
|
||||
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?allowCleartextPasswords=true&allowOldPasswords=true",
|
||||
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowCleartextPasswords=true&allowOldPasswords=true",
|
||||
},
|
||||
{
|
||||
loc: "mysql://sakila:p_ssW0rd@localhost:3306/sqtest?parseTime=true&allowCleartextPasswords=true&allowOldPasswords=true",
|
||||
wantDSN: "sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowCleartextPasswords=true&allowOldPasswords=true&parseTime=true",
|
||||
parseTime: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.loc, func(t *testing.T) {
|
||||
t.Run(tutil.Name(tc.loc, tc.parseTime), func(t *testing.T) {
|
||||
src := &source.Source{
|
||||
Handle: "@testhandle",
|
||||
Type: Type,
|
||||
Location: tc.loc,
|
||||
}
|
||||
|
||||
gotDSN, gotErr := dsnFromLocation(src)
|
||||
gotDSN, gotErr := dsnFromLocation(src, tc.parseTime)
|
||||
if tc.wantErr {
|
||||
require.Error(t, gotErr)
|
||||
return
|
||||
|
@ -39,7 +39,9 @@ func kindFromDBTypeName(log lg.Log, colName, dbTypeName string) kind.Kind {
|
||||
knd = kind.Unknown
|
||||
case "":
|
||||
knd = kind.Unknown
|
||||
case "INTEGER", "INT", "TINYINT", "SMALLINT", "MEDIUMINT", "BIGINT", "YEAR", "BIT":
|
||||
case "INTEGER", "INT", "TINYINT", "SMALLINT", "MEDIUMINT", "BIGINT", "YEAR", "BIT",
|
||||
"UNSIGNED INTEGER", "UNSIGNED INT", "UNSIGNED TINYINT",
|
||||
"UNSIGNED SMALLINT", "UNSIGNED MEDIUMINT", "UNSIGNED BIGINT":
|
||||
knd = kind.Int
|
||||
case "DECIMAL", "NUMERIC":
|
||||
knd = kind.Decimal
|
||||
|
@ -259,7 +259,7 @@ func (d *driveri) getTableRecordMeta(ctx context.Context, db sqlz.DB, tblName st
|
||||
|
||||
// Open implements driver.Driver.
|
||||
func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Database, error) {
|
||||
dsn, err := dsnFromLocation(src)
|
||||
dsn, err := dsnFromLocation(src, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -296,7 +296,7 @@ func (d *driveri) Ping(ctx context.Context, src *source.Source) error {
|
||||
// the TRUNCATE statement.
|
||||
func (d *driveri) Truncate(ctx context.Context, src *source.Source, tbl string, reset bool) (affected int64, err error) {
|
||||
// https://dev.mysql.com/doc/refman/8.0/en/truncate-table.html
|
||||
dsn, err := dsnFromLocation(src)
|
||||
dsn, err := dsnFromLocation(src, true)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -390,8 +390,10 @@ func hasErrCode(err error, code uint16) bool {
|
||||
|
||||
const errNumTableNotExist = uint16(1146)
|
||||
|
||||
// dsnFromLocation extracts the mysql driver DSN from src.Location.
|
||||
func dsnFromLocation(src *source.Source) (string, error) {
|
||||
// dsnFromLocation builds the mysql driver DSN from src.Location.
|
||||
// If parseTime is true, the param "parseTime=true" is added. This
|
||||
// is because of: https://stackoverflow.com/questions/29341590/how-to-parse-time-from-database/29343013#29343013
|
||||
func dsnFromLocation(src *source.Source, parseTime bool) (string, error) {
|
||||
if !strings.HasPrefix(src.Location, "mysql://") || len(src.Location) < 10 {
|
||||
return "", errz.Errorf("invalid source location %s", src.RedactedLocation())
|
||||
}
|
||||
@ -406,10 +408,13 @@ func dsnFromLocation(src *source.Source) (string, error) {
|
||||
// Driver DSN: sakila:p_ssW0rd@tcp(localhost:3306)/sqtest?allowOldPasswords=1
|
||||
driverDSN := u.DSN
|
||||
|
||||
_, err = mysql.ParseDSN(driverDSN) // verify
|
||||
myCfg, err := mysql.ParseDSN(driverDSN) // verify
|
||||
if err != nil {
|
||||
return "", errz.Wrapf(err, "invalid source location: %q", driverDSN)
|
||||
}
|
||||
|
||||
myCfg.ParseTime = parseTime
|
||||
driverDSN = myCfg.FormatDSN()
|
||||
|
||||
return driverDSN, nil
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package sqlite3_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -188,7 +189,7 @@ func TestDatabaseTypes(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
|
||||
require.Equal(t, wantVal, testh.Val(gotVal),
|
||||
require.Equal(t, wantVal, tutil.Val(gotVal),
|
||||
"%s[%d][%d] (%s) expected %T(%v) but got %T(%v)",
|
||||
actualTblName, i, j, typeTestColNames[j], wantVal, wantVal, gotVal, gotVal)
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package xmlud_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -49,19 +50,19 @@ func TestImport_Ppl(t *testing.T) {
|
||||
sink, err := th.QuerySQL(scratchDB.Source(), "SELECT * FROM person")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, len(sink.Recs))
|
||||
require.Equal(t, "Nikola", testh.Val(sink.Recs[0][1]))
|
||||
require.Equal(t, "Nikola", tutil.Val(sink.Recs[0][1]))
|
||||
for i, rec := range sink.Recs {
|
||||
// Verify that the primary id cols are sequential
|
||||
require.Equal(t, int64(i+1), testh.Val(rec[0]))
|
||||
require.Equal(t, int64(i+1), tutil.Val(rec[0]))
|
||||
}
|
||||
|
||||
sink, err = th.QuerySQL(scratchDB.Source(), "SELECT * FROM skill")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 6, len(sink.Recs))
|
||||
require.Equal(t, "Electrifying", testh.Val(sink.Recs[0][2]))
|
||||
require.Equal(t, "Electrifying", tutil.Val(sink.Recs[0][2]))
|
||||
for i, rec := range sink.Recs {
|
||||
// Verify that the primary id cols are sequential
|
||||
require.Equal(t, int64(i+1), testh.Val(rec[0]))
|
||||
require.Equal(t, int64(i+1), tutil.Val(rec[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@ -95,27 +96,27 @@ func TestImport_RSS(t *testing.T) {
|
||||
sink, err := th.QuerySQL(scratchDB.Source(), "SELECT * FROM channel")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(sink.Recs))
|
||||
require.Equal(t, "NYT > World", testh.Val(sink.Recs[0][1]))
|
||||
require.Equal(t, "NYT > World", tutil.Val(sink.Recs[0][1]))
|
||||
for i, rec := range sink.Recs {
|
||||
// Verify that the primary id cols are sequential
|
||||
require.Equal(t, int64(i+1), testh.Val(rec[0]))
|
||||
require.Equal(t, int64(i+1), tutil.Val(rec[0]))
|
||||
}
|
||||
|
||||
sink, err = th.QuerySQL(scratchDB.Source(), "SELECT * FROM category")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 251, len(sink.Recs))
|
||||
require.EqualValues(t, "Extradition", testh.Val(sink.Recs[0][2]))
|
||||
require.EqualValues(t, "Extradition", tutil.Val(sink.Recs[0][2]))
|
||||
for i, rec := range sink.Recs {
|
||||
// Verify that the primary id cols are sequential
|
||||
require.Equal(t, int64(i+1), testh.Val(rec[0]))
|
||||
require.Equal(t, int64(i+1), tutil.Val(rec[0]))
|
||||
}
|
||||
|
||||
sink, err = th.QuerySQL(scratchDB.Source(), "SELECT * FROM item")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 45, len(sink.Recs))
|
||||
require.EqualValues(t, "Trilobites: Fishing for Clues to Solve Namibia’s Fairy Circle Mystery", testh.Val(sink.Recs[17][4]))
|
||||
require.EqualValues(t, "Trilobites: Fishing for Clues to Solve Namibia’s Fairy Circle Mystery", tutil.Val(sink.Recs[17][4]))
|
||||
for i, rec := range sink.Recs {
|
||||
// Verify that the primary id cols are sequential
|
||||
require.Equal(t, int64(i+1), testh.Val(rec[0]))
|
||||
require.Equal(t, int64(i+1), tutil.Val(rec[0]))
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package xlsx_test
|
||||
|
||||
import (
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -24,7 +25,7 @@ func Test_Smoke_Subset(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_Smoke_Full(t *testing.T) {
|
||||
testh.SkipShort(t, true)
|
||||
tutil.SkipShort(t, true)
|
||||
|
||||
th := testh.New(t)
|
||||
src := th.Source(sakila.XLSX)
|
||||
|
2
go.mod
2
go.mod
@ -14,7 +14,7 @@ require (
|
||||
github.com/djherbis/fscache v0.10.1
|
||||
github.com/emirpasic/gods v1.9.0
|
||||
github.com/fatih/color v1.13.0
|
||||
github.com/go-sql-driver/mysql v1.6.0
|
||||
github.com/go-sql-driver/mysql v1.7.0
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/h2non/filetype v1.1.0
|
||||
github.com/jackc/pgconn v1.5.0
|
||||
|
2
go.sum
2
go.sum
@ -352,6 +352,8 @@ github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/godbus/dbus v0.0.0-20151105175453-c7fdd8b5cd55/go.mod h1:/YcGZj5zSblfDWMMoOzV4fas9FZnQYTkDnsGvmh2Grw=
|
||||
github.com/godbus/dbus v0.0.0-20180201030542-885f9cc04c9c/go.mod h1:/YcGZj5zSblfDWMMoOzV4fas9FZnQYTkDnsGvmh2Grw=
|
||||
|
@ -389,3 +389,13 @@ func LineCount(r io.Reader, skipEmpty bool) int {
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
// TrimLen returns s but with a maximum length of maxLen.
|
||||
func TrimLen(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
|
||||
return s[:maxLen]
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package stringz_test
|
||||
|
||||
import (
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -8,7 +9,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/neilotoole/sq/libsq/core/stringz"
|
||||
"github.com/neilotoole/sq/testh"
|
||||
)
|
||||
|
||||
func TestGenerateAlphaColName(t *testing.T) {
|
||||
@ -74,6 +74,30 @@ func TestPluralize(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestTrimLen(t *testing.T) {
|
||||
testCases := []struct {
|
||||
s string
|
||||
i int
|
||||
want string
|
||||
}{
|
||||
{s: "", i: 0, want: ""},
|
||||
{s: "", i: 1, want: ""},
|
||||
{s: "abc", i: 0, want: ""},
|
||||
{s: "abc", i: 1, want: "a"},
|
||||
{s: "abc", i: 2, want: "ab"},
|
||||
{s: "abc", i: 3, want: "abc"},
|
||||
{s: "abc", i: 4, want: "abc"},
|
||||
{s: "abc", i: 5, want: "abc"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
got := stringz.TrimLen(tc.s, tc.i)
|
||||
require.Equal(t, tc.want, got)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestRepeatJoin(t *testing.T) {
|
||||
testCases := []struct {
|
||||
s string
|
||||
@ -272,7 +296,7 @@ func TestLineCount(t *testing.T) {
|
||||
for i, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(testh.Name(i, tc.in), func(t *testing.T) {
|
||||
t.Run(tutil.Name(i, tc.in), func(t *testing.T) {
|
||||
count := stringz.LineCount(strings.NewReader(tc.in), false)
|
||||
require.Equal(t, tc.withEmpty, count)
|
||||
count = stringz.LineCount(strings.NewReader(tc.in), true)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package driver_test
|
||||
|
||||
import (
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -145,7 +146,7 @@ func TestDriver_TableColumnTypes(t *testing.T) {
|
||||
handle := handle
|
||||
|
||||
t.Run(handle, func(t *testing.T) {
|
||||
testh.SkipShort(t, handle == sakila.XLSX)
|
||||
tutil.SkipShort(t, handle == sakila.XLSX)
|
||||
t.Parallel()
|
||||
|
||||
th := testh.New(t)
|
||||
@ -190,7 +191,7 @@ func TestSQLDriver_PrepareUpdateStmt(t *testing.T) {
|
||||
handle := handle
|
||||
|
||||
t.Run(handle, func(t *testing.T) {
|
||||
testh.SkipShort(t, handle == sakila.XLSX)
|
||||
tutil.SkipShort(t, handle == sakila.XLSX)
|
||||
t.Parallel()
|
||||
|
||||
th, src, dbase, drvr := testh.NewWith(t, handle)
|
||||
@ -220,9 +221,9 @@ func TestSQLDriver_PrepareUpdateStmt(t *testing.T) {
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(sink.Recs))
|
||||
require.Equal(t, actorID, testh.Val(sink.Recs[0][0]))
|
||||
require.Equal(t, wantVals[0], testh.Val(sink.Recs[0][1]))
|
||||
require.Equal(t, wantVals[1], testh.Val(sink.Recs[0][2]))
|
||||
require.Equal(t, actorID, tutil.Val(sink.Recs[0][0]))
|
||||
require.Equal(t, wantVals[0], tutil.Val(sink.Recs[0][1]))
|
||||
require.Equal(t, wantVals[1], tutil.Val(sink.Recs[0][2]))
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -235,7 +236,7 @@ func TestDriver_Ping(t *testing.T) {
|
||||
handle := handle
|
||||
|
||||
t.Run(handle, func(t *testing.T) {
|
||||
testh.SkipShort(t, handle == sakila.XLSX)
|
||||
tutil.SkipShort(t, handle == sakila.XLSX)
|
||||
|
||||
th := testh.New(t)
|
||||
src := th.Source(handle)
|
||||
@ -256,7 +257,7 @@ func TestDriver_Open(t *testing.T) {
|
||||
handle := handle
|
||||
|
||||
t.Run(handle, func(t *testing.T) {
|
||||
testh.SkipShort(t, handle == sakila.XLSX)
|
||||
tutil.SkipShort(t, handle == sakila.XLSX)
|
||||
t.Parallel()
|
||||
|
||||
th := testh.New(t)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package libsq_test
|
||||
|
||||
import (
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -28,7 +29,7 @@ func TestSLQ2SQL(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(testh.Name(tc.slq), func(t *testing.T) {
|
||||
t.Run(tutil.Name(tc.slq), func(t *testing.T) {
|
||||
th := testh.New(t)
|
||||
srcs := th.NewSourceSet(tc.handles...)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package libsq_test
|
||||
|
||||
import (
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
@ -62,7 +63,7 @@ func TestQuerySQL_Smoke(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.handle, func(t *testing.T) {
|
||||
testh.SkipShort(t, tc.handle == sakila.XLSX)
|
||||
tutil.SkipShort(t, tc.handle == sakila.XLSX)
|
||||
t.Parallel()
|
||||
|
||||
th := testh.New(t)
|
||||
|
@ -2,6 +2,7 @@ package source_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
@ -187,7 +188,7 @@ func TestFiles_Stdin(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(testh.Name(tc.fpath), func(t *testing.T) {
|
||||
t.Run(tutil.Name(tc.fpath), func(t *testing.T) {
|
||||
th := testh.New(t)
|
||||
fs := th.Files()
|
||||
|
||||
|
104
testh/testh.go
104
testh/testh.go
@ -4,17 +4,14 @@ package testh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/alexflint/go-filemutex"
|
||||
"github.com/neilotoole/lg"
|
||||
"github.com/neilotoole/lg/testlg"
|
||||
"github.com/neilotoole/sq/cli/config"
|
||||
@ -625,42 +622,6 @@ func DriverDefsFrom(t testing.TB, cfgFiles ...string) []*userdriver.DriverDef {
|
||||
return userDriverDefs
|
||||
}
|
||||
|
||||
// SkipShort invokes t.Skip if testing.Short and arg skip are both true.
|
||||
func SkipShort(t *testing.T, skip bool) {
|
||||
if skip && testing.Short() {
|
||||
t.Skip("Skipping long-running test because -short is true.")
|
||||
}
|
||||
}
|
||||
|
||||
// Val returns the fully dereferenced value of i. If i
|
||||
// is nil, nil is returned. If i has type *(*string),
|
||||
// Val(i) returns string.
|
||||
// Useful for testing.
|
||||
func Val(i any) any {
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(i)
|
||||
for {
|
||||
if !v.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
default:
|
||||
return v.Interface()
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
if v.IsNil() {
|
||||
return nil
|
||||
}
|
||||
v = v.Elem()
|
||||
// Loop again
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TypeDetectors returns the common set of TypeDetectorFuncs.
|
||||
func TypeDetectors() []source.TypeDetectFunc {
|
||||
return []source.TypeDetectFunc{
|
||||
@ -670,68 +631,3 @@ func TypeDetectors() []source.TypeDetectFunc {
|
||||
/*json.DetectJSON,*/ json.DetectJSONA, json.DetectJSONL, // FIXME: enable DetectJSON when it's ready
|
||||
}
|
||||
}
|
||||
|
||||
// AssertCompareFunc matches several of the the testify/require funcs.
|
||||
// It can be used to choose assertion comparison funcs in test cases.
|
||||
type AssertCompareFunc func(require.TestingT, any, any, ...any)
|
||||
|
||||
// Verify that a sample of the require funcs match AssertCompareFunc.
|
||||
var (
|
||||
_ AssertCompareFunc = require.Equal
|
||||
_ AssertCompareFunc = require.GreaterOrEqual
|
||||
_ AssertCompareFunc = require.Greater
|
||||
)
|
||||
|
||||
// Name is a convenience function for building a test name to
|
||||
// pass to t.Run.
|
||||
//
|
||||
// t.Run(testh.Name("my_test", 1), func(t *testing.T) {
|
||||
//
|
||||
// The most common usage is with test names that are file
|
||||
// paths.
|
||||
//
|
||||
// testh.Name("path/to/file") --> "path_to_file"
|
||||
//
|
||||
// Any element of arg that prints to empty string is skipped.
|
||||
func Name(args ...any) string {
|
||||
var parts []string
|
||||
var s string
|
||||
for _, a := range args {
|
||||
s = fmt.Sprintf("%v", a)
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
s = strings.Replace(s, "/", "_", -1)
|
||||
parts = append(parts, s)
|
||||
}
|
||||
|
||||
s = strings.Join(parts, "_")
|
||||
if s == "" {
|
||||
return "empty"
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Lock obtains a universal (cross-process) mutex for all tests.
|
||||
// This should be called by tests that cannot be executed in parallel
|
||||
// with any other test (even those in another package).
|
||||
//
|
||||
// Why? The vast majority of tests can be run in parallel, both inside
|
||||
// each test package and across test packages. The handful of tests
|
||||
// that must not be run in parallel can use this function to guarantee
|
||||
// sequential execution.
|
||||
//
|
||||
// This is implemented via a lock file /tmp/go_test.lock.
|
||||
// The lock is released via t.Cleanup.
|
||||
func Lock(t testing.TB) {
|
||||
fp := filepath.Join(os.TempDir(), "go_test.lock")
|
||||
mu, err := filemutex.New(fp)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := mu.Unlock()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package testh_test
|
||||
|
||||
import (
|
||||
"github.com/neilotoole/sq/testh/tutil"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
"time"
|
||||
@ -23,12 +24,12 @@ func TestVal(t *testing.T) {
|
||||
want := "hello"
|
||||
var got any
|
||||
|
||||
if testh.Val(nil) != nil {
|
||||
if tutil.Val(nil) != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
var v0 any
|
||||
if testh.Val(v0) != nil {
|
||||
if tutil.Val(v0) != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
@ -41,7 +42,7 @@ func TestVal(t *testing.T) {
|
||||
|
||||
vals := []any{v1, v1a, v2, v3, v4, v5}
|
||||
for _, val := range vals {
|
||||
got = testh.Val(val)
|
||||
got = tutil.Val(val)
|
||||
|
||||
if got != want {
|
||||
t.Errorf("expected %T(%v) but got %T(%v)", want, want, got, got)
|
||||
@ -49,26 +50,26 @@ func TestVal(t *testing.T) {
|
||||
}
|
||||
|
||||
slice := []string{"a", "b"}
|
||||
require.Equal(t, slice, testh.Val(slice))
|
||||
require.Equal(t, slice, testh.Val(&slice))
|
||||
require.Equal(t, slice, tutil.Val(slice))
|
||||
require.Equal(t, slice, tutil.Val(&slice))
|
||||
|
||||
b := true
|
||||
require.Equal(t, b, testh.Val(b))
|
||||
require.Equal(t, b, testh.Val(&b))
|
||||
require.Equal(t, b, tutil.Val(b))
|
||||
require.Equal(t, b, tutil.Val(&b))
|
||||
|
||||
type structT struct {
|
||||
f string
|
||||
}
|
||||
|
||||
st1 := structT{f: "hello"}
|
||||
require.Equal(t, st1, testh.Val(st1))
|
||||
require.Equal(t, st1, testh.Val(&st1))
|
||||
require.Equal(t, st1, tutil.Val(st1))
|
||||
require.Equal(t, st1, tutil.Val(&st1))
|
||||
|
||||
var c chan int
|
||||
require.Nil(t, testh.Val(c))
|
||||
require.Nil(t, tutil.Val(c))
|
||||
c = make(chan int, 10)
|
||||
require.Equal(t, c, testh.Val(c))
|
||||
require.Equal(t, c, testh.Val(&c))
|
||||
require.Equal(t, c, tutil.Val(c))
|
||||
require.Equal(t, c, tutil.Val(&c))
|
||||
}
|
||||
|
||||
func TestCopyRecords(t *testing.T) {
|
||||
@ -106,7 +107,7 @@ func TestCopyRecords(t *testing.T) {
|
||||
require.False(t, recs[i][j] == recs2[i][j],
|
||||
"pointer values should not be equal: %#v --> %#v", recs[i][j], recs2[i][j])
|
||||
|
||||
val1, val2 := testh.Val(recs[i][j]), testh.Val(recs2[i][j])
|
||||
val1, val2 := tutil.Val(recs[i][j]), tutil.Val(recs2[i][j])
|
||||
require.Equal(t, val1, val2,
|
||||
"dereferenced values should be equal: %#v --> %#v", val1, val2)
|
||||
}
|
||||
@ -176,7 +177,7 @@ func TestTName(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
got := testh.Name(tc.a...)
|
||||
got := tutil.Name(tc.a...)
|
||||
require.Equal(t, tc.want, got)
|
||||
}
|
||||
|
||||
|
295
testh/tutil/tutil.go
Normal file
295
testh/tutil/tutil.go
Normal file
@ -0,0 +1,295 @@
|
||||
// Package tutil contains basic generic test utilities.
|
||||
package tutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/alexflint/go-filemutex"
|
||||
"github.com/neilotoole/sq/libsq/core/stringz"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// SkipIff skips t if b is true. If msgAndArgs is non-empty, its first
|
||||
// element must be a string, which can be a format string if there are
|
||||
// additional elements.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// tutil.SkipIff(t, a == b)
|
||||
// tutil.SkipIff(t, a == b, "skipping because a == b")
|
||||
// tutil.SkipIff(t, a == b, "skipping because a is %v and b is %v", a, b)
|
||||
func SkipIff(t testing.TB, b bool, format string, args ...any) {
|
||||
if b {
|
||||
if format == "" {
|
||||
t.SkipNow()
|
||||
} else {
|
||||
t.Skipf(format, args...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StructFieldValue extracts the value of fieldName from arg strct.
|
||||
// If strct is nil, nil is returned.
|
||||
// The function will panic if strct is not a struct (or pointer to struct), or if
|
||||
// the struct does not have fieldName. The returned value may be nil if the
|
||||
// field is a pointer and is nil.
|
||||
//
|
||||
// Note that this function uses reflection, and may panic. It is only
|
||||
// to be used by test code.
|
||||
//
|
||||
// See also: SliceFieldValues, SliceFieldKeyValues.
|
||||
func StructFieldValue(fieldName string, strct any) any {
|
||||
if strct == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// zv is the zero value of reflect.Value, which can be returned by FieldByName
|
||||
zv := reflect.Value{}
|
||||
|
||||
e := reflect.Indirect(reflect.ValueOf(strct))
|
||||
if e.Kind() != reflect.Struct {
|
||||
panic(fmt.Sprintf("strct expected to be struct but was %s", e.Kind()))
|
||||
}
|
||||
|
||||
f := e.FieldByName(fieldName)
|
||||
if f == zv {
|
||||
panic(fmt.Sprintf("struct (%T) does not have field {%s}", strct, fieldName))
|
||||
}
|
||||
fieldValue := f.Interface()
|
||||
return fieldValue
|
||||
}
|
||||
|
||||
// SliceFieldValues takes a slice of structs, and returns a slice
|
||||
// containing the value of fieldName for each element of slice.
|
||||
//
|
||||
// Note that slice can be []interface{}, or a typed slice (e.g. []*Person).
|
||||
// If slice is nil, nil is returned. If slice has len zero, an empty slice
|
||||
// is returned. The function panics if slice is not a slice, or if any element
|
||||
// of slice is not a struct (excepting nil elements).
|
||||
//
|
||||
// Note that this function uses reflection, and may panic. It is only
|
||||
// to be used by test code.
|
||||
//
|
||||
// See also: StructFieldValue, SliceFieldKeyValues.
|
||||
func SliceFieldValues(fieldName string, slice any) []any {
|
||||
if slice == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s := reflect.ValueOf(slice)
|
||||
if s.Kind() != reflect.Slice {
|
||||
panic(fmt.Sprintf("arg slice expected to be a slice, but was {%T}", slice))
|
||||
}
|
||||
|
||||
iSlice := InterfaceSlice(slice)
|
||||
retVals := make([]any, len(iSlice))
|
||||
|
||||
for i := range iSlice {
|
||||
retVals[i] = StructFieldValue(fieldName, iSlice[i])
|
||||
}
|
||||
|
||||
return retVals
|
||||
}
|
||||
|
||||
// SliceFieldKeyValues is similar to SliceFieldValues, but instead of
|
||||
// returning a slice of field values, it returns a map containing two
|
||||
// field values, a "key" and a "value". For example:
|
||||
//
|
||||
// persons := []*person{
|
||||
// {Name: "Alice", Age: 42},
|
||||
// {Name: "Bob", Age: 27},
|
||||
// }
|
||||
//
|
||||
// m := SliceFieldKeyValues("Name", "Age", persons)
|
||||
// // map[Alice:42 Bob:27]
|
||||
//
|
||||
// Note that this function uses reflection, and may panic. It is only
|
||||
// to be used by test code.
|
||||
//
|
||||
// See also: StructFieldValue, SliceFieldValues.
|
||||
func SliceFieldKeyValues(keyFieldName, valFieldName string, slice any) map[any]any {
|
||||
if slice == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s := reflect.ValueOf(slice)
|
||||
if s.Kind() != reflect.Slice {
|
||||
panic(fmt.Sprintf("arg slice expected to be a slice, but was {%T}", slice))
|
||||
}
|
||||
|
||||
iSlice := InterfaceSlice(slice)
|
||||
m := make(map[any]any, len(iSlice))
|
||||
|
||||
for i := range iSlice {
|
||||
key := StructFieldValue(keyFieldName, iSlice[i])
|
||||
val := StructFieldValue(valFieldName, iSlice[i])
|
||||
|
||||
m[key] = val
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// InterfaceSlice converts a typed slice (such as []string) to []interface{}.
|
||||
// If slice is already of type []interface{}, it is returned unmodified.
|
||||
// Otherwise a new []interface{} is constructed. If slice is nil, nil is
|
||||
// returned. The function panics if slice is not a slice.
|
||||
//
|
||||
// Note that this function uses reflection, and may panic. It is only
|
||||
// to be used by test code.
|
||||
func InterfaceSlice(slice any) []any {
|
||||
if slice == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If it's already an []interface{}, then just return
|
||||
if iSlice, ok := slice.([]any); ok {
|
||||
return iSlice
|
||||
}
|
||||
|
||||
s := reflect.ValueOf(slice)
|
||||
if s.Kind() != reflect.Slice {
|
||||
panic(fmt.Sprintf("arg slice expected to be a slice, but was {%T}", slice))
|
||||
}
|
||||
|
||||
// Keep the distinction between nil and empty slice input
|
||||
if s.IsNil() {
|
||||
return nil
|
||||
}
|
||||
|
||||
ret := make([]any, s.Len())
|
||||
|
||||
for i := 0; i < s.Len(); i++ {
|
||||
ret[i] = s.Index(i).Interface()
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// StringSlice accepts a slice of arbitrary type (e.g. []int64 or []interface{})
|
||||
// and returns a slice of string.
|
||||
func StringSlice(slice any) []string {
|
||||
if slice == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If it's already []string, return directly
|
||||
if sSlice, ok := slice.([]string); ok {
|
||||
return sSlice
|
||||
}
|
||||
|
||||
iSlice := InterfaceSlice(slice)
|
||||
sSlice := make([]string, len(iSlice))
|
||||
for i := range iSlice {
|
||||
sSlice[i] = fmt.Sprintf("%v", iSlice[i])
|
||||
}
|
||||
|
||||
return sSlice
|
||||
}
|
||||
|
||||
// Name is a convenience function for building a test name to
|
||||
// pass to t.Run.
|
||||
//
|
||||
// t.Run(testh.Name("my_test", 1), func(t *testing.T) {
|
||||
//
|
||||
// The most common usage is with test names that are file
|
||||
// paths.
|
||||
//
|
||||
// testh.Name("path/to/file") --> "path_to_file"
|
||||
//
|
||||
// Any element of arg that prints to empty string is skipped.
|
||||
func Name(args ...any) string {
|
||||
var parts []string
|
||||
var s string
|
||||
for _, a := range args {
|
||||
s = fmt.Sprintf("%v", a)
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
s = strings.Replace(s, "/", "_", -1)
|
||||
s = stringz.TrimLen(s, 40) // we don't want it to be too long
|
||||
parts = append(parts, s)
|
||||
}
|
||||
|
||||
s = strings.Join(parts, "_")
|
||||
if s == "" {
|
||||
return "empty"
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SkipShort invokes t.Skip if testing.Short and arg skip are both true.
|
||||
func SkipShort(t *testing.T, skip bool) {
|
||||
if skip && testing.Short() {
|
||||
t.Skip("Skipping long-running test because -short is true.")
|
||||
}
|
||||
}
|
||||
|
||||
// Val returns the fully dereferenced value of i. If i
|
||||
// is nil, nil is returned. If i has type *(*string),
|
||||
// Val(i) returns string.
|
||||
// Useful for testing.
|
||||
func Val(i any) any {
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(i)
|
||||
for {
|
||||
if !v.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
default:
|
||||
return v.Interface()
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
if v.IsNil() {
|
||||
return nil
|
||||
}
|
||||
v = v.Elem()
|
||||
// Loop again
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AssertCompareFunc matches several of the testify/require funcs.
|
||||
// It can be used to choose assertion comparison funcs in test cases.
|
||||
type AssertCompareFunc func(require.TestingT, any, any, ...any)
|
||||
|
||||
// Verify that a sample of the require funcs match AssertCompareFunc.
|
||||
var (
|
||||
_ AssertCompareFunc = require.Equal
|
||||
_ AssertCompareFunc = require.GreaterOrEqual
|
||||
_ AssertCompareFunc = require.Greater
|
||||
)
|
||||
|
||||
// Lock obtains a universal (cross-process) mutex for all tests.
|
||||
// This should be called by tests that cannot be executed in parallel
|
||||
// with any other test (even those in another package).
|
||||
//
|
||||
// Why? The vast majority of tests can be run in parallel, both inside
|
||||
// each test package and across test packages. The handful of tests
|
||||
// that must not be run in parallel can use this function to guarantee
|
||||
// sequential execution.
|
||||
//
|
||||
// This is implemented via a lock file /tmp/go_test.lock.
|
||||
// The lock is released via t.Cleanup.
|
||||
func Lock(t testing.TB) {
|
||||
fp := filepath.Join(os.TempDir(), "go_test.lock")
|
||||
mu, err := filemutex.New(fp)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := mu.Unlock()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
95
testh/tutil/tutil_test.go
Normal file
95
testh/tutil/tutil_test.go
Normal file
@ -0,0 +1,95 @@
|
||||
package tutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestFieldExtractionFunctions tests StructFieldValue, SliceFieldValues,
|
||||
// SliceFieldKeyValues.
|
||||
func TestFieldExtractionFunctions(t *testing.T) {
|
||||
type person struct {
|
||||
UUID string
|
||||
Age int
|
||||
Nickname *string
|
||||
}
|
||||
|
||||
p1 := &person{
|
||||
UUID: "235a50d7-3955-431c-8641-6ce171abf589",
|
||||
Age: 42,
|
||||
Nickname: nil,
|
||||
}
|
||||
|
||||
nn := "The Great"
|
||||
p2 := &person{
|
||||
UUID: "81975a8f-6add-441a-8c81-3806a9f4c6f0",
|
||||
Age: 27,
|
||||
Nickname: &nn,
|
||||
}
|
||||
|
||||
uu := StructFieldValue("UUID", p1)
|
||||
require.Equal(t, uu, p1.UUID)
|
||||
age := StructFieldValue("Age", p1)
|
||||
require.Equal(t, age, 42)
|
||||
|
||||
require.Panics(t, func() {
|
||||
_ = StructFieldValue("UUID", 123)
|
||||
}, "non-struct arg should panic")
|
||||
|
||||
require.Nil(t, StructFieldValue("UUID", nil))
|
||||
|
||||
require.Panics(t, func() {
|
||||
_ = StructFieldValue("", p1)
|
||||
}, "invalid fieldName should panic")
|
||||
|
||||
require.Panics(t, func() {
|
||||
_ = StructFieldValue("NotAField", p1)
|
||||
}, "invalid fieldName should panic")
|
||||
|
||||
nickname := StructFieldValue("Nickname", p1)
|
||||
require.Nil(t, nickname)
|
||||
|
||||
nickname = StructFieldValue("Nickname", p2)
|
||||
require.NotNil(t, nickname)
|
||||
require.EqualValues(t, nickname, p2.Nickname)
|
||||
|
||||
iSlice := []any{p1, p2}
|
||||
iVals := SliceFieldValues("UUID", iSlice)
|
||||
require.Len(t, iVals, 2)
|
||||
require.Equal(t, p1.UUID, iVals[0])
|
||||
|
||||
personPtrSlice := []*person{p1, p2}
|
||||
iVals2 := SliceFieldValues("UUID", personPtrSlice)
|
||||
require.Len(t, iVals2, 2)
|
||||
require.EqualValues(t, iVals, iVals2)
|
||||
|
||||
personSlice := []person{*p1, *p2}
|
||||
iVals3 := SliceFieldValues("UUID", personSlice)
|
||||
require.Len(t, iVals2, 2)
|
||||
require.EqualValues(t, iVals, iVals3)
|
||||
|
||||
require.Panics(t, func() {
|
||||
_ = SliceFieldValues("UUID", p1)
|
||||
}, "non-slice arg should panic")
|
||||
|
||||
m1 := SliceFieldKeyValues("UUID", "Age", iSlice)
|
||||
require.Len(t, m1, 2)
|
||||
|
||||
require.Equal(t, m1[p1.UUID], p1.Age)
|
||||
require.Equal(t, m1[p2.UUID], p2.Age)
|
||||
}
|
||||
|
||||
func TestInterfaceSlice(t *testing.T) {
|
||||
stringSlice := []string{"hello", "world"}
|
||||
iSlice := InterfaceSlice(stringSlice)
|
||||
require.Equal(t, len(stringSlice), len(iSlice))
|
||||
require.Equal(t, stringSlice[0], iSlice[0])
|
||||
|
||||
iSlice = InterfaceSlice(nil)
|
||||
require.Nil(t, iSlice)
|
||||
|
||||
require.Panics(t, func() {
|
||||
_ = InterfaceSlice(42)
|
||||
}, "should panic for non-slice arg")
|
||||
}
|
Loading…
Reference in New Issue
Block a user