db tools preliminary work; --src.schema changes (#392)

- Preliminary work on the (currently hidden) `db` cmds.
- Improvements to `--src.schema`
This commit is contained in:
Neil O'Toole 2024-02-09 09:08:39 -07:00 committed by GitHub
parent 4a884e147c
commit 99454852f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
60 changed files with 2612 additions and 274 deletions

4
.gitignore vendored
View File

@ -42,7 +42,7 @@ _testmain.go
.envrc
**/*.bench
go.work*
*.dump.sql
# Some apps create temp files when editing, e.g. Excel with drivers/xlsx/testdata/~$test_header.xlsx
**/testdata/~*
@ -59,3 +59,5 @@ goreleaser-test.sh
/*.db
/.CHANGELOG.delta.md
/testh/progress-remove.test.sh
/*.dump

View File

@ -8,6 +8,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Breaking changes are annotated with ☢️, and alpha/beta features with 🐥.
## [v0.47.4] - UPCOMING
Minor changes to the behavior of the `--src.schema` flag.
See the earlier [`v0.47.0`](https://github.com/neilotoole/sq/releases/tag/v0.47.0)
release for recent headline features.
### Changed
- The [`--src.schema`](https://sq.io/docs/source#source-override) flag (as used in [`sq inspect`](https://sq.io/docs/cmd/inspect),
[`sq sql`](https://sq.io/docs/cmd/sql), and the root [`sq`](https://sq.io/docs/cmd/sq#override-active-schema) cmd)
now accepts `--src.schema=CATALOG.` (note the `.` suffix). This is in addition to the existing allowed forms `SCHEMA`
and `CATALOG.SCHEMA`. This new `CATALOG.` form is effectively equivalent to `CATALOG.CURRENT_SCHEMA`.
```shell
# Inspect using the default schema in the "sales" catalog
$ sq inspect --src.schema=sales.
```
- The [`--src.schema`](https://sq.io/docs/source#source-override) flag is now validated. Previously, if you provided a non-existing catalog or schema
value, `sq` would silently ignore it and use the defaults. This could mislead the user into thinking that
they were getting valid results from the non-existent catalog or schema. Now an error is returned.
## [v0.47.3] - 2024-02-03
Minor bug fix release. See the earlier [`v0.47.0`](https://github.com/neilotoole/sq/releases/tag/v0.47.0)
@ -1138,3 +1161,4 @@ make working with lots of sources much easier.
[v0.47.1]: https://github.com/neilotoole/sq/compare/v0.47.0...v0.47.1
[v0.47.2]: https://github.com/neilotoole/sq/compare/v0.47.1...v0.47.2
[v0.47.3]: https://github.com/neilotoole/sq/compare/v0.47.2...v0.47.3
[v0.47.4]: https://github.com/neilotoole/sq/compare/v0.47.3...v0.47.4

View File

@ -29,11 +29,13 @@ gen:
.PHONY: fmt
fmt:
@# https://github.com/incu6us/goimports-reviser
@# Note that *_windows.go is excluded because the tool seems
@# Note that termz_windows.go is excluded because the tool seems
@# to mangle Go code that is guarded by build tags that
@# are not in use.
@# are not in use. Alas, we can't provide a double star glob,
@# e.g. **/*_windows.go, because filepath.Match doesn't support
@# double star, so we explicitly name the file.
@goimports-reviser -company-prefixes github.com/neilotoole -set-alias \
-excludes '**/*_windows.go' \
-excludes 'libsq/core/termz/termz_windows.go' \
-rm-unused -output write \
-project-name github.com/neilotoole/sq ./...

View File

@ -259,6 +259,15 @@ func newCommandTree(ru *run.Run) (rootCmd *cobra.Command) {
addCmd(ru, tblCmd, newTblTruncateCmd())
addCmd(ru, tblCmd, newTblDropCmd())
dbCmd := addCmd(ru, rootCmd, newDBCmd())
addCmd(ru, dbCmd, newDBExecCmd())
dbDumpCmd := addCmd(ru, dbCmd, newDBDumpCmd())
addCmd(ru, dbDumpCmd, newDBDumpCatalogCmd())
addCmd(ru, dbDumpCmd, newDBDumpClusterCmd())
dbRestoreCmd := addCmd(ru, dbCmd, newDBRestoreCmd())
addCmd(ru, dbRestoreCmd, newDBRestoreCatalogCmd())
addCmd(ru, dbRestoreCmd, newDBRestoreClusterCmd())
addCmd(ru, rootCmd, newDiffCmd())
driverCmd := addCmd(ru, rootCmd, newDriverCmd())

View File

@ -156,7 +156,7 @@ More examples:
Long: `Add data source specified by LOCATION, optionally identified by @HANDLE.`,
}
markCmdRequiresConfigLock(cmd)
cmdMarkRequiresConfigLock(cmd)
addTextFormatFlags(cmd)
cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage)
cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage)
@ -242,7 +242,7 @@ func execSrcAdd(cmd *cobra.Command, args []string) error {
// or sq prompts the user.
if cmdFlagIsSetTrue(cmd, flag.PasswordPrompt) {
var passwd []byte
if passwd, err = readPassword(ctx, ru.Stdin, ru.Out, ru.Writers.Printing); err != nil {
if passwd, err = readPassword(ctx, ru.Stdin, ru.Out, ru.Writers.OutPrinting); err != nil {
return err
}

View File

@ -141,14 +141,12 @@ func TestCmdAdd(t *testing.T) {
wantHandle: "@sakila",
wantType: drivertype.SQLite,
},
{
// with scheme
loc: proj.Abs(sakila.PathSL3),
wantHandle: "@sakila",
wantType: drivertype.SQLite,
},
{
// without scheme, relative path
loc: proj.Rel(sakila.PathSL3),

View File

@ -74,7 +74,7 @@ func newCacheStatCmd() *cobra.Command {
/Users/neilotoole/Library/Caches/sq/f36ac695 enabled (472.8MB)`,
}
markCmdRequiresConfigLock(cmd)
cmdMarkRequiresConfigLock(cmd)
addTextFormatFlags(cmd)
cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage)
cmd.Flags().BoolP(flag.YAML, flag.YAMLShort, false, flag.YAMLUsage)
@ -111,7 +111,7 @@ func newCacheClearCmd() *cobra.Command {
$ sq cache clear @sakila`,
}
markCmdRequiresConfigLock(cmd)
cmdMarkRequiresConfigLock(cmd)
return cmd
}
@ -150,7 +150,7 @@ func newCacheTreeCmd() *cobra.Command {
$ sq cache tree --size`,
}
markCmdRequiresConfigLock(cmd)
cmdMarkRequiresConfigLock(cmd)
_ = cmd.Flags().BoolP(flag.CacheTreeSize, flag.CacheTreeSizeShort, false, flag.CacheTreeSizeUsage)
return cmd
}
@ -163,7 +163,7 @@ func execCacheTree(cmd *cobra.Command, _ []string) error {
}
showSize := cmdFlagBool(cmd, flag.CacheTreeSize)
return ioz.PrintTree(ru.Out, cacheDir, showSize, !ru.Writers.Printing.IsMonochrome())
return ioz.PrintTree(ru.Out, cacheDir, showSize, !ru.Writers.OutPrinting.IsMonochrome())
}
func newCacheEnableCmd() *cobra.Command { //nolint:dupl
@ -200,7 +200,7 @@ func newCacheEnableCmd() *cobra.Command { //nolint:dupl
$ sq cache enable @sakila`,
}
markCmdRequiresConfigLock(cmd)
cmdMarkRequiresConfigLock(cmd)
return cmd
}
@ -238,6 +238,6 @@ func newCacheDisableCmd() *cobra.Command { //nolint:dupl
$ sq cache disable @sakila`,
}
markCmdRequiresConfigLock(cmd)
cmdMarkRequiresConfigLock(cmd)
return cmd
}

View File

@ -51,7 +51,7 @@ in envar $SQ_EDITOR or $EDITOR.`,
# Use a different editor
$ SQ_EDITOR=nano sq config edit`,
}
markCmdRequiresConfigLock(cmd)
cmdMarkRequiresConfigLock(cmd)
return cmd
}
@ -134,7 +134,7 @@ func execConfigEditSource(cmd *cobra.Command, args []string) error {
// The Catalog and Schema fields have yaml tag 'omitempty',
// so they wouldn't be rendered in the editor yaml if empty.
// However, we to render the fields commented-out if empty.
// However, we want to render the fields commented-out if empty.
// Hence this little hack.
if tmpSrc.Catalog == "" {
// Forces yaml rendering

View File

@ -41,7 +41,7 @@ Use "sq config ls -v" to list available options.`,
# Help for an individual option
$ sq config set conn.max-open --help`,
}
markCmdRequiresConfigLock(cmd)
cmdMarkRequiresConfigLock(cmd)
addTextFormatFlags(cmd)
cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage)
cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage)

18
cli/cmd_db.go Normal file
View File

@ -0,0 +1,18 @@
package cli
import (
"github.com/spf13/cobra"
)
func newDBCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "db",
Short: "Useful database actions",
RunE: func(cmd *cobra.Command, args []string) error {
return cmd.Help()
},
Example: ` # TBD`,
}
return cmd
}

269
cli/cmd_db_dump.go Normal file
View File

@ -0,0 +1,269 @@
package cli
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"github.com/neilotoole/sq/cli/flag"
"github.com/neilotoole/sq/cli/run"
"github.com/neilotoole/sq/drivers/postgres"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/execz"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/source"
"github.com/neilotoole/sq/libsq/source/drivertype"
)
func newDBDumpCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "dump",
Short: "Dump db catalog or cluster",
Long: `Execute or print db-native dump command for db catalog or cluster.`,
RunE: func(cmd *cobra.Command, args []string) error {
return cmd.Help()
},
}
return cmd
}
func newDBDumpCatalogCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "catalog @src [--print]",
Short: "Dump db catalog",
Long: `Dump db catalog using database-native dump tool.`,
ValidArgsFunction: completeHandle(1, true),
Args: cobra.MaximumNArgs(1),
RunE: execDBDumpCatalog,
Example: ` # Dump @sakila_pg to file sakila.dump using pg_dump
$ sq db dump catalog @sakila_pg -o sakila.dump
# Same as above, but verbose mode, and dump via stdout
$ sq db dump catalog @sakila_pg -v > sakila.dump
# Dump without ownership or ACL
$ sq db dump catalog --no-owner @sakila_pg > sakila.dump
# Print the dump command, but don't execute it
$ sq db dump catalog @sakila_pg --print
# Dump a catalog (db) other than the source's current catalog
$ sq db dump catalog @sakila_pg --catalog sales > sales.dump`,
}
// Calling cmdMarkPlainStdout means that ru.Stdout will be
// the plain os.Stdout, and won't be decorated with color, or
// progress listeners etc. The dump commands handle their own output.
cmdMarkPlainStdout(cmd)
cmd.Flags().String(flag.DBDumpCatalog, "", flag.DBDumpCatalogUsage)
panicOn(cmd.RegisterFlagCompletionFunc(flag.DBDumpCatalog, completeCatalog(0)))
cmd.Flags().Bool(flag.DBDumpNoOwner, false, flag.DBDumpNoOwnerUsage)
cmd.Flags().StringP(flag.FileOutput, flag.FileOutputShort, "", flag.FileOutputUsage)
cmd.Flags().Bool(flag.DBPrintToolCmd, false, flag.DBPrintToolCmdUsage)
cmd.Flags().Bool(flag.DBPrintLongToolCmd, false, flag.DBPrintLongToolCmdUsage)
cmd.MarkFlagsMutuallyExclusive(flag.DBPrintToolCmd, flag.DBPrintLongToolCmd)
return cmd
}
func execDBDumpCatalog(cmd *cobra.Command, args []string) error {
ru := run.FromContext(cmd.Context())
var (
src *source.Source
err error
)
if len(args) == 0 {
if src = ru.Config.Collection.Active(); src == nil {
return errz.New(msgNoActiveSrc)
}
} else if src, err = ru.Config.Collection.Get(args[0]); err != nil {
return err
}
if err = applySourceOptions(cmd, src); err != nil {
return err
}
if cmdFlagChanged(cmd, flag.DBDumpCatalog) {
// Use a different catalog than the source's current catalog.
if src.Catalog, err = cmd.Flags().GetString(flag.DBDumpCatalog); err != nil {
return err
}
}
var (
errPrefix = fmt.Sprintf("db dump catalog: %s", src.Handle)
dumpVerbose = cmdFlagBool(cmd, flag.Verbose)
dumpNoOwner = cmdFlagBool(cmd, flag.DBDumpNoOwner)
dumpLongFlags = cmdFlagBool(cmd, flag.DBPrintLongToolCmd)
dumpFile string
)
if cmdFlagChanged(cmd, flag.FileOutput) {
if dumpFile, err = cmd.Flags().GetString(flag.FileOutput); err != nil {
return err
}
if dumpFile = strings.TrimSpace(dumpFile); dumpFile == "" {
return errz.Errorf("%s: %s is specified, but empty", errPrefix, flag.FileOutput)
}
}
var execCmd *execz.Cmd
switch src.Type { //nolint:exhaustive
case drivertype.Pg:
params := &postgres.ToolParams{
Verbose: dumpVerbose,
NoOwner: dumpNoOwner,
File: dumpFile,
LongFlags: dumpLongFlags,
}
execCmd, err = postgres.DumpCatalogCmd(src, params)
default:
return errz.Errorf("%s: not supported for %s", errPrefix, src.Type)
}
if err != nil {
return errz.Wrap(err, errPrefix)
}
execCmd.NoProgress = !OptProgress.Get(src.Options)
execCmd.Label = src.Handle + ": " + execCmd.Name
execCmd.Stdin = ru.Stdin
execCmd.Stdout = ru.Stdout
execCmd.Stderr = ru.ErrOut
execCmd.ErrPrefix = errPrefix
if cmdFlagBool(cmd, flag.DBPrintToolCmd) || cmdFlagBool(cmd, flag.DBPrintLongToolCmd) {
lg.FromContext(cmd.Context()).Info("Printing external cmd", lga.Cmd, execCmd)
_, err = fmt.Fprintln(ru.Out, execCmd.String())
return errz.Err(err)
}
switch src.Type { //nolint:exhaustive
case drivertype.Pg:
lg.FromContext(cmd.Context()).Info("Executing external cmd", lga.Cmd, execCmd)
return execz.Exec(cmd.Context(), execCmd)
default:
return errz.Errorf("%s: not supported for %s", errPrefix, src.Type)
}
}
func newDBDumpClusterCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "cluster @src [--print]",
Short: "Dump entire db cluster",
Long: `Dump all catalogs in src's db cluster using the db-native dump tool.`,
ValidArgsFunction: completeHandle(1, true),
Args: cobra.MaximumNArgs(1),
RunE: execDBDumpCluster,
Example: ` # Dump all catalogs in @sakila_pg's cluster using pg_dumpall
$ sq db dump cluster @sakila_pg -f all.dump
# Same as above, but verbose mode and using stdout
$ sq db dump cluster @sakila_pg -v > all.dump
# Dump without ownership or ACL
$ sq db dump cluster @sakila_pg --no-owner > all.dump
# Print the dump command, but don't execute it
$ sq db dump cluster @sakila_pg -f all.dump --print`,
}
// Calling cmdMarkPlainStdout means that ru.Stdout will be
// the plain os.Stdout, and won't be decorated with color, or
// progress listeners etc. The dump commands handle their own output.
cmdMarkPlainStdout(cmd)
cmd.Flags().Bool(flag.DBDumpNoOwner, false, flag.DBDumpNoOwnerUsage)
cmd.Flags().StringP(flag.FileOutput, flag.FileOutputShort, "", flag.FileOutputUsage)
cmd.Flags().Bool(flag.DBPrintToolCmd, false, flag.DBPrintToolCmdUsage)
cmd.Flags().Bool(flag.DBPrintLongToolCmd, false, flag.DBPrintLongToolCmdUsage)
cmd.MarkFlagsMutuallyExclusive(flag.DBPrintToolCmd, flag.DBPrintLongToolCmd)
return cmd
}
func execDBDumpCluster(cmd *cobra.Command, args []string) error {
var (
ru = run.FromContext(cmd.Context())
src *source.Source
err error
)
if len(args) == 0 {
if src = ru.Config.Collection.Active(); src == nil {
return errz.New(msgNoActiveSrc)
}
} else if src, err = ru.Config.Collection.Get(args[0]); err != nil {
return err
}
if err = applySourceOptions(cmd, src); err != nil {
return err
}
var (
errPrefix = fmt.Sprintf("db dump cluster: %s", src.Handle)
dumpVerbose = cmdFlagBool(cmd, flag.Verbose)
dumpNoOwner = cmdFlagBool(cmd, flag.DBDumpNoOwner)
dumpLongFlags = cmdFlagBool(cmd, flag.DBPrintLongToolCmd)
dumpFile string
)
if cmdFlagChanged(cmd, flag.FileOutput) {
if dumpFile, err = cmd.Flags().GetString(flag.FileOutput); err != nil {
return err
}
if dumpFile = strings.TrimSpace(dumpFile); dumpFile == "" {
return errz.Errorf("%s: %s is specified, but empty", errPrefix, flag.FileOutput)
}
}
var execCmd *execz.Cmd
switch src.Type { //nolint:exhaustive
case drivertype.Pg:
params := &postgres.ToolParams{
Verbose: dumpVerbose,
NoOwner: dumpNoOwner,
File: dumpFile,
LongFlags: dumpLongFlags,
}
execCmd, err = postgres.DumpClusterCmd(src, params)
default:
err = errz.Errorf("%s: not supported for %s", errPrefix, src.Type)
}
if err != nil {
return errz.Wrap(err, errPrefix)
}
execCmd.NoProgress = !OptProgress.Get(src.Options)
execCmd.Label = src.Handle + ": " + execCmd.Name
execCmd.Stdin = ru.Stdin
execCmd.Stdout = ru.Stdout
execCmd.Stderr = ru.ErrOut
execCmd.ErrPrefix = errPrefix
if cmdFlagBool(cmd, flag.DBPrintToolCmd) || cmdFlagBool(cmd, flag.DBPrintLongToolCmd) {
lg.FromContext(cmd.Context()).Info("Printing external cmd", lga.Cmd, execCmd)
_, err = fmt.Fprintln(ru.Out, execCmd.String())
return errz.Err(err)
}
switch src.Type { //nolint:exhaustive
case drivertype.Pg:
lg.FromContext(cmd.Context()).Info("Executing external cmd", lga.Cmd, execCmd)
return execz.Exec(cmd.Context(), execCmd)
default:
return errz.Errorf("%s: not supported for %s", errPrefix, src.Type)
}
}

