diff --git a/.gitignore b/.gitignore index d381fc7e..a2d6de8b 100644 --- a/.gitignore +++ b/.gitignore @@ -42,7 +42,7 @@ _testmain.go .envrc **/*.bench go.work* - +*.dump.sql # Some apps create temp files when editing, e.g. Excel with drivers/xlsx/testdata/~$test_header.xlsx **/testdata/~* @@ -59,3 +59,5 @@ goreleaser-test.sh /*.db /.CHANGELOG.delta.md /testh/progress-remove.test.sh + +/*.dump diff --git a/CHANGELOG.md b/CHANGELOG.md index f81d2810..be900f1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Breaking changes are annotated with ☢️, and alpha/beta features with 🐥. +## [v0.47.4] - UPCOMING + +Minor changes to the behavior of the `--src.schema` flag. +See the earlier [`v0.47.0`](https://github.com/neilotoole/sq/releases/tag/v0.47.0) +release for recent headline features. + +### Changed + +- The [`--src.schema`](https://sq.io/docs/source#source-override) flag (as used in [`sq inspect`](https://sq.io/docs/cmd/inspect), + [`sq sql`](https://sq.io/docs/cmd/sql), and the root [`sq`](https://sq.io/docs/cmd/sq#override-active-schema) cmd) + now accepts `--src.schema=CATALOG.` (note the `.` suffix). This is in addition to the existing allowed forms `SCHEMA` + and `CATALOG.SCHEMA`. This new `CATALOG.` form is effectively equivalent to `CATALOG.CURRENT_SCHEMA`. + + ```shell + # Inspect using the default schema in the "sales" catalog + $ sq inspect --src.schema=sales. + ``` + +- The [`--src.schema`](https://sq.io/docs/source#source-override) flag is now validated. Previously, if you provided a non-existing catalog or schema + value, `sq` would silently ignore it and use the defaults. This could mislead the user into thinking that + they were getting valid results from the non-existent catalog or schema. Now an error is returned. + + ## [v0.47.3] - 2024-02-03 Minor bug fix release. See the earlier [`v0.47.0`](https://github.com/neilotoole/sq/releases/tag/v0.47.0) @@ -1138,3 +1161,4 @@ make working with lots of sources much easier. [v0.47.1]: https://github.com/neilotoole/sq/compare/v0.47.0...v0.47.1 [v0.47.2]: https://github.com/neilotoole/sq/compare/v0.47.1...v0.47.2 [v0.47.3]: https://github.com/neilotoole/sq/compare/v0.47.2...v0.47.3 +[v0.47.4]: https://github.com/neilotoole/sq/compare/v0.47.3...v0.47.4 diff --git a/Makefile b/Makefile index 6389eed7..8a2c6ead 100644 --- a/Makefile +++ b/Makefile @@ -29,11 +29,13 @@ gen: .PHONY: fmt fmt: @# https://github.com/incu6us/goimports-reviser - @# Note that *_windows.go is excluded because the tool seems + @# Note that termz_windows.go is excluded because the tool seems @# to mangle Go code that is guarded by build tags that - @# are not in use. + @# are not in use. Alas, we can't provide a double star glob, + @# e.g. **/*_windows.go, because filepath.Match doesn't support + @# double star, so we explicitly name the file. @goimports-reviser -company-prefixes github.com/neilotoole -set-alias \ - -excludes '**/*_windows.go' \ + -excludes 'libsq/core/termz/termz_windows.go' \ -rm-unused -output write \ -project-name github.com/neilotoole/sq ./... diff --git a/cli/cli.go b/cli/cli.go index 0b6381a4..cdc20ed7 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -259,6 +259,15 @@ func newCommandTree(ru *run.Run) (rootCmd *cobra.Command) { addCmd(ru, tblCmd, newTblTruncateCmd()) addCmd(ru, tblCmd, newTblDropCmd()) + dbCmd := addCmd(ru, rootCmd, newDBCmd()) + addCmd(ru, dbCmd, newDBExecCmd()) + dbDumpCmd := addCmd(ru, dbCmd, newDBDumpCmd()) + addCmd(ru, dbDumpCmd, newDBDumpCatalogCmd()) + addCmd(ru, dbDumpCmd, newDBDumpClusterCmd()) + dbRestoreCmd := addCmd(ru, dbCmd, newDBRestoreCmd()) + addCmd(ru, dbRestoreCmd, newDBRestoreCatalogCmd()) + addCmd(ru, dbRestoreCmd, newDBRestoreClusterCmd()) + addCmd(ru, rootCmd, newDiffCmd()) driverCmd := addCmd(ru, rootCmd, newDriverCmd()) diff --git a/cli/cmd_add.go b/cli/cmd_add.go index 77950163..51ed2733 100644 --- a/cli/cmd_add.go +++ b/cli/cmd_add.go @@ -156,7 +156,7 @@ More examples: Long: `Add data source specified by LOCATION, optionally identified by @HANDLE.`, } - markCmdRequiresConfigLock(cmd) + cmdMarkRequiresConfigLock(cmd) addTextFormatFlags(cmd) cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage) @@ -242,7 +242,7 @@ func execSrcAdd(cmd *cobra.Command, args []string) error { // or sq prompts the user. if cmdFlagIsSetTrue(cmd, flag.PasswordPrompt) { var passwd []byte - if passwd, err = readPassword(ctx, ru.Stdin, ru.Out, ru.Writers.Printing); err != nil { + if passwd, err = readPassword(ctx, ru.Stdin, ru.Out, ru.Writers.OutPrinting); err != nil { return err } diff --git a/cli/cmd_add_test.go b/cli/cmd_add_test.go index c8f65b8a..02b4ae8f 100644 --- a/cli/cmd_add_test.go +++ b/cli/cmd_add_test.go @@ -141,14 +141,12 @@ func TestCmdAdd(t *testing.T) { wantHandle: "@sakila", wantType: drivertype.SQLite, }, - { // with scheme loc: proj.Abs(sakila.PathSL3), wantHandle: "@sakila", wantType: drivertype.SQLite, }, - { // without scheme, relative path loc: proj.Rel(sakila.PathSL3), diff --git a/cli/cmd_cache.go b/cli/cmd_cache.go index 4f1cb6b4..78e6ebfb 100644 --- a/cli/cmd_cache.go +++ b/cli/cmd_cache.go @@ -74,7 +74,7 @@ func newCacheStatCmd() *cobra.Command { /Users/neilotoole/Library/Caches/sq/f36ac695 enabled (472.8MB)`, } - markCmdRequiresConfigLock(cmd) + cmdMarkRequiresConfigLock(cmd) addTextFormatFlags(cmd) cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) cmd.Flags().BoolP(flag.YAML, flag.YAMLShort, false, flag.YAMLUsage) @@ -111,7 +111,7 @@ func newCacheClearCmd() *cobra.Command { $ sq cache clear @sakila`, } - markCmdRequiresConfigLock(cmd) + cmdMarkRequiresConfigLock(cmd) return cmd } @@ -150,7 +150,7 @@ func newCacheTreeCmd() *cobra.Command { $ sq cache tree --size`, } - markCmdRequiresConfigLock(cmd) + cmdMarkRequiresConfigLock(cmd) _ = cmd.Flags().BoolP(flag.CacheTreeSize, flag.CacheTreeSizeShort, false, flag.CacheTreeSizeUsage) return cmd } @@ -163,7 +163,7 @@ func execCacheTree(cmd *cobra.Command, _ []string) error { } showSize := cmdFlagBool(cmd, flag.CacheTreeSize) - return ioz.PrintTree(ru.Out, cacheDir, showSize, !ru.Writers.Printing.IsMonochrome()) + return ioz.PrintTree(ru.Out, cacheDir, showSize, !ru.Writers.OutPrinting.IsMonochrome()) } func newCacheEnableCmd() *cobra.Command { //nolint:dupl @@ -200,7 +200,7 @@ func newCacheEnableCmd() *cobra.Command { //nolint:dupl $ sq cache enable @sakila`, } - markCmdRequiresConfigLock(cmd) + cmdMarkRequiresConfigLock(cmd) return cmd } @@ -238,6 +238,6 @@ func newCacheDisableCmd() *cobra.Command { //nolint:dupl $ sq cache disable @sakila`, } - markCmdRequiresConfigLock(cmd) + cmdMarkRequiresConfigLock(cmd) return cmd } diff --git a/cli/cmd_config_edit.go b/cli/cmd_config_edit.go index ca521b22..075a3d97 100644 --- a/cli/cmd_config_edit.go +++ b/cli/cmd_config_edit.go @@ -51,7 +51,7 @@ in envar $SQ_EDITOR or $EDITOR.`, # Use a different editor $ SQ_EDITOR=nano sq config edit`, } - markCmdRequiresConfigLock(cmd) + cmdMarkRequiresConfigLock(cmd) return cmd } @@ -134,7 +134,7 @@ func execConfigEditSource(cmd *cobra.Command, args []string) error { // The Catalog and Schema fields have yaml tag 'omitempty', // so they wouldn't be rendered in the editor yaml if empty. - // However, we to render the fields commented-out if empty. + // However, we want to render the fields commented-out if empty. // Hence this little hack. if tmpSrc.Catalog == "" { // Forces yaml rendering diff --git a/cli/cmd_config_set.go b/cli/cmd_config_set.go index ca1ae43f..4dcb9852 100644 --- a/cli/cmd_config_set.go +++ b/cli/cmd_config_set.go @@ -41,7 +41,7 @@ Use "sq config ls -v" to list available options.`, # Help for an individual option $ sq config set conn.max-open --help`, } - markCmdRequiresConfigLock(cmd) + cmdMarkRequiresConfigLock(cmd) addTextFormatFlags(cmd) cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage) diff --git a/cli/cmd_db.go b/cli/cmd_db.go new file mode 100644 index 00000000..a94452ae --- /dev/null +++ b/cli/cmd_db.go @@ -0,0 +1,18 @@ +package cli + +import ( + "github.com/spf13/cobra" +) + +func newDBCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "db", + Short: "Useful database actions", + RunE: func(cmd *cobra.Command, args []string) error { + return cmd.Help() + }, + Example: ` # TBD`, + } + + return cmd +} diff --git a/cli/cmd_db_dump.go b/cli/cmd_db_dump.go new file mode 100644 index 00000000..669ff7f4 --- /dev/null +++ b/cli/cmd_db_dump.go @@ -0,0 +1,269 @@ +package cli + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + + "github.com/neilotoole/sq/cli/flag" + "github.com/neilotoole/sq/cli/run" + "github.com/neilotoole/sq/drivers/postgres" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/execz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/drivertype" +) + +func newDBDumpCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "dump", + Short: "Dump db catalog or cluster", + Long: `Execute or print db-native dump command for db catalog or cluster.`, + RunE: func(cmd *cobra.Command, args []string) error { + return cmd.Help() + }, + } + + return cmd +} + +func newDBDumpCatalogCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "catalog @src [--print]", + Short: "Dump db catalog", + Long: `Dump db catalog using database-native dump tool.`, + ValidArgsFunction: completeHandle(1, true), + Args: cobra.MaximumNArgs(1), + RunE: execDBDumpCatalog, + Example: ` # Dump @sakila_pg to file sakila.dump using pg_dump + $ sq db dump catalog @sakila_pg -o sakila.dump + + # Same as above, but verbose mode, and dump via stdout + $ sq db dump catalog @sakila_pg -v > sakila.dump + + # Dump without ownership or ACL + $ sq db dump catalog --no-owner @sakila_pg > sakila.dump + + # Print the dump command, but don't execute it + $ sq db dump catalog @sakila_pg --print + + # Dump a catalog (db) other than the source's current catalog + $ sq db dump catalog @sakila_pg --catalog sales > sales.dump`, + } + + // Calling cmdMarkPlainStdout means that ru.Stdout will be + // the plain os.Stdout, and won't be decorated with color, or + // progress listeners etc. The dump commands handle their own output. + cmdMarkPlainStdout(cmd) + + cmd.Flags().String(flag.DBDumpCatalog, "", flag.DBDumpCatalogUsage) + panicOn(cmd.RegisterFlagCompletionFunc(flag.DBDumpCatalog, completeCatalog(0))) + cmd.Flags().Bool(flag.DBDumpNoOwner, false, flag.DBDumpNoOwnerUsage) + cmd.Flags().StringP(flag.FileOutput, flag.FileOutputShort, "", flag.FileOutputUsage) + cmd.Flags().Bool(flag.DBPrintToolCmd, false, flag.DBPrintToolCmdUsage) + cmd.Flags().Bool(flag.DBPrintLongToolCmd, false, flag.DBPrintLongToolCmdUsage) + cmd.MarkFlagsMutuallyExclusive(flag.DBPrintToolCmd, flag.DBPrintLongToolCmd) + + return cmd +} + +func execDBDumpCatalog(cmd *cobra.Command, args []string) error { + ru := run.FromContext(cmd.Context()) + + var ( + src *source.Source + err error + ) + + if len(args) == 0 { + if src = ru.Config.Collection.Active(); src == nil { + return errz.New(msgNoActiveSrc) + } + } else if src, err = ru.Config.Collection.Get(args[0]); err != nil { + return err + } + + if err = applySourceOptions(cmd, src); err != nil { + return err + } + + if cmdFlagChanged(cmd, flag.DBDumpCatalog) { + // Use a different catalog than the source's current catalog. + if src.Catalog, err = cmd.Flags().GetString(flag.DBDumpCatalog); err != nil { + return err + } + } + + var ( + errPrefix = fmt.Sprintf("db dump catalog: %s", src.Handle) + dumpVerbose = cmdFlagBool(cmd, flag.Verbose) + dumpNoOwner = cmdFlagBool(cmd, flag.DBDumpNoOwner) + dumpLongFlags = cmdFlagBool(cmd, flag.DBPrintLongToolCmd) + dumpFile string + ) + + if cmdFlagChanged(cmd, flag.FileOutput) { + if dumpFile, err = cmd.Flags().GetString(flag.FileOutput); err != nil { + return err + } + + if dumpFile = strings.TrimSpace(dumpFile); dumpFile == "" { + return errz.Errorf("%s: %s is specified, but empty", errPrefix, flag.FileOutput) + } + } + + var execCmd *execz.Cmd + + switch src.Type { //nolint:exhaustive + case drivertype.Pg: + params := &postgres.ToolParams{ + Verbose: dumpVerbose, + NoOwner: dumpNoOwner, + File: dumpFile, + LongFlags: dumpLongFlags, + } + execCmd, err = postgres.DumpCatalogCmd(src, params) + default: + return errz.Errorf("%s: not supported for %s", errPrefix, src.Type) + } + + if err != nil { + return errz.Wrap(err, errPrefix) + } + + execCmd.NoProgress = !OptProgress.Get(src.Options) + execCmd.Label = src.Handle + ": " + execCmd.Name + execCmd.Stdin = ru.Stdin + execCmd.Stdout = ru.Stdout + execCmd.Stderr = ru.ErrOut + execCmd.ErrPrefix = errPrefix + + if cmdFlagBool(cmd, flag.DBPrintToolCmd) || cmdFlagBool(cmd, flag.DBPrintLongToolCmd) { + lg.FromContext(cmd.Context()).Info("Printing external cmd", lga.Cmd, execCmd) + _, err = fmt.Fprintln(ru.Out, execCmd.String()) + return errz.Err(err) + } + + switch src.Type { //nolint:exhaustive + case drivertype.Pg: + lg.FromContext(cmd.Context()).Info("Executing external cmd", lga.Cmd, execCmd) + return execz.Exec(cmd.Context(), execCmd) + default: + return errz.Errorf("%s: not supported for %s", errPrefix, src.Type) + } +} + +func newDBDumpClusterCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "cluster @src [--print]", + Short: "Dump entire db cluster", + Long: `Dump all catalogs in src's db cluster using the db-native dump tool.`, + ValidArgsFunction: completeHandle(1, true), + Args: cobra.MaximumNArgs(1), + RunE: execDBDumpCluster, + Example: ` # Dump all catalogs in @sakila_pg's cluster using pg_dumpall + $ sq db dump cluster @sakila_pg -f all.dump + + # Same as above, but verbose mode and using stdout + $ sq db dump cluster @sakila_pg -v > all.dump + + # Dump without ownership or ACL + $ sq db dump cluster @sakila_pg --no-owner > all.dump + + # Print the dump command, but don't execute it + $ sq db dump cluster @sakila_pg -f all.dump --print`, + } + + // Calling cmdMarkPlainStdout means that ru.Stdout will be + // the plain os.Stdout, and won't be decorated with color, or + // progress listeners etc. The dump commands handle their own output. + cmdMarkPlainStdout(cmd) + cmd.Flags().Bool(flag.DBDumpNoOwner, false, flag.DBDumpNoOwnerUsage) + cmd.Flags().StringP(flag.FileOutput, flag.FileOutputShort, "", flag.FileOutputUsage) + cmd.Flags().Bool(flag.DBPrintToolCmd, false, flag.DBPrintToolCmdUsage) + cmd.Flags().Bool(flag.DBPrintLongToolCmd, false, flag.DBPrintLongToolCmdUsage) + cmd.MarkFlagsMutuallyExclusive(flag.DBPrintToolCmd, flag.DBPrintLongToolCmd) + + return cmd +} + +func execDBDumpCluster(cmd *cobra.Command, args []string) error { + var ( + ru = run.FromContext(cmd.Context()) + src *source.Source + err error + ) + + if len(args) == 0 { + if src = ru.Config.Collection.Active(); src == nil { + return errz.New(msgNoActiveSrc) + } + } else if src, err = ru.Config.Collection.Get(args[0]); err != nil { + return err + } + + if err = applySourceOptions(cmd, src); err != nil { + return err + } + + var ( + errPrefix = fmt.Sprintf("db dump cluster: %s", src.Handle) + dumpVerbose = cmdFlagBool(cmd, flag.Verbose) + dumpNoOwner = cmdFlagBool(cmd, flag.DBDumpNoOwner) + dumpLongFlags = cmdFlagBool(cmd, flag.DBPrintLongToolCmd) + dumpFile string + ) + + if cmdFlagChanged(cmd, flag.FileOutput) { + if dumpFile, err = cmd.Flags().GetString(flag.FileOutput); err != nil { + return err + } + + if dumpFile = strings.TrimSpace(dumpFile); dumpFile == "" { + return errz.Errorf("%s: %s is specified, but empty", errPrefix, flag.FileOutput) + } + } + + var execCmd *execz.Cmd + + switch src.Type { //nolint:exhaustive + case drivertype.Pg: + params := &postgres.ToolParams{ + Verbose: dumpVerbose, + NoOwner: dumpNoOwner, + File: dumpFile, + LongFlags: dumpLongFlags, + } + execCmd, err = postgres.DumpClusterCmd(src, params) + default: + err = errz.Errorf("%s: not supported for %s", errPrefix, src.Type) + } + + if err != nil { + return errz.Wrap(err, errPrefix) + } + + execCmd.NoProgress = !OptProgress.Get(src.Options) + execCmd.Label = src.Handle + ": " + execCmd.Name + execCmd.Stdin = ru.Stdin + execCmd.Stdout = ru.Stdout + execCmd.Stderr = ru.ErrOut + execCmd.ErrPrefix = errPrefix + + if cmdFlagBool(cmd, flag.DBPrintToolCmd) || cmdFlagBool(cmd, flag.DBPrintLongToolCmd) { + lg.FromContext(cmd.Context()).Info("Printing external cmd", lga.Cmd, execCmd) + _, err = fmt.Fprintln(ru.Out, execCmd.String()) + return errz.Err(err) + } + + switch src.Type { //nolint:exhaustive + case drivertype.Pg: + lg.FromContext(cmd.Context()).Info("Executing external cmd", lga.Cmd, execCmd) + return execz.Exec(cmd.Context(), execCmd) + default: + return errz.Errorf("%s: not supported for %s", errPrefix, src.Type) + } +} diff --git a/cli/cmd_db_exec.go b/cli/cmd_db_exec.go new file mode 100644 index 00000000..06750d4d --- /dev/null +++ b/cli/cmd_db_exec.go @@ -0,0 +1,151 @@ +package cli + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + + "github.com/neilotoole/sq/cli/flag" + "github.com/neilotoole/sq/cli/run" + "github.com/neilotoole/sq/drivers/postgres" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/execz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/drivertype" +) + +func newDBExecCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "exec [@src] [--f SCRIPT.sql] [-c 'SQL'] [--print]", + Short: "Execute SQL script or command", + Long: `Execute SQL script or command using the db-native tool. + +If no source is specified, the active source is used. + +If --file is specified, the SQL is read from that file; otherwise if --command +is specified, that command string is used; otherwise the SQL commands are +read from stdin. + +If --print or --print-long are specified, the SQL is not executed, but instead +the db-native command is printed to stdout. Note that the output will include DB +credentials.`, + Args: cobra.MaximumNArgs(1), + ValidArgsFunction: completeHandle(1, true), + RunE: execDBExec, + Example: ` # Execute query.sql on @sakila_pg + $ sq db exec @sakila_pg -f query.sql + + # Same as above, but use stdin + $ sq db exec @sakila_pg < query.sql + + # Execute a command string against the active source + $ sq db exec -c 'SELECT 777' + 777 + + # Print the db-native command, but don't execute it + $ sq db exec -f query.sql --print + psql -d 'postgres://alice:abc123@db.acme.com:5432/sales' -f query.sql + + # Execute against an alternative catalog or schema + $ sq db exec @sakila_pg --schema inventory.public -f query.sql`, + } + + cmdMarkPlainStdout(cmd) + + cmd.Flags().StringP(flag.DBExecFile, flag.DBExecFileShort, "", flag.DBExecFileUsage) + cmd.Flags().StringP(flag.DBExecCmd, flag.DBExecCmdShort, "", flag.DBExecCmdUsage) + cmd.MarkFlagsMutuallyExclusive(flag.DBExecFile, flag.DBExecCmd) + + cmd.Flags().Bool(flag.DBPrintToolCmd, false, flag.DBPrintToolCmdUsage) + cmd.Flags().Bool(flag.DBPrintLongToolCmd, false, flag.DBPrintLongToolCmdUsage) + cmd.MarkFlagsMutuallyExclusive(flag.DBPrintToolCmd, flag.DBPrintLongToolCmd) + + cmd.Flags().String(flag.ActiveSchema, "", flag.ActiveSchemaUsage) + panicOn(cmd.RegisterFlagCompletionFunc(flag.ActiveSchema, + activeSchemaCompleter{getActiveSourceViaArgs}.complete)) + + return cmd +} + +func execDBExec(cmd *cobra.Command, args []string) error { + var ( + ru = run.FromContext(cmd.Context()) + src *source.Source + err error + + // scriptFile is the (optional) path to the SQL file. + // If empty, cmdString or stdin is used. + scriptFile string + + // scriptString is the optional SQL command string. + // If empty, scriptFile or stdin is used. + cmdString string + + verbose = cmdFlagBool(cmd, flag.Verbose) + ) + + if src, err = getCmdSource(cmd, args); err != nil { + return err + } + + errPrefix := "db exec: " + src.Handle + if cmdFlagChanged(cmd, flag.DBExecFile) { + if scriptFile = strings.TrimSpace(cmd.Flag(flag.DBExecFile).Value.String()); scriptFile == "" { + return errz.Errorf("%s: %s is specified, but empty", errPrefix, flag.DBExecFile) + } + } + + if cmdFlagChanged(cmd, flag.DBExecCmd) { + if cmdString = strings.TrimSpace(cmd.Flag(flag.DBExecCmd).Value.String()); cmdString == "" { + return errz.Errorf("%s: %s is specified, but empty", errPrefix, flag.DBExecCmd) + } + } + + if err = applySourceOptions(cmd, src); err != nil { + return err + } + + var execCmd *execz.Cmd + + switch src.Type { //nolint:exhaustive + case drivertype.Pg: + params := &postgres.ExecToolParams{ + Verbose: verbose, + ScriptFile: scriptFile, + CmdString: cmdString, + LongFlags: cmdFlagChanged(cmd, flag.DBPrintLongToolCmd), + } + execCmd, err = postgres.ExecCmd(src, params) + default: + return errz.Errorf("%s: not supported for %s", errPrefix, src.Type) + } + + if err != nil { + return errz.Wrap(err, errPrefix) + } + + execCmd.NoProgress = !OptProgress.Get(src.Options) + execCmd.Label = src.Handle + ": " + execCmd.Name + execCmd.Stdin = ru.Stdin + execCmd.Stdout = ru.Stdout + execCmd.Stderr = ru.ErrOut + execCmd.ErrPrefix = errPrefix + + if cmdFlagBool(cmd, flag.DBPrintToolCmd) || cmdFlagBool(cmd, flag.DBPrintLongToolCmd) { + lg.FromContext(cmd.Context()).Info("Printing external cmd", lga.Cmd, execCmd) + s := execCmd.String() + _, err = fmt.Fprintln(ru.Out, s) + return errz.Err(err) + } + + switch src.Type { //nolint:exhaustive + case drivertype.Pg: + lg.FromContext(cmd.Context()).Info("Executing external cmd", lga.Cmd, execCmd) + return execz.Exec(cmd.Context(), execCmd) + default: + return errz.Errorf("%s: not supported for %s", errPrefix, src.Type) + } +} diff --git a/cli/cmd_db_restore.go b/cli/cmd_db_restore.go new file mode 100644 index 00000000..221d7a76 --- /dev/null +++ b/cli/cmd_db_restore.go @@ -0,0 +1,259 @@ +package cli + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + + "github.com/neilotoole/sq/cli/flag" + "github.com/neilotoole/sq/cli/run" + "github.com/neilotoole/sq/drivers/postgres" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/execz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/drivertype" +) + +func newDBRestoreCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "restore", + Short: "Restore db catalog or cluster from dump", + RunE: func(cmd *cobra.Command, args []string) error { + return cmd.Help() + }, + } + + return cmd +} + +func newDBRestoreCatalogCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "catalog @src [--from file.dump] [--print]", + Short: "Restore db catalog from dump", + Long: `Restore into @src from dump file, using the db-native restore tool. + +If --from is specified, the dump is read from that file; otherwise stdin is used. + +When --no-owner is specified, the source user will own the restored objects: the +ownership (and ACLs) from the dump file are disregarded. + +If --print or --print-long are specified, the restore command is not executed, but +instead the db-native command is printed to stdout. Note that the command output +will include DB credentials. For a Postgres source, it would look something like: + + pg_restore -d 'postgres://alice:abc123@localhost:5432/sales' backup.dump`, + Args: cobra.ExactArgs(1), + ValidArgsFunction: completeHandle(1, true), + RunE: execDBRestoreCatalog, + Example: ` # Restore @sakila_pg from backup.dump + $ sq db restore catalog @sakila_pg -f backup.dump + + # With verbose output, and reading from stdin + $ sq db restore catalog -v @sakila_pg < backup.dump + + # Don't use ownership from dump; the source user will own the restored objects + $ sq db restore catalog @sakila_pg --no-owner < backup.dump + + # Print the db-native restore command, but don't execute it + $ sq db restore catalog @sakila_pg -f backup.dump --print`, + } + + cmdMarkPlainStdout(cmd) + + // FIXME: add --src.Schema + + cmd.Flags().StringP(flag.DBRestoreFrom, flag.DBRestoreFromShort, "", flag.DBRestoreFromUsage) + cmd.Flags().Bool(flag.DBRestoreNoOwner, false, flag.DBRestoreNoOwnerUsage) + cmd.Flags().StringP(flag.FileOutput, flag.FileOutputShort, "", flag.FileOutputUsage) + cmd.Flags().Bool(flag.DBPrintToolCmd, false, flag.DBPrintToolCmdUsage) + cmd.Flags().Bool(flag.DBPrintLongToolCmd, false, flag.DBPrintLongToolCmdUsage) + cmd.MarkFlagsMutuallyExclusive(flag.DBPrintToolCmd, flag.DBPrintLongToolCmd) + + return cmd +} + +func execDBRestoreCatalog(cmd *cobra.Command, args []string) error { + var ( + ru = run.FromContext(cmd.Context()) + src *source.Source + err error + // fpDump is the (optional) path to the dump file. + // If empty, stdin is used. + dumpFile string + ) + + if src, err = ru.Config.Collection.Get(args[0]); err != nil { + return err + } + + errPrefix := "db restore catalog: " + src.Handle + if cmdFlagChanged(cmd, flag.DBRestoreFrom) { + if dumpFile = strings.TrimSpace(cmd.Flag(flag.DBRestoreFrom).Value.String()); dumpFile == "" { + return errz.Errorf("%s: %s is specified, but empty", errPrefix, flag.DBRestoreFrom) + } + } + + if err = applySourceOptions(cmd, src); err != nil { + return err + } + + verbose := cmdFlagBool(cmd, flag.Verbose) + noOwner := cmdFlagBool(cmd, flag.DBRestoreNoOwner) + + var execCmd *execz.Cmd + + switch src.Type { //nolint:exhaustive + case drivertype.Pg: + params := &postgres.ToolParams{ + Verbose: verbose, + NoOwner: noOwner, + File: dumpFile, + } + execCmd, err = postgres.RestoreCatalogCmd(src, params) + default: + return errz.Errorf("%s: not supported for %s", errPrefix, src.Type) + } + + if err != nil { + return errz.Wrap(err, errPrefix) + } + + execCmd.NoProgress = !OptProgress.Get(src.Options) + execCmd.Label = src.Handle + ": " + execCmd.Name + execCmd.Stdin = ru.Stdin + execCmd.Stdout = ru.Stdout + execCmd.Stderr = ru.ErrOut + execCmd.ErrPrefix = errPrefix + + if cmdFlagBool(cmd, flag.DBPrintToolCmd) || cmdFlagBool(cmd, flag.DBPrintLongToolCmd) { + lg.FromContext(cmd.Context()).Info("Printing external cmd", lga.Cmd, execCmd) + _, err = fmt.Fprintln(ru.Out, execCmd.String()) + return errz.Err(err) + } + + switch src.Type { //nolint:exhaustive + case drivertype.Pg: + lg.FromContext(cmd.Context()).Info("Executing external cmd", lga.Cmd, execCmd) + return execz.Exec(cmd.Context(), execCmd) + default: + return errz.Errorf("%s: cmd not supported for %s", errPrefix, src.Type) + } +} + +func newDBRestoreClusterCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "cluster @src [--from file.dump] [--print]", + Short: "Restore db cluster from dump", + Long: `Restore entire db cluster into @src from dump file, using the db-native restore +tool. + +If --from is specified, the dump is read from that file; otherwise stdin is used. + +When --no-owner is specified, the source user will own the restored objects: the +ownership (and ACLs) from the dump file are disregarded. + +If --print or --print-long are specified, the restore command is not executed, but +instead the db-native command is printed to stdout. Note that the command output +will include DB credentials. For a Postgres source, it would look something like: + +FIXME: example command + psql -d 'postgres://alice:abc123@localhost:5432/sales' backup.dump`, + Args: cobra.ExactArgs(1), + ValidArgsFunction: completeHandle(1, true), + RunE: execDBRestoreCluster, + Example: ` # Restore @sakila_pg from backup.dump + $ sq db restore cluster @sakila_pg -f backup.dump + + # With verbose output, and reading from stdin + $ sq db restore cluster -v @sakila_pg < backup.dump + + # Don't use ownership from dump; the source user will own the restored objects + $ sq db restore cluster @sakila_pg --no-owner < backup.dump + + # Print the db-native restore command, but don't execute it + $ sq db restore cluster @sakila_pg -f backup.dump --print`, + } + + cmdMarkPlainStdout(cmd) + cmd.Flags().StringP(flag.DBRestoreFrom, flag.DBRestoreFromShort, "", flag.DBRestoreFromUsage) + cmd.Flags().Bool(flag.DBRestoreNoOwner, false, flag.DBRestoreNoOwnerUsage) + cmd.Flags().StringP(flag.FileOutput, flag.FileOutputShort, "", flag.FileOutputUsage) + cmd.Flags().Bool(flag.DBPrintToolCmd, false, flag.DBPrintToolCmdUsage) + cmd.Flags().Bool(flag.DBPrintLongToolCmd, false, flag.DBPrintLongToolCmdUsage) + cmd.MarkFlagsMutuallyExclusive(flag.DBPrintToolCmd, flag.DBPrintLongToolCmd) + + return cmd +} + +func execDBRestoreCluster(cmd *cobra.Command, args []string) error { + var ( + ru = run.FromContext(cmd.Context()) + src *source.Source + err error + // dumpFile is the (optional) path to the dump file. + // If empty, stdin is used. + dumpFile string + ) + + if src, err = ru.Config.Collection.Get(args[0]); err != nil { + return err + } + + errPrefix := "db restore cluster: " + src.Handle + if cmdFlagChanged(cmd, flag.DBRestoreFrom) { + if dumpFile = strings.TrimSpace(cmd.Flag(flag.DBRestoreFrom).Value.String()); dumpFile == "" { + return errz.Errorf("%s: %s is specified, but empty", errPrefix, flag.DBRestoreFrom) + } + } + + if err = applySourceOptions(cmd, src); err != nil { + return err + } + + verbose := cmdFlagBool(cmd, flag.Verbose) + + // FIXME: get rid of noOwner from this command? + // noOwner := cmdFlagBool(cmd, flag.RestoreNoOwner) + + var execCmd *execz.Cmd + + switch src.Type { //nolint:exhaustive + case drivertype.Pg: + params := &postgres.ToolParams{ + Verbose: verbose, + File: dumpFile, + } + execCmd, err = postgres.RestoreClusterCmd(src, params) + default: + return errz.Errorf("%s: not supported for %s", errPrefix, src.Type) + } + + if err != nil { + return errz.Wrap(err, errPrefix) + } + + execCmd.NoProgress = !OptProgress.Get(src.Options) + execCmd.Label = src.Handle + ": " + execCmd.Name + execCmd.Stdin = ru.Stdin + execCmd.Stdout = ru.Stdout + execCmd.Stderr = ru.ErrOut + execCmd.ErrPrefix = errPrefix + + if cmdFlagBool(cmd, flag.DBPrintToolCmd) || cmdFlagBool(cmd, flag.DBPrintLongToolCmd) { + lg.FromContext(cmd.Context()).Info("Printing external cmd", lga.Cmd, execCmd) + s := execCmd.String() + _, err = fmt.Fprintln(ru.Out, s) + return errz.Err(err) + } + + switch src.Type { //nolint:exhaustive + case drivertype.Pg: + lg.FromContext(cmd.Context()).Info("Executing external cmd", lga.Cmd, execCmd) + return execz.Exec(cmd.Context(), execCmd) + default: + return errz.Errorf("%s: not supported for %s", errPrefix, src.Type) + } +} diff --git a/cli/cmd_inspect.go b/cli/cmd_inspect.go index 7e3fbec9..77f2c270 100644 --- a/cli/cmd_inspect.go +++ b/cli/cmd_inspect.go @@ -121,7 +121,8 @@ func execInspect(cmd *cobra.Command, args []string) error { // Handle flag.ActiveSchema (--src.schema=SCHEMA). This func will mutate // src's Catalog and Schema fields if appropriate. - if err = processFlagActiveSchema(cmd, src); err != nil { + var srcModified bool + if srcModified, err = processFlagActiveSchema(cmd, src); err != nil { return err } @@ -129,6 +130,12 @@ func execInspect(cmd *cobra.Command, args []string) error { return err } + if srcModified { + if err = verifySourceCatalogSchema(ctx, ru, src); err != nil { + return err + } + } + grip, err := ru.Grips.Open(ctx, src) if err != nil { return errz.Wrapf(err, "failed to inspect %s", src.Handle) diff --git a/cli/cmd_mv.go b/cli/cmd_mv.go index af8c28a7..11c5061f 100644 --- a/cli/cmd_mv.go +++ b/cli/cmd_mv.go @@ -42,7 +42,7 @@ source handles are files, and groups are directories.`, $ sq mv production prod`, } - markCmdRequiresConfigLock(cmd) + cmdMarkRequiresConfigLock(cmd) addTextFormatFlags(cmd) cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage) diff --git a/cli/cmd_remove.go b/cli/cmd_remove.go index a50ce6cd..ade3d48b 100644 --- a/cli/cmd_remove.go +++ b/cli/cmd_remove.go @@ -34,7 +34,7 @@ may have changed, if that source or group was removed.`, $ sq rm @staging/sakila_db @staging/backup_db dev`, } - markCmdRequiresConfigLock(cmd) + cmdMarkRequiresConfigLock(cmd) addTextFormatFlags(cmd) cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage) diff --git a/cli/cmd_scratch.go b/cli/cmd_scratch.go index 36b04547..88c03bed 100644 --- a/cli/cmd_scratch.go +++ b/cli/cmd_scratch.go @@ -36,7 +36,7 @@ importing non-SQL data, or cross-database joins. If no argument provided, get th source. Otherwise, set @HANDLE or an internal db as the scratch data source. The reserved handle "@scratch" resets the `, } - markCmdRequiresConfigLock(cmd) + cmdMarkRequiresConfigLock(cmd) return cmd } diff --git a/cli/cmd_slq.go b/cli/cmd_slq.go index 5e8ebfe1..1fc757a8 100644 --- a/cli/cmd_slq.go +++ b/cli/cmd_slq.go @@ -17,7 +17,6 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" @@ -204,7 +203,7 @@ func execSLQPrint(ctx context.Context, ru *run.Run, mArgs map[string]string) err // // $ cat something.xlsx | sq @stdin.sheet1 func preprocessUserSLQ(ctx context.Context, ru *run.Run, args []string) (string, error) { - log, reg, grips, coll := lg.FromContext(ctx), ru.DriverRegistry, ru.Grips, ru.Config.Collection + log, reg, coll := lg.FromContext(ctx), ru.DriverRegistry, ru.Config.Collection activeSrc := coll.Active() if len(args) == 0 { @@ -212,6 +211,10 @@ func preprocessUserSLQ(ctx context.Context, ru *run.Run, args []string) (string, // but sq is receiving pipe input. Let's say the user does this: // // $ cat something.csv | sq # query becomes "@stdin.data" + // + // REVISIT: It's not clear that this is even reachable any more? + // Plus, it's a bit ugly in general. Was the code already changed + // to force providing at least one query arg? if activeSrc == nil { // Piped input would result in an active @stdin src. We don't // have that; we don't have any active src. @@ -230,27 +233,26 @@ func preprocessUserSLQ(ctx context.Context, ru *run.Run, args []string) (string, } tblName := source.MonotableName - if !drvr.DriverMetadata().Monotable { // This isn't a monotable src, so we can't // just select @stdin.data. Instead we'll select // the first table name, as found in the source meta. - grip, err := grips.Open(ctx, activeSrc) - if err != nil { - return "", err - } - defer lg.WarnIfCloseError(log, lgm.CloseDB, grip) - srcMeta, err := grip.SourceMetadata(ctx, false) + db, sqlDrvr, err := ru.DB(ctx, activeSrc) if err != nil { return "", err } - if len(srcMeta.Tables) == 0 { + tables, err := sqlDrvr.ListTableNames(ctx, db, "", true, true) + if err != nil { + return "", err + } + + if len(tables) == 0 { return "", errz.New(msgSrcNoData) } - tblName = srcMeta.Tables[0].Name + tblName = tables[0] if tblName == "" { return "", errz.New(msgSrcEmptyTableName) } @@ -325,7 +327,7 @@ func addQueryCmdFlags(cmd *cobra.Command) { addTimeFormatOptsFlags(cmd) - cmd.Flags().StringP(flag.Output, flag.OutputShort, "", flag.OutputUsage) + cmd.Flags().StringP(flag.FileOutput, flag.FileOutputShort, "", flag.FileOutputUsage) cmd.Flags().StringP(flag.Input, flag.InputShort, "", flag.InputUsage) panicOn(cmd.Flags().MarkHidden(flag.Input)) // Hide for now; this is mostly used for testing. diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 13d08c04..7c0b0a22 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -14,7 +14,7 @@ import ( "github.com/neilotoole/sq/libsq/files" ) -// newXCmd returns the root "x" command, which is the container +// newXCmd returns the "x" command, which is the container // for a set of hidden commands that are useful for development. // The x commands are not intended for end users. func newXCmd() *cobra.Command { diff --git a/cli/complete.go b/cli/complete.go index 453a32ba..59a38073 100644 --- a/cli/complete.go +++ b/cli/complete.go @@ -113,6 +113,65 @@ func completeHandle(max int, includeActive bool) completionFunc { } } +// completeCatalog is a completionFunc that suggests catalogs. +// If srcArgPos >= 0 and in the range of the cmd args, then +// that value is used to determine the source handle. Typically the +// src handle is the first arg, but it could be elsewhere. +// For example, arg[0] below is @sakila, thus srcArgPos should be 0: +// +// $ sq db dump @sakila --catalog [COMPLETE] +// +// If srcArgPos < 0, the catalogs for the active source are suggested. +func completeCatalog(srcArgPos int) completionFunc { + return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + log, ru := logFrom(cmd), getRun(cmd) + if err := preRun(cmd, ru); err != nil { + lg.Unexpected(log, err) + return nil, cobra.ShellCompDirectiveError + } + + var ( + ctx = cmd.Context() + coll = ru.Config.Collection + src *source.Source + err error + ) + + if srcArgPos >= 0 && srcArgPos < len(args) { + if src, err = coll.Get(args[srcArgPos]); err != nil { + lg.Unexpected(log, err) + return nil, cobra.ShellCompDirectiveError + } + } + + if src == nil { + if src = coll.Active(); src == nil { + log.Debug("No active source, so no catalog completions") + return nil, cobra.ShellCompDirectiveError + } + } + + db, drvr, err := ru.DB(ctx, src) + if err != nil { + lg.Unexpected(log, err) + return nil, cobra.ShellCompDirectiveError + } + + catalogs, err := drvr.ListCatalogs(ctx, db) + if err != nil { + lg.Unexpected(log, err) + return nil, cobra.ShellCompDirectiveError + } + + catalogs = lo.Filter(catalogs, func(catalog string, index int) bool { + return strings.HasPrefix(catalog, toComplete) + }) + + slices.Sort(catalogs) + return catalogs, cobra.ShellCompDirectiveNoFileComp + } +} + // completeGroup is a completionFunc that suggests groups. // The max arg is the maximum number of completions. Set to 0 // for no limit. @@ -793,29 +852,21 @@ func getTableNamesForHandle(ctx context.Context, ru *run.Run, handle string) ([] return nil, err } - grip, err := ru.Grips.Open(ctx, src) + db, drvr, err := ru.DB(ctx, src) if err != nil { return nil, err } - // TODO: We shouldn't have to load the full metadata just to get - // the table names. driver.SQLDriver should have a method ListTables. - md, err := grip.SourceMetadata(ctx, false) - if err != nil { - return nil, err - } - - return md.TableNames(), nil + return drvr.ListTableNames(ctx, db, src.Schema, true, true) } // maybeFilterHandlesByActiveGroup filters the supplied handles by // active group, if appropriate. func maybeFilterHandlesByActiveGroup(ru *run.Run, toComplete string, suggestions []string) []string { - handleFilter := getActiveGroupHandleFilterPrefix(ru) - if handleFilter != "" { - if strings.HasPrefix(handleFilter, toComplete) { - suggestions = lo.Filter(suggestions, func(handle string, index int) bool { - return strings.HasPrefix(handle, handleFilter) + if groupPrefix := getActiveGroupHandleFilterPrefix(ru); groupPrefix != "" { + if strings.HasPrefix(groupPrefix, toComplete) { + suggestions = lo.Filter(suggestions, func(handle string, _ int) bool { + return strings.HasPrefix(handle, groupPrefix) }) } } diff --git a/cli/config/yamlstore/yamlstore.go b/cli/config/yamlstore/yamlstore.go index 6d6c738a..74f9351f 100644 --- a/cli/config/yamlstore/yamlstore.go +++ b/cli/config/yamlstore/yamlstore.go @@ -201,6 +201,8 @@ func (fs *Store) write(ctx context.Context, data []byte) error { return errz.Wrapf(err, "failed to make parent dir of config file: %s", filepath.Dir(fs.Path)) } + // FIXME: Store.Save should do a two-step atomic write of the file. + if err := os.WriteFile(fs.Path, data, ioz.RWPerms); err != nil { return errz.Wrap(err, "failed to save config file") } diff --git a/cli/diff/data_naive.go b/cli/diff/data_naive.go index 2a21f121..59c39951 100644 --- a/cli/diff/data_naive.go +++ b/cli/diff/data_naive.go @@ -42,7 +42,7 @@ func buildTableDataDiff(ctx context.Context, ru *run.Run, cfg *Config, query2 := td2.src.Handle + "." + td2.tblName log := lg.FromContext(ctx).With("a", query1).With("b", query2) - pr := ru.Writers.Printing.Clone() + pr := ru.Writers.OutPrinting.Clone() pr.EnableColor(false) buf1, buf2 := &bytes.Buffer{}, &bytes.Buffer{} @@ -176,7 +176,7 @@ func execSourceDataDiff(ctx context.Context, ru *run.Run, cfg *Config, sd1, sd2 } tblDataDiff = diffs[printIndex] - if err := Print(ctx, ru.Out, ru.Writers.Printing, tblDataDiff.header, tblDataDiff.diff); err != nil { + if err := Print(ctx, ru.Out, ru.Writers.OutPrinting, tblDataDiff.header, tblDataDiff.diff); err != nil { printErrCh <- err return } diff --git a/cli/diff/record.go b/cli/diff/record.go index 7ae5dbc6..8ac2f834 100644 --- a/cli/diff/record.go +++ b/cli/diff/record.go @@ -158,7 +158,7 @@ func findRecordDiff(ctx context.Context, ru *run.Run, lines int, row: i, } - if err = populateRecordDiff(ctx, lines, ru.Writers.Printing, recDiff); err != nil { + if err = populateRecordDiff(ctx, lines, ru.Writers.OutPrinting, recDiff); err != nil { return nil, err } diff --git a/cli/diff/source.go b/cli/diff/source.go index 6b52530b..1ff1c3ca 100644 --- a/cli/diff/source.go +++ b/cli/diff/source.go @@ -47,7 +47,7 @@ func ExecSourceDiff(ctx context.Context, ru *run.Run, cfg *Config, return err } - if err = Print(ctx, ru.Out, ru.Writers.Printing, srcDiff.header, srcDiff.diff); err != nil { + if err = Print(ctx, ru.Out, ru.Writers.OutPrinting, srcDiff.header, srcDiff.diff); err != nil { return err } } @@ -57,7 +57,7 @@ func ExecSourceDiff(ctx context.Context, ru *run.Run, cfg *Config, if err != nil { return err } - if err = Print(ctx, ru.Out, ru.Writers.Printing, propsDiff.header, propsDiff.diff); err != nil { + if err = Print(ctx, ru.Out, ru.Writers.OutPrinting, propsDiff.header, propsDiff.diff); err != nil { return err } } @@ -68,7 +68,7 @@ func ExecSourceDiff(ctx context.Context, ru *run.Run, cfg *Config, return err } for _, tblDiff := range tblDiffs { - if err = Print(ctx, ru.Out, ru.Writers.Printing, tblDiff.header, tblDiff.diff); err != nil { + if err = Print(ctx, ru.Out, ru.Writers.OutPrinting, tblDiff.header, tblDiff.diff); err != nil { return err } } diff --git a/cli/diff/table.go b/cli/diff/table.go index 50dbbfc5..ca63cac4 100644 --- a/cli/diff/table.go +++ b/cli/diff/table.go @@ -51,7 +51,7 @@ func ExecTableDiff(ctx context.Context, ru *run.Run, cfg *Config, elems *Element return err } - if err = Print(ctx, ru.Out, ru.Writers.Printing, tblDiff.header, tblDiff.diff); err != nil { + if err = Print(ctx, ru.Out, ru.Writers.OutPrinting, tblDiff.header, tblDiff.diff); err != nil { return err } } @@ -69,7 +69,7 @@ func ExecTableDiff(ctx context.Context, ru *run.Run, cfg *Config, elems *Element return nil } - return Print(ctx, ru.Out, ru.Writers.Printing, tblDataDiff.header, tblDataDiff.diff) + return Print(ctx, ru.Out, ru.Writers.OutPrinting, tblDataDiff.header, tblDataDiff.diff) } func buildTableStructureDiff(ctx context.Context, cfg *Config, showRowCounts bool, diff --git a/cli/error.go b/cli/error.go index 867edf89..fd07a882 100644 --- a/cli/error.go +++ b/cli/error.go @@ -19,6 +19,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/termz" ) // PrintError is the centralized function for printing @@ -88,8 +89,10 @@ func PrintError(ctx context.Context, ru *run.Run, err error) { } else { clnup = cleanup.New() } - // getPrinting works even if cmd is nil - pr, _, errOut := getPrinting(cmd, clnup, opts, os.Stdout, os.Stderr) + // getOutputConfig works even if cmd is nil + fm := getFormat(cmd, opts) + outCfg := getOutputConfig(cmd, clnup, fm, opts, ru.Stdout, ru.Stderr) + errOut, pr := outCfg.errOut, outCfg.errOutPr // Execute the cleanup before we print the error. if cleanErr := clnup.Run(); cleanErr != nil { log.Error("Cleanup failed", lga.Err, cleanErr) @@ -103,7 +106,7 @@ func PrintError(ctx context.Context, ru *run.Run, err error) { } // The user didn't want JSON, so we just print to stderr. - if isColorTerminal(os.Stderr) { + if termz.IsColorTerminal(os.Stderr) { pr.Error.Fprintln(os.Stderr, "sq: "+err.Error()) } else { fmt.Fprintln(os.Stderr, "sq: "+err.Error()) diff --git a/cli/flag/flag.go b/cli/flag/flag.go index 3b5453ac..bce0c35c 100644 --- a/cli/flag/flag.go +++ b/cli/flag/flag.go @@ -1,12 +1,15 @@ // Package flag holds CLI flags. package flag +// FIXME: Need to update docs for use of src.schema to note the +// new "catalog." mechanism. + const ( ActiveSrc = "src" ActiveSrcUsage = "Override active source for this query" ActiveSchema = "src.schema" - ActiveSchemaUsage = "Override active schema or catalog.schema for this query" + ActiveSchemaUsage = "Override active schema (and/or catalog) for this query" ConfigSrc = "src" ConfigSrcUsage = "Config for source" @@ -67,12 +70,9 @@ const ( MonochromeShort = "M" MonochromeUsage = "Don't colorize output" - NoProgress = "no-progress" - NoProgressUsage = "Don't show progress bar" - - Output = "output" - OutputShort = "o" - OutputUsage = "Write output to instead of stdout" + FileOutput = "output" + FileOutputShort = "o" + FileOutputUsage = "Write output to instead of stdout" // Input sets Run.Stdin to the named file. At this time, this is used // mainly for debugging, so it's marked hidden by the CLI. I'm not @@ -210,6 +210,29 @@ const ( DiffAll = "all" DiffAllShort = "a" DiffAllUsage = "Compare everything (caution: may be slow)" + + DBDumpCatalog = "catalog" + DBDumpCatalogUsage = "Dump the named catalog" + DBDumpNoOwner = "no-owner" + DBDumpNoOwnerUsage = "Don't set ownership or ACL" + + DBPrintToolCmd = "print" + DBPrintToolCmdUsage = "Print the db-native tool command, but don't execute it" + DBPrintLongToolCmd = "print-long" + DBPrintLongToolCmdUsage = "Print the long-form db-native tool command, but don't execute it" + + DBRestoreFrom = "from" + DBRestoreFromShort = "f" + DBRestoreFromUsage = "Restore from dump file; if omitted, read from stdin" + DBRestoreNoOwner = "no-owner" + DBRestoreNoOwnerUsage = "Don't use ownership or ACL from dump" + + DBExecFile = "file" + DBExecFileShort = "f" + DBExecFileUsage = "Read SQL from instead of stdin" + DBExecCmd = "command" + DBExecCmdShort = "c" + DBExecCmdUsage = "Execute SQL command string" ) // OutputFormatFlags is the set of flags that control output format. diff --git a/cli/output.go b/cli/output.go index 02810e93..28c55eb5 100644 --- a/cli/output.go +++ b/cli/output.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "io" "os" @@ -30,6 +31,7 @@ import ( "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/stringz" + "github.com/neilotoole/sq/libsq/core/termz" "github.com/neilotoole/sq/libsq/core/timez" ) @@ -269,13 +271,15 @@ the rendered value is not an integer. ) // newWriters returns an output.Writers instance configured per defaults and/or -// flags from cmd. The returned out2/errOut2 values may differ -// from the out/errOut args (e.g. decorated to support colorization). +// flags from cmd. The returned writers in [outputConfig] may differ from +// the stdout and stderr params (e.g. decorated to support colorization). func newWriters(cmd *cobra.Command, clnup *cleanup.Cleanup, o options.Options, - out, errOut io.Writer, -) (w *output.Writers, out2, errOut2 io.Writer) { - var pr *output.Printing - pr, out2, errOut2 = getPrinting(cmd, clnup, o, out, errOut) + stdout, stderr io.Writer, +) (w *output.Writers, outCfg *outputConfig) { + // Invoke getFormat to see if the format was specified + // via config or flag. + fm := getFormat(cmd, o) + outCfg = getOutputConfig(cmd, clnup, fm, o, stdout, stderr) log := logFrom(cmd) // Package tablew has writer impls for each of the writer interfaces, @@ -283,50 +287,47 @@ func newWriters(cmd *cobra.Command, clnup *cleanup.Cleanup, o options.Options, // flags and set the various writer fields depending upon which // writers the format implements. w = &output.Writers{ - Printing: pr, - Record: tablew.NewRecordWriter(out2, pr), - Metadata: tablew.NewMetadataWriter(out2, pr), - Source: tablew.NewSourceWriter(out2, pr), - Ping: tablew.NewPingWriter(out2, pr), - Error: tablew.NewErrorWriter(errOut2, pr), - Version: tablew.NewVersionWriter(out2, pr), - Config: tablew.NewConfigWriter(out2, pr), + OutPrinting: outCfg.outPr, + ErrPrinting: outCfg.errOutPr, + Record: tablew.NewRecordWriter(outCfg.out, outCfg.outPr), + Metadata: tablew.NewMetadataWriter(outCfg.out, outCfg.outPr), + Source: tablew.NewSourceWriter(outCfg.out, outCfg.outPr), + Ping: tablew.NewPingWriter(outCfg.out, outCfg.outPr), + Error: tablew.NewErrorWriter(outCfg.errOut, outCfg.errOutPr), + Version: tablew.NewVersionWriter(outCfg.out, outCfg.outPr), + Config: tablew.NewConfigWriter(outCfg.out, outCfg.outPr), } if OptErrorFormat.Get(o) == format.JSON { // This logic works because the only supported values are text and json. - w.Error = jsonw.NewErrorWriter(log, errOut2, pr) + w.Error = jsonw.NewErrorWriter(log, outCfg.errOut, outCfg.errOutPr) } - // Invoke getFormat to see if the format was specified - // via config or flag. - fm := getFormat(cmd, o) - //nolint:exhaustive switch fm { case format.JSON: // No format specified, use JSON - w.Metadata = jsonw.NewMetadataWriter(out2, pr) - w.Source = jsonw.NewSourceWriter(out2, pr) - w.Version = jsonw.NewVersionWriter(out2, pr) - w.Ping = jsonw.NewPingWriter(out2, pr) - w.Config = jsonw.NewConfigWriter(out2, pr) + w.Metadata = jsonw.NewMetadataWriter(outCfg.out, outCfg.outPr) + w.Source = jsonw.NewSourceWriter(outCfg.out, outCfg.outPr) + w.Version = jsonw.NewVersionWriter(outCfg.out, outCfg.outPr) + w.Ping = jsonw.NewPingWriter(outCfg.out, outCfg.outPr) + w.Config = jsonw.NewConfigWriter(outCfg.out, outCfg.outPr) case format.Text: // Don't delete this case, it's actually needed due to // the slightly odd logic that determines format. case format.TSV: - w.Ping = csvw.NewPingWriter(out2, csvw.Tab) + w.Ping = csvw.NewPingWriter(outCfg.out, csvw.Tab) case format.CSV: - w.Ping = csvw.NewPingWriter(out2, csvw.Comma) + w.Ping = csvw.NewPingWriter(outCfg.out, csvw.Comma) case format.YAML: - w.Config = yamlw.NewConfigWriter(out2, pr) - w.Metadata = yamlw.NewMetadataWriter(out2, pr) - w.Source = yamlw.NewSourceWriter(out2, pr) - w.Version = yamlw.NewVersionWriter(out2, pr) + w.Config = yamlw.NewConfigWriter(outCfg.out, outCfg.outPr) + w.Metadata = yamlw.NewMetadataWriter(outCfg.out, outCfg.outPr) + w.Source = yamlw.NewSourceWriter(outCfg.out, outCfg.outPr) + w.Version = yamlw.NewVersionWriter(outCfg.out, outCfg.outPr) default: } @@ -335,10 +336,10 @@ func newWriters(cmd *cobra.Command, clnup *cleanup.Cleanup, o options.Options, // We can still continue, because w.Record was already set above. log.Warn("No record writer impl for format", "format", fm) } else { - w.Record = recwFn(out2, pr) + w.Record = recwFn(outCfg.out, outCfg.outPr) } - return w, out2, errOut2 + return w, outCfg } // getRecordWriterFunc returns a func that creates a new output.RecordWriter @@ -374,23 +375,71 @@ func getRecordWriterFunc(f format.Format) output.NewRecordWriterFunc { } } -// getPrinting returns a Printing instance and -// colorable or non-colorable writers. It is permissible -// for the cmd arg to be nil. The caller should use the returned -// io.Writer instances instead of the supplied writers, as they -// may be decorated for dealing with color, etc. -// The supplied opts must already have flags merged into it -// via getOptionsFromCmd. -// -// Be cautious making changes to getPrinting. This function must -// be absolutely bulletproof, as it's called by all commands, as well -// as by the error handling mechanism. So, be sure to always check -// for nil cmd, nil cmd.Context, etc. -func getPrinting(cmd *cobra.Command, clnup *cleanup.Cleanup, opts options.Options, - out, errOut io.Writer, -) (pr *output.Printing, out2, errOut2 io.Writer) { - pr = output.NewPrinting() +// outputConfig is a container for the various output writers. +type outputConfig struct { + // outPr is the printing config for out. + outPr *output.Printing + // out is the output writer that should be used for stdout output. + out io.Writer + + // stdout is the original stdout, which probably was os.Stdin. + // It's referenced here for special cases. + stdout io.Writer + + // errOutPr is the printing config for errOut. + errOutPr *output.Printing + + // errOut is the output writer that should be used for stderr output. + errOut io.Writer + + // stderr is the original errOut, which probably was os.Stderr. + // It's referenced here for special cases. + stderr io.Writer +} + +// getOutputConfig returns the configured output writers for cmd. Generally +// speaking, the caller should use the outputConfig.out and outputConfig.errOut +// writers for program output, as they are decorated appropriately for dealing +// with colorization, progress bars, etc. In very rare cases, such as calling +// out to an external program (e.g. pg_dump), the original outputConfig.stdout +// and outputConfig.stderr may be used. +// +// The supplied opts must already have flags merged into it via getOptionsFromCmd. +// +// If the progress bar is enabled and possible (stderr is TTY etc.), then cmd's +// context is decorated with via [progress.NewContext]. +// +// Be VERY cautious about making changes to getOutputConfig. This function must +// be absolutely bulletproof, as it's called by all commands, as well as by the +// error handling mechanism. So, be sure to always check for nil: any of the +// args could be nil, or their fields could be nil. Check EVERYTHING for nil. +// +// The returned outputConfig and all of its fields are guaranteed to be non-nil. +// +// See also: [OptMonochrome], [OptProgress], newWriters. +func getOutputConfig(cmd *cobra.Command, clnup *cleanup.Cleanup, + fm format.Format, opts options.Options, stdout, stderr io.Writer, +) (outCfg *outputConfig) { + if opts == nil { + opts = options.Options{} + } + + var ctx context.Context + if cmd != nil { + ctx = cmd.Context() + } + + if stdout == nil { + stdout = os.Stdout + } + if stderr == nil { + stderr = os.Stderr + } + + outCfg = &outputConfig{stdout: stdout, stderr: stderr} + + pr := output.NewPrinting() pr.FormatDatetime = timez.FormatFunc(OptDatetimeFormat.Get(opts)) pr.FormatDatetimeAsNumber = OptDatetimeFormatAsNumber.Get(opts) pr.FormatTime = timez.FormatFunc(OptTimeFormat.Get(opts)) @@ -416,79 +465,104 @@ func getPrinting(cmd *cobra.Command, clnup *cleanup.Cleanup, opts options.Option pr.ShowHeader = OptPrintHeader.Get(opts) } - colorize := !OptMonochrome.Get(opts) + var ( + prog *progress.Progress + noProg = !OptProgress.Get(opts) + progColors = progress.DefaultColors() + monochrome = OptMonochrome.Get(opts) + ) - if cmdFlagChanged(cmd, flag.Output) { - // We're outputting to a file, thus no color. - colorize = false - } - - if !colorize { + if monochrome { color.NoColor = true pr.EnableColor(false) - out2 = out - errOut2 = errOut - - if cmd != nil && cmd.Context() != nil && OptProgress.Get(opts) && isTerminal(errOut) { - progColors := progress.DefaultColors() - progColors.EnableColor(false) - ctx := cmd.Context() - renderDelay := OptProgressDelay.Get(opts) - pb := progress.New(ctx, errOut2, renderDelay, progColors) - clnup.Add(pb.Stop) - // On first write to stdout, we remove the progress widget. - out2 = ioz.NotifyOnceWriter(out2, pb.Stop) - cmd.SetContext(progress.NewContext(ctx, pb)) - } - - return pr, out2, errOut2 - } - - // We do want to colorize - if !isColorTerminal(out) { - // But out can't be colorized. - color.NoColor = true - pr.EnableColor(false) - out2, errOut2 = out, errOut - return pr, out2, errOut2 - } - - // out can be colorized. - color.NoColor = false - pr.EnableColor(true) - out2 = colorable.NewColorable(out.(*os.File)) - - // Check if we can colorize errOut - if isColorTerminal(errOut) { - errOut2 = colorable.NewColorable(errOut.(*os.File)) + progColors.EnableColor(false) } else { - // errOut2 can't be colorized, but since we're colorizing - // out, we'll apply the non-colorable filter to errOut. - errOut2 = colorable.NewNonColorable(errOut) + color.NoColor = false + pr.EnableColor(true) + progColors.EnableColor(true) } - if cmd != nil && cmd.Context() != nil && OptProgress.Get(opts) && isTerminal(errOut) { - progColors := progress.DefaultColors() - progColors.EnableColor(isColorTerminal(errOut)) + outCfg.outPr = pr.Clone() + outCfg.errOutPr = pr.Clone() + pr = nil //nolint:wastedassign // Make sure we don't accidentally use pr again - ctx := cmd.Context() - renderDelay := OptProgressDelay.Get(opts) - pb := progress.New(ctx, errOut2, renderDelay, progColors) - clnup.Add(pb.Stop) - - // On first write to stdout, we remove the progress widget. - out2 = ioz.NotifyOnceWriter(out2, pb.Stop) - - cmd.SetContext(progress.NewContext(ctx, pb)) + switch { + case termz.IsColorTerminal(stderr) && !monochrome: + // stderr is a color terminal and we're colorizing, thus + // we enable progress if allowed and if ctx is non-nil. + outCfg.errOut = colorable.NewColorable(stderr.(*os.File)) + outCfg.errOutPr.EnableColor(true) + if ctx != nil && !noProg { + progColors.EnableColor(true) + prog = progress.New(ctx, outCfg.errOut, OptProgressDelay.Get(opts), progColors) + } + case termz.IsTerminal(stderr): + // stderr is a terminal, and won't have color output, but we still enable + // progress, if allowed and ctx is non-nil. + // + // But... slightly weirdly, we still need to wrap stderr in a colorable, or + // else the progress bar won't render correctly. But it's not a problem, + // because we'll just disable the colors directly. + outCfg.errOut = colorable.NewColorable(stderr.(*os.File)) + outCfg.errOutPr.EnableColor(false) + if ctx != nil && !noProg { + progColors.EnableColor(false) + prog = progress.New(ctx, outCfg.errOut, OptProgressDelay.Get(opts), progColors) + } + default: + // stderr is a not a terminal at all. No color, no progress. + outCfg.errOut = colorable.NewNonColorable(stderr) + outCfg.errOutPr.EnableColor(false) + progColors.EnableColor(false) + prog = nil // Set to nil just to be explicit. } - return pr, out2, errOut2 + switch { + case cmdFlagChanged(cmd, flag.FileOutput) || fm == format.Raw: + // For file or raw output, we don't decorate stdout with + // any colorable decorator. + outCfg.out = stdout + outCfg.outPr.EnableColor(false) + case cmd != nil && cmdFlagChanged(cmd, flag.FileOutput): + // stdout is an actual regular file on disk, so no color. + outCfg.out = colorable.NewNonColorable(stdout) + outCfg.outPr.EnableColor(false) + case termz.IsColorTerminal(stdout) && !monochrome: + // stdout is a color terminal and we're colorizing. + outCfg.out = colorable.NewColorable(stdout.(*os.File)) + outCfg.outPr.EnableColor(true) + case termz.IsTerminal(stderr): + // stdout is a terminal, but won't be colorized. + outCfg.out = colorable.NewNonColorable(stdout) + outCfg.outPr.EnableColor(false) + default: + // stdout is a not a terminal at all. No color. + outCfg.out = colorable.NewNonColorable(stdout) + outCfg.outPr.EnableColor(false) + } + + if !noProg && prog != nil && cmd != nil && ctx != nil { + // The progress bar is enabled. + + // Be sure to stop the progress bar eventually. + clnup.Add(prog.Stop) + + // Also, stop the progress bar as soon as bytes are written + // to out, because we don't want the progress bar to + // corrupt the terminal output. + outCfg.out = ioz.NotifyOnceWriter(outCfg.out, prog.Stop) + cmd.SetContext(progress.NewContext(ctx, prog)) + } + + return outCfg } func getFormat(cmd *cobra.Command, o options.Options) format.Format { var fm format.Format switch { + case cmd == nil: + fm = OptFormat.Get(o) case cmdFlagChanged(cmd, flag.TSV): fm = format.TSV case cmdFlagChanged(cmd, flag.CSV): diff --git a/cli/output/tablew/pingwriter.go b/cli/output/tablew/pingwriter.go index 0a99ddbd..8fe7b313 100644 --- a/cli/output/tablew/pingwriter.go +++ b/cli/output/tablew/pingwriter.go @@ -40,7 +40,7 @@ func (w *PingWriter) Open(srcs []*source.Source) error { // Result implements output.PingWriter. func (w *PingWriter) Result(src *source.Source, d time.Duration, err error) error { - w.pr.Number.Fprintf(w.out, "%-"+strconv.Itoa(w.handleWidthMax)+"s", src.Handle) + w.pr.Handle.Fprintf(w.out, "%-"+strconv.Itoa(w.handleWidthMax)+"s", src.Handle) w.pr.Duration.Fprintf(w.out, "%10s ", d.Truncate(time.Millisecond).String()) // The ping result is one of: diff --git a/cli/output/writers.go b/cli/output/writers.go index f58e5b08..a76b9332 100644 --- a/cli/output/writers.go +++ b/cli/output/writers.go @@ -153,7 +153,8 @@ type ConfigWriter interface { // Writers is a container for the various output Writers. type Writers struct { - Printing *Printing + OutPrinting *Printing + ErrPrinting *Printing Record RecordWriter Metadata MetadataWriter diff --git a/cli/run.go b/cli/run.go index 23498e98..80eb3c2d 100644 --- a/cli/run.go +++ b/cli/run.go @@ -70,8 +70,8 @@ func newRun(ctx context.Context, stdin *os.File, stdout, stderr io.Writer, args ru := &run.Run{ Stdin: stdin, - Out: stdout, - ErrOut: stderr, + Stdout: stdout, + Stderr: stderr, OptionsRegistry: &options.Registry{}, } @@ -158,35 +158,41 @@ func preRun(cmd *cobra.Command, ru *run.Run) error { } // If the --output=/some/file flag is set, then we need to - // override ru.Out (which is typically stdout) to point it at + // override ru.Stdout (which is typically stdout) to point it at // the output destination file. - if cmdFlagChanged(ru.Cmd, flag.Output) { - fpath, _ := ru.Cmd.Flags().GetString(flag.Output) + // + + if cmdFlagChanged(ru.Cmd, flag.FileOutput) && !cmdRequiresPlainStdout(ru.Cmd) { + fpath, _ := ru.Cmd.Flags().GetString(flag.FileOutput) fpath, err := filepath.Abs(fpath) if err != nil { - return errz.Wrapf(err, "failed to get absolute path for --%s", flag.Output) + return errz.Wrapf(err, "failed to get absolute path for --%s", flag.FileOutput) } // Ensure the parent dir exists err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm) if err != nil { - return errz.Wrapf(err, "failed to make parent dir for --%s", flag.Output) + return errz.Wrapf(err, "failed to make parent dir for --%s", flag.FileOutput) } f, err := os.Create(fpath) if err != nil { - return errz.Wrapf(err, "failed to open file specified by flag --%s", flag.Output) + return errz.Wrapf(err, "failed to open file specified by flag --%s", flag.FileOutput) } ru.Cleanup.AddC(f) // Make sure the file gets closed eventually - ru.Out = f + ru.Stdout = f } cmdOpts, err := getOptionsFromCmd(ru.Cmd) if err != nil { return err } - ru.Writers, ru.Out, ru.ErrOut = newWriters(ru.Cmd, ru.Cleanup, cmdOpts, ru.Out, ru.ErrOut) + + var outCfg *outputConfig + ru.Writers, outCfg = newWriters(ru.Cmd, ru.Cleanup, cmdOpts, ru.Stdout, ru.Stderr) + ru.Out = outCfg.out + ru.ErrOut = outCfg.errOut if err = FinishRunInit(ctx, ru); err != nil { return err @@ -330,20 +336,20 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { return nil } -// markCmdRequiresConfigLock marks cmd as requiring a config lock. +// cmdMarkRequiresConfigLock marks cmd as requiring a config lock. // Thus, before the command's RunE is invoked, the config lock // is acquired (in preRun), and released on cleanup. -func markCmdRequiresConfigLock(cmd *cobra.Command) { +func cmdMarkRequiresConfigLock(cmd *cobra.Command) { if cmd.Annotations == nil { cmd.Annotations = make(map[string]string) } - cmd.Annotations["config.lock"] = "true" + cmd.Annotations["config.lock"] = "true" //nolint:goconst } -// cmdRequiresConfigLock returns true if markCmdRequiresConfigLock was +// cmdRequiresConfigLock returns true if cmdMarkRequiresConfigLock was // previously invoked on cmd. func cmdRequiresConfigLock(cmd *cobra.Command) bool { - return cmd.Annotations != nil && cmd.Annotations["config.lock"] == "true" + return cmd != nil && cmd.Annotations != nil && cmd.Annotations["config.lock"] == "true" } // lockReloadConfig acquires the lock for the config store, and updates the @@ -363,7 +369,7 @@ func cmdRequiresConfigLock(cmd *cobra.Command) bool { // defer unlock() // } // -// However, in practice, most commands will invoke markCmdRequiresConfigLock +// However, in practice, most commands will invoke cmdMarkRequiresConfigLock // instead of explicitly invoking lockReloadConfig. func lockReloadConfig(cmd *cobra.Command) (unlock func(), err error) { ctx := cmd.Context() @@ -433,3 +439,20 @@ func newProgressLockFunc(lock lockfile.Lockfile, msg string, timeout time.Durati }, nil } } + +// cmdMarkPlainStdout indicates that the command's stdout should +// not be decorated in any way, e.g. with color or progress bars. +// This is useful for binary output. +func cmdMarkPlainStdout(cmd *cobra.Command) { + // FIXME: implement this in newWriters or such? + if cmd.Annotations == nil { + cmd.Annotations = make(map[string]string) + } + cmd.Annotations["stdout.plain"] = "true" +} + +// cmdRequiresPlainStdout returns true if cmdMarkPlainStdout was +// previously invoked on cmd. +func cmdRequiresPlainStdout(cmd *cobra.Command) bool { + return cmd != nil && cmd.Annotations != nil && cmd.Annotations["stdout.plain"] == "true" +} diff --git a/cli/run/run.go b/cli/run/run.go index 7e72710d..9e474dc0 100644 --- a/cli/run/run.go +++ b/cli/run/run.go @@ -16,8 +16,10 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/sqlz" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/files" + "github.com/neilotoole/sq/libsq/source" ) type runKey struct{} @@ -40,12 +42,28 @@ func FromContext(ctx context.Context) *Run { // to all cobra exec funcs. The Close method should be invoked when // the Run is no longer needed. type Run struct { - // Out is the output destination, typically os.Stdout. + // Out is the output destination, typically a decorated writer over + // [Run.Stdout]. This writer should generally be used for program output, + // not [Run.Stdout]. Out io.Writer - // ErrOut is the error output destination, typically os.Stderr. + // Stdout is the original stdout file descriptor, which is typically + // the actual os.Stdout. Output should generally be written to [Run.Out] + // except for a few rare circumstances, such as executing an external + // program. + Stdout io.Writer + + // ErrOut is the error output destination, typically a decorated writer + // over [Run.Stderr]. This writer should generally be used for error output, + // not [Run.Stderr]. ErrOut io.Writer + // Stderr is the original stderr file descriptor, which is typically + // the actual os.Stderr. Error output should generally be written to + // [Run.ErrOut] except for a few rare circumstances, such as executing an + // external program. + Stderr io.Writer + // ConfigStore manages config persistence. ConfigStore config.Store @@ -108,6 +126,21 @@ func (ru *Run) Close() error { return errz.Wrap(ru.Cleanup.Run(), "close run") } +// DB is a convenience method that gets the sqlz.DB and driver.SQLDriver\ +// for src. +func (ru *Run) DB(ctx context.Context, src *source.Source) (sqlz.DB, driver.SQLDriver, error) { + grip, err := ru.Grips.Open(ctx, src) + if err != nil { + return nil, nil, err + } + db, err := grip.DB(ctx) + if err != nil { + return nil, nil, err + } + + return db, grip.SQLDriver(), nil +} + // NewQueryContext returns a *libsq.QueryContext constructed from ru. func NewQueryContext(ru *Run, args map[string]string) *libsq.QueryContext { return &libsq.QueryContext{ diff --git a/cli/source.go b/cli/source.go index bab3b549..dcf321dc 100644 --- a/cli/source.go +++ b/cli/source.go @@ -70,14 +70,16 @@ func determineSources(ctx context.Context, ru *run.Run, requireActive bool) erro } // activeSrcAndSchemaFromFlagsOrConfig gets the active source, either -// from flagActiveSrc or from srcs.Active. An error is returned +// from [flag.ActiveSrc] or from srcs.Active. An error is returned // if the flag src is not found: if the flag src is found, // it is set as the active src on coll. If the flag was not // set and there is no active src in coll, (nil, nil) is // returned. // -// This source also checks flag.ActiveSchema, and changes the schema +// This source also checks [flag.ActiveSchema], and changes the catalog/schema // of the source if the flag is set. +// +// See also: processFlagActiveSchema, verifySourceCatalogSchema. func activeSrcAndSchemaFromFlagsOrConfig(ru *run.Run) (*source.Source, error) { cmd, coll := ru.Cmd, ru.Config.Collection var activeSrc *source.Source @@ -100,49 +102,148 @@ func activeSrcAndSchemaFromFlagsOrConfig(ru *run.Run) (*source.Source, error) { activeSrc = coll.Active() } - if err := processFlagActiveSchema(cmd, activeSrc); err != nil { + srcModified, err := processFlagActiveSchema(cmd, activeSrc) + if err != nil { return nil, err } + if srcModified { + if err = verifySourceCatalogSchema(ru.Cmd.Context(), ru, activeSrc); err != nil { + return nil, err + } + } + return activeSrc, nil } -// processFlagActiveSchema processes the --src.schema flag, setting -// appropriate Source.Catalog and Source.Schema values on activeSrc. -// If flag.ActiveSchema is not set, this is no-op. If activeSrc is nil, -// an error is returned. -func processFlagActiveSchema(cmd *cobra.Command, activeSrc *source.Source) error { +// verifySourceCatalogSchema verifies that src's non-empty [source.Source.Catalog] +// and [source.Source.Schema] are valid, in that they are referenceable in src's +// DB. If both fields are empty, this is a no-op. This function is typically +// used when the source's catalog or schema are modified, e.g. via [flag.ActiveSchema]. +// +// See also: processFlagActiveSchema. +func verifySourceCatalogSchema(ctx context.Context, ru *run.Run, src *source.Source) error { + if src.Catalog == "" && src.Schema == "" { + return nil + } + + db, drvr, err := ru.DB(ctx, src) + if err != nil { + return err + } + + var exists bool + if src.Catalog != "" { + if exists, err = drvr.CatalogExists(ctx, db, src.Catalog); err != nil { + return err + } + if !exists { + return errz.Errorf("%s: catalog {%s} doesn't exist or not referenceable", src.Handle, src.Catalog) + } + } + + if src.Schema != "" { + if exists, err = drvr.SchemaExists(ctx, db, src.Schema); err != nil { + return err + } + if !exists { + return errz.Errorf("%s: schema {%s} doesn't exist or not referenceable", src.Handle, src.Schema) + } + } + + return nil +} + +// getCmdSource gets the source specified in args[0], or if args is empty, the +// active source, and calls applySourceOptions. If [flag.ActiveSchema] is set, +// the [source.Source.Catalog] and [source.Source.Schema] fields are configured +// (and validated) as appropriate. +// +// See: applySourceOptions, processFlagActiveSchema, verifySourceCatalogSchema. +func getCmdSource(cmd *cobra.Command, args []string) (*source.Source, error) { + ru := run.FromContext(cmd.Context()) + + var src *source.Source + var err error + if len(args) == 0 { + if src = ru.Config.Collection.Active(); src == nil { + return nil, errz.New("no active source") + } + } else { + src, err = ru.Config.Collection.Get(args[0]) + if err != nil { + return nil, err + } + } + + if !cmdFlagChanged(cmd, flag.ActiveSchema) { + return src, nil + } + + // Handle flag.ActiveSchema (--src.schema=CATALOG.SCHEMA). This func may + // mutate src's Catalog and Schema fields if appropriate. + var srcModified bool + if srcModified, err = processFlagActiveSchema(cmd, src); err != nil { + return nil, err + } + + if err = applySourceOptions(cmd, src); err != nil { + return nil, err + } + + if srcModified { + if err = verifySourceCatalogSchema(cmd.Context(), ru, src); err != nil { + return nil, err + } + } + + return src, nil +} + +// processFlagActiveSchema processes the --src.schema flag, setting appropriate +// [source.Source.Catalog] and [source.Source.Schema] values on activeSrc. If +// the src is modified by this function, modified returns true. If +// [flag.ActiveSchema] is not set, this is no-op. If activeSrc is nil, an error +// is returned. +// +// See also: verifySourceCatalogSchema. +func processFlagActiveSchema(cmd *cobra.Command, activeSrc *source.Source) (modified bool, err error) { ru := run.FromContext(cmd.Context()) if !cmdFlagChanged(cmd, flag.ActiveSchema) { // Nothing to do here - return nil + return false, nil } if activeSrc == nil { - return errz.Errorf("active catalog/schema specified via --%s, but active source is nil", + return false, errz.Errorf("active catalog/schema specified via --%s, but active source is nil", flag.ActiveSchema) } val, _ := cmd.Flags().GetString(flag.ActiveSchema) if val = strings.TrimSpace(val); val == "" { - return errz.Errorf("active catalog/schema specified via --%s, but schema is empty", + return false, errz.Errorf("active catalog/schema specified via --%s, but schema is empty", flag.ActiveSchema) } catalog, schema, err := ast.ParseCatalogSchema(val) if err != nil { - return errz.Wrapf(err, "invalid active schema specified via --%s", + return false, errz.Wrapf(err, "invalid active schema specified via --%s", flag.ActiveSchema) } + if catalog != activeSrc.Catalog || schema != activeSrc.Schema { + modified = true + } + drvr, err := ru.DriverRegistry.SQLDriverFor(activeSrc.Type) if err != nil { - return err + return false, err } if catalog != "" { if !drvr.Dialect().Catalog { - return errz.Errorf("driver {%s} does not support catalog", activeSrc.Type) + return false, errz.Errorf("driver {%s} does not support catalog", activeSrc.Type) } + activeSrc.Catalog = catalog } @@ -150,7 +251,7 @@ func processFlagActiveSchema(cmd *cobra.Command, activeSrc *source.Source) error activeSrc.Schema = schema } - return nil + return modified, nil } // checkStdinSource checks if there's stdin data (on pipe/redirect). diff --git a/cli/testrun/testrun.go b/cli/testrun/testrun.go index 9ec5769d..a3e9295b 100644 --- a/cli/testrun/testrun.go +++ b/cli/testrun/testrun.go @@ -112,7 +112,9 @@ func newRun(ctx context.Context, tb testing.TB, ru = &run.Run{ Stdin: os.Stdin, + Stdout: out, Out: out, + Stderr: errOut, ErrOut: errOut, Config: cfg, ConfigStore: cfgStore, diff --git a/drivers/csv/csv_test.go b/drivers/csv/csv_test.go index 30cc84e9..60e7cb90 100644 --- a/drivers/csv/csv_test.go +++ b/drivers/csv/csv_test.go @@ -206,9 +206,12 @@ func TestIngest_DuplicateColumns(t *testing.T) { require.NoError(t, err) tr = testrun.New(ctx, t, tr).Hush() - require.NoError(t, tr.Exec("--csv", ".data")) + err = tr.Exec("--csv", ".data") + require.NoError(t, err) wantHeaders := []string{"actor_id", "first_name", "last_name", "last_update", "actor_id_1"} data := tr.BindCSV() + gotOut := tr.OutString() + _ = gotOut require.Equal(t, wantHeaders, data[0]) // Make sure the data is correct diff --git a/drivers/mysql/mysql.go b/drivers/mysql/mysql.go index 44cdfa7f..d1d4ec0d 100644 --- a/drivers/mysql/mysql.go +++ b/drivers/mysql/mysql.go @@ -224,6 +224,17 @@ func (d *driveri) ListSchemas(ctx context.Context, db sqlz.DB) ([]string, error) return schemas, nil } +// SchemaExists implements driver.SQLDriver. +func (d *driveri) SchemaExists(ctx context.Context, db sqlz.DB, schma string) (bool, error) { + if schma == "" { + return false, nil + } + + const q = "SELECT COUNT(SCHEMA_NAME) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?" + var count int + return count > 0, errw(db.QueryRowContext(ctx, q, schma).Scan(&count)) +} + // ListSchemaMetadata implements driver.SQLDriver. func (d *driveri) ListSchemaMetadata(ctx context.Context, db sqlz.DB) ([]*metadata.Schema, error) { log := lg.FromContext(ctx) @@ -273,6 +284,43 @@ func (d *driveri) CurrentCatalog(ctx context.Context, db sqlz.DB) (string, error return catalog, nil } +// ListTableNames implements driver.SQLDriver. +func (d *driveri) ListTableNames(ctx context.Context, db sqlz.DB, schma string, tables, views bool) ([]string, error) { + var tblClause string + switch { + case tables && views: + tblClause = " AND (TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW')" + case tables: + tblClause = " AND TABLE_TYPE = 'BASE TABLE'" + case views: + tblClause = " AND TABLE_TYPE = 'VIEW'" + default: + return []string{}, nil + } + + var args []any + q := "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = " + if schma == "" { + q += "DATABASE()" + } else { + q += "?" + args = append(args, schma) + } + q += tblClause + " ORDER BY TABLE_NAME" + + rows, err := db.QueryContext(ctx, q, args...) + if err != nil { + return nil, errw(err) + } + + names, err := sqlz.RowsScanColumn[string](ctx, rows) + if err != nil { + return nil, errw(err) + } + + return names, nil +} + // ListCatalogs implements driver.SQLDriver. MySQL does not really support catalogs, // but this method simply delegates to CurrentCatalog, which returns the value // found in INFORMATION_SCHEMA.SCHEMATA, i.e. "def". @@ -285,6 +333,12 @@ func (d *driveri) ListCatalogs(ctx context.Context, db sqlz.DB) ([]string, error return []string{catalog}, nil } +// CatalogExists implements driver.SQLDriver. It returns true if catalog is "def", +// and false otherwise, nothing that MySQL doesn't really support catalogs. +func (d *driveri) CatalogExists(_ context.Context, _ sqlz.DB, catalog string) (bool, error) { + return catalog == "def", nil +} + // AlterTableRename implements driver.SQLDriver. func (d *driveri) AlterTableRename(ctx context.Context, db sqlz.DB, tbl, newName string) error { q := fmt.Sprintf("RENAME TABLE `%s` TO `%s`", tbl, newName) diff --git a/drivers/postgres/postgres.go b/drivers/postgres/postgres.go index a81c4b5b..ae561d99 100644 --- a/drivers/postgres/postgres.go +++ b/drivers/postgres/postgres.go @@ -150,35 +150,9 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, er func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, error) { log := lg.FromContext(ctx) ctx = options.NewContext(ctx, src.Options) - dbCfg, err := pgxpool.ParseConfig(src.Location) + poolCfg, err := getPoolConfig(src, false) if err != nil { - return nil, errw(err) - } - - if src.Catalog != "" && src.Catalog != dbCfg.ConnConfig.Database { - // The catalog differs from the database in the connection string. - // OOTB, Postgres doesn't support cross-database references. So, - // we'll need to change the connection string to use the catalog - // as the database. Note that we don't modify src.Location, but it's - // not entirely clear if that's the correct approach. Are there any - // downsides to modifying it (as long as the modified Location is not - // persisted back to config)? - var u *dburl.URL - if u, err = dburl.Parse(src.Location); err != nil { - return nil, errw(err) - } - - u.Path = src.Catalog - connStr := u.String() - dbCfg, err = pgxpool.ParseConfig(connStr) - if err != nil { - return nil, errw(err) - } - log.Debug("Using catalog as database in connection string", - lga.Src, src, - lga.Catalog, src.Catalog, - lga.Conn, location.Redact(connStr), - ) + return nil, err } var opts []stdlib.OptionOpenDB @@ -196,7 +170,7 @@ func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, erro log.Debug("Setting default schema (search_path) on Postgres DB connection", lga.Src, src, - lga.Conn, location.Redact(dbCfg.ConnString()), + lga.Conn, location.Redact(poolCfg.ConnString()), lga.Catalog, src.Catalog, lga.Schema, src.Schema, lga.Old, oldSearchPath, @@ -207,9 +181,9 @@ func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, erro })) } - dbCfg.ConnConfig.ConnectTimeout = driver.OptConnOpenTimeout.Get(src.Options) + poolCfg.ConnConfig.ConnectTimeout = driver.OptConnOpenTimeout.Get(src.Options) - db := stdlib.OpenDB(*dbCfg.ConnConfig, opts...) + db := stdlib.OpenDB(*poolCfg.ConnConfig, opts...) driver.ConfigureDB(ctx, db, src.Options) return db, nil @@ -430,6 +404,32 @@ ORDER BY datname` return catalogs, nil } +// SchemaExists implements driver.SQLDriver. +func (d *driveri) SchemaExists(ctx context.Context, db sqlz.DB, schma string) (bool, error) { + if schma == "" { + return false, nil + } + + const q = `SELECT COUNT(schema_name) FROM information_schema.schemata + WHERE schema_name = $1 AND catalog_name = current_database()` + + var count int + return count > 0, errw(db.QueryRowContext(ctx, q, schma).Scan(&count)) +} + +// CatalogExists implements driver.SQLDriver. +func (d *driveri) CatalogExists(ctx context.Context, db sqlz.DB, catalog string) (bool, error) { + if catalog == "" { + return false, nil + } + + const q = `SELECT COUNT(datname) FROM pg_catalog.pg_database +WHERE datistemplate = FALSE AND datallowconn = TRUE AND datname = $1` + + var count int + return count > 0, errw(db.QueryRowContext(ctx, q, catalog).Scan(&count)) +} + // AlterTableRename implements driver.SQLDriver. func (d *driveri) AlterTableRename(ctx context.Context, db sqlz.DB, tbl, newName string) error { q := fmt.Sprintf(`ALTER TABLE %q RENAME TO %q`, tbl, newName) @@ -549,6 +549,43 @@ WHERE table_name = $1` return count == 1, nil } +// ListTableNames implements driver.SQLDriver. +func (d *driveri) ListTableNames(ctx context.Context, db sqlz.DB, schma string, tables, views bool) ([]string, error) { + var tblClause string + switch { + case tables && views: + tblClause = " AND (table_type = 'BASE TABLE' OR table_type = 'VIEW')" + case tables: + tblClause = " AND table_type = 'BASE TABLE'" + case views: + tblClause = " AND table_type = 'VIEW'" + default: + return []string{}, nil + } + + var args []any + q := "SELECT table_name FROM information_schema.tables WHERE table_schema = " + if schma == "" { + q += "current_schema()" + } else { + q += "$1" + args = append(args, schma) + } + q += tblClause + " ORDER BY table_name" + + rows, err := db.QueryContext(ctx, q, args...) + if err != nil { + return nil, errw(err) + } + + names, err := sqlz.RowsScanColumn[string](ctx, rows) + if err != nil { + return nil, errw(err) + } + + return names, nil +} + // DropTable implements driver.SQLDriver. func (d *driveri) DropTable(ctx context.Context, db sqlz.DB, tbl tablefq.T, ifExists bool) error { var stmt string @@ -741,15 +778,66 @@ func (d *driveri) RecordMeta(ctx context.Context, colTypes []*sql.ColumnType) ( return recMeta, mungeFn, nil } +// getPoolConfig returns the native postgres [*pgxpool.Config] for src, applying +// src's fields, such as [source.Source.Catalog] as appropriate. If +// includeConnTimeout is true, then 'connect_timeout' is included in the +// returned config; this is provided as an option, because the connection +// timeout is sometimes better handled via [context.WithTimeout]. +func getPoolConfig(src *source.Source, includeConnTimeout bool) (*pgxpool.Config, error) { + poolCfg, err := pgxpool.ParseConfig(src.Location) + if err != nil { + return nil, errw(err) + } + + if src.Catalog != "" && src.Catalog != poolCfg.ConnConfig.Database { + // The catalog differs from the database in the connection string. + // OOTB, Postgres doesn't support cross-database references. So, + // we'll need to change the connection string to use the catalog + // as the database. Note that we don't modify src.Location, but it's + // not entirely clear if that's the correct approach. Are there any + // downsides to modifying it (as long as the modified Location is not + // persisted back to config)? + var u *dburl.URL + if u, err = dburl.Parse(src.Location); err != nil { + return nil, errw(err) + } + + u.Path = src.Catalog + connStr := u.String() + poolCfg, err = pgxpool.ParseConfig(connStr) + if err != nil { + return nil, errw(err) + } + } + + if includeConnTimeout { + srcTimeout := driver.OptConnOpenTimeout.Get(src.Options) + // Only set connect_timeout if it's non-zero and differs from the + // already-configured value. + // REVISIT: We should actually always set it, otherwise the user's + // envar PGCONNECT_TIMEOUT may override it? + + if srcTimeout > 0 || poolCfg.ConnConfig.ConnectTimeout != srcTimeout { + var u *dburl.URL + if u, err = dburl.Parse(poolCfg.ConnString()); err != nil { + return nil, errw(err) + } + + q := u.Query() + q.Set("connect_timeout", strconv.Itoa(int(srcTimeout.Seconds()))) + u.RawQuery = q.Encode() + poolCfg, err = pgxpool.ParseConfig(u.String()) + if err != nil { + return nil, errw(err) + } + } + } + + return poolCfg, nil +} + // doRetry executes fn with retry on isErrTooManyConnections. func doRetry(ctx context.Context, fn func() error) error { maxRetryInterval := driver.OptMaxRetryInterval.Get(options.FromContext(ctx)) return retry.Do(ctx, maxRetryInterval, fn, isErrTooManyConnections) } - -// tblfmt formats a table name for use in a query. The arg can be a string, -// or a tablefq.T. -func tblfmt[T string | tablefq.T](tbl T) string { - tfq := tablefq.From(tbl) - return tfq.Render(stringz.DoubleQuote) -} diff --git a/drivers/postgres/render.go b/drivers/postgres/render.go index 9a360432..e02659e4 100644 --- a/drivers/postgres/render.go +++ b/drivers/postgres/render.go @@ -8,8 +8,17 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/schema" + "github.com/neilotoole/sq/libsq/core/stringz" + "github.com/neilotoole/sq/libsq/core/tablefq" ) +// tblfmt formats a table name for use in a query. The arg can be a string, +// or a tablefq.T. +func tblfmt[T string | tablefq.T](tbl T) string { + tfq := tablefq.From(tbl) + return tfq.Render(stringz.DoubleQuote) +} + func dbTypeNameFromKind(knd kind.Kind) string { switch knd { //nolint:exhaustive default: diff --git a/drivers/postgres/tools.go b/drivers/postgres/tools.go new file mode 100644 index 00000000..f43e130a --- /dev/null +++ b/drivers/postgres/tools.go @@ -0,0 +1,332 @@ +package postgres + +import ( + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/execz" + "github.com/neilotoole/sq/libsq/source" +) + +// REVISIT: DumpCatalogCmd and DumpClusterCmd could be methods on driver.SQLDriver. + +// TODO: Unify DumpCatalogCmd and DumpClusterCmd, as they're almost identical, probably +// in the form: +// DumpCatalogCmd(src *source.Source, all bool) (cmd []string, err error). + +// ToolParams are parameters for postgres tools such as pg_dump and pg_restore. +// +// - https://www.postgresql.org/docs/9.6/app-pgdump.html +// - https://www.postgresql.org/docs/9.6/app-pgrestore.html. +// - https://www.postgresql.org/docs/9.6/app-pg-dumpall.html +// - https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp +// +// Not every flag is applicable to all tools. +type ToolParams struct { + // File is the path to the file. + File string + + // Verbose indicates verbose output (progress). + Verbose bool + + // NoOwner won't output commands to set ownership of objects; the source's + // connection user will own all objects. This also sets the --no-acl flag. + // Maybe NoOwner should be named "no security" or similar? + NoOwner bool + + // LongFlags indicates whether to use long flags, e.g. --no-owner instead + // of -O. + LongFlags bool +} + +func (p *ToolParams) flag(name string) string { + if p.LongFlags { + return flagsLong[name] + } + return flagsShort[name] +} + +// DumpCatalogCmd returns the shell command to execute pg_dump for src. +// Example output: +// +// pg_dump -Fc -d postgres://alice:vNgR6R@db.acme.com:5432/sales sales.dump +// +// Reference: +// +// - https://www.postgresql.org/docs/9.6/app-pgdump.html +// - https://www.postgresql.org/docs/9.6/app-pgrestore.html +// +// See also: [RestoreCatalogCmd]. +func DumpCatalogCmd(src *source.Source, p *ToolParams) (*execz.Cmd, error) { + // - https://www.postgresql.org/docs/9.6/app-pgdump.html + // - https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp + // - https://gist.github.com/vielhuber/96eefdb3aff327bdf8230d753aaee1e1 + + cfg, err := getPoolConfig(src, true) + if err != nil { + return nil, err + } + + cmd := &execz.Cmd{Name: "pg_dump"} + + if p.Verbose { + cmd.ProgressFromStderr = true + cmd.Args = append(cmd.Args, p.flag(flagVerbose)) + } + + cmd.Args = append(cmd.Args, p.flag(flagClean), p.flag(flagIfExists)) // TODO: should be optional + cmd.Args = append(cmd.Args, p.flag(flagFormatCustomArchive)) + + if p.NoOwner { + // You might expect we'd add --no-owner, but if we're outputting a custom + // archive (-Fc), then --no-owner is the default. From the pg_dump docs: + // + // This option is ignored when emitting an archive (non-text) output file. + // For the archive formats, you can specify the option when you call pg_restore. + // + // If we ultimately allow non-archive formats, then we'll need to add + // special handling for --no-owner. + cmd.Args = append(cmd.Args, p.flag(flagNoACL)) + } + cmd.Args = append(cmd.Args, p.flag(flagDBName), cfg.ConnString()) + if p.File != "" { + cmd.UsesOutputFile = p.File + cmd.Args = append(cmd.Args, p.flag(flagFile), p.File) + } + return cmd, nil +} + +// RestoreCatalogCmd returns the shell command to restore a pg catalog (db) from +// a dump produced by pg_dump ([DumpClusterCmd]). Example command: +// +// pg_restore -d postgres://alice:vNgR6R@db.acme.com:5432/sales sales.dump +// +// Reference: +// +// - https://www.postgresql.org/docs/9.6/app-pgrestore.html +// - https://www.postgresql.org/docs/9.6/app-pgdump.html +// +// See also: [DumpCatalogCmd]. +func RestoreCatalogCmd(src *source.Source, p *ToolParams) (*execz.Cmd, error) { + // - https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp + // - https://gist.github.com/vielhuber/96eefdb3aff327bdf8230d753aaee1e1 + + cfg, err := getPoolConfig(src, true) + if err != nil { + return nil, err + } + + cmd := &execz.Cmd{Name: "pg_restore", CmdDirPath: true} + if p.Verbose { + cmd.ProgressFromStderr = true + cmd.Args = append(cmd.Args, p.flag(flagVerbose)) + } + if p.NoOwner { + // NoOwner sets both --no-owner and --no-acl. Maybe these should + // be separate options. + cmd.Args = append(cmd.Args, p.flag(flagNoACL), p.flag(flagNoOwner)) // -O is --no-owner + } + + cmd.Args = append(cmd.Args, + p.flag(flagClean), + p.flag(flagIfExists), + p.flag(flagCreate), + p.flag(flagDBName), + cfg.ConnString(), + ) + + if p.File != "" { + cmd.UsesOutputFile = p.File + cmd.Args = append(cmd.Args, p.File) + } + + return cmd, nil +} + +// DumpClusterCmd returns the shell command to execute pg_dumpall for src. +// Example output (components concatenated with space): +// +// PGPASSWORD=vNgR6R pg_dumpall -w -l sales -d postgres://alice:vNgR6R@db.acme.com:5432/sales -f cluster.dump +// +// Note that the dump produced by pg_dumpall is executed by psql, not pg_restore. +// +// - https://www.postgresql.org/docs/9.6/app-pg-dumpall.html +// - https://www.postgresql.org/docs/9.6/app-psql.html +// - https://www.postgresql.org/docs/9.6/app-pgdump.html +// - https://www.postgresql.org/docs/9.6/app-pgrestore.html +// - https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp +// +// See also: [RestoreClusterCmd]. +func DumpClusterCmd(src *source.Source, p *ToolParams) (*execz.Cmd, error) { + // - https://www.postgresql.org/docs/9.6/app-pg-dumpall.html + // - https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp + + cfg, err := getPoolConfig(src, true) + if err != nil { + return nil, err + } + + cmd := &execz.Cmd{ + Name: "pg_dumpall", + CmdDirPath: true, + Env: []string{"PGPASSWORD=" + cfg.ConnConfig.Password}, + } + + // FIXME: need mechanism to indicate that env contains password + if p.Verbose { + cmd.ProgressFromStderr = true + cmd.Args = append(cmd.Args, p.flag(flagVerbose)) + } + cmd.Args = append(cmd.Args, p.flag(flagClean), p.flag(flagIfExists)) // TODO: should be optional + if p.NoOwner { + // NoOwner sets both --no-owner and --no-acl. Maybe these should + // be separate options. + cmd.Args = append(cmd.Args, p.flag(flagNoACL), p.flag(flagNoOwner)) + } + + cmd.Args = append(cmd.Args, + p.flag(flagNoPassword), + p.flag(flagDatabase), cfg.ConnConfig.Database, + p.flag(flagDBName), cfg.ConnString(), + ) + + if p.File != "" { + cmd.Args = append(cmd.Args, p.flag(flagFile), p.File) + } + + return cmd, nil +} + +// RestoreClusterCmd returns the shell command to restore a pg cluster from a +// dump produced by pg_dumpall (DumpClusterCmd). Note that the dump produced +// by pg_dumpall is executed by psql, not pg_restore. Example command: +// +// psql -d postgres://alice:vNgR6R@db.acme.com:5432/sales -f sales.dump +// +// Reference: +// +// - https://www.postgresql.org/docs/9.6/app-pg-dumpall.html +// - https://www.postgresql.org/docs/9.6/app-psql.html +// - https://www.postgresql.org/docs/9.6/app-pgdump.html +// - https://www.postgresql.org/docs/9.6/app-pgrestore.html +// - https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp +// +// See also: [DumpClusterCmd]. +func RestoreClusterCmd(src *source.Source, p *ToolParams) (*execz.Cmd, error) { + // - https://gist.github.com/vielhuber/96eefdb3aff327bdf8230d753aaee1e1 + cfg, err := getPoolConfig(src, true) + if err != nil { + return nil, err + } + + cmd := &execz.Cmd{Name: "psql", CmdDirPath: true} + if p.Verbose { + cmd.ProgressFromStderr = true + cmd.Args = append(cmd.Args, p.flag(flagVerbose)) + } + cmd.Args = append(cmd.Args, p.flag(flagDBName), cfg.ConnString()) + if p.File != "" { + cmd.Args = append(cmd.Args, p.flag(flagFile), p.File) + } + return cmd, nil +} + +type ExecToolParams struct { + // ScriptFile is the path to the script file. + // Only one of ScriptFile or CmdString will be set. + ScriptFile string + + // CmdString is the literal SQL command string. + CmdString string + + // Verbose indicates verbose output (progress). + // Only one of ScriptFile or CmdString will be set. + Verbose bool + + // LongFlags indicates whether to use long flags, e.g. --file instead of -f. + LongFlags bool +} + +func (p *ExecToolParams) flag(name string) string { + if p.LongFlags { + return flagsLong[name] + } + return flagsShort[name] +} + +// ExecCmd returns the shell command to execute psql with a script file +// or command string. Example command: +// +// psql -d postgres://alice:vNgR6R@db.acme.com:5432/sales -f query.sql +// +// See: https://www.postgresql.org/docs/9.6/app-psql.html. +func ExecCmd(src *source.Source, p *ExecToolParams) (*execz.Cmd, error) { + cfg, err := getPoolConfig(src, true) + if err != nil { + return nil, err + } + + cmd := &execz.Cmd{Name: "psql"} + if !p.Verbose { + cmd.Args = append(cmd.Args, p.flag(flagQuiet)) + } + cmd.Args = append(cmd.Args, p.flag(flagDBName), cfg.ConnString()) + if p.ScriptFile != "" { + cmd.Args = append(cmd.Args, p.flag(flagFile), p.ScriptFile) + } + if p.CmdString != "" { + if p.ScriptFile != "" { + return nil, errz.Errorf("only one of --file or --command may be set") + } + cmd.Args = append(cmd.Args, p.flag(flagCommand), p.CmdString) + } + + return cmd, nil +} + +// flags for pg_dump and pg_restore programs. +const ( + flagNoOwner = "--no-owner" + flagVerbose = "--verbose" + flagQuiet = "--quiet" + flagNoACL = "--no-acl" + flagCreate = "--create" + flagDBName = "--dbname" + flagDatabase = "--database" + flagFormatCustomArchive = "--format=custom" + flagIfExists = "--if-exists" + flagClean = "--clean" + flagNoPassword = "--no-password" + flagFile = "--file" + flagCommand = "--command" +) + +var flagsLong = map[string]string{ + flagNoOwner: flagNoOwner, + flagVerbose: flagVerbose, + flagQuiet: flagQuiet, + flagNoACL: flagNoACL, + flagCreate: flagCreate, + flagDBName: flagDBName, + flagIfExists: flagIfExists, + flagFormatCustomArchive: flagFormatCustomArchive, + flagClean: flagClean, + flagNoPassword: flagNoPassword, + flagDatabase: flagDatabase, + flagFile: flagFile, + flagCommand: flagCommand, +} + +var flagsShort = map[string]string{ + flagNoOwner: "-O", + flagVerbose: "-v", + flagQuiet: "-q", + flagNoACL: "-x", + flagCreate: "-C", + flagClean: "-c", + flagDBName: "-d", + flagFormatCustomArchive: "-Fc", + flagIfExists: flagIfExists, + flagNoPassword: "-w", + flagDatabase: "-l", + flagFile: "-f", + flagCommand: "-c", +} diff --git a/drivers/sqlite3/sqlite3.go b/drivers/sqlite3/sqlite3.go index be17dfb4..a537c786 100644 --- a/drivers/sqlite3/sqlite3.go +++ b/drivers/sqlite3/sqlite3.go @@ -682,6 +682,18 @@ func (d *driveri) CurrentSchema(ctx context.Context, db sqlz.DB) (string, error) return name, nil } +// SchemaExists implements driver.SQLDriver. +func (d *driveri) SchemaExists(ctx context.Context, db sqlz.DB, schma string) (bool, error) { + if schma == "" { + return false, nil + } + + const q = `SELECT COUNT(name) FROM pragma_database_list WHERE name = ?` + + var count int + return count > 0, errw(db.QueryRowContext(ctx, q, schma).Scan(&count)) +} + // ListSchemas implements driver.SQLDriver. func (d *driveri) ListSchemas(ctx context.Context, db sqlz.DB) ([]string, error) { log := lg.FromContext(ctx) @@ -709,6 +721,44 @@ func (d *driveri) ListSchemas(ctx context.Context, db sqlz.DB) ([]string, error) return schemas, nil } +// ListTableNames implements driver.SQLDriver. The returned names exclude +// any sqlite_ internal tables. +func (d *driveri) ListTableNames(ctx context.Context, db sqlz.DB, schma string, tables, views bool) ([]string, error) { + var tblClause string + switch { + case tables && views: + tblClause = " WHERE (type = 'table' OR type = 'view')" + case tables: + tblClause = " WHERE type = 'table'" + case views: + tblClause = " WHERE type = 'view'" + default: + return []string{}, nil + } + + tblClause += " AND name NOT LIKE 'sqlite_%'" + + q := "SELECT name FROM " + if schma == "" { + q += "sqlite_master" + } else { + q += stringz.DoubleQuote(schma) + ".sqlite_master" + } + q += tblClause + " ORDER BY name" + + rows, err := db.QueryContext(ctx, q) + if err != nil { + return nil, errw(err) + } + + names, err := sqlz.RowsScanColumn[string](ctx, rows) + if err != nil { + return nil, errw(err) + } + + return names, nil +} + // ListSchemaMetadata implements driver.SQLDriver. // The returned metadata.Schema instances will have a Catalog // value of "default", and an empty Owner value. @@ -728,6 +778,12 @@ func (d *driveri) ListSchemaMetadata(ctx context.Context, db sqlz.DB) ([]*metada return schemas, nil } +// CatalogExists implements driver.SQLDriver. SQLite does not support catalogs, +// so this method always returns an error. +func (d *driveri) CatalogExists(_ context.Context, _ sqlz.DB, _ string) (bool, error) { + return false, errz.New("sqlite3: catalog mechanism not supported") +} + // CurrentCatalog implements driver.SQLDriver. SQLite does not support catalogs, // so this method returns an error. func (d *driveri) CurrentCatalog(_ context.Context, _ sqlz.DB) (string, error) { diff --git a/drivers/sqlserver/sqlserver.go b/drivers/sqlserver/sqlserver.go index 2625a1aa..51b6666b 100644 --- a/drivers/sqlserver/sqlserver.go +++ b/drivers/sqlserver/sqlserver.go @@ -402,6 +402,19 @@ func (d *driveri) ListSchemas(ctx context.Context, db sqlz.DB) ([]string, error) return schemas, nil } +// SchemaExists implements driver.SQLDriver. +func (d *driveri) SchemaExists(ctx context.Context, db sqlz.DB, schma string) (bool, error) { + if schma == "" { + return false, nil + } + + const q = `SELECT COUNT(SCHEMA_NAME) FROM INFORMATION_SCHEMA.SCHEMATA +WHERE SCHEMA_NAME = @p1 AND CATALOG_NAME = DB_NAME()` + + var count int + return count > 0, errw(db.QueryRowContext(ctx, q, schma).Scan(&count)) +} + // ListSchemaMetadata implements driver.SQLDriver. func (d *driveri) ListSchemaMetadata(ctx context.Context, db sqlz.DB) ([]*metadata.Schema, error) { log := lg.FromContext(ctx) @@ -450,6 +463,18 @@ func (d *driveri) CurrentCatalog(ctx context.Context, db sqlz.DB) (string, error return name, nil } +// CatalogExists implements driver.SQLDriver. +func (d *driveri) CatalogExists(ctx context.Context, db sqlz.DB, catalog string) (bool, error) { + if catalog == "" { + return false, nil + } + + const q = `SELECT COUNT(name) FROM sys.databases WHERE name = @p1` + + var count int + return count > 0, errw(db.QueryRowContext(ctx, q, catalog).Scan(&count)) +} + // ListCatalogs implements driver.SQLDriver. func (d *driveri) ListCatalogs(ctx context.Context, db sqlz.DB) ([]string, error) { catalogs := make([]string, 1, 3) @@ -457,9 +482,7 @@ func (d *driveri) ListCatalogs(ctx context.Context, db sqlz.DB) ([]string, error return nil, errw(err) } - const q = `SELECT name FROM sys.databases -WHERE name != DB_NAME() -ORDER BY name` + const q = `SELECT name FROM sys.databases WHERE name != DB_NAME() ORDER BY name` rows, err := db.QueryContext(ctx, q) if err != nil { @@ -511,6 +534,44 @@ func (d *driveri) DropSchema(ctx context.Context, db sqlz.DB, schemaName string) return nil } +// ListTableNames implements driver.SQLDriver. +func (d *driveri) ListTableNames(ctx context.Context, db sqlz.DB, schma string, tables, views bool) ([]string, error) { + var tblClause string + + switch { + case tables && views: + tblClause = " AND (TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW')" + case tables: + tblClause = " AND TABLE_TYPE = 'BASE TABLE'" + case views: + tblClause = " AND TABLE_TYPE = 'VIEW'" + default: + return []string{}, nil + } + + var args []any + q := "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = " + if schma == "" { + q += "SCHEMA_NAME()" + } else { + q += "@p1" + args = append(args, schma) + } + q += tblClause + " ORDER BY TABLE_NAME" + + rows, err := db.QueryContext(ctx, q, args...) + if err != nil { + return nil, errw(err) + } + + names, err := sqlz.RowsScanColumn[string](ctx, rows) + if err != nil { + return nil, errw(err) + } + + return names, nil +} + // CreateTable implements driver.SQLDriver. func (d *driveri) CreateTable(ctx context.Context, db sqlz.DB, tblDef *schema.Table) error { stmt := buildCreateTableStmt(tblDef) diff --git a/drivers/userdriver/grip.go b/drivers/userdriver/grip.go index 32ba0d2f..e544b96b 100644 --- a/drivers/userdriver/grip.go +++ b/drivers/userdriver/grip.go @@ -42,20 +42,20 @@ func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Tab // SourceMetadata implements driver.Grip. func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - meta, err := g.impl.SourceMetadata(ctx, noSchema) + md, err := g.impl.SourceMetadata(ctx, noSchema) if err != nil { return nil, err } - meta.Handle = g.src.Handle - meta.Location = g.src.Location - meta.Name, err = location.Filename(g.src.Location) + md.Handle = g.src.Handle + md.Location = g.src.Location + md.Name, err = location.Filename(g.src.Location) if err != nil { return nil, err } - meta.FQName = meta.Name - return meta, nil + md.FQName = md.Name + return md, nil } // Close implements driver.Grip. diff --git a/go.mod b/go.mod index dc4684f9..810c1a27 100644 --- a/go.mod +++ b/go.mod @@ -56,7 +56,9 @@ require ( github.com/Masterminds/semver/v3 v3.2.1 // indirect github.com/VividCortex/ewma v1.2.0 // indirect github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect + github.com/alecthomas/chroma v0.10.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/huandu/xstrings v1.4.0 // indirect diff --git a/go.sum b/go.sum index 74cf1651..7d349298 100644 --- a/go.sum +++ b/go.sum @@ -25,6 +25,8 @@ github.com/a8m/tree v0.0.0-20240104212747-2c8764a5f17e h1:KMVieI1/Ub++GYfnhyFPoG github.com/a8m/tree v0.0.0-20240104212747-2c8764a5f17e/go.mod h1:j5astEcUkZQX8lK+KKlQ3NRQ50f4EE8ZjyZpCz3mrH4= github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d h1:licZJFw2RwpHMqeKTCYkitsPqHNxTmd4SNR5r94FGM8= github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat636LX7Bqt5lYEZ27JNDcqxfjdBQuJ/MM4CN/Lzo= +github.com/alecthomas/chroma v0.10.0 h1:7XDcGkCQopCNKjZHfYrNLraA+M7e0fMiJ/Mfikbfjek= +github.com/alecthomas/chroma v0.10.0/go.mod h1:jtJATyUxlIORhUOFNA9NZDWGAQ8wpxQQqNSB4rjA/1s= github.com/alessio/shellescape v1.4.2 h1:MHPfaU+ddJ0/bYWpgIeUnQUqKrlJ1S7BfEYPM4uEoM0= github.com/alessio/shellescape v1.4.2/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= @@ -38,6 +40,9 @@ github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/ecnepsnai/osquery v1.0.1 h1:i96n/3uqcafKZtRYmXVNqekKbfrIm66q179mWZ/Y2Aw= diff --git a/huzzah.txt b/huzzah.txt new file mode 100644 index 00000000..a298f64e --- /dev/null +++ b/huzzah.txt @@ -0,0 +1,2 @@ +payment_id customer_id name staff_id rental_id amount payment_date last_update +1 71 Alice 71 0 3 500 2023-12-04T04:29:00Z 2023-12-04T04:29:00Z diff --git a/libsq/ast/ast.go b/libsq/ast/ast.go index c5cc12e0..0c66f3ce 100644 --- a/libsq/ast/ast.go +++ b/libsq/ast/ast.go @@ -198,14 +198,15 @@ func errorf(format string, v ...any) error { } // ParseCatalogSchema parses a string of the form 'catalog.schema' -// and returns the catalog and schema. An error is returned if the schema -// is empty (but catalog may be empty). Whitespace and quotes are handled +// and returns the catalog and schema. It is permissible for one of the +// components to be empty (but not both). Whitespace and quotes are handled // correctly. // // Examples: // // `catalog.schema` -> "catalog", "schema", nil -// `schema` -> "", "schema", nil +// `catalog.` -> "catalog", "", nil +// `schema` -> "", "schema", nil // `"my catalog"."my schema"` -> "my catalog", "my schema", nil // // An error is returned if s is empty. @@ -219,8 +220,18 @@ func ParseCatalogSchema(s string) (catalog, schema string, err error) { // We'll hijack the existing parser code. A value "catalog.schema" is // not valid, but ".catalog.schema" works as a selector. - + // + // Being that we accept "catalog." as valid (indicating the default schema), + // we'll use a hack to make the parser work: we append a const string to + // the input (which the parser will think is the schema name), and later + // on we check for that const string, and set schema to empty string if + // that const is found. + const schemaNameHack = "DEFAULT_SCHEMA_HACK_be8hx64wd45vxusdebez2e6tega8ussy" sel := "." + s + if strings.HasSuffix(s, ".") { + sel += schemaNameHack + } + a, err := Parse(lg.Discard(), sel) if err != nil { return "", "", errz.Errorf(errTpl, s) @@ -230,7 +241,9 @@ func ParseCatalogSchema(s string) (catalog, schema string, err error) { return "", "", errz.Errorf(errTpl, s) } - tblSel := NewInspector(a).FindFirstTableSelector() + insp := NewInspector(a) + + tblSel := insp.FindFirstTableSelector() if tblSel == nil { return "", "", errz.Errorf(errTpl, s) } @@ -243,6 +256,8 @@ func ParseCatalogSchema(s string) (catalog, schema string, err error) { } if schema == "" { return "", "", errz.Errorf(errTpl, s) + } else if schema == schemaNameHack { + schema = "" } return catalog, schema, nil diff --git a/libsq/ast/ast_test.go b/libsq/ast/ast_test.go index e58da1f4..95fb6325 100644 --- a/libsq/ast/ast_test.go +++ b/libsq/ast/ast_test.go @@ -16,7 +16,10 @@ func TestParseCatalogSchema(t *testing.T) { wantErr bool }{ {in: "", wantErr: true}, + {in: ".", wantErr: true}, {in: "dbo", wantCatalog: "", wantSchema: "dbo"}, + {in: "sakila.", wantCatalog: "sakila", wantSchema: ""}, + {in: ".dbo", wantErr: true}, {in: "sakila.dbo", wantCatalog: "sakila", wantSchema: "dbo"}, {in: `"my catalog"."my schema"`, wantCatalog: "my catalog", wantSchema: "my schema"}, {in: `"my catalog""."my schema"`, wantErr: true}, diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index f4f9f31f..c5b33d5a 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -266,6 +266,31 @@ func Chain(err error) []error { return errs } +// ExitCoder is an interface that an error type can implement to indicate +// that the program should exit with a specific status code. +// In particular, note that *exec.ExitError implements this interface. +type ExitCoder interface { + // ExitCode returns the exit code indicated by the error, or -1 if + // the error does not indicate a particular exit code. + ExitCode() int +} + +// ExitCode returns the exit code of the first error in err's chain +// that implements ExitCoder, otherwise -1. +func ExitCode(err error) (code int) { + if err == nil { + return -1 + } + + chain := Chain(err) + for i := range chain { + if coder, ok := chain[i].(ExitCoder); ok { //nolint:errorlint + return coder.ExitCode() + } + } + return -1 +} + // SprintTreeTypes returns a string representation of err's type tree. // A multi-error is represented as a slice of its children. func SprintTreeTypes(err error) string { diff --git a/libsq/core/execz/execz.go b/libsq/core/execz/execz.go new file mode 100644 index 00000000..9f8b27d8 --- /dev/null +++ b/libsq/core/execz/execz.go @@ -0,0 +1,332 @@ +// Package execz builds on stdlib os/exec. +package execz + +import ( + "bytes" + "context" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/loz" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/neilotoole/sq/libsq/core/stringz" + "github.com/neilotoole/sq/libsq/core/termz" + "github.com/neilotoole/sq/libsq/source/location" +) + +// Cmd represents an external command being prepared or run. +type Cmd struct { + // Stdin is the command's stdin. If nil, [os.Stdin] is used. + Stdin io.Reader + + // Stdout is the command's stdout. If nil, [os.Stdout] is used. + Stdout io.Writer + + // Stderr is the command's stderr. If nil, [os.Stderr] is used. + Stderr io.Writer + + // Name is the executable name, e.g. "pg_dump". + Name string + + // Label is a human-readable label for the command, e.g. "@sakila: dump". + // If empty, [Cmd.Name] is used. + Label string + + // ErrPrefix is the prefix to use for error messages. + ErrPrefix string + + // UsesOutputFile indicates that the command its output to this filepath + // instead of stdout. If empty, stdout is being used. + UsesOutputFile string + + // Args is the set of args to the command. + Args []string + + // Env is the set of environment variables to set for the command. + Env []string + + // NoProgress indicates that progress messages should not be output. + NoProgress bool + + // ProgressFromStderr indicates that the command outputs progress messages + // on stderr. + ProgressFromStderr bool + + // CmdDirPath controls whether the command's PATH will include the parent dir + // of the command. This allows the command to access sibling commands in the + // same dir, e.g. "pg_dumpall" needs to invoke "pg_dump". + CmdDirPath bool +} + +// String returns what command would look like if executed in a shell. +// Note that the returned string could contain sensitive information such as +// passwords, so it's not safe for logging. Instead, see [Cmd.LogValue]. +func (c *Cmd) String() string { + if c == nil { + return "" + } + + sb := strings.Builder{} + + for i := range c.Env { + if i > 0 { + sb.WriteRune(' ') + } + sb.WriteString(stringz.ShellEscape(c.Env[i])) + } + + if sb.Len() > 0 { + sb.WriteRune(' ') + } + + sb.WriteString(stringz.ShellEscape(c.Name)) + + for i := range c.Args { + sb.WriteRune(' ') + sb.WriteString(stringz.ShellEscape(c.Args[i])) + } + + return sb.String() +} + +// redactedCmd returns a redacted rendering of c, suitable for logging (but +// not execution). If escape is true, the string is also shell-escaped. +func (c *Cmd) redactedCmd(escape bool) string { + if c == nil { + return "" + } + + env := c.redactedEnv(escape) + args := c.redactedArgs(escape) + + switch { + case len(env) == 0 && len(args) == 0: + return c.Name + case len(env) == 0: + return c.Name + " " + strings.Join(args, " ") + case len(args) == 0: + return strings.Join(env, " ") + " " + c.Name + default: + return strings.Join(env, " ") + " " + c.Name + " " + strings.Join(args, " ") + } +} + +// redactedEnv returns c's env with sensitive values redacted. +// If escape is true, the values are also shell-escaped. +func (c *Cmd) redactedEnv(escape bool) []string { + if c == nil || len(c.Env) == 0 { + return []string{} + } + + envars := make([]string, len(c.Env)) + for i := range c.Env { + parts := strings.SplitN(c.Env[i], "=", 2) + if len(parts) < 2 { + // Shouldn't happen, but just in case. + if escape { + envars[i] = stringz.ShellEscape(c.Env[i]) + } else { + envars[i] = c.Env[i] + } + continue + } + + // If the envar value is a SQL or HTTP location, redact it. + if location.TypeOf(parts[1]).IsURL() { + parts[1] = location.Redact(parts[1]) + } + + if escape { + envars[i] = parts[0] + "=" + stringz.ShellEscape(parts[1]) + } else { + envars[i] = parts[0] + "=" + parts[1] + } + } + return envars +} + +// redactedArgs returns c's args with sensitive values redacted. +// If escape is true, the values are also shell-escaped. +func (c *Cmd) redactedArgs(escape bool) []string { + if c == nil || len(c.Args) == 0 { + return []string{} + } + + args := make([]string, len(c.Args)) + for i := range c.Args { + if location.TypeOf(c.Args[i]).IsURL() { + args[i] = location.Redact(c.Args[i]) + if escape { + args[i] = stringz.ShellEscape(args[i]) + } + continue + } + + if escape { + args[i] = stringz.ShellEscape(c.Args[i]) + } else { + args[i] = c.Args[i] + } + } + return args +} + +var _ slog.LogValuer = (*Cmd)(nil) + +// LogValue implements [slog.LogValuer]. It redacts sensitive information +// (passwords etc.) from URL-like values. +func (c *Cmd) LogValue() slog.Value { + if c == nil { + return slog.Value{} + } + + attrs := []slog.Attr{ + slog.String("name", c.Name), + slog.String("exec", c.redactedCmd(false)), + } + + return slog.GroupValue(attrs...) +} + +// Exec executes cmd. +func Exec(ctx context.Context, cmd *Cmd) (err error) { + log := lg.FromContext(ctx) + + defer func() { + if err != nil && cmd.UsesOutputFile != "" { + // If an error occurred, we want to remove the output file. + lg.WarnIfError(lg.FromContext(ctx), lgm.RemoveFile, os.Remove(cmd.UsesOutputFile)) + } + }() + + if cmd.Stdin == nil { + cmd.Stdin = os.Stdin + } + if cmd.Stdout == nil { + cmd.Stdout = os.Stdout + } + if cmd.Stderr == nil { + cmd.Stderr = os.Stderr + } + + execCmd := exec.CommandContext(ctx, cmd.Name, cmd.Args...) //nolint:gosec + if cmd.CmdDirPath { + execCmd.Env = append(execCmd.Env, "PATH="+filepath.Dir(execCmd.Path)) + } + execCmd.Env = append(execCmd.Env, cmd.Env...) + execCmd.Stdin = cmd.Stdin + execCmd.Stdout = cmd.Stdout + + stderrBuf := &bytes.Buffer{} + execCmd.Stderr = io.MultiWriter(stderrBuf, cmd.Stderr) + + switch { + case cmd.ProgressFromStderr: + log.Warn("It's cmd.ProgressFromStderr") + // TODO: We really want to print stderr. + case cmd.UsesOutputFile != "": + log.Warn("It's cmd.UsesOutputFile") + // Truncate the file, ignoring any error (e.g. if it doesn't exist). + _ = os.Truncate(cmd.UsesOutputFile, 0) + + if !cmd.NoProgress { + bar := progress.FromContext(ctx).NewFilesizeCounter( + loz.NonEmptyOf(cmd.Label, cmd.Name), + nil, + cmd.UsesOutputFile, + progress.OptTimer, + ) + defer bar.Stop() + } + + default: + log.Warn("It's default") + + // We're reduced to reading the size of stdout, but not if we're on a + // terminal. If we are on a terminal, then the user will get to see the + // command output in real-time and we don't need a progress bar. + if !termz.IsTerminal(os.Stdout) { + log.Warn("It's not a terminal") + + if _, ok := cmd.Stdout.(*os.File); ok && !cmd.NoProgress { + bar := progress.FromContext(ctx).NewFilesizeCounter( + loz.NonEmptyOf(cmd.Label, cmd.Name), + cmd.Stdout.(*os.File), + "", + progress.OptTimer, + ) + defer bar.Stop() + } + } + } + + if err = execCmd.Run(); err != nil { + return newExecError(cmd.ErrPrefix, cmd, execCmd, stderrBuf, err) + } + return nil +} + +var _ error = (*execError)(nil) + +// execError is an error that occurred during command execution. +type execError struct { + msg string + execErr error + cmd *Cmd + execCmd *exec.Cmd + errOut []byte +} + +// Error returns the error message. +func (e *execError) Error() string { + s := e.msg + ": " + e.execErr.Error() + + if len(e.errOut) > 0 { + s += ": " + string(e.errOut) + s = strings.TrimSuffix(s, "\r\n") // windows + s = strings.TrimSuffix(s, "\n") + } + + return s +} + +// Unwrap returns the underlying error. +func (e *execError) Unwrap() error { + return e.execErr +} + +// ExitCode returns the exit code of the command execution if the underlying +// execution error was an *exec.ExitError, otherwise -1. +func (e *execError) ExitCode() int { + if ee, ok := errz.As[*exec.ExitError](e.execErr); ok { + return ee.ExitCode() + } + return -1 +} + +// newExecError creates a new execError. If cmd.Stderr is +// a *bytes.Buffer, it will be used to populate the errOut field, +// otherwise errOut may be nil. +func newExecError(msg string, cmd *Cmd, execCmd *exec.Cmd, stderrBuf *bytes.Buffer, execErr error) *execError { + e := &execError{ + msg: msg, + execErr: execErr, + cmd: cmd, + execCmd: execCmd, + } + + // TODO: We should implement special handling for Lookup errors, + // e.g. "pg_dump" not found. + + if stderrBuf != nil { + e.errOut = stderrBuf.Bytes() + } + return e +} diff --git a/libsq/core/lg/lgm/lgm.go b/libsq/core/lg/lgm/lgm.go index 968dd8c4..8085d3e3 100644 --- a/libsq/core/lg/lgm/lgm.go +++ b/libsq/core/lg/lgm/lgm.go @@ -11,6 +11,7 @@ const ( CloseHTTPResponseBody = "Close HTTP response body" CloseFileReader = "Close file reader" CloseFileWriter = "Close file writer" + CloseOutputFile = "Close output file" CtxDone = "Context unexpectedly done" OpenSrc = "Open source" ReadDBRows = "Read DB rows" diff --git a/libsq/core/loz/loz.go b/libsq/core/loz/loz.go index db538b9b..76ce83ef 100644 --- a/libsq/core/loz/loz.go +++ b/libsq/core/loz/loz.go @@ -202,3 +202,20 @@ func RemoveUnordered[T any](a []*T, v *T) []*T { } return a } + +// Cond returns a if cond is true, else b. It's basically the ternary operator. +func Cond[T any](cond bool, a, b T) T { + if cond { + return a + } + return b +} + +// NonEmptyOf returns a if a is non-zero, else b. +func NonEmptyOf[T comparable](a, b T) T { + var zero T + if a != zero { + return a + } + return b +} diff --git a/libsq/core/progress/bars.go b/libsq/core/progress/bars.go index cf4b5c6d..a0c820db 100644 --- a/libsq/core/progress/bars.go +++ b/libsq/core/progress/bars.go @@ -1,6 +1,8 @@ package progress import ( + "fmt" + "os" "time" humanize "github.com/dustin/go-humanize" @@ -38,6 +40,40 @@ func (p *Progress) NewByteCounter(msg string, size int64, opts ...Opt) *Bar { return p.newBar(cfg, opts) } +// NewFilesizeCounter returns a new indeterminate bar whose label metric is a +// filesize, or "-" if it can't be read. If f is non-nil, its size is used; else +// the file at path fp is used. The caller is ultimately responsible for calling +// Bar.Stop on the returned Bar. +func (p *Progress) NewFilesizeCounter(msg string, f *os.File, fp string, opts ...Opt) *Bar { + if p == nil { + return nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + cfg := &barConfig{msg: msg, total: -1, style: spinnerStyle(p.colors.Filler)} + + d := decor.Any(func(statistics decor.Statistics) string { + var fi os.FileInfo + var err error + if f != nil { + fi, err = f.Stat() + } else { + fi, err = os.Stat(fp) + } + + if err != nil { + return "-" + } + + return fmt.Sprintf("% .1f", decor.SizeB1024(fi.Size())) + }) + + cfg.decorators = []decor.Decorator{colorize(d, p.colors.Size)} + return p.newBar(cfg, opts) +} + // NewUnitCounter returns a new indeterminate bar whose label // metric is the plural of the provided unit. The caller is ultimately // responsible for calling Bar.Stop on the returned Bar. diff --git a/libsq/core/progress/style.go b/libsq/core/progress/style.go index 8ef5b145..f343cf6b 100644 --- a/libsq/core/progress/style.go +++ b/libsq/core/progress/style.go @@ -117,6 +117,17 @@ func newElapsedSeconds(c *color.Color, startTime time.Time, wcc ...decor.WC) dec return decor.Any(fn, wcc...) } +// OptTimer is an Opt that causes the bar to display elapsed seconds. +var OptTimer = optElapsedSeconds{} + +var _ Opt = optElapsedSeconds{} + +type optElapsedSeconds struct{} + +func (optElapsedSeconds) apply(p *Progress, cfg *barConfig) { + cfg.decorators = append(cfg.decorators, newElapsedSeconds(p.colors.Size, time.Now(), decor.WCSyncSpace)) +} + // OptMemUsage is an Opt that causes the bar to display program // memory usage. var OptMemUsage = optMemUsage{} diff --git a/libsq/core/sqlz/sqlz.go b/libsq/core/sqlz/sqlz.go index 808dcf9a..3d72b846 100644 --- a/libsq/core/sqlz/sqlz.go +++ b/libsq/core/sqlz/sqlz.go @@ -6,6 +6,8 @@ import ( "database/sql" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgm" ) // Execer abstracts the ExecContext method @@ -78,3 +80,36 @@ func RequireSingleConn(db DB) error { return nil } + +// RowsScanColumn scans a single-column [*sql.Rows] into a slice of T. If the +// returned value could be null, use a nullable type, e.g. [sql.NullString]. +// Arg rows is always closed. On any error, the returned slice is nil. +func RowsScanColumn[T any](ctx context.Context, rows *sql.Rows) (vals []T, err error) { + defer func() { + if rows != nil { + lg.WarnIfCloseError(lg.FromContext(ctx), lgm.CloseDBRows, rows) + } + }() + + // We want to return an empty slice rather than nil. + vals = make([]T, 0) + for rows.Next() { + var val T + if err = rows.Scan(&val); err != nil { + return nil, errz.Err(err) + } + vals = append(vals, val) + } + + if err = rows.Err(); err != nil { + return nil, errz.Err(err) + } + + err = rows.Close() + rows = nil + if err != nil { + return nil, errz.Err(err) + } + + return vals, nil +} diff --git a/cli/terminal.go b/libsq/core/termz/termz.go similarity index 71% rename from cli/terminal.go rename to libsq/core/termz/termz.go index 24b3cbe5..ef144aa1 100644 --- a/cli/terminal.go +++ b/libsq/core/termz/termz.go @@ -1,6 +1,7 @@ //go:build !windows -package cli +// Package termz contains a handful of terminal utilities. +package termz import ( "io" @@ -9,8 +10,8 @@ import ( "golang.org/x/term" ) -// isTerminal returns true if w is a terminal. -func isTerminal(w io.Writer) bool { +// IsTerminal returns true if w is a terminal. +func IsTerminal(w io.Writer) bool { switch v := w.(type) { case *os.File: return term.IsTerminal(int(v.Fd())) @@ -19,12 +20,12 @@ func isTerminal(w io.Writer) bool { } } -// isColorTerminal returns true if w is a colorable terminal. +// IsColorTerminal returns true if w is a colorable terminal. // It respects [NO_COLOR], [FORCE_COLOR] and TERM=dumb environment variables. // // [NO_COLOR]: https://no-color.org/ // [FORCE_COLOR]: https://force-color.org/ -func isColorTerminal(w io.Writer) bool { +func IsColorTerminal(w io.Writer) bool { if os.Getenv("NO_COLOR") != "" { return false } diff --git a/cli/terminal_windows.go b/libsq/core/termz/termz_windows.go similarity index 85% rename from cli/terminal_windows.go rename to libsq/core/termz/termz_windows.go index 40b2573d..aad72f5c 100644 --- a/cli/terminal_windows.go +++ b/libsq/core/termz/termz_windows.go @@ -1,4 +1,4 @@ -package cli +package termz import ( "io" @@ -8,8 +8,8 @@ import ( "golang.org/x/term" ) -// isTerminal returns true if w is a terminal. -func isTerminal(w io.Writer) bool { +// IsTerminal returns true if w is a terminal. +func IsTerminal(w io.Writer) bool { switch v := w.(type) { case *os.File: return term.IsTerminal(int(v.Fd())) @@ -18,7 +18,7 @@ func isTerminal(w io.Writer) bool { } } -// isColorTerminal returns true if w is a colorable terminal. +// IsColorTerminal returns true if w is a colorable terminal. // It respects [NO_COLOR], [FORCE_COLOR] and TERM=dumb environment variables. // // Acknowledgement: This function is lifted from neilotoole/jsoncolor, but @@ -27,7 +27,7 @@ func isTerminal(w io.Writer) bool { // // [NO_COLOR]: https://no-color.org/ // [FORCE_COLOR]: https://force-color.org/ -func isColorTerminal(w io.Writer) bool { +func IsColorTerminal(w io.Writer) bool { if os.Getenv("NO_COLOR") != "" { return false } diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index 08e14ede..9a37bb47 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -150,6 +150,14 @@ type SQLDriver interface { // DropSchema drops the named schema in db. DropSchema(ctx context.Context, db sqlz.DB, schemaName string) error + // CatalogExists returns true if db can reference the named catalog. If + // catalog is empty string, false is returned. + CatalogExists(ctx context.Context, db sqlz.DB, catalog string) (bool, error) + + // SchemaExists returns true if db can reference the named schema. If + // schma is empty string, false is returned. + SchemaExists(ctx context.Context, db sqlz.DB, schma string) (bool, error) + // Truncate truncates tbl in src. If arg reset is true, the // identity counter for tbl should be reset, if supported // by the driver. Some DB impls may reset the identity @@ -159,6 +167,11 @@ type SQLDriver interface { // TableExists returns true if there's an existing table tbl in db. TableExists(ctx context.Context, db sqlz.DB, tbl string) (bool, error) + // ListTableNames lists the tables of schma in db. The "tables" and "views" + // args filter TABLE and VIEW types, respectively. If both are false, an empty + // slice is returned. If schma is empty, the current schema is used. + ListTableNames(ctx context.Context, db sqlz.DB, schma string, tables, views bool) ([]string, error) + // CopyTable copies fromTable into a new table toTable. // If copyData is true, fromTable's data is also copied. // Constraints (keys, defaults etc.) may not be copied. The diff --git a/libsq/driver/driver_test.go b/libsq/driver/driver_test.go index 01a8250d..e4951e85 100644 --- a/libsq/driver/driver_test.go +++ b/libsq/driver/driver_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "testing" "github.com/stretchr/testify/assert" @@ -430,7 +431,7 @@ func TestRegistry_DriversMetadata_Doc(t *testing.T) { } } -func TestDatabase_TableMetadata(t *testing.T) { //nolint:tparallel +func TestGrip_TableMetadata(t *testing.T) { //nolint:tparallel for _, handle := range sakila.SQLAll() { handle := handle @@ -447,7 +448,7 @@ func TestDatabase_TableMetadata(t *testing.T) { //nolint:tparallel } } -func TestDatabase_SourceMetadata(t *testing.T) { +func TestGrip_SourceMetadata(t *testing.T) { t.Parallel() for _, handle := range sakila.SQLAll() { @@ -466,9 +467,99 @@ func TestDatabase_SourceMetadata(t *testing.T) { } } -// TestDatabase_SourceMetadata_concurrent tests the behavior of the +// TestSQLDriver_ListTableNames_ArgSchemaEmpty tests [driver.SQLDriver.ListTableNames] +// with an empty schema arg. +func TestSQLDriver_ListTableNames_ArgSchemaEmpty(t *testing.T) { //nolint:tparallel + for _, handle := range sakila.SQLLatest() { + handle := handle + + t.Run(handle, func(t *testing.T) { + t.Parallel() + + th, _, drvr, _, db := testh.NewWith(t, handle) + + got, err := drvr.ListTableNames(th.Context, db, "", false, false) + require.NoError(t, err) + require.NotNil(t, got) + require.True(t, len(got) == 0) + + got, err = drvr.ListTableNames(th.Context, db, "", true, false) + require.NoError(t, err) + require.NotNil(t, got) + require.Contains(t, got, sakila.TblActor) + require.NotContains(t, got, sakila.ViewFilmList) + + got, err = drvr.ListTableNames(th.Context, db, "", false, true) + require.NoError(t, err) + require.NotNil(t, got) + require.NotContains(t, got, sakila.TblActor) + require.Contains(t, got, sakila.ViewFilmList) + + got, err = drvr.ListTableNames(th.Context, db, "", true, true) + require.NoError(t, err) + require.NotNil(t, got) + require.Contains(t, got, sakila.TblActor) + require.Contains(t, got, sakila.ViewFilmList) + + gotCopy := append([]string(nil), got...) + slices.Sort(gotCopy) + require.Equal(t, got, gotCopy, "expected results to be sorted") + }) + } +} + +// TestSQLDriver_ListTableNames_ArgSchemaNotEmpty tests +// [driver.SQLDriver.ListTableNames] with a non-empty schema arg. +func TestSQLDriver_ListTableNames_ArgSchemaNotEmpty(t *testing.T) { //nolint:tparallel + testCases := []struct { + handle string + schema string + wantTables int + wantViews int + }{ + {handle: sakila.Pg12, schema: "public", wantTables: 25, wantViews: 5}, + {handle: sakila.MS19, schema: "dbo", wantTables: 17, wantViews: 5}, + {handle: sakila.SL3, schema: "main", wantTables: 16, wantViews: 5}, + {handle: sakila.My8, schema: "sakila", wantTables: 16, wantViews: 7}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.handle, func(t *testing.T) { + t.Parallel() + + th, _, drvr, _, db := testh.NewWith(t, tc.handle) + + got, err := drvr.ListTableNames(th.Context, db, tc.schema, false, false) + require.NoError(t, err) + require.NotNil(t, got) + require.True(t, len(got) == 0) + + got, err = drvr.ListTableNames(th.Context, db, tc.schema, true, false) + require.NoError(t, err) + require.NotNil(t, got) + require.Len(t, got, tc.wantTables) + + got, err = drvr.ListTableNames(th.Context, db, tc.schema, false, true) + require.NoError(t, err) + require.NotNil(t, got) + require.Len(t, got, tc.wantViews) + + got, err = drvr.ListTableNames(th.Context, db, tc.schema, true, true) + require.NoError(t, err) + require.NotNil(t, got) + require.Len(t, got, tc.wantTables+tc.wantViews) + + gotCopy := append([]string(nil), got...) + slices.Sort(gotCopy) + require.Equal(t, got, gotCopy, "expected results to be sorted") + }) + } +} + +// TestGrip_SourceMetadata_concurrent tests the behavior of the // drivers when SourceMetadata is invoked concurrently. -func TestDatabase_SourceMetadata_concurrent(t *testing.T) { //nolint:tparallel +func TestGrip_SourceMetadata_concurrent(t *testing.T) { //nolint:tparallel const concurrency = 5 handles := sakila.SQLLatest() @@ -648,6 +739,89 @@ func TestSQLDriver_CurrentSchemaCatalog(t *testing.T) { } } +func TestSQLDriver_SchemaExists(t *testing.T) { + t.Parallel() + + testCases := []struct { + handle string + schema string + wantOK bool + }{ + {handle: sakila.SL3, schema: "main", wantOK: true}, + {handle: sakila.SL3, schema: "", wantOK: false}, + {handle: sakila.SL3, schema: "not_exist", wantOK: false}, + {handle: sakila.Pg, schema: "public", wantOK: true}, + {handle: sakila.Pg, schema: "information_schema", wantOK: true}, + {handle: sakila.Pg, schema: "not_exist", wantOK: false}, + {handle: sakila.Pg, schema: "", wantOK: false}, + {handle: sakila.My, schema: "sakila", wantOK: true}, + {handle: sakila.My, schema: "", wantOK: false}, + {handle: sakila.My, schema: "not_exist", wantOK: false}, + {handle: sakila.MS, schema: "dbo", wantOK: true}, + {handle: sakila.MS, schema: "sys", wantOK: true}, + {handle: sakila.MS, schema: "INFORMATION_SCHEMA", wantOK: true}, + {handle: sakila.MS, schema: "", wantOK: false}, + {handle: sakila.MS, schema: "not_exist", wantOK: false}, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tu.Name(tc.handle, tc.schema, tc.wantOK), func(t *testing.T) { + t.Parallel() + + th, _, drvr, _, db := testh.NewWith(t, tc.handle) + ok, err := drvr.SchemaExists(th.Context, db, tc.schema) + require.NoError(t, err) + require.Equal(t, tc.wantOK, ok) + }) + } +} + +func TestSQLDriver_CatalogExists(t *testing.T) { + t.Parallel() + + testCases := []struct { + handle string + catalog string + wantOK bool + wantErr bool + }{ + {handle: sakila.SL3, catalog: "default", wantErr: true}, + {handle: sakila.SL3, catalog: "not_exist", wantErr: true}, + {handle: sakila.SL3, catalog: "", wantErr: true}, + {handle: sakila.Pg, catalog: "sakila", wantOK: true}, + {handle: sakila.Pg, catalog: "postgres", wantOK: true}, + {handle: sakila.Pg, catalog: "not_exist", wantOK: false}, + {handle: sakila.Pg, catalog: "", wantOK: false}, + {handle: sakila.My, catalog: "def", wantOK: true}, + {handle: sakila.My, catalog: "not_exist", wantOK: false}, + {handle: sakila.My, catalog: "", wantOK: false}, + {handle: sakila.MS, catalog: "sakila", wantOK: true}, + {handle: sakila.MS, catalog: "model", wantOK: true}, + {handle: sakila.MS, catalog: "not_exist", wantOK: false}, + {handle: sakila.MS, catalog: "", wantOK: false}, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tu.Name(tc.handle, tc.catalog, tc.wantOK), func(t *testing.T) { + t.Parallel() + + th, _, drvr, _, db := testh.NewWith(t, tc.handle) + + ok, err := drvr.CatalogExists(th.Context, db, tc.catalog) + require.Equal(t, tc.wantOK, ok) + if tc.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + func TestDriverCreateDropSchema(t *testing.T) { testCases := []struct { handle string diff --git a/libsq/source/location/location.go b/libsq/source/location/location.go index ff0570cb..bc019244 100644 --- a/libsq/source/location/location.go +++ b/libsq/source/location/location.go @@ -358,6 +358,11 @@ func TypeOf(loc string) Type { return TypeFile } +// IsURL returns true if t is TypeHTTP or TypeSQL. +func (t Type) IsURL() bool { + return t == TypeHTTP || t == TypeSQL +} + // isHTTP tests if s is a well-structured HTTP or HTTPS url, and // if so, returns the url and true. func isHTTP(s string) (u *url.URL, ok bool) { diff --git a/main.go b/main.go index 0b702ea8..0c6cf062 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,10 @@ func main() { defer func() { cancelFn() if err != nil { + if code := errz.ExitCode(err); code > 0 { + os.Exit(code) + } + os.Exit(1) } }()