#244: shell completion for "sq add LOCATION" (#246)

- Shell completion for `sq add LOCATION`.
This commit is contained in:
Neil O'Toole 2023-06-13 10:06:18 -06:00 committed by GitHub
parent 3ecdde5595
commit 9cb42bf579
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 3558 additions and 113 deletions

View File

@ -9,9 +9,9 @@ on:
workflow_dispatch:
env:
GO_VERSION: 1.20.2
GO_VERSION: 1.20.5
GORELEASER_VERSION: 1.13.1
GOLANGCI_LINT_VERSION: v1.52.2
GOLANGCI_LINT_VERSION: v1.53.2
jobs:

View File

@ -10,9 +10,10 @@
// usage pattern by eliminating all pkg-level constructs
// (which makes testing easier).
//
// All interaction with cobra should happen inside this package.
// All interaction with cobra should happen inside this package, or
// via the utility cli/cobraz package.
// That is to say, the spf13/cobra package should not be imported
// anywhere outside this package.
// anywhere outside this package and cli/cobraz.
//
// The entry point to this pkg is the Execute function.
package cli
@ -103,7 +104,7 @@ func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error {
// now handles this situation?
// We need to perform handling for autocomplete
if len(args) > 0 && args[0] == "__complete" {
if len(args) > 0 && args[0] == cobra.ShellCompRequestCmd {
if hasMatchingChildCommand(rootCmd, args[1]) {
// If there is a matching child command, we let rootCmd
// handle it, as per normal.
@ -112,7 +113,7 @@ func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error {
// There's no command matching the first argument to __complete.
// Therefore, we assume that we want to perform completion
// for the "slq" command (which is the pseudo-root command).
effectiveArgs := append([]string{"__complete", "slq"}, args[1:]...)
effectiveArgs := append([]string{cobra.ShellCompRequestCmd, "slq"}, args[1:]...)
rootCmd.SetArgs(effectiveArgs)
}
} else {

34
cli/cli_export_test.go Normal file
View File

@ -0,0 +1,34 @@
package cli
// This file exports package constructs for testing.
import (
"testing"
"github.com/neilotoole/sq/cli/run"
)
type PlocStage = plocStage
const (
PlocInit = plocInit
PlocScheme = plocScheme
PlocUser = plocUser
PlocPass = plocPass
PlocHostname = plocHostname
PlocHost = plocHost
PlocPath = plocPath
)
var DoCompleteAddLocationFile = locCompListFiles
// ToTestParseLocStage is a helper to test the
// non-exported locCompletionHelper.locCompParseLoc method.
func DoTestParseLocStage(t testing.TB, ru *run.Run, loc string) (PlocStage, error) { //nolint:revive
ploc, err := locCompParseLoc(loc)
if err != nil {
return PlocInit, err
}
return ploc.stageDone, nil
}

View File

@ -26,9 +26,10 @@ import (
func newSrcAddCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "add [--handle @HANDLE] LOCATION",
RunE: execSrcAdd,
Args: cobra.ExactArgs(1),
Use: "add [--handle @HANDLE] LOCATION",
RunE: execSrcAdd,
Args: cobra.ExactArgs(1),
ValidArgsFunction: completeAddLocation,
Example: `
When adding a data source, LOCATION is the only required arg.
@ -39,21 +40,30 @@ Note that sq generated the handle "@actor". But you can explicitly specify
a handle.
# Add a postgres source with handle "@sakila/pg"
$ sq add --handle @sakila/pg 'postgres://user:pass@localhost/sakila'
$ sq add --handle @sakila/pg postgres://user:pass@localhost/sakila
This handle format "@sakila/pg" includes a group, "sakila". Using a group
is entirely optional: it is a way to organize sources. For example:
$ sq add --handle @dev/pg 'postgres://user:pass@dev.db.example.com/sakila'
$ sq add --handle @prod/pg 'postgres://user:pass@prod.db.acme.com/sakila'
$ sq add --handle @dev/pg postgres://user:pass@dev.db.acme.com/sakila
$ sq add --handle @prod/pg postgres://user:pass@prod.db.acme.com/sakila
The format of LOCATION is driver-specific, but is generally a DB connection
string, a file path, or a URL.
DRIVER://USER:PASS@HOST:PORT/DBNAME
DRIVER://USER:PASS@HOST:PORT/DBNAME?PARAM=VAL
/path/to/local/file.ext
https://sq.io/data/test1.xlsx
If LOCATION contains special shell characters, it's necessary to enclose
it in single quotes, or to escape the special character. For example,
note the "\?" in the unquoted location below.
$ sq add postgres://user:pass@localhost/sakila\?sslmode=disable
A significant advantage of not quoting LOCATION is that sq provides extensive
shell completion when inputting the location value.
If flag --handle is omitted, sq will generate a handle based
on LOCATION and the source driver type.
@ -61,7 +71,7 @@ It's a security hazard to expose the data source password via
the LOCATION string. If flag --password (-p) is set, sq prompt the
user for the password:
$ sq add 'postgres://user@localhost/sakila' -p
$ sq add postgres://user@localhost/sakila -p
Password: ****
However, if there's input on stdin, sq will read the password from
@ -69,18 +79,18 @@ there instead of prompting the user:
# Add a source, but read password from an environment variable
$ export PASSWD='open:;"_Ses@me'
$ sq add 'postgres://user@localhost/sakila' -p <<< $PASSWD
$ sq add postgres://user@localhost/sakila -p <<< $PASSWD
# Same as above, but instead read password from file
$ echo 'open:;"_Ses@me' > password.txt
$ sq add 'postgres://user@localhost/sakila' -p < password.txt
$ sq add postgres://user@localhost/sakila -p < password.txt
There are various driver-specific options available. For example:
$ sq add actor.csv --ingest.header=false --driver.csv.delim=colon
If flag --driver is omitted, sq will attempt to determine the
type from LOCATION via file suffix, content type, etc.. If the result
type from LOCATION via file suffix, content type, etc. If the result
is ambiguous, explicitly specify the driver type.
$ sq add --driver=tsv ./mystery.data
@ -106,19 +116,25 @@ use flag --active to make the new source active.
More examples:
# Add a source, but prompt user for password
$ sq add 'postgres://user@localhost/sakila' -p
$ sq add postgres://user@localhost/sakila -p
Password: ****
# Explicitly set flags
$ sq add --handle @sakila_pg --driver postgres 'postgres://user:pass@localhost/sakila'
$ sq add --handle @sakila_pg --driver postgres postgres://user:pass@localhost/sakila
# Same as above, but with short flags
$ sq add -n @sakila_pg -d postgres 'postgres://user:pass@localhost/sakila'
$ sq add -n @sakila_pg -d postgres postgres://user:pass@localhost/sakila
# Specify some params (note escaped chars)
$ sq add postgres://user:pass@localhost/sakila\?sslmode=disable\&application_name=sq
# Specify some params, but use quoted string (no shell completion)
$ sq add 'postgres://user:pass@localhost/sakila?sslmode=disable&application_name=sq''
# Add a SQL Server source; will have generated handle @sakila
$ sq add 'sqlserver://user:pass@localhost?database=sakila'
# Add a sqlite db, and immediately make it the active source
# Add a SQLite DB, and immediately make it the active source
$ sq add ./testdata/sqlite1.db --active
# Add an Excel spreadsheet, with options
@ -134,7 +150,7 @@ More examples:
$ sq add ./actor.csv --handle @csv/actor
# Add a currently unreachable source
$ sq add 'postgres://user:pass@db.offline.com/sakila' --skip-verify`,
$ sq add postgres://user:pass@db.offline.com/sakila --skip-verify`,
Short: "Add data source",
Long: `Add data source specified by LOCATION, optionally identified by @HANDLE.`,
}

113
cli/cobraz/cobraz.go Normal file
View File

@ -0,0 +1,113 @@
// Package cobraz contains supplemental logic for dealing with spf13/cobra.
package cobraz
import (
"fmt"
"strings"
"github.com/spf13/cobra"
)
// Defines the text values for cobra.ShellCompDirective.
const (
ShellCompDirectiveErrorText = "ShellCompDirectiveError"
ShellCompDirectiveNoSpaceText = "ShellCompDirectiveNoSpace"
ShellCompDirectiveNoFileCompText = "ShellCompDirectiveNoFileComp"
ShellCompDirectiveFilterFileExtText = "ShellCompDirectiveFilterFileExt"
ShellCompDirectiveFilterDirsText = "ShellCompDirectiveFilterDirs"
ShellCompDirectiveKeepOrderText = "ShellCompDirectiveKeepOrder"
ShellCompDirectiveDefaultText = "ShellCompDirectiveDefault"
ShellCompDirectiveUnknownText = "ShellCompDirectiveUnknown"
)
// ParseDirectivesLine parses the line of text returned by "__complete" cmd
// that contains the text description of the result.
// The line looks like:
//
// Completion ended with directive: ShellCompDirectiveNoSpace, ShellCompDirectiveKeepOrder
//
// Note that this function will panic on an unknown directive.
func ParseDirectivesLine(directivesLine string) []cobra.ShellCompDirective {
trimmedLine := strings.TrimPrefix(strings.TrimSpace(directivesLine), "Completion ended with directive: ")
parts := strings.Split(trimmedLine, ", ")
directives := make([]cobra.ShellCompDirective, 0, len(parts))
for _, part := range parts {
switch part {
case ShellCompDirectiveErrorText:
directives = append(directives, cobra.ShellCompDirectiveError)
case ShellCompDirectiveNoSpaceText:
directives = append(directives, cobra.ShellCompDirectiveNoSpace)
case ShellCompDirectiveNoFileCompText:
directives = append(directives, cobra.ShellCompDirectiveNoFileComp)
case ShellCompDirectiveFilterFileExtText:
directives = append(directives, cobra.ShellCompDirectiveFilterFileExt)
case ShellCompDirectiveFilterDirsText:
directives = append(directives, cobra.ShellCompDirectiveFilterDirs)
case ShellCompDirectiveKeepOrderText:
directives = append(directives, cobra.ShellCompDirectiveKeepOrder)
case ShellCompDirectiveDefaultText:
directives = append(directives, cobra.ShellCompDirectiveDefault)
default:
panic(fmt.Sprintf("Unknown cobra.ShellCompDirective {%s} in: %s", part, directivesLine))
}
}
return directives
}
// ExtractDirectives extracts the individual directives
// from a combined directive.
func ExtractDirectives(result cobra.ShellCompDirective) []cobra.ShellCompDirective {
if result == cobra.ShellCompDirectiveDefault {
return []cobra.ShellCompDirective{cobra.ShellCompDirectiveDefault}
}
var a []cobra.ShellCompDirective
allDirectives := []cobra.ShellCompDirective{
cobra.ShellCompDirectiveError,
cobra.ShellCompDirectiveNoSpace,
cobra.ShellCompDirectiveNoFileComp,
cobra.ShellCompDirectiveFilterFileExt,
cobra.ShellCompDirectiveFilterDirs,
cobra.ShellCompDirectiveDefault,
}
for _, directive := range allDirectives {
if directive&result > 0 {
a = append(a, directive)
}
}
return a
}
// MarshalDirective marshals a cobra.ShellCompDirective to text strings,
// after extracting the embedded directives.
func MarshalDirective(directive cobra.ShellCompDirective) []string {
gotDirectives := ExtractDirectives(directive)
s := make([]string, len(gotDirectives))
for i, d := range gotDirectives {
switch d {
case cobra.ShellCompDirectiveError:
s[i] = ShellCompDirectiveErrorText
case cobra.ShellCompDirectiveNoSpace:
s[i] = ShellCompDirectiveNoSpaceText
case cobra.ShellCompDirectiveNoFileComp:
s[i] = ShellCompDirectiveNoFileCompText
case cobra.ShellCompDirectiveFilterFileExt:
s[i] = ShellCompDirectiveFilterFileExtText
case cobra.ShellCompDirectiveFilterDirs:
s[i] = ShellCompDirectiveFilterDirsText
case cobra.ShellCompDirectiveKeepOrder:
s[i] = ShellCompDirectiveKeepOrderText
case cobra.ShellCompDirectiveDefault:
s[i] = ShellCompDirectiveDefaultText
default:
// Should never happen
s[i] = ShellCompDirectiveUnknownText
}
}
return s
}

42
cli/cobraz/cobraz_test.go Normal file
View File

@ -0,0 +1,42 @@
package cobraz
import (
"testing"
"github.com/neilotoole/sq/testh/tutil"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
)
func TestExtractDirectives(t *testing.T) {
testCases := []struct {
in cobra.ShellCompDirective
want []cobra.ShellCompDirective
wantStrings []string
}{
{
cobra.ShellCompDirectiveError,
[]cobra.ShellCompDirective{cobra.ShellCompDirectiveError},
[]string{ShellCompDirectiveErrorText},
},
{
cobra.ShellCompDirectiveError | cobra.ShellCompDirectiveNoSpace,
[]cobra.ShellCompDirective{cobra.ShellCompDirectiveError, cobra.ShellCompDirectiveNoSpace},
[]string{ShellCompDirectiveErrorText, ShellCompDirectiveNoSpaceText},
},
{
cobra.ShellCompDirectiveDefault,
[]cobra.ShellCompDirective{cobra.ShellCompDirectiveDefault},
[]string{ShellCompDirectiveDefaultText},
},
}
for i, tc := range testCases {
t.Run(tutil.Name(i, tc.in), func(t *testing.T) {
gotDirectives := ExtractDirectives(tc.in)
require.Equal(t, tc.want, gotDirectives)
gotStrings := MarshalDirective(tc.in)
require.Equal(t, tc.wantStrings, gotStrings)
})
}
}

970
cli/complete_location.go Normal file
View File

@ -0,0 +1,970 @@
package cli
import (
"context"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/neilotoole/sq/libsq/core/urlz"
"golang.org/x/exp/slog"
"github.com/neilotoole/sq/libsq/core/ioz"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/cli/run"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/driver"
"golang.org/x/exp/slices"
"github.com/neilotoole/sq/drivers/mysql"
"github.com/neilotoole/sq/drivers/postgres"
"github.com/neilotoole/sq/drivers/sqlite3"
"github.com/neilotoole/sq/drivers/sqlserver"
"github.com/xo/dburl"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/stringz"
"github.com/neilotoole/sq/libsq/source"
"github.com/samber/lo"
"github.com/spf13/cobra"
)
// locCompStdDirective is the standard cobra shell completion directive
// returned by completeAddLocation.
const locCompStdDirective = cobra.ShellCompDirectiveNoSpace | cobra.ShellCompDirectiveKeepOrder
// completeAddLocation provides completion for the "sq add LOCATION" arg.
// This is a messy task, as LOCATION can be a database driver URL,
// and it can also be a filepath. To complicate matters further, sqlite
// has a format sqlite://FILE/PATH?param=val, which is a driver URL, with
// embedded file completion.
//
// The general strategy is:
// - Does toComplete have a driver prefix ("postgres://", "sqlite3://" etc)?
// If so, delegate to the appropriate function.
// - Is toComplete definitively NOT a driver URL? That is to say, is toComplete
// a file path? If so, then we need regular shell file completion.
// Return cobra.ShellCompDirectiveDefault, and let the shell handle it.
// - There's a messy overlap where toComplete could be either a driver URL
// or a filepath. For example, "post" could be leading to "postgres://", or
// to a file named "post.db". For this situation, it is necessary to
// mimic in code the behavior of the shell's file completion.
// - There's another layer of complexity: previous locations (i.e. "history")
// are also suggested.
//
// The code, as currently structured, is ungainly, and downright ugly in
// spots, and probably won't scale well if more drivers are supported. That
// is to say, this mechanism would benefit from a through refactor.
func completeAddLocation(cmd *cobra.Command, args []string, toComplete string) (
[]string, cobra.ShellCompDirective,
) {
if len(args) > 0 {
return nil, cobra.ShellCompDirectiveError
}
if strings.HasPrefix(toComplete, "/") || strings.HasPrefix(toComplete, ".") {
// This has to be a file path.
// Go straight to default (file) completion.
return nil, cobra.ShellCompDirectiveDefault
}
var a []string
if toComplete == "" {
// No input yet. Offer both the driver URL schemes and file listing.
a = append(a, locSchemes...)
files := locCompListFiles(cmd.Context(), toComplete)
if len(files) > 0 {
a = append(a, files...)
}
return a, locCompStdDirective
}
// We've got some input in toComplete...
if !stringz.HasAnyPrefix(toComplete, locSchemes...) {
// But toComplete isn't a full match for any of the driver
// URL schemes. However, it could still be a partial match.
a = stringz.FilterPrefix(toComplete, locSchemes...)
if len(a) == 0 {
// We're not matching any URL prefix, fall back to default
// shell completion, i.e. list files.
return nil, cobra.ShellCompDirectiveDefault
}
// Partial match, e.g. "post". So, this could match both
// a URL such as "postgres://", or a file such as "post.db".
files := locCompListFiles(cmd.Context(), toComplete)
if len(files) > 0 {
a = append(a, files...)
}
return a, locCompStdDirective
}
// If we got this far, we know that toComplete starts with one of the
// driver schemes, e.g. "postgres://". There's no possibility that
// this could be a file completion.
if strings.HasPrefix(toComplete, string(sqlite3.Type)) {
// Special handling for sqlite.
return locCompDoSQLite3(cmd, args, toComplete)
}
return locCompDoGenericDriver(cmd, args, toComplete)
}
// locCompDoGenericDriver provides completion for generic SQL drivers.
// Specifically, it's tested with postgres, sqlserver, and mysql. Note that
// sqlserver is slightly different from the others, in that the db name goes
// in a query param, not in the URL path. It might be cleaner to split sqlserver
// off into its own function.
func locCompDoGenericDriver(cmd *cobra.Command, _ []string, toComplete string, //nolint:funlen,gocognit
) ([]string, cobra.ShellCompDirective) {
// If we get this far, then toComplete is at least a partial URL
// starting with "postgres://", "mysql://", etc.
var (
ctx = cmd.Context()
log = lg.FromContext(ctx)
ru = run.FromContext(ctx)
drvr driver.SQLDriver
ploc *parsedLoc
a []string // a holds the completion strings to be returned
err error
)
if err = FinishRunInit(ctx, ru); err != nil {
log.Error("Init run", lga.Err, err)
return nil, cobra.ShellCompDirectiveError
}
if ploc, err = locCompParseLoc(toComplete); err != nil {
log.Error("Parse location", lga.Err, err)
return nil, cobra.ShellCompDirectiveError
}
if drvr, err = ru.DriverRegistry.SQLDriverFor(ploc.typ); err != nil {
log.Error("Load driver", lga.Err, err)
return nil, cobra.ShellCompDirectiveError
}
hist := &locHistory{
coll: ru.Config.Collection,
typ: ploc.typ,
log: log,
}
switch ploc.stageDone { //nolint:exhaustive
case plocScheme:
unames := hist.usernames()
if ploc.user == "" {
a = stringz.PrefixSlice(unames, toComplete)
a = append(a, toComplete+"username")
a = lo.Uniq(a)
a = lo.Without(a, toComplete)
return a, locCompStdDirective
}
// Else, we have at least a partial username
a = []string{
toComplete + "@",
toComplete + ":",
}
for _, uname := range unames {
v := string(ploc.typ) + "://" + uname
a = append(a, v+"@")
a = append(a, v+":")
}
a = lo.Uniq(a)
a = stringz.FilterPrefix(toComplete, a...)
a = lo.Without(a, toComplete)
return a, locCompStdDirective
case plocUser:
if ploc.pass == "" {
a = []string{
toComplete,
toComplete + "@",
toComplete + "password@",
}
return a, locCompStdDirective
}
a = []string{
toComplete + "@",
}
return a, locCompStdDirective
case plocPass:
hosts := hist.hosts()
hostsWithPath := hist.hostsWithPathAndQuery()
defaultPort := locCompDriverPort(drvr)
afterHost := locCompAfterHost(ploc.typ)
if ploc.hostname == "" {
if defaultPort == "" {
a = []string{
toComplete + "localhost" + afterHost,
}
} else {
a = []string{
toComplete + "localhost" + afterHost,
toComplete + "localhost:" + defaultPort + afterHost,
}
}
var b []string
for _, h := range hostsWithPath {
v := toComplete + h
b = append(b, v)
}
for _, h := range hosts {
v := toComplete + h + afterHost
b = append(b, v)
}
slices.Sort(b)
a = append(a, b...)
a = lo.Uniq(a)
a = stringz.FilterPrefix(toComplete, a...)
a = lo.Without(a, toComplete)
return a, locCompStdDirective
}
base, _, _ := strings.Cut(toComplete, "@")
base += "@"
if ploc.port <= 0 {
if defaultPort == "" {
a = []string{
toComplete + afterHost,
base + "localhost" + afterHost,
}
} else {
a = []string{
toComplete + afterHost,
toComplete + ":" + defaultPort + afterHost,
base + "localhost" + afterHost,
base + "localhost:" + defaultPort + afterHost,
}
}
var b []string
for _, h := range hostsWithPath {
v := base + h
b = append(b, v)
}
for _, h := range hosts {
v := base + h + afterHost
b = append(b, v)
}
slices.Sort(b)
a = append(a, b...)
a = lo.Uniq(a)
a = stringz.FilterPrefix(toComplete, a...)
a = lo.Without(a, toComplete)
return a, locCompStdDirective
}
if defaultPort == "" {
a = []string{
base + "localhost" + afterHost,
toComplete + afterHost,
}
} else {
a = []string{
base + "localhost" + afterHost,
base + "localhost:" + defaultPort + afterHost,
toComplete + afterHost,
}
}
var b []string
for _, h := range hostsWithPath {
v := base + h
b = append(b, v)
}
for _, h := range hosts {
v := base + h + afterHost
b = append(b, v)
}
slices.Sort(b)
a = append(a, b...)
a = lo.Uniq(a)
a = stringz.FilterPrefix(toComplete, a...)
a = lo.Without(a, toComplete)
return a, locCompStdDirective
case plocHostname:
defaultPort := locCompDriverPort(drvr)
afterHost := locCompAfterHost(ploc.typ)
if strings.HasSuffix(toComplete, ":") {
a = []string{toComplete + defaultPort + afterHost}
return a, locCompStdDirective
}
a = []string{toComplete + afterHost}
return a, locCompStdDirective
case plocHost:
dbNames := hist.databases()
// Special handling for SQLServer. The input is typically of the form:
// sqlserver://alice@server?database=db
// But it can also be of the form:
// sqlserver://alice@server/instance?database=db
if ploc.typ == sqlserver.Type {
if ploc.du.Path == "/" {
a = []string{toComplete + "instance?database="}
return a, locCompStdDirective
}
a = []string{toComplete + "?database="}
return a, locCompStdDirective
}
if ploc.name == "" {
a = []string{toComplete + "db"}
for _, dbName := range dbNames {
v := toComplete + dbName
a = append(a, v)
}
a = lo.Uniq(a)
a = lo.Without(a, toComplete)
return a, locCompStdDirective
}
// We already have a partial dbname
a = []string{toComplete + "?"}
base := urlz.StripPath(ploc.du.URL, true)
for _, dbName := range dbNames {
a = append(a, base+"/"+dbName)
}
a = lo.Uniq(a)
a = lo.Without(a, toComplete)
return a, locCompStdDirective
default:
// We're at plocName (db name is done), so it's on to conn params.
return locCompDoConnParams(ploc.du, hist, drvr, toComplete)
}
}
// locCompDoSQLite3 completes a location starting with "sqlite3://".
// We have special handling for SQLite, because it's not a generic
// driver URL, but rather sqlite3://FILE/PATH?param=X.
func locCompDoSQLite3(cmd *cobra.Command, _ []string, toComplete string) ([]string, cobra.ShellCompDirective) {
var (
ctx = cmd.Context()
log = lg.FromContext(ctx)
ru = run.FromContext(ctx)
drvr driver.SQLDriver
err error
)
if err = FinishRunInit(ctx, ru); err != nil {
log.Error("Init run", lga.Err, err)
return nil, cobra.ShellCompDirectiveError
}
if drvr, err = ru.DriverRegistry.SQLDriverFor(sqlite3.Type); err != nil {
// Shouldn't happen
log.Error("Cannot load driver", lga.Err, err)
return nil, cobra.ShellCompDirectiveError
}
hist := &locHistory{
coll: ru.Config.Collection,
typ: sqlite3.Type,
log: log,
}
du, err := dburl.Parse(toComplete)
if err == nil {
// Check if we're done with the filepath part, and on to conn params?
if du.URL.RawQuery != "" || strings.HasSuffix(toComplete, "?") {
return locCompDoConnParams(du, hist, drvr, toComplete)
}
}
// Build a list of files.
start := strings.TrimPrefix(toComplete, "sqlite3://")
paths := locCompListFiles(ctx, start)
for i := range paths {
if ioz.IsPathToRegularFile(paths[i]) && paths[i] == start {
paths[i] += "?"
}
paths[i] = "sqlite3://" + paths[i]
}
a := hist.locations()
a = append(a, paths...)
a = lo.Uniq(a)
a = stringz.FilterPrefix(toComplete, a...)
a = lo.Without(a, toComplete)
return a, locCompStdDirective
}
// locCompDoConnParams completes the query params. For example, given
// a toComplete value "sqlite3://my.db?", the result would include values
// such as "sqlite3://my.db?cache=".
func locCompDoConnParams(du *dburl.URL, hist *locHistory, drvr driver.SQLDriver, toComplete string) (
[]string, cobra.ShellCompDirective,
) {
var (
a []string
query = du.RawQuery
drvrParamKeys, drvrParams = locCompGetConnParams(drvr)
)
pathsWithQueries := hist.pathsWithQueries()
base := urlz.StripPath(du.URL, true)
if query == "" {
a = stringz.PrefixSlice(pathsWithQueries, base)
v := stringz.PrefixSlice(drvrParamKeys, toComplete)
v = stringz.SuffixSlice(v, "=")
a = append(a, v...)
a = lo.Uniq(a)
a = stringz.FilterPrefix(toComplete, a...)
a = lo.Without(a, toComplete)
return a, locCompStdDirective
}
for _, pwq := range pathsWithQueries {
if strings.HasPrefix(pwq, du.Path) {
a = append(a, base+pwq)
}
}
a = stringz.FilterPrefix(toComplete, a...)
actualKeys, err := urlz.QueryParamKeys(query)
if err != nil || len(actualKeys) == 0 {
return nil, cobra.ShellCompDirectiveError
}
actualValues, err := url.ParseQuery(query)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}
elements := strings.Split(query, "&")
// could be "sslmo", "sslmode", "sslmode=", "sslmode=dis"
lastElement := elements[len(elements)-1]
stump := strings.TrimSuffix(toComplete, lastElement)
before, _, ok := strings.Cut(lastElement, "=")
if !ok {
candidateKeys := stringz.ElementsHavingPrefix(drvrParamKeys, before)
candidateKeys = lo.Reject(candidateKeys, func(candidateKey string, index int) bool {
// We don't want the same candidate to show up twice, so we exclude
// it, but only if it already has a value in the query string.
if slices.Contains(actualKeys, candidateKey) {
vals, ok := actualValues[candidateKey]
if !ok || len(vals) == 0 {
return false
}
for _, val := range vals {
if val != "" {
return true
}
}
}
return false
})
for i := range candidateKeys {
s := stump + candidateKeys[i] + "="
a = append(a, s)
}
a = lo.Uniq(a)
a = lo.Without(a, toComplete)
return a, locCompStdDirective
}
candidateVals := drvrParams[before]
for i := range candidateVals {
s := stump + before + "=" + candidateVals[i]
a = append(a, s)
}
if len(candidateVals) == 0 {
lastChar := toComplete[len(toComplete)-1]
switch lastChar {
case '&', '?', '=':
default:
a = append(a, toComplete+"&")
}
}
a = lo.Uniq(a)
a = stringz.FilterPrefix(toComplete, a...)
if len(a) == 0 {
// If it's an unknown value, append "&" to move
// on to a further query param.
a = []string{toComplete + "&"}
return a, locCompStdDirective
}
if len(a) == 1 && a[0] == toComplete {
// If it's a completed known value ("sslmode=disable"),
// then append "?" to move on to a further query param.
a[0] += "&"
}
return a, locCompStdDirective
}
// locCompGetConnParams returns the driver's connection params. The returned
// keys are sorted appropriately for the driver, and are query encoded.
func locCompGetConnParams(drvr driver.SQLDriver) (keys []string, params map[string][]string) {
ogParams := drvr.ConnParams()
ogKeys := lo.Keys(ogParams)
slices.Sort(ogKeys)
if drvr.DriverMetadata().Type == sqlserver.Type {
// For SQLServer, the "database" key should come first, because
// it's required.
ogKeys = lo.Without(ogKeys, "database")
ogKeys = append([]string{"database"}, ogKeys...)
}
keys = make([]string, len(ogKeys))
params = make(map[string][]string, len(ogParams))
for i := range ogKeys {
k := url.QueryEscape(ogKeys[i])
keys[i] = k
params[k] = ogParams[ogKeys[i]]
}
return keys, params
}
// locCompDriverPort returns the default port for the driver, as a string,
// or empty string if not applicable.
func locCompDriverPort(drvr driver.SQLDriver) string {
p := drvr.DriverMetadata().DefaultPort
if p <= 0 {
return ""
}
return strconv.Itoa(p)
}
// locCompAfterHost returns the next text to show after the host
// part of the URL is complete.
func locCompAfterHost(typ source.DriverType) string {
if typ == sqlserver.Type {
return "?database="
}
return "/"
}
// locCompParseLoc parses a location string. The string can
// be in various stages of construction, e.g. "postgres://user" or
// "postgres://user@locahost/db". The stage is noted in parsedLoc.stageDone.
func locCompParseLoc(loc string) (*parsedLoc, error) {
p := &parsedLoc{loc: loc}
if !stringz.HasAnyPrefix(loc, locSchemes...) {
return p, nil
}
var (
s string
ok bool
err error
creds string
)
p.stageDone = plocScheme
p.scheme, s, ok = strings.Cut(loc, "://")
p.typ = source.DriverType(p.scheme)
if s == "" || !ok {
return p, nil
}
creds, s, ok = strings.Cut(s, "@")
if creds != "" {
// creds can be:
// user:pass
// user:
// user
var hasColon bool
p.user, p.pass, hasColon = strings.Cut(creds, ":")
if hasColon {
p.stageDone = plocUser
}
}
if !ok {
return p, nil
}
p.stageDone = plocPass
// At a minimum, we're at this point:
// postgres://
// Next we're looking for user:pass, e.g.
// postgres://alice:huzzah@localhost
if p.du, err = dburl.Parse(p.loc); err != nil {
return p, errz.Err(err)
}
du := p.du
p.scheme = du.OriginalScheme
if du.User != nil {
p.user = du.User.Username()
p.pass, _ = du.User.Password()
}
p.hostname = du.Hostname()
if strings.ContainsRune(du.URL.Host, ':') {
p.stageDone = plocHostname
}
if du.Port() != "" {
p.stageDone = plocHostname
p.port, err = strconv.Atoi(du.Port())
if err != nil {
p.port = -1
return p, nil //nolint:nilerr
}
}
switch p.typ { //nolint:exhaustive
default:
case sqlserver.Type:
var u *url.URL
if u, err = url.ParseRequestURI(loc); err == nil {
var vals url.Values
if vals, err = url.ParseQuery(u.RawQuery); err == nil {
p.name = vals.Get("database")
}
}
case postgres.Type, mysql.Type:
p.name = strings.TrimPrefix(du.Path, "/")
}
if strings.HasSuffix(s, "/") || strings.HasSuffix(s, `\?`) || du.URL.Path != "" {
p.stageDone = plocHost
}
if strings.HasSuffix(s, "?") {
p.stageDone = plocPath
}
if du.URL.RawQuery != "" {
p.stageDone = plocPath
}
return p, nil
}
// parsedLoc is a parsed representation of a driver location URL.
// It can represent partial or fully constructed locations. The stage
// of construction is noted in parsedLoc.stageDone.
type parsedLoc struct {
// loc is the original unparsed location value.
loc string
// stageDone indicates what stage of construction the location
// string is in.
stageDone plocStage
// typ is the associated source driver type, which may
// be empty until later determination.
typ source.DriverType
// scheme is the original location scheme
scheme string
// user is the username, if applicable.
user string
// pass is the password, if applicable.
pass string
// hostname is the hostname, if applicable.
hostname string
// port is the port number, or 0 if not applicable.
port int
// name is the database name.
name string
// du holds the parsed db url. This may be nil.
du *dburl.URL
}
// plocStage is an enum indicating what stage of construction
// a location string is in.
type plocStage string
const (
plocInit plocStage = ""
plocScheme plocStage = "scheme"
plocUser plocStage = "user"
plocPass plocStage = "pass"
plocHostname plocStage = "hostname"
plocHost plocStage = "host" // host is hostname+port, or just hostname
plocPath plocStage = "path"
)
// locSchemes is the set of built-in (SQL) driver schemes.
var locSchemes = []string{
"mysql://",
"postgres://",
"sqlite3://",
"sqlserver://",
}
// locCompListFiles completes filenames. This function tries to
// mimic what a shell would do. Any errors are logged and swallowed.
func locCompListFiles(ctx context.Context, toComplete string) []string {
var (
start = toComplete
files []string
err error
)
if start == "" {
start, err = os.Getwd()
if err != nil {
return nil
}
files, err = ioz.ReadDir(start, false, true, false)
if err != nil {
lg.FromContext(ctx).Warn("Read dir", lga.Path, start, lga.Err, err)
}
return files
}
if strings.HasSuffix(start, "/") {
files, err = ioz.ReadDir(start, true, true, false)
if err != nil {
lg.FromContext(ctx).Warn("Read dir", lga.Path, start, lga.Err, err)
}
return files
}
// We could have a situation like this:
// + [working dir]
// - my.db
// - my/my2.db
dir := filepath.Dir(start)
fi, err := os.Stat(dir)
if err == nil && fi.IsDir() {
files, err = ioz.ReadDir(dir, true, true, false)
if err != nil {
lg.FromContext(ctx).Warn("Read dir", lga.Path, start, lga.Err, err)
}
} else {
files = []string{start}
}
return stringz.FilterPrefix(toComplete, files...)
}
// locHistory provides methods for getting previously used
// elements of a location.
type locHistory struct {
coll *source.Collection
typ source.DriverType
log *slog.Logger
}
func (h *locHistory) usernames() []string {
var unames []string
_ = h.coll.Visit(func(src *source.Source) error {
if src.Type != h.typ {
return nil
}
du, err := dburl.Parse(src.Location)
if err != nil {
// Shouldn't happen
h.log.Warn("Parse source location", lga.Err, err)
return nil
}
if du.User != nil {
uname := du.User.Username()
if uname != "" {
unames = append(unames, uname)
}
}
return nil
})
unames = lo.Uniq(unames)
slices.Sort(unames)
return unames
}
func (h *locHistory) hosts() []string {
var hosts []string
_ = h.coll.Visit(func(src *source.Source) error {
if src.Type != h.typ {
return nil
}
du, err := dburl.Parse(src.Location)
if err != nil {
// Shouldn't happen
h.log.Warn("Parse source location", lga.Err, err)
return nil
}
hosts = append(hosts, du.Host)
return nil
})
hosts = lo.Uniq(hosts)
slices.Sort(hosts)
return hosts
}
func (h *locHistory) databases() []string {
var dbNames []string
_ = h.coll.Visit(func(src *source.Source) error {
if src.Type != h.typ {
return nil
}
du, err := dburl.Parse(src.Location)
if err != nil {
// Shouldn't happen
h.log.Warn("Parse source location", lga.Err, err)
return nil
}
if h.typ == sqlserver.Type && du.RawQuery != "" {
var vals url.Values
if vals, err = url.ParseQuery(du.RawQuery); err == nil {
db := vals.Get("database")
if db != "" {
dbNames = append(dbNames, db)
}
}
return nil
}
if du.Path != "" {
v := strings.TrimPrefix(du.Path, "/")
dbNames = append(dbNames, v)
}
return nil
})
dbNames = lo.Uniq(dbNames)
slices.Sort(dbNames)
return dbNames
}
func (h *locHistory) locations() []string {
var locs []string
_ = h.coll.Visit(func(src *source.Source) error {
if src.Type != h.typ {
return nil
}
locs = append(locs, src.Location)
return nil
})
locs = lo.Uniq(locs)
slices.Sort(locs)
return locs
}
// hostsWithPathAndQuery returns locations, minus the
// scheme and user info.
func (h *locHistory) hostsWithPathAndQuery() []string {
var values []string
_ = h.coll.Visit(func(src *source.Source) error {
if src.Type != h.typ {
return nil
}
du, err := dburl.Parse(src.Location)
if err != nil {
// Shouldn't happen
h.log.Warn("Parse source location", lga.Err, err)
return nil
}
v := urlz.StripSchemeAndUser(du.URL)
if v != "" {
values = append(values, v)
}
return nil
})
values = lo.Uniq(values)
slices.Sort(values)
return values
}
// pathsWithQueries returns the location elements after
// the host, i.e. the path and query.
func (h *locHistory) pathsWithQueries() []string {
var values []string
_ = h.coll.Visit(func(src *source.Source) error {
if src.Type != h.typ {
return nil
}
du, err := dburl.Parse(src.Location)
if err != nil {
// Shouldn't happen
h.log.Warn("Parse source location", lga.Err, err)
return nil
}
v := du.Path
if du.RawQuery != "" {
v += `?` + du.RawQuery
}
values = append(values, v)
return nil
})
values = lo.Uniq(values)
slices.Sort(values)
return values
}

File diff suppressed because it is too large Load Diff

View File

@ -226,7 +226,7 @@ func addOptionFlag(flags *pflag.FlagSet, opt options.Opt) (key string) {
}
flags.BoolP(key, string(opt.Short()), opt.Default(), opt.Usage())
return
return key
case options.Duration:
if opt.Short() == 0 {
flags.Duration(key, opt.Default(), opt.Usage())

View File

@ -116,59 +116,18 @@ func newRun(ctx context.Context, stdin *os.File, stdout, stderr io.Writer, args
return ru, log, nil
}
// preRun is invoked by cobra prior to the command's RunE being
// invoked. It sets up the driver registry, databases, writers and related
// fundamental components. Subsequent invocations of this method
// are no-op.
func preRun(cmd *cobra.Command, ru *run.Run) error {
if ru == nil {
return errz.New("Run is nil")
}
if ru.Writers != nil {
// If ru.Writers is already set, then this function has already been
// called on ru. That's ok, just return.
return nil
}
// FinishRunInit finishes setting up ru.
//
// TODO: This run.Run initialization mechanism is a bit of a mess.
// There's logic in newRun, preRun, FinishRunInit, as well as testh.Helper.init.
// Surely the init logic can be consolidated.
func FinishRunInit(ctx context.Context, ru *run.Run) error {
if ru.Cleanup == nil {
ru.Cleanup = cleanup.New()
}
ctx := cmd.Context()
cfg, log := ru.Config, lg.FromContext(ctx)
// If the --output=/some/file flag is set, then we need to
// override ru.Out (which is typically stdout) to point it at
// the output destination file.
if cmdFlagChanged(ru.Cmd, flag.Output) {
fpath, _ := ru.Cmd.Flags().GetString(flag.Output)
fpath, err := filepath.Abs(fpath)
if err != nil {
return errz.Wrapf(err, "failed to get absolute path for --%s", flag.Output)
}
// Ensure the parent dir exists
err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm)
if err != nil {
return errz.Wrapf(err, "failed to make parent dir for --%s", flag.Output)
}
f, err := os.Create(fpath)
if err != nil {
return errz.Wrapf(err, "failed to open file specified by flag --%s", flag.Output)
}
ru.Cleanup.AddC(f) // Make sure the file gets closed eventually
ru.Out = f
}
cmdOpts, err := getOptionsFromCmd(ru.Cmd)
if err != nil {
return err
}
ru.Writers, ru.Out, ru.ErrOut = newWriters(ru.Cmd, cmdOpts, ru.Out, ru.ErrOut)
var scratchSrcFunc driver.ScratchSrcFunc
// scratchSrc could be nil, and that's ok
@ -181,10 +140,13 @@ func preRun(cmd *cobra.Command, ru *run.Run) error {
}
}
ru.Files, err = source.NewFiles(ctx)
if err != nil {
lg.WarnIfFuncError(log, lga.Cleanup, ru.Cleanup.Run)
return err
var err error
if ru.Files == nil {
ru.Files, err = source.NewFiles(ctx)
if err != nil {
lg.WarnIfFuncError(log, lga.Cleanup, ru.Cleanup.Run)
return err
}
}
// Note: it's important that files.Close is invoked
@ -260,3 +222,57 @@ func preRun(cmd *cobra.Command, ru *run.Run) error {
return nil
}
// preRun is invoked by cobra prior to the command's RunE being
// invoked. It sets up the driver registry, databases, writers and related
// fundamental components. Subsequent invocations of this method
// are no-op.
func preRun(cmd *cobra.Command, ru *run.Run) error {
if ru == nil {
return errz.New("Run is nil")
}
if ru.Writers != nil {
// If ru.Writers is already set, then this function has already been
// called on ru. That's ok, just return.
return nil
}
if ru.Cleanup == nil {
ru.Cleanup = cleanup.New()
}
ctx := cmd.Context()
// If the --output=/some/file flag is set, then we need to
// override ru.Out (which is typically stdout) to point it at
// the output destination file.
if cmdFlagChanged(ru.Cmd, flag.Output) {
fpath, _ := ru.Cmd.Flags().GetString(flag.Output)
fpath, err := filepath.Abs(fpath)
if err != nil {
return errz.Wrapf(err, "failed to get absolute path for --%s", flag.Output)
}
// Ensure the parent dir exists
err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm)
if err != nil {
return errz.Wrapf(err, "failed to make parent dir for --%s", flag.Output)
}
f, err := os.Create(fpath)
if err != nil {
return errz.Wrapf(err, "failed to open file specified by flag --%s", flag.Output)
}
ru.Cleanup.AddC(f) // Make sure the file gets closed eventually
ru.Out = f
}
cmdOpts, err := getOptionsFromCmd(ru.Cmd)
if err != nil {
return err
}
ru.Writers, ru.Out, ru.ErrOut = newWriters(ru.Cmd, cmdOpts, ru.Out, ru.ErrOut)
return FinishRunInit(ctx, ru)
}

View File

View File

0
cli/testdata/add_location/my.db vendored Normal file
View File

0
cli/testdata/add_location/my/my1.db vendored Normal file
View File

View File

0
cli/testdata/add_location/post.db vendored Normal file
View File

View File

0
cli/testdata/add_location/sqlite.db vendored Normal file
View File

View File

View File

@ -26,13 +26,13 @@ import (
// TestRun is a helper for testing sq commands.
type TestRun struct {
T *testing.T
ctx context.Context
mu sync.Mutex
Run *run.Run
Out *bytes.Buffer
ErrOut *bytes.Buffer
used bool
T testing.TB
Context context.Context
mu sync.Mutex
Run *run.Run
Out *bytes.Buffer
ErrOut *bytes.Buffer
used bool
// When true, out and errOut are not logged.
hushOutput bool
@ -41,8 +41,12 @@ type TestRun struct {
// New returns a new run instance for testing sq commands.
// If from is non-nil, its config is used. This allows sequential
// commands to use the same config.
func New(ctx context.Context, t *testing.T, from *TestRun) *TestRun {
tr := &TestRun{T: t, ctx: ctx}
func New(ctx context.Context, t testing.TB, from *TestRun) *TestRun {
if ctx == nil {
ctx = context.Background()
}
tr := &TestRun{T: t, Context: ctx}
var cfgStore config.Store
if from != nil {
@ -92,12 +96,15 @@ func newRun(ctx context.Context, t testing.TB, cfgStore config.Store) (ru *run.R
OptionsRegistry: optsReg,
}
require.NoError(t, cli.FinishRunInit(ctx, ru))
return ru, out, errOut
}
// Add adds srcs to tr.Run.Config.Collection. If the collection
// does not already have an active source, the first element
// of srcs is used as the active source.
//
// REVISIT: Why not use *source.Source instead of the value?
func (tr *TestRun) Add(srcs ...source.Source) *TestRun {
tr.mu.Lock()
defer tr.mu.Unlock()
@ -119,6 +126,9 @@ func (tr *TestRun) Add(srcs ...source.Source) *TestRun {
require.NoError(tr.T, err)
}
err := tr.Run.ConfigStore.Save(tr.Context, tr.Run.Config)
require.NoError(tr.T, err)
return tr
}
@ -141,7 +151,7 @@ func (tr *TestRun) doExec(args []string) error {
require.False(tr.T, tr.used, "TestRun instance must only be used once")
ctx, cancelFn := context.WithCancel(context.Background())
ctx, cancelFn := context.WithCancel(tr.Context)
tr.T.Cleanup(cancelFn)
execErr := cli.ExecuteWith(ctx, tr.Run, args)

276
drivers/mysql/collation.go Normal file
View File

@ -0,0 +1,276 @@
package mysql
var collations = []string{
"armscii8_bin",
"armscii8_general_ci",
"ascii_bin",
"ascii_general_ci",
"big5_bin",
"big5_chinese_ci",
"binary",
"cp1250_bin",
"cp1250_croatian_ci",
"cp1250_czech_cs",
"cp1250_general_ci",
"cp1250_polish_ci",
"cp1251_bin",
"cp1251_bulgarian_ci",
"cp1251_general_ci",
"cp1251_general_cs",
"cp1251_ukrainian_ci",
"cp1256_bin",
"cp1256_general_ci",
"cp1257_bin",
"cp1257_general_ci",
"cp1257_lithuanian_ci",
"cp850_bin",
"cp850_general_ci",
"cp852_bin",
"cp852_general_ci",
"cp866_bin",
"cp866_general_ci",
"cp932_bin",
"cp932_japanese_ci",
"dec8_bin",
"dec8_swedish_ci",
"eucjpms_bin",
"eucjpms_japanese_ci",
"euckr_bin",
"euckr_korean_ci",
"gb18030_bin",
"gb18030_chinese_ci",
"gb18030_unicode_520_ci",
"gb2312_bin",
"gb2312_chinese_ci",
"gbk_bin",
"gbk_chinese_ci",
"geostd8_bin",
"geostd8_general_ci",
"greek_bin",
"greek_general_ci",
"hebrew_bin",
"hebrew_general_ci",
"hp8_bin",
"hp8_english_ci",
"keybcs2_bin",
"keybcs2_general_ci",
"koi8r_bin",
"koi8r_general_ci",
"koi8u_bin",
"koi8u_general_ci",
"latin1_bin",
"latin1_danish_ci",
"latin1_general_ci",
"latin1_general_cs",
"latin1_german1_ci",
"latin1_german2_ci",
"latin1_spanish_ci",
"latin1_swedish_ci",
"latin2_bin",
"latin2_croatian_ci",
"latin2_czech_cs",
"latin2_general_ci",
"latin2_hungarian_ci",
"latin5_bin",
"latin5_turkish_ci",
"latin7_bin",
"latin7_estonian_cs",
"latin7_general_ci",
"latin7_general_cs",
"macce_bin",
"macce_general_ci",
"macroman_bin",
"macroman_general_ci",
"sjis_bin",
"sjis_japanese_ci",
"swe7_bin",
"swe7_swedish_ci",
"tis620_bin",
"tis620_thai_ci",
"ucs2_bin",
"ucs2_croatian_ci",
"ucs2_czech_ci",
"ucs2_danish_ci",
"ucs2_esperanto_ci",
"ucs2_estonian_ci",
"ucs2_general_ci",
"ucs2_general_mysql500_ci",
"ucs2_german2_ci",
"ucs2_hungarian_ci",
"ucs2_icelandic_ci",
"ucs2_latvian_ci",
"ucs2_lithuanian_ci",
"ucs2_persian_ci",
"ucs2_polish_ci",
"ucs2_romanian_ci",
"ucs2_roman_ci",
"ucs2_sinhala_ci",
"ucs2_slovak_ci",
"ucs2_slovenian_ci",
"ucs2_spanish2_ci",
"ucs2_spanish_ci",
"ucs2_swedish_ci",
"ucs2_turkish_ci",
"ucs2_unicode_520_ci",
"ucs2_unicode_ci",
"ucs2_vietnamese_ci",
"ujis_bin",
"ujis_japanese_ci",
"utf16le_bin",
"utf16le_general_ci",
"utf16_bin",
"utf16_croatian_ci",
"utf16_czech_ci",
"utf16_danish_ci",
"utf16_esperanto_ci",
"utf16_estonian_ci",
"utf16_general_ci",
"utf16_german2_ci",
"utf16_hungarian_ci",
"utf16_icelandic_ci",
"utf16_latvian_ci",
"utf16_lithuanian_ci",
"utf16_persian_ci",
"utf16_polish_ci",
"utf16_romanian_ci",
"utf16_roman_ci",
"utf16_sinhala_ci",
"utf16_slovak_ci",
"utf16_slovenian_ci",
"utf16_spanish2_ci",
"utf16_spanish_ci",
"utf16_swedish_ci",
"utf16_turkish_ci",
"utf16_unicode_520_ci",
"utf16_unicode_ci",
"utf16_vietnamese_ci",
"utf32_bin",
"utf32_croatian_ci",
"utf32_czech_ci",
"utf32_danish_ci",
"utf32_esperanto_ci",
"utf32_estonian_ci",
"utf32_general_ci",
"utf32_german2_ci",
"utf32_hungarian_ci",
"utf32_icelandic_ci",
"utf32_latvian_ci",
"utf32_lithuanian_ci",
"utf32_persian_ci",
"utf32_polish_ci",
"utf32_romanian_ci",
"utf32_roman_ci",
"utf32_sinhala_ci",
"utf32_slovak_ci",
"utf32_slovenian_ci",
"utf32_spanish2_ci",
"utf32_spanish_ci",
"utf32_swedish_ci",
"utf32_turkish_ci",
"utf32_unicode_520_ci",
"utf32_unicode_ci",
"utf32_vietnamese_ci",
"utf8mb4_0900_ai_ci",
"utf8mb4_0900_as_ci",
"utf8mb4_0900_as_cs",
"utf8mb4_0900_bin",
"utf8mb4_bin",
"utf8mb4_croatian_ci",
"utf8mb4_cs_0900_ai_ci",
"utf8mb4_cs_0900_as_cs",
"utf8mb4_czech_ci",
"utf8mb4_danish_ci",
"utf8mb4_da_0900_ai_ci",
"utf8mb4_da_0900_as_cs",
"utf8mb4_de_pb_0900_ai_ci",
"utf8mb4_de_pb_0900_as_cs",
"utf8mb4_eo_0900_ai_ci",
"utf8mb4_eo_0900_as_cs",
"utf8mb4_esperanto_ci",
"utf8mb4_estonian_ci",
"utf8mb4_es_0900_ai_ci",
"utf8mb4_es_0900_as_cs",
"utf8mb4_es_trad_0900_ai_ci",
"utf8mb4_es_trad_0900_as_cs",
"utf8mb4_et_0900_ai_ci",
"utf8mb4_et_0900_as_cs",
"utf8mb4_general_ci",
"utf8mb4_german2_ci",
"utf8mb4_hr_0900_ai_ci",
"utf8mb4_hr_0900_as_cs",
"utf8mb4_hungarian_ci",
"utf8mb4_hu_0900_ai_ci",
"utf8mb4_hu_0900_as_cs",
"utf8mb4_icelandic_ci",
"utf8mb4_is_0900_ai_ci",
"utf8mb4_is_0900_as_cs",
"utf8mb4_ja_0900_as_cs",
"utf8mb4_ja_0900_as_cs_ks",
"utf8mb4_latvian_ci",
"utf8mb4_la_0900_ai_ci",
"utf8mb4_la_0900_as_cs",
"utf8mb4_lithuanian_ci",
"utf8mb4_lt_0900_ai_ci",
"utf8mb4_lt_0900_as_cs",
"utf8mb4_lv_0900_ai_ci",
"utf8mb4_lv_0900_as_cs",
"utf8mb4_persian_ci",
"utf8mb4_pl_0900_ai_ci",
"utf8mb4_pl_0900_as_cs",
"utf8mb4_polish_ci",
"utf8mb4_romanian_ci",
"utf8mb4_roman_ci",
"utf8mb4_ro_0900_ai_ci",
"utf8mb4_ro_0900_as_cs",
"utf8mb4_ru_0900_ai_ci",
"utf8mb4_ru_0900_as_cs",
"utf8mb4_sinhala_ci",
"utf8mb4_sk_0900_ai_ci",
"utf8mb4_sk_0900_as_cs",
"utf8mb4_slovak_ci",
"utf8mb4_slovenian_ci",
"utf8mb4_sl_0900_ai_ci",
"utf8mb4_sl_0900_as_cs",
"utf8mb4_spanish2_ci",
"utf8mb4_spanish_ci",
"utf8mb4_sv_0900_ai_ci",
"utf8mb4_sv_0900_as_cs",
"utf8mb4_swedish_ci",
"utf8mb4_tr_0900_ai_ci",
"utf8mb4_tr_0900_as_cs",
"utf8mb4_turkish_ci",
"utf8mb4_unicode_520_ci",
"utf8mb4_unicode_ci",
"utf8mb4_vietnamese_ci",
"utf8mb4_vi_0900_ai_ci",
"utf8mb4_vi_0900_as_cs",
"utf8mb4_zh_0900_as_cs",
"utf8_bin",
"utf8_croatian_ci",
"utf8_czech_ci",
"utf8_danish_ci",
"utf8_esperanto_ci",
"utf8_estonian_ci",
"utf8_general_ci",
"utf8_general_mysql500_ci",
"utf8_german2_ci",
"utf8_hungarian_ci",
"utf8_icelandic_ci",
"utf8_latvian_ci",
"utf8_lithuanian_ci",
"utf8_persian_ci",
"utf8_polish_ci",
"utf8_romanian_ci",
"utf8_roman_ci",
"utf8_sinhala_ci",
"utf8_slovak_ci",
"utf8_slovenian_ci",
"utf8_spanish2_ci",
"utf8_spanish_ci",
"utf8_swedish_ci",
"utf8_tolower_ci",
"utf8_turkish_ci",
"utf8_unicode_520_ci",
"utf8_unicode_ci",
"utf8_vietnamese_ci",
}

View File

@ -57,6 +57,34 @@ type driveri struct {
log *slog.Logger
}
// ConnParams implements driver.SQLDriver.
// See: https://github.com/go-sql-driver/mysql#dsn-data-source-name.
func (d *driveri) ConnParams() map[string][]string {
return map[string][]string{
"allowAllFiles": {"false", "true"},
"allowCleartextPasswords": {"false", "true"},
"allowFallbackToPlaintext": {"false", "true"},
"allowNativePasswords": {"false", "true"},
"allowOldPasswords": {"false", "true"},
"charset": nil,
"checkConnLiveness": {"true", "false"},
"clientFoundRows": {"false", "true"},
"collation": collations,
"columnsWithAlias": {"false", "true"},
"connectionAttributes": nil,
"interpolateParams": {"false", "true"},
"loc": {"UTC"},
"maxAllowedPackage": {"0", "67108864"},
"multiStatements": {"false", "true"},
"parseTime": {"false", "true"},
"readTimeout": {"0"},
"rejectReadOnly": {"false", "true"},
"timeout": nil,
"tls": {"false", "true", "skip-verify", "preferred"},
"writeTimeout": {"0"},
}
}
// ErrWrapFunc implements driver.SQLDriver.
func (d *driveri) ErrWrapFunc() func(error) error {
return errw
@ -74,6 +102,7 @@ func (d *driveri) DriverMetadata() driver.Metadata {
Description: "MySQL",
Doc: "https://github.com/go-sql-driver/mysql",
IsSQL: true,
DefaultPort: 3306,
}
}

View File

@ -62,11 +62,26 @@ func (p *Provider) DriverFor(typ source.DriverType) (driver.Driver, error) {
return &driveri{log: p.Log}, nil
}
var _ driver.SQLDriver = (*driveri)(nil)
// driveri is the postgres implementation of driver.Driver.
type driveri struct {
log *slog.Logger
}
// ConnParams implements driver.SQLDriver.
func (d *driveri) ConnParams() map[string][]string {
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
return map[string][]string{
"channel_binding": {"prefer", "require", "disable"},
"connect_timeout": {"2"},
"application_name": nil,
"fallback_application_name": nil,
"gssencmode": {"disable", "prefer", "require"},
"sslmode": {"disable", "allow", "prefer", "require", "verify-ca", "verify-full"},
}
}
// ErrWrapFunc implements driver.SQLDriver.
func (d *driveri) ErrWrapFunc() func(error) error {
return errw
@ -79,6 +94,7 @@ func (d *driveri) DriverMetadata() driver.Metadata {
Description: "PostgreSQL",
Doc: "https://github.com/jackc/pgx",
IsSQL: true,
DefaultPort: 5432,
}
}

View File

@ -52,7 +52,7 @@ func readPragma(ctx context.Context, db sqlz.DB, pragma string) (any, error) {
// Some of the pragmas can't be selected from. Ignore these.
// SQLite returns a generic (1) SQLITE_ERROR in this case,
// so we match using the error string.
return nil, nil
return nil, nil //nolint:nilnil
}
return nil, errw(err)
@ -61,7 +61,7 @@ func readPragma(ctx context.Context, db sqlz.DB, pragma string) (any, error) {
defer lg.WarnIfCloseError(lg.FromContext(ctx), lgm.CloseDBRows, rows)
if !rows.Next() {
return nil, nil
return nil, nil //nolint:nilnil
}
cols, err := rows.Columns()
@ -72,7 +72,7 @@ func readPragma(ctx context.Context, db sqlz.DB, pragma string) (any, error) {
switch len(cols) {
case 0:
// Shouldn't happen
return nil, nil
return nil, nil //nolint:nilnil
case 1:
var val any
if err = rows.Scan(&val); err != nil {

View File

@ -54,10 +54,6 @@ var _ driver.Provider = (*Provider)(nil)
// Provider is the SQLite3 implementation of driver.Provider.
type Provider struct {
Log *slog.Logger
// NOTE: Unlike other driver.SQLDriver impls, sqlite doesn't
// seem to benefit from applying a driver.SQLConfig to
// its sql.DB.
}
// DriverFor implements driver.Provider.
@ -69,13 +65,43 @@ func (p *Provider) DriverFor(typ source.DriverType) (driver.Driver, error) {
return &driveri{log: p.Log}, nil
}
var _ driver.Driver = (*driveri)(nil)
var _ driver.SQLDriver = (*driveri)(nil)
// driveri is the SQLite3 implementation of driver.Driver.
// driveri is the SQLite3 implementation of driver.SQLDriver.
type driveri struct {
log *slog.Logger
}
// ConnParams implements driver.SQLDriver.
// See: https://github.com/mattn/go-sqlite3#connection-string.
func (d *driveri) ConnParams() map[string][]string {
return map[string][]string{
"_auth": nil,
"_auth_crypt": {"SHA1", "SSHA1", "SHA256", "SSHA256", "SHA384", "SSHA384", "SHA512", "SSHA512"},
"_auth_pass": nil,
"_auth_salt": nil,
"_auth_user": nil,
"_auto_vacuum": {"none", "full", "incremental"},
"_busy_timeout": nil,
"_cache_size": {"-2000"},
"_case_sensitive_like": {"true", "false"},
"_defer_foreign_keys": {"true", "false"},
"_foreign_keys": {"true", "false"},
"_ignore_check_constraints": {"true", "false"},
"_journal_mode": {"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"},
"_loc": nil,
"_locking_mode": {"NORMAL", "EXCLUSIVE"},
"_mutex": {"no", "full"},
"_query_only": {"true", "false"},
"_recursive_triggers": {"true", "false"},
"_secure_delete": {"true", "false", "FAST"},
"_synchronous": {"OFF", "NORMAL", "FULL", "EXTRA"},
"_txlock": {"immediate", "deferred", "exclusive"},
"cache": {"true", "false", "FAST"},
"mode": {"ro", "rw", "rwc", "memory"},
}
}
// ErrWrapFunc implements driver.SQLDriver.
func (d *driveri) ErrWrapFunc() func(error) error {
return errw

View File

@ -61,6 +61,32 @@ type driveri struct {
log *slog.Logger
}
// ConnParams implements driver.SQLDriver.
func (d *driveri) ConnParams() map[string][]string {
// https://github.com/microsoft/go-mssqldb#connection-parameters-and-dsn.
return map[string][]string{
"ApplicationIntent": {"ReadOnly"},
"ServerSPN": nil,
"TrustServerCertificate": {"false", "true"},
"Workstation ID": nil,
"app name": {"sq"},
"certificate": nil,
"connection timeout": {"0"},
"database": nil,
"dial timeout": {"0"},
"encrypt": {"disable", "false", "true"},
"failoverpartner": nil,
"failoverport": {"1433"},
"hostNameInCertificate": nil,
"keepAlive": {"0", "30"},
"log": {"0", "1", "2", "4", "8", "16", "32", "64", "128", "255"},
"packet size": {"512", "4096", "16383", "32767"},
"protocol": nil,
"tlsmin": {"1.0", "1.1", "1.2", "1.3"},
"user id": nil,
}
}
// ErrWrapFunc implements driver.SQLDriver.
func (d *driveri) ErrWrapFunc() func(error) error {
return errw
@ -73,6 +99,7 @@ func (d *driveri) DriverMetadata() driver.Metadata {
Description: "Microsoft SQL Server / Azure SQL Edge",
Doc: "https://github.com/microsoft/go-mssqldb",
IsSQL: true,
DefaultPort: 1433,
}
}

View File

@ -6,6 +6,8 @@ import (
"context"
"io"
"os"
"path/filepath"
"strings"
"github.com/neilotoole/sq/libsq/core/lg"
@ -83,3 +85,92 @@ func MarshalYAML(v any) ([]byte, error) {
func UnmarshallYAML(data []byte, v any) error {
return errz.Err(yaml.Unmarshal(data, v))
}
// ReadDir lists the contents of dir, returning the relative paths
// of the files. If markDirs is true, directories are listed with
// a "/" suffix (including symlinked dirs). If includeDirPath is true,
// the listing is of the form "dir/name". If includeDot is true,
// files beginning with period (dot files) are included. The function
// attempts to continue in the present of errors: the returned paths
// may contain values even in the presence of a returned error (which
// may be a multierr).
func ReadDir(dir string, includeDirPath, markDirs, includeDot bool) (paths []string, err error) {
fi, err := os.Stat(dir)
if err != nil {
return nil, errz.Err(err)
}
if !fi.Mode().IsDir() {
return nil, errz.Errorf("not a dir: %s", dir)
}
var entries []os.DirEntry
if entries, err = os.ReadDir(dir); err != nil {
return nil, errz.Err(err)
}
var name string
for _, entry := range entries {
name = entry.Name()
if strings.HasPrefix(name, ".") && !includeDot {
// Skip invisible files
continue
}
mode := entry.Type()
if !mode.IsRegular() && markDirs {
if entry.IsDir() {
name += "/"
} else if mode&os.ModeSymlink != 0 {
// Follow the symlink to detect if it's a dir
linked, err2 := filepath.EvalSymlinks(filepath.Join(dir, name))
if err2 != nil {
err = errz.Append(err, errz.Err(err2))
continue
}
fi, err2 = os.Stat(linked)
if err2 != nil {
err = errz.Append(err, errz.Err(err2))
continue
}
if fi.IsDir() {
name += "/"
}
}
}
paths = append(paths, name)
}
if includeDirPath {
for i := range paths {
// filepath.Join strips the "/" suffix, so we need to preserve it.
hasSlashSuffix := strings.HasSuffix(paths[i], "/")
paths[i] = filepath.Join(dir, paths[i])
if hasSlashSuffix {
paths[i] += "/"
}
}
}
return paths, nil
}
// IsPathToRegularFile return true if path is a regular file or
// a symlink that resolves to a regular file. False is returned on
// any error.
func IsPathToRegularFile(path string) bool {
dest, err := filepath.EvalSymlinks(path)
if err != nil {
return false
}
fi, err := os.Stat(dest)
if err != nil {
return false
}
return fi.Mode().IsRegular()
}

View File

@ -152,7 +152,7 @@ func (d *Detector) doSampleString(s string) {
d.mungeFns[Time] = func(val any) (any, error) {
if val == nil {
return nil, nil
return nil, nil //nolint:nilnil
}
s, ok = val.(string)
@ -161,7 +161,7 @@ func (d *Detector) doSampleString(s string) {
}
if s == "" {
return nil, nil
return nil, nil //nolint:nilnil
}
var t time.Time
@ -186,7 +186,7 @@ func (d *Detector) doSampleString(s string) {
d.mungeFns[Date] = func(val any) (any, error) {
if val == nil {
return nil, nil
return nil, nil //nolint:nilnil
}
s, ok = val.(string)
@ -195,7 +195,7 @@ func (d *Detector) doSampleString(s string) {
}
if s == "" {
return nil, nil
return nil, nil //nolint:nilnil
}
var t time.Time
@ -222,7 +222,7 @@ func (d *Detector) doSampleString(s string) {
// it returns a time.Time instead of a string
d.mungeFns[Datetime] = func(val any) (any, error) {
if val == nil {
return nil, nil
return nil, nil //nolint:nilnil
}
s, ok := val.(string)
@ -231,7 +231,7 @@ func (d *Detector) doSampleString(s string) {
}
if s == "" {
return nil, nil
return nil, nil //nolint:nilnil
}
t, err := time.Parse(format, s)

View File

@ -12,14 +12,14 @@ var _ MungeFunc = MungeEmptyStringAsNil
func MungeEmptyStringAsNil(v any) (any, error) {
switch v := v.(type) {
case nil:
return nil, nil
return nil, nil //nolint:nilnil
case *string:
if len(*v) == 0 {
return nil, nil
return nil, nil //nolint:nilnil
}
case string:
if len(v) == 0 {
return nil, nil
return nil, nil //nolint:nilnil
}
}

View File

@ -37,6 +37,7 @@ const (
To = "to"
Type = "type"
Line = "line"
URL = "url"
Val = "value"
Via = "via"
Version = "version"

View File

@ -16,6 +16,8 @@ import (
"time"
"unicode"
"github.com/samber/lo"
"github.com/google/uuid"
"github.com/neilotoole/sq/libsq/core/errz"
@ -298,9 +300,9 @@ func SurroundSlice(a []string, w string) []string {
}
// PrefixSlice returns a new slice with each element
// of a prefixed with w, unless a is nil, in which
// of a prefixed with prefix, unless a is nil, in which
// case nil is returned.
func PrefixSlice(a []string, w string) []string {
func PrefixSlice(a []string, prefix string) []string {
if a == nil {
return nil
}
@ -310,8 +312,8 @@ func PrefixSlice(a []string, w string) []string {
ret := make([]string, len(a))
sb := strings.Builder{}
for i := 0; i < len(a); i++ {
sb.Grow(len(a[i]) + len(w))
sb.WriteString(w)
sb.Grow(len(a[i]) + len(prefix))
sb.WriteString(prefix)
sb.WriteString(a[i])
ret[i] = sb.String()
sb.Reset()
@ -592,3 +594,28 @@ func IndentLines(s, indent string) string {
return indent + line
})
}
// HasAnyPrefix returns true if s has any of the prefixes.
func HasAnyPrefix(s string, prefixes ...string) bool {
for _, prefix := range prefixes {
if strings.HasPrefix(s, prefix) {
return true
}
}
return false
}
// FilterPrefix returns a new slice containing each element
// of a that has prefix.
func FilterPrefix(prefix string, a ...string) []string {
return lo.Filter(a, func(item string, index int) bool {
return strings.HasPrefix(item, prefix)
})
}
// ElementsHavingPrefix returns the elements of a that have prefix.
func ElementsHavingPrefix(a []string, prefix string) []string {
return lo.Filter(a, func(item string, index int) bool {
return strings.HasPrefix(item, prefix)
})
}

105
libsq/core/urlz/urlz.go Normal file
View File

@ -0,0 +1,105 @@
// Package urlz contains URL utility functionality.
package urlz
import (
"net/url"
"strings"
"github.com/neilotoole/sq/libsq/core/errz"
)
// QueryParamKeys returns the keys of a URL query. This function
// exists because url.ParseQuery returns a url.Values, which is a
// map type, and the keys don't preserve order.
func QueryParamKeys(query string) (keys []string, err error) {
// Code is adapted from url.ParseQuery.
for query != "" {
var key string
key, query, _ = strings.Cut(query, "&")
if strings.Contains(key, ";") {
err = errz.Errorf("invalid semicolon separator in query")
continue
}
if key == "" {
continue
}
key, _, _ = strings.Cut(key, "=")
key, err1 := url.QueryUnescape(key)
if err1 != nil {
if err == nil {
err = errz.Err(err1)
}
continue
}
keys = append(keys, key)
}
return keys, err
}
// RenameQueryParamKey renames all instances of oldKey in query
// to newKey, where query is a URL query string.
func RenameQueryParamKey(query, oldKey, newKey string) string {
if query == "" {
return ""
}
parts := strings.Split(query, "&")
for i, part := range parts {
if part == oldKey {
parts[i] = newKey
continue
}
if strings.HasPrefix(part, oldKey+"=") {
parts[i] = strings.Replace(part, oldKey, newKey, 1)
}
}
return strings.Join(parts, "&")
}
// StripQuery strips the query params from u.
func StripQuery(u url.URL) string {
u.RawQuery = ""
u.ForceQuery = false
return u.String()
}
// StripPath strips the url's path. If stripQuery is true, the
// query is also stripped.
func StripPath(u url.URL, stripQuery bool) string {
u.Path = ""
if stripQuery {
u.RawQuery = ""
u.ForceQuery = false
}
return u.String()
}
// StripUser strips the URL's user info.
func StripUser(u url.URL) string {
u2 := u
u2.User = nil
s := u2.String()
return s
}
// StripScheme removes the URL's scheme.
func StripScheme(u url.URL) string {
u2 := u
u2.Scheme = ""
s := u2.String()
s = strings.TrimPrefix(s, "//")
return s
}
// StripSchemeAndUser removes the URL's scheme and user info.
func StripSchemeAndUser(u url.URL) string {
u2 := u
u2.User = nil
u2.Scheme = ""
s := u2.String()
s = strings.TrimPrefix(s, "//")
return s
}

View File

@ -0,0 +1,152 @@
package urlz_test
import (
"net/url"
"testing"
"github.com/neilotoole/sq/libsq/core/urlz"
"github.com/neilotoole/sq/testh/tutil"
"github.com/stretchr/testify/require"
)
func TestQueryParamKeys(t *testing.T) {
testCases := []struct {
q string
want []string
wantErr bool
}{
{"a=1", []string{"a"}, false},
{"a=1&b=2", []string{"a", "b"}, false},
{"b=1&a=2", []string{"b", "a"}, false},
{"a=1&b=", []string{"a", "b"}, false},
{"a=1&b=;", []string{"a"}, true},
}
for i, tc := range testCases {
tc := tc
t.Run(tutil.Name(i, tc.q), func(t *testing.T) {
got, gotErr := urlz.QueryParamKeys(tc.q)
if tc.wantErr {
require.Error(t, gotErr)
} else {
require.NoError(t, gotErr)
}
require.Equal(t, tc.want, got)
})
}
}
func TestRenameQueryParamKey(t *testing.T) {
testCases := []struct {
q string
oldKey string
newKey string
want string
}{
{"", "a", "b", ""},
{"a=1", "a", "b", "b=1"},
{"a", "a", "b", "b"},
{"aa", "a", "b", "aa"},
{"a=", "a", "b", "b="},
{"a=1&a=2", "a", "b", "b=1&b=2"},
{"a=1&c=2", "a", "b", "b=1&c=2"},
{"a=1&c=2&a=3&a=4", "a", "b", "b=1&c=2&b=3&b=4"},
{"a=a&c=2&a=b&a=c", "a", "b", "b=a&c=2&b=b&b=c"},
}
for i, tc := range testCases {
tc := tc
t.Run(tutil.Name(i, tc.q, tc.oldKey, tc.newKey), func(t *testing.T) {
got := urlz.RenameQueryParamKey(tc.q, tc.oldKey, tc.newKey)
require.Equal(t, tc.want, got)
})
}
}
func TestStripQuery(t *testing.T) {
testCases := []struct {
in string
want string
}{
{"https://sq.io", "https://sq.io"},
{"https://sq.io/path", "https://sq.io/path"},
{"https://sq.io/path#frag", "https://sq.io/path#frag"},
{"https://sq.io?a=b", "https://sq.io"},
{"https://sq.io/path?a=b", "https://sq.io/path"},
{"https://sq.io/path?a=b&c=d#frag", "https://sq.io/path#frag"},
}
for i, tc := range testCases {
tc := tc
t.Run(tutil.Name(i, tc), func(t *testing.T) {
u, err := url.Parse(tc.in)
require.NoError(t, err)
got := urlz.StripQuery(*u)
require.Equal(t, tc.want, got)
})
}
}
func TestStripUser(t *testing.T) {
testCases := []struct {
in string
want string
}{
{"https://sq.io", "https://sq.io"},
{"https://alice:123@sq.io/path", "https://sq.io/path"},
{"https://alice@sq.io/path", "https://sq.io/path"},
{"https://alice:@sq.io/path", "https://sq.io/path"},
}
for i, tc := range testCases {
tc := tc
t.Run(tutil.Name(i, tc), func(t *testing.T) {
u, err := url.Parse(tc.in)
require.NoError(t, err)
got := urlz.StripUser(*u)
require.Equal(t, tc.want, got)
})
}
}
func TestStripScheme(t *testing.T) {
testCases := []struct {
in string
want string
}{
{"https://sq.io", "sq.io"},
{"https://alice:123@sq.io/path", "alice:123@sq.io/path"},
}
for i, tc := range testCases {
tc := tc
t.Run(tutil.Name(i, tc), func(t *testing.T) {
u, err := url.Parse(tc.in)
require.NoError(t, err)
got := urlz.StripScheme(*u)
require.Equal(t, tc.want, got)
})
}
}
func TestStripSchemeAndUser(t *testing.T) {
testCases := []struct {
in string
want string
}{
{"https://sq.io", "sq.io"},
{"https://alice:123@sq.io/path", "sq.io/path"},
{"https://alice@sq.io/path", "sq.io/path"},
{"https://alice:@sq.io/path", "sq.io/path"},
}
for i, tc := range testCases {
tc := tc
t.Run(tutil.Name(i, tc), func(t *testing.T) {
u, err := url.Parse(tc.in)
require.NoError(t, err)
got := urlz.StripSchemeAndUser(*u)
require.Equal(t, tc.want, got)
})
}
}

View File

@ -216,6 +216,13 @@ type SQLDriver interface {
// Dialect returns the SQL dialect.
Dialect() dialect.Dialect
// ConnParams returns the db parameters available for use in a connection
// string. The key is the parameter name (e.g. "sslmode"), and the value
// can be either the set of allowed values, sample values, or nil.
// These values are used for shell completion and the like. The returned
// map does not have to be exhaustive, and can be nil.
ConnParams() map[string][]string
// ErrWrapFunc returns a func that wraps the driver's errors.
ErrWrapFunc() func(error) error
@ -342,25 +349,29 @@ type Database interface {
// Metadata holds driver metadata.
type Metadata struct {
// Type is the driver type, e.g. "mysql" or "csv", etc.
Type source.DriverType `json:"type"`
Type source.DriverType `json:"type" yaml:"type"`
// Description is typically the long name of the driver, e.g.
// "MySQL" or "Microsoft Excel XLSX".
Description string `json:"description"`
Description string `json:"description" yaml:"description"`
// Doc is optional documentation, typically a URL.
Doc string `json:"doc,omitempty"`
Doc string `json:"doc,omitempty" yaml:"doc,omitempty"`
// UserDefined is true if this driver is the product of a
// user driver definition, and false if built-in.
UserDefined bool `json:"user_defined"`
UserDefined bool `json:"user_defined" yaml:"user_defined"`
// IsSQL is true if this driver is a SQL driver.
IsSQL bool `json:"is_sql"`
IsSQL bool `json:"is_sql" yaml:"is_sql"`
// Monotable is true if this is a non-SQL document type that
// effectively has a single table, such as CSV.
Monotable bool `json:"monotable"`
Monotable bool `json:"monotable" yaml:"monotable"`
// DefaultPort is the default port that a driver connects on. A
// value <= 0 indicates not applicable.
DefaultPort int `json:"default_port" yaml:"default_port"`
}
var _ DatabaseOpener = (*Databases)(nil)

View File

@ -66,6 +66,21 @@ func (r *Registry) DriverFor(typ source.DriverType) (Driver, error) {
return p.DriverFor(typ)
}
// SQLDriverFor for is a convenience method for getting a SQLDriver.
func (r *Registry) SQLDriverFor(typ source.DriverType) (SQLDriver, error) {
drvr, err := r.DriverFor(typ)
if err != nil {
return nil, err
}
sqlDrvr, ok := drvr.(SQLDriver)
if !ok {
return nil, errz.Errorf("driver %T is not of type %T", drvr, sqlDrvr)
}
return sqlDrvr, nil
}
// DriversMetadata returns metadata for each registered driver type.
func (r *Registry) DriversMetadata() []Metadata {
var md []Metadata

View File

@ -97,7 +97,7 @@ func (s *Source) LogValue() slog.Value {
// String returns a log/debug-friendly representation.
func (s *Source) String() string {
return fmt.Sprintf("%s|%s| %s", s.Handle, s.Type, s.RedactedLocation())
return fmt.Sprintf("%s|%s|%s", s.Handle, s.Type, s.RedactedLocation())
}
// Group returns the source's group. If s is in the root group,

View File

@ -12,6 +12,8 @@ import (
"testing"
"time"
"github.com/neilotoole/sq/cli/run"
"github.com/neilotoole/sq/drivers"
"github.com/neilotoole/sq/cli"
@ -97,6 +99,7 @@ type Helper struct {
registry *driver.Registry
files *source.Files
databases *driver.Databases
run *run.Run
initOnce sync.Once
@ -191,6 +194,16 @@ func (h *Helper) init() {
h.files.AddDriverDetectors(xlsx.DetectXLSX)
h.addUserDrivers()
h.run = &run.Run{
Stdin: os.Stdin,
Out: os.Stdout,
ErrOut: os.Stdin,
Config: config.New(),
ConfigStore: config.DiscardStore{},
OptionsRegistry: &options.Registry{},
DriverRegistry: h.registry,
}
})
}
@ -606,6 +619,12 @@ func (h *Helper) Registry() *driver.Registry {
return h.registry
}
// Run returns the helper's run instance.
func (h *Helper) Run() *run.Run {
h.init()
return h.run
}
// addUserDrivers adds some user drivers to the registry.
func (h *Helper) addUserDrivers() {
userDriverDefs := DriverDefsFrom(h.T, testsrc.PathDriverDefPpl, testsrc.PathDriverDefRSS)

View File

@ -5,7 +5,9 @@ import (
"fmt"
"io"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"testing"
"unicode"
@ -318,3 +320,41 @@ func (t *tWriter) Write(p []byte) (n int, err error) {
t.t.Log("\n" + string(p))
return len(p), nil
}
// Chdir changes the working directory to dir, or if dir is empty,
// to a temp dir. On test end, the original working dir is restored,
// and the temp dir deleted (if applicable). The absolute path
// of the changed working dir is returned.
func Chdir(t *testing.T, dir string) (absDir string) {
origDir, err := os.Getwd()
require.NoError(t, err)
if filepath.IsAbs(dir) {
absDir = dir
} else {
absDir, err = filepath.Abs(dir)
require.NoError(t, err)
}
if dir == "" {
tmpDir := t.TempDir()
t.Cleanup(func() {
_ = os.Remove(tmpDir)
})
dir = tmpDir
}
require.NoError(t, os.Chdir(dir))
t.Cleanup(func() {
_ = os.Chdir(origDir)
})
return absDir
}
// SkipWindows skips t if running on Windows.
func SkipWindows(t *testing.T, format string, args ...any) {
if runtime.GOOS == "windows" {
t.Skipf(format, args...)
}
}