diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f0d23a35..4e55f7bc 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -9,9 +9,9 @@ on: workflow_dispatch: env: - GO_VERSION: 1.20.2 + GO_VERSION: 1.20.5 GORELEASER_VERSION: 1.13.1 - GOLANGCI_LINT_VERSION: v1.52.2 + GOLANGCI_LINT_VERSION: v1.53.2 jobs: diff --git a/cli/cli.go b/cli/cli.go index c4c1d692..7c6708bb 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -10,9 +10,10 @@ // usage pattern by eliminating all pkg-level constructs // (which makes testing easier). // -// All interaction with cobra should happen inside this package. +// All interaction with cobra should happen inside this package, or +// via the utility cli/cobraz package. // That is to say, the spf13/cobra package should not be imported -// anywhere outside this package. +// anywhere outside this package and cli/cobraz. // // The entry point to this pkg is the Execute function. package cli @@ -103,7 +104,7 @@ func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error { // now handles this situation? // We need to perform handling for autocomplete - if len(args) > 0 && args[0] == "__complete" { + if len(args) > 0 && args[0] == cobra.ShellCompRequestCmd { if hasMatchingChildCommand(rootCmd, args[1]) { // If there is a matching child command, we let rootCmd // handle it, as per normal. @@ -112,7 +113,7 @@ func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error { // There's no command matching the first argument to __complete. // Therefore, we assume that we want to perform completion // for the "slq" command (which is the pseudo-root command). - effectiveArgs := append([]string{"__complete", "slq"}, args[1:]...) + effectiveArgs := append([]string{cobra.ShellCompRequestCmd, "slq"}, args[1:]...) rootCmd.SetArgs(effectiveArgs) } } else { diff --git a/cli/cli_export_test.go b/cli/cli_export_test.go new file mode 100644 index 00000000..7d7ace08 --- /dev/null +++ b/cli/cli_export_test.go @@ -0,0 +1,34 @@ +package cli + +// This file exports package constructs for testing. + +import ( + "testing" + + "github.com/neilotoole/sq/cli/run" +) + +type PlocStage = plocStage + +const ( + PlocInit = plocInit + PlocScheme = plocScheme + PlocUser = plocUser + PlocPass = plocPass + PlocHostname = plocHostname + PlocHost = plocHost + PlocPath = plocPath +) + +var DoCompleteAddLocationFile = locCompListFiles + +// ToTestParseLocStage is a helper to test the +// non-exported locCompletionHelper.locCompParseLoc method. +func DoTestParseLocStage(t testing.TB, ru *run.Run, loc string) (PlocStage, error) { //nolint:revive + ploc, err := locCompParseLoc(loc) + if err != nil { + return PlocInit, err + } + + return ploc.stageDone, nil +} diff --git a/cli/cmd_add.go b/cli/cmd_add.go index d65c72f7..55703dd0 100644 --- a/cli/cmd_add.go +++ b/cli/cmd_add.go @@ -26,9 +26,10 @@ import ( func newSrcAddCmd() *cobra.Command { cmd := &cobra.Command{ - Use: "add [--handle @HANDLE] LOCATION", - RunE: execSrcAdd, - Args: cobra.ExactArgs(1), + Use: "add [--handle @HANDLE] LOCATION", + RunE: execSrcAdd, + Args: cobra.ExactArgs(1), + ValidArgsFunction: completeAddLocation, Example: ` When adding a data source, LOCATION is the only required arg. @@ -39,21 +40,30 @@ Note that sq generated the handle "@actor". But you can explicitly specify a handle. # Add a postgres source with handle "@sakila/pg" - $ sq add --handle @sakila/pg 'postgres://user:pass@localhost/sakila' + $ sq add --handle @sakila/pg postgres://user:pass@localhost/sakila This handle format "@sakila/pg" includes a group, "sakila". Using a group is entirely optional: it is a way to organize sources. For example: - $ sq add --handle @dev/pg 'postgres://user:pass@dev.db.example.com/sakila' - $ sq add --handle @prod/pg 'postgres://user:pass@prod.db.acme.com/sakila' + $ sq add --handle @dev/pg postgres://user:pass@dev.db.acme.com/sakila + $ sq add --handle @prod/pg postgres://user:pass@prod.db.acme.com/sakila The format of LOCATION is driver-specific, but is generally a DB connection string, a file path, or a URL. - DRIVER://USER:PASS@HOST:PORT/DBNAME + DRIVER://USER:PASS@HOST:PORT/DBNAME?PARAM=VAL /path/to/local/file.ext https://sq.io/data/test1.xlsx +If LOCATION contains special shell characters, it's necessary to enclose +it in single quotes, or to escape the special character. For example, +note the "\?" in the unquoted location below. + + $ sq add postgres://user:pass@localhost/sakila\?sslmode=disable + +A significant advantage of not quoting LOCATION is that sq provides extensive +shell completion when inputting the location value. + If flag --handle is omitted, sq will generate a handle based on LOCATION and the source driver type. @@ -61,7 +71,7 @@ It's a security hazard to expose the data source password via the LOCATION string. If flag --password (-p) is set, sq prompt the user for the password: - $ sq add 'postgres://user@localhost/sakila' -p + $ sq add postgres://user@localhost/sakila -p Password: **** However, if there's input on stdin, sq will read the password from @@ -69,18 +79,18 @@ there instead of prompting the user: # Add a source, but read password from an environment variable $ export PASSWD='open:;"_Ses@me' - $ sq add 'postgres://user@localhost/sakila' -p <<< $PASSWD + $ sq add postgres://user@localhost/sakila -p <<< $PASSWD # Same as above, but instead read password from file $ echo 'open:;"_Ses@me' > password.txt - $ sq add 'postgres://user@localhost/sakila' -p < password.txt + $ sq add postgres://user@localhost/sakila -p < password.txt There are various driver-specific options available. For example: $ sq add actor.csv --ingest.header=false --driver.csv.delim=colon If flag --driver is omitted, sq will attempt to determine the -type from LOCATION via file suffix, content type, etc.. If the result +type from LOCATION via file suffix, content type, etc. If the result is ambiguous, explicitly specify the driver type. $ sq add --driver=tsv ./mystery.data @@ -106,19 +116,25 @@ use flag --active to make the new source active. More examples: # Add a source, but prompt user for password - $ sq add 'postgres://user@localhost/sakila' -p + $ sq add postgres://user@localhost/sakila -p Password: **** # Explicitly set flags - $ sq add --handle @sakila_pg --driver postgres 'postgres://user:pass@localhost/sakila' + $ sq add --handle @sakila_pg --driver postgres postgres://user:pass@localhost/sakila # Same as above, but with short flags - $ sq add -n @sakila_pg -d postgres 'postgres://user:pass@localhost/sakila' + $ sq add -n @sakila_pg -d postgres postgres://user:pass@localhost/sakila + + # Specify some params (note escaped chars) + $ sq add postgres://user:pass@localhost/sakila\?sslmode=disable\&application_name=sq + +# Specify some params, but use quoted string (no shell completion) + $ sq add 'postgres://user:pass@localhost/sakila?sslmode=disable&application_name=sq'' # Add a SQL Server source; will have generated handle @sakila $ sq add 'sqlserver://user:pass@localhost?database=sakila' - # Add a sqlite db, and immediately make it the active source + # Add a SQLite DB, and immediately make it the active source $ sq add ./testdata/sqlite1.db --active # Add an Excel spreadsheet, with options @@ -134,7 +150,7 @@ More examples: $ sq add ./actor.csv --handle @csv/actor # Add a currently unreachable source - $ sq add 'postgres://user:pass@db.offline.com/sakila' --skip-verify`, + $ sq add postgres://user:pass@db.offline.com/sakila --skip-verify`, Short: "Add data source", Long: `Add data source specified by LOCATION, optionally identified by @HANDLE.`, } diff --git a/cli/cobraz/cobraz.go b/cli/cobraz/cobraz.go new file mode 100644 index 00000000..97ff6f3a --- /dev/null +++ b/cli/cobraz/cobraz.go @@ -0,0 +1,113 @@ +// Package cobraz contains supplemental logic for dealing with spf13/cobra. +package cobraz + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" +) + +// Defines the text values for cobra.ShellCompDirective. +const ( + ShellCompDirectiveErrorText = "ShellCompDirectiveError" + ShellCompDirectiveNoSpaceText = "ShellCompDirectiveNoSpace" + ShellCompDirectiveNoFileCompText = "ShellCompDirectiveNoFileComp" + ShellCompDirectiveFilterFileExtText = "ShellCompDirectiveFilterFileExt" + ShellCompDirectiveFilterDirsText = "ShellCompDirectiveFilterDirs" + ShellCompDirectiveKeepOrderText = "ShellCompDirectiveKeepOrder" + ShellCompDirectiveDefaultText = "ShellCompDirectiveDefault" + ShellCompDirectiveUnknownText = "ShellCompDirectiveUnknown" +) + +// ParseDirectivesLine parses the line of text returned by "__complete" cmd +// that contains the text description of the result. +// The line looks like: +// +// Completion ended with directive: ShellCompDirectiveNoSpace, ShellCompDirectiveKeepOrder +// +// Note that this function will panic on an unknown directive. +func ParseDirectivesLine(directivesLine string) []cobra.ShellCompDirective { + trimmedLine := strings.TrimPrefix(strings.TrimSpace(directivesLine), "Completion ended with directive: ") + parts := strings.Split(trimmedLine, ", ") + directives := make([]cobra.ShellCompDirective, 0, len(parts)) + for _, part := range parts { + switch part { + case ShellCompDirectiveErrorText: + directives = append(directives, cobra.ShellCompDirectiveError) + case ShellCompDirectiveNoSpaceText: + directives = append(directives, cobra.ShellCompDirectiveNoSpace) + case ShellCompDirectiveNoFileCompText: + directives = append(directives, cobra.ShellCompDirectiveNoFileComp) + case ShellCompDirectiveFilterFileExtText: + directives = append(directives, cobra.ShellCompDirectiveFilterFileExt) + case ShellCompDirectiveFilterDirsText: + directives = append(directives, cobra.ShellCompDirectiveFilterDirs) + case ShellCompDirectiveKeepOrderText: + directives = append(directives, cobra.ShellCompDirectiveKeepOrder) + case ShellCompDirectiveDefaultText: + directives = append(directives, cobra.ShellCompDirectiveDefault) + default: + panic(fmt.Sprintf("Unknown cobra.ShellCompDirective {%s} in: %s", part, directivesLine)) + } + } + return directives +} + +// ExtractDirectives extracts the individual directives +// from a combined directive. +func ExtractDirectives(result cobra.ShellCompDirective) []cobra.ShellCompDirective { + if result == cobra.ShellCompDirectiveDefault { + return []cobra.ShellCompDirective{cobra.ShellCompDirectiveDefault} + } + + var a []cobra.ShellCompDirective + + allDirectives := []cobra.ShellCompDirective{ + cobra.ShellCompDirectiveError, + cobra.ShellCompDirectiveNoSpace, + cobra.ShellCompDirectiveNoFileComp, + cobra.ShellCompDirectiveFilterFileExt, + cobra.ShellCompDirectiveFilterDirs, + cobra.ShellCompDirectiveDefault, + } + + for _, directive := range allDirectives { + if directive&result > 0 { + a = append(a, directive) + } + } + + return a +} + +// MarshalDirective marshals a cobra.ShellCompDirective to text strings, +// after extracting the embedded directives. +func MarshalDirective(directive cobra.ShellCompDirective) []string { + gotDirectives := ExtractDirectives(directive) + + s := make([]string, len(gotDirectives)) + for i, d := range gotDirectives { + switch d { + case cobra.ShellCompDirectiveError: + s[i] = ShellCompDirectiveErrorText + case cobra.ShellCompDirectiveNoSpace: + s[i] = ShellCompDirectiveNoSpaceText + case cobra.ShellCompDirectiveNoFileComp: + s[i] = ShellCompDirectiveNoFileCompText + case cobra.ShellCompDirectiveFilterFileExt: + s[i] = ShellCompDirectiveFilterFileExtText + case cobra.ShellCompDirectiveFilterDirs: + s[i] = ShellCompDirectiveFilterDirsText + case cobra.ShellCompDirectiveKeepOrder: + s[i] = ShellCompDirectiveKeepOrderText + case cobra.ShellCompDirectiveDefault: + s[i] = ShellCompDirectiveDefaultText + default: + // Should never happen + s[i] = ShellCompDirectiveUnknownText + } + } + + return s +} diff --git a/cli/cobraz/cobraz_test.go b/cli/cobraz/cobraz_test.go new file mode 100644 index 00000000..8a26b8bd --- /dev/null +++ b/cli/cobraz/cobraz_test.go @@ -0,0 +1,42 @@ +package cobraz + +import ( + "testing" + + "github.com/neilotoole/sq/testh/tutil" + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" +) + +func TestExtractDirectives(t *testing.T) { + testCases := []struct { + in cobra.ShellCompDirective + want []cobra.ShellCompDirective + wantStrings []string + }{ + { + cobra.ShellCompDirectiveError, + []cobra.ShellCompDirective{cobra.ShellCompDirectiveError}, + []string{ShellCompDirectiveErrorText}, + }, + { + cobra.ShellCompDirectiveError | cobra.ShellCompDirectiveNoSpace, + []cobra.ShellCompDirective{cobra.ShellCompDirectiveError, cobra.ShellCompDirectiveNoSpace}, + []string{ShellCompDirectiveErrorText, ShellCompDirectiveNoSpaceText}, + }, + { + cobra.ShellCompDirectiveDefault, + []cobra.ShellCompDirective{cobra.ShellCompDirectiveDefault}, + []string{ShellCompDirectiveDefaultText}, + }, + } + + for i, tc := range testCases { + t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + gotDirectives := ExtractDirectives(tc.in) + require.Equal(t, tc.want, gotDirectives) + gotStrings := MarshalDirective(tc.in) + require.Equal(t, tc.wantStrings, gotStrings) + }) + } +} diff --git a/cli/complete_location.go b/cli/complete_location.go new file mode 100644 index 00000000..72245a2e --- /dev/null +++ b/cli/complete_location.go @@ -0,0 +1,970 @@ +package cli + +import ( + "context" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/neilotoole/sq/libsq/core/urlz" + + "golang.org/x/exp/slog" + + "github.com/neilotoole/sq/libsq/core/ioz" + + "github.com/neilotoole/sq/libsq/core/errz" + + "github.com/neilotoole/sq/cli/run" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/driver" + "golang.org/x/exp/slices" + + "github.com/neilotoole/sq/drivers/mysql" + "github.com/neilotoole/sq/drivers/postgres" + "github.com/neilotoole/sq/drivers/sqlite3" + "github.com/neilotoole/sq/drivers/sqlserver" + + "github.com/xo/dburl" + + "github.com/neilotoole/sq/libsq/core/lg/lga" + + "github.com/neilotoole/sq/libsq/core/stringz" + + "github.com/neilotoole/sq/libsq/source" + + "github.com/samber/lo" + "github.com/spf13/cobra" +) + +// locCompStdDirective is the standard cobra shell completion directive +// returned by completeAddLocation. +const locCompStdDirective = cobra.ShellCompDirectiveNoSpace | cobra.ShellCompDirectiveKeepOrder + +// completeAddLocation provides completion for the "sq add LOCATION" arg. +// This is a messy task, as LOCATION can be a database driver URL, +// and it can also be a filepath. To complicate matters further, sqlite +// has a format sqlite://FILE/PATH?param=val, which is a driver URL, with +// embedded file completion. +// +// The general strategy is: +// - Does toComplete have a driver prefix ("postgres://", "sqlite3://" etc)? +// If so, delegate to the appropriate function. +// - Is toComplete definitively NOT a driver URL? That is to say, is toComplete +// a file path? If so, then we need regular shell file completion. +// Return cobra.ShellCompDirectiveDefault, and let the shell handle it. +// - There's a messy overlap where toComplete could be either a driver URL +// or a filepath. For example, "post" could be leading to "postgres://", or +// to a file named "post.db". For this situation, it is necessary to +// mimic in code the behavior of the shell's file completion. +// - There's another layer of complexity: previous locations (i.e. "history") +// are also suggested. +// +// The code, as currently structured, is ungainly, and downright ugly in +// spots, and probably won't scale well if more drivers are supported. That +// is to say, this mechanism would benefit from a through refactor. +func completeAddLocation(cmd *cobra.Command, args []string, toComplete string) ( + []string, cobra.ShellCompDirective, +) { + if len(args) > 0 { + return nil, cobra.ShellCompDirectiveError + } + + if strings.HasPrefix(toComplete, "/") || strings.HasPrefix(toComplete, ".") { + // This has to be a file path. + // Go straight to default (file) completion. + return nil, cobra.ShellCompDirectiveDefault + } + + var a []string + if toComplete == "" { + // No input yet. Offer both the driver URL schemes and file listing. + a = append(a, locSchemes...) + files := locCompListFiles(cmd.Context(), toComplete) + if len(files) > 0 { + a = append(a, files...) + } + + return a, locCompStdDirective + } + + // We've got some input in toComplete... + if !stringz.HasAnyPrefix(toComplete, locSchemes...) { + // But toComplete isn't a full match for any of the driver + // URL schemes. However, it could still be a partial match. + + a = stringz.FilterPrefix(toComplete, locSchemes...) + if len(a) == 0 { + // We're not matching any URL prefix, fall back to default + // shell completion, i.e. list files. + return nil, cobra.ShellCompDirectiveDefault + } + + // Partial match, e.g. "post". So, this could match both + // a URL such as "postgres://", or a file such as "post.db". + files := locCompListFiles(cmd.Context(), toComplete) + if len(files) > 0 { + a = append(a, files...) + } + + return a, locCompStdDirective + } + + // If we got this far, we know that toComplete starts with one of the + // driver schemes, e.g. "postgres://". There's no possibility that + // this could be a file completion. + + if strings.HasPrefix(toComplete, string(sqlite3.Type)) { + // Special handling for sqlite. + return locCompDoSQLite3(cmd, args, toComplete) + } + + return locCompDoGenericDriver(cmd, args, toComplete) +} + +// locCompDoGenericDriver provides completion for generic SQL drivers. +// Specifically, it's tested with postgres, sqlserver, and mysql. Note that +// sqlserver is slightly different from the others, in that the db name goes +// in a query param, not in the URL path. It might be cleaner to split sqlserver +// off into its own function. +func locCompDoGenericDriver(cmd *cobra.Command, _ []string, toComplete string, //nolint:funlen,gocognit +) ([]string, cobra.ShellCompDirective) { + // If we get this far, then toComplete is at least a partial URL + // starting with "postgres://", "mysql://", etc. + + var ( + ctx = cmd.Context() + log = lg.FromContext(ctx) + ru = run.FromContext(ctx) + drvr driver.SQLDriver + ploc *parsedLoc + a []string // a holds the completion strings to be returned + err error + ) + + if err = FinishRunInit(ctx, ru); err != nil { + log.Error("Init run", lga.Err, err) + return nil, cobra.ShellCompDirectiveError + } + + if ploc, err = locCompParseLoc(toComplete); err != nil { + log.Error("Parse location", lga.Err, err) + return nil, cobra.ShellCompDirectiveError + } + + if drvr, err = ru.DriverRegistry.SQLDriverFor(ploc.typ); err != nil { + log.Error("Load driver", lga.Err, err) + return nil, cobra.ShellCompDirectiveError + } + + hist := &locHistory{ + coll: ru.Config.Collection, + typ: ploc.typ, + log: log, + } + + switch ploc.stageDone { //nolint:exhaustive + case plocScheme: + unames := hist.usernames() + + if ploc.user == "" { + a = stringz.PrefixSlice(unames, toComplete) + a = append(a, toComplete+"username") + a = lo.Uniq(a) + a = lo.Without(a, toComplete) + return a, locCompStdDirective + } + + // Else, we have at least a partial username + a = []string{ + toComplete + "@", + toComplete + ":", + } + + for _, uname := range unames { + v := string(ploc.typ) + "://" + uname + a = append(a, v+"@") + a = append(a, v+":") + } + + a = lo.Uniq(a) + a = stringz.FilterPrefix(toComplete, a...) + a = lo.Without(a, toComplete) + return a, locCompStdDirective + case plocUser: + if ploc.pass == "" { + a = []string{ + toComplete, + toComplete + "@", + toComplete + "password@", + } + + return a, locCompStdDirective + } + + a = []string{ + toComplete + "@", + } + + return a, locCompStdDirective + case plocPass: + hosts := hist.hosts() + hostsWithPath := hist.hostsWithPathAndQuery() + defaultPort := locCompDriverPort(drvr) + afterHost := locCompAfterHost(ploc.typ) + + if ploc.hostname == "" { + if defaultPort == "" { + a = []string{ + toComplete + "localhost" + afterHost, + } + } else { + a = []string{ + toComplete + "localhost" + afterHost, + toComplete + "localhost:" + defaultPort + afterHost, + } + } + + var b []string + for _, h := range hostsWithPath { + v := toComplete + h + b = append(b, v) + } + for _, h := range hosts { + v := toComplete + h + afterHost + b = append(b, v) + } + + slices.Sort(b) + a = append(a, b...) + a = lo.Uniq(a) + a = stringz.FilterPrefix(toComplete, a...) + a = lo.Without(a, toComplete) + return a, locCompStdDirective + } + + base, _, _ := strings.Cut(toComplete, "@") + base += "@" + + if ploc.port <= 0 { + if defaultPort == "" { + a = []string{ + toComplete + afterHost, + base + "localhost" + afterHost, + } + } else { + a = []string{ + toComplete + afterHost, + toComplete + ":" + defaultPort + afterHost, + base + "localhost" + afterHost, + base + "localhost:" + defaultPort + afterHost, + } + } + + var b []string + for _, h := range hostsWithPath { + v := base + h + b = append(b, v) + } + for _, h := range hosts { + v := base + h + afterHost + b = append(b, v) + } + + slices.Sort(b) + a = append(a, b...) + a = lo.Uniq(a) + a = stringz.FilterPrefix(toComplete, a...) + a = lo.Without(a, toComplete) + return a, locCompStdDirective + } + + if defaultPort == "" { + a = []string{ + base + "localhost" + afterHost, + toComplete + afterHost, + } + } else { + a = []string{ + base + "localhost" + afterHost, + base + "localhost:" + defaultPort + afterHost, + toComplete + afterHost, + } + } + + var b []string + for _, h := range hostsWithPath { + v := base + h + b = append(b, v) + } + for _, h := range hosts { + v := base + h + afterHost + b = append(b, v) + } + slices.Sort(b) + a = append(a, b...) + a = lo.Uniq(a) + a = stringz.FilterPrefix(toComplete, a...) + a = lo.Without(a, toComplete) + return a, locCompStdDirective + case plocHostname: + defaultPort := locCompDriverPort(drvr) + afterHost := locCompAfterHost(ploc.typ) + if strings.HasSuffix(toComplete, ":") { + a = []string{toComplete + defaultPort + afterHost} + return a, locCompStdDirective + } + + a = []string{toComplete + afterHost} + return a, locCompStdDirective + + case plocHost: + dbNames := hist.databases() + // Special handling for SQLServer. The input is typically of the form: + // sqlserver://alice@server?database=db + // But it can also be of the form: + // sqlserver://alice@server/instance?database=db + if ploc.typ == sqlserver.Type { + if ploc.du.Path == "/" { + a = []string{toComplete + "instance?database="} + return a, locCompStdDirective + } + + a = []string{toComplete + "?database="} + return a, locCompStdDirective + } + + if ploc.name == "" { + a = []string{toComplete + "db"} + for _, dbName := range dbNames { + v := toComplete + dbName + a = append(a, v) + } + a = lo.Uniq(a) + a = lo.Without(a, toComplete) + return a, locCompStdDirective + } + + // We already have a partial dbname + a = []string{toComplete + "?"} + + base := urlz.StripPath(ploc.du.URL, true) + for _, dbName := range dbNames { + a = append(a, base+"/"+dbName) + } + + a = lo.Uniq(a) + a = lo.Without(a, toComplete) + return a, locCompStdDirective + + default: + // We're at plocName (db name is done), so it's on to conn params. + return locCompDoConnParams(ploc.du, hist, drvr, toComplete) + } +} + +// locCompDoSQLite3 completes a location starting with "sqlite3://". +// We have special handling for SQLite, because it's not a generic +// driver URL, but rather sqlite3://FILE/PATH?param=X. +func locCompDoSQLite3(cmd *cobra.Command, _ []string, toComplete string) ([]string, cobra.ShellCompDirective) { + var ( + ctx = cmd.Context() + log = lg.FromContext(ctx) + ru = run.FromContext(ctx) + drvr driver.SQLDriver + err error + ) + + if err = FinishRunInit(ctx, ru); err != nil { + log.Error("Init run", lga.Err, err) + return nil, cobra.ShellCompDirectiveError + } + + if drvr, err = ru.DriverRegistry.SQLDriverFor(sqlite3.Type); err != nil { + // Shouldn't happen + log.Error("Cannot load driver", lga.Err, err) + return nil, cobra.ShellCompDirectiveError + } + + hist := &locHistory{ + coll: ru.Config.Collection, + typ: sqlite3.Type, + log: log, + } + + du, err := dburl.Parse(toComplete) + if err == nil { + // Check if we're done with the filepath part, and on to conn params? + if du.URL.RawQuery != "" || strings.HasSuffix(toComplete, "?") { + return locCompDoConnParams(du, hist, drvr, toComplete) + } + } + + // Build a list of files. + start := strings.TrimPrefix(toComplete, "sqlite3://") + paths := locCompListFiles(ctx, start) + for i := range paths { + if ioz.IsPathToRegularFile(paths[i]) && paths[i] == start { + paths[i] += "?" + } + + paths[i] = "sqlite3://" + paths[i] + } + + a := hist.locations() + a = append(a, paths...) + a = lo.Uniq(a) + a = stringz.FilterPrefix(toComplete, a...) + a = lo.Without(a, toComplete) + + return a, locCompStdDirective +} + +// locCompDoConnParams completes the query params. For example, given +// a toComplete value "sqlite3://my.db?", the result would include values +// such as "sqlite3://my.db?cache=". +func locCompDoConnParams(du *dburl.URL, hist *locHistory, drvr driver.SQLDriver, toComplete string) ( + []string, cobra.ShellCompDirective, +) { + var ( + a []string + query = du.RawQuery + drvrParamKeys, drvrParams = locCompGetConnParams(drvr) + ) + + pathsWithQueries := hist.pathsWithQueries() + + base := urlz.StripPath(du.URL, true) + if query == "" { + a = stringz.PrefixSlice(pathsWithQueries, base) + v := stringz.PrefixSlice(drvrParamKeys, toComplete) + v = stringz.SuffixSlice(v, "=") + a = append(a, v...) + a = lo.Uniq(a) + a = stringz.FilterPrefix(toComplete, a...) + a = lo.Without(a, toComplete) + return a, locCompStdDirective + } + + for _, pwq := range pathsWithQueries { + if strings.HasPrefix(pwq, du.Path) { + a = append(a, base+pwq) + } + } + a = stringz.FilterPrefix(toComplete, a...) + + actualKeys, err := urlz.QueryParamKeys(query) + if err != nil || len(actualKeys) == 0 { + return nil, cobra.ShellCompDirectiveError + } + + actualValues, err := url.ParseQuery(query) + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + elements := strings.Split(query, "&") + + // could be "sslmo", "sslmode", "sslmode=", "sslmode=dis" + lastElement := elements[len(elements)-1] + stump := strings.TrimSuffix(toComplete, lastElement) + before, _, ok := strings.Cut(lastElement, "=") + if !ok { + candidateKeys := stringz.ElementsHavingPrefix(drvrParamKeys, before) + candidateKeys = lo.Reject(candidateKeys, func(candidateKey string, index int) bool { + // We don't want the same candidate to show up twice, so we exclude + // it, but only if it already has a value in the query string. + if slices.Contains(actualKeys, candidateKey) { + vals, ok := actualValues[candidateKey] + if !ok || len(vals) == 0 { + return false + } + + for _, val := range vals { + if val != "" { + return true + } + } + } + + return false + }) + + for i := range candidateKeys { + s := stump + candidateKeys[i] + "=" + a = append(a, s) + } + + a = lo.Uniq(a) + a = lo.Without(a, toComplete) + return a, locCompStdDirective + } + + candidateVals := drvrParams[before] + for i := range candidateVals { + s := stump + before + "=" + candidateVals[i] + a = append(a, s) + } + + if len(candidateVals) == 0 { + lastChar := toComplete[len(toComplete)-1] + switch lastChar { + case '&', '?', '=': + default: + a = append(a, toComplete+"&") + } + } + + a = lo.Uniq(a) + a = stringz.FilterPrefix(toComplete, a...) + if len(a) == 0 { + // If it's an unknown value, append "&" to move + // on to a further query param. + a = []string{toComplete + "&"} + return a, locCompStdDirective + } + + if len(a) == 1 && a[0] == toComplete { + // If it's a completed known value ("sslmode=disable"), + // then append "?" to move on to a further query param. + a[0] += "&" + } + + return a, locCompStdDirective +} + +// locCompGetConnParams returns the driver's connection params. The returned +// keys are sorted appropriately for the driver, and are query encoded. +func locCompGetConnParams(drvr driver.SQLDriver) (keys []string, params map[string][]string) { + ogParams := drvr.ConnParams() + ogKeys := lo.Keys(ogParams) + slices.Sort(ogKeys) + + if drvr.DriverMetadata().Type == sqlserver.Type { + // For SQLServer, the "database" key should come first, because + // it's required. + ogKeys = lo.Without(ogKeys, "database") + ogKeys = append([]string{"database"}, ogKeys...) + } + + keys = make([]string, len(ogKeys)) + params = make(map[string][]string, len(ogParams)) + for i := range ogKeys { + k := url.QueryEscape(ogKeys[i]) + keys[i] = k + params[k] = ogParams[ogKeys[i]] + } + + return keys, params +} + +// locCompDriverPort returns the default port for the driver, as a string, +// or empty string if not applicable. +func locCompDriverPort(drvr driver.SQLDriver) string { + p := drvr.DriverMetadata().DefaultPort + if p <= 0 { + return "" + } + + return strconv.Itoa(p) +} + +// locCompAfterHost returns the next text to show after the host +// part of the URL is complete. +func locCompAfterHost(typ source.DriverType) string { + if typ == sqlserver.Type { + return "?database=" + } + + return "/" +} + +// locCompParseLoc parses a location string. The string can +// be in various stages of construction, e.g. "postgres://user" or +// "postgres://user@locahost/db". The stage is noted in parsedLoc.stageDone. +func locCompParseLoc(loc string) (*parsedLoc, error) { + p := &parsedLoc{loc: loc} + if !stringz.HasAnyPrefix(loc, locSchemes...) { + return p, nil + } + + var ( + s string + ok bool + err error + creds string + ) + + p.stageDone = plocScheme + p.scheme, s, ok = strings.Cut(loc, "://") + p.typ = source.DriverType(p.scheme) + + if s == "" || !ok { + return p, nil + } + + creds, s, ok = strings.Cut(s, "@") + if creds != "" { + // creds can be: + // user:pass + // user: + // user + + var hasColon bool + p.user, p.pass, hasColon = strings.Cut(creds, ":") + if hasColon { + p.stageDone = plocUser + } + } + if !ok { + return p, nil + } + + p.stageDone = plocPass + + // At a minimum, we're at this point: + // postgres:// + + // Next we're looking for user:pass, e.g. + // postgres://alice:huzzah@localhost + + if p.du, err = dburl.Parse(p.loc); err != nil { + return p, errz.Err(err) + } + du := p.du + p.scheme = du.OriginalScheme + if du.User != nil { + p.user = du.User.Username() + p.pass, _ = du.User.Password() + } + p.hostname = du.Hostname() + + if strings.ContainsRune(du.URL.Host, ':') { + p.stageDone = plocHostname + } + + if du.Port() != "" { + p.stageDone = plocHostname + p.port, err = strconv.Atoi(du.Port()) + if err != nil { + p.port = -1 + return p, nil //nolint:nilerr + } + } + + switch p.typ { //nolint:exhaustive + default: + case sqlserver.Type: + var u *url.URL + if u, err = url.ParseRequestURI(loc); err == nil { + var vals url.Values + if vals, err = url.ParseQuery(u.RawQuery); err == nil { + p.name = vals.Get("database") + } + } + + case postgres.Type, mysql.Type: + p.name = strings.TrimPrefix(du.Path, "/") + } + + if strings.HasSuffix(s, "/") || strings.HasSuffix(s, `\?`) || du.URL.Path != "" { + p.stageDone = plocHost + } + + if strings.HasSuffix(s, "?") { + p.stageDone = plocPath + } + + if du.URL.RawQuery != "" { + p.stageDone = plocPath + } + + return p, nil +} + +// parsedLoc is a parsed representation of a driver location URL. +// It can represent partial or fully constructed locations. The stage +// of construction is noted in parsedLoc.stageDone. +type parsedLoc struct { + // loc is the original unparsed location value. + loc string + + // stageDone indicates what stage of construction the location + // string is in. + stageDone plocStage + + // typ is the associated source driver type, which may + // be empty until later determination. + typ source.DriverType + + // scheme is the original location scheme + scheme string + + // user is the username, if applicable. + user string + + // pass is the password, if applicable. + pass string + + // hostname is the hostname, if applicable. + hostname string + + // port is the port number, or 0 if not applicable. + port int + + // name is the database name. + name string + + // du holds the parsed db url. This may be nil. + du *dburl.URL +} + +// plocStage is an enum indicating what stage of construction +// a location string is in. +type plocStage string + +const ( + plocInit plocStage = "" + plocScheme plocStage = "scheme" + plocUser plocStage = "user" + plocPass plocStage = "pass" + plocHostname plocStage = "hostname" + plocHost plocStage = "host" // host is hostname+port, or just hostname + plocPath plocStage = "path" +) + +// locSchemes is the set of built-in (SQL) driver schemes. +var locSchemes = []string{ + "mysql://", + "postgres://", + "sqlite3://", + "sqlserver://", +} + +// locCompListFiles completes filenames. This function tries to +// mimic what a shell would do. Any errors are logged and swallowed. +func locCompListFiles(ctx context.Context, toComplete string) []string { + var ( + start = toComplete + files []string + err error + ) + + if start == "" { + start, err = os.Getwd() + if err != nil { + return nil + } + files, err = ioz.ReadDir(start, false, true, false) + if err != nil { + lg.FromContext(ctx).Warn("Read dir", lga.Path, start, lga.Err, err) + } + + return files + } + + if strings.HasSuffix(start, "/") { + files, err = ioz.ReadDir(start, true, true, false) + if err != nil { + lg.FromContext(ctx).Warn("Read dir", lga.Path, start, lga.Err, err) + } + return files + } + + // We could have a situation like this: + // + [working dir] + // - my.db + // - my/my2.db + + dir := filepath.Dir(start) + fi, err := os.Stat(dir) + if err == nil && fi.IsDir() { + files, err = ioz.ReadDir(dir, true, true, false) + if err != nil { + lg.FromContext(ctx).Warn("Read dir", lga.Path, start, lga.Err, err) + } + } else { + files = []string{start} + } + + return stringz.FilterPrefix(toComplete, files...) +} + +// locHistory provides methods for getting previously used +// elements of a location. +type locHistory struct { + coll *source.Collection + typ source.DriverType + log *slog.Logger +} + +func (h *locHistory) usernames() []string { + var unames []string + + _ = h.coll.Visit(func(src *source.Source) error { + if src.Type != h.typ { + return nil + } + + du, err := dburl.Parse(src.Location) + if err != nil { + // Shouldn't happen + h.log.Warn("Parse source location", lga.Err, err) + return nil + } + + if du.User != nil { + uname := du.User.Username() + if uname != "" { + unames = append(unames, uname) + } + } + + return nil + }) + + unames = lo.Uniq(unames) + slices.Sort(unames) + return unames +} + +func (h *locHistory) hosts() []string { + var hosts []string + + _ = h.coll.Visit(func(src *source.Source) error { + if src.Type != h.typ { + return nil + } + + du, err := dburl.Parse(src.Location) + if err != nil { + // Shouldn't happen + h.log.Warn("Parse source location", lga.Err, err) + return nil + } + + hosts = append(hosts, du.Host) + + return nil + }) + + hosts = lo.Uniq(hosts) + slices.Sort(hosts) + return hosts +} + +func (h *locHistory) databases() []string { + var dbNames []string + + _ = h.coll.Visit(func(src *source.Source) error { + if src.Type != h.typ { + return nil + } + + du, err := dburl.Parse(src.Location) + if err != nil { + // Shouldn't happen + h.log.Warn("Parse source location", lga.Err, err) + return nil + } + + if h.typ == sqlserver.Type && du.RawQuery != "" { + var vals url.Values + if vals, err = url.ParseQuery(du.RawQuery); err == nil { + db := vals.Get("database") + if db != "" { + dbNames = append(dbNames, db) + } + } + return nil + } + + if du.Path != "" { + v := strings.TrimPrefix(du.Path, "/") + dbNames = append(dbNames, v) + } + + return nil + }) + + dbNames = lo.Uniq(dbNames) + slices.Sort(dbNames) + return dbNames +} + +func (h *locHistory) locations() []string { + var locs []string + _ = h.coll.Visit(func(src *source.Source) error { + if src.Type != h.typ { + return nil + } + + locs = append(locs, src.Location) + return nil + }) + + locs = lo.Uniq(locs) + slices.Sort(locs) + return locs +} + +// hostsWithPathAndQuery returns locations, minus the +// scheme and user info. +func (h *locHistory) hostsWithPathAndQuery() []string { + var values []string + + _ = h.coll.Visit(func(src *source.Source) error { + if src.Type != h.typ { + return nil + } + + du, err := dburl.Parse(src.Location) + if err != nil { + // Shouldn't happen + h.log.Warn("Parse source location", lga.Err, err) + return nil + } + + v := urlz.StripSchemeAndUser(du.URL) + if v != "" { + values = append(values, v) + } + return nil + }) + + values = lo.Uniq(values) + slices.Sort(values) + return values +} + +// pathsWithQueries returns the location elements after +// the host, i.e. the path and query. +func (h *locHistory) pathsWithQueries() []string { + var values []string + + _ = h.coll.Visit(func(src *source.Source) error { + if src.Type != h.typ { + return nil + } + + du, err := dburl.Parse(src.Location) + if err != nil { + // Shouldn't happen + h.log.Warn("Parse source location", lga.Err, err) + return nil + } + + v := du.Path + if du.RawQuery != "" { + v += `?` + du.RawQuery + } + + values = append(values, v) + return nil + }) + + values = lo.Uniq(values) + slices.Sort(values) + return values +} diff --git a/cli/complete_location_test.go b/cli/complete_location_test.go new file mode 100644 index 00000000..c39b6d9d --- /dev/null +++ b/cli/complete_location_test.go @@ -0,0 +1,1408 @@ +package cli_test + +import ( + "context" + "path/filepath" + "strconv" + "strings" + "testing" + + "github.com/neilotoole/sq/drivers/sqlite3" + + "github.com/neilotoole/sq/drivers/sqlserver" + + "github.com/neilotoole/sq/drivers/postgres" + "github.com/neilotoole/sq/libsq/source" + + "github.com/neilotoole/sq/libsq/core/stringz" + + "github.com/samber/lo" + + "github.com/neilotoole/sq/testh" + + "github.com/neilotoole/sq/cli" + + "github.com/stretchr/testify/assert" + + "github.com/neilotoole/sq/cli/cobraz" + + "github.com/neilotoole/slogt" + "github.com/neilotoole/sq/libsq/core/lg" + + "github.com/neilotoole/sq/testh/tutil" + + "github.com/stretchr/testify/require" + + "github.com/neilotoole/sq/cli/testrun" + "github.com/spf13/cobra" +) + +var locSchemes = []string{ + "mysql://", + "postgres://", + "sqlite3://", + "sqlserver://", +} + +const stdDirective = cobra.ShellCompDirectiveNoSpace | cobra.ShellCompDirectiveKeepOrder + +func TestCompleteAddLocation_Postgres(t *testing.T) { + tutil.SkipWindows(t, "Shell completion not fully implemented for windows") + + wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + t.Logf("Working dir: %s", wd) + + testCases := []struct { + // args will have "add" prepended + args []string + want []string + wantResult cobra.ShellCompDirective + }{ + { + args: []string{""}, + want: lo.Union(locSchemes, []string{ + "data/", "my/", "my.db", "post/", "post.db", "sqlite/", "sqlite.db", + }), + wantResult: stdDirective, + }, + { + args: []string{"p"}, + want: []string{"postgres://", "post/", "post.db"}, + wantResult: stdDirective, + }, + { + args: []string{"postgres:/"}, + want: []string{"postgres://"}, + wantResult: stdDirective, + }, + { + args: []string{"postgres://"}, + want: []string{ + "postgres://username", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice"}, + want: []string{ + "postgres://alice@", + "postgres://alice:", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice:"}, + want: []string{ + "postgres://alice:", + "postgres://alice:@", + "postgres://alice:password@", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@"}, + want: []string{ + "postgres://alice@localhost/", + "postgres://alice@localhost:5432/", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@server"}, + want: []string{ + "postgres://alice@server/", + "postgres://alice@server:5432/", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localho"}, + want: []string{ + "postgres://alice@localho/", + "postgres://alice@localho:5432/", + "postgres://alice@localhost/", + "postgres://alice@localhost:5432/", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost"}, + want: []string{ + "postgres://alice@localhost/", + "postgres://alice@localhost:5432/", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost:"}, + want: []string{ + "postgres://alice@localhost:5432/", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost:80"}, + want: []string{ + "postgres://alice@localhost:80/", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost/"}, + want: []string{ + "postgres://alice@localhost/db", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost/sakila"}, + want: []string{ + "postgres://alice@localhost/sakila?", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost/sakila?"}, + want: []string{ + "postgres://alice@localhost/sakila?application_name=", + "postgres://alice@localhost/sakila?channel_binding=", + "postgres://alice@localhost/sakila?connect_timeout=", + "postgres://alice@localhost/sakila?fallback_application_name=", + "postgres://alice@localhost/sakila?gssencmode=", + "postgres://alice@localhost/sakila?sslmode=", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost/sakila?ss"}, + want: []string{ + "postgres://alice@localhost/sakila?sslmode=", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost/sakila?a=1&b=2&ss"}, + want: []string{ + "postgres://alice@localhost/sakila?a=1&b=2&sslmode=", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost/sakila?a=1&b=2&sslmode"}, + want: []string{ + "postgres://alice@localhost/sakila?a=1&b=2&sslmode=", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost/sakila?sslmode="}, + want: []string{ + "postgres://alice@localhost/sakila?sslmode=disable", + "postgres://alice@localhost/sakila?sslmode=allow", + "postgres://alice@localhost/sakila?sslmode=prefer", + "postgres://alice@localhost/sakila?sslmode=require", + "postgres://alice@localhost/sakila?sslmode=verify-ca", + "postgres://alice@localhost/sakila?sslmode=verify-full", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost/sakila?sslmode=v"}, + want: []string{ + "postgres://alice@localhost/sakila?sslmode=verify-ca", + "postgres://alice@localhost/sakila?sslmode=verify-full", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost/sakila?sslmode=verify-"}, + want: []string{ + "postgres://alice@localhost/sakila?sslmode=verify-ca", + "postgres://alice@localhost/sakila?sslmode=verify-full", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost/sakila?sslmode=verify-ful"}, + want: []string{ + "postgres://alice@localhost/sakila?sslmode=verify-full", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost/sakila?sslmode=verify-full"}, + want: []string{ + "postgres://alice@localhost/sakila?sslmode=verify-full&", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost/sakila?sslmode=verify-full-something"}, + want: []string{ + "postgres://alice@localhost/sakila?sslmode=verify-full-something&", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@localhost/sakila?sslmode=disable"}, + want: []string{ + "postgres://alice@localhost/sakila?sslmode=disable&", + }, + wantResult: stdDirective, + }, + { + // Note the extra "?", which apparently is valid + args: []string{"postgres://alice@localhost/sakila?sslmode=disable?"}, + want: []string{"postgres://alice@localhost/sakila?sslmode=disable?&"}, + wantResult: stdDirective, + }, + { + // Being that sslmode is already specified, it should not appear a + // second time. + args: []string{"postgres://alice@localhost/sakila?sslmode=disable&"}, + want: []string{ + "postgres://alice@localhost/sakila?sslmode=disable&application_name=", + "postgres://alice@localhost/sakila?sslmode=disable&channel_binding=", + "postgres://alice@localhost/sakila?sslmode=disable&connect_timeout=", + "postgres://alice@localhost/sakila?sslmode=disable&fallback_application_name=", + "postgres://alice@localhost/sakila?sslmode=disable&gssencmode=", + }, + wantResult: stdDirective, + }, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + args := append([]string{"add"}, tc.args...) + got := testComplete(t, nil, args...) + assert.Equal(t, tc.wantResult, got.result, got.directives) + assert.Equal(t, tc.want, got.values) + }) + } +} + +func TestCompleteAddLocation_SQLServer(t *testing.T) { + tutil.SkipWindows(t, "Shell completion not fully implemented for windows") + + wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + t.Logf("Working dir: %s", wd) + + testCases := []struct { + args []string + want []string + wantResult cobra.ShellCompDirective + }{ + { + args: []string{"sqlse"}, + want: []string{"sqlserver://"}, + wantResult: stdDirective, + }, + + { + args: []string{"sqlserver:/"}, + want: []string{"sqlserver://"}, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://"}, + want: []string{ + "sqlserver://username", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@server"}, + want: []string{ + "sqlserver://alice@server?database=", + "sqlserver://alice@server:1433?database=", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@server/"}, + want: []string{"sqlserver://alice@server/instance?database="}, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@server/instance"}, + want: []string{"sqlserver://alice@server/instance?database="}, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@server?"}, + want: []string{ + "sqlserver://alice@server?database=", + "sqlserver://alice@server?ApplicationIntent=", + "sqlserver://alice@server?ServerSPN=", + "sqlserver://alice@server?TrustServerCertificate=", + "sqlserver://alice@server?Workstation+ID=", + "sqlserver://alice@server?app+name=", + "sqlserver://alice@server?certificate=", + "sqlserver://alice@server?connection+timeout=", + "sqlserver://alice@server?dial+timeout=", + "sqlserver://alice@server?encrypt=", + "sqlserver://alice@server?failoverpartner=", + "sqlserver://alice@server?failoverport=", + "sqlserver://alice@server?hostNameInCertificate=", + "sqlserver://alice@server?keepAlive=", + "sqlserver://alice@server?log=", + "sqlserver://alice@server?packet+size=", + "sqlserver://alice@server?protocol=", + "sqlserver://alice@server?tlsmin=", + "sqlserver://alice@server?user+id=", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@server?da"}, + want: []string{ + "sqlserver://alice@server?database=", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@server?database"}, + want: []string{ + "sqlserver://alice@server?database=", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@server?database=sakila"}, + want: []string{ + "sqlserver://alice@server?database=sakila&", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@server?database=sakila&tls"}, + want: []string{ + "sqlserver://alice@server?database=sakila&tlsmin=", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@server?database=sakila&tlsmin"}, + want: []string{ + "sqlserver://alice@server?database=sakila&tlsmin=", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@server?database=sakila&tlsmin="}, + want: []string{ + "sqlserver://alice@server?database=sakila&tlsmin=1.0", + "sqlserver://alice@server?database=sakila&tlsmin=1.1", + "sqlserver://alice@server?database=sakila&tlsmin=1.2", + "sqlserver://alice@server?database=sakila&tlsmin=1.3", + }, + wantResult: stdDirective, + }, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + args := append([]string{"add"}, tc.args...) + got := testComplete(t, nil, args...) + assert.Equal(t, tc.wantResult, got.result, got.directives) + assert.Equal(t, tc.want, got.values) + }) + } +} + +func TestCompleteAddLocation_MySQL(t *testing.T) { + tutil.SkipWindows(t, "Shell completion not fully implemented for windows") + + wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + t.Logf("Working dir: %s", wd) + + testCases := []struct { + args []string + want []string + wantResult cobra.ShellCompDirective + }{ + { + args: []string{"m"}, + want: []string{"mysql://", "my/", "my.db"}, + wantResult: stdDirective, + }, + { + args: []string{"my"}, + want: []string{"mysql://", "my/", "my.db"}, + wantResult: stdDirective, + }, + { + // When the input is definitively not a db url, the completion + // switches to the default shell (file) completion. + args: []string{"my/"}, + want: []string{}, + wantResult: cobra.ShellCompDirectiveDefault, + }, + { + args: []string{"mysql"}, + want: []string{"mysql://"}, + wantResult: stdDirective, + }, + { + args: []string{"mysql:/"}, + want: []string{"mysql://"}, + wantResult: stdDirective, + }, + { + args: []string{"mysql://"}, + want: []string{ + "mysql://username", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice"}, + want: []string{ + "mysql://alice@", + "mysql://alice:", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice:"}, + want: []string{ + "mysql://alice:", + "mysql://alice:@", + "mysql://alice:password@", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@"}, + want: []string{ + "mysql://alice@localhost/", + "mysql://alice@localhost:3306/", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@server"}, + want: []string{ + "mysql://alice@server/", + "mysql://alice@server:3306/", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localho"}, + want: []string{ + "mysql://alice@localho/", + "mysql://alice@localho:3306/", + "mysql://alice@localhost/", + "mysql://alice@localhost:3306/", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost"}, + want: []string{ + "mysql://alice@localhost/", + "mysql://alice@localhost:3306/", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost:"}, + want: []string{ + "mysql://alice@localhost:3306/", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost:80"}, + want: []string{ + "mysql://alice@localhost:80/", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost/"}, + want: []string{ + "mysql://alice@localhost/db", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost/sakila"}, + want: []string{ + "mysql://alice@localhost/sakila?", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost/sakila?"}, + want: []string{ + "mysql://alice@localhost/sakila?allowAllFiles=", + "mysql://alice@localhost/sakila?allowCleartextPasswords=", + "mysql://alice@localhost/sakila?allowFallbackToPlaintext=", + "mysql://alice@localhost/sakila?allowNativePasswords=", + "mysql://alice@localhost/sakila?allowOldPasswords=", + "mysql://alice@localhost/sakila?charset=", + "mysql://alice@localhost/sakila?checkConnLiveness=", + "mysql://alice@localhost/sakila?clientFoundRows=", + "mysql://alice@localhost/sakila?collation=", + "mysql://alice@localhost/sakila?columnsWithAlias=", + "mysql://alice@localhost/sakila?connectionAttributes=", + "mysql://alice@localhost/sakila?interpolateParams=", + "mysql://alice@localhost/sakila?loc=", + "mysql://alice@localhost/sakila?maxAllowedPackage=", + "mysql://alice@localhost/sakila?multiStatements=", + "mysql://alice@localhost/sakila?parseTime=", + "mysql://alice@localhost/sakila?readTimeout=", + "mysql://alice@localhost/sakila?rejectReadOnly=", + "mysql://alice@localhost/sakila?timeout=", + "mysql://alice@localhost/sakila?tls=", + "mysql://alice@localhost/sakila?writeTimeout=", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost/sakila?tl"}, + want: []string{ + "mysql://alice@localhost/sakila?tls=", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost/sakila?a=1&b=2&tl"}, + want: []string{ + "mysql://alice@localhost/sakila?a=1&b=2&tls=", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost/sakila?a=1&b=2&tls"}, + want: []string{ + "mysql://alice@localhost/sakila?a=1&b=2&tls=", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost/sakila?tls="}, + want: []string{ + "mysql://alice@localhost/sakila?tls=false", + "mysql://alice@localhost/sakila?tls=true", + "mysql://alice@localhost/sakila?tls=skip-verify", + "mysql://alice@localhost/sakila?tls=preferred", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost/sakila?tls=s"}, + want: []string{ + "mysql://alice@localhost/sakila?tls=skip-verify", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost/sakila?tls=skip-verify"}, + want: []string{ + "mysql://alice@localhost/sakila?tls=skip-verify&", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost/sakila?tls=skip-verify&lo"}, + want: []string{ + "mysql://alice@localhost/sakila?tls=skip-verify&loc=", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost/sakila?tls=skip-verify&loc="}, + want: []string{ + "mysql://alice@localhost/sakila?tls=skip-verify&loc=UTC", + }, + wantResult: stdDirective, + }, + { + args: []string{"mysql://alice@localhost/sakila?tls=skip-verify&loc=UTC"}, + want: []string{ + "mysql://alice@localhost/sakila?tls=skip-verify&loc=UTC&", + }, + wantResult: stdDirective, + }, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + args := append([]string{"add"}, tc.args...) + got := testComplete(t, nil, args...) + assert.Equal(t, tc.wantResult, got.result, got.directives) + assert.Equal(t, tc.want, got.values) + }) + } +} + +func TestCompleteAddLocation_SQLite3(t *testing.T) { + tutil.SkipWindows(t, "Shell completion not fully implemented for windows") + + wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + t.Logf("Working dir: %s", wd) + + testCases := []struct { + args []string + want []string + wantResult cobra.ShellCompDirective + }{ + { + args: []string{"s"}, + want: []string{"sqlite3://", "sqlserver://", "sqlite/", "sqlite.db"}, + wantResult: stdDirective, + }, + { + args: []string{"sqlite"}, + want: []string{"sqlite3://", "sqlite/", "sqlite.db"}, + wantResult: stdDirective, + }, + { + args: []string{"sqlite/"}, + want: []string{}, + wantResult: cobra.ShellCompDirectiveDefault, + }, + { + args: []string{"my/my_"}, + want: []string{}, + wantResult: cobra.ShellCompDirectiveDefault, + }, + { + args: []string{"sqlite3:"}, + want: []string{"sqlite3://"}, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3:/"}, + want: []string{"sqlite3://"}, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://"}, + want: []string{ + "sqlite3://data/", + "sqlite3://my/", + "sqlite3://my.db", + "sqlite3://post/", + "sqlite3://post.db", + "sqlite3://sqlite/", + "sqlite3://sqlite.db", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://my"}, + want: []string{ + "sqlite3://my/", + "sqlite3://my.db", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://my.d"}, + want: []string{"sqlite3://my.db"}, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://my.db"}, + want: []string{"sqlite3://my.db?"}, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://data/nest1/data.db"}, + want: []string{"sqlite3://data/nest1/data.db?", "sqlite3://data/nest1/data.db2"}, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://my.db?"}, + want: []string{ + "sqlite3://my.db?_auth=", + "sqlite3://my.db?_auth_crypt=", + "sqlite3://my.db?_auth_pass=", + "sqlite3://my.db?_auth_salt=", + "sqlite3://my.db?_auth_user=", + "sqlite3://my.db?_auto_vacuum=", + "sqlite3://my.db?_busy_timeout=", + "sqlite3://my.db?_cache_size=", + "sqlite3://my.db?_case_sensitive_like=", + "sqlite3://my.db?_defer_foreign_keys=", + "sqlite3://my.db?_foreign_keys=", + "sqlite3://my.db?_ignore_check_constraints=", + "sqlite3://my.db?_journal_mode=", + "sqlite3://my.db?_loc=", + "sqlite3://my.db?_locking_mode=", + "sqlite3://my.db?_mutex=", + "sqlite3://my.db?_query_only=", + "sqlite3://my.db?_recursive_triggers=", + "sqlite3://my.db?_secure_delete=", + "sqlite3://my.db?_synchronous=", + "sqlite3://my.db?_txlock=", + "sqlite3://my.db?cache=", + "sqlite3://my.db?mode=", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://my.db?_locking_"}, + want: []string{"sqlite3://my.db?_locking_mode="}, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://my.db?_locking_mode="}, + want: []string{ + "sqlite3://my.db?_locking_mode=NORMAL", + "sqlite3://my.db?_locking_mode=EXCLUSIVE", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://my.db?_locking_mode=NORM"}, + want: []string{"sqlite3://my.db?_locking_mode=NORMAL"}, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://my.db?_locking_mode=NORMAL"}, + want: []string{"sqlite3://my.db?_locking_mode=NORMAL&"}, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://my.db?_locking_mode=NORMAL"}, + want: []string{"sqlite3://my.db?_locking_mode=NORMAL&"}, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://my.db?_locking_mode=NORMAL&ca"}, + want: []string{"sqlite3://my.db?_locking_mode=NORMAL&cache="}, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://my.db?_locking_mode=NORMAL&cache="}, + want: []string{ + "sqlite3://my.db?_locking_mode=NORMAL&cache=true", + "sqlite3://my.db?_locking_mode=NORMAL&cache=false", + "sqlite3://my.db?_locking_mode=NORMAL&cache=FAST", + }, + wantResult: stdDirective, + }, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + args := append([]string{"add"}, tc.args...) + got := testComplete(t, nil, args...) + assert.Equal(t, tc.wantResult, got.result, got.directives) + assert.Equal(t, tc.want, got.values) + }) + } +} + +func testComplete(t testing.TB, from *testrun.TestRun, args ...string) completion { + ctx := lg.NewContext(context.Background(), slogt.New(t)) + + tr := testrun.New(ctx, t, from) + args = append([]string{"__complete"}, args...) + + err := tr.Exec(args...) + require.NoError(t, err) + + c := parseCompletion(tr) + return c +} + +// parseCompletion parses the output of cobra "__complete". +// Example output: +// +// @active +// @sakila +// :4 +// Completion ended with directive: ShellCompDirectiveNoFileComp +// +// The tr.T test will fail on any error. +func parseCompletion(tr *testrun.TestRun) completion { + c := completion{ + stdout: tr.Out.String(), + stderr: tr.ErrOut.String(), + } + + lines := strings.Split(strings.TrimSpace(c.stdout), "\n") + require.True(tr.T, len(lines) >= 1) + c.values = lines[:len(lines)-1] + + result, err := strconv.Atoi(lines[len(lines)-1][1:]) + require.NoError(tr.T, err) + c.result = cobra.ShellCompDirective(result) + + c.directives = cobraz.ParseDirectivesLine(c.stderr) + return c +} + +// completion models the result returned from the cobra "__complete" command. +type completion struct { + stdout string + stderr string + values []string + result cobra.ShellCompDirective + directives []cobra.ShellCompDirective +} + +func TestCompleteAddLocation_History_Postgres(t *testing.T) { + tutil.SkipWindows(t, "Shell completion not fully implemented for windows") + wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + t.Logf("Working dir: %s", wd) + + th := testh.New(t) + tr := testrun.New(th.Context, t, nil) + + tr.Add( + source.Source{ + Handle: "@src1", + Type: postgres.Type, + Location: "postgres://alice:abc123@dev.acme.com:7777/sakila?application_name=app1&channel_binding=prefer", + }, + source.Source{ + Handle: "@src2", + Type: postgres.Type, + Location: "postgres://bob:abc123@prod.acme.com:8888/sales?application_name=app2&channel_binding=require", + }, + ) + + testCases := []struct { + args []string + want []string + wantResult cobra.ShellCompDirective + }{ + { + args: []string{"postgres://"}, + want: []string{ + "postgres://alice", + "postgres://bob", + "postgres://username", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://a"}, + want: []string{ + "postgres://a@", + "postgres://a:", + "postgres://alice@", + "postgres://alice:", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice"}, + want: []string{ + "postgres://alice@", + "postgres://alice:", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@"}, + want: []string{ + "postgres://alice@localhost/", + "postgres://alice@localhost:5432/", + "postgres://alice@dev.acme.com:7777/", + "postgres://alice@dev.acme.com:7777/sakila?application_name=app1&channel_binding=prefer", + "postgres://alice@prod.acme.com:8888/", + "postgres://alice@prod.acme.com:8888/sales?application_name=app2&channel_binding=require", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@dev"}, + want: []string{ + "postgres://alice@dev/", + "postgres://alice@dev:5432/", + "postgres://alice@dev.acme.com:7777/", + "postgres://alice@dev.acme.com:7777/sakila?application_name=app1&channel_binding=prefer", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@dev.acme.com"}, + want: []string{ + "postgres://alice@dev.acme.com/", + "postgres://alice@dev.acme.com:5432/", + "postgres://alice@dev.acme.com:7777/", + "postgres://alice@dev.acme.com:7777/sakila?application_name=app1&channel_binding=prefer", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@dev.acme.com/"}, + want: []string{ + "postgres://alice@dev.acme.com/db", + "postgres://alice@dev.acme.com/sakila", + "postgres://alice@dev.acme.com/sales", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@dev.acme.com/sa"}, + want: []string{ + "postgres://alice@dev.acme.com/sa?", + "postgres://alice@dev.acme.com/sakila", + "postgres://alice@dev.acme.com/sales", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@dev.acme.com/sakila?"}, + want: []string{ + "postgres://alice@dev.acme.com/sakila?application_name=app1&channel_binding=prefer", + "postgres://alice@dev.acme.com/sakila?application_name=", + "postgres://alice@dev.acme.com/sakila?channel_binding=", + "postgres://alice@dev.acme.com/sakila?connect_timeout=", + "postgres://alice@dev.acme.com/sakila?fallback_application_name=", + "postgres://alice@dev.acme.com/sakila?gssencmode=", + "postgres://alice@dev.acme.com/sakila?sslmode=", + }, + wantResult: stdDirective, + }, + { + args: []string{"postgres://alice@dev.acme.com/sakila?app"}, + want: []string{ + "postgres://alice@dev.acme.com/sakila?application_name=app1&channel_binding=prefer", + "postgres://alice@dev.acme.com/sakila?application_name=", + }, + wantResult: stdDirective, + }, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + args := append([]string{"add"}, tc.args...) + got := testComplete(t, tr, args...) + assert.Equal(t, tc.wantResult, got.result, got.directives) + assert.Equal(t, tc.want, got.values) + }) + } +} + +func TestCompleteAddLocation_History_SQLServer(t *testing.T) { + tutil.SkipWindows(t, "Shell completion not fully implemented for windows") + wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + t.Logf("Working dir: %s", wd) + + th := testh.New(t) + tr := testrun.New(th.Context, t, nil) + + tr.Add( + source.Source{ + Handle: "@src1", + Type: sqlserver.Type, + Location: "sqlserver://alice:abc123@dev.acme.com:7777?database=sakila&app+name=app1&encrypt=disable", + }, + source.Source{ + Handle: "@src2", + Type: sqlserver.Type, + Location: "sqlserver://bob:abc123@prod.acme.com:8888?database=sales&app+name=app2&encrypt=true", + }, + source.Source{ + Handle: "@src3", + Type: sqlserver.Type, + Location: "sqlserver://bob:abc123@prod.acme.com:8888/my_instance?database=sakila", + }, + ) + + testCases := []struct { + args []string + want []string + wantResult cobra.ShellCompDirective + }{ + { + args: []string{"sqlserver://"}, + want: []string{ + "sqlserver://alice", + "sqlserver://bob", + "sqlserver://username", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://a"}, + want: []string{ + "sqlserver://a@", + "sqlserver://a:", + "sqlserver://alice@", + "sqlserver://alice:", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice"}, + want: []string{ + "sqlserver://alice@", + "sqlserver://alice:", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@"}, + want: []string{ + "sqlserver://alice@localhost?database=", + "sqlserver://alice@localhost:1433?database=", + "sqlserver://alice@dev.acme.com:7777?database=", + "sqlserver://alice@dev.acme.com:7777?database=sakila&app+name=app1&encrypt=disable", + "sqlserver://alice@prod.acme.com:8888/my_instance?database=sakila", + "sqlserver://alice@prod.acme.com:8888?database=", + "sqlserver://alice@prod.acme.com:8888?database=sales&app+name=app2&encrypt=true", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@dev"}, + want: []string{ + "sqlserver://alice@dev?database=", + "sqlserver://alice@dev:1433?database=", + "sqlserver://alice@dev.acme.com:7777?database=", + "sqlserver://alice@dev.acme.com:7777?database=sakila&app+name=app1&encrypt=disable", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@prod"}, + want: []string{ + "sqlserver://alice@prod?database=", + "sqlserver://alice@prod:1433?database=", + "sqlserver://alice@prod.acme.com:8888/my_instance?database=sakila", + "sqlserver://alice@prod.acme.com:8888?database=", + "sqlserver://alice@prod.acme.com:8888?database=sales&app+name=app2&encrypt=true", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@dev.acme.com"}, + want: []string{ + "sqlserver://alice@dev.acme.com?database=", + "sqlserver://alice@dev.acme.com:1433?database=", + "sqlserver://alice@dev.acme.com:7777?database=", + "sqlserver://alice@dev.acme.com:7777?database=sakila&app+name=app1&encrypt=disable", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@dev.acme.com?"}, + want: []string{ + "sqlserver://alice@dev.acme.com?database=sakila&app+name=app1&encrypt=disable", + "sqlserver://alice@dev.acme.com?database=sales&app+name=app2&encrypt=true", + "sqlserver://alice@dev.acme.com?database=", + "sqlserver://alice@dev.acme.com?ApplicationIntent=", + "sqlserver://alice@dev.acme.com?ServerSPN=", + "sqlserver://alice@dev.acme.com?TrustServerCertificate=", + "sqlserver://alice@dev.acme.com?Workstation+ID=", + "sqlserver://alice@dev.acme.com?app+name=", + "sqlserver://alice@dev.acme.com?certificate=", + "sqlserver://alice@dev.acme.com?connection+timeout=", + "sqlserver://alice@dev.acme.com?dial+timeout=", + "sqlserver://alice@dev.acme.com?encrypt=", + "sqlserver://alice@dev.acme.com?failoverpartner=", + "sqlserver://alice@dev.acme.com?failoverport=", + "sqlserver://alice@dev.acme.com?hostNameInCertificate=", + "sqlserver://alice@dev.acme.com?keepAlive=", + "sqlserver://alice@dev.acme.com?log=", + "sqlserver://alice@dev.acme.com?packet+size=", + "sqlserver://alice@dev.acme.com?protocol=", + "sqlserver://alice@dev.acme.com?tlsmin=", + "sqlserver://alice@dev.acme.com?user+id=", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@dev.acme.com?data"}, + want: []string{ + "sqlserver://alice@dev.acme.com?database=sakila&app+name=app1&encrypt=disable", + "sqlserver://alice@dev.acme.com?database=sales&app+name=app2&encrypt=true", + "sqlserver://alice@dev.acme.com?database=", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@dev.acme.com?database"}, + want: []string{ + "sqlserver://alice@dev.acme.com?database=sakila&app+name=app1&encrypt=disable", + "sqlserver://alice@dev.acme.com?database=sales&app+name=app2&encrypt=true", + "sqlserver://alice@dev.acme.com?database=", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@dev.acme.com?database="}, + want: []string{ + "sqlserver://alice@dev.acme.com?database=sakila&app+name=app1&encrypt=disable", + "sqlserver://alice@dev.acme.com?database=sales&app+name=app2&encrypt=true", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@dev.acme.com?database=sa"}, + want: []string{ + "sqlserver://alice@dev.acme.com?database=sakila&app+name=app1&encrypt=disable", + "sqlserver://alice@dev.acme.com?database=sales&app+name=app2&encrypt=true", + "sqlserver://alice@dev.acme.com?database=sa&", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@dev.acme.com?database=saki"}, + want: []string{ + "sqlserver://alice@dev.acme.com?database=sakila&app+name=app1&encrypt=disable", + "sqlserver://alice@dev.acme.com?database=saki&", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@dev.acme.com?database=sakila"}, + want: []string{ + "sqlserver://alice@dev.acme.com?database=sakila&app+name=app1&encrypt=disable", + "sqlserver://alice@dev.acme.com?database=sakila&", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlserver://alice@dev.acme.com?database=sakila&app"}, + want: []string{ + "sqlserver://alice@dev.acme.com?database=sakila&app+name=app1&encrypt=disable", + "sqlserver://alice@dev.acme.com?database=sakila&app+name=", + }, + wantResult: stdDirective, + }, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + args := append([]string{"add"}, tc.args...) + got := testComplete(t, tr, args...) + assert.Equal(t, tc.wantResult, got.result, got.directives) + assert.Equal(t, tc.want, got.values) + }) + } +} + +func TestCompleteAddLocation_History_SQLite3(t *testing.T) { + tutil.SkipWindows(t, "Shell completion not fully implemented for windows") + wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + t.Logf("Working dir: %s", wd) + src3Loc := "sqlite3://" + wd + "/my.db?cache=FAST" + + th := testh.New(t) + tr := testrun.New(th.Context, t, nil) + tr.Add( + source.Source{ + Handle: "@src2", + Type: sqlite3.Type, + // Note that this file doesn't actually exist + Location: "sqlite3:///zz_dir1/sqtest/sq/src2.db?mode=rwc&cache=FAST", + }, + source.Source{ + Handle: "@src1", + Type: sqlite3.Type, + // Note that this file doesn't actually exist + Location: "sqlite3:///zz_dir1/sqtest/sq/src1.db", + }, + source.Source{ + Handle: "@src3", + Type: sqlite3.Type, + // This file DOES exist + Location: src3Loc, + }, + ) + + testCases := []struct { + args []string + want []string + wantResult cobra.ShellCompDirective + }{ + { + args: []string{"sqlite3://"}, + want: []string{ + src3Loc, + "sqlite3:///zz_dir1/sqtest/sq/src1.db", + "sqlite3:///zz_dir1/sqtest/sq/src2.db?mode=rwc&cache=FAST", + "sqlite3://data/", + "sqlite3://my/", + "sqlite3://my.db", + "sqlite3://post/", + "sqlite3://post.db", + "sqlite3://sqlite/", + "sqlite3://sqlite.db", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3://my"}, + want: []string{ + "sqlite3://my/", + "sqlite3://my.db", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3:///zz_dir1/sqtest/"}, + want: []string{ + "sqlite3:///zz_dir1/sqtest/sq/src1.db", + "sqlite3:///zz_dir1/sqtest/sq/src2.db?mode=rwc&cache=FAST", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3:///zz_dir1/sqtest/sq/not_a_dir"}, + want: []string{}, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3:///zz_dir1/sqtest/sq/src"}, + want: []string{ + "sqlite3:///zz_dir1/sqtest/sq/src1.db", + "sqlite3:///zz_dir1/sqtest/sq/src2.db?mode=rwc&cache=FAST", + }, + wantResult: stdDirective, + }, + { + args: []string{"sqlite3:///zz_dir1/sqtest/sq/src1.db"}, + want: []string{}, // Empty because file doesn't actually exist + wantResult: stdDirective, + }, + { + args: []string{src3Loc}, + want: []string{src3Loc + "&"}, + wantResult: stdDirective, + }, + { + args: []string{src3Loc + "&"}, + want: []string{ + src3Loc + "&_auth=", + src3Loc + "&_auth_crypt=", + src3Loc + "&_auth_pass=", + src3Loc + "&_auth_salt=", + src3Loc + "&_auth_user=", + src3Loc + "&_auto_vacuum=", + src3Loc + "&_busy_timeout=", + src3Loc + "&_cache_size=", + src3Loc + "&_case_sensitive_like=", + src3Loc + "&_defer_foreign_keys=", + src3Loc + "&_foreign_keys=", + src3Loc + "&_ignore_check_constraints=", + src3Loc + "&_journal_mode=", + src3Loc + "&_loc=", + src3Loc + "&_locking_mode=", + src3Loc + "&_mutex=", + src3Loc + "&_query_only=", + src3Loc + "&_recursive_triggers=", + src3Loc + "&_secure_delete=", + src3Loc + "&_synchronous=", + src3Loc + "&_txlock=", + src3Loc + "&mode=", + }, + wantResult: stdDirective, + }, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + args := append([]string{"add"}, tc.args...) + got := testComplete(t, tr, args...) + assert.Equal(t, tc.wantResult, got.result, got.directives) + assert.Equal(t, tc.want, got.values) + }) + } +} + +func TestParseLoc_stage(t *testing.T) { + testCases := []struct { + loc string + want cli.PlocStage + }{ + {"", cli.PlocInit}, + {"postgres", cli.PlocInit}, + {"postgres:/", cli.PlocInit}, + {"postgres://", cli.PlocScheme}, + {"postgres://alice", cli.PlocScheme}, + {"postgres://alice:", cli.PlocUser}, + {"postgres://alice:pass", cli.PlocUser}, + {"postgres://alice:pass@", cli.PlocPass}, + {"postgres://alice:@", cli.PlocPass}, + {"postgres://alice@", cli.PlocPass}, + {"postgres://alice@localhost", cli.PlocPass}, + {"postgres://alice:@localhost", cli.PlocPass}, + {"postgres://alice:pass@localhost", cli.PlocPass}, + {"postgres://alice@localhost:", cli.PlocHostname}, + {"postgres://alice:@localhost:", cli.PlocHostname}, + {"postgres://alice:pass@localhost:", cli.PlocHostname}, + {"postgres://alice@localhost:5432", cli.PlocHostname}, + {"postgres://alice@localhost:5432/", cli.PlocHost}, + {"postgres://alice@localhost:5432/s", cli.PlocHost}, + {"postgres://alice@localhost:5432/sakila", cli.PlocHost}, + {"postgres://alice@localhost:5432/sakila?", cli.PlocPath}, + {"postgres://alice@localhost:5432/sakila?sslmode=verify-ca", cli.PlocPath}, + {"postgres://alice:@localhost:5432/sakila?sslmode=verify-ca", cli.PlocPath}, + {"postgres://alice:pass@localhost:5432/sakila?sslmode=verify-ca", cli.PlocPath}, + {"sqlserver://alice:pass@localhost?", cli.PlocPath}, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, tc.loc), func(t *testing.T) { + th := testh.New(t) + ru := th.Run() + + gotStage, err := cli.DoTestParseLocStage(t, ru, tc.loc) + require.NoError(t, err) + require.Equal(t, tc.want, gotStage) + }) + } +} + +func TestDoCompleteAddLocationFile(t *testing.T) { + tutil.SkipWindows(t, "Shell completion not fully implemented for windows") + + absDir := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + t.Logf("Working dir: %s", absDir) + + testCases := []struct { + in string + want []string + }{ + {"", []string{"data/", "my/", "my.db", "post/", "post.db", "sqlite/", "sqlite.db"}}, + {"m", []string{"my/", "my.db"}}, + {"my", []string{"my/", "my.db"}}, + {"my/", []string{"my/my1.db", "my/my_nest/"}}, + {"my/my", []string{"my/my1.db", "my/my_nest/"}}, + {"my/my1", []string{"my/my1.db"}}, + {"my/my1.db", []string{"my/my1.db"}}, + {"my/my_nes", []string{"my/my_nest/"}}, + {"my/my_nest", []string{"my/my_nest/"}}, + {"my/my_nest/", []string{"my/my_nest/my2.db"}}, + {"data/nest1/", []string{"data/nest1/data.db", "data/nest1/data.db2"}}, + {"data/nest1/data.db", []string{"data/nest1/data.db", "data/nest1/data.db2"}}, + { + absDir + "/", + stringz.PrefixSlice([]string{ + "data/", "my/", "my.db", "post/", "post.db", "sqlite/", "sqlite.db", + }, absDir+"/"), + }, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + ctx := lg.NewContext(context.Background(), slogt.New(t)) + t.Logf("input: %s", tc.in) + t.Logf("want: %s", tc.want) + got := cli.DoCompleteAddLocationFile(ctx, tc.in) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/cli/options.go b/cli/options.go index b20531a2..6b68eec2 100644 --- a/cli/options.go +++ b/cli/options.go @@ -226,7 +226,7 @@ func addOptionFlag(flags *pflag.FlagSet, opt options.Opt) (key string) { } flags.BoolP(key, string(opt.Short()), opt.Default(), opt.Usage()) - return + return key case options.Duration: if opt.Short() == 0 { flags.Duration(key, opt.Default(), opt.Usage()) diff --git a/cli/run.go b/cli/run.go index e050c15b..7c8ef5ad 100644 --- a/cli/run.go +++ b/cli/run.go @@ -116,59 +116,18 @@ func newRun(ctx context.Context, stdin *os.File, stdout, stderr io.Writer, args return ru, log, nil } -// preRun is invoked by cobra prior to the command's RunE being -// invoked. It sets up the driver registry, databases, writers and related -// fundamental components. Subsequent invocations of this method -// are no-op. -func preRun(cmd *cobra.Command, ru *run.Run) error { - if ru == nil { - return errz.New("Run is nil") - } - - if ru.Writers != nil { - // If ru.Writers is already set, then this function has already been - // called on ru. That's ok, just return. - return nil - } - +// FinishRunInit finishes setting up ru. +// +// TODO: This run.Run initialization mechanism is a bit of a mess. +// There's logic in newRun, preRun, FinishRunInit, as well as testh.Helper.init. +// Surely the init logic can be consolidated. +func FinishRunInit(ctx context.Context, ru *run.Run) error { if ru.Cleanup == nil { ru.Cleanup = cleanup.New() } - ctx := cmd.Context() cfg, log := ru.Config, lg.FromContext(ctx) - // If the --output=/some/file flag is set, then we need to - // override ru.Out (which is typically stdout) to point it at - // the output destination file. - if cmdFlagChanged(ru.Cmd, flag.Output) { - fpath, _ := ru.Cmd.Flags().GetString(flag.Output) - fpath, err := filepath.Abs(fpath) - if err != nil { - return errz.Wrapf(err, "failed to get absolute path for --%s", flag.Output) - } - - // Ensure the parent dir exists - err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm) - if err != nil { - return errz.Wrapf(err, "failed to make parent dir for --%s", flag.Output) - } - - f, err := os.Create(fpath) - if err != nil { - return errz.Wrapf(err, "failed to open file specified by flag --%s", flag.Output) - } - - ru.Cleanup.AddC(f) // Make sure the file gets closed eventually - ru.Out = f - } - - cmdOpts, err := getOptionsFromCmd(ru.Cmd) - if err != nil { - return err - } - ru.Writers, ru.Out, ru.ErrOut = newWriters(ru.Cmd, cmdOpts, ru.Out, ru.ErrOut) - var scratchSrcFunc driver.ScratchSrcFunc // scratchSrc could be nil, and that's ok @@ -181,10 +140,13 @@ func preRun(cmd *cobra.Command, ru *run.Run) error { } } - ru.Files, err = source.NewFiles(ctx) - if err != nil { - lg.WarnIfFuncError(log, lga.Cleanup, ru.Cleanup.Run) - return err + var err error + if ru.Files == nil { + ru.Files, err = source.NewFiles(ctx) + if err != nil { + lg.WarnIfFuncError(log, lga.Cleanup, ru.Cleanup.Run) + return err + } } // Note: it's important that files.Close is invoked @@ -260,3 +222,57 @@ func preRun(cmd *cobra.Command, ru *run.Run) error { return nil } + +// preRun is invoked by cobra prior to the command's RunE being +// invoked. It sets up the driver registry, databases, writers and related +// fundamental components. Subsequent invocations of this method +// are no-op. +func preRun(cmd *cobra.Command, ru *run.Run) error { + if ru == nil { + return errz.New("Run is nil") + } + + if ru.Writers != nil { + // If ru.Writers is already set, then this function has already been + // called on ru. That's ok, just return. + return nil + } + + if ru.Cleanup == nil { + ru.Cleanup = cleanup.New() + } + + ctx := cmd.Context() + // If the --output=/some/file flag is set, then we need to + // override ru.Out (which is typically stdout) to point it at + // the output destination file. + if cmdFlagChanged(ru.Cmd, flag.Output) { + fpath, _ := ru.Cmd.Flags().GetString(flag.Output) + fpath, err := filepath.Abs(fpath) + if err != nil { + return errz.Wrapf(err, "failed to get absolute path for --%s", flag.Output) + } + + // Ensure the parent dir exists + err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm) + if err != nil { + return errz.Wrapf(err, "failed to make parent dir for --%s", flag.Output) + } + + f, err := os.Create(fpath) + if err != nil { + return errz.Wrapf(err, "failed to open file specified by flag --%s", flag.Output) + } + + ru.Cleanup.AddC(f) // Make sure the file gets closed eventually + ru.Out = f + } + + cmdOpts, err := getOptionsFromCmd(ru.Cmd) + if err != nil { + return err + } + ru.Writers, ru.Out, ru.ErrOut = newWriters(ru.Cmd, cmdOpts, ru.Out, ru.ErrOut) + + return FinishRunInit(ctx, ru) +} diff --git a/cli/testdata/add_location/data/nest1/data.db b/cli/testdata/add_location/data/nest1/data.db new file mode 100644 index 00000000..e69de29b diff --git a/cli/testdata/add_location/data/nest1/data.db2 b/cli/testdata/add_location/data/nest1/data.db2 new file mode 100644 index 00000000..e69de29b diff --git a/cli/testdata/add_location/my.db b/cli/testdata/add_location/my.db new file mode 100644 index 00000000..e69de29b diff --git a/cli/testdata/add_location/my/my1.db b/cli/testdata/add_location/my/my1.db new file mode 100644 index 00000000..e69de29b diff --git a/cli/testdata/add_location/my/my_nest/my2.db b/cli/testdata/add_location/my/my_nest/my2.db new file mode 100644 index 00000000..e69de29b diff --git a/cli/testdata/add_location/post.db b/cli/testdata/add_location/post.db new file mode 100644 index 00000000..e69de29b diff --git a/cli/testdata/add_location/post/post2.db b/cli/testdata/add_location/post/post2.db new file mode 100644 index 00000000..e69de29b diff --git a/cli/testdata/add_location/sqlite.db b/cli/testdata/add_location/sqlite.db new file mode 100644 index 00000000..e69de29b diff --git a/cli/testdata/add_location/sqlite/sqlite2.db b/cli/testdata/add_location/sqlite/sqlite2.db new file mode 100644 index 00000000..e69de29b diff --git a/cli/testrun/testrun.go b/cli/testrun/testrun.go index 8b713bb0..3a3cbf79 100644 --- a/cli/testrun/testrun.go +++ b/cli/testrun/testrun.go @@ -26,13 +26,13 @@ import ( // TestRun is a helper for testing sq commands. type TestRun struct { - T *testing.T - ctx context.Context - mu sync.Mutex - Run *run.Run - Out *bytes.Buffer - ErrOut *bytes.Buffer - used bool + T testing.TB + Context context.Context + mu sync.Mutex + Run *run.Run + Out *bytes.Buffer + ErrOut *bytes.Buffer + used bool // When true, out and errOut are not logged. hushOutput bool @@ -41,8 +41,12 @@ type TestRun struct { // New returns a new run instance for testing sq commands. // If from is non-nil, its config is used. This allows sequential // commands to use the same config. -func New(ctx context.Context, t *testing.T, from *TestRun) *TestRun { - tr := &TestRun{T: t, ctx: ctx} +func New(ctx context.Context, t testing.TB, from *TestRun) *TestRun { + if ctx == nil { + ctx = context.Background() + } + + tr := &TestRun{T: t, Context: ctx} var cfgStore config.Store if from != nil { @@ -92,12 +96,15 @@ func newRun(ctx context.Context, t testing.TB, cfgStore config.Store) (ru *run.R OptionsRegistry: optsReg, } + require.NoError(t, cli.FinishRunInit(ctx, ru)) return ru, out, errOut } // Add adds srcs to tr.Run.Config.Collection. If the collection // does not already have an active source, the first element // of srcs is used as the active source. +// +// REVISIT: Why not use *source.Source instead of the value? func (tr *TestRun) Add(srcs ...source.Source) *TestRun { tr.mu.Lock() defer tr.mu.Unlock() @@ -119,6 +126,9 @@ func (tr *TestRun) Add(srcs ...source.Source) *TestRun { require.NoError(tr.T, err) } + err := tr.Run.ConfigStore.Save(tr.Context, tr.Run.Config) + require.NoError(tr.T, err) + return tr } @@ -141,7 +151,7 @@ func (tr *TestRun) doExec(args []string) error { require.False(tr.T, tr.used, "TestRun instance must only be used once") - ctx, cancelFn := context.WithCancel(context.Background()) + ctx, cancelFn := context.WithCancel(tr.Context) tr.T.Cleanup(cancelFn) execErr := cli.ExecuteWith(ctx, tr.Run, args) diff --git a/drivers/mysql/collation.go b/drivers/mysql/collation.go new file mode 100644 index 00000000..fbdceddd --- /dev/null +++ b/drivers/mysql/collation.go @@ -0,0 +1,276 @@ +package mysql + +var collations = []string{ + "armscii8_bin", + "armscii8_general_ci", + "ascii_bin", + "ascii_general_ci", + "big5_bin", + "big5_chinese_ci", + "binary", + "cp1250_bin", + "cp1250_croatian_ci", + "cp1250_czech_cs", + "cp1250_general_ci", + "cp1250_polish_ci", + "cp1251_bin", + "cp1251_bulgarian_ci", + "cp1251_general_ci", + "cp1251_general_cs", + "cp1251_ukrainian_ci", + "cp1256_bin", + "cp1256_general_ci", + "cp1257_bin", + "cp1257_general_ci", + "cp1257_lithuanian_ci", + "cp850_bin", + "cp850_general_ci", + "cp852_bin", + "cp852_general_ci", + "cp866_bin", + "cp866_general_ci", + "cp932_bin", + "cp932_japanese_ci", + "dec8_bin", + "dec8_swedish_ci", + "eucjpms_bin", + "eucjpms_japanese_ci", + "euckr_bin", + "euckr_korean_ci", + "gb18030_bin", + "gb18030_chinese_ci", + "gb18030_unicode_520_ci", + "gb2312_bin", + "gb2312_chinese_ci", + "gbk_bin", + "gbk_chinese_ci", + "geostd8_bin", + "geostd8_general_ci", + "greek_bin", + "greek_general_ci", + "hebrew_bin", + "hebrew_general_ci", + "hp8_bin", + "hp8_english_ci", + "keybcs2_bin", + "keybcs2_general_ci", + "koi8r_bin", + "koi8r_general_ci", + "koi8u_bin", + "koi8u_general_ci", + "latin1_bin", + "latin1_danish_ci", + "latin1_general_ci", + "latin1_general_cs", + "latin1_german1_ci", + "latin1_german2_ci", + "latin1_spanish_ci", + "latin1_swedish_ci", + "latin2_bin", + "latin2_croatian_ci", + "latin2_czech_cs", + "latin2_general_ci", + "latin2_hungarian_ci", + "latin5_bin", + "latin5_turkish_ci", + "latin7_bin", + "latin7_estonian_cs", + "latin7_general_ci", + "latin7_general_cs", + "macce_bin", + "macce_general_ci", + "macroman_bin", + "macroman_general_ci", + "sjis_bin", + "sjis_japanese_ci", + "swe7_bin", + "swe7_swedish_ci", + "tis620_bin", + "tis620_thai_ci", + "ucs2_bin", + "ucs2_croatian_ci", + "ucs2_czech_ci", + "ucs2_danish_ci", + "ucs2_esperanto_ci", + "ucs2_estonian_ci", + "ucs2_general_ci", + "ucs2_general_mysql500_ci", + "ucs2_german2_ci", + "ucs2_hungarian_ci", + "ucs2_icelandic_ci", + "ucs2_latvian_ci", + "ucs2_lithuanian_ci", + "ucs2_persian_ci", + "ucs2_polish_ci", + "ucs2_romanian_ci", + "ucs2_roman_ci", + "ucs2_sinhala_ci", + "ucs2_slovak_ci", + "ucs2_slovenian_ci", + "ucs2_spanish2_ci", + "ucs2_spanish_ci", + "ucs2_swedish_ci", + "ucs2_turkish_ci", + "ucs2_unicode_520_ci", + "ucs2_unicode_ci", + "ucs2_vietnamese_ci", + "ujis_bin", + "ujis_japanese_ci", + "utf16le_bin", + "utf16le_general_ci", + "utf16_bin", + "utf16_croatian_ci", + "utf16_czech_ci", + "utf16_danish_ci", + "utf16_esperanto_ci", + "utf16_estonian_ci", + "utf16_general_ci", + "utf16_german2_ci", + "utf16_hungarian_ci", + "utf16_icelandic_ci", + "utf16_latvian_ci", + "utf16_lithuanian_ci", + "utf16_persian_ci", + "utf16_polish_ci", + "utf16_romanian_ci", + "utf16_roman_ci", + "utf16_sinhala_ci", + "utf16_slovak_ci", + "utf16_slovenian_ci", + "utf16_spanish2_ci", + "utf16_spanish_ci", + "utf16_swedish_ci", + "utf16_turkish_ci", + "utf16_unicode_520_ci", + "utf16_unicode_ci", + "utf16_vietnamese_ci", + "utf32_bin", + "utf32_croatian_ci", + "utf32_czech_ci", + "utf32_danish_ci", + "utf32_esperanto_ci", + "utf32_estonian_ci", + "utf32_general_ci", + "utf32_german2_ci", + "utf32_hungarian_ci", + "utf32_icelandic_ci", + "utf32_latvian_ci", + "utf32_lithuanian_ci", + "utf32_persian_ci", + "utf32_polish_ci", + "utf32_romanian_ci", + "utf32_roman_ci", + "utf32_sinhala_ci", + "utf32_slovak_ci", + "utf32_slovenian_ci", + "utf32_spanish2_ci", + "utf32_spanish_ci", + "utf32_swedish_ci", + "utf32_turkish_ci", + "utf32_unicode_520_ci", + "utf32_unicode_ci", + "utf32_vietnamese_ci", + "utf8mb4_0900_ai_ci", + "utf8mb4_0900_as_ci", + "utf8mb4_0900_as_cs", + "utf8mb4_0900_bin", + "utf8mb4_bin", + "utf8mb4_croatian_ci", + "utf8mb4_cs_0900_ai_ci", + "utf8mb4_cs_0900_as_cs", + "utf8mb4_czech_ci", + "utf8mb4_danish_ci", + "utf8mb4_da_0900_ai_ci", + "utf8mb4_da_0900_as_cs", + "utf8mb4_de_pb_0900_ai_ci", + "utf8mb4_de_pb_0900_as_cs", + "utf8mb4_eo_0900_ai_ci", + "utf8mb4_eo_0900_as_cs", + "utf8mb4_esperanto_ci", + "utf8mb4_estonian_ci", + "utf8mb4_es_0900_ai_ci", + "utf8mb4_es_0900_as_cs", + "utf8mb4_es_trad_0900_ai_ci", + "utf8mb4_es_trad_0900_as_cs", + "utf8mb4_et_0900_ai_ci", + "utf8mb4_et_0900_as_cs", + "utf8mb4_general_ci", + "utf8mb4_german2_ci", + "utf8mb4_hr_0900_ai_ci", + "utf8mb4_hr_0900_as_cs", + "utf8mb4_hungarian_ci", + "utf8mb4_hu_0900_ai_ci", + "utf8mb4_hu_0900_as_cs", + "utf8mb4_icelandic_ci", + "utf8mb4_is_0900_ai_ci", + "utf8mb4_is_0900_as_cs", + "utf8mb4_ja_0900_as_cs", + "utf8mb4_ja_0900_as_cs_ks", + "utf8mb4_latvian_ci", + "utf8mb4_la_0900_ai_ci", + "utf8mb4_la_0900_as_cs", + "utf8mb4_lithuanian_ci", + "utf8mb4_lt_0900_ai_ci", + "utf8mb4_lt_0900_as_cs", + "utf8mb4_lv_0900_ai_ci", + "utf8mb4_lv_0900_as_cs", + "utf8mb4_persian_ci", + "utf8mb4_pl_0900_ai_ci", + "utf8mb4_pl_0900_as_cs", + "utf8mb4_polish_ci", + "utf8mb4_romanian_ci", + "utf8mb4_roman_ci", + "utf8mb4_ro_0900_ai_ci", + "utf8mb4_ro_0900_as_cs", + "utf8mb4_ru_0900_ai_ci", + "utf8mb4_ru_0900_as_cs", + "utf8mb4_sinhala_ci", + "utf8mb4_sk_0900_ai_ci", + "utf8mb4_sk_0900_as_cs", + "utf8mb4_slovak_ci", + "utf8mb4_slovenian_ci", + "utf8mb4_sl_0900_ai_ci", + "utf8mb4_sl_0900_as_cs", + "utf8mb4_spanish2_ci", + "utf8mb4_spanish_ci", + "utf8mb4_sv_0900_ai_ci", + "utf8mb4_sv_0900_as_cs", + "utf8mb4_swedish_ci", + "utf8mb4_tr_0900_ai_ci", + "utf8mb4_tr_0900_as_cs", + "utf8mb4_turkish_ci", + "utf8mb4_unicode_520_ci", + "utf8mb4_unicode_ci", + "utf8mb4_vietnamese_ci", + "utf8mb4_vi_0900_ai_ci", + "utf8mb4_vi_0900_as_cs", + "utf8mb4_zh_0900_as_cs", + "utf8_bin", + "utf8_croatian_ci", + "utf8_czech_ci", + "utf8_danish_ci", + "utf8_esperanto_ci", + "utf8_estonian_ci", + "utf8_general_ci", + "utf8_general_mysql500_ci", + "utf8_german2_ci", + "utf8_hungarian_ci", + "utf8_icelandic_ci", + "utf8_latvian_ci", + "utf8_lithuanian_ci", + "utf8_persian_ci", + "utf8_polish_ci", + "utf8_romanian_ci", + "utf8_roman_ci", + "utf8_sinhala_ci", + "utf8_slovak_ci", + "utf8_slovenian_ci", + "utf8_spanish2_ci", + "utf8_spanish_ci", + "utf8_swedish_ci", + "utf8_tolower_ci", + "utf8_turkish_ci", + "utf8_unicode_520_ci", + "utf8_unicode_ci", + "utf8_vietnamese_ci", +} diff --git a/drivers/mysql/mysql.go b/drivers/mysql/mysql.go index d5b354e2..019a6eb0 100644 --- a/drivers/mysql/mysql.go +++ b/drivers/mysql/mysql.go @@ -57,6 +57,34 @@ type driveri struct { log *slog.Logger } +// ConnParams implements driver.SQLDriver. +// See: https://github.com/go-sql-driver/mysql#dsn-data-source-name. +func (d *driveri) ConnParams() map[string][]string { + return map[string][]string{ + "allowAllFiles": {"false", "true"}, + "allowCleartextPasswords": {"false", "true"}, + "allowFallbackToPlaintext": {"false", "true"}, + "allowNativePasswords": {"false", "true"}, + "allowOldPasswords": {"false", "true"}, + "charset": nil, + "checkConnLiveness": {"true", "false"}, + "clientFoundRows": {"false", "true"}, + "collation": collations, + "columnsWithAlias": {"false", "true"}, + "connectionAttributes": nil, + "interpolateParams": {"false", "true"}, + "loc": {"UTC"}, + "maxAllowedPackage": {"0", "67108864"}, + "multiStatements": {"false", "true"}, + "parseTime": {"false", "true"}, + "readTimeout": {"0"}, + "rejectReadOnly": {"false", "true"}, + "timeout": nil, + "tls": {"false", "true", "skip-verify", "preferred"}, + "writeTimeout": {"0"}, + } +} + // ErrWrapFunc implements driver.SQLDriver. func (d *driveri) ErrWrapFunc() func(error) error { return errw @@ -74,6 +102,7 @@ func (d *driveri) DriverMetadata() driver.Metadata { Description: "MySQL", Doc: "https://github.com/go-sql-driver/mysql", IsSQL: true, + DefaultPort: 3306, } } diff --git a/drivers/postgres/postgres.go b/drivers/postgres/postgres.go index af6f85ed..e6d168d2 100644 --- a/drivers/postgres/postgres.go +++ b/drivers/postgres/postgres.go @@ -62,11 +62,26 @@ func (p *Provider) DriverFor(typ source.DriverType) (driver.Driver, error) { return &driveri{log: p.Log}, nil } +var _ driver.SQLDriver = (*driveri)(nil) + // driveri is the postgres implementation of driver.Driver. type driveri struct { log *slog.Logger } +// ConnParams implements driver.SQLDriver. +func (d *driveri) ConnParams() map[string][]string { + // https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS + return map[string][]string{ + "channel_binding": {"prefer", "require", "disable"}, + "connect_timeout": {"2"}, + "application_name": nil, + "fallback_application_name": nil, + "gssencmode": {"disable", "prefer", "require"}, + "sslmode": {"disable", "allow", "prefer", "require", "verify-ca", "verify-full"}, + } +} + // ErrWrapFunc implements driver.SQLDriver. func (d *driveri) ErrWrapFunc() func(error) error { return errw @@ -79,6 +94,7 @@ func (d *driveri) DriverMetadata() driver.Metadata { Description: "PostgreSQL", Doc: "https://github.com/jackc/pgx", IsSQL: true, + DefaultPort: 5432, } } diff --git a/drivers/sqlite3/pragma.go b/drivers/sqlite3/pragma.go index be24e7bb..dd202e8a 100644 --- a/drivers/sqlite3/pragma.go +++ b/drivers/sqlite3/pragma.go @@ -52,7 +52,7 @@ func readPragma(ctx context.Context, db sqlz.DB, pragma string) (any, error) { // Some of the pragmas can't be selected from. Ignore these. // SQLite returns a generic (1) SQLITE_ERROR in this case, // so we match using the error string. - return nil, nil + return nil, nil //nolint:nilnil } return nil, errw(err) @@ -61,7 +61,7 @@ func readPragma(ctx context.Context, db sqlz.DB, pragma string) (any, error) { defer lg.WarnIfCloseError(lg.FromContext(ctx), lgm.CloseDBRows, rows) if !rows.Next() { - return nil, nil + return nil, nil //nolint:nilnil } cols, err := rows.Columns() @@ -72,7 +72,7 @@ func readPragma(ctx context.Context, db sqlz.DB, pragma string) (any, error) { switch len(cols) { case 0: // Shouldn't happen - return nil, nil + return nil, nil //nolint:nilnil case 1: var val any if err = rows.Scan(&val); err != nil { diff --git a/drivers/sqlite3/sqlite3.go b/drivers/sqlite3/sqlite3.go index e5abffcf..d61c462e 100644 --- a/drivers/sqlite3/sqlite3.go +++ b/drivers/sqlite3/sqlite3.go @@ -54,10 +54,6 @@ var _ driver.Provider = (*Provider)(nil) // Provider is the SQLite3 implementation of driver.Provider. type Provider struct { Log *slog.Logger - - // NOTE: Unlike other driver.SQLDriver impls, sqlite doesn't - // seem to benefit from applying a driver.SQLConfig to - // its sql.DB. } // DriverFor implements driver.Provider. @@ -69,13 +65,43 @@ func (p *Provider) DriverFor(typ source.DriverType) (driver.Driver, error) { return &driveri{log: p.Log}, nil } -var _ driver.Driver = (*driveri)(nil) +var _ driver.SQLDriver = (*driveri)(nil) -// driveri is the SQLite3 implementation of driver.Driver. +// driveri is the SQLite3 implementation of driver.SQLDriver. type driveri struct { log *slog.Logger } +// ConnParams implements driver.SQLDriver. +// See: https://github.com/mattn/go-sqlite3#connection-string. +func (d *driveri) ConnParams() map[string][]string { + return map[string][]string{ + "_auth": nil, + "_auth_crypt": {"SHA1", "SSHA1", "SHA256", "SSHA256", "SHA384", "SSHA384", "SHA512", "SSHA512"}, + "_auth_pass": nil, + "_auth_salt": nil, + "_auth_user": nil, + "_auto_vacuum": {"none", "full", "incremental"}, + "_busy_timeout": nil, + "_cache_size": {"-2000"}, + "_case_sensitive_like": {"true", "false"}, + "_defer_foreign_keys": {"true", "false"}, + "_foreign_keys": {"true", "false"}, + "_ignore_check_constraints": {"true", "false"}, + "_journal_mode": {"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"}, + "_loc": nil, + "_locking_mode": {"NORMAL", "EXCLUSIVE"}, + "_mutex": {"no", "full"}, + "_query_only": {"true", "false"}, + "_recursive_triggers": {"true", "false"}, + "_secure_delete": {"true", "false", "FAST"}, + "_synchronous": {"OFF", "NORMAL", "FULL", "EXTRA"}, + "_txlock": {"immediate", "deferred", "exclusive"}, + "cache": {"true", "false", "FAST"}, + "mode": {"ro", "rw", "rwc", "memory"}, + } +} + // ErrWrapFunc implements driver.SQLDriver. func (d *driveri) ErrWrapFunc() func(error) error { return errw diff --git a/drivers/sqlserver/sqlserver.go b/drivers/sqlserver/sqlserver.go index 5ddfba05..fc4f8116 100644 --- a/drivers/sqlserver/sqlserver.go +++ b/drivers/sqlserver/sqlserver.go @@ -61,6 +61,32 @@ type driveri struct { log *slog.Logger } +// ConnParams implements driver.SQLDriver. +func (d *driveri) ConnParams() map[string][]string { + // https://github.com/microsoft/go-mssqldb#connection-parameters-and-dsn. + return map[string][]string{ + "ApplicationIntent": {"ReadOnly"}, + "ServerSPN": nil, + "TrustServerCertificate": {"false", "true"}, + "Workstation ID": nil, + "app name": {"sq"}, + "certificate": nil, + "connection timeout": {"0"}, + "database": nil, + "dial timeout": {"0"}, + "encrypt": {"disable", "false", "true"}, + "failoverpartner": nil, + "failoverport": {"1433"}, + "hostNameInCertificate": nil, + "keepAlive": {"0", "30"}, + "log": {"0", "1", "2", "4", "8", "16", "32", "64", "128", "255"}, + "packet size": {"512", "4096", "16383", "32767"}, + "protocol": nil, + "tlsmin": {"1.0", "1.1", "1.2", "1.3"}, + "user id": nil, + } +} + // ErrWrapFunc implements driver.SQLDriver. func (d *driveri) ErrWrapFunc() func(error) error { return errw @@ -73,6 +99,7 @@ func (d *driveri) DriverMetadata() driver.Metadata { Description: "Microsoft SQL Server / Azure SQL Edge", Doc: "https://github.com/microsoft/go-mssqldb", IsSQL: true, + DefaultPort: 1433, } } diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 02783f23..a56696a1 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -6,6 +6,8 @@ import ( "context" "io" "os" + "path/filepath" + "strings" "github.com/neilotoole/sq/libsq/core/lg" @@ -83,3 +85,92 @@ func MarshalYAML(v any) ([]byte, error) { func UnmarshallYAML(data []byte, v any) error { return errz.Err(yaml.Unmarshal(data, v)) } + +// ReadDir lists the contents of dir, returning the relative paths +// of the files. If markDirs is true, directories are listed with +// a "/" suffix (including symlinked dirs). If includeDirPath is true, +// the listing is of the form "dir/name". If includeDot is true, +// files beginning with period (dot files) are included. The function +// attempts to continue in the present of errors: the returned paths +// may contain values even in the presence of a returned error (which +// may be a multierr). +func ReadDir(dir string, includeDirPath, markDirs, includeDot bool) (paths []string, err error) { + fi, err := os.Stat(dir) + if err != nil { + return nil, errz.Err(err) + } + + if !fi.Mode().IsDir() { + return nil, errz.Errorf("not a dir: %s", dir) + } + + var entries []os.DirEntry + if entries, err = os.ReadDir(dir); err != nil { + return nil, errz.Err(err) + } + + var name string + for _, entry := range entries { + name = entry.Name() + if strings.HasPrefix(name, ".") && !includeDot { + // Skip invisible files + continue + } + + mode := entry.Type() + if !mode.IsRegular() && markDirs { + if entry.IsDir() { + name += "/" + } else if mode&os.ModeSymlink != 0 { + // Follow the symlink to detect if it's a dir + linked, err2 := filepath.EvalSymlinks(filepath.Join(dir, name)) + if err2 != nil { + err = errz.Append(err, errz.Err(err2)) + continue + } + + fi, err2 = os.Stat(linked) + if err2 != nil { + err = errz.Append(err, errz.Err(err2)) + continue + } + + if fi.IsDir() { + name += "/" + } + } + } + + paths = append(paths, name) + } + + if includeDirPath { + for i := range paths { + // filepath.Join strips the "/" suffix, so we need to preserve it. + hasSlashSuffix := strings.HasSuffix(paths[i], "/") + paths[i] = filepath.Join(dir, paths[i]) + if hasSlashSuffix { + paths[i] += "/" + } + } + } + + return paths, nil +} + +// IsPathToRegularFile return true if path is a regular file or +// a symlink that resolves to a regular file. False is returned on +// any error. +func IsPathToRegularFile(path string) bool { + dest, err := filepath.EvalSymlinks(path) + if err != nil { + return false + } + + fi, err := os.Stat(dest) + if err != nil { + return false + } + + return fi.Mode().IsRegular() +} diff --git a/libsq/core/kind/detect.go b/libsq/core/kind/detect.go index 77f0f13a..d5ceeaea 100644 --- a/libsq/core/kind/detect.go +++ b/libsq/core/kind/detect.go @@ -152,7 +152,7 @@ func (d *Detector) doSampleString(s string) { d.mungeFns[Time] = func(val any) (any, error) { if val == nil { - return nil, nil + return nil, nil //nolint:nilnil } s, ok = val.(string) @@ -161,7 +161,7 @@ func (d *Detector) doSampleString(s string) { } if s == "" { - return nil, nil + return nil, nil //nolint:nilnil } var t time.Time @@ -186,7 +186,7 @@ func (d *Detector) doSampleString(s string) { d.mungeFns[Date] = func(val any) (any, error) { if val == nil { - return nil, nil + return nil, nil //nolint:nilnil } s, ok = val.(string) @@ -195,7 +195,7 @@ func (d *Detector) doSampleString(s string) { } if s == "" { - return nil, nil + return nil, nil //nolint:nilnil } var t time.Time @@ -222,7 +222,7 @@ func (d *Detector) doSampleString(s string) { // it returns a time.Time instead of a string d.mungeFns[Datetime] = func(val any) (any, error) { if val == nil { - return nil, nil + return nil, nil //nolint:nilnil } s, ok := val.(string) @@ -231,7 +231,7 @@ func (d *Detector) doSampleString(s string) { } if s == "" { - return nil, nil + return nil, nil //nolint:nilnil } t, err := time.Parse(format, s) diff --git a/libsq/core/kind/munge.go b/libsq/core/kind/munge.go index 87d64489..477d7aea 100644 --- a/libsq/core/kind/munge.go +++ b/libsq/core/kind/munge.go @@ -12,14 +12,14 @@ var _ MungeFunc = MungeEmptyStringAsNil func MungeEmptyStringAsNil(v any) (any, error) { switch v := v.(type) { case nil: - return nil, nil + return nil, nil //nolint:nilnil case *string: if len(*v) == 0 { - return nil, nil + return nil, nil //nolint:nilnil } case string: if len(v) == 0 { - return nil, nil + return nil, nil //nolint:nilnil } } diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index 84e81b49..6e6688d1 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -37,6 +37,7 @@ const ( To = "to" Type = "type" Line = "line" + URL = "url" Val = "value" Via = "via" Version = "version" diff --git a/libsq/core/stringz/stringz.go b/libsq/core/stringz/stringz.go index 2382533e..16f5921f 100644 --- a/libsq/core/stringz/stringz.go +++ b/libsq/core/stringz/stringz.go @@ -16,6 +16,8 @@ import ( "time" "unicode" + "github.com/samber/lo" + "github.com/google/uuid" "github.com/neilotoole/sq/libsq/core/errz" @@ -298,9 +300,9 @@ func SurroundSlice(a []string, w string) []string { } // PrefixSlice returns a new slice with each element -// of a prefixed with w, unless a is nil, in which +// of a prefixed with prefix, unless a is nil, in which // case nil is returned. -func PrefixSlice(a []string, w string) []string { +func PrefixSlice(a []string, prefix string) []string { if a == nil { return nil } @@ -310,8 +312,8 @@ func PrefixSlice(a []string, w string) []string { ret := make([]string, len(a)) sb := strings.Builder{} for i := 0; i < len(a); i++ { - sb.Grow(len(a[i]) + len(w)) - sb.WriteString(w) + sb.Grow(len(a[i]) + len(prefix)) + sb.WriteString(prefix) sb.WriteString(a[i]) ret[i] = sb.String() sb.Reset() @@ -592,3 +594,28 @@ func IndentLines(s, indent string) string { return indent + line }) } + +// HasAnyPrefix returns true if s has any of the prefixes. +func HasAnyPrefix(s string, prefixes ...string) bool { + for _, prefix := range prefixes { + if strings.HasPrefix(s, prefix) { + return true + } + } + return false +} + +// FilterPrefix returns a new slice containing each element +// of a that has prefix. +func FilterPrefix(prefix string, a ...string) []string { + return lo.Filter(a, func(item string, index int) bool { + return strings.HasPrefix(item, prefix) + }) +} + +// ElementsHavingPrefix returns the elements of a that have prefix. +func ElementsHavingPrefix(a []string, prefix string) []string { + return lo.Filter(a, func(item string, index int) bool { + return strings.HasPrefix(item, prefix) + }) +} diff --git a/libsq/core/urlz/urlz.go b/libsq/core/urlz/urlz.go new file mode 100644 index 00000000..51bd492f --- /dev/null +++ b/libsq/core/urlz/urlz.go @@ -0,0 +1,105 @@ +// Package urlz contains URL utility functionality. +package urlz + +import ( + "net/url" + "strings" + + "github.com/neilotoole/sq/libsq/core/errz" +) + +// QueryParamKeys returns the keys of a URL query. This function +// exists because url.ParseQuery returns a url.Values, which is a +// map type, and the keys don't preserve order. +func QueryParamKeys(query string) (keys []string, err error) { + // Code is adapted from url.ParseQuery. + for query != "" { + var key string + key, query, _ = strings.Cut(query, "&") + if strings.Contains(key, ";") { + err = errz.Errorf("invalid semicolon separator in query") + continue + } + if key == "" { + continue + } + key, _, _ = strings.Cut(key, "=") + key, err1 := url.QueryUnescape(key) + if err1 != nil { + if err == nil { + err = errz.Err(err1) + } + continue + } + + keys = append(keys, key) + } + return keys, err +} + +// RenameQueryParamKey renames all instances of oldKey in query +// to newKey, where query is a URL query string. +func RenameQueryParamKey(query, oldKey, newKey string) string { + if query == "" { + return "" + } + + parts := strings.Split(query, "&") + for i, part := range parts { + if part == oldKey { + parts[i] = newKey + continue + } + + if strings.HasPrefix(part, oldKey+"=") { + parts[i] = strings.Replace(part, oldKey, newKey, 1) + } + } + + return strings.Join(parts, "&") +} + +// StripQuery strips the query params from u. +func StripQuery(u url.URL) string { + u.RawQuery = "" + u.ForceQuery = false + return u.String() +} + +// StripPath strips the url's path. If stripQuery is true, the +// query is also stripped. +func StripPath(u url.URL, stripQuery bool) string { + u.Path = "" + if stripQuery { + u.RawQuery = "" + u.ForceQuery = false + } + return u.String() +} + +// StripUser strips the URL's user info. +func StripUser(u url.URL) string { + u2 := u + u2.User = nil + s := u2.String() + return s +} + +// StripScheme removes the URL's scheme. +func StripScheme(u url.URL) string { + u2 := u + u2.Scheme = "" + s := u2.String() + s = strings.TrimPrefix(s, "//") + return s +} + +// StripSchemeAndUser removes the URL's scheme and user info. +func StripSchemeAndUser(u url.URL) string { + u2 := u + u2.User = nil + u2.Scheme = "" + s := u2.String() + s = strings.TrimPrefix(s, "//") + return s +} diff --git a/libsq/core/urlz/urlz_test.go b/libsq/core/urlz/urlz_test.go new file mode 100644 index 00000000..b599e095 --- /dev/null +++ b/libsq/core/urlz/urlz_test.go @@ -0,0 +1,152 @@ +package urlz_test + +import ( + "net/url" + "testing" + + "github.com/neilotoole/sq/libsq/core/urlz" + "github.com/neilotoole/sq/testh/tutil" + "github.com/stretchr/testify/require" +) + +func TestQueryParamKeys(t *testing.T) { + testCases := []struct { + q string + want []string + wantErr bool + }{ + {"a=1", []string{"a"}, false}, + {"a=1&b=2", []string{"a", "b"}, false}, + {"b=1&a=2", []string{"b", "a"}, false}, + {"a=1&b=", []string{"a", "b"}, false}, + {"a=1&b=;", []string{"a"}, true}, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, tc.q), func(t *testing.T) { + got, gotErr := urlz.QueryParamKeys(tc.q) + if tc.wantErr { + require.Error(t, gotErr) + } else { + require.NoError(t, gotErr) + } + require.Equal(t, tc.want, got) + }) + } +} + +func TestRenameQueryParamKey(t *testing.T) { + testCases := []struct { + q string + oldKey string + newKey string + want string + }{ + {"", "a", "b", ""}, + {"a=1", "a", "b", "b=1"}, + {"a", "a", "b", "b"}, + {"aa", "a", "b", "aa"}, + {"a=", "a", "b", "b="}, + {"a=1&a=2", "a", "b", "b=1&b=2"}, + {"a=1&c=2", "a", "b", "b=1&c=2"}, + {"a=1&c=2&a=3&a=4", "a", "b", "b=1&c=2&b=3&b=4"}, + {"a=a&c=2&a=b&a=c", "a", "b", "b=a&c=2&b=b&b=c"}, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, tc.q, tc.oldKey, tc.newKey), func(t *testing.T) { + got := urlz.RenameQueryParamKey(tc.q, tc.oldKey, tc.newKey) + require.Equal(t, tc.want, got) + }) + } +} + +func TestStripQuery(t *testing.T) { + testCases := []struct { + in string + want string + }{ + {"https://sq.io", "https://sq.io"}, + {"https://sq.io/path", "https://sq.io/path"}, + {"https://sq.io/path#frag", "https://sq.io/path#frag"}, + {"https://sq.io?a=b", "https://sq.io"}, + {"https://sq.io/path?a=b", "https://sq.io/path"}, + {"https://sq.io/path?a=b&c=d#frag", "https://sq.io/path#frag"}, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, tc), func(t *testing.T) { + u, err := url.Parse(tc.in) + require.NoError(t, err) + got := urlz.StripQuery(*u) + require.Equal(t, tc.want, got) + }) + } +} + +func TestStripUser(t *testing.T) { + testCases := []struct { + in string + want string + }{ + {"https://sq.io", "https://sq.io"}, + {"https://alice:123@sq.io/path", "https://sq.io/path"}, + {"https://alice@sq.io/path", "https://sq.io/path"}, + {"https://alice:@sq.io/path", "https://sq.io/path"}, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, tc), func(t *testing.T) { + u, err := url.Parse(tc.in) + require.NoError(t, err) + got := urlz.StripUser(*u) + require.Equal(t, tc.want, got) + }) + } +} + +func TestStripScheme(t *testing.T) { + testCases := []struct { + in string + want string + }{ + {"https://sq.io", "sq.io"}, + {"https://alice:123@sq.io/path", "alice:123@sq.io/path"}, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, tc), func(t *testing.T) { + u, err := url.Parse(tc.in) + require.NoError(t, err) + got := urlz.StripScheme(*u) + require.Equal(t, tc.want, got) + }) + } +} + +func TestStripSchemeAndUser(t *testing.T) { + testCases := []struct { + in string + want string + }{ + {"https://sq.io", "sq.io"}, + {"https://alice:123@sq.io/path", "sq.io/path"}, + {"https://alice@sq.io/path", "sq.io/path"}, + {"https://alice:@sq.io/path", "sq.io/path"}, + } + + for i, tc := range testCases { + tc := tc + t.Run(tutil.Name(i, tc), func(t *testing.T) { + u, err := url.Parse(tc.in) + require.NoError(t, err) + got := urlz.StripSchemeAndUser(*u) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index f368ce6a..fcf58487 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -216,6 +216,13 @@ type SQLDriver interface { // Dialect returns the SQL dialect. Dialect() dialect.Dialect + // ConnParams returns the db parameters available for use in a connection + // string. The key is the parameter name (e.g. "sslmode"), and the value + // can be either the set of allowed values, sample values, or nil. + // These values are used for shell completion and the like. The returned + // map does not have to be exhaustive, and can be nil. + ConnParams() map[string][]string + // ErrWrapFunc returns a func that wraps the driver's errors. ErrWrapFunc() func(error) error @@ -342,25 +349,29 @@ type Database interface { // Metadata holds driver metadata. type Metadata struct { // Type is the driver type, e.g. "mysql" or "csv", etc. - Type source.DriverType `json:"type"` + Type source.DriverType `json:"type" yaml:"type"` // Description is typically the long name of the driver, e.g. // "MySQL" or "Microsoft Excel XLSX". - Description string `json:"description"` + Description string `json:"description" yaml:"description"` // Doc is optional documentation, typically a URL. - Doc string `json:"doc,omitempty"` + Doc string `json:"doc,omitempty" yaml:"doc,omitempty"` // UserDefined is true if this driver is the product of a // user driver definition, and false if built-in. - UserDefined bool `json:"user_defined"` + UserDefined bool `json:"user_defined" yaml:"user_defined"` // IsSQL is true if this driver is a SQL driver. - IsSQL bool `json:"is_sql"` + IsSQL bool `json:"is_sql" yaml:"is_sql"` // Monotable is true if this is a non-SQL document type that // effectively has a single table, such as CSV. - Monotable bool `json:"monotable"` + Monotable bool `json:"monotable" yaml:"monotable"` + + // DefaultPort is the default port that a driver connects on. A + // value <= 0 indicates not applicable. + DefaultPort int `json:"default_port" yaml:"default_port"` } var _ DatabaseOpener = (*Databases)(nil) diff --git a/libsq/driver/registry.go b/libsq/driver/registry.go index b521c24f..112a3abf 100644 --- a/libsq/driver/registry.go +++ b/libsq/driver/registry.go @@ -66,6 +66,21 @@ func (r *Registry) DriverFor(typ source.DriverType) (Driver, error) { return p.DriverFor(typ) } +// SQLDriverFor for is a convenience method for getting a SQLDriver. +func (r *Registry) SQLDriverFor(typ source.DriverType) (SQLDriver, error) { + drvr, err := r.DriverFor(typ) + if err != nil { + return nil, err + } + + sqlDrvr, ok := drvr.(SQLDriver) + if !ok { + return nil, errz.Errorf("driver %T is not of type %T", drvr, sqlDrvr) + } + + return sqlDrvr, nil +} + // DriversMetadata returns metadata for each registered driver type. func (r *Registry) DriversMetadata() []Metadata { var md []Metadata diff --git a/libsq/source/source.go b/libsq/source/source.go index 518828f0..ddb150b3 100644 --- a/libsq/source/source.go +++ b/libsq/source/source.go @@ -97,7 +97,7 @@ func (s *Source) LogValue() slog.Value { // String returns a log/debug-friendly representation. func (s *Source) String() string { - return fmt.Sprintf("%s|%s| %s", s.Handle, s.Type, s.RedactedLocation()) + return fmt.Sprintf("%s|%s|%s", s.Handle, s.Type, s.RedactedLocation()) } // Group returns the source's group. If s is in the root group, diff --git a/testh/testh.go b/testh/testh.go index 69912c58..6abc434f 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -12,6 +12,8 @@ import ( "testing" "time" + "github.com/neilotoole/sq/cli/run" + "github.com/neilotoole/sq/drivers" "github.com/neilotoole/sq/cli" @@ -97,6 +99,7 @@ type Helper struct { registry *driver.Registry files *source.Files databases *driver.Databases + run *run.Run initOnce sync.Once @@ -191,6 +194,16 @@ func (h *Helper) init() { h.files.AddDriverDetectors(xlsx.DetectXLSX) h.addUserDrivers() + + h.run = &run.Run{ + Stdin: os.Stdin, + Out: os.Stdout, + ErrOut: os.Stdin, + Config: config.New(), + ConfigStore: config.DiscardStore{}, + OptionsRegistry: &options.Registry{}, + DriverRegistry: h.registry, + } }) } @@ -606,6 +619,12 @@ func (h *Helper) Registry() *driver.Registry { return h.registry } +// Run returns the helper's run instance. +func (h *Helper) Run() *run.Run { + h.init() + return h.run +} + // addUserDrivers adds some user drivers to the registry. func (h *Helper) addUserDrivers() { userDriverDefs := DriverDefsFrom(h.T, testsrc.PathDriverDefPpl, testsrc.PathDriverDefRSS) diff --git a/testh/tutil/tutil.go b/testh/tutil/tutil.go index a356ddcd..33fadeef 100644 --- a/testh/tutil/tutil.go +++ b/testh/tutil/tutil.go @@ -5,7 +5,9 @@ import ( "fmt" "io" "os" + "path/filepath" "reflect" + "runtime" "strings" "testing" "unicode" @@ -318,3 +320,41 @@ func (t *tWriter) Write(p []byte) (n int, err error) { t.t.Log("\n" + string(p)) return len(p), nil } + +// Chdir changes the working directory to dir, or if dir is empty, +// to a temp dir. On test end, the original working dir is restored, +// and the temp dir deleted (if applicable). The absolute path +// of the changed working dir is returned. +func Chdir(t *testing.T, dir string) (absDir string) { + origDir, err := os.Getwd() + require.NoError(t, err) + + if filepath.IsAbs(dir) { + absDir = dir + } else { + absDir, err = filepath.Abs(dir) + require.NoError(t, err) + } + + if dir == "" { + tmpDir := t.TempDir() + t.Cleanup(func() { + _ = os.Remove(tmpDir) + }) + dir = tmpDir + } + + require.NoError(t, os.Chdir(dir)) + t.Cleanup(func() { + _ = os.Chdir(origDir) + }) + + return absDir +} + +// SkipWindows skips t if running on Windows. +func SkipWindows(t *testing.T, format string, args ...any) { + if runtime.GOOS == "windows" { + t.Skipf(format, args...) + } +}