This commit is contained in:
Kovid Goyal 2023-02-28 15:26:17 +05:30
parent 1b2fe90ed1
commit 944e036611
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
5 changed files with 69 additions and 74 deletions

View File

@ -72,7 +72,10 @@ func RunSSHAskpass() {
break
}
}
data = shm.ReadWithSize(data_shm, 1)
data, err = shm.ReadWithSize(data_shm, 1)
if err != nil {
fatal(fmt.Errorf("Failed to read from SHM file with error: %w", err))
}
response := ""
if is_confirm {
var ok bool

View File

@ -65,11 +65,7 @@ func get_destination(hostname string) (username, hostname_for_match string) {
}
func read_data_from_shared_memory(shm_name string) ([]byte, error) {
data, err := shm.ReadWithSizeAndUnlink(shm_name, func(f *os.File) error {
s, err := f.Stat()
if err != nil {
return fmt.Errorf("Failed to stat SHM file with error: %w", err)
}
data, err := shm.ReadWithSizeAndUnlink(shm_name, func(s fs.FileInfo) error {
if stat, ok := s.Sys().(unix.Stat_t); ok {
if os.Getuid() != int(stat.Uid) || os.Getgid() != int(stat.Gid) {
return fmt.Errorf("Incorrect owner on SHM file")

View File

@ -104,43 +104,58 @@ func truncate_or_unlink(ans *os.File, size uint64) (err error) {
return
}
func read_till_buf_full(f *os.File, buf []byte) ([]byte, error) {
p := buf
for len(p) > 0 {
n, err := f.Read(p)
p = p[n:]
if err != nil {
if len(p) == 0 && errors.Is(err, io.EOF) {
err = nil
}
err = fmt.Errorf("Failed to read from SHM file with error: %w", err)
return buf[:len(buf)-len(p)], err
}
}
return buf, nil
}
func read_with_size(f *os.File) ([]byte, error) {
szbuf := []byte{0, 0, 0, 0}
szbuf, err := read_till_buf_full(f, szbuf)
if err != nil {
return nil, err
}
size := int(binary.BigEndian.Uint32(szbuf))
return read_till_buf_full(f, make([]byte, size))
}
const NUM_BYTES_FOR_SIZE = 4
func WriteWithSize(self MMap, b []byte, at int) error {
szbuf := []byte{0, 0, 0, 0}
binary.BigEndian.PutUint32(szbuf, uint32(len(b)))
copy(self.Slice()[at:], szbuf)
copy(self.Slice()[at+4:], b)
if len(self.Slice()) < at+len(b)+NUM_BYTES_FOR_SIZE {
return io.ErrShortBuffer
}
binary.BigEndian.PutUint32(self.Slice()[at:], uint32(len(b)))
copy(self.Slice()[at+NUM_BYTES_FOR_SIZE:], b)
return nil
}
func ReadWithSize(self MMap, at int) []byte {
size := int(binary.BigEndian.Uint32(self.Slice()[at : at+4]))
return self.Slice()[at+4 : at+4+size]
func ReadWithSize(self MMap, at int) ([]byte, error) {
s := self.Slice()[at:]
if len(s) < NUM_BYTES_FOR_SIZE {
return nil, io.ErrShortBuffer
}
size := int(binary.BigEndian.Uint32(self.Slice()[at : at+NUM_BYTES_FOR_SIZE]))
s = s[NUM_BYTES_FOR_SIZE:]
if len(s) < size {
return nil, io.ErrShortBuffer
}
return s[:size], nil
}
func ReadWithSizeAndUnlink(name string, file_callback ...func(fs.FileInfo) error) ([]byte, error) {
mmap, err := Open(name, 0)
if err != nil {
return nil, err
}
if len(file_callback) > 0 {
s, err := mmap.Stat()
if err != nil {
return nil, fmt.Errorf("Failed to stat SHM file with error: %w", err)
}
for _, f := range file_callback {
err = f(s)
if err != nil {
return nil, err
}
}
}
defer func() {
mmap.Close()
mmap.Unlink()
}()
slice, err := ReadWithSize(mmap, 0)
if err != nil {
return nil, err
}
ans := make([]byte, len(slice))
copy(ans, slice)
return ans, nil
}
func test_integration_with_python(args []string) (rc int, err error) {
@ -161,7 +176,7 @@ func test_integration_with_python(args []string) (rc int, err error) {
if err != nil {
return 1, err
}
mmap, err := CreateTemp("shmtest-", uint64(len(data)+4))
mmap, err := CreateTemp("shmtest-", uint64(len(data)+NUM_BYTES_FOR_SIZE))
if err != nil {
return 1, err
}

View File

@ -142,21 +142,13 @@ func Open(name string, size uint64) (MMap, error) {
if err != nil {
return nil, err
}
if size == 0 {
s, err := ans.Stat()
if err != nil {
ans.Close()
return nil, fmt.Errorf("Failed to stat SHM file with error: %w", err)
}
size = uint64(s.Size())
}
return file_mmap(ans, size, READ, false, name)
}
func ReadWithSizeAndUnlink(name string, file_callback ...func(*os.File) error) ([]byte, error) {
f, err := open(name)
if err != nil {
return nil, err
}
defer f.Close()
defer os.Remove(f.Name())
for _, cb := range file_callback {
err = cb(f)
if err != nil {
return nil, err
}
}
return read_with_size(f)
}

View File

@ -4,7 +4,6 @@
package shm
import (
"encoding/binary"
"errors"
"fmt"
"io/fs"
@ -159,23 +158,13 @@ func Open(name string, size uint64) (MMap, error) {
if err != nil {
return nil, err
}
if size == 0 {
s, err := ans.Stat()
if err != nil {
ans.Close()
return nil, fmt.Errorf("Failed to stat SHM file with error: %w", err)
}
size = uint64(s.Size())
}
return syscall_mmap(ans, size, READ, false)
}
func ReadWithSizeAndUnlink(name string, file_callback ...func(*os.File) error) ([]byte, error) {
mmap, err := Open(name, 4)
if err != nil {
return nil, err
}
size := uint64(binary.BigEndian.Uint32(mmap.Slice()))
mmap.Close()
mmap, err = Open(name, 4+size)
if err != nil {
return nil, err
}
ans := make([]byte, size)
copy(ans, mmap.Slice()[4:])
mmap.Close()
mmap.Unlink()
return ans, nil
}