Go SHM API to read simple data with size from SHM name

This commit is contained in:
Kovid Goyal 2023-02-20 21:23:23 +05:30
parent 3f829ccdde
commit 5a8d903a4d
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
7 changed files with 229 additions and 4 deletions

30
kitty_tests/shm.py Normal file
View File

@ -0,0 +1,30 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2023, Kovid Goyal <kovid at kovidgoyal.net>
import os
import subprocess
from kitty.constants import kitten_exe
from kitty.fast_data_types import shm_unlink
from kitty.shm import SharedMemory
from . import BaseTest
class SHMTest(BaseTest):
def test_shm_with_kitten(self):
data = os.urandom(333)
with SharedMemory(size=363) as shm:
shm.write_data_with_size(data)
cp = subprocess.run([kitten_exe(), '__pytest__', 'shm', 'read', shm.name], stdout=subprocess.PIPE)
self.assertEqual(cp.returncode, 0)
self.assertEqual(cp.stdout, data)
self.assertRaises(FileNotFoundError, shm_unlink, shm.name)
cp = subprocess.run([kitten_exe(), '__pytest__', 'shm', 'write'], input=data, stdout=subprocess.PIPE)
self.assertEqual(cp.returncode, 0)
name = cp.stdout.decode().strip()
with SharedMemory(name=name, unlink_on_exit=True) as shm:
q = shm.read_data_with_size()
self.assertEqual(data, q)

20
tools/cmd/pytest/main.go Normal file
View File

@ -0,0 +1,20 @@
// License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
package pytest
import (
"fmt"
"kitty/tools/cli"
"kitty/tools/utils/shm"
)
var _ = fmt.Print
func EntryPoint(root *cli.Command) {
root = root.AddSubCommand(&cli.Command{
Name: "__pytest__",
Hidden: true,
})
shm.TestEntryPoint(root)
}

View File

