From 944e036611b6710670778b9e295273862ae6cb96 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Tue, 28 Feb 2023 15:26:17 +0530 Subject: [PATCH] DRYer --- tools/cmd/ssh/askpass.go | 5 ++- tools/cmd/ssh/main.go | 6 +-- tools/utils/shm/shm.go | 81 ++++++++++++++++++++-------------- tools/utils/shm/shm_fs.go | 24 ++++------ tools/utils/shm/shm_syscall.go | 27 ++++-------- 5 files changed, 69 insertions(+), 74 deletions(-) diff --git a/tools/cmd/ssh/askpass.go b/tools/cmd/ssh/askpass.go index 304d369d7..a90081420 100644 --- a/tools/cmd/ssh/askpass.go +++ b/tools/cmd/ssh/askpass.go @@ -72,7 +72,10 @@ func RunSSHAskpass() { break } } - data = shm.ReadWithSize(data_shm, 1) + data, err = shm.ReadWithSize(data_shm, 1) + if err != nil { + fatal(fmt.Errorf("Failed to read from SHM file with error: %w", err)) + } response := "" if is_confirm { var ok bool diff --git a/tools/cmd/ssh/main.go b/tools/cmd/ssh/main.go index fb48137f1..8b3a9f100 100644 --- a/tools/cmd/ssh/main.go +++ b/tools/cmd/ssh/main.go @@ -65,11 +65,7 @@ func get_destination(hostname string) (username, hostname_for_match string) { } func read_data_from_shared_memory(shm_name string) ([]byte, error) { - data, err := shm.ReadWithSizeAndUnlink(shm_name, func(f *os.File) error { - s, err := f.Stat() - if err != nil { - return fmt.Errorf("Failed to stat SHM file with error: %w", err) - } + data, err := shm.ReadWithSizeAndUnlink(shm_name, func(s fs.FileInfo) error { if stat, ok := s.Sys().(unix.Stat_t); ok { if os.Getuid() != int(stat.Uid) || os.Getgid() != int(stat.Gid) { return fmt.Errorf("Incorrect owner on SHM file") diff --git a/tools/utils/shm/shm.go b/tools/utils/shm/shm.go index 5f8618e29..02427183e 100644 --- a/tools/utils/shm/shm.go +++ b/tools/utils/shm/shm.go @@ -104,43 +104,58 @@ func truncate_or_unlink(ans *os.File, size uint64) (err error) { return } -func read_till_buf_full(f *os.File, buf []byte) ([]byte, error) { - p := buf - for len(p) > 0 { - n, err := f.Read(p) - p = p[n:] - if err != nil { - if len(p) == 0 && errors.Is(err, io.EOF) { - err = nil - } - err = fmt.Errorf("Failed to read from SHM file with error: %w", err) - return buf[:len(buf)-len(p)], err - } - } - return buf, nil -} - -func read_with_size(f *os.File) ([]byte, error) { - szbuf := []byte{0, 0, 0, 0} - szbuf, err := read_till_buf_full(f, szbuf) - if err != nil { - return nil, err - } - size := int(binary.BigEndian.Uint32(szbuf)) - return read_till_buf_full(f, make([]byte, size)) -} +const NUM_BYTES_FOR_SIZE = 4 func WriteWithSize(self MMap, b []byte, at int) error { - szbuf := []byte{0, 0, 0, 0} - binary.BigEndian.PutUint32(szbuf, uint32(len(b))) - copy(self.Slice()[at:], szbuf) - copy(self.Slice()[at+4:], b) + if len(self.Slice()) < at+len(b)+NUM_BYTES_FOR_SIZE { + return io.ErrShortBuffer + } + binary.BigEndian.PutUint32(self.Slice()[at:], uint32(len(b))) + copy(self.Slice()[at+NUM_BYTES_FOR_SIZE:], b) return nil } -func ReadWithSize(self MMap, at int) []byte { - size := int(binary.BigEndian.Uint32(self.Slice()[at : at+4])) - return self.Slice()[at+4 : at+4+size] +func ReadWithSize(self MMap, at int) ([]byte, error) { + s := self.Slice()[at:] + if len(s) < NUM_BYTES_FOR_SIZE { + return nil, io.ErrShortBuffer + } + size := int(binary.BigEndian.Uint32(self.Slice()[at : at+NUM_BYTES_FOR_SIZE])) + s = s[NUM_BYTES_FOR_SIZE:] + if len(s) < size { + return nil, io.ErrShortBuffer + } + return s[:size], nil +} + +func ReadWithSizeAndUnlink(name string, file_callback ...func(fs.FileInfo) error) ([]byte, error) { + mmap, err := Open(name, 0) + if err != nil { + return nil, err + } + if len(file_callback) > 0 { + s, err := mmap.Stat() + if err != nil { + return nil, fmt.Errorf("Failed to stat SHM file with error: %w", err) + } + for _, f := range file_callback { + err = f(s) + if err != nil { + return nil, err + } + } + } + defer func() { + mmap.Close() + mmap.Unlink() + }() + slice, err := ReadWithSize(mmap, 0) + if err != nil { + return nil, err + } + ans := make([]byte, len(slice)) + copy(ans, slice) + return ans, nil } func test_integration_with_python(args []string) (rc int, err error) { @@ -161,7 +176,7 @@ func test_integration_with_python(args []string) (rc int, err error) { if err != nil { return 1, err } - mmap, err := CreateTemp("shmtest-", uint64(len(data)+4)) + mmap, err := CreateTemp("shmtest-", uint64(len(data)+NUM_BYTES_FOR_SIZE)) if err != nil { return 1, err } diff --git a/tools/utils/shm/shm_fs.go b/tools/utils/shm/shm_fs.go index f505cf3ae..f0fd82003 100644 --- a/tools/utils/shm/shm_fs.go +++ b/tools/utils/shm/shm_fs.go @@ -142,21 +142,13 @@ func Open(name string, size uint64) (MMap, error) { if err != nil { return nil, err } + if size == 0 { + s, err := ans.Stat() + if err != nil { + ans.Close() + return nil, fmt.Errorf("Failed to stat SHM file with error: %w", err) + } + size = uint64(s.Size()) + } return file_mmap(ans, size, READ, false, name) } - -func ReadWithSizeAndUnlink(name string, file_callback ...func(*os.File) error) ([]byte, error) { - f, err := open(name) - if err != nil { - return nil, err - } - defer f.Close() - defer os.Remove(f.Name()) - for _, cb := range file_callback { - err = cb(f) - if err != nil { - return nil, err - } - } - return read_with_size(f) -} diff --git a/tools/utils/shm/shm_syscall.go b/tools/utils/shm/shm_syscall.go index b19b4cdda..c1718e826 100644 --- a/tools/utils/shm/shm_syscall.go +++ b/tools/utils/shm/shm_syscall.go @@ -4,7 +4,6 @@ package shm import ( - "encoding/binary" "errors" "fmt" "io/fs" @@ -159,23 +158,13 @@ func Open(name string, size uint64) (MMap, error) { if err != nil { return nil, err } + if size == 0 { + s, err := ans.Stat() + if err != nil { + ans.Close() + return nil, fmt.Errorf("Failed to stat SHM file with error: %w", err) + } + size = uint64(s.Size()) + } return syscall_mmap(ans, size, READ, false) } - -func ReadWithSizeAndUnlink(name string, file_callback ...func(*os.File) error) ([]byte, error) { - mmap, err := Open(name, 4) - if err != nil { - return nil, err - } - size := uint64(binary.BigEndian.Uint32(mmap.Slice())) - mmap.Close() - mmap, err = Open(name, 4+size) - if err != nil { - return nil, err - } - ans := make([]byte, size) - copy(ans, mmap.Slice()[4:]) - mmap.Close() - mmap.Unlink() - return ans, nil -}