package cli

import (
	"context"
	"errors"
	"time"

	"github.com/neilotoole/sq/libsq/core/lg/lga"

	"github.com/neilotoole/sq/libsq/core/lg"

	"github.com/spf13/cobra"

	"github.com/neilotoole/sq/cli/output"
	"github.com/neilotoole/sq/libsq/core/errz"
	"github.com/neilotoole/sq/libsq/driver"
	"github.com/neilotoole/sq/libsq/source"
)

func newPingCmd() *cobra.Command {
	cmd := &cobra.Command{
		Use:               "ping [@HANDLE [@HANDLE_N]]",
		RunE:              execPing,
		ValidArgsFunction: completeHandle(0),

		Short: "Ping data sources",
		Long: `Ping data sources to check connection health. If no arguments provided, the
active data source is pinged. Provide the handles of one or more sources
to ping those sources, or --all to ping all sources.

The exit code is 1 if ping fails for any of the sources.`,
		Example: `  # ping active data source
  $ sq ping

  # ping all data sources
  $ sq ping --all

  # ping @my1 and @pg1
  $ sq ping @my1 @pg1

  # ping @my1 with 2s timeout
  $ sq ping @my1 --timeout=2s

  # output in TSV format
  $ sq ping --tsv @my1`,
	}

	cmd.Flags().BoolP(flagTable, flagTableShort, false, flagTableUsage)
	cmd.Flags().BoolP(flagCSV, flagCSVShort, false, flagCSVUsage)
	cmd.Flags().BoolP(flagTSV, flagTSVShort, false, flagTSVUsage)
	cmd.Flags().BoolP(flagJSON, flagJSONShort, false, flagJSONUsage)
	cmd.Flags().Duration(flagPingTimeout, time.Second*10, flagPingTimeoutUsage)
	cmd.Flags().BoolP(flagPingAll, flagPingAllShort, false, flagPingAllUsage)

	return cmd
}

func execPing(cmd *cobra.Command, args []string) error {
	rc := RunContextFrom(cmd.Context())
	cfg := rc.Config
	var srcs []*source.Source

	// args can be:
	// [empty] : ping active source
	// @handle1 @handleN: ping multiple sources

	var pingAll bool
	if cmd.Flags().Changed(flagPingAll) {
		pingAll, _ = cmd.Flags().GetBool(flagPingAll)
	}

	switch {
	case pingAll:
		srcs = cfg.Sources.Items()
	case len(args) == 0:
		src := cfg.Sources.Active()
		if src == nil {
			return errz.New(msgNoActiveSrc)
		}
		srcs = []*source.Source{src}
	default:
		for _, arg := range args {
			err := source.VerifyLegalHandle(arg)
			if err != nil {
				return err
			}

			src, err := cfg.Sources.Get(arg)
			if err != nil {
				return err
			}

			srcs = append(srcs, src)
		}
	}

	timeout := cfg.Defaults.PingTimeout
	if cmdFlagChanged(cmd, flagPingTimeout) {
		timeout, _ = cmd.Flags().GetDuration(flagPingTimeout)
	}

	rc.Log.Debug("Using ping timeout", lga.Val, timeout)

	err := pingSources(cmd.Context(), rc.registry, srcs, rc.writers.pingw, timeout)
	if errors.Is(err, context.Canceled) {
		// It's common to cancel "sq ping". We don't want to print the cancel message.
		return errNoMsg
	}

	return err
}

// pingSources pings each of the sources in srcs, and prints results
// to w. If any error occurs pinging any of srcs, that error is printed
// inline as part of the ping results, and an errNoMsg is returned.
//
// NOTE: This ping code has an ancient lineage, in that it was
// originally laid down before context.Context was a thing. Thus,
// the entire thing could probably be rewritten for simplicity.
func pingSources(ctx context.Context, dp driver.Provider, srcs []*source.Source,
	w output.PingWriter, timeout time.Duration,
) error {
	if err := w.Open(srcs); err != nil {
		return err
	}

	log := lg.FromContext(ctx)
	defer lg.WarnIfFuncError(log, "Close ping writer", w.Close)

	resultCh := make(chan pingResult, len(srcs))

	// pingErrExists is set to true if there was an error for
	// any of the pings. This later determines if an error
	// is returned from this func.
	var pingErrExists bool

	for _, src := range srcs {
		go pingSource(ctx, dp, src, timeout, resultCh)
	}

	// This func doesn't check for context.Canceled itself; instead
	// it checks if any of the goroutines return that value on
	// resultCh.
	for i := 0; i < len(srcs); i++ {
		result := <-resultCh

		switch {
		case errors.Is(result.err, context.Canceled):
			// If any one of the goroutines have received context.Canceled,
			// then we'll bubble that up and ignore the remaining goroutines.
			return context.Canceled

		case errors.Is(result.err, context.DeadlineExceeded):
			// If timeout occurred, set the duration to timeout.
			result.duration = timeout
			pingErrExists = true

		case result.err != nil:
			pingErrExists = true
		}

		err := w.Result(result.src, result.duration, result.err)
		lg.WarnIfError(log, "Print ping result", err)
	}

	// If there's at least one error, we return the
	// sentinel errNoMsg so that sq can os.Exit(1) without printing
	// an additional error message (as the error message will already have
	// been printed by PingWriter).
	if pingErrExists {
		return errNoMsg
	}

	return nil
}

// pingSource pings an individual driver.Source. It always returns a
// result on resultCh, even when ctx is done.
func pingSource(ctx context.Context, dp driver.Provider, src *source.Source, timeout time.Duration,
	resultCh chan<- pingResult,
) {
	drvr, err := dp.DriverFor(src.Type)
	if err != nil {
		resultCh <- pingResult{src: src, err: err}
		return
	}

	if timeout > 0 {
		var cancelFn context.CancelFunc
		ctx, cancelFn = context.WithTimeout(ctx, timeout)
		defer cancelFn()
	}

	doneCh := make(chan pingResult)
	start := time.Now()

	go func() {
		err = drvr.Ping(ctx, src)
		doneCh <- pingResult{src: src, duration: time.Since(start), err: err}
	}()

	select {
	case <-ctx.Done():
		resultCh <- pingResult{src: src, err: ctx.Err()}
	case result := <-doneCh:
		resultCh <- result
	}
}

type pingResult struct {
	src      *source.Source
	duration time.Duration
	err      error
}