@ -5,10 +5,13 @@ package ssh
import (
"errors"
"fmt"
"net/url"
"os"
"os/user"
"strings"
"kitty/tools/cli"
"kitty/tools/tty"
"golang.org/x/exp/maps"
"golang.org/x/sys/unix"
@ -16,6 +19,74 @@ import (
var _ = fmt.Print
func get_destination(hostname string) (username, hostname_for_match string) {
u, err := user.Current()
if err == nil {
username = u.Username
}
hostname_for_match = hostname
if strings.HasPrefix(hostname, "ssh://") {
p, err := url.Parse(hostname)
if err == nil {
hostname_for_match = p.Hostname()
if p.User.Username() != "" {
username = p.User.Username()
}
}
} else if strings.Contains(hostname, "@") && hostname[0] != '@' {
username, hostname_for_match, _ = strings.Cut(hostname, "@")
}
if strings.Contains(hostname, "@") && hostname[0] != '@' {
_, hostname_for_match, _ = strings.Cut(hostname_for_match, "@")
}
hostname_for_match, _, _ = strings.Cut(hostname_for_match, ":")
return
}
func add_cloned_env(val string) map[string]string {
return nil // TODO: Implement me
}
func parse_kitten_args(found_extra_args []string, username, hostname_for_match string) (overrides []string, literal_env map[string]string) {
literal_env = make(map[string]string)
overrides = make([]string, 0, 4)
for i, a := range found_extra_args {
if i%2 == 0 {
continue
}
if key, val, found := strings.Cut(a, "="); found {
if key == "clone_env" {
le := add_cloned_env(val)
if le != nil {
literal_env = le
}
} else if key != "hostname" {
overrides = append(overrides, key+" "+val)
}
}
}
if len(overrides) > 0 {
overrides = append([]string{"hostname " + username + "@" + hostname_for_match}, overrides...)
}
return
}
func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err error) {
cmd := append([]string{ssh_exe()}, ssh_args...)
hostname, remote_args := server_args[0], server_args[1:]
if len(remote_args) == 0 {
cmd = append(cmd, "-t")
}
insertion_point := len(cmd)
cmd = append(cmd, "--", hostname)
uname, hostname_for_match := get_destination(hostname)
overrides, literal_env := parse_kitten_args(found_extra_args, uname, hostname_for_match)
if insertion_point > 0 && overrides != nil && literal_env != nil {
}
// TODO: Implement me
return
}
func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) {
if len(args) > 0 {
switch args[0] {
@ -44,10 +115,13 @@ func main(cmd *cli.Command, o *Options, args []string) (rc int, err error) {
}
return 1, unix.Exec(ssh_exe(), append([]string{"ssh"}, args...), os.Environ())
}
if false {
return len(ssh_args) + len(server_args), nil
if os.Getenv("KITTY_WINDOW_ID") == "" || os.Getenv("KITTY_PID") == "" {
return 1, fmt.Errorf("The SSH kitten is meant to run inside a kitty window")
}
return
if !tty.IsTerminal(os.Stdin.Fd()) {
return 1, fmt.Errorf("The SSH kitten is meant for interactive use only, STDIN must be a terminal")
}
return run_ssh(ssh_args, server_args, found_extra_args)
}
func EntryPoint(parent *cli.Command) {

View File

@ -10,6 +10,7 @@ import (
"kitty/tools/cmd/clipboard"
"kitty/tools/cmd/edit_in_kitty"
"kitty/tools/cmd/icat"
"kitty/tools/cmd/pytest"
"kitty/tools/cmd/ssh"
"kitty/tools/cmd/unicode_input"
"kitty/tools/cmd/update_self"
@ -35,6 +36,8 @@ func KittyToolEntryPoints(root *cli.Command) {
ssh.EntryPoint(root)
// unicode_input
unicode_input.EntryPoint(root)
// __pytest__
pytest.EntryPoint(root)
// __hold_till_enter__
root.AddSubCommand(&cli.Command{
Name: "__hold_till_enter__",

View File

@ -5,13 +5,17 @@ package shm
import (
"crypto/rand"
"encoding/base32"
"encoding/binary"
"errors"
"fmt"
"io"
not_rand "math/rand"
"os"
"strconv"
"strings"
"kitty/tools/cli"
"golang.org/x/sys/unix"
)
@ -109,3 +113,69 @@ func truncate_or_unlink(ans *os.File, size uint64) (err error) {
}
return
}
func read_till_buf_full(f *os.File, buf []byte) ([]byte, error) {
p := buf
for len(p) > 0 {
n, err := f.Read(p)
p = p[n:]
if err != nil {
if len(p) == 0 && errors.Is(err, io.EOF) {
err = nil
}
return buf[:len(buf)-len(p)], err
}
}
return buf, nil
}
func read_with_size(f *os.File) ([]byte, error) {
szbuf := []byte{0, 0, 0, 0}
szbuf, err := read_till_buf_full(f, szbuf)
if err != nil {
return nil, err
}
size := int(binary.BigEndian.Uint32(szbuf))
return read_till_buf_full(f, make([]byte, size))
}
func test_integration_with_python(args []string) (rc int, err error) {
switch args[0] {
default:
return 1, fmt.Errorf("Unknown test type: %s", args[0])
case "read":
data, err := ReadWithSizeAndUnlink(args[1])
if err != nil {
return 1, err
}
_, err = os.Stdout.Write(data)
if err != nil {
return 1, err
}
case "write":
data, err := io.ReadAll(os.Stdin)
if err != nil {
return 1, err
}
mmap, err := CreateTemp("shmtest-", uint64(len(data)+4))
if err != nil {
return 1, err
}
defer mmap.Close()
binary.BigEndian.PutUint32(mmap.Slice(), uint32(len(data)))
copy(mmap.Slice()[4:], data)
fmt.Println(mmap.Name())
}
return 0, nil
}
func TestEntryPoint(root *cli.Command) {
root.AddSubCommand(&cli.Command{
Name: "shm",
OnlyArgsAllowed: true,
Run: func(cmd *cli.Command, args []string) (rc int, err error) {
return test_integration_with_python(args)
},
})
}

View File

@ -113,7 +113,7 @@ func create_temp(pattern string, size uint64) (ans MMap, err error) {
return file_mmap(f, size, WRITE, true, special_name)
}
func Open(name string, size uint64) (MMap, error) {
func open(name string) (*os.File, error) {
ans, err := os.OpenFile(file_path_from_name(name), os.O_RDONLY, 0)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
@ -123,5 +123,23 @@ func Open(name string, size uint64) (MMap, error) {
}
return nil, err
}
return ans, nil
}
func Open(name string, size uint64) (MMap, error) {
ans, err := open(name)
if err != nil {
return nil, err
}
return file_mmap(ans, size, READ, false, name)
}
func ReadWithSizeAndUnlink(name string) ([]byte, error) {
f, err := open(name)
if err != nil {
return nil, err
}
defer f.Close()
defer os.Remove(f.Name())
return read_with_size(f)
}

View File

@ -151,3 +151,13 @@ func Open(name string, size uint64) (MMap, error) {
}
return syscall_mmap(ans, size, READ, false)
}
func ReadWithSizeAndUnlink(name string) ([]byte, error) {
f, err := shm_open(name, os.O_RDONLY, 0)
if err != nil {
return nil, err
}
defer f.Close()
defer shm_unlink(f.Name())
return read_with_size(f)
}