1
1
mirror of https://github.com/wez/wezterm.git synced 2024-12-24 13:52:55 +03:00

ssh: fix coordination of pty readers

At some recent point in history, I effective broke multiple tabs in
`wezterm ssh HOST` by allowing them to contend in weird ways on locks,
leading to a horribly sluggish experience where multiple keypresses
in alternate tabs would appear to be swallowed until IO happened in
another tab.  Yuk!

This commit fixes that up by teaching channels how to wait cooperatively
and to attempt a read in all waiting channels when the fd becomes
readable.
This commit is contained in:
Wez Furlong 2020-01-31 22:53:55 -08:00
parent 5f2f35971d
commit 8f9d654301

View File

@ -6,12 +6,12 @@
//! before we can get to a point where `openpty` will be able to run.
use crate::{Child, CommandBuilder, ExitStatus, MasterPty, PtyPair, PtySize, PtySystem, SlavePty};
use anyhow::Context;
use filedescriptor::AsRawSocketDescriptor;
use filedescriptor::{AsRawSocketDescriptor, POLLIN};
use ssh2::{Channel, Session};
use std::collections::HashMap;
use std::io::Result as IoResult;
use std::io::{Read, Write};
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Condvar, Mutex};
/// Represents a pty channel within a session.
struct SshPty {
@ -31,6 +31,15 @@ struct SessionInner {
ptys: HashMap<usize, SshPty>,
next_channel_id: usize,
term: String,
/// an instance of SshReader owns the wait for read and subsequent
/// wakeup broadcast
waiting_for_read: bool,
}
#[derive(Debug)]
struct SessionHolder {
locked_inner: Mutex<SessionInner>,
read_waiters: Condvar,
}
// An anemic impl of Debug to satisfy some indirect trait bounds
@ -49,7 +58,7 @@ impl std::fmt::Debug for SessionInner {
/// implements the `PtySystem` trait and exposes the `openpty` function
/// that can be used to return a remote pty via ssh.
pub struct SshSession {
inner: Arc<Mutex<SessionInner>>,
inner: Arc<SessionHolder>,
}
impl SshSession {
@ -62,19 +71,23 @@ impl SshSession {
/// the case that a pty needs to be allocated.
pub fn new(session: Session, term: &str) -> Self {
Self {
inner: Arc::new(Mutex::new(SessionInner {
inner: Arc::new(SessionHolder {
locked_inner: Mutex::new(SessionInner {
session,
ptys: HashMap::new(),
next_channel_id: 1,
term: term.to_string(),
})),
waiting_for_read: false,
}),
read_waiters: Condvar::new(),
}),
}
}
}
impl PtySystem for SshSession {
fn openpty(&self, size: PtySize) -> anyhow::Result<PtyPair> {
let mut inner = self.inner.lock().unwrap();
let mut inner = self.inner.locked_inner.lock().unwrap();
let mut channel = inner.session.channel_session()?;
channel.handle_extended_data(ssh2::ExtendedData::Merge)?;
channel.request_pty(
@ -110,27 +123,21 @@ impl PtySystem for SshSession {
#[derive(Clone, Debug)]
struct PtyHandle {
id: usize,
inner: Arc<Mutex<SessionInner>>,
inner: Arc<SessionHolder>,
}
impl PtyHandle {
/// Acquire the session mutex and then perform a lambda on the Channel
fn with_channel<R, F: FnMut(&mut Channel) -> R>(&self, mut f: F) -> R {
let mut inner = self.inner.lock().unwrap();
let mut inner = self.inner.locked_inner.lock().unwrap();
f(&mut inner.ptys.get_mut(&self.id).unwrap().channel)
}
/// Acquire the session mutex and then perform a lambda on the SshPty
fn with_pty<R, F: FnMut(&mut SshPty) -> R>(&self, mut f: F) -> R {
let mut inner = self.inner.lock().unwrap();
let mut inner = self.inner.locked_inner.lock().unwrap();
f(&mut inner.ptys.get_mut(&self.id).unwrap())
}
fn as_socket_descriptor(&self) -> filedescriptor::SocketDescriptor {
let inner = self.inner.lock().unwrap();
let stream = inner.session.tcp_stream();
stream.as_ref().unwrap().as_socket_descriptor()
}
}
struct SshMaster {
@ -208,15 +215,19 @@ struct SshChild {
impl Child for SshChild {
fn try_wait(&mut self) -> IoResult<Option<ExitStatus>> {
self.pty.with_channel(|channel| {
if channel.eof() {
let mut lock = self.pty.inner.locked_inner.try_lock();
if let Ok(ref mut inner) = lock {
let ssh_pty = inner.ptys.get_mut(&self.pty.id).unwrap();
if ssh_pty.channel.eof() {
Ok(Some(ExitStatus::with_exit_code(
channel.exit_status()? as u32
ssh_pty.channel.exit_status()? as u32,
)))
} else {
Ok(None)
}
})
} else {
Ok(None)
}
}
fn kill(&mut self) -> IoResult<()> {
@ -239,37 +250,56 @@ struct SshReader {
impl Read for SshReader {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
// A blocking read, but we don't want to own the mutex while we
// sleep, so we manually poll the underlying socket descriptor
// and then use a non-blocking read to read the actual data
let socket = self.pty.as_socket_descriptor();
loop {
// Wait for input on the descriptor
let mut inner = self.pty.inner.locked_inner.lock().unwrap();
inner.session.set_blocking(false);
let res = inner.ptys.get_mut(&self.pty.id).unwrap().channel.read(buf);
inner.session.set_blocking(true);
match res {
Ok(size) => return Ok(size),
Err(err) => match err.kind() {
std::io::ErrorKind::WouldBlock => {}
_ => return Err(err),
},
};
// No data available for this channel, so we'll wait.
// If we're the first SshReader to do this, we'll perform the
// OS level poll() call for ourselves, otherwise we'll block
// on the condvar
if inner.waiting_for_read {
self.pty.inner.read_waiters.wait(inner).ok();
} else {
let socket = inner
.session
.tcp_stream()
.as_ref()
.unwrap()
.as_socket_descriptor();
// We own waiting for read
inner.waiting_for_read = true;
// Unlock and wait
drop(inner);
let mut pfd = [filedescriptor::pollfd {
fd: socket,
events: filedescriptor::POLLIN,
events: POLLIN,
revents: 0,
}];
filedescriptor::poll(&mut pfd, None).ok();
// a read won't block, so ask libssh2 for data from the
// associated channel, but do not block!
let res = {
let mut inner = self.pty.inner.lock().unwrap();
inner.session.set_blocking(false);
let res = inner.ptys.get_mut(&self.pty.id).unwrap().channel.read(buf);
inner.session.set_blocking(true);
res
};
// re-acquire the lock to release our ownership of the poll
// and to wake up the others
let mut inner = self.pty.inner.locked_inner.lock().unwrap();
inner.waiting_for_read = false;
// If we have data or an error, return it, otherwise let's
// try again!
match res {
Ok(len) => return Ok(len),
Err(err) => match err.kind() {
std::io::ErrorKind::WouldBlock => continue,
_ => return Err(err),
},
// Wake all readers and we'll all race to read our next
// iteration
self.pty.inner.read_waiters.notify_all();
drop(inner);
}
}
}