sq/cli/cmd_ping.go

220 lines
5.7 KiB
Go
Raw Normal View History

2020-08-06 20:58:47 +03:00
package cli
import (
"context"
"errors"
2020-08-06 20:58:47 +03:00
"time"
"github.com/samber/lo"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/lg"
2020-08-06 20:58:47 +03:00
"github.com/spf13/cobra"
2020-08-06 20:58:47 +03:00
"github.com/neilotoole/sq/cli/output"
"github.com/neilotoole/sq/libsq/core/errz"
2020-08-06 20:58:47 +03:00
"github.com/neilotoole/sq/libsq/driver"
"github.com/neilotoole/sq/libsq/source"
)
func newPingCmd() *cobra.Command {
2020-08-06 20:58:47 +03:00
cmd := &cobra.Command{
Use: "ping [@HANDLE|GROUP]*",
RunE: execPing,
ValidArgsFunction: completeHandleOrGroup,
Short: "Ping data sources",
Long: `Ping data sources (or groups of sources) to check connection health.
If no arguments provided, the active data source is pinged. Otherwise, ping
the specified sources or groups.
The exit code is 1 if ping fails for any of the sources.`,
Example: ` # Ping active data source.
$ sq ping
2020-08-06 20:58:47 +03:00
# Ping @my1 and @pg1.
$ sq ping @my1 @pg1
2020-08-06 20:58:47 +03:00
# Ping sources in the root group (i.e. all sources).
$ sq ping /
# Ping sources in the "prod" and "staging" groups.
$ sq ping prod staging
2020-08-06 20:58:47 +03:00
# Ping @my1 with 2s timeout.
$ sq ping @my1 --timeout 2s
# Output in TSV format.
$ sq ping --tsv @my1`,
2020-08-06 20:58:47 +03:00
}
cmd.Flags().BoolP(flagTable, flagTableShort, false, flagTableUsage)
cmd.Flags().BoolP(flagCSV, flagCSVShort, false, flagCSVUsage)
cmd.Flags().BoolP(flagTSV, flagTSVShort, false, flagTSVUsage)
cmd.Flags().BoolP(flagJSON, flagJSONShort, false, flagJSONUsage)
cmd.Flags().Duration(flagPingTimeout, time.Second*10, flagPingTimeoutUsage)
return cmd
2020-08-06 20:58:47 +03:00
}
func execPing(cmd *cobra.Command, args []string) error {
rc := RunContextFrom(cmd.Context())
cfg, ss := rc.Config, rc.Config.Sources
2020-08-06 20:58:47 +03:00
var srcs []*source.Source
// args can be:
// [empty] : ping active source
// @handle1 @handleN: ping multiple sources
// @handle1 group1: ping sources, or those in groups.
args = lo.Uniq(args)
if len(args) == 0 {
2020-08-06 20:58:47 +03:00
src := cfg.Sources.Active()
if src == nil {
return errz.New(msgNoActiveSrc)
}
srcs = []*source.Source{src}
} else {
for _, arg := range args {
switch {
case source.IsValidHandle(arg):
src, err := ss.Get(arg)
if err != nil {
return err
}
srcs = append(srcs, src)
case source.IsValidGroup(arg):
groupSrcs, err := ss.SourcesInGroup(arg)
if err != nil {
return err
}
srcs = append(srcs, groupSrcs...)
default:
return errz.Errorf("invalid arg: %s", arg)
2020-08-06 20:58:47 +03:00
}
}
}
srcs = lo.Uniq(srcs)
timeout := cfg.Defaults.PingTimeout
if cmdFlagChanged(cmd, flagPingTimeout) {
timeout, _ = cmd.Flags().GetDuration(flagPingTimeout)
2020-08-06 20:58:47 +03:00
}
rc.Log.Debug("Using ping timeout", lga.Val, timeout)
2020-08-06 20:58:47 +03:00
err := pingSources(cmd.Context(), rc.registry, srcs, rc.writers.pingw, timeout)
if errors.Is(err, context.Canceled) {
// It's common to cancel "sq ping". We don't want to print the cancel message.
return errNoMsg
}
return err
2020-08-06 20:58:47 +03:00
}
// pingSources pings each of the sources in srcs, and prints results
// to w. If any error occurs pinging any of srcs, that error is printed
// inline as part of the ping results, and an errNoMsg is returned.
//
// NOTE: This ping code has an ancient lineage, in that it was
// originally laid down before context.Context was a thing. Thus,
// the entire thing could probably be rewritten for simplicity.
func pingSources(ctx context.Context, dp driver.Provider, srcs []*source.Source,
w output.PingWriter, timeout time.Duration,
) error {
if err := w.Open(srcs); err != nil {
return err
}
log := lg.FromContext(ctx)
defer lg.WarnIfFuncError(log, "Close ping writer", w.Close)
2020-08-06 20:58:47 +03:00
resultCh := make(chan pingResult, len(srcs))
// pingErrExists is set to true if there was an error for
// any of the pings. This later determines if an error
// is returned from this func.
var pingErrExists bool
for _, src := range srcs {
go pingSource(ctx, dp, src, timeout, resultCh)
}
// This func doesn't check for context.Canceled itself; instead
// it checks if any of the goroutines return that value on
// resultCh.
for i := 0; i < len(srcs); i++ {
result := <-resultCh
switch {
case errors.Is(result.err, context.Canceled):
2020-08-06 20:58:47 +03:00
// If any one of the goroutines have received context.Canceled,
// then we'll bubble that up and ignore the remaining goroutines.
return context.Canceled
case errors.Is(result.err, context.DeadlineExceeded):
2020-08-06 20:58:47 +03:00
// If timeout occurred, set the duration to timeout.
result.duration = timeout
pingErrExists = true
case result.err != nil:
pingErrExists = true
}
err := w.Result(result.src, result.duration, result.err)
lg.WarnIfError(log, "Print ping result", err)
2020-08-06 20:58:47 +03:00
}
// If there's at least one error, we return the
// sentinel errNoMsg so that sq can os.Exit(1) without printing
// an additional error message (as the error message will already have
// been printed by PingWriter).
if pingErrExists {
return errNoMsg
}
return nil
}
// pingSource pings an individual driver.Source. It always returns a
// result on resultCh, even when ctx is done.
func pingSource(ctx context.Context, dp driver.Provider, src *source.Source, timeout time.Duration,
resultCh chan<- pingResult,
) {
2020-08-06 20:58:47 +03:00
drvr, err := dp.DriverFor(src.Type)
if err != nil {
resultCh <- pingResult{src: src, err: err}
return
}
if timeout > 0 {
var cancelFn context.CancelFunc
ctx, cancelFn = context.WithTimeout(ctx, timeout)
defer cancelFn()
}
doneCh := make(chan pingResult)
start := time.Now()
go func() {
err = drvr.Ping(ctx, src)
doneCh <- pingResult{src: src, duration: time.Since(start), err: err}
}()
select {
case <-ctx.Done():
resultCh <- pingResult{src: src, err: ctx.Err()}
case result := <-doneCh:
resultCh <- result
}
}
type pingResult struct {
src *source.Source
duration time.Duration
err error
}