mirror of
https://github.com/neilotoole/sq.git
synced 2024-11-28 03:53:07 +03:00
fab365f43c
* gofumpt on files * more gofumpt
222 lines
4.7 KiB
Go
222 lines
4.7 KiB
Go
package testh
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/neilotoole/lg"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/neilotoole/sq/drivers/sqlite3"
|
|
"github.com/neilotoole/sq/libsq/core/kind"
|
|
"github.com/neilotoole/sq/libsq/core/sqlz"
|
|
)
|
|
|
|
// RecordSink is a testing impl of output.RecordWriter that
|
|
// captures invocations of that interface.
|
|
type RecordSink struct {
|
|
mu sync.Mutex
|
|
|
|
// RecMeta holds the recMeta received via Open.
|
|
RecMeta sqlz.RecordMeta
|
|
|
|
// Recs holds the records received via WriteRecords.
|
|
Recs []sqlz.Record
|
|
|
|
// Closed tracks the times Close was invoked.
|
|
Closed []time.Time
|
|
|
|
// Flushed tracks the times Flush was invoked.
|
|
Flushed []time.Time
|
|
}
|
|
|
|
// Open implements libsq.RecordWriter.
|
|
func (r *RecordSink) Open(recMeta sqlz.RecordMeta) error {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
r.RecMeta = recMeta
|
|
return nil
|
|
}
|
|
|
|
// WriteRecords implements libsq.RecordWriter.
|
|
func (r *RecordSink) WriteRecords(recs []sqlz.Record) error {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
r.Recs = append(r.Recs, recs...)
|
|
return nil
|
|
}
|
|
|
|
// Flush implements libsq.RecordWriter.
|
|
func (r *RecordSink) Flush() error {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.Flushed = append(r.Flushed, time.Now())
|
|
return nil
|
|
}
|
|
|
|
// Close implements libsq.RecordWriter.
|
|
func (r *RecordSink) Close() error {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.Closed = append(r.Closed, time.Now())
|
|
return nil
|
|
}
|
|
|
|
var (
|
|
recSinkCache = map[string]*RecordSink{}
|
|
recSinkMu sync.Mutex
|
|
)
|
|
|
|
// RecordsFromTbl returns a cached copy of all records from handle.tbl.
|
|
// The function performs a "SELECT * FROM tbl" and caches (in a package
|
|
// variable) the returned recs and recMeta for subsequent calls. Thus
|
|
// if the underlying data source records are modified, the returned records
|
|
// may be inconsistent.
|
|
//
|
|
// This function effectively exists to speed up testing times.
|
|
func RecordsFromTbl(tb testing.TB, handle, tbl string) (recMeta sqlz.RecordMeta, recs []sqlz.Record) {
|
|
recSinkMu.Lock()
|
|
defer recSinkMu.Unlock()
|
|
|
|
key := fmt.Sprintf("#rec_sink__%s__%s", handle, tbl)
|
|
sink, ok := recSinkCache[key]
|
|
if !ok {
|
|
th := New(tb)
|
|
th.Log = lg.Discard()
|
|
src := th.Source(handle)
|
|
var err error
|
|
sink, err = th.QuerySQL(src, "SELECT * FROM "+tbl)
|
|
require.NoError(tb, err)
|
|
recSinkCache[key] = sink
|
|
}
|
|
|
|
// Make copies so that the caller can mutate their records
|
|
// without it affecting other callers
|
|
recMeta = make(sqlz.RecordMeta, len(sink.RecMeta))
|
|
|
|
// Don't need to make a deep copy of each FieldMeta because
|
|
// the type is effectively immutable
|
|
copy(recMeta, sink.RecMeta)
|
|
|
|
recs = CopyRecords(sink.Recs)
|
|
return recMeta, recs
|
|
}
|
|
|
|
// NewRecordMeta builds a new RecordMeta instance for testing.
|
|
func NewRecordMeta(colNames []string, colKinds []kind.Kind) sqlz.RecordMeta {
|
|
recMeta := make(sqlz.RecordMeta, len(colNames))
|
|
for i := range colNames {
|
|
knd := colKinds[i]
|
|
ct := &sqlz.ColumnTypeData{
|
|
Name: colNames[i],
|
|
HasNullable: true,
|
|
Nullable: true,
|
|
DatabaseTypeName: sqlite3.DBTypeForKind(knd),
|
|
ScanType: KindScanType(knd),
|
|
Kind: knd,
|
|
}
|
|
|
|
recMeta[i] = sqlz.NewFieldMeta(ct)
|
|
}
|
|
|
|
return recMeta
|
|
}
|
|
|
|
// CopyRecords returns a deep copy of recs.
|
|
func CopyRecords(recs []sqlz.Record) []sqlz.Record {
|
|
if recs == nil {
|
|
return recs
|
|
}
|
|
|
|
if len(recs) == 0 {
|
|
return []sqlz.Record{}
|
|
}
|
|
|
|
r2 := make([]sqlz.Record, len(recs))
|
|
for i := range recs {
|
|
r2[i] = CopyRecord(recs[i])
|
|
}
|
|
return r2
|
|
}
|
|
|
|
// CopyRecord returns a deep copy of rec.
|
|
func CopyRecord(rec sqlz.Record) sqlz.Record {
|
|
if rec == nil {
|
|
return nil
|
|
}
|
|
|
|
if len(rec) == 0 {
|
|
return sqlz.Record{}
|
|
}
|
|
|
|
r2 := make(sqlz.Record, len(rec))
|
|
for i := range rec {
|
|
val := rec[i]
|
|
switch val := val.(type) {
|
|
case nil:
|
|
continue
|
|
case *int64:
|
|
v := *val
|
|
r2[i] = &v
|
|
case *bool:
|
|
v := *val
|
|
r2[i] = &v
|
|
case *float64:
|
|
v := *val
|
|
r2[i] = &v
|
|
case *string:
|
|
v := *val
|
|
r2[i] = &v
|
|
case *[]byte:
|
|
b := make([]byte, len(*val))
|
|
copy(b, *val)
|
|
r2[i] = &b
|
|
case *time.Time:
|
|
v := *val
|
|
r2[i] = &v
|
|
default:
|
|
panic(fmt.Sprintf("field [%d] has unacceptable record value type %T", i, val))
|
|
}
|
|
}
|
|
|
|
return r2
|
|
}
|
|
|
|
// KindScanType returns the default scan type for kind. The returned
|
|
// type is typically a sql.NullType.
|
|
func KindScanType(knd kind.Kind) reflect.Type {
|
|
switch knd { //nolint:exhaustive
|
|
default:
|
|
return sqlz.RTypeNullString
|
|
|
|
case kind.Text, kind.Decimal:
|
|
return sqlz.RTypeNullString
|
|
|
|
case kind.Int:
|
|
return sqlz.RTypeNullInt64
|
|
|
|
case kind.Bool:
|
|
return sqlz.RTypeNullBool
|
|
|
|
case kind.Float:
|
|
return sqlz.RTypeNullFloat64
|
|
|
|
case kind.Bytes:
|
|
return sqlz.RTypeBytes
|
|
|
|
case kind.Datetime:
|
|
return sqlz.RTypeNullTime
|
|
|
|
case kind.Date:
|
|
return sqlz.RTypeNullTime
|
|
|
|
case kind.Time:
|
|
return sqlz.RTypeNullTime
|
|
}
|
|
}
|