Port parsing of ssh args

This commit is contained in:
Kovid Goyal 2023-02-20 17:18:43 +05:30
parent 12c8af60dc
commit 97b9572bec
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 155 additions and 0 deletions

View File

@ -11,6 +11,15 @@ import (
var _ = fmt.Print
func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) {
if len(args) > 0 {
switch args[0] {
case "use-python":
args = args[1:] // backwards compat from when we had a python implementation
case "-h", "--help":
cmd.ShowHelp()
return
}
}
return
}

View File

@ -80,3 +80,112 @@ func SSHOptions() map[string]string {
query_ssh_for_options_once.Do(get_ssh_options)
return ssh_options
}
func GetSSHCLI() (boolean_ssh_args *utils.Set[string], other_ssh_args *utils.Set[string]) {
other_ssh_args, boolean_ssh_args = utils.NewSet[string](32), utils.NewSet[string](32)
for k, v := range SSHOptions() {
k = "-" + k
if v == "" {
boolean_ssh_args.Add(k)
} else {
other_ssh_args.Add(k)
}
}
return
}
func is_extra_arg(arg string, extra_args []string) string {
for _, x := range extra_args {
if arg == x || strings.HasPrefix(arg, x+"=") {
return x
}
}
return ""
}
type ErrInvalidSSHArgs struct {
Msg string
}
func (self *ErrInvalidSSHArgs) Error() string {
return self.Msg
}
func ParseSSHArgs(args []string, extra_args ...string) (ssh_args []string, server_args []string, passthrough bool, found_extra_args []string, err error) {
if extra_args == nil {
extra_args = []string{}
}
if len(args) == 0 {
passthrough = true
return
}
passthrough_args := map[string]bool{"-N": true, "-n": true, "-f": true, "-G": true, "-T": true}
boolean_ssh_args, other_ssh_args := GetSSHCLI()
ssh_args, server_args, found_extra_args = make([]string, 0, 16), make([]string, 0, 16), make([]string, 0, 16)
expecting_option_val := false
stop_option_processing := false
expecting_extra_val := ""
for _, argument := range args {
if len(server_args) > 1 || stop_option_processing {
server_args = append(server_args, argument)
continue
}
if strings.HasPrefix(argument, "-") && !expecting_option_val {
if argument == "--" {
stop_option_processing = true
continue
}
if len(extra_args) > 0 {
matching_ex := is_extra_arg(argument, extra_args)
if matching_ex != "" {
_, exval, found := strings.Cut(argument, "=")
if found {
found_extra_args = append(found_extra_args, matching_ex, exval)
} else {
expecting_extra_val = matching_ex
expecting_option_val = true
}
continue
}
}
// could be a multi-character option
all_args := []rune(argument[1:])
for i, ch := range all_args {
arg := "-" + string(ch)
if passthrough_args[arg] {
passthrough = true
}
if boolean_ssh_args.Has(arg) {
ssh_args = append(ssh_args, arg)
continue
}
if other_ssh_args.Has(arg) {
ssh_args = append(ssh_args, arg)
if i+1 < len(all_args) {
ssh_args = append(ssh_args, string(all_args[i+1:]))
} else {
expecting_option_val = true
}
break
}
err = &ErrInvalidSSHArgs{Msg: "unknown option -- " + arg[1:]}
return
}
continue
}
if expecting_option_val {
if expecting_extra_val != "" {
found_extra_args = append(found_extra_args, expecting_extra_val, argument)
} else {
ssh_args = append(ssh_args, argument)
}
expecting_option_val = false
continue
}
server_args = append(server_args, argument)
}
if len(server_args) == 0 {
err = &ErrInvalidSSHArgs{Msg: "No server to connect to specified"}
}
return
}

View File

@ -5,6 +5,10 @@ package ssh
import (
"fmt"
"testing"
"kitty/tools/utils/shlex"
"github.com/google/go-cmp/cmp"
)
var _ = fmt.Print
@ -15,3 +19,36 @@ func TestGetSSHOptions(t *testing.T) {
t.Fatalf("Unexpected set of SSH options: %#v", m)
}
}
func TestParseSSHArgs(t *testing.T) {
split := func(x string) []string {
ans, err := shlex.Split(x)
if err != nil {
t.Fatal(err)
}
return ans
}
p := func(args, expected_ssh_args, expected_server_args, expected_extra_args string, expected_passthrough bool) {
ssh_args, server_args, passthrough, extra_args, err := ParseSSHArgs(split(args), "--kitten")
if err != nil {
t.Fatal(err)
}
check := func(a, b any) {
diff := cmp.Diff(a, b)
if diff != "" {
t.Fatalf("Unexpected value for args: %s\n%s", args, diff)
}
}
check(split(expected_ssh_args), ssh_args)
check(split(expected_server_args), server_args)
check(split(expected_extra_args), extra_args)
check(expected_passthrough, passthrough)
}
p(`localhost`, ``, `localhost`, ``, false)
p(`-- localhost`, ``, `localhost`, ``, false)
p(`-46p23 localhost sh -c "a b"`, `-4 -6 -p 23`, `localhost sh -c "a b"`, ``, false)
p(`-46p23 -S/moose -W x:6 -- localhost sh -c "a b"`, `-4 -6 -p 23 -S /moose -W x:6`, `localhost sh -c "a b"`, ``, false)
p(`--kitten=abc -np23 --kitten xyz host`, `-n -p 23`, `host`, `--kitten abc --kitten xyz`, true)
}