1
1
mirror of https://github.com/wez/wezterm.git synced 2024-12-23 05:12:40 +03:00

mux: try to detect invalid data on stdout during connection

Use some heuristics to verify the data that is about to be parsed;
this can help to detect eg: data being output to stdout prior
to us sending any encoded data to the remote mux.

In addition, add a timeout to help avoid waiting forever in
the case that we didn't detect a problem.

refs: https://github.com/wez/wezterm/issues/1860
This commit is contained in:
Wez Furlong 2022-04-11 21:08:57 -07:00
parent 1d908457ae
commit 6b67ae842c
5 changed files with 83 additions and 34 deletions

1
Cargo.lock generated
View File

@ -612,6 +612,7 @@ dependencies = [
"serde",
"smol",
"termwiz",
"thiserror",
"varbincode",
"wezterm-term",
"zstd",

View File

@ -18,6 +18,7 @@ rangeset = { path = "../rangeset" }
serde = {version="1.0", features = ["rc", "derive"]}
smol = "1.2"
termwiz = { path = "../termwiz" }
thiserror = "1.0"
varbincode = "0.1"
wezterm-term = { path = "../term", features=["use_serde"] }
zstd = "0.6"

View File

@ -30,9 +30,14 @@ use std::sync::Arc;
use termwiz::hyperlink::Hyperlink;
use termwiz::image::{ImageData, TextureCoordinate};
use termwiz::surface::{Line, SequenceNo};
use thiserror::Error;
use wezterm_term::color::ColorPalette;
use wezterm_term::{Alert, ClipboardSelection, StableRowIndex};
#[derive(Error, Debug)]
#[error("Corrupt Response")]
pub struct CorruptResponse;
/// Returns the encoded length of the leb128 representation of value
fn encoded_length(value: u64) -> usize {
struct NullWrite {}
@ -164,6 +169,7 @@ struct Decoded {
/// See encode_raw() for the frame format.
async fn decode_raw_async<R: Unpin + AsyncRead + std::fmt::Debug>(
r: &mut R,
max_serial: Option<u64>,
) -> anyhow::Result<Decoded> {
let len = read_u64_async(r).await.context("reading PDU length")?;
let (len, is_compressed) = if (len & COMPRESSED_MASK) != 0 {
@ -172,6 +178,11 @@ async fn decode_raw_async<R: Unpin + AsyncRead + std::fmt::Debug>(
(len, false)
};
let serial = read_u64_async(r).await.context("reading PDU serial")?;
if let Some(max_serial) = max_serial {
if serial > max_serial && max_serial > 0 {
return Err(CorruptResponse).context("decode_raw");
}
}
let ident = read_u64_async(r).await.context("reading PDU ident")?;
let data_len =
match (len as usize).overflowing_sub(encoded_length(ident) + encoded_length(serial)) {
@ -373,12 +384,12 @@ macro_rules! pdu {
}
}
pub async fn decode_async<R>(r: &mut R) -> Result<DecodedPdu, Error>
pub async fn decode_async<R>(r: &mut R, max_serial: Option<u64>) -> Result<DecodedPdu, Error>
where R: std::marker::Unpin,
R: AsyncRead,
R: std::fmt::Debug
{
let decoded = decode_raw_async(r).await.context("decoding a PDU")?;
let decoded = decode_raw_async(r, max_serial).await.context("decoding a PDU")?;
match decoded.ident {
$(
$vers => {

View File

@ -31,6 +31,14 @@ use std::thread;
use std::time::Duration;
use thiserror::Error;
#[derive(Error, Debug)]
#[error("Timeout")]
struct Timeout;
#[derive(Error, Debug)]
#[error("ChannelSendError")]
struct ChannelSendError;
enum ReaderMessage {
SendPdu {
pdu: Pdu,
@ -288,34 +296,36 @@ async fn client_thread_async(
.context("encoding a PDU to send to the server")?;
stream.flush().await.context("flushing PDU to server")?;
}
Ok(ReaderMessage::Readable) => match Pdu::decode_async(&mut stream).await {
Ok(decoded) => {
log::trace!("decoded serial {}", decoded.serial);
if decoded.serial == 0 {
process_unilateral(local_domain_id, decoded)
.context("processing unilateral PDU from server")
.map_err(|e| {
log::error!("process_unilateral: {:?}", e);
e
})?;
} else if let Some(promise) = promises.map.remove(&decoded.serial) {
if promise.try_send(Ok(decoded.pdu)).is_err() {
return Err(NotReconnectableError::ClientWasDestroyed.into());
Ok(ReaderMessage::Readable) => {
match Pdu::decode_async(&mut stream, Some(next_serial)).await {
Ok(decoded) => {
log::trace!("decoded serial {}", decoded.serial);
if decoded.serial == 0 {
process_unilateral(local_domain_id, decoded)
.context("processing unilateral PDU from server")
.map_err(|e| {
log::error!("process_unilateral: {:?}", e);
e
})?;
} else if let Some(promise) = promises.map.remove(&decoded.serial) {
if promise.try_send(Ok(decoded.pdu)).is_err() {
return Err(NotReconnectableError::ClientWasDestroyed.into());
}
} else {
let reason =
format!("got serial {:?} without a corresponding promise", decoded);
promises.fail_all(&reason);
anyhow::bail!("{}", reason);
}
} else {
let reason =
format!("got serial {:?} without a corresponding promise", decoded);
}
Err(err) => {
let reason = format!("Error while decoding response pdu: {:#}", err);
log::error!("{}", reason);
promises.fail_all(&reason);
anyhow::bail!("{}", reason);
return Err(err).context("Error while decoding response pdu");
}
}
Err(err) => {
let reason = format!("Error while decoding response pdu: {:#}", err);
log::error!("{}", reason);
promises.fail_all(&reason);
return Err(err).context("Error while decoding response pdu");
}
},
}
Err(_) => {
return Err(NotReconnectableError::ClientWasDestroyed.into());
}
@ -1014,7 +1024,14 @@ impl Client {
&self,
ui: &ConnectionUI,
) -> anyhow::Result<GetCodecVersionResponse> {
match self.get_codec_version(GetCodecVersion {}).await {
match self
.get_codec_version(GetCodecVersion {})
.or(async {
smol::Timer::after(Duration::from_secs(60)).await;
Err(Timeout).context("Timeout")
})
.await
{
Ok(info) if info.codec_vers == CODEC_VERSION => {
log::trace!(
"Server version is {} (codec version {})",
@ -1037,14 +1054,31 @@ impl Client {
return Err(err.into());
}
Err(err) => {
let msg = format!(
"Please install the same version of wezterm on both \
log::trace!("{:?}", err);
let msg = if err.root_cause().is::<Timeout>() {
"Timed out while parsing the response from the server. \
This may be due to network connectivity issues"
.to_string()
} else if err.root_cause().is::<CorruptResponse>() {
"Received an implausible and likely corrupt response from \
the server. This can happen if the remote host outputs \
to stdout prior to running commands."
.to_string()
} else if err.root_cause().is::<ChannelSendError>() {
"Internal channel was closed prior to sending request. \
This may indicate that the remote host output invalid data \
to stdout prior to running the requested command"
.to_string()
} else {
format!(
"Please install the same version of wezterm on both \
the client and server! \
The server reported error '{}' while being asked for its \
version. This likely means that the server is older \
than the client.\n",
err
);
err
)
};
ui.output_str(&msg);
bail!("{}", msg);
}
@ -1142,8 +1176,10 @@ impl Client {
let (promise, rx) = bounded(1);
self.sender
.send(ReaderMessage::SendPdu { pdu, promise })
.await?;
rx.recv().await?
.await
.map_err(|_| ChannelSendError)
.context("send_pdu send")?;
rx.recv().await.context("send_pdu recv")?
}
rpc!(ping, Ping = (), Pong);

View File

@ -68,7 +68,7 @@ where
match smol::future::or(rx_msg, wait_for_read).await {
Ok(Item::Readable) => {
let decoded = match Pdu::decode_async(&mut stream).await {
let decoded = match Pdu::decode_async(&mut stream, None).await {
Ok(data) => data,
Err(err) => {
if let Some(err) = err.root_cause().downcast_ref::<std::io::Error>() {