mirror of
https://github.com/wez/wezterm.git
synced 2024-12-23 21:32:13 +03:00
ssh: fix coordination of pty readers
At some recent point in history, I effective broke multiple tabs in `wezterm ssh HOST` by allowing them to contend in weird ways on locks, leading to a horribly sluggish experience where multiple keypresses in alternate tabs would appear to be swallowed until IO happened in another tab. Yuk! This commit fixes that up by teaching channels how to wait cooperatively and to attempt a read in all waiting channels when the fd becomes readable.
This commit is contained in:
parent
5f2f35971d
commit
8f9d654301
126
pty/src/ssh.rs
126
pty/src/ssh.rs
@ -6,12 +6,12 @@
|
||||
//! before we can get to a point where `openpty` will be able to run.
|
||||
use crate::{Child, CommandBuilder, ExitStatus, MasterPty, PtyPair, PtySize, PtySystem, SlavePty};
|
||||
use anyhow::Context;
|
||||
use filedescriptor::AsRawSocketDescriptor;
|
||||
use filedescriptor::{AsRawSocketDescriptor, POLLIN};
|
||||
use ssh2::{Channel, Session};
|
||||
use std::collections::HashMap;
|
||||
use std::io::Result as IoResult;
|
||||
use std::io::{Read, Write};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::{Arc, Condvar, Mutex};
|
||||
|
||||
/// Represents a pty channel within a session.
|
||||
struct SshPty {
|
||||
@ -31,6 +31,15 @@ struct SessionInner {
|
||||
ptys: HashMap<usize, SshPty>,
|
||||
next_channel_id: usize,
|
||||
term: String,
|
||||
/// an instance of SshReader owns the wait for read and subsequent
|
||||
/// wakeup broadcast
|
||||
waiting_for_read: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SessionHolder {
|
||||
locked_inner: Mutex<SessionInner>,
|
||||
read_waiters: Condvar,
|
||||
}
|
||||
|
||||
// An anemic impl of Debug to satisfy some indirect trait bounds
|
||||
@ -49,7 +58,7 @@ impl std::fmt::Debug for SessionInner {
|
||||
/// implements the `PtySystem` trait and exposes the `openpty` function
|
||||
/// that can be used to return a remote pty via ssh.
|
||||
pub struct SshSession {
|
||||
inner: Arc<Mutex<SessionInner>>,
|
||||
inner: Arc<SessionHolder>,
|
||||
}
|
||||
|
||||
impl SshSession {
|
||||
@ -62,19 +71,23 @@ impl SshSession {
|
||||
/// the case that a pty needs to be allocated.
|
||||
pub fn new(session: Session, term: &str) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(SessionInner {
|
||||
session,
|
||||
ptys: HashMap::new(),
|
||||
next_channel_id: 1,
|
||||
term: term.to_string(),
|
||||
})),
|
||||
inner: Arc::new(SessionHolder {
|
||||
locked_inner: Mutex::new(SessionInner {
|
||||
session,
|
||||
ptys: HashMap::new(),
|
||||
next_channel_id: 1,
|
||||
term: term.to_string(),
|
||||
waiting_for_read: false,
|
||||
}),
|
||||
read_waiters: Condvar::new(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PtySystem for SshSession {
|
||||
fn openpty(&self, size: PtySize) -> anyhow::Result<PtyPair> {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
let mut inner = self.inner.locked_inner.lock().unwrap();
|
||||
let mut channel = inner.session.channel_session()?;
|
||||
channel.handle_extended_data(ssh2::ExtendedData::Merge)?;
|
||||
channel.request_pty(
|
||||
@ -110,27 +123,21 @@ impl PtySystem for SshSession {
|
||||
#[derive(Clone, Debug)]
|
||||
struct PtyHandle {
|
||||
id: usize,
|
||||
inner: Arc<Mutex<SessionInner>>,
|
||||
inner: Arc<SessionHolder>,
|
||||
}
|
||||
|
||||
impl PtyHandle {
|
||||
/// Acquire the session mutex and then perform a lambda on the Channel
|
||||
fn with_channel<R, F: FnMut(&mut Channel) -> R>(&self, mut f: F) -> R {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
let mut inner = self.inner.locked_inner.lock().unwrap();
|
||||
f(&mut inner.ptys.get_mut(&self.id).unwrap().channel)
|
||||
}
|
||||
|
||||
/// Acquire the session mutex and then perform a lambda on the SshPty
|
||||
fn with_pty<R, F: FnMut(&mut SshPty) -> R>(&self, mut f: F) -> R {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
let mut inner = self.inner.locked_inner.lock().unwrap();
|
||||
f(&mut inner.ptys.get_mut(&self.id).unwrap())
|
||||
}
|
||||
|
||||
fn as_socket_descriptor(&self) -> filedescriptor::SocketDescriptor {
|
||||
let inner = self.inner.lock().unwrap();
|
||||
let stream = inner.session.tcp_stream();
|
||||
stream.as_ref().unwrap().as_socket_descriptor()
|
||||
}
|
||||
}
|
||||
|
||||
struct SshMaster {
|
||||
@ -208,15 +215,19 @@ struct SshChild {
|
||||
|
||||
impl Child for SshChild {
|
||||
fn try_wait(&mut self) -> IoResult<Option<ExitStatus>> {
|
||||
self.pty.with_channel(|channel| {
|
||||
if channel.eof() {
|
||||
let mut lock = self.pty.inner.locked_inner.try_lock();
|
||||
if let Ok(ref mut inner) = lock {
|
||||
let ssh_pty = inner.ptys.get_mut(&self.pty.id).unwrap();
|
||||
if ssh_pty.channel.eof() {
|
||||
Ok(Some(ExitStatus::with_exit_code(
|
||||
channel.exit_status()? as u32
|
||||
ssh_pty.channel.exit_status()? as u32,
|
||||
)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn kill(&mut self) -> IoResult<()> {
|
||||
@ -239,37 +250,56 @@ struct SshReader {
|
||||
|
||||
impl Read for SshReader {
|
||||
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
|
||||
// A blocking read, but we don't want to own the mutex while we
|
||||
// sleep, so we manually poll the underlying socket descriptor
|
||||
// and then use a non-blocking read to read the actual data
|
||||
let socket = self.pty.as_socket_descriptor();
|
||||
loop {
|
||||
// Wait for input on the descriptor
|
||||
let mut pfd = [filedescriptor::pollfd {
|
||||
fd: socket,
|
||||
events: filedescriptor::POLLIN,
|
||||
revents: 0,
|
||||
}];
|
||||
filedescriptor::poll(&mut pfd, None).ok();
|
||||
let mut inner = self.pty.inner.locked_inner.lock().unwrap();
|
||||
|
||||
// a read won't block, so ask libssh2 for data from the
|
||||
// associated channel, but do not block!
|
||||
let res = {
|
||||
let mut inner = self.pty.inner.lock().unwrap();
|
||||
inner.session.set_blocking(false);
|
||||
let res = inner.ptys.get_mut(&self.pty.id).unwrap().channel.read(buf);
|
||||
inner.session.set_blocking(true);
|
||||
res
|
||||
};
|
||||
|
||||
// If we have data or an error, return it, otherwise let's
|
||||
// try again!
|
||||
inner.session.set_blocking(false);
|
||||
let res = inner.ptys.get_mut(&self.pty.id).unwrap().channel.read(buf);
|
||||
inner.session.set_blocking(true);
|
||||
match res {
|
||||
Ok(len) => return Ok(len),
|
||||
Ok(size) => return Ok(size),
|
||||
Err(err) => match err.kind() {
|
||||
std::io::ErrorKind::WouldBlock => continue,
|
||||
std::io::ErrorKind::WouldBlock => {}
|
||||
_ => return Err(err),
|
||||
},
|
||||
};
|
||||
|
||||
// No data available for this channel, so we'll wait.
|
||||
// If we're the first SshReader to do this, we'll perform the
|
||||
// OS level poll() call for ourselves, otherwise we'll block
|
||||
// on the condvar
|
||||
if inner.waiting_for_read {
|
||||
self.pty.inner.read_waiters.wait(inner).ok();
|
||||
} else {
|
||||
let socket = inner
|
||||
.session
|
||||
.tcp_stream()
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.as_socket_descriptor();
|
||||
|
||||
// We own waiting for read
|
||||
inner.waiting_for_read = true;
|
||||
|
||||
// Unlock and wait
|
||||
drop(inner);
|
||||
|
||||
let mut pfd = [filedescriptor::pollfd {
|
||||
fd: socket,
|
||||
events: POLLIN,
|
||||
revents: 0,
|
||||
}];
|
||||
filedescriptor::poll(&mut pfd, None).ok();
|
||||
|
||||
// re-acquire the lock to release our ownership of the poll
|
||||
// and to wake up the others
|
||||
let mut inner = self.pty.inner.locked_inner.lock().unwrap();
|
||||
inner.waiting_for_read = false;
|
||||
|
||||
// Wake all readers and we'll all race to read our next
|
||||
// iteration
|
||||
self.pty.inner.read_waiters.notify_all();
|
||||
drop(inner);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user