mirror of
https://github.com/neilotoole/sq.git
synced 2024-12-25 09:16:59 +03:00
parent
08dfa10325
commit
ac7535609d
@ -301,5 +301,3 @@ issues:
|
|||||||
- gosec
|
- gosec
|
||||||
- noctx
|
- noctx
|
||||||
- wrapcheck
|
- wrapcheck
|
||||||
|
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ func (w *RecordWriter) writeRecord(rec sqlz.Record) error {
|
|||||||
case *[]byte:
|
case *[]byte:
|
||||||
s = base64.StdEncoding.EncodeToString(*val)
|
s = base64.StdEncoding.EncodeToString(*val)
|
||||||
case *time.Time:
|
case *time.Time:
|
||||||
switch w.recMeta[i].Kind() {
|
switch w.recMeta[i].Kind() { //nolint:exhaustive
|
||||||
default:
|
default:
|
||||||
s = val.Format(stringz.DatetimeFormat)
|
s = val.Format(stringz.DatetimeFormat)
|
||||||
case kind.Time:
|
case kind.Time:
|
||||||
|
@ -60,7 +60,7 @@ func (w *recordWriter) WriteRecords(recs []sqlz.Record) error {
|
|||||||
case *float64:
|
case *float64:
|
||||||
fmt.Fprint(w.out, stringz.FormatFloat(*val))
|
fmt.Fprint(w.out, stringz.FormatFloat(*val))
|
||||||
case *time.Time:
|
case *time.Time:
|
||||||
switch w.recMeta[i].Kind() {
|
switch w.recMeta[i].Kind() { //nolint:exhaustive
|
||||||
default:
|
default:
|
||||||
fmt.Fprint(w.out, val.Format(stringz.DatetimeFormat))
|
fmt.Fprint(w.out, val.Format(stringz.DatetimeFormat))
|
||||||
case kind.Time:
|
case kind.Time:
|
||||||
|
@ -35,7 +35,7 @@ type table struct {
|
|||||||
tblImpl *internal.Table
|
tblImpl *internal.Table
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *table) renderResultCell(knd kind.Kind, val any) string {
|
func (t *table) renderResultCell(knd kind.Kind, val any) string { //nolint:funlen
|
||||||
switch val := val.(type) {
|
switch val := val.(type) {
|
||||||
case string:
|
case string:
|
||||||
return val
|
return val
|
||||||
@ -159,7 +159,7 @@ func (t *table) renderResultCell(knd kind.Kind, val any) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var s string
|
var s string
|
||||||
switch knd {
|
switch knd { //nolint:exhaustive
|
||||||
default:
|
default:
|
||||||
s = val.Format(stringz.DatetimeFormat)
|
s = val.Format(stringz.DatetimeFormat)
|
||||||
case kind.Time:
|
case kind.Time:
|
||||||
|
@ -88,7 +88,7 @@ func (w *recordWriter) WriteRecords(recs []sqlz.Record) error {
|
|||||||
case *float64:
|
case *float64:
|
||||||
cell.SetFloat(*val)
|
cell.SetFloat(*val)
|
||||||
case *time.Time:
|
case *time.Time:
|
||||||
switch w.recMeta[i].Kind() {
|
switch w.recMeta[i].Kind() { //nolint:exhaustive
|
||||||
default:
|
default:
|
||||||
cell.SetDateTime(*val)
|
cell.SetDateTime(*val)
|
||||||
case kind.Date:
|
case kind.Date:
|
||||||
|
@ -196,7 +196,7 @@ func (w *recordWriter) writeRecord(rec sqlz.Record) error {
|
|||||||
case *float64:
|
case *float64:
|
||||||
w.fieldPrintFns[i](w.outBuf, stringz.FormatFloat(*val))
|
w.fieldPrintFns[i](w.outBuf, stringz.FormatFloat(*val))
|
||||||
case *time.Time:
|
case *time.Time:
|
||||||
switch w.recMeta[i].Kind() {
|
switch w.recMeta[i].Kind() { //nolint:exhaustive
|
||||||
default:
|
default:
|
||||||
w.fieldPrintFns[i](w.outBuf, val.Format(stringz.DatetimeFormat))
|
w.fieldPrintFns[i](w.outBuf, val.Format(stringz.DatetimeFormat))
|
||||||
case kind.Time:
|
case kind.Time:
|
||||||
|
@ -34,7 +34,7 @@ type Provider struct {
|
|||||||
|
|
||||||
// DriverFor implements driver.Provider.
|
// DriverFor implements driver.Provider.
|
||||||
func (d *Provider) DriverFor(typ source.Type) (driver.Driver, error) {
|
func (d *Provider) DriverFor(typ source.Type) (driver.Driver, error) {
|
||||||
switch typ {
|
switch typ { //nolint:exhaustive
|
||||||
case TypeCSV:
|
case TypeCSV:
|
||||||
return &driveri{log: d.Log, typ: TypeCSV, scratcher: d.Scratcher, files: d.Files}, nil
|
return &driveri{log: d.Log, typ: TypeCSV, scratcher: d.Scratcher, files: d.Files}, nil
|
||||||
case TypeTSV:
|
case TypeTSV:
|
||||||
|
@ -44,3 +44,14 @@ func TestReplacePlaceholders(t *testing.T) {
|
|||||||
require.Equal(t, want, got)
|
require.Equal(t, want, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_idSanitize(t *testing.T) {
|
||||||
|
testCases := map[string]string{
|
||||||
|
`tbl_name`: `"tbl_name"`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for input, want := range testCases {
|
||||||
|
got := idSanitize(input)
|
||||||
|
require.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v4"
|
||||||
// Import jackc/pgx, which is our postgres driver.
|
// Import jackc/pgx, which is our postgres driver.
|
||||||
_ "github.com/jackc/pgx/v4/stdlib"
|
_ "github.com/jackc/pgx/v4/stdlib"
|
||||||
"github.com/neilotoole/lg"
|
"github.com/neilotoole/lg"
|
||||||
@ -137,26 +138,25 @@ func (d *driveri) Truncate(ctx context.Context, src *source.Source, tbl string,
|
|||||||
// https://www.postgresql.org/docs/9.1/sql-truncate.html
|
// https://www.postgresql.org/docs/9.1/sql-truncate.html
|
||||||
|
|
||||||
// RESTART IDENTITY and CASCADE/RESTRICT are from pg 8.2 onwards
|
// RESTART IDENTITY and CASCADE/RESTRICT are from pg 8.2 onwards
|
||||||
// TODO: should first check the pg version for < pg8.2 support
|
// FIXME: should first check the pg version for < pg8.2 support
|
||||||
|
|
||||||
db, err := sql.Open(dbDrvr, src.Location)
|
db, err := sql.Open(dbDrvr, src.Location)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return affected, errz.Err(err)
|
return affected, errz.Err(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
truncateQuery := fmt.Sprintf("TRUNCATE TABLE %q", tbl)
|
affectedQuery := "SELECT COUNT(*) FROM " + idSanitize(tbl)
|
||||||
if reset {
|
|
||||||
// if reset & src.DBVersion >= 8.2
|
|
||||||
truncateQuery += " RESTART IDENTITY" // default is CONTINUE IDENTITY
|
|
||||||
}
|
|
||||||
// We could add RESTRICT here; alternative is CASCADE
|
|
||||||
|
|
||||||
affectedQuery := fmt.Sprintf("SELECT COUNT(*) FROM %q", tbl)
|
|
||||||
err = db.QueryRowContext(ctx, affectedQuery).Scan(&affected)
|
err = db.QueryRowContext(ctx, affectedQuery).Scan(&affected)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errz.Err(err)
|
return 0, errz.Err(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
truncateQuery := "TRUNCATE TABLE " + idSanitize(tbl)
|
||||||
|
if reset {
|
||||||
|
// if reset & src.DBVersion >= 8.2
|
||||||
|
truncateQuery += " RESTART IDENTITY" // default is CONTINUE IDENTITY
|
||||||
|
}
|
||||||
|
// We could add RESTRICT here; alternative is CASCADE
|
||||||
_, err = db.ExecContext(ctx, truncateQuery)
|
_, err = db.ExecContext(ctx, truncateQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errz.Err(err)
|
return 0, errz.Err(err)
|
||||||
@ -165,6 +165,14 @@ func (d *driveri) Truncate(ctx context.Context, src *source.Source, tbl string,
|
|||||||
return affected, nil
|
return affected, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// idSanitize sanitizes an identifier (such as table name). It will
|
||||||
|
// add surrounding quotes. For example:
|
||||||
|
//
|
||||||
|
// table_name --> "table_name"
|
||||||
|
func idSanitize(s string) string {
|
||||||
|
return pgx.Identifier([]string{s}).Sanitize()
|
||||||
|
}
|
||||||
|
|
||||||
// CreateTable implements driver.SQLDriver.
|
// CreateTable implements driver.SQLDriver.
|
||||||
func (d *driveri) CreateTable(ctx context.Context, db sqlz.DB, tblDef *sqlmodel.TableDef) error {
|
func (d *driveri) CreateTable(ctx context.Context, db sqlz.DB, tblDef *sqlmodel.TableDef) error {
|
||||||
stmt := buildCreateTableStmt(tblDef)
|
stmt := buildCreateTableStmt(tblDef)
|
||||||
|
@ -488,12 +488,7 @@ func (d *database) Close() error {
|
|||||||
d.closeMu.Lock()
|
d.closeMu.Lock()
|
||||||
defer d.closeMu.Unlock()
|
defer d.closeMu.Unlock()
|
||||||
|
|
||||||
//if !d.closed {
|
|
||||||
// debug.PrintStack()
|
|
||||||
//}
|
|
||||||
|
|
||||||
if d.closed {
|
if d.closed {
|
||||||
//panic( "SQLITE DB already closed")
|
|
||||||
d.log.Warnf("SQLite DB already closed: %v", d.src)
|
d.log.Warnf("SQLite DB already closed: %v", d.src)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -147,7 +147,6 @@ func (d *driveri) Ping(ctx context.Context, src *source.Source) error {
|
|||||||
//nolint:lll
|
//nolint:lll
|
||||||
func (d *driveri) Truncate(ctx context.Context, src *source.Source, tbl string, reset bool) (affected int64,
|
func (d *driveri) Truncate(ctx context.Context, src *source.Source, tbl string, reset bool) (affected int64,
|
||||||
err error) {
|
err error) {
|
||||||
|
|
||||||
// https://docs.microsoft.com/en-us/sql/t-sql/statements/truncate-table-transact-sql?view=sql-server-ver15
|
// https://docs.microsoft.com/en-us/sql/t-sql/statements/truncate-table-transact-sql?view=sql-server-ver15
|
||||||
|
|
||||||
// When there are foreign key constraints on mssql tables,
|
// When there are foreign key constraints on mssql tables,
|
||||||
|
@ -199,7 +199,7 @@ func stringWithCharset(length int, charset string) string {
|
|||||||
|
|
||||||
b := make([]byte, length)
|
b := make([]byte, length)
|
||||||
for i := range b {
|
for i := range b {
|
||||||
b[i] = charset[rand.Intn(len(charset))]
|
b[i] = charset[rand.Intn(len(charset))] //#nosec G404 // Doesn't need to be strongly random
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(b)
|
return string(b)
|
||||||
|
@ -108,7 +108,8 @@ func (x *StmtExecer) Close() error {
|
|||||||
// copied directly into rec, and its index is returned in skipped.
|
// copied directly into rec, and its index is returned in skipped.
|
||||||
// The caller must take appropriate action to deal with all
|
// The caller must take appropriate action to deal with all
|
||||||
// elements of rec listed in skipped.
|
// elements of rec listed in skipped.
|
||||||
func NewRecordFromScanRow(meta sqlz.RecordMeta, row []any, skip []int) (rec sqlz.Record, skipped []int) {
|
func NewRecordFromScanRow(meta sqlz.RecordMeta, row []any, skip []int) (rec sqlz.Record,
|
||||||
|
skipped []int) { //nolint:funlen
|
||||||
rec = make([]any, len(row))
|
rec = make([]any, len(row))
|
||||||
|
|
||||||
// For convenience, make a map of the skip row indices.
|
// For convenience, make a map of the skip row indices.
|
||||||
|
@ -188,7 +188,7 @@ func CopyRecord(rec sqlz.Record) sqlz.Record {
|
|||||||
// KindScanType returns the default scan type for kind. The returned
|
// KindScanType returns the default scan type for kind. The returned
|
||||||
// type is typically a sql.NullType.
|
// type is typically a sql.NullType.
|
||||||
func KindScanType(knd kind.Kind) reflect.Type {
|
func KindScanType(knd kind.Kind) reflect.Type {
|
||||||
switch knd {
|
switch knd { //nolint:exhaustive
|
||||||
default:
|
default:
|
||||||
return sqlz.RTypeNullString
|
return sqlz.RTypeNullString
|
||||||
|
|
||||||
|
@ -203,7 +203,7 @@ func (h *Helper) Source(handle string) *source.Source {
|
|||||||
require.NoError(t, err,
|
require.NoError(t, err,
|
||||||
"source %s was not found in %s", handle, testsrc.PathSrcsConfig)
|
"source %s was not found in %s", handle, testsrc.PathSrcsConfig)
|
||||||
|
|
||||||
switch src.Type {
|
switch src.Type { //nolint:exhaustive
|
||||||
case sqlite3.Type:
|
case sqlite3.Type:
|
||||||
// This could be easily generalized for CSV/XLSX etc.
|
// This could be easily generalized for CSV/XLSX etc.
|
||||||
fpath, err := sqlite3.PathFromLocation(src)
|
fpath, err := sqlite3.PathFromLocation(src)
|
||||||
|
Loading…
Reference in New Issue
Block a user