Use cobra's builtin machinery for creating new types

This commit is contained in:
Kovid Goyal 2022-08-17 12:35:14 +05:30
parent a0bff4abab
commit 6c25f0cf4b
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 36 additions and 39 deletions

View File

@ -30,19 +30,6 @@ func GetTTYSize() (*unix.Winsize, error) {
return nil, fmt.Errorf("STDOUT is not a TTY")
}
func add_choices(cmd *cobra.Command, flags *pflag.FlagSet, choices []string, name string, usage string) *string {
cmd.Annotations["choices-"+name] = strings.Join(choices, "\000")
return flags.String(name, choices[0], usage)
}
func Choices(cmd *cobra.Command, name string, usage string, choices ...string) *string {
return add_choices(cmd, cmd.Flags(), choices, name, usage)
}
func PersistentChoices(cmd *cobra.Command, name string, usage string, choices ...string) *string {
return add_choices(cmd, cmd.PersistentFlags(), choices, name, usage)
}
func key_in_slice(vals []string, key string) bool {
for _, q := range vals {
if q == key {
@ -52,18 +39,33 @@ func key_in_slice(vals []string, key string) bool {
return false
}
func ValidateChoices(cmd *cobra.Command, args []string) error {
for key, val := range cmd.Annotations {
if strings.HasPrefix(key, "choices-") {
allowed := strings.Split(val, "\000")
name := key[len("choices-"):]
if cval, err := cmd.Flags().GetString(name); err == nil && !key_in_slice(allowed, cval) {
return fmt.Errorf("%s: Invalid value: %s. Allowed values are: %s", color.YellowString("--"+name), color.RedString(cval), strings.Join(allowed, ", "))
}
}
}
type ChoicesVal struct {
name, Choice string
allowed []string
}
type choicesVal ChoicesVal
func (i *choicesVal) String() string { return ChoicesVal(*i).Choice }
func (i *choicesVal) Type() string { return "choices" }
func (i *choicesVal) Set(s string) error {
(*i).Choice = s
return nil
}
func newChoicesVal(val ChoicesVal, p *ChoicesVal) *choicesVal {
*p = val
return (*choicesVal)(p)
}
func add_choices(flags *pflag.FlagSet, p *ChoicesVal, choices []string, name string, usage string) {
value := ChoicesVal{Choice: choices[0], allowed: choices}
flags.VarP(newChoicesVal(value, p), name, "", usage)
}
func Choices(flags *pflag.FlagSet, name string, usage string, choices ...string) *ChoicesVal {
p := new(ChoicesVal)
add_choices(flags, p, choices, name, usage)
return p
}
var stdout_is_terminal = false
var title_fmt = color.New(color.FgBlue, color.Bold).SprintFunc()
@ -356,24 +358,15 @@ func CreateCommand(cmd *cobra.Command) *cobra.Command {
cmd.RunE = func(cmd *cobra.Command, args []string) error {
if len(cmd.Commands()) > 0 {
if len(args) == 0 {
return fmt.Errorf("%s. Use %s -h to get a list of available sub-commands", err_fmt("No sub-command specified"), full_command_name(cmd))
return fmt.Errorf("%s. Use %s -h to get a list of available sub-commands", "No sub-command specified", full_command_name(cmd))
}
return fmt.Errorf("Not a valid subcommand: %s. Use %s -h to get a list of available sub-commands", err_fmt(args[0]), full_command_name(cmd))
return fmt.Errorf("Not a valid subcommand: %s. Use %s -h to get a list of available sub-commands", args[0], full_command_name(cmd))
}
return nil
}
}
cmd.SilenceErrors = true
cmd.SilenceUsage = true
orig_pre_run := cmd.PersistentPreRunE
cmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
err := ValidateChoices(cmd, args)
if err != nil || orig_pre_run == nil {
return err
}
return orig_pre_run(cmd, args)
}
cmd.PersistentFlags().SortFlags = false
cmd.Flags().SortFlags = false
return cmd
@ -386,6 +379,10 @@ func show_help(cmd *cobra.Command, args []string) {
show_usage(cmd)
}
func PrintError(err error) {
fmt.Println(err_fmt("Error")+":", err)
}
func Init(root *cobra.Command) {
vs := kitty.VersionString
if kitty.VCSRevision != "" {

View File

@ -152,7 +152,8 @@ var command_objects map[string]*cobra.Command = make(map[string]*cobra.Command)
func EntryPoint(tool_root *cobra.Command) *cobra.Command {
var at_root_command *cobra.Command
var to, password, password_file, password_env, use_password *string
var to, password, password_file, password_env *string
var use_password *cli.ChoicesVal
at_root_command = cli.CreateCommand(&cobra.Command{
Use: "@ [global options] command [command options] [command args]",
Short: "Control kitty remotely",
@ -163,7 +164,7 @@ func EntryPoint(tool_root *cobra.Command) *cobra.Command {
global_options.to_address_is_from_env_var = true
}
global_options.to_address = *to
q, err := get_password(*password, *password_file, *password_env, *use_password)
q, err := get_password(*password, *password_file, *password_env, use_password.Choice)
global_options.password = q
return err
},
@ -192,7 +193,7 @@ func EntryPoint(tool_root *cobra.Command) *cobra.Command {
"The name of an environment variable to read the password from."+
" Used if no :option:`--password-file` or :option:`--password` is supplied.")
use_password = cli.PersistentChoices(at_root_command, "use-password", "If no password is available, kitty will usually just send the remote control command without a password. This option can be used to force it to always or never use the supplied password.", "if-available", "always", "never")
use_password = cli.Choices(at_root_command.PersistentFlags(), "use-password", "If no password is available, kitty will usually just send the remote control command without a password. This option can be used to force it to always or never use the supplied password.", "if-available", "always", "never")
for cmd_name, reg_func := range all_commands {
c := reg_func(at_root_command)

View File

@ -1,7 +1,6 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
@ -19,7 +18,7 @@ func main() {
cli.Init(root)
if err := root.Execute(); err != nil {
fmt.Fprintln(os.Stderr, "Error:", err)
cli.PrintError(err)
os.Exit(1)
}
}