151
cli/cmd_db_exec.go Normal file
View File

@ -0,0 +1,151 @@
package cli
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"github.com/neilotoole/sq/cli/flag"
"github.com/neilotoole/sq/cli/run"
"github.com/neilotoole/sq/drivers/postgres"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/execz"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/source"
"github.com/neilotoole/sq/libsq/source/drivertype"
)
func newDBExecCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "exec [@src] [--f SCRIPT.sql] [-c 'SQL'] [--print]",
Short: "Execute SQL script or command",
Long: `Execute SQL script or command using the db-native tool.
If no source is specified, the active source is used.
If --file is specified, the SQL is read from that file; otherwise if --command
is specified, that command string is used; otherwise the SQL commands are
read from stdin.
If --print or --print-long are specified, the SQL is not executed, but instead
the db-native command is printed to stdout. Note that the output will include DB
credentials.`,
Args: cobra.MaximumNArgs(1),
ValidArgsFunction: completeHandle(1, true),
RunE: execDBExec,
Example: ` # Execute query.sql on @sakila_pg
$ sq db exec @sakila_pg -f query.sql
# Same as above, but use stdin
$ sq db exec @sakila_pg < query.sql
# Execute a command string against the active source
$ sq db exec -c 'SELECT 777'
777
# Print the db-native command, but don't execute it
$ sq db exec -f query.sql --print
psql -d 'postgres://alice:abc123@db.acme.com:5432/sales' -f query.sql
# Execute against an alternative catalog or schema
$ sq db exec @sakila_pg --schema inventory.public -f query.sql`,
}
cmdMarkPlainStdout(cmd)
cmd.Flags().StringP(flag.DBExecFile, flag.DBExecFileShort, "", flag.DBExecFileUsage)
cmd.Flags().StringP(flag.DBExecCmd, flag.DBExecCmdShort, "", flag.DBExecCmdUsage)
cmd.MarkFlagsMutuallyExclusive(flag.DBExecFile, flag.DBExecCmd)
cmd.Flags().Bool(flag.DBPrintToolCmd, false, flag.DBPrintToolCmdUsage)
cmd.Flags().Bool(flag.DBPrintLongToolCmd, false, flag.DBPrintLongToolCmdUsage)
cmd.MarkFlagsMutuallyExclusive(flag.DBPrintToolCmd, flag.DBPrintLongToolCmd)
cmd.Flags().String(flag.ActiveSchema, "", flag.ActiveSchemaUsage)
panicOn(cmd.RegisterFlagCompletionFunc(flag.ActiveSchema,
activeSchemaCompleter{getActiveSourceViaArgs}.complete))
return cmd
}
func execDBExec(cmd *cobra.Command, args []string) error {
var (
ru = run.FromContext(cmd.Context())
src *source.Source
err error
// scriptFile is the (optional) path to the SQL file.
// If empty, cmdString or stdin is used.
scriptFile string
// scriptString is the optional SQL command string.
// If empty, scriptFile or stdin is used.
cmdString string
verbose = cmdFlagBool(cmd, flag.Verbose)
)
if src, err = getCmdSource(cmd, args); err != nil {
return err
}
errPrefix := "db exec: " + src.Handle
if cmdFlagChanged(cmd, flag.DBExecFile) {
if scriptFile = strings.TrimSpace(cmd.Flag(flag.DBExecFile).Value.String()); scriptFile == "" {
return errz.Errorf("%s: %s is specified, but empty", errPrefix, flag.DBExecFile)
}
}
if cmdFlagChanged(cmd, flag.DBExecCmd) {
if cmdString = strings.TrimSpace(cmd.Flag(flag.DBExecCmd).Value.String()); cmdString == "" {
return errz.Errorf("%s: %s is specified, but empty", errPrefix, flag.DBExecCmd)
}
}
if err = applySourceOptions(cmd, src); err != nil {
return err
}
var execCmd *execz.Cmd
switch src.Type { //nolint:exhaustive
case drivertype.Pg:
params := &postgres.ExecToolParams{
Verbose: verbose,
ScriptFile: scriptFile,
CmdString: cmdString,
LongFlags: cmdFlagChanged(cmd, flag.DBPrintLongToolCmd),
}
execCmd, err = postgres.ExecCmd(src, params)
default:
return errz.Errorf("%s: not supported for %s", errPrefix, src.Type)
}
if err != nil {
return errz.Wrap(err, errPrefix)
}
execCmd.NoProgress = !OptProgress.Get(src.Options)
execCmd.Label = src.Handle + ": " + execCmd.Name
execCmd.Stdin = ru.Stdin
execCmd.Stdout = ru.Stdout
execCmd.Stderr = ru.ErrOut
execCmd.ErrPrefix = errPrefix
if cmdFlagBool(cmd, flag.DBPrintToolCmd) || cmdFlagBool(cmd, flag.DBPrintLongToolCmd) {
lg.FromContext(cmd.Context()).Info("Printing external cmd", lga.Cmd, execCmd)
s := execCmd.String()
_, err = fmt.Fprintln(ru.Out, s)
return errz.Err(err)
}
switch src.Type { //nolint:exhaustive
case drivertype.Pg:
lg.FromContext(cmd.Context()).Info("Executing external cmd", lga.Cmd, execCmd)
return execz.Exec(cmd.Context(), execCmd)
default:
return errz.Errorf("%s: not supported for %s", errPrefix, src.Type)
}
}

259
cli/cmd_db_restore.go Normal file
View File

@ -0,0 +1,259 @@
package cli
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"github.com/neilotoole/sq/cli/flag"
"github.com/neilotoole/sq/cli/run"
"github.com/neilotoole/sq/drivers/postgres"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/execz"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/source"
"github.com/neilotoole/sq/libsq/source/drivertype"
)
func newDBRestoreCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "restore",
Short: "Restore db catalog or cluster from dump",
RunE: func(cmd *cobra.Command, args []string) error {
return cmd.Help()
},
}
return cmd
}
func newDBRestoreCatalogCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "catalog @src [--from file.dump] [--print]",
Short: "Restore db catalog from dump",
Long: `Restore into @src from dump file, using the db-native restore tool.
If --from is specified, the dump is read from that file; otherwise stdin is used.
When --no-owner is specified, the source user will own the restored objects: the
ownership (and ACLs) from the dump file are disregarded.
If --print or --print-long are specified, the restore command is not executed, but
instead the db-native command is printed to stdout. Note that the command output
will include DB credentials. For a Postgres source, it would look something like:
pg_restore -d 'postgres://alice:abc123@localhost:5432/sales' backup.dump`,
Args: cobra.ExactArgs(1),
ValidArgsFunction: completeHandle(1, true),
RunE: execDBRestoreCatalog,
Example: ` # Restore @sakila_pg from backup.dump
$ sq db restore catalog @sakila_pg -f backup.dump
# With verbose output, and reading from stdin
$ sq db restore catalog -v @sakila_pg < backup.dump
# Don't use ownership from dump; the source user will own the restored objects
$ sq db restore catalog @sakila_pg --no-owner < backup.dump
# Print the db-native restore command, but don't execute it
$ sq db restore catalog @sakila_pg -f backup.dump --print`,
}
cmdMarkPlainStdout(cmd)
// FIXME: add --src.Schema
cmd.Flags().StringP(flag.DBRestoreFrom, flag.DBRestoreFromShort, "", flag.DBRestoreFromUsage)
cmd.Flags().Bool(flag.DBRestoreNoOwner, false, flag.DBRestoreNoOwnerUsage)
cmd.Flags().StringP(flag.FileOutput, flag.FileOutputShort, "", flag.FileOutputUsage)
cmd.Flags().Bool(flag.DBPrintToolCmd, false, flag.DBPrintToolCmdUsage)
cmd.Flags().Bool(flag.DBPrintLongToolCmd, false, flag.DBPrintLongToolCmdUsage)
cmd.MarkFlagsMutuallyExclusive(flag.DBPrintToolCmd, flag.DBPrintLongToolCmd)
return cmd
}
func execDBRestoreCatalog(cmd *cobra.Command, args []string) error {
var (
ru = run.FromContext(cmd.Context())
src *source.Source
err error
// fpDump is the (optional) path to the dump file.
// If empty, stdin is used.
dumpFile string
)
if src, err = ru.Config.Collection.Get(args[0]); err != nil {
return err
}
errPrefix := "db restore catalog: " + src.Handle
if cmdFlagChanged(cmd, flag.DBRestoreFrom) {
if dumpFile = strings.TrimSpace(cmd.Flag(flag.DBRestoreFrom).Value.String()); dumpFile == "" {
return errz.Errorf("%s: %s is specified, but empty", errPrefix, flag.DBRestoreFrom)
}
}
if err = applySourceOptions(cmd, src); err != nil {
return err
}
verbose := cmdFlagBool(cmd, flag.Verbose)
noOwner := cmdFlagBool(cmd, flag.DBRestoreNoOwner)
var execCmd *execz.Cmd
switch src.Type { //nolint:exhaustive
case drivertype.Pg:
params := &postgres.ToolParams{
Verbose: verbose,
NoOwner: noOwner,
File: dumpFile,
}
execCmd, err = postgres.RestoreCatalogCmd(src, params)
default:
return errz.Errorf("%s: not supported for %s", errPrefix, src.Type)
}
if err != nil {
return errz.Wrap(err, errPrefix)
}
execCmd.NoProgress = !OptProgress.Get(src.Options)
execCmd.Label = src.Handle + ": " + execCmd.Name
execCmd.Stdin = ru.Stdin
execCmd.Stdout = ru.Stdout
execCmd.Stderr = ru.ErrOut
execCmd.ErrPrefix = errPrefix
if cmdFlagBool(cmd, flag.DBPrintToolCmd) || cmdFlagBool(cmd, flag.DBPrintLongToolCmd) {
lg.FromContext(cmd.Context()).Info("Printing external cmd", lga.Cmd, execCmd)
_, err = fmt.Fprintln(ru.Out, execCmd.String())
return errz.Err(err)
}
switch src.Type { //nolint:exhaustive
case drivertype.Pg:
lg.FromContext(cmd.Context()).Info("Executing external cmd", lga.Cmd, execCmd)
return execz.Exec(cmd.Context(), execCmd)
default:
return errz.Errorf("%s: cmd not supported for %s", errPrefix, src.Type)
}
}
func newDBRestoreClusterCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "cluster @src [--from file.dump] [--print]",
Short: "Restore db cluster from dump",
Long: `Restore entire db cluster into @src from dump file, using the db-native restore
tool.
If --from is specified, the dump is read from that file; otherwise stdin is used.
When --no-owner is specified, the source user will own the restored objects: the
ownership (and ACLs) from the dump file are disregarded.
If --print or --print-long are specified, the restore command is not executed, but
instead the db-native command is printed to stdout. Note that the command output
will include DB credentials. For a Postgres source, it would look something like:
FIXME: example command
psql -d 'postgres://alice:abc123@localhost:5432/sales' backup.dump`,
Args: cobra.ExactArgs(1),
ValidArgsFunction: completeHandle(1, true),
RunE: execDBRestoreCluster,
Example: ` # Restore @sakila_pg from backup.dump
$ sq db restore cluster @sakila_pg -f backup.dump
# With verbose output, and reading from stdin
$ sq db restore cluster -v @sakila_pg < backup.dump
# Don't use ownership from dump; the source user will own the restored objects
$ sq db restore cluster @sakila_pg --no-owner < backup.dump
# Print the db-native restore command, but don't execute it
$ sq db restore cluster @sakila_pg -f backup.dump --print`,
}
cmdMarkPlainStdout(cmd)
cmd.Flags().StringP(flag.DBRestoreFrom, flag.DBRestoreFromShort, "", flag.DBRestoreFromUsage)
cmd.Flags().Bool(flag.DBRestoreNoOwner, false, flag.DBRestoreNoOwnerUsage)
cmd.Flags().StringP(flag.FileOutput, flag.FileOutputShort, "", flag.FileOutputUsage)
cmd.Flags().Bool(flag.DBPrintToolCmd, false, flag.DBPrintToolCmdUsage)
cmd.Flags().Bool(flag.DBPrintLongToolCmd, false, flag.DBPrintLongToolCmdUsage)
cmd.MarkFlagsMutuallyExclusive(flag.DBPrintToolCmd, flag.DBPrintLongToolCmd)
return cmd
}
func execDBRestoreCluster(cmd *cobra.Command, args []string) error {
var (
ru = run.FromContext(cmd.Context())
src *source.Source
err error
// dumpFile is the (optional) path to the dump file.
// If empty, stdin is used.
dumpFile string
)
if src, err = ru.Config.Collection.Get(args[0]); err != nil {
return err
}
errPrefix := "db restore cluster: " + src.Handle
if cmdFlagChanged(cmd, flag.DBRestoreFrom) {
if dumpFile = strings.TrimSpace(cmd.Flag(flag.DBRestoreFrom).Value.String()); dumpFile == "" {
return errz.Errorf("%s: %s is specified, but empty", errPrefix, flag.DBRestoreFrom)
}
}
if err = applySourceOptions(cmd, src); err != nil {
return err
}
verbose := cmdFlagBool(cmd, flag.Verbose)
// FIXME: get rid of noOwner from this command?
// noOwner := cmdFlagBool(cmd, flag.RestoreNoOwner)
var execCmd *execz.Cmd
switch src.Type { //nolint:exhaustive
case drivertype.Pg:
params := &postgres.ToolParams{
Verbose: verbose,
File: dumpFile,
}
execCmd, err = postgres.RestoreClusterCmd(src, params)
default:
return errz.Errorf("%s: not supported for %s", errPrefix, src.Type)
}
if err != nil {
return errz.Wrap(err, errPrefix)
}
execCmd.NoProgress = !OptProgress.Get(src.Options)
execCmd.Label = src.Handle + ": " + execCmd.Name
execCmd.Stdin = ru.Stdin
execCmd.Stdout = ru.Stdout
execCmd.Stderr = ru.ErrOut
execCmd.ErrPrefix = errPrefix
if cmdFlagBool(cmd, flag.DBPrintToolCmd) || cmdFlagBool(cmd, flag.DBPrintLongToolCmd) {
lg.FromContext(cmd.Context()).Info("Printing external cmd", lga.Cmd, execCmd)
s := execCmd.String()
_, err = fmt.Fprintln(ru.Out, s)
return errz.Err(err)
}
switch src.Type { //nolint:exhaustive
case drivertype.Pg:
lg.FromContext(cmd.Context()).Info("Executing external cmd", lga.Cmd, execCmd)
return execz.Exec(cmd.Context(), execCmd)
default:
return errz.Errorf("%s: not supported for %s", errPrefix, src.Type)
}
}

View File

@ -121,7 +121,8 @@ func execInspect(cmd *cobra.Command, args []string) error {
// Handle flag.ActiveSchema (--src.schema=SCHEMA). This func will mutate
// src's Catalog and Schema fields if appropriate.
if err = processFlagActiveSchema(cmd, src); err != nil {
var srcModified bool
if srcModified, err = processFlagActiveSchema(cmd, src); err != nil {
return err
}
@ -129,6 +130,12 @@ func execInspect(cmd *cobra.Command, args []string) error {
return err
}
if srcModified {
if err = verifySourceCatalogSchema(ctx, ru, src); err != nil {
return err
}
}
grip, err := ru.Grips.Open(ctx, src)
if err != nil {
return errz.Wrapf(err, "failed to inspect %s", src.Handle)

View File

@ -42,7 +42,7 @@ source handles are files, and groups are directories.`,
$ sq mv production prod`,
}
markCmdRequiresConfigLock(cmd)
cmdMarkRequiresConfigLock(cmd)
addTextFormatFlags(cmd)
cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage)
cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage)

View File

@ -34,7 +34,7 @@ may have changed, if that source or group was removed.`,
$ sq rm @staging/sakila_db @staging/backup_db dev`,
}
markCmdRequiresConfigLock(cmd)
cmdMarkRequiresConfigLock(cmd)
addTextFormatFlags(cmd)
cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage)
cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage)

