From 6c25f0cf4b87224a6720f5a6ef2829bfbea2e2a2 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Wed, 17 Aug 2022 12:35:14 +0530 Subject: [PATCH] Use cobra's builtin machinery for creating new types --- tools/cli/infrastructure.go | 65 ++++++++++++++++++------------------- tools/cmd/at/main.go | 7 ++-- tools/cmd/main.go | 3 +- 3 files changed, 36 insertions(+), 39 deletions(-) diff --git a/tools/cli/infrastructure.go b/tools/cli/infrastructure.go index d705187f4..a72d25a98 100644 --- a/tools/cli/infrastructure.go +++ b/tools/cli/infrastructure.go @@ -30,19 +30,6 @@ func GetTTYSize() (*unix.Winsize, error) { return nil, fmt.Errorf("STDOUT is not a TTY") } -func add_choices(cmd *cobra.Command, flags *pflag.FlagSet, choices []string, name string, usage string) *string { - cmd.Annotations["choices-"+name] = strings.Join(choices, "\000") - return flags.String(name, choices[0], usage) -} - -func Choices(cmd *cobra.Command, name string, usage string, choices ...string) *string { - return add_choices(cmd, cmd.Flags(), choices, name, usage) -} - -func PersistentChoices(cmd *cobra.Command, name string, usage string, choices ...string) *string { - return add_choices(cmd, cmd.PersistentFlags(), choices, name, usage) -} - func key_in_slice(vals []string, key string) bool { for _, q := range vals { if q == key { @@ -52,18 +39,33 @@ func key_in_slice(vals []string, key string) bool { return false } -func ValidateChoices(cmd *cobra.Command, args []string) error { - for key, val := range cmd.Annotations { - if strings.HasPrefix(key, "choices-") { - allowed := strings.Split(val, "\000") - name := key[len("choices-"):] - if cval, err := cmd.Flags().GetString(name); err == nil && !key_in_slice(allowed, cval) { - return fmt.Errorf("%s: Invalid value: %s. Allowed values are: %s", color.YellowString("--"+name), color.RedString(cval), strings.Join(allowed, ", ")) - } - } - } +type ChoicesVal struct { + name, Choice string + allowed []string +} +type choicesVal ChoicesVal + +func (i *choicesVal) String() string { return ChoicesVal(*i).Choice } +func (i *choicesVal) Type() string { return "choices" } +func (i *choicesVal) Set(s string) error { + (*i).Choice = s return nil } +func newChoicesVal(val ChoicesVal, p *ChoicesVal) *choicesVal { + *p = val + return (*choicesVal)(p) +} + +func add_choices(flags *pflag.FlagSet, p *ChoicesVal, choices []string, name string, usage string) { + value := ChoicesVal{Choice: choices[0], allowed: choices} + flags.VarP(newChoicesVal(value, p), name, "", usage) +} + +func Choices(flags *pflag.FlagSet, name string, usage string, choices ...string) *ChoicesVal { + p := new(ChoicesVal) + add_choices(flags, p, choices, name, usage) + return p +} var stdout_is_terminal = false var title_fmt = color.New(color.FgBlue, color.Bold).SprintFunc() @@ -356,24 +358,15 @@ func CreateCommand(cmd *cobra.Command) *cobra.Command { cmd.RunE = func(cmd *cobra.Command, args []string) error { if len(cmd.Commands()) > 0 { if len(args) == 0 { - return fmt.Errorf("%s. Use %s -h to get a list of available sub-commands", err_fmt("No sub-command specified"), full_command_name(cmd)) + return fmt.Errorf("%s. Use %s -h to get a list of available sub-commands", "No sub-command specified", full_command_name(cmd)) } - return fmt.Errorf("Not a valid subcommand: %s. Use %s -h to get a list of available sub-commands", err_fmt(args[0]), full_command_name(cmd)) + return fmt.Errorf("Not a valid subcommand: %s. Use %s -h to get a list of available sub-commands", args[0], full_command_name(cmd)) } return nil } } cmd.SilenceErrors = true cmd.SilenceUsage = true - orig_pre_run := cmd.PersistentPreRunE - cmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { - err := ValidateChoices(cmd, args) - if err != nil || orig_pre_run == nil { - return err - } - return orig_pre_run(cmd, args) - } - cmd.PersistentFlags().SortFlags = false cmd.Flags().SortFlags = false return cmd @@ -386,6 +379,10 @@ func show_help(cmd *cobra.Command, args []string) { show_usage(cmd) } +func PrintError(err error) { + fmt.Println(err_fmt("Error")+":", err) +} + func Init(root *cobra.Command) { vs := kitty.VersionString if kitty.VCSRevision != "" { diff --git a/tools/cmd/at/main.go b/tools/cmd/at/main.go index 70848e4e3..c07113ac2 100644 --- a/tools/cmd/at/main.go +++ b/tools/cmd/at/main.go @@ -152,7 +152,8 @@ var command_objects map[string]*cobra.Command = make(map[string]*cobra.Command) func EntryPoint(tool_root *cobra.Command) *cobra.Command { var at_root_command *cobra.Command - var to, password, password_file, password_env, use_password *string + var to, password, password_file, password_env *string + var use_password *cli.ChoicesVal at_root_command = cli.CreateCommand(&cobra.Command{ Use: "@ [global options] command [command options] [command args]", Short: "Control kitty remotely", @@ -163,7 +164,7 @@ func EntryPoint(tool_root *cobra.Command) *cobra.Command { global_options.to_address_is_from_env_var = true } global_options.to_address = *to - q, err := get_password(*password, *password_file, *password_env, *use_password) + q, err := get_password(*password, *password_file, *password_env, use_password.Choice) global_options.password = q return err }, @@ -192,7 +193,7 @@ func EntryPoint(tool_root *cobra.Command) *cobra.Command { "The name of an environment variable to read the password from."+ " Used if no :option:`--password-file` or :option:`--password` is supplied.") - use_password = cli.PersistentChoices(at_root_command, "use-password", "If no password is available, kitty will usually just send the remote control command without a password. This option can be used to force it to always or never use the supplied password.", "if-available", "always", "never") + use_password = cli.Choices(at_root_command.PersistentFlags(), "use-password", "If no password is available, kitty will usually just send the remote control command without a password. This option can be used to force it to always or never use the supplied password.", "if-available", "always", "never") for cmd_name, reg_func := range all_commands { c := reg_func(at_root_command) diff --git a/tools/cmd/main.go b/tools/cmd/main.go index f09e0a016..8ee111c2b 100644 --- a/tools/cmd/main.go +++ b/tools/cmd/main.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "os" "github.com/spf13/cobra" @@ -19,7 +18,7 @@ func main() { cli.Init(root) if err := root.Execute(); err != nil { - fmt.Fprintln(os.Stderr, "Error:", err) + cli.PrintError(err) os.Exit(1) } }