sq/drivers/csv/csv.go

314 lines
6.9 KiB
Go
Raw Normal View History

2020-08-06 20:58:47 +03:00
// Package csv implements the sq driver for CSV/TSV et al.
package csv
import (
"context"
"database/sql"
2020-08-06 20:58:47 +03:00
"encoding/csv"
"errors"
2020-08-06 20:58:47 +03:00
"io"
"strconv"
"github.com/neilotoole/lg"
"github.com/neilotoole/sq/cli/output/csvw"
"github.com/neilotoole/sq/libsq/core/errz"
2020-08-06 20:58:47 +03:00
"github.com/neilotoole/sq/libsq/driver"
"github.com/neilotoole/sq/libsq/source"
)
const (
// TypeCSV is the CSV driver type.
TypeCSV = source.Type("csv")
// TypeTSV is the TSV driver type.
TypeTSV = source.Type("tsv")
)
// Provider implements driver.Provider.
type Provider struct {
Log lg.Log
Scratcher driver.ScratchDatabaseOpener
Files *source.Files
}
// DriverFor implements driver.Provider.
func (d *Provider) DriverFor(typ source.Type) (driver.Driver, error) {
switch typ { //nolint:exhaustive
2020-08-06 20:58:47 +03:00
case TypeCSV:
return &driveri{log: d.Log, typ: TypeCSV, scratcher: d.Scratcher, files: d.Files}, nil
2020-08-06 20:58:47 +03:00
case TypeTSV:
return &driveri{log: d.Log, typ: TypeTSV, scratcher: d.Scratcher, files: d.Files}, nil
2020-08-06 20:58:47 +03:00
}
return nil, errz.Errorf("unsupported driver type %q", typ)
}
// Driver implements driver.Driver.
type driveri struct {
2020-08-06 20:58:47 +03:00
log lg.Log
typ source.Type
scratcher driver.ScratchDatabaseOpener
files *source.Files
}
// DriverMetadata implements driver.Driver.
func (d *driveri) DriverMetadata() driver.Metadata {
2020-08-06 20:58:47 +03:00
md := driver.Metadata{Type: d.typ, Monotable: true}
if d.typ == TypeCSV {
md.Description = "Comma-Separated Values"
md.Doc = "https://en.wikipedia.org/wiki/Comma-separated_values"
} else {
md.Description = "Tab-Separated Values"
md.Doc = "https://en.wikipedia.org/wiki/Tab-separated_values"
}
return md
}
// Open implements driver.Driver.
func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Database, error) {
dbase := &database{
log: d.log,
src: src,
files: d.files,
2020-08-06 20:58:47 +03:00
}
var err error
2020-08-06 20:58:47 +03:00
dbase.impl, err = d.scratcher.OpenScratch(ctx, src.Handle)
if err != nil {
return nil, err
}
err = importCSV(ctx, d.log, src, d.files.OpenFunc(src), dbase.impl)
2020-08-06 20:58:47 +03:00
if err != nil {
return nil, err
}
return dbase, nil
}
// Truncate implements driver.Driver.
func (d *driveri) Truncate(ctx context.Context, src *source.Source, tbl string, reset bool) (int64, error) {
2020-08-06 20:58:47 +03:00
// TODO: CSV could support Truncate for local files
return 0, errz.Errorf("truncate not supported for %s", d.DriverMetadata().Type)
}
// ValidateSource implements driver.Driver.
func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) {
2020-08-06 20:58:47 +03:00
if src.Type != d.typ {
return nil, errz.Errorf("expected source type %q but got %q", d.typ, src.Type)
}
if src.Options != nil || len(src.Options) > 0 {
d.log.Debugf("opts: %v", src.Options.Encode())
key := "header"
v := src.Options.Get(key)
if v != "" {
_, err := strconv.ParseBool(v)
if err != nil {
return nil, errz.Errorf(`unable to parse option %q: %v`, key, err)
}
}
}
return src, nil
}
// Ping implements driver.Driver.
func (d *driveri) Ping(ctx context.Context, src *source.Source) error {
2020-08-06 20:58:47 +03:00
d.log.Debugf("driver %q attempting to ping %q", d.typ, src)
r, err := d.files.Open(src)
2020-08-06 20:58:47 +03:00
if err != nil {
return err
}
defer d.log.WarnIfCloseError(r)
return nil
}
// database implements driver.Database.
type database struct {
log lg.Log
src *source.Source
impl driver.Database
files *source.Files
}
// DB implements driver.Database.
func (d *database) DB() *sql.DB {
return d.impl.DB()
}
// SQLDriver implements driver.Database.
func (d *database) SQLDriver() driver.SQLDriver {
return d.impl.SQLDriver()
}
// Source implements driver.Database.
func (d *database) Source() *source.Source {
return d.src
}
// TableMetadata implements driver.Database.
func (d *database) TableMetadata(ctx context.Context, tblName string) (*source.TableMetadata, error) {
if tblName != source.MonotableName {
return nil, errz.Errorf("table name should be %s for CSV/TSV etc., but got: %s",
source.MonotableName, tblName)
}
srcMeta, err := d.SourceMetadata(ctx)
2020-08-06 20:58:47 +03:00
if err != nil {
return nil, err
2020-08-06 20:58:47 +03:00
}
// There will only ever be one table for CSV.
return srcMeta.Tables[0], nil
}
// SourceMetadata implements driver.Database.
func (d *database) SourceMetadata(ctx context.Context) (*source.Metadata, error) {
md, err := d.impl.SourceMetadata(ctx)
if err != nil {
return nil, err
}
md.Handle = d.src.Handle
md.Location = d.src.Location
md.SourceType = d.src.Type
md.Name, err = source.LocationFileName(d.src)
if err != nil {
return nil, err
}
md.Size, err = d.files.Size(d.src)
if err != nil {
return nil, err
}
md.FQName = md.Name
return md, nil
}
// Close implements driver.Database.
func (d *database) Close() error {
d.log.Debugf("Close database: %s", d.src)
return errz.Err(d.impl.Close())
}
var (
_ source.TypeDetectFunc = DetectCSV
_ source.TypeDetectFunc = DetectTSV
)
// DetectCSV implements source.TypeDetectFunc.
func DetectCSV(ctx context.Context, log lg.Log, openFn source.FileOpenFunc) (detected source.Type, score float32,
err error,
) {
return detectType(ctx, TypeCSV, log, openFn)
}
// DetectTSV implements source.TypeDetectFunc.
func DetectTSV(ctx context.Context, log lg.Log, openFn source.FileOpenFunc) (detected source.Type,
score float32, err error,
) {
return detectType(ctx, TypeTSV, log, openFn)
}
func detectType(ctx context.Context, typ source.Type, log lg.Log, openFn source.FileOpenFunc) (detected source.Type,
score float32, err error,
) {
var r io.ReadCloser
r, err = openFn()
if err != nil {
return source.TypeNone, 0, errz.Err(err)
}
defer log.WarnIfCloseError(r)
delim := csvw.Comma
if typ == TypeTSV {
delim = csvw.Tab
}
cr := csv.NewReader(&crFilterReader{r: r})
cr.Comma = delim
cr.FieldsPerRecord = -1
score = isCSV(ctx, cr)
if score > 0 {
return typ, score, nil
}
return source.TypeNone, 0, nil
2020-08-06 20:58:47 +03:00
}
const (
scoreNo float32 = 0
scoreMaybe float32 = 0.1
scoreProbably float32 = 0.2
// scoreYes is less than 1.0 because other detectors
// (e.g. XLSX) can be more confident.
2020-08-06 20:58:47 +03:00
scoreYes float32 = 0.9
)
// isCSV returns a score indicating the
// the confidence that cr is reading legitimate CSV, where
// a score <= 0 is not CSV, a score >= 1 is definitely CSV.
func isCSV(ctx context.Context, cr *csv.Reader) (score float32) {
const (
maxRecords int = 100
)
var recordCount, totalFieldCount int
var avgFields float32
for i := 0; i < maxRecords; i++ {
select {
case <-ctx.Done():
return 0
default:
}
rec, err := cr.Read()
if err != nil {
if errors.Is(err, io.EOF) && rec == nil {
2020-08-06 20:58:47 +03:00
// This means end of data
break
}
// It's a genuine error
return scoreNo
}
totalFieldCount += len(rec)
recordCount++
}
if recordCount == 0 {
return scoreNo
}
avgFields = float32(totalFieldCount) / float32(recordCount)
if recordCount == 1 {
if avgFields <= 2 {
return scoreMaybe
}
return scoreProbably
}
// recordCount >= 2
switch {
case avgFields <= 1:
return scoreMaybe
case avgFields <= 2:
return scoreProbably
default:
// avgFields > 2
return scoreYes
}
}