View File

@ -36,7 +36,7 @@ importing non-SQL data, or cross-database joins. If no argument provided, get th
source. Otherwise, set @HANDLE or an internal db as the scratch data source. The reserved handle "@scratch" resets the
`,
}
markCmdRequiresConfigLock(cmd)
cmdMarkRequiresConfigLock(cmd)
return cmd
}

View File

@ -17,7 +17,6 @@ import (
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/lg/lgm"
"github.com/neilotoole/sq/libsq/core/stringz"
"github.com/neilotoole/sq/libsq/driver"
"github.com/neilotoole/sq/libsq/source"
@ -204,7 +203,7 @@ func execSLQPrint(ctx context.Context, ru *run.Run, mArgs map[string]string) err
//
// $ cat something.xlsx | sq @stdin.sheet1
func preprocessUserSLQ(ctx context.Context, ru *run.Run, args []string) (string, error) {
log, reg, grips, coll := lg.FromContext(ctx), ru.DriverRegistry, ru.Grips, ru.Config.Collection
log, reg, coll := lg.FromContext(ctx), ru.DriverRegistry, ru.Config.Collection
activeSrc := coll.Active()
if len(args) == 0 {
@ -212,6 +211,10 @@ func preprocessUserSLQ(ctx context.Context, ru *run.Run, args []string) (string,
// but sq is receiving pipe input. Let's say the user does this:
//
// $ cat something.csv | sq # query becomes "@stdin.data"
//
// REVISIT: It's not clear that this is even reachable any more?
// Plus, it's a bit ugly in general. Was the code already changed
// to force providing at least one query arg?
if activeSrc == nil {
// Piped input would result in an active @stdin src. We don't
// have that; we don't have any active src.
@ -230,27 +233,26 @@ func preprocessUserSLQ(ctx context.Context, ru *run.Run, args []string) (string,
}
tblName := source.MonotableName
if !drvr.DriverMetadata().Monotable {
// This isn't a monotable src, so we can't
// just select @stdin.data. Instead we'll select
// the first table name, as found in the source meta.
grip, err := grips.Open(ctx, activeSrc)
if err != nil {
return "", err
}
defer lg.WarnIfCloseError(log, lgm.CloseDB, grip)
srcMeta, err := grip.SourceMetadata(ctx, false)
db, sqlDrvr, err := ru.DB(ctx, activeSrc)
if err != nil {
return "", err
}
if len(srcMeta.Tables) == 0 {
tables, err := sqlDrvr.ListTableNames(ctx, db, "", true, true)
if err != nil {
return "", err
}
if len(tables) == 0 {
return "", errz.New(msgSrcNoData)
}
tblName = srcMeta.Tables[0].Name
tblName = tables[0]
if tblName == "" {
return "", errz.New(msgSrcEmptyTableName)
}
@ -325,7 +327,7 @@ func addQueryCmdFlags(cmd *cobra.Command) {
addTimeFormatOptsFlags(cmd)
cmd.Flags().StringP(flag.Output, flag.OutputShort, "", flag.OutputUsage)
cmd.Flags().StringP(flag.FileOutput, flag.FileOutputShort, "", flag.FileOutputUsage)
cmd.Flags().StringP(flag.Input, flag.InputShort, "", flag.InputUsage)
panicOn(cmd.Flags().MarkHidden(flag.Input)) // Hide for now; this is mostly used for testing.

View File

@ -14,7 +14,7 @@ import (
"github.com/neilotoole/sq/libsq/files"
)
// newXCmd returns the root "x" command, which is the container
// newXCmd returns the "x" command, which is the container
// for a set of hidden commands that are useful for development.
// The x commands are not intended for end users.
func newXCmd() *cobra.Command {

View File

@ -113,6 +113,65 @@ func completeHandle(max int, includeActive bool) completionFunc {
}
}
// completeCatalog is a completionFunc that suggests catalogs.
// If srcArgPos >= 0 and in the range of the cmd args, then
// that value is used to determine the source handle. Typically the
// src handle is the first arg, but it could be elsewhere.
// For example, arg[0] below is @sakila, thus srcArgPos should be 0:
//
// $ sq db dump @sakila --catalog [COMPLETE]
//
// If srcArgPos < 0, the catalogs for the active source are suggested.
func completeCatalog(srcArgPos int) completionFunc {
return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
log, ru := logFrom(cmd), getRun(cmd)
if err := preRun(cmd, ru); err != nil {
lg.Unexpected(log, err)
return nil, cobra.ShellCompDirectiveError
}
var (
ctx = cmd.Context()
coll = ru.Config.Collection
src *source.Source
err error
)
if srcArgPos >= 0 && srcArgPos < len(args) {
if src, err = coll.Get(args[srcArgPos]); err != nil {
lg.Unexpected(log, err)
return nil, cobra.ShellCompDirectiveError
}
}
if src == nil {
if src = coll.Active(); src == nil {
log.Debug("No active source, so no catalog completions")
return nil, cobra.ShellCompDirectiveError
}
}
db, drvr, err := ru.DB(ctx, src)
if err != nil {
lg.Unexpected(log, err)
return nil, cobra.ShellCompDirectiveError
}
catalogs, err := drvr.ListCatalogs(ctx, db)
if err != nil {
lg.Unexpected(log, err)
return nil, cobra.ShellCompDirectiveError
}
catalogs = lo.Filter(catalogs, func(catalog string, index int) bool {
return strings.HasPrefix(catalog, toComplete)
})
slices.Sort(catalogs)
return catalogs, cobra.ShellCompDirectiveNoFileComp
}
}
// completeGroup is a completionFunc that suggests groups.
// The max arg is the maximum number of completions. Set to 0
// for no limit.
@ -793,29 +852,21 @@ func getTableNamesForHandle(ctx context.Context, ru *run.Run, handle string) ([]
return nil, err
}
grip, err := ru.Grips.Open(ctx, src)
db, drvr, err := ru.DB(ctx, src)
if err != nil {
return nil, err
}
// TODO: We shouldn't have to load the full metadata just to get
// the table names. driver.SQLDriver should have a method ListTables.
md, err := grip.SourceMetadata(ctx, false)
if err != nil {
return nil, err
}
return md.TableNames(), nil
return drvr.ListTableNames(ctx, db, src.Schema, true, true)
}
// maybeFilterHandlesByActiveGroup filters the supplied handles by
// active group, if appropriate.
func maybeFilterHandlesByActiveGroup(ru *run.Run, toComplete string, suggestions []string) []string {
handleFilter := getActiveGroupHandleFilterPrefix(ru)
if handleFilter != "" {
if strings.HasPrefix(handleFilter, toComplete) {
suggestions = lo.Filter(suggestions, func(handle string, index int) bool {
return strings.HasPrefix(handle, handleFilter)
if groupPrefix := getActiveGroupHandleFilterPrefix(ru); groupPrefix != "" {
if strings.HasPrefix(groupPrefix, toComplete) {
suggestions = lo.Filter(suggestions, func(handle string, _ int) bool {
return strings.HasPrefix(handle, groupPrefix)
})
}
}

View File

@ -201,6 +201,8 @@ func (fs *Store) write(ctx context.Context, data []byte) error {
return errz.Wrapf(err, "failed to make parent dir of config file: %s", filepath.Dir(fs.Path))
}
// FIXME: Store.Save should do a two-step atomic write of the file.
if err := os.WriteFile(fs.Path, data, ioz.RWPerms); err != nil {
return errz.Wrap(err, "failed to save config file")
}

View File

@ -42,7 +42,7 @@ func buildTableDataDiff(ctx context.Context, ru *run.Run, cfg *Config,
query2 := td2.src.Handle + "." + td2.tblName
log := lg.FromContext(ctx).With("a", query1).With("b", query2)
pr := ru.Writers.Printing.Clone()
pr := ru.Writers.OutPrinting.Clone()
pr.EnableColor(false)
buf1, buf2 := &bytes.Buffer{}, &bytes.Buffer{}
@ -176,7 +176,7 @@ func execSourceDataDiff(ctx context.Context, ru *run.Run, cfg *Config, sd1, sd2
}
tblDataDiff = diffs[printIndex]
if err := Print(ctx, ru.Out, ru.Writers.Printing, tblDataDiff.header, tblDataDiff.diff); err != nil {
if err := Print(ctx, ru.Out, ru.Writers.OutPrinting, tblDataDiff.header, tblDataDiff.diff); err != nil {
printErrCh <- err
return
}

View File

@ -158,7 +158,7 @@ func findRecordDiff(ctx context.Context, ru *run.Run, lines int,
row: i,
}
if err = populateRecordDiff(ctx, lines, ru.Writers.Printing, recDiff); err != nil {
if err = populateRecordDiff(ctx, lines, ru.Writers.OutPrinting, recDiff); err != nil {
return nil, err
}

View File

@ -47,7 +47,7 @@ func ExecSourceDiff(ctx context.Context, ru *run.Run, cfg *Config,
return err
}
if err = Print(ctx, ru.Out, ru.Writers.Printing, srcDiff.header, srcDiff.diff); err != nil {
if err = Print(ctx, ru.Out, ru.Writers.OutPrinting, srcDiff.header, srcDiff.diff); err != nil {
return err
}
}
@ -57,7 +57,7 @@ func ExecSourceDiff(ctx context.Context, ru *run.Run, cfg *Config,
if err != nil {
return err
}
if err = Print(ctx, ru.Out, ru.Writers.Printing, propsDiff.header, propsDiff.diff); err != nil {
if err = Print(ctx, ru.Out, ru.Writers.OutPrinting, propsDiff.header, propsDiff.diff); err != nil {
return err
}
}
@ -68,7 +68,7 @@ func ExecSourceDiff(ctx context.Context, ru *run.Run, cfg *Config,
return err
}
for _, tblDiff := range tblDiffs {
if err = Print(ctx, ru.Out, ru.Writers.Printing, tblDiff.header, tblDiff.diff); err != nil {
if err = Print(ctx, ru.Out, ru.Writers.OutPrinting, tblDiff.header, tblDiff.diff); err != nil {
return err
}
}

View File

@ -51,7 +51,7 @@ func ExecTableDiff(ctx context.Context, ru *run.Run, cfg *Config, elems *Element
return err
}
if err = Print(ctx, ru.Out, ru.Writers.Printing, tblDiff.header, tblDiff.diff); err != nil {
if err = Print(ctx, ru.Out, ru.Writers.OutPrinting, tblDiff.header, tblDiff.diff); err != nil {
return err
}
}
@ -69,7 +69,7 @@ func ExecTableDiff(ctx context.Context, ru *run.Run, cfg *Config, elems *Element
return nil
}
return Print(ctx, ru.Out, ru.Writers.Printing, tblDataDiff.header, tblDataDiff.diff)
return Print(ctx, ru.Out, ru.Writers.OutPrinting, tblDataDiff.header, tblDataDiff.diff)
}
func buildTableStructureDiff(ctx context.Context, cfg *Config, showRowCounts bool,

View File

@ -19,6 +19,7 @@ import (
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/options"
"github.com/neilotoole/sq/libsq/core/termz"
)
// PrintError is the centralized function for printing
@ -88,8 +89,10 @@ func PrintError(ctx context.Context, ru *run.Run, err error) {
} else {
clnup = cleanup.New()
}
// getPrinting works even if cmd is nil
pr, _, errOut := getPrinting(cmd, clnup, opts, os.Stdout, os.Stderr)
// getOutputConfig works even if cmd is nil
fm := getFormat(cmd, opts)
outCfg := getOutputConfig(cmd, clnup, fm, opts, ru.Stdout, ru.Stderr)
errOut, pr := outCfg.errOut, outCfg.errOutPr
// Execute the cleanup before we print the error.
if cleanErr := clnup.Run(); cleanErr != nil {
log.Error("Cleanup failed", lga.Err, cleanErr)
@ -103,7 +106,7 @@ func PrintError(ctx context.Context, ru *run.Run, err error) {
}
// The user didn't want JSON, so we just print to stderr.
if isColorTerminal(os.Stderr) {
if termz.IsColorTerminal(os.Stderr) {
pr.Error.Fprintln(os.Stderr, "sq: "+err.Error())
} else {
fmt.Fprintln(os.Stderr, "sq: "+err.Error())

View File

@ -1,12 +1,15 @@
// Package flag holds CLI flags.
package flag
// FIXME: Need to update docs for use of src.schema to note the
// new "catalog." mechanism.
const (
ActiveSrc = "src"
ActiveSrcUsage = "Override active source for this query"
ActiveSchema = "src.schema"
ActiveSchemaUsage = "Override active schema or catalog.schema for this query"
ActiveSchemaUsage = "Override active schema (and/or catalog) for this query"
ConfigSrc = "src"
ConfigSrcUsage = "Config for source"
@ -67,12 +70,9 @@ const (
MonochromeShort = "M"
MonochromeUsage = "Don't colorize output"
NoProgress = "no-progress"
NoProgressUsage = "Don't show progress bar"
Output = "output"
OutputShort = "o"
OutputUsage = "Write output to <file> instead of stdout"
FileOutput = "output"
FileOutputShort = "o"
FileOutputUsage = "Write output to <file> instead of stdout"
// Input sets Run.Stdin to the named file. At this time, this is used
// mainly for debugging, so it's marked hidden by the CLI. I'm not
@ -210,6 +210,29 @@ const (
DiffAll = "all"
DiffAllShort = "a"
DiffAllUsage = "Compare everything (caution: may be slow)"
DBDumpCatalog = "catalog"
DBDumpCatalogUsage = "Dump the named catalog"
DBDumpNoOwner = "no-owner"
DBDumpNoOwnerUsage = "Don't set ownership or ACL"
DBPrintToolCmd = "print"
DBPrintToolCmdUsage = "Print the db-native tool command, but don't execute it"
DBPrintLongToolCmd = "print-long"
DBPrintLongToolCmdUsage = "Print the long-form db-native tool command, but don't execute it"
DBRestoreFrom = "from"
DBRestoreFromShort = "f"
DBRestoreFromUsage = "Restore from dump file; if omitted, read from stdin"
DBRestoreNoOwner = "no-owner"
DBRestoreNoOwnerUsage = "Don't use ownership or ACL from dump"
DBExecFile = "file"
DBExecFileShort = "f"
DBExecFileUsage = "Read SQL from <file> instead of stdin"
DBExecCmd = "command"
DBExecCmdShort = "c"
DBExecCmdUsage = "Execute SQL command string"
)
// OutputFormatFlags is the set of flags that control output format.

View File

@ -1,6 +1,7 @@
package cli
import (
"context"
"fmt"
"io"
"os"
@ -30,6 +31,7 @@ import (
"github.com/neilotoole/sq/libsq/core/options"
"github.com/neilotoole/sq/libsq/core/progress"
"github.com/neilotoole/sq/libsq/core/stringz"
"github.com/neilotoole/sq/libsq/core/termz"
"github.com/neilotoole/sq/libsq/core/timez"
)
@ -269,13 +271,15 @@ the rendered value is not an integer.
)
// newWriters returns an output.Writers instance configured per defaults and/or
// flags from cmd. The returned out2/errOut2 values may differ
// from the out/errOut args (e.g. decorated to support colorization).
// flags from cmd. The returned writers in [outputConfig] may differ from
// the stdout and stderr params (e.g. decorated to support colorization).
func newWriters(cmd *cobra.Command, clnup *cleanup.Cleanup, o options.Options,
out, errOut io.Writer,
) (w *output.Writers, out2, errOut2 io.Writer) {
var pr *output.Printing
pr, out2, errOut2 = getPrinting(cmd, clnup, o, out, errOut)
stdout, stderr io.Writer,
) (w *output.Writers, outCfg *outputConfig) {
// Invoke getFormat to see if the format was specified
// via config or flag.
fm := getFormat(cmd, o)
outCfg = getOutputConfig(cmd, clnup, fm, o, stdout, stderr)
log := logFrom(cmd)
// Package tablew has writer impls for each of the writer interfaces,
@ -283,50 +287,47 @@ func newWriters(cmd *cobra.Command, clnup *cleanup.Cleanup, o options.Options,
// flags and set the various writer fields depending upon which
// writers the format implements.
w = &output.Writers{
Printing: pr,
Record: tablew.NewRecordWriter(out2, pr),
Metadata: tablew.NewMetadataWriter(out2, pr),
Source: tablew.NewSourceWriter(out2, pr),
Ping: tablew.NewPingWriter(out2, pr),
Error: tablew.NewErrorWriter(errOut2, pr),
Version: tablew.NewVersionWriter(out2, pr),
Config: tablew.NewConfigWriter(out2, pr),
OutPrinting: outCfg.outPr,
ErrPrinting: outCfg.errOutPr,
Record: tablew.NewRecordWriter(outCfg.out, outCfg.outPr),
Metadata: tablew.NewMetadataWriter(outCfg.out, outCfg.outPr),
Source: tablew.NewSourceWriter(outCfg.out, outCfg.outPr),
Ping: tablew.NewPingWriter(outCfg.out, outCfg.outPr),
Error: tablew.NewErrorWriter(outCfg.errOut, outCfg.errOutPr),
Version: tablew.NewVersionWriter(outCfg.out, outCfg.outPr),
Config: tablew.NewConfigWriter(outCfg.out, outCfg.outPr),
}
if OptErrorFormat.Get(o) == format.JSON {
// This logic works because the only supported values are text and json.
w.Error = jsonw.NewErrorWriter(log, errOut2, pr)
w.Error = jsonw.NewErrorWriter(log, outCfg.errOut, outCfg.errOutPr)
}
// Invoke getFormat to see if the format was specified
// via config or flag.
fm := getFormat(cmd, o)
//nolint:exhaustive
switch fm {
case format.JSON:
// No format specified, use JSON
w.Metadata = jsonw.NewMetadataWriter(out2, pr)
w.Source = jsonw.NewSourceWriter(out2, pr)
w.Version = jsonw.NewVersionWriter(out2, pr)
w.Ping = jsonw.NewPingWriter(out2, pr)
w.Config = jsonw.NewConfigWriter(out2, pr)
w.Metadata = jsonw.NewMetadataWriter(outCfg.out, outCfg.outPr)
w.Source = jsonw.NewSourceWriter(outCfg.out, outCfg.outPr)
w.Version = jsonw.NewVersionWriter(outCfg.out, outCfg.outPr)
w.Ping = jsonw.NewPingWriter(outCfg.out, outCfg.outPr)
w.Config = jsonw.NewConfigWriter(outCfg.out, outCfg.outPr)
case format.Text:
// Don't delete this case, it's actually needed due to
// the slightly odd logic that determines format.
case format.TSV:
w.Ping = csvw.NewPingWriter(out2, csvw.Tab)
w.Ping = csvw.NewPingWriter(outCfg.out, csvw.Tab)
case format.CSV:
w.Ping = csvw.NewPingWriter(out2, csvw.Comma)
w.Ping = csvw.NewPingWriter(outCfg.out, csvw.Comma)
case format.YAML:
w.Config = yamlw.NewConfigWriter(out2, pr)
w.Metadata = yamlw.NewMetadataWriter(out2, pr)
w.Source = yamlw.NewSourceWriter(out2, pr)
w.Version = yamlw.NewVersionWriter(out2, pr)
w.Config = yamlw.NewConfigWriter(outCfg.out, outCfg.outPr)
w.Metadata = yamlw.NewMetadataWriter(outCfg.out, outCfg.outPr)
w.Source = yamlw.NewSourceWriter(outCfg.out, outCfg.outPr)
w.Version = yamlw.NewVersionWriter(outCfg.out, outCfg.outPr)
default:
}
@ -335,10 +336,10 @@ func newWriters(cmd *cobra.Command, clnup *cleanup.Cleanup, o options.Options,
// We can still continue, because w.Record was already set above.
log.Warn("No record writer impl for format", "format", fm)
} else {
w.Record = recwFn(out2, pr)
w.Record = recwFn(outCfg.out, outCfg.outPr)
}
return w, out2, errOut2
return w, outCfg
}
// getRecordWriterFunc returns a func that creates a new output.RecordWriter
@ -374,23 +375,71 @@ func getRecordWriterFunc(f format.Format) output.NewRecordWriterFunc {
}
}
// getPrinting returns a Printing instance and
// colorable or non-colorable writers. It is permissible
// for the cmd arg to be nil. The caller should use the returned
// io.Writer instances instead of the supplied writers, as they
// may be decorated for dealing with color, etc.
// The supplied opts must already have flags merged into it
// via getOptionsFromCmd.
//
// Be cautious making changes to getPrinting. This function must
// be absolutely bulletproof, as it's called by all commands, as well
// as by the error handling mechanism. So, be sure to always check
// for nil cmd, nil cmd.Context, etc.
func getPrinting(cmd *cobra.Command, clnup *cleanup.Cleanup, opts options.Options,
out, errOut io.Writer,
) (pr *output.Printing, out2, errOut2 io.Writer) {
pr = output.NewPrinting()
// outputConfig is a container for the various output writers.
type outputConfig struct {
// outPr is the printing config for out.
outPr *output.Printing
// out is the output writer that should be used for stdout output.
out io.Writer
// stdout is the original stdout, which probably was os.Stdin.
// It's referenced here for special cases.
stdout io.Writer
// errOutPr is the printing config for errOut.
errOutPr *output.Printing
// errOut is the output writer that should be used for stderr output.
errOut io.Writer
// stderr is the original errOut, which probably was os.Stderr.
// It's referenced here for special cases.
stderr io.Writer
}
// getOutputConfig returns the configured output writers for cmd. Generally
// speaking, the caller should use the outputConfig.out and outputConfig.errOut
// writers for program output, as they are decorated appropriately for dealing
// with colorization, progress bars, etc. In very rare cases, such as calling
// out to an external program (e.g. pg_dump), the original outputConfig.stdout
// and outputConfig.stderr may be used.
//
// The supplied opts must already have flags merged into it via getOptionsFromCmd.
//
// If the progress bar is enabled and possible (stderr is TTY etc.), then cmd's
// context is decorated with via [progress.NewContext].
//
// Be VERY cautious about making changes to getOutputConfig. This function must
// be absolutely bulletproof, as it's called by all commands, as well as by the
// error handling mechanism. So, be sure to always check for nil: any of the
// args could be nil, or their fields could be nil. Check EVERYTHING for nil.
//
// The returned outputConfig and all of its fields are guaranteed to be non-nil.
//
// See also: [OptMonochrome], [OptProgress], newWriters.
func getOutputConfig(cmd *cobra.Command, clnup *cleanup.Cleanup,
fm format.Format, opts options.Options, stdout, stderr io.Writer,
) (outCfg *outputConfig) {
if opts == nil {
opts = options.Options{}
}
var ctx context.Context
if cmd != nil {
ctx = cmd.Context()
}
if stdout == nil {
stdout = os.Stdout
}
if stderr == nil {
stderr = os.Stderr
}
outCfg = &outputConfig{stdout: stdout, stderr: stderr}
pr := output.NewPrinting()
pr.FormatDatetime = timez.FormatFunc(OptDatetimeFormat.Get(opts))
pr.FormatDatetimeAsNumber = OptDatetimeFormatAsNumber.Get(opts)
pr.FormatTime = timez.FormatFunc(OptTimeFormat.Get(opts))
@ -416,79 +465,104 @@ func getPrinting(cmd *cobra.Command, clnup *cleanup.Cleanup, opts options.Option
pr.ShowHeader = OptPrintHeader.Get(opts)
}
colorize := !OptMonochrome.Get(opts)
var (
prog *progress.Progress
noProg = !OptProgress.Get(opts)
progColors = progress.DefaultColors()
monochrome = OptMonochrome.Get(opts)
)
if cmdFlagChanged(cmd, flag.Output) {
// We're outputting to a file, thus no color.
colorize = false
}
if !colorize {
if monochrome {
color.NoColor = true
pr.EnableColor(false)
out2 = out
errOut2 = errOut
if cmd != nil && cmd.Context() != nil && OptProgress.Get(opts) && isTerminal(errOut) {
progColors := progress.DefaultColors()
progColors.EnableColor(false)
ctx := cmd.Context()
renderDelay := OptProgressDelay.Get(opts)
pb := progress.New(ctx, errOut2, renderDelay, progColors)
clnup.Add(pb.Stop)
// On first write to stdout, we remove the progress widget.
out2 = ioz.NotifyOnceWriter(out2, pb.Stop)
cmd.SetContext(progress.NewContext(ctx, pb))
}
return pr, out2, errOut2
}
// We do want to colorize
if !isColorTerminal(out) {
// But out can't be colorized.
color.NoColor = true
pr.EnableColor(false)
out2, errOut2 = out, errOut
return pr, out2, errOut2
}
// out can be colorized.
color.NoColor = false
pr.EnableColor(true)
out2 = colorable.NewColorable(out.(*os.File))
// Check if we can colorize errOut
if isColorTerminal(errOut) {
errOut2 = colorable.NewColorable(errOut.(*os.File))
progColors.EnableColor(false)
} else {
// errOut2 can't be colorized, but since we're colorizing
// out, we'll apply the non-colorable filter to errOut.
errOut2 = colorable.NewNonColorable(errOut)
color.NoColor = false
pr.EnableColor(true)
progColors.EnableColor(true)
}
if cmd != nil && cmd.Context() != nil && OptProgress.Get(opts) && isTerminal(errOut) {
progColors := progress.DefaultColors()
progColors.EnableColor(isColorTerminal(errOut))
outCfg.outPr = pr.Clone()
outCfg.errOutPr = pr.Clone()
pr = nil //nolint:wastedassign // Make sure we don't accidentally use pr again
ctx := cmd.Context()
renderDelay := OptProgressDelay.Get(opts)
pb := progress.New(ctx, errOut2, renderDelay, progColors)
clnup.Add(pb.Stop)
// On first write to stdout, we remove the progress widget.
out2 = ioz.NotifyOnceWriter(out2, pb.Stop)
cmd.SetContext(progress.NewContext(ctx, pb))
switch {
case termz.IsColorTerminal(stderr) && !monochrome:
// stderr is a color terminal and we're colorizing, thus
// we enable progress if allowed and if ctx is non-nil.
outCfg.errOut = colorable.NewColorable(stderr.(*os.File))
outCfg.errOutPr.EnableColor(true)
if ctx != nil && !noProg {
progColors.EnableColor(true)
prog = progress.New(ctx, outCfg.errOut, OptProgressDelay.Get(opts), progColors)
}
case termz.IsTerminal(stderr):
// stderr is a terminal, and won't have color output, but we still enable
// progress, if allowed and ctx is non-nil.
//
// But... slightly weirdly, we still need to wrap stderr in a colorable, or
// else the progress bar won't render correctly. But it's not a problem,
// because we'll just disable the colors directly.
outCfg.errOut = colorable.NewColorable(stderr.(*os.File))
outCfg.errOutPr.EnableColor(false)
if ctx != nil && !noProg {
progColors.EnableColor(false)
prog = progress.New(ctx, outCfg.errOut, OptProgressDelay.Get(opts), progColors)
}
default:
// stderr is a not a terminal at all. No color, no progress.
outCfg.errOut = colorable.NewNonColorable(stderr)
outCfg.errOutPr.EnableColor(false)
progColors.EnableColor(false)
prog = nil // Set to nil just to be explicit.
}
return pr, out2, errOut2
switch {
case cmdFlagChanged(cmd, flag.FileOutput) || fm == format.Raw:
// For file or raw output, we don't decorate stdout with
// any colorable decorator.
outCfg.out = stdout
outCfg.outPr.EnableColor(false)
case cmd != nil && cmdFlagChanged(cmd, flag.FileOutput):
// stdout is an actual regular file on disk, so no color.
outCfg.out = colorable.NewNonColorable(stdout)
outCfg.outPr.EnableColor(false)
case termz.IsColorTerminal(stdout) && !monochrome:
// stdout is a color terminal and we're colorizing.
outCfg.out = colorable.NewColorable(stdout.(*os.File))
outCfg.outPr.EnableColor(true)
case termz.IsTerminal(stderr):
// stdout is a terminal, but won't be colorized.
outCfg.out = colorable.NewNonColorable(stdout)
outCfg.outPr.EnableColor(false)
default:
// stdout is a not a terminal at all. No color.
outCfg.out = colorable.NewNonColorable(stdout)
outCfg.outPr.EnableColor(false)
}
if !noProg && prog != nil && cmd != nil && ctx != nil {
// The progress bar is enabled.
// Be sure to stop the progress bar eventually.
clnup.Add(prog.Stop)
// Also, stop the progress bar as soon as bytes are written
// to out, because we don't want the progress bar to
// corrupt the terminal output.
outCfg.out = ioz.NotifyOnceWriter(outCfg.out, prog.Stop)
cmd.SetContext(progress.NewContext(ctx, prog))
}
return outCfg
}
func getFormat(cmd *cobra.Command, o options.Options) format.Format {
var fm format.Format
switch {
case cmd == nil:
fm = OptFormat.Get(o)
case cmdFlagChanged(cmd, flag.TSV):
fm = format.TSV
case cmdFlagChanged(cmd, flag.CSV):

View File

@ -40,7 +40,7 @@ func (w *PingWriter) Open(srcs []*source.Source) error {
// Result implements output.PingWriter.
func (w *PingWriter) Result(src *source.Source, d time.Duration, err error) error {
w.pr.Number.Fprintf(w.out, "%-"+strconv.Itoa(w.handleWidthMax)+"s", src.Handle)
w.pr.Handle.Fprintf(w.out, "%-"+strconv.Itoa(w.handleWidthMax)+"s", src.Handle)
w.pr.Duration.Fprintf(w.out, "%10s ", d.Truncate(time.Millisecond).String())
// The ping result is one of:

View File

@ -153,7 +153,8 @@ type ConfigWriter interface {
// Writers is a container for the various output Writers.
type Writers struct {
Printing *Printing
OutPrinting *Printing
ErrPrinting *Printing
Record RecordWriter
Metadata MetadataWriter

View File

@ -70,8 +70,8 @@ func newRun(ctx context.Context, stdin *os.File, stdout, stderr io.Writer, args
ru := &run.Run{
Stdin: stdin,
Out: stdout,
ErrOut: stderr,
Stdout: stdout,
Stderr: stderr,
OptionsRegistry: &options.Registry{},
}
@ -158,35 +158,41 @@ func preRun(cmd *cobra.Command, ru *run.Run) error {
}
// If the --output=/some/file flag is set, then we need to
// override ru.Out (which is typically stdout) to point it at
// override ru.Stdout (which is typically stdout) to point it at
// the output destination file.
if cmdFlagChanged(ru.Cmd, flag.Output) {
fpath, _ := ru.Cmd.Flags().GetString(flag.Output)
//
if cmdFlagChanged(ru.Cmd, flag.FileOutput) && !cmdRequiresPlainStdout(ru.Cmd) {
fpath, _ := ru.Cmd.Flags().GetString(flag.FileOutput)
fpath, err := filepath.Abs(fpath)
if err != nil {
return errz.Wrapf(err, "failed to get absolute path for --%s", flag.Output)
return errz.Wrapf(err, "failed to get absolute path for --%s", flag.FileOutput)
}
// Ensure the parent dir exists
err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm)
if err != nil {
return errz.Wrapf(err, "failed to make parent dir for --%s", flag.Output)
return errz.Wrapf(err, "failed to make parent dir for --%s", flag.FileOutput)
}
f, err := os.Create(fpath)
if err != nil {
return errz.Wrapf(err, "failed to open file specified by flag --%s", flag.Output)
return errz.Wrapf(err, "failed to open file specified by flag --%s", flag.FileOutput)
}
ru.Cleanup.AddC(f) // Make sure the file gets closed eventually
ru.Out = f
ru.Stdout = f
}
cmdOpts, err := getOptionsFromCmd(ru.Cmd)
if err != nil {
return err
}
ru.Writers, ru.Out, ru.ErrOut = newWriters(ru.Cmd, ru.Cleanup, cmdOpts, ru.Out, ru.ErrOut)
var outCfg *outputConfig
ru.Writers, outCfg = newWriters(ru.Cmd, ru.Cleanup, cmdOpts, ru.Stdout, ru.Stderr)
ru.Out = outCfg.out
ru.ErrOut = outCfg.errOut
if err = FinishRunInit(ctx, ru); err != nil {
return err
@ -330,20 +336,20 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error {
return nil
}
// markCmdRequiresConfigLock marks cmd as requiring a config lock.
// cmdMarkRequiresConfigLock marks cmd as requiring a config lock.
// Thus, before the command's RunE is invoked, the config lock
// is acquired (in preRun), and released on cleanup.
func markCmdRequiresConfigLock(cmd *cobra.Command) {
func cmdMarkRequiresConfigLock(cmd *cobra.Command) {
if cmd.Annotations == nil {
cmd.Annotations = make(map[string]string)
}
cmd.Annotations["config.lock"] = "true"
cmd.Annotations["config.lock"] = "true" //nolint:goconst
}
// cmdRequiresConfigLock returns true if markCmdRequiresConfigLock was
// cmdRequiresConfigLock returns true if cmdMarkRequiresConfigLock was
// previously invoked on cmd.
func cmdRequiresConfigLock(cmd *cobra.Command) bool {
return cmd.Annotations != nil && cmd.Annotations["config.lock"] == "true"
return cmd != nil && cmd.Annotations != nil && cmd.Annotations["config.lock"] == "true"
}
// lockReloadConfig acquires the lock for the config store, and updates the
@ -363,7 +369,7 @@ func cmdRequiresConfigLock(cmd *cobra.Command) bool {
// defer unlock()
// }
//
// However, in practice, most commands will invoke markCmdRequiresConfigLock
// However, in practice, most commands will invoke cmdMarkRequiresConfigLock
// instead of explicitly invoking lockReloadConfig.
func lockReloadConfig(cmd *cobra.Command) (unlock func(), err error) {
ctx := cmd.Context()
@ -433,3 +439,20 @@ func newProgressLockFunc(lock lockfile.Lockfile, msg string, timeout time.Durati
}, nil
}
}
// cmdMarkPlainStdout indicates that the command's stdout should
// not be decorated in any way, e.g. with color or progress bars.
// This is useful for binary output.
func cmdMarkPlainStdout(cmd *cobra.Command) {
// FIXME: implement this in newWriters or such?
if cmd.Annotations == nil {
cmd.Annotations = make(map[string]string)
}
cmd.Annotations["stdout.plain"] = "true"
}
// cmdRequiresPlainStdout returns true if cmdMarkPlainStdout was
// previously invoked on cmd.
func cmdRequiresPlainStdout(cmd *cobra.Command) bool {
return cmd != nil && cmd.Annotations != nil && cmd.Annotations["stdout.plain"] == "true"
}

View File

@ -16,8 +16,10 @@ import (
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/options"
"github.com/neilotoole/sq/libsq/core/sqlz"
"github.com/neilotoole/sq/libsq/driver"
"github.com/neilotoole/sq/libsq/files"
"github.com/neilotoole/sq/libsq/source"
)
type runKey struct{}
@ -40,12 +42,28 @@ func FromContext(ctx context.Context) *Run {
// to all cobra exec funcs. The Close method should be invoked when
// the Run is no longer needed.
type Run struct {
// Out is the output destination, typically os.Stdout.
// Out is the output destination, typically a decorated writer over
// [Run.Stdout]. This writer should generally be used for program output,
// not [Run.Stdout].
Out io.Writer
// ErrOut is the error output destination, typically os.Stderr.
// Stdout is the original stdout file descriptor, which is typically
// the actual os.Stdout. Output should generally be written to [Run.Out]
// except for a few rare circumstances, such as executing an external
// program.
Stdout io.Writer
// ErrOut is the error output destination, typically a decorated writer
// over [Run.Stderr]. This writer should generally be used for error output,
// not [Run.Stderr].
ErrOut io.Writer
// Stderr is the original stderr file descriptor, which is typically
// the actual os.Stderr. Error output should generally be written to
// [Run.ErrOut] except for a few rare circumstances, such as executing an
// external program.
Stderr io.Writer
// ConfigStore manages config persistence.
ConfigStore config.Store
@ -108,6 +126,21 @@ func (ru *Run) Close() error {
return errz.Wrap(ru.Cleanup.Run(), "close run")
}
// DB is a convenience method that gets the sqlz.DB and driver.SQLDriver\
// for src.
func (ru *Run) DB(ctx context.Context, src *source.Source) (sqlz.DB, driver.SQLDriver, error) {
grip, err := ru.Grips.Open(ctx, src)
if err != nil {
return nil, nil, err
}
db, err := grip.DB(ctx)
if err != nil {
return nil, nil, err
}
return db, grip.SQLDriver(), nil
}
// NewQueryContext returns a *libsq.QueryContext constructed from ru.
func NewQueryContext(ru *Run, args map[string]string) *libsq.QueryContext {
return &libsq.QueryContext{

View File

@ -70,14 +70,16 @@ func determineSources(ctx context.Context, ru *run.Run, requireActive bool) erro
}
// activeSrcAndSchemaFromFlagsOrConfig gets the active source, either
// from flagActiveSrc or from srcs.Active. An error is returned
// from [flag.ActiveSrc] or from srcs.Active. An error is returned
// if the flag src is not found: if the flag src is found,
// it is set as the active src on coll. If the flag was not
// set and there is no active src in coll, (nil, nil) is
// returned.
//
// This source also checks flag.ActiveSchema, and changes the schema
// This source also checks [flag.ActiveSchema], and changes the catalog/schema
// of the source if the flag is set.
//
// See also: processFlagActiveSchema, verifySourceCatalogSchema.
func activeSrcAndSchemaFromFlagsOrConfig(ru *run.Run) (*source.Source, error) {
cmd, coll := ru.Cmd, ru.Config.Collection
var activeSrc *source.Source
@ -100,49 +102,148 @@ func activeSrcAndSchemaFromFlagsOrConfig(ru *run.Run) (*source.Source, error) {
activeSrc = coll.Active()
}
if err := processFlagActiveSchema(cmd, activeSrc); err != nil {
srcModified, err := processFlagActiveSchema(cmd, activeSrc)
if err != nil {
return nil, err
}
if srcModified {
if err = verifySourceCatalogSchema(ru.Cmd.Context(), ru, activeSrc); err != nil {
return nil, err
}
}
return activeSrc, nil
}
// processFlagActiveSchema processes the --src.schema flag, setting
// appropriate Source.Catalog and Source.Schema values on activeSrc.
// If flag.ActiveSchema is not set, this is no-op. If activeSrc is nil,
// an error is returned.
func processFlagActiveSchema(cmd *cobra.Command, activeSrc *source.Source) error {
// verifySourceCatalogSchema verifies that src's non-empty [source.Source.Catalog]
// and [source.Source.Schema] are valid, in that they are referenceable in src's
// DB. If both fields are empty, this is a no-op. This function is typically
// used when the source's catalog or schema are modified, e.g. via [flag.ActiveSchema].
//
// See also: processFlagActiveSchema.
func verifySourceCatalogSchema(ctx context.Context, ru *run.Run, src *source.Source) error {
if src.Catalog == "" && src.Schema == "" {
return nil
}
db, drvr, err := ru.DB(ctx, src)
if err != nil {
return err
}
var exists bool
if src.Catalog != "" {
if exists, err = drvr.CatalogExists(ctx, db, src.Catalog); err != nil {
return err
}
if !exists {
return errz.Errorf("%s: catalog {%s} doesn't exist or not referenceable", src.Handle, src.Catalog)
}
}
if src.Schema != "" {
if exists, err = drvr.SchemaExists(ctx, db, src.Schema); err != nil {
return err
}
if !exists {
return errz.Errorf("%s: schema {%s} doesn't exist or not referenceable", src.Handle, src.Schema)
}
}
return nil
}
// getCmdSource gets the source specified in args[0], or if args is empty, the
// active source, and calls applySourceOptions. If [flag.ActiveSchema] is set,
// the [source.Source.Catalog] and [source.Source.Schema] fields are configured
// (and validated) as appropriate.
//
// See: applySourceOptions, processFlagActiveSchema, verifySourceCatalogSchema.
func getCmdSource(cmd *cobra.Command, args []string) (*source.Source, error) {
ru := run.FromContext(cmd.Context())
var src *source.Source
var err error
if len(args) == 0 {
if src = ru.Config.Collection.Active(); src == nil {
return nil, errz.New("no active source")
}
} else {
src, err = ru.Config.Collection.Get(args[0])
if err != nil {
return nil, err
}
}
if !cmdFlagChanged(cmd, flag.ActiveSchema) {
return src, nil
}
// Handle flag.ActiveSchema (--src.schema=CATALOG.SCHEMA). This func may
// mutate src's Catalog and Schema fields if appropriate.
var srcModified bool
if srcModified, err = processFlagActiveSchema(cmd, src); err != nil {
return nil, err
}
if err = applySourceOptions(cmd, src); err != nil {
return nil, err
}
if srcModified {
if err = verifySourceCatalogSchema(cmd.Context(), ru, src); err != nil {
return nil, err
}
}
return src, nil
}
// processFlagActiveSchema processes the --src.schema flag, setting appropriate
// [source.Source.Catalog] and [source.Source.Schema] values on activeSrc. If
// the src is modified by this function, modified returns true. If
// [flag.ActiveSchema] is not set, this is no-op. If activeSrc is nil, an error
// is returned.
//
// See also: verifySourceCatalogSchema.
func processFlagActiveSchema(cmd *cobra.Command, activeSrc *source.Source) (modified bool, err error) {
ru := run.FromContext(cmd.Context())
if !cmdFlagChanged(cmd, flag.ActiveSchema) {
// Nothing to do here
return nil
return false, nil
}
if activeSrc == nil {
return errz.Errorf("active catalog/schema specified via --%s, but active source is nil",
return false, errz.Errorf("active catalog/schema specified via --%s, but active source is nil",
flag.ActiveSchema)
}
val, _ := cmd.Flags().GetString(flag.ActiveSchema)
if val = strings.TrimSpace(val); val == "" {
return errz.Errorf("active catalog/schema specified via --%s, but schema is empty",
return false, errz.Errorf("active catalog/schema specified via --%s, but schema is empty",
flag.ActiveSchema)
}
catalog, schema, err := ast.ParseCatalogSchema(val)
if err != nil {
return errz.Wrapf(err, "invalid active schema specified via --%s",
return false, errz.Wrapf(err, "invalid active schema specified via --%s",
flag.ActiveSchema)
}
if catalog != activeSrc.Catalog || schema != activeSrc.Schema {
modified = true
}
drvr, err := ru.DriverRegistry.SQLDriverFor(activeSrc.Type)
if err != nil {
return err
return false, err
}
if catalog != "" {
if !drvr.Dialect().Catalog {
return errz.Errorf("driver {%s} does not support catalog", activeSrc.Type)
return false, errz.Errorf("driver {%s} does not support catalog", activeSrc.Type)
}
activeSrc.Catalog = catalog
}
@ -150,7 +251,7 @@ func processFlagActiveSchema(cmd *cobra.Command, activeSrc *source.Source) error
activeSrc.Schema = schema
}
return nil
return modified, nil
}
// checkStdinSource checks if there's stdin data (on pipe/redirect).

View File

@ -112,7 +112,9 @@ func newRun(ctx context.Context, tb testing.TB,
ru = &run.Run{
Stdin: os.Stdin,
Stdout: out,
Out: out,
Stderr: errOut,
ErrOut: errOut,
Config: cfg,
ConfigStore: cfgStore,

View File

@ -206,9 +206,12 @@ func TestIngest_DuplicateColumns(t *testing.T) {
require.NoError(t, err)
tr = testrun.New(ctx, t, tr).Hush()
require.NoError(t, tr.Exec("--csv", ".data"))
err = tr.Exec("--csv", ".data")
require.NoError(t, err)
wantHeaders := []string{"actor_id", "first_name", "last_name", "last_update", "actor_id_1"}
data := tr.BindCSV()
gotOut := tr.OutString()
_ = gotOut
require.Equal(t, wantHeaders, data[0])
// Make sure the data is correct

View File

@ -224,6 +224,17 @@ func (d *driveri) ListSchemas(ctx context.Context, db sqlz.DB) ([]string, error)
return schemas, nil
}
// SchemaExists implements driver.SQLDriver.
func (d *driveri) SchemaExists(ctx context.Context, db sqlz.DB, schma string) (bool, error) {
if schma == "" {
return false, nil
}
const q = "SELECT COUNT(SCHEMA_NAME) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?"
var count int
return count > 0, errw(db.QueryRowContext(ctx, q, schma).Scan(&count))
}
// ListSchemaMetadata implements driver.SQLDriver.
func (d *driveri) ListSchemaMetadata(ctx context.Context, db sqlz.DB) ([]*metadata.Schema, error) {
log := lg.FromContext(ctx)
@ -273,6 +284,43 @@ func (d *driveri) CurrentCatalog(ctx context.Context, db sqlz.DB) (string, error
return catalog, nil
}
// ListTableNames implements driver.SQLDriver.
func (d *driveri) ListTableNames(ctx context.Context, db sqlz.DB, schma string, tables, views bool) ([]string, error) {
var tblClause string
switch {
case tables && views:
tblClause = " AND (TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW')"
case tables:
tblClause = " AND TABLE_TYPE = 'BASE TABLE'"
case views:
tblClause = " AND TABLE_TYPE = 'VIEW'"
default:
return []string{}, nil
}
var args []any
q := "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = "
if schma == "" {
q += "DATABASE()"
} else {
q += "?"
args = append(args, schma)
}
q += tblClause + " ORDER BY TABLE_NAME"
rows, err := db.QueryContext(ctx, q, args...)
if err != nil {
return nil, errw(err)
}
names, err := sqlz.RowsScanColumn[string](ctx, rows)
if err != nil {
return nil, errw(err)
}
return names, nil
}
// ListCatalogs implements driver.SQLDriver. MySQL does not really support catalogs,
// but this method simply delegates to CurrentCatalog, which returns the value
// found in INFORMATION_SCHEMA.SCHEMATA, i.e. "def".
@ -285,6 +333,12 @@ func (d *driveri) ListCatalogs(ctx context.Context, db sqlz.DB) ([]string, error
return []string{catalog}, nil
}
// CatalogExists implements driver.SQLDriver. It returns true if catalog is "def",
// and false otherwise, nothing that MySQL doesn't really support catalogs.
func (d *driveri) CatalogExists(_ context.Context, _ sqlz.DB, catalog string) (bool, error) {
return catalog == "def", nil
}
// AlterTableRename implements driver.SQLDriver.
func (d *driveri) AlterTableRename(ctx context.Context, db sqlz.DB, tbl, newName string) error {
q := fmt.Sprintf("RENAME TABLE `%s` TO `%s`", tbl, newName)

View File

@ -150,35 +150,9 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, er
func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, error) {
log := lg.FromContext(ctx)
ctx = options.NewContext(ctx, src.Options)
dbCfg, err := pgxpool.ParseConfig(src.Location)
poolCfg, err := getPoolConfig(src, false)
if err != nil {
return nil, errw(err)
}
if src.Catalog != "" && src.Catalog != dbCfg.ConnConfig.Database {
// The catalog differs from the database in the connection string.
// OOTB, Postgres doesn't support cross-database references. So,
// we'll need to change the connection string to use the catalog
// as the database. Note that we don't modify src.Location, but it's
// not entirely clear if that's the correct approach. Are there any
// downsides to modifying it (as long as the modified Location is not
// persisted back to config)?
var u *dburl.URL
if u, err = dburl.Parse(src.Location); err != nil {
return nil, errw(err)
}
u.Path = src.Catalog
connStr := u.String()
dbCfg, err = pgxpool.ParseConfig(connStr)
if err != nil {
return nil, errw(err)
}
log.Debug("Using catalog as database in connection string",
lga.Src, src,
lga.Catalog, src.Catalog,
lga.Conn, location.Redact(connStr),
)
return nil, err
}
var opts []stdlib.OptionOpenDB
@ -196,7 +170,7 @@ func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, erro
log.Debug("Setting default schema (search_path) on Postgres DB connection",
lga.Src, src,
lga.Conn, location.Redact(dbCfg.ConnString()),
lga.Conn, location.Redact(poolCfg.ConnString()),
lga.Catalog, src.Catalog,
lga.Schema, src.Schema,
lga.Old, oldSearchPath,
@ -207,9 +181,9 @@ func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, erro
}))
}
dbCfg.ConnConfig.ConnectTimeout = driver.OptConnOpenTimeout.Get(src.Options)
poolCfg.ConnConfig.ConnectTimeout = driver.OptConnOpenTimeout.Get(src.Options)
db := stdlib.OpenDB(*dbCfg.ConnConfig, opts...)
db := stdlib.OpenDB(*poolCfg.ConnConfig, opts...)
driver.ConfigureDB(ctx, db, src.Options)
return db, nil
@ -430,6 +404,32 @@ ORDER BY datname`
return catalogs, nil
}
// SchemaExists implements driver.SQLDriver.
func (d *driveri) SchemaExists(ctx context.Context, db sqlz.DB, schma string) (bool, error) {
if schma == "" {
return false, nil
}
const q = `SELECT COUNT(schema_name) FROM information_schema.schemata
WHERE schema_name = $1 AND catalog_name = current_database()`
var count int
return count > 0, errw(db.QueryRowContext(ctx, q, schma).Scan(&count))
}
// CatalogExists implements driver.SQLDriver.
func (d *driveri) CatalogExists(ctx context.Context, db sqlz.DB, catalog string) (bool, error) {
if catalog == "" {
return false, nil
}
const q = `SELECT COUNT(datname) FROM pg_catalog.pg_database
WHERE datistemplate = FALSE AND datallowconn = TRUE AND datname = $1`
var count int
return count > 0, errw(db.QueryRowContext(ctx, q, catalog).Scan(&count))
}
// AlterTableRename implements driver.SQLDriver.
func (d *driveri) AlterTableRename(ctx context.Context, db sqlz.DB, tbl, newName string) error {
q := fmt.Sprintf(`ALTER TABLE %q RENAME TO %q`, tbl, newName)
@ -549,6 +549,43 @@ WHERE table_name = $1`
return count == 1, nil
}
// ListTableNames implements driver.SQLDriver.
func (d *driveri) ListTableNames(ctx context.Context, db sqlz.DB, schma string, tables, views bool) ([]string, error) {
var tblClause string
switch {
case tables && views:
tblClause = " AND (table_type = 'BASE TABLE' OR table_type = 'VIEW')"
case tables:
tblClause = " AND table_type = 'BASE TABLE'"
case views:
tblClause = " AND table_type = 'VIEW'"
default:
return []string{}, nil
}
var args []any
q := "SELECT table_name FROM information_schema.tables WHERE table_schema = "
if schma == "" {
q += "current_schema()"
} else {
q += "$1"
args = append(args, schma)
}
q += tblClause + " ORDER BY table_name"
rows, err := db.QueryContext(ctx, q, args...)
if err != nil {
return nil, errw(err)
}
names, err := sqlz.RowsScanColumn[string](ctx, rows)
if err != nil {
return nil, errw(err)
}
return names, nil
}
// DropTable implements driver.SQLDriver.
func (d *driveri) DropTable(ctx context.Context, db sqlz.DB, tbl tablefq.T, ifExists bool) error {
var stmt string
@ -741,15 +778,66 @@ func (d *driveri) RecordMeta(ctx context.Context, colTypes []*sql.ColumnType) (
return recMeta, mungeFn, nil
}
// getPoolConfig returns the native postgres [*pgxpool.Config] for src, applying
// src's fields, such as [source.Source.Catalog] as appropriate. If
// includeConnTimeout is true, then 'connect_timeout' is included in the
// returned config; this is provided as an option, because the connection
// timeout is sometimes better handled via [context.WithTimeout].
func getPoolConfig(src *source.Source, includeConnTimeout bool) (*pgxpool.Config, error) {
poolCfg, err := pgxpool.ParseConfig(src.Location)
if err != nil {
return nil, errw(err)
}
if src.Catalog != "" && src.Catalog != poolCfg.ConnConfig.Database {
// The catalog differs from the database in the connection string.
// OOTB, Postgres doesn't support cross-database references. So,
// we'll need to change the connection string to use the catalog
// as the database. Note that we don't modify src.Location, but it's
// not entirely clear if that's the correct approach. Are there any
// downsides to modifying it (as long as the modified Location is not
// persisted back to config)?
var u *dburl.URL
if u, err = dburl.Parse(src.Location); err != nil {
return nil, errw(err)
}
u.Path = src.Catalog
connStr := u.String()
poolCfg, err = pgxpool.ParseConfig(connStr)
if err != nil {
return nil, errw(err)
}
}
if includeConnTimeout {
srcTimeout := driver.OptConnOpenTimeout.Get(src.Options)
// Only set connect_timeout if it's non-zero and differs from the
// already-configured value.
// REVISIT: We should actually always set it, otherwise the user's
// envar PGCONNECT_TIMEOUT may override it?
if srcTimeout > 0 || poolCfg.ConnConfig.ConnectTimeout != srcTimeout {
var u *dburl.URL
if u, err = dburl.Parse(poolCfg.ConnString()); err != nil {
return nil, errw(err)
}
q := u.Query()
q.Set("connect_timeout", strconv.Itoa(int(srcTimeout.Seconds())))
u.RawQuery = q.Encode()
poolCfg, err = pgxpool.ParseConfig(u.String())
if err != nil {
return nil, errw(err)
}
}
}
return poolCfg, nil
}
// doRetry executes fn with retry on isErrTooManyConnections.
func doRetry(ctx context.Context, fn func() error) error {
maxRetryInterval := driver.OptMaxRetryInterval.Get(options.FromContext(ctx))
return retry.Do(ctx, maxRetryInterval, fn, isErrTooManyConnections)
}
// tblfmt formats a table name for use in a query. The arg can be a string,
// or a tablefq.T.
func tblfmt[T string | tablefq.T](tbl T) string {
tfq := tablefq.From(tbl)
return tfq.Render(stringz.DoubleQuote)
}

View File

@ -8,8 +8,17 @@ import (
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/kind"
"github.com/neilotoole/sq/libsq/core/schema"
"github.com/neilotoole/sq/libsq/core/stringz"
"github.com/neilotoole/sq/libsq/core/tablefq"
)
// tblfmt formats a table name for use in a query. The arg can be a string,
// or a tablefq.T.
func tblfmt[T string | tablefq.T](tbl T) string {
tfq := tablefq.From(tbl)
return tfq.Render(stringz.DoubleQuote)
}
func dbTypeNameFromKind(knd kind.Kind) string {
switch knd { //nolint:exhaustive
default:

332
drivers/postgres/tools.go Normal file
View File

@ -0,0 +1,332 @@
package postgres
import (
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/execz"
"github.com/neilotoole/sq/libsq/source"
)
// REVISIT: DumpCatalogCmd and DumpClusterCmd could be methods on driver.SQLDriver.
// TODO: Unify DumpCatalogCmd and DumpClusterCmd, as they're almost identical, probably
// in the form:
// DumpCatalogCmd(src *source.Source, all bool) (cmd []string, err error).
// ToolParams are parameters for postgres tools such as pg_dump and pg_restore.
//
// - https://www.postgresql.org/docs/9.6/app-pgdump.html
// - https://www.postgresql.org/docs/9.6/app-pgrestore.html.
// - https://www.postgresql.org/docs/9.6/app-pg-dumpall.html
// - https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp
//
// Not every flag is applicable to all tools.
type ToolParams struct {
// File is the path to the file.
File string
// Verbose indicates verbose output (progress).
Verbose bool
// NoOwner won't output commands to set ownership of objects; the source's
// connection user will own all objects. This also sets the --no-acl flag.
// Maybe NoOwner should be named "no security" or similar?
NoOwner bool
// LongFlags indicates whether to use long flags, e.g. --no-owner instead
// of -O.
LongFlags bool
}
func (p *ToolParams) flag(name string) string {
if p.LongFlags {
return flagsLong[name]
}
return flagsShort[name]
}
// DumpCatalogCmd returns the shell command to execute pg_dump for src.
// Example output:
//
// pg_dump -Fc -d postgres://alice:vNgR6R@db.acme.com:5432/sales sales.dump
//
// Reference:
//
// - https://www.postgresql.org/docs/9.6/app-pgdump.html
// - https://www.postgresql.org/docs/9.6/app-pgrestore.html
//
// See also: [RestoreCatalogCmd].
func DumpCatalogCmd(src *source.Source, p *ToolParams) (*execz.Cmd, error) {
// - https://www.postgresql.org/docs/9.6/app-pgdump.html
// - https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp
// - https://gist.github.com/vielhuber/96eefdb3aff327bdf8230d753aaee1e1
cfg, err := getPoolConfig(src, true)
if err != nil {
return nil, err
}
cmd := &execz.Cmd{Name: "pg_dump"}
if p.Verbose {
cmd.ProgressFromStderr = true
cmd.Args = append(cmd.Args, p.flag(flagVerbose))
}
cmd.Args = append(cmd.Args, p.flag(flagClean), p.flag(flagIfExists)) // TODO: should be optional
cmd.Args = append(cmd.Args, p.flag(flagFormatCustomArchive))
if p.NoOwner {
// You might expect we'd add --no-owner, but if we're outputting a custom
// archive (-Fc), then --no-owner is the default. From the pg_dump docs:
//
// This option is ignored when emitting an archive (non-text) output file.
// For the archive formats, you can specify the option when you call pg_restore.
//
// If we ultimately allow non-archive formats, then we'll need to add
// special handling for --no-owner.
cmd.Args = append(cmd.Args, p.flag(flagNoACL))
}
cmd.Args = append(cmd.Args, p.flag(flagDBName), cfg.ConnString())
if p.File != "" {
cmd.UsesOutputFile = p.File
cmd.Args = append(cmd.Args, p.flag(flagFile), p.File)
}
return cmd, nil
}
// RestoreCatalogCmd returns the shell command to restore a pg catalog (db) from
// a dump produced by pg_dump ([DumpClusterCmd]). Example command:
//
// pg_restore -d postgres://alice:vNgR6R@db.acme.com:5432/sales sales.dump
//
// Reference:
//
// - https://www.postgresql.org/docs/9.6/app-pgrestore.html
// - https://www.postgresql.org/docs/9.6/app-pgdump.html
//
// See also: [DumpCatalogCmd].
func RestoreCatalogCmd(src *source.Source, p *ToolParams) (*execz.Cmd, error) {
// - https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp
// - https://gist.github.com/vielhuber/96eefdb3aff327bdf8230d753aaee1e1
cfg, err := getPoolConfig(src, true)
if err != nil {
return nil, err
}
cmd := &execz.Cmd{Name: "pg_restore", CmdDirPath: true}
if p.Verbose {
cmd.ProgressFromStderr = true
cmd.Args = append(cmd.Args, p.flag(flagVerbose))
}
if p.NoOwner {
// NoOwner sets both --no-owner and --no-acl. Maybe these should
// be separate options.
cmd.Args = append(cmd.Args, p.flag(flagNoACL), p.flag(flagNoOwner)) // -O is --no-owner
}
cmd.Args = append(cmd.Args,
p.flag(flagClean),
p.flag(flagIfExists),
p.flag(flagCreate),
p.flag(flagDBName),
cfg.ConnString(),
)
if p.File != "" {
cmd.UsesOutputFile = p.File
cmd.Args = append(cmd.Args, p.File)
}
return cmd, nil
}
// DumpClusterCmd returns the shell command to execute pg_dumpall for src.
// Example output (components concatenated with space):
//
// PGPASSWORD=vNgR6R pg_dumpall -w -l sales -d postgres://alice:vNgR6R@db.acme.com:5432/sales -f cluster.dump
//
// Note that the dump produced by pg_dumpall is executed by psql, not pg_restore.
//
// - https://www.postgresql.org/docs/9.6/app-pg-dumpall.html
// - https://www.postgresql.org/docs/9.6/app-psql.html
// - https://www.postgresql.org/docs/9.6/app-pgdump.html
// - https://www.postgresql.org/docs/9.6/app-pgrestore.html
// - https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp
//
// See also: [RestoreClusterCmd].
func DumpClusterCmd(src *source.Source, p *ToolParams) (*execz.Cmd, error) {
// - https://www.postgresql.org/docs/9.6/app-pg-dumpall.html
// - https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp
cfg, err := getPoolConfig(src, true)
if err != nil {
return nil, err
}
cmd := &execz.Cmd{
Name: "pg_dumpall",
CmdDirPath: true,
Env: []string{"PGPASSWORD=" + cfg.ConnConfig.Password},
}
// FIXME: need mechanism to indicate that env contains password
if p.Verbose {
cmd.ProgressFromStderr = true
cmd.Args = append(cmd.Args, p.flag(flagVerbose))
}
cmd.Args = append(cmd.Args, p.flag(flagClean), p.flag(flagIfExists)) // TODO: should be optional
if p.NoOwner {
// NoOwner sets both --no-owner and --no-acl. Maybe these should
// be separate options.
cmd.Args = append(cmd.Args, p.flag(flagNoACL), p.flag(flagNoOwner))
}
cmd.Args = append(cmd.Args,
p.flag(flagNoPassword),
p.flag(flagDatabase), cfg.ConnConfig.Database,
p.flag(flagDBName), cfg.ConnString(),
)
if p.File != "" {
cmd.Args = append(cmd.Args, p.flag(flagFile), p.File)
}
return cmd, nil
}
// RestoreClusterCmd returns the shell command to restore a pg cluster from a
// dump produced by pg_dumpall (DumpClusterCmd). Note that the dump produced
// by pg_dumpall is executed by psql, not pg_restore. Example command:
//
// psql -d postgres://alice:vNgR6R@db.acme.com:5432/sales -f sales.dump
//
// Reference:
//
// - https://www.postgresql.org/docs/9.6/app-pg-dumpall.html
// - https://www.postgresql.org/docs/9.6/app-psql.html
// - https://www.postgresql.org/docs/9.6/app-pgdump.html
// - https://www.postgresql.org/docs/9.6/app-pgrestore.html
// - https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp
//
// See also: [DumpClusterCmd].
func RestoreClusterCmd(src *source.Source, p *ToolParams) (*execz.Cmd, error) {
// - https://gist.github.com/vielhuber/96eefdb3aff327bdf8230d753aaee1e1
cfg, err := getPoolConfig(src, true)
if err != nil {
return nil, err
}
cmd := &execz.Cmd{Name: "psql", CmdDirPath: true}
if p.Verbose {
cmd.ProgressFromStderr = true
cmd.Args = append(cmd.Args, p.flag(flagVerbose))
}
cmd.Args = append(cmd.Args, p.flag(flagDBName), cfg.ConnString())
if p.File != "" {
cmd.Args = append(cmd.Args, p.flag(flagFile), p.File)
}
return cmd, nil
}
type ExecToolParams struct {
// ScriptFile is the path to the script file.
// Only one of ScriptFile or CmdString will be set.
ScriptFile string
// CmdString is the literal SQL command string.
CmdString string
// Verbose indicates verbose output (progress).
// Only one of ScriptFile or CmdString will be set.
Verbose bool
// LongFlags indicates whether to use long flags, e.g. --file instead of -f.
LongFlags bool
}
func (p *ExecToolParams) flag(name string) string {
if p.LongFlags {
return flagsLong[name]
}
return flagsShort[name]
}
// ExecCmd returns the shell command to execute psql with a script file
// or command string. Example command:
//
// psql -d postgres://alice:vNgR6R@db.acme.com:5432/sales -f query.sql
//
// See: https://www.postgresql.org/docs/9.6/app-psql.html.
func ExecCmd(src *source.Source, p *ExecToolParams) (*execz.Cmd, error) {
cfg, err := getPoolConfig(src, true)
if err != nil {
return nil, err
}
cmd := &execz.Cmd{Name: "psql"}
if !p.Verbose {
cmd.Args = append(cmd.Args, p.flag(flagQuiet))
}
cmd.Args = append(cmd.Args, p.flag(flagDBName), cfg.ConnString())
if p.ScriptFile != "" {
cmd.Args = append(cmd.Args, p.flag(flagFile), p.ScriptFile)
}
if p.CmdString != "" {
if p.ScriptFile != "" {
return nil, errz.Errorf("only one of --file or --command may be set")
}
cmd.Args = append(cmd.Args, p.flag(flagCommand), p.CmdString)
}
return cmd, nil
}
// flags for pg_dump and pg_restore programs.
const (
flagNoOwner = "--no-owner"
flagVerbose = "--verbose"
flagQuiet = "--quiet"
flagNoACL = "--no-acl"
flagCreate = "--create"
flagDBName = "--dbname"
flagDatabase = "--database"
flagFormatCustomArchive = "--format=custom"
flagIfExists = "--if-exists"
flagClean = "--clean"
flagNoPassword = "--no-password"
flagFile = "--file"
flagCommand = "--command"
)
var flagsLong = map[string]string{
flagNoOwner: flagNoOwner,
flagVerbose: flagVerbose,
flagQuiet: flagQuiet,
flagNoACL: flagNoACL,
flagCreate: flagCreate,
flagDBName: flagDBName,
flagIfExists: flagIfExists,
flagFormatCustomArchive: flagFormatCustomArchive,
flagClean: flagClean,
flagNoPassword: flagNoPassword,
flagDatabase: flagDatabase,
flagFile: flagFile,
flagCommand: flagCommand,
}
var flagsShort = map[string]string{
flagNoOwner: "-O",
flagVerbose: "-v",
flagQuiet: "-q",
flagNoACL: "-x",
flagCreate: "-C",
flagClean: "-c",
flagDBName: "-d",
flagFormatCustomArchive: "-Fc",
flagIfExists: flagIfExists,
flagNoPassword: "-w",
flagDatabase: "-l",
flagFile: "-f",
flagCommand: "-c",
}

View File

@ -682,6 +682,18 @@ func (d *driveri) CurrentSchema(ctx context.Context, db sqlz.DB) (string, error)
return name, nil
}
// SchemaExists implements driver.SQLDriver.
func (d *driveri) SchemaExists(ctx context.Context, db sqlz.DB, schma string) (bool, error) {
if schma == "" {
return false, nil
}
const q = `SELECT COUNT(name) FROM pragma_database_list WHERE name = ?`
var count int
return count > 0, errw(db.QueryRowContext(ctx, q, schma).Scan(&count))
}
// ListSchemas implements driver.SQLDriver.
func (d *driveri) ListSchemas(ctx context.Context, db sqlz.DB) ([]string, error) {
log := lg.FromContext(ctx)
@ -709,6 +721,44 @@ func (d *driveri) ListSchemas(ctx context.Context, db sqlz.DB) ([]string, error)
return schemas, nil
}
// ListTableNames implements driver.SQLDriver. The returned names exclude
// any sqlite_ internal tables.
func (d *driveri) ListTableNames(ctx context.Context, db sqlz.DB, schma string, tables, views bool) ([]string, error) {
var tblClause string
switch {
case tables && views:
tblClause = " WHERE (type = 'table' OR type = 'view')"
case tables:
tblClause = " WHERE type = 'table'"
case views:
tblClause = " WHERE type = 'view'"
default:
return []string{}, nil
}
tblClause += " AND name NOT LIKE 'sqlite_%'"
q := "SELECT name FROM "
if schma == "" {
q += "sqlite_master"
} else {
q += stringz.DoubleQuote(schma) + ".sqlite_master"
}
q += tblClause + " ORDER BY name"
rows, err := db.QueryContext(ctx, q)
if err != nil {
return nil, errw(err)
}
names, err := sqlz.RowsScanColumn[string](ctx, rows)
if err != nil {
return nil, errw(err)
}
return names, nil
}
// ListSchemaMetadata implements driver.SQLDriver.
// The returned metadata.Schema instances will have a Catalog
// value of "default", and an empty Owner value.
@ -728,6 +778,12 @@ func (d *driveri) ListSchemaMetadata(ctx context.Context, db sqlz.DB) ([]*metada
return schemas, nil
}
// CatalogExists implements driver.SQLDriver. SQLite does not support catalogs,
// so this method always returns an error.
func (d *driveri) CatalogExists(_ context.Context, _ sqlz.DB, _ string) (bool, error) {
return false, errz.New("sqlite3: catalog mechanism not supported")
}
// CurrentCatalog implements driver.SQLDriver. SQLite does not support catalogs,
// so this method returns an error.
func (d *driveri) CurrentCatalog(_ context.Context, _ sqlz.DB) (string, error) {

View File

@ -402,6 +402,19 @@ func (d *driveri) ListSchemas(ctx context.Context, db sqlz.DB) ([]string, error)
return schemas, nil
}
// SchemaExists implements driver.SQLDriver.
func (d *driveri) SchemaExists(ctx context.Context, db sqlz.DB, schma string) (bool, error) {
if schma == "" {
return false, nil
}
const q = `SELECT COUNT(SCHEMA_NAME) FROM INFORMATION_SCHEMA.SCHEMATA
WHERE SCHEMA_NAME = @p1 AND CATALOG_NAME = DB_NAME()`
var count int
return count > 0, errw(db.QueryRowContext(ctx, q, schma).Scan(&count))
}
// ListSchemaMetadata implements driver.SQLDriver.
func (d *driveri) ListSchemaMetadata(ctx context.Context, db sqlz.DB) ([]*metadata.Schema, error) {
log := lg.FromContext(ctx)
@ -450,6 +463,18 @@ func (d *driveri) CurrentCatalog(ctx context.Context, db sqlz.DB) (string, error
return name, nil
}
// CatalogExists implements driver.SQLDriver.
func (d *driveri) CatalogExists(ctx context.Context, db sqlz.DB, catalog string) (bool, error) {
if catalog == "" {
return false, nil
}
const q = `SELECT COUNT(name) FROM sys.databases WHERE name = @p1`
var count int
return count > 0, errw(db.QueryRowContext(ctx, q, catalog).Scan(&count))
}
// ListCatalogs implements driver.SQLDriver.
func (d *driveri) ListCatalogs(ctx context.Context, db sqlz.DB) ([]string, error) {
catalogs := make([]string, 1, 3)
@ -457,9 +482,7 @@ func (d *driveri) ListCatalogs(ctx context.Context, db sqlz.DB) ([]string, error
return nil, errw(err)
}
const q = `SELECT name FROM sys.databases
WHERE name != DB_NAME()
ORDER BY name`
const q = `SELECT name FROM sys.databases WHERE name != DB_NAME() ORDER BY name`
rows, err := db.QueryContext(ctx, q)
if err != nil {
@ -511,6 +534,44 @@ func (d *driveri) DropSchema(ctx context.Context, db sqlz.DB, schemaName string)
return nil
}
// ListTableNames implements driver.SQLDriver.
func (d *driveri) ListTableNames(ctx context.Context, db sqlz.DB, schma string, tables, views bool) ([]string, error) {
var tblClause string
switch {
case tables && views:
tblClause = " AND (TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW')"
case tables:
tblClause = " AND TABLE_TYPE = 'BASE TABLE'"
case views:
tblClause = " AND TABLE_TYPE = 'VIEW'"
default:
return []string{}, nil
}
var args []any
q := "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = "
if schma == "" {
q += "SCHEMA_NAME()"
} else {
q += "@p1"
args = append(args, schma)
}
q += tblClause + " ORDER BY TABLE_NAME"
rows, err := db.QueryContext(ctx, q, args...)
if err != nil {
return nil, errw(err)
}
names, err := sqlz.RowsScanColumn[string](ctx, rows)
if err != nil {
return nil, errw(err)
}
return names, nil
}
// CreateTable implements driver.SQLDriver.
func (d *driveri) CreateTable(ctx context.Context, db sqlz.DB, tblDef *schema.Table) error {
stmt := buildCreateTableStmt(tblDef)

View File

@ -42,20 +42,20 @@ func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Tab
// SourceMetadata implements driver.Grip.
func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) {
meta, err := g.impl.SourceMetadata(ctx, noSchema)
md, err := g.impl.SourceMetadata(ctx, noSchema)
if err != nil {
return nil, err
}
meta.Handle = g.src.Handle
meta.Location = g.src.Location
meta.Name, err = location.Filename(g.src.Location)
md.Handle = g.src.Handle
md.Location = g.src.Location
md.Name, err = location.Filename(g.src.Location)
if err != nil {
return nil, err
}
meta.FQName = meta.Name
return meta, nil
md.FQName = md.Name
return md, nil
}
// Close implements driver.Grip.

2
go.mod
View File

@ -56,7 +56,9 @@ require (
github.com/Masterminds/semver/v3 v3.2.1 // indirect
github.com/VividCortex/ewma v1.2.0 // indirect
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect
github.com/alecthomas/chroma v0.10.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/huandu/xstrings v1.4.0 // indirect

5
go.sum
View File

@ -25,6 +25,8 @@ github.com/a8m/tree v0.0.0-20240104212747-2c8764a5f17e h1:KMVieI1/Ub++GYfnhyFPoG
github.com/a8m/tree v0.0.0-20240104212747-2c8764a5f17e/go.mod h1:j5astEcUkZQX8lK+KKlQ3NRQ50f4EE8ZjyZpCz3mrH4=
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d h1:licZJFw2RwpHMqeKTCYkitsPqHNxTmd4SNR5r94FGM8=
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat636LX7Bqt5lYEZ27JNDcqxfjdBQuJ/MM4CN/Lzo=
github.com/alecthomas/chroma v0.10.0 h1:7XDcGkCQopCNKjZHfYrNLraA+M7e0fMiJ/Mfikbfjek=
github.com/alecthomas/chroma v0.10.0/go.mod h1:jtJATyUxlIORhUOFNA9NZDWGAQ8wpxQQqNSB4rjA/1s=
github.com/alessio/shellescape v1.4.2 h1:MHPfaU+ddJ0/bYWpgIeUnQUqKrlJ1S7BfEYPM4uEoM0=
github.com/alessio/shellescape v1.4.2/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30=
github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI=
@ -38,6 +40,9 @@ github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/ecnepsnai/osquery v1.0.1 h1:i96n/3uqcafKZtRYmXVNqekKbfrIm66q179mWZ/Y2Aw=

2
huzzah.txt Normal file
View File

@ -0,0 +1,2 @@
payment_id customer_id name staff_id rental_id amount payment_date last_update
1 71 Alice 71 0 3 500 2023-12-04T04:29:00Z 2023-12-04T04:29:00Z

View File

@ -198,14 +198,15 @@ func errorf(format string, v ...any) error {
}
// ParseCatalogSchema parses a string of the form 'catalog.schema'
// and returns the catalog and schema. An error is returned if the schema
// is empty (but catalog may be empty). Whitespace and quotes are handled
// and returns the catalog and schema. It is permissible for one of the
// components to be empty (but not both). Whitespace and quotes are handled
// correctly.
//
// Examples:
//
// `catalog.schema` -> "catalog", "schema", nil
// `schema` -> "", "schema", nil
// `catalog.` -> "catalog", "", nil
// `schema` -> "", "schema", nil
// `"my catalog"."my schema"` -> "my catalog", "my schema", nil
//
// An error is returned if s is empty.
@ -219,8 +220,18 @@ func ParseCatalogSchema(s string) (catalog, schema string, err error) {
// We'll hijack the existing parser code. A value "catalog.schema" is
// not valid, but ".catalog.schema" works as a selector.
//
// Being that we accept "catalog." as valid (indicating the default schema),
// we'll use a hack to make the parser work: we append a const string to
// the input (which the parser will think is the schema name), and later
// on we check for that const string, and set schema to empty string if
// that const is found.
const schemaNameHack = "DEFAULT_SCHEMA_HACK_be8hx64wd45vxusdebez2e6tega8ussy"
sel := "." + s
if strings.HasSuffix(s, ".") {
sel += schemaNameHack
}
a, err := Parse(lg.Discard(), sel)
if err != nil {
return "", "", errz.Errorf(errTpl, s)
@ -230,7 +241,9 @@ func ParseCatalogSchema(s string) (catalog, schema string, err error) {
return "", "", errz.Errorf(errTpl, s)
}
tblSel := NewInspector(a).FindFirstTableSelector()
insp := NewInspector(a)
tblSel := insp.FindFirstTableSelector()
if tblSel == nil {
return "", "", errz.Errorf(errTpl, s)
}
@ -243,6 +256,8 @@ func ParseCatalogSchema(s string) (catalog, schema string, err error) {
}
if schema == "" {
return "", "", errz.Errorf(errTpl, s)
} else if schema == schemaNameHack {
schema = ""
}
return catalog, schema, nil

View File

@ -16,7 +16,10 @@ func TestParseCatalogSchema(t *testing.T) {
wantErr bool
}{
{in: "", wantErr: true},
{in: ".", wantErr: true},
{in: "dbo", wantCatalog: "", wantSchema: "dbo"},
{in: "sakila.", wantCatalog: "sakila", wantSchema: ""},
{in: ".dbo", wantErr: true},
{in: "sakila.dbo", wantCatalog: "sakila", wantSchema: "dbo"},
{in: `"my catalog"."my schema"`, wantCatalog: "my catalog", wantSchema: "my schema"},
{in: `"my catalog""."my schema"`, wantErr: true},

View File

@ -266,6 +266,31 @@ func Chain(err error) []error {
return errs
}
// ExitCoder is an interface that an error type can implement to indicate
// that the program should exit with a specific status code.
// In particular, note that *exec.ExitError implements this interface.
type ExitCoder interface {
// ExitCode returns the exit code indicated by the error, or -1 if
// the error does not indicate a particular exit code.
ExitCode() int
}
// ExitCode returns the exit code of the first error in err's chain
// that implements ExitCoder, otherwise -1.
func ExitCode(err error) (code int) {
if err == nil {
return -1
}
chain := Chain(err)
for i := range chain {
if coder, ok := chain[i].(ExitCoder); ok { //nolint:errorlint
return coder.ExitCode()
}
}
return -1
}
// SprintTreeTypes returns a string representation of err's type tree.
// A multi-error is represented as a slice of its children.
func SprintTreeTypes(err error) string {

332
libsq/core/execz/execz.go Normal file
View File

@ -0,0 +1,332 @@
// Package execz builds on stdlib os/exec.
package execz
import (
"bytes"
"context"
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lgm"
"github.com/neilotoole/sq/libsq/core/loz"
"github.com/neilotoole/sq/libsq/core/progress"
"github.com/neilotoole/sq/libsq/core/stringz"
"github.com/neilotoole/sq/libsq/core/termz"
"github.com/neilotoole/sq/libsq/source/location"
)
// Cmd represents an external command being prepared or run.
type Cmd struct {
// Stdin is the command's stdin. If nil, [os.Stdin] is used.
Stdin io.Reader
// Stdout is the command's stdout. If nil, [os.Stdout] is used.
Stdout io.Writer
// Stderr is the command's stderr. If nil, [os.Stderr] is used.
Stderr io.Writer
// Name is the executable name, e.g. "pg_dump".
Name string
// Label is a human-readable label for the command, e.g. "@sakila: dump".
// If empty, [Cmd.Name] is used.
Label string
// ErrPrefix is the prefix to use for error messages.
ErrPrefix string
// UsesOutputFile indicates that the command its output to this filepath
// instead of stdout. If empty, stdout is being used.
UsesOutputFile string
// Args is the set of args to the command.
Args []string
// Env is the set of environment variables to set for the command.
Env []string
// NoProgress indicates that progress messages should not be output.
NoProgress bool
// ProgressFromStderr indicates that the command outputs progress messages
// on stderr.
ProgressFromStderr bool
// CmdDirPath controls whether the command's PATH will include the parent dir
// of the command. This allows the command to access sibling commands in the
// same dir, e.g. "pg_dumpall" needs to invoke "pg_dump".
CmdDirPath bool
}
// String returns what command would look like if executed in a shell.
// Note that the returned string could contain sensitive information such as
// passwords, so it's not safe for logging. Instead, see [Cmd.LogValue].
func (c *Cmd) String() string {
if c == nil {
return ""
}
sb := strings.Builder{}
for i := range c.Env {
if i > 0 {
sb.WriteRune(' ')
}
sb.WriteString(stringz.ShellEscape(c.Env[i]))
}
if sb.Len() > 0 {
sb.WriteRune(' ')
}
sb.WriteString(stringz.ShellEscape(c.Name))
for i := range c.Args {
sb.WriteRune(' ')
sb.WriteString(stringz.ShellEscape(c.Args[i]))
}
return sb.String()
}
// redactedCmd returns a redacted rendering of c, suitable for logging (but
// not execution). If escape is true, the string is also shell-escaped.
func (c *Cmd) redactedCmd(escape bool) string {
if c == nil {
return ""
}
env := c.redactedEnv(escape)
args := c.redactedArgs(escape)
switch {
case len(env) == 0 && len(args) == 0:
return c.Name
case len(env) == 0:
return c.Name + " " + strings.Join(args, " ")
case len(args) == 0:
return strings.Join(env, " ") + " " + c.Name
default:
return strings.Join(env, " ") + " " + c.Name + " " + strings.Join(args, " ")
}
}
// redactedEnv returns c's env with sensitive values redacted.
// If escape is true, the values are also shell-escaped.
func (c *Cmd) redactedEnv(escape bool) []string {
if c == nil || len(c.Env) == 0 {
return []string{}
}
envars := make([]string, len(c.Env))
for i := range c.Env {
parts := strings.SplitN(c.Env[i], "=", 2)
if len(parts) < 2 {
// Shouldn't happen, but just in case.
if escape {
envars[i] = stringz.ShellEscape(c.Env[i])
} else {
envars[i] = c.Env[i]
}
continue
}
// If the envar value is a SQL or HTTP location, redact it.
if location.TypeOf(parts[1]).IsURL() {
parts[1] = location.Redact(parts[1])
}
if escape {
envars[i] = parts[0] + "=" + stringz.ShellEscape(parts[1])
} else {
envars[i] = parts[0] + "=" + parts[1]
}
}
return envars
}
// redactedArgs returns c's args with sensitive values redacted.
// If escape is true, the values are also shell-escaped.
func (c *Cmd) redactedArgs(escape bool) []string {
if c == nil || len(c.Args) == 0 {
return []string{}
}
args := make([]string, len(c.Args))
for i := range c.Args {
if location.TypeOf(c.Args[i]).IsURL() {
args[i] = location.Redact(c.Args[i])
if escape {
args[i] = stringz.ShellEscape(args[i])
}
continue
}
if escape {
args[i] = stringz.ShellEscape(c.Args[i])
} else {
args[i] = c.Args[i]
}
}
return args
}
var _ slog.LogValuer = (*Cmd)(nil)
// LogValue implements [slog.LogValuer]. It redacts sensitive information
// (passwords etc.) from URL-like values.
func (c *Cmd) LogValue() slog.Value {
if c == nil {
return slog.Value{}
}
attrs := []slog.Attr{
slog.String("name", c.Name),
slog.String("exec", c.redactedCmd(false)),
}
return slog.GroupValue(attrs...)
}
// Exec executes cmd.
func Exec(ctx context.Context, cmd *Cmd) (err error) {
log := lg.FromContext(ctx)
defer func() {
if err != nil && cmd.UsesOutputFile != "" {
// If an error occurred, we want to remove the output file.
lg.WarnIfError(lg.FromContext(ctx), lgm.RemoveFile, os.Remove(cmd.UsesOutputFile))
}
}()
if cmd.Stdin == nil {
cmd.Stdin = os.Stdin
}
if cmd.Stdout == nil {
cmd.Stdout = os.Stdout
}
if cmd.Stderr == nil {
cmd.Stderr = os.Stderr
}
execCmd := exec.CommandContext(ctx, cmd.Name, cmd.Args...) //nolint:gosec
if cmd.CmdDirPath {
execCmd.Env = append(execCmd.Env, "PATH="+filepath.Dir(execCmd.Path))
}
execCmd.Env = append(execCmd.Env, cmd.Env...)
execCmd.Stdin = cmd.Stdin
execCmd.Stdout = cmd.Stdout
stderrBuf := &bytes.Buffer{}
execCmd.Stderr = io.MultiWriter(stderrBuf, cmd.Stderr)
switch {
case cmd.ProgressFromStderr:
log.Warn("It's cmd.ProgressFromStderr")
// TODO: We really want to print stderr.
case cmd.UsesOutputFile != "":
log.Warn("It's cmd.UsesOutputFile")
// Truncate the file, ignoring any error (e.g. if it doesn't exist).
_ = os.Truncate(cmd.UsesOutputFile, 0)
if !cmd.NoProgress {
bar := progress.FromContext(ctx).NewFilesizeCounter(
loz.NonEmptyOf(cmd.Label, cmd.Name),
nil,
cmd.UsesOutputFile,
progress.OptTimer,
)
defer bar.Stop()
}
default:
log.Warn("It's default")
// We're reduced to reading the size of stdout, but not if we're on a
// terminal. If we are on a terminal, then the user will get to see the
// command output in real-time and we don't need a progress bar.
if !termz.IsTerminal(os.Stdout) {
log.Warn("It's not a terminal")
if _, ok := cmd.Stdout.(*os.File); ok && !cmd.NoProgress {
bar := progress.FromContext(ctx).NewFilesizeCounter(
loz.NonEmptyOf(cmd.Label, cmd.Name),
cmd.Stdout.(*os.File),
"",
progress.OptTimer,
)
defer bar.Stop()
}
}
}
if err = execCmd.Run(); err != nil {
return newExecError(cmd.ErrPrefix, cmd, execCmd, stderrBuf, err)
}
return nil
}
var _ error = (*execError)(nil)
// execError is an error that occurred during command execution.
type execError struct {
msg string
execErr error
cmd *Cmd
execCmd *exec.Cmd
errOut []byte
}
// Error returns the error message.
func (e *execError) Error() string {
s := e.msg + ": " + e.execErr.Error()
if len(e.errOut) > 0 {
s += ": " + string(e.errOut)
s = strings.TrimSuffix(s, "\r\n") // windows
s = strings.TrimSuffix(s, "\n")
}
return s
}
// Unwrap returns the underlying error.
func (e *execError) Unwrap() error {
return e.execErr
}
// ExitCode returns the exit code of the command execution if the underlying
// execution error was an *exec.ExitError, otherwise -1.
func (e *execError) ExitCode() int {
if ee, ok := errz.As[*exec.ExitError](e.execErr); ok {
return ee.ExitCode()
}
return -1
}
// newExecError creates a new execError. If cmd.Stderr is
// a *bytes.Buffer, it will be used to populate the errOut field,
// otherwise errOut may be nil.
func newExecError(msg string, cmd *Cmd, execCmd *exec.Cmd, stderrBuf *bytes.Buffer, execErr error) *execError {
e := &execError{
msg: msg,
execErr: execErr,
cmd: cmd,
execCmd: execCmd,
}
// TODO: We should implement special handling for Lookup errors,
// e.g. "pg_dump" not found.
if stderrBuf != nil {
e.errOut = stderrBuf.Bytes()
}
return e
}

View File

@ -11,6 +11,7 @@ const (
CloseHTTPResponseBody = "Close HTTP response body"
CloseFileReader = "Close file reader"
CloseFileWriter = "Close file writer"
CloseOutputFile = "Close output file"
CtxDone = "Context unexpectedly done"
OpenSrc = "Open source"
ReadDBRows = "Read DB rows"

View File

@ -202,3 +202,20 @@ func RemoveUnordered[T any](a []*T, v *T) []*T {
}
return a
}
// Cond returns a if cond is true, else b. It's basically the ternary operator.
func Cond[T any](cond bool, a, b T) T {
if cond {
return a
}
return b
}
// NonEmptyOf returns a if a is non-zero, else b.
func NonEmptyOf[T comparable](a, b T) T {
var zero T
if a != zero {
return a
}
return b
}

View File

@ -1,6 +1,8 @@
package progress
import (
"fmt"
"os"
"time"
humanize "github.com/dustin/go-humanize"
@ -38,6 +40,40 @@ func (p *Progress) NewByteCounter(msg string, size int64, opts ...Opt) *Bar {
return p.newBar(cfg, opts)
}
// NewFilesizeCounter returns a new indeterminate bar whose label metric is a
// filesize, or "-" if it can't be read. If f is non-nil, its size is used; else
// the file at path fp is used. The caller is ultimately responsible for calling
// Bar.Stop on the returned Bar.
func (p *Progress) NewFilesizeCounter(msg string, f *os.File, fp string, opts ...Opt) *Bar {
if p == nil {
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
cfg := &barConfig{msg: msg, total: -1, style: spinnerStyle(p.colors.Filler)}
d := decor.Any(func(statistics decor.Statistics) string {
var fi os.FileInfo
var err error
if f != nil {
fi, err = f.Stat()
} else {
fi, err = os.Stat(fp)
}
if err != nil {
return "-"
}
return fmt.Sprintf("% .1f", decor.SizeB1024(fi.Size()))
})
cfg.decorators = []decor.Decorator{colorize(d, p.colors.Size)}
return p.newBar(cfg, opts)
}
// NewUnitCounter returns a new indeterminate bar whose label
// metric is the plural of the provided unit. The caller is ultimately
// responsible for calling Bar.Stop on the returned Bar.

View File

@ -117,6 +117,17 @@ func newElapsedSeconds(c *color.Color, startTime time.Time, wcc ...decor.WC) dec
return decor.Any(fn, wcc...)
}
// OptTimer is an Opt that causes the bar to display elapsed seconds.
var OptTimer = optElapsedSeconds{}
var _ Opt = optElapsedSeconds{}
type optElapsedSeconds struct{}
func (optElapsedSeconds) apply(p *Progress, cfg *barConfig) {
cfg.decorators = append(cfg.decorators, newElapsedSeconds(p.colors.Size, time.Now(), decor.WCSyncSpace))
}
// OptMemUsage is an Opt that causes the bar to display program
// memory usage.
var OptMemUsage = optMemUsage{}

View File

@ -6,6 +6,8 @@ import (
"database/sql"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lgm"
)
// Execer abstracts the ExecContext method
@ -78,3 +80,36 @@ func RequireSingleConn(db DB) error {
return nil
}
// RowsScanColumn scans a single-column [*sql.Rows] into a slice of T. If the
// returned value could be null, use a nullable type, e.g. [sql.NullString].
// Arg rows is always closed. On any error, the returned slice is nil.
func RowsScanColumn[T any](ctx context.Context, rows *sql.Rows) (vals []T, err error) {
defer func() {
if rows != nil {
lg.WarnIfCloseError(lg.FromContext(ctx), lgm.CloseDBRows, rows)
}
}()
// We want to return an empty slice rather than nil.
vals = make([]T, 0)
for rows.Next() {
var val T
if err = rows.Scan(&val); err != nil {
return nil, errz.Err(err)
}
vals = append(vals, val)
}
if err = rows.Err(); err != nil {
return nil, errz.Err(err)
}
err = rows.Close()
rows = nil
if err != nil {
return nil, errz.Err(err)
}
return vals, nil
}

View File

@ -1,6 +1,7 @@
//go:build !windows
package cli
// Package termz contains a handful of terminal utilities.
package termz
import (
"io"
@ -9,8 +10,8 @@ import (
"golang.org/x/term"
)
// isTerminal returns true if w is a terminal.
func isTerminal(w io.Writer) bool {
// IsTerminal returns true if w is a terminal.
func IsTerminal(w io.Writer) bool {
switch v := w.(type) {
case *os.File:
return term.IsTerminal(int(v.Fd()))
@ -19,12 +20,12 @@ func isTerminal(w io.Writer) bool {
}
}
// isColorTerminal returns true if w is a colorable terminal.
// IsColorTerminal returns true if w is a colorable terminal.
// It respects [NO_COLOR], [FORCE_COLOR] and TERM=dumb environment variables.
//
// [NO_COLOR]: https://no-color.org/
// [FORCE_COLOR]: https://force-color.org/
func isColorTerminal(w io.Writer) bool {
func IsColorTerminal(w io.Writer) bool {
if os.Getenv("NO_COLOR") != "" {
return false
}

View File

@ -1,4 +1,4 @@
package cli
package termz
import (
"io"
@ -8,8 +8,8 @@ import (
"golang.org/x/term"
)
// isTerminal returns true if w is a terminal.
func isTerminal(w io.Writer) bool {
// IsTerminal returns true if w is a terminal.
func IsTerminal(w io.Writer) bool {
switch v := w.(type) {
case *os.File:
return term.IsTerminal(int(v.Fd()))
@ -18,7 +18,7 @@ func isTerminal(w io.Writer) bool {
}
}
// isColorTerminal returns true if w is a colorable terminal.
// IsColorTerminal returns true if w is a colorable terminal.
// It respects [NO_COLOR], [FORCE_COLOR] and TERM=dumb environment variables.
//
// Acknowledgement: This function is lifted from neilotoole/jsoncolor, but
@ -27,7 +27,7 @@ func isTerminal(w io.Writer) bool {
//
// [NO_COLOR]: https://no-color.org/
// [FORCE_COLOR]: https://force-color.org/
func isColorTerminal(w io.Writer) bool {
func IsColorTerminal(w io.Writer) bool {
if os.Getenv("NO_COLOR") != "" {
return false
}

View File

@ -150,6 +150,14 @@ type SQLDriver interface {
// DropSchema drops the named schema in db.
DropSchema(ctx context.Context, db sqlz.DB, schemaName string) error
// CatalogExists returns true if db can reference the named catalog. If
// catalog is empty string, false is returned.
CatalogExists(ctx context.Context, db sqlz.DB, catalog string) (bool, error)
// SchemaExists returns true if db can reference the named schema. If
// schma is empty string, false is returned.
SchemaExists(ctx context.Context, db sqlz.DB, schma string) (bool, error)
// Truncate truncates tbl in src. If arg reset is true, the
// identity counter for tbl should be reset, if supported
// by the driver. Some DB impls may reset the identity
@ -159,6 +167,11 @@ type SQLDriver interface {
// TableExists returns true if there's an existing table tbl in db.
TableExists(ctx context.Context, db sqlz.DB, tbl string) (bool, error)
// ListTableNames lists the tables of schma in db. The "tables" and "views"
// args filter TABLE and VIEW types, respectively. If both are false, an empty
// slice is returned. If schma is empty, the current schema is used.
ListTableNames(ctx context.Context, db sqlz.DB, schma string, tables, views bool) ([]string, error)
// CopyTable copies fromTable into a new table toTable.
// If copyData is true, fromTable's data is also copied.
// Constraints (keys, defaults etc.) may not be copied. The

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"slices"
"testing"
"github.com/stretchr/testify/assert"
@ -430,7 +431,7 @@ func TestRegistry_DriversMetadata_Doc(t *testing.T) {
}
}
func TestDatabase_TableMetadata(t *testing.T) { //nolint:tparallel
func TestGrip_TableMetadata(t *testing.T) { //nolint:tparallel
for _, handle := range sakila.SQLAll() {
handle := handle
@ -447,7 +448,7 @@ func TestDatabase_TableMetadata(t *testing.T) { //nolint:tparallel
}
}
func TestDatabase_SourceMetadata(t *testing.T) {
func TestGrip_SourceMetadata(t *testing.T) {
t.Parallel()
for _, handle := range sakila.SQLAll() {
@ -466,9 +467,99 @@ func TestDatabase_SourceMetadata(t *testing.T) {
}
}
// TestDatabase_SourceMetadata_concurrent tests the behavior of the
// TestSQLDriver_ListTableNames_ArgSchemaEmpty tests [driver.SQLDriver.ListTableNames]
// with an empty schema arg.
func TestSQLDriver_ListTableNames_ArgSchemaEmpty(t *testing.T) { //nolint:tparallel
for _, handle := range sakila.SQLLatest() {
handle := handle
t.Run(handle, func(t *testing.T) {
t.Parallel()
th, _, drvr, _, db := testh.NewWith(t, handle)
got, err := drvr.ListTableNames(th.Context, db, "", false, false)
require.NoError(t, err)
require.NotNil(t, got)
require.True(t, len(got) == 0)
got, err = drvr.ListTableNames(th.Context, db, "", true, false)
require.NoError(t, err)
require.NotNil(t, got)
require.Contains(t, got, sakila.TblActor)
require.NotContains(t, got, sakila.ViewFilmList)
got, err = drvr.ListTableNames(th.Context, db, "", false, true)
require.NoError(t, err)
require.NotNil(t, got)
require.NotContains(t, got, sakila.TblActor)
require.Contains(t, got, sakila.ViewFilmList)
got, err = drvr.ListTableNames(th.Context, db, "", true, true)
require.NoError(t, err)
require.NotNil(t, got)
require.Contains(t, got, sakila.TblActor)
require.Contains(t, got, sakila.ViewFilmList)
gotCopy := append([]string(nil), got...)
slices.Sort(gotCopy)
require.Equal(t, got, gotCopy, "expected results to be sorted")
})
}
}
// TestSQLDriver_ListTableNames_ArgSchemaNotEmpty tests
// [driver.SQLDriver.ListTableNames] with a non-empty schema arg.
func TestSQLDriver_ListTableNames_ArgSchemaNotEmpty(t *testing.T) { //nolint:tparallel
testCases := []struct {
handle string
schema string
wantTables int
wantViews int
}{
{handle: sakila.Pg12, schema: "public", wantTables: 25, wantViews: 5},
{handle: sakila.MS19, schema: "dbo", wantTables: 17, wantViews: 5},
{handle: sakila.SL3, schema: "main", wantTables: 16, wantViews: 5},
{handle: sakila.My8, schema: "sakila", wantTables: 16, wantViews: 7},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.handle, func(t *testing.T) {
t.Parallel()
th, _, drvr, _, db := testh.NewWith(t, tc.handle)
got, err := drvr.ListTableNames(th.Context, db, tc.schema, false, false)
require.NoError(t, err)
require.NotNil(t, got)
require.True(t, len(got) == 0)
got, err = drvr.ListTableNames(th.Context, db, tc.schema, true, false)
require.NoError(t, err)
require.NotNil(t, got)
require.Len(t, got, tc.wantTables)
got, err = drvr.ListTableNames(th.Context, db, tc.schema, false, true)
require.NoError(t, err)
require.NotNil(t, got)
require.Len(t, got, tc.wantViews)
got, err = drvr.ListTableNames(th.Context, db, tc.schema, true, true)
require.NoError(t, err)
require.NotNil(t, got)
require.Len(t, got, tc.wantTables+tc.wantViews)
gotCopy := append([]string(nil), got...)
slices.Sort(gotCopy)
require.Equal(t, got, gotCopy, "expected results to be sorted")
})
}
}
// TestGrip_SourceMetadata_concurrent tests the behavior of the
// drivers when SourceMetadata is invoked concurrently.
func TestDatabase_SourceMetadata_concurrent(t *testing.T) { //nolint:tparallel
func TestGrip_SourceMetadata_concurrent(t *testing.T) { //nolint:tparallel
const concurrency = 5
handles := sakila.SQLLatest()
@ -648,6 +739,89 @@ func TestSQLDriver_CurrentSchemaCatalog(t *testing.T) {
}
}
func TestSQLDriver_SchemaExists(t *testing.T) {
t.Parallel()
testCases := []struct {
handle string
schema string
wantOK bool
}{
{handle: sakila.SL3, schema: "main", wantOK: true},
{handle: sakila.SL3, schema: "", wantOK: false},
{handle: sakila.SL3, schema: "not_exist", wantOK: false},
{handle: sakila.Pg, schema: "public", wantOK: true},
{handle: sakila.Pg, schema: "information_schema", wantOK: true},
{handle: sakila.Pg, schema: "not_exist", wantOK: false},
{handle: sakila.Pg, schema: "", wantOK: false},
{handle: sakila.My, schema: "sakila", wantOK: true},
{handle: sakila.My, schema: "", wantOK: false},
{handle: sakila.My, schema: "not_exist", wantOK: false},
{handle: sakila.MS, schema: "dbo", wantOK: true},
{handle: sakila.MS, schema: "sys", wantOK: true},
{handle: sakila.MS, schema: "INFORMATION_SCHEMA", wantOK: true},
{handle: sakila.MS, schema: "", wantOK: false},
{handle: sakila.MS, schema: "not_exist", wantOK: false},
}
for _, tc := range testCases {
tc := tc
t.Run(tu.Name(tc.handle, tc.schema, tc.wantOK), func(t *testing.T) {
t.Parallel()
th, _, drvr, _, db := testh.NewWith(t, tc.handle)
ok, err := drvr.SchemaExists(th.Context, db, tc.schema)
require.NoError(t, err)
require.Equal(t, tc.wantOK, ok)
})
}
}
func TestSQLDriver_CatalogExists(t *testing.T) {
t.Parallel()
testCases := []struct {
handle string
catalog string
wantOK bool
wantErr bool
}{
{handle: sakila.SL3, catalog: "default", wantErr: true},
{handle: sakila.SL3, catalog: "not_exist", wantErr: true},
{handle: sakila.SL3, catalog: "", wantErr: true},
{handle: sakila.Pg, catalog: "sakila", wantOK: true},
{handle: sakila.Pg, catalog: "postgres", wantOK: true},
{handle: sakila.Pg, catalog: "not_exist", wantOK: false},
{handle: sakila.Pg, catalog: "", wantOK: false},
{handle: sakila.My, catalog: "def", wantOK: true},
{handle: sakila.My, catalog: "not_exist", wantOK: false},
{handle: sakila.My, catalog: "", wantOK: false},
{handle: sakila.MS, catalog: "sakila", wantOK: true},
{handle: sakila.MS, catalog: "model", wantOK: true},
{handle: sakila.MS, catalog: "not_exist", wantOK: false},
{handle: sakila.MS, catalog: "", wantOK: false},
}
for _, tc := range testCases {
tc := tc
t.Run(tu.Name(tc.handle, tc.catalog, tc.wantOK), func(t *testing.T) {
t.Parallel()
th, _, drvr, _, db := testh.NewWith(t, tc.handle)
ok, err := drvr.CatalogExists(th.Context, db, tc.catalog)
require.Equal(t, tc.wantOK, ok)
if tc.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
func TestDriverCreateDropSchema(t *testing.T) {
testCases := []struct {
handle string

View File

@ -358,6 +358,11 @@ func TypeOf(loc string) Type {
return TypeFile
}
// IsURL returns true if t is TypeHTTP or TypeSQL.
func (t Type) IsURL() bool {
return t == TypeHTTP || t == TypeSQL
}
// isHTTP tests if s is a well-structured HTTP or HTTPS url, and
// if so, returns the url and true.
func isHTTP(s string) (u *url.URL, ok bool) {

View File

@ -20,6 +20,10 @@ func main() {
defer func() {
cancelFn()
if err != nil {
if code := errz.ExitCode(err); code > 0 {
os.Exit(code)
}
os.Exit(1)
}
}()