mirror of
https://github.com/wez/wezterm.git
synced 2024-12-23 05:12:40 +03:00
mux: try to detect invalid data on stdout during connection
Use some heuristics to verify the data that is about to be parsed; this can help to detect eg: data being output to stdout prior to us sending any encoded data to the remote mux. In addition, add a timeout to help avoid waiting forever in the case that we didn't detect a problem. refs: https://github.com/wez/wezterm/issues/1860
This commit is contained in:
parent
1d908457ae
commit
6b67ae842c
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -612,6 +612,7 @@ dependencies = [
|
||||
"serde",
|
||||
"smol",
|
||||
"termwiz",
|
||||
"thiserror",
|
||||
"varbincode",
|
||||
"wezterm-term",
|
||||
"zstd",
|
||||
|
@ -18,6 +18,7 @@ rangeset = { path = "../rangeset" }
|
||||
serde = {version="1.0", features = ["rc", "derive"]}
|
||||
smol = "1.2"
|
||||
termwiz = { path = "../termwiz" }
|
||||
thiserror = "1.0"
|
||||
varbincode = "0.1"
|
||||
wezterm-term = { path = "../term", features=["use_serde"] }
|
||||
zstd = "0.6"
|
||||
|
@ -30,9 +30,14 @@ use std::sync::Arc;
|
||||
use termwiz::hyperlink::Hyperlink;
|
||||
use termwiz::image::{ImageData, TextureCoordinate};
|
||||
use termwiz::surface::{Line, SequenceNo};
|
||||
use thiserror::Error;
|
||||
use wezterm_term::color::ColorPalette;
|
||||
use wezterm_term::{Alert, ClipboardSelection, StableRowIndex};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("Corrupt Response")]
|
||||
pub struct CorruptResponse;
|
||||
|
||||
/// Returns the encoded length of the leb128 representation of value
|
||||
fn encoded_length(value: u64) -> usize {
|
||||
struct NullWrite {}
|
||||
@ -164,6 +169,7 @@ struct Decoded {
|
||||
/// See encode_raw() for the frame format.
|
||||
async fn decode_raw_async<R: Unpin + AsyncRead + std::fmt::Debug>(
|
||||
r: &mut R,
|
||||
max_serial: Option<u64>,
|
||||
) -> anyhow::Result<Decoded> {
|
||||
let len = read_u64_async(r).await.context("reading PDU length")?;
|
||||
let (len, is_compressed) = if (len & COMPRESSED_MASK) != 0 {
|
||||
@ -172,6 +178,11 @@ async fn decode_raw_async<R: Unpin + AsyncRead + std::fmt::Debug>(
|
||||
(len, false)
|
||||
};
|
||||
let serial = read_u64_async(r).await.context("reading PDU serial")?;
|
||||
if let Some(max_serial) = max_serial {
|
||||
if serial > max_serial && max_serial > 0 {
|
||||
return Err(CorruptResponse).context("decode_raw");
|
||||
}
|
||||
}
|
||||
let ident = read_u64_async(r).await.context("reading PDU ident")?;
|
||||
let data_len =
|
||||
match (len as usize).overflowing_sub(encoded_length(ident) + encoded_length(serial)) {
|
||||
@ -373,12 +384,12 @@ macro_rules! pdu {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn decode_async<R>(r: &mut R) -> Result<DecodedPdu, Error>
|
||||
pub async fn decode_async<R>(r: &mut R, max_serial: Option<u64>) -> Result<DecodedPdu, Error>
|
||||
where R: std::marker::Unpin,
|
||||
R: AsyncRead,
|
||||
R: std::fmt::Debug
|
||||
{
|
||||
let decoded = decode_raw_async(r).await.context("decoding a PDU")?;
|
||||
let decoded = decode_raw_async(r, max_serial).await.context("decoding a PDU")?;
|
||||
match decoded.ident {
|
||||
$(
|
||||
$vers => {
|
||||
|
@ -31,6 +31,14 @@ use std::thread;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("Timeout")]
|
||||
struct Timeout;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("ChannelSendError")]
|
||||
struct ChannelSendError;
|
||||
|
||||
enum ReaderMessage {
|
||||
SendPdu {
|
||||
pdu: Pdu,
|
||||
@ -288,34 +296,36 @@ async fn client_thread_async(
|
||||
.context("encoding a PDU to send to the server")?;
|
||||
stream.flush().await.context("flushing PDU to server")?;
|
||||
}
|
||||
Ok(ReaderMessage::Readable) => match Pdu::decode_async(&mut stream).await {
|
||||
Ok(decoded) => {
|
||||
log::trace!("decoded serial {}", decoded.serial);
|
||||
if decoded.serial == 0 {
|
||||
process_unilateral(local_domain_id, decoded)
|
||||
.context("processing unilateral PDU from server")
|
||||
.map_err(|e| {
|
||||
log::error!("process_unilateral: {:?}", e);
|
||||
e
|
||||
})?;
|
||||
} else if let Some(promise) = promises.map.remove(&decoded.serial) {
|
||||
if promise.try_send(Ok(decoded.pdu)).is_err() {
|
||||
return Err(NotReconnectableError::ClientWasDestroyed.into());
|
||||
Ok(ReaderMessage::Readable) => {
|
||||
match Pdu::decode_async(&mut stream, Some(next_serial)).await {
|
||||
Ok(decoded) => {
|
||||
log::trace!("decoded serial {}", decoded.serial);
|
||||
if decoded.serial == 0 {
|
||||
process_unilateral(local_domain_id, decoded)
|
||||
.context("processing unilateral PDU from server")
|
||||
.map_err(|e| {
|
||||
log::error!("process_unilateral: {:?}", e);
|
||||
e
|
||||
})?;
|
||||
} else if let Some(promise) = promises.map.remove(&decoded.serial) {
|
||||
if promise.try_send(Ok(decoded.pdu)).is_err() {
|
||||
return Err(NotReconnectableError::ClientWasDestroyed.into());
|
||||
}
|
||||
} else {
|
||||
let reason =
|
||||
format!("got serial {:?} without a corresponding promise", decoded);
|
||||
promises.fail_all(&reason);
|
||||
anyhow::bail!("{}", reason);
|
||||
}
|
||||
} else {
|
||||
let reason =
|
||||
format!("got serial {:?} without a corresponding promise", decoded);
|
||||
}
|
||||
Err(err) => {
|
||||
let reason = format!("Error while decoding response pdu: {:#}", err);
|
||||
log::error!("{}", reason);
|
||||
promises.fail_all(&reason);
|
||||
anyhow::bail!("{}", reason);
|
||||
return Err(err).context("Error while decoding response pdu");
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let reason = format!("Error while decoding response pdu: {:#}", err);
|
||||
log::error!("{}", reason);
|
||||
promises.fail_all(&reason);
|
||||
return Err(err).context("Error while decoding response pdu");
|
||||
}
|
||||
},
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(NotReconnectableError::ClientWasDestroyed.into());
|
||||
}
|
||||
@ -1014,7 +1024,14 @@ impl Client {
|
||||
&self,
|
||||
ui: &ConnectionUI,
|
||||
) -> anyhow::Result<GetCodecVersionResponse> {
|
||||
match self.get_codec_version(GetCodecVersion {}).await {
|
||||
match self
|
||||
.get_codec_version(GetCodecVersion {})
|
||||
.or(async {
|
||||
smol::Timer::after(Duration::from_secs(60)).await;
|
||||
Err(Timeout).context("Timeout")
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(info) if info.codec_vers == CODEC_VERSION => {
|
||||
log::trace!(
|
||||
"Server version is {} (codec version {})",
|
||||
@ -1037,14 +1054,31 @@ impl Client {
|
||||
return Err(err.into());
|
||||
}
|
||||
Err(err) => {
|
||||
let msg = format!(
|
||||
"Please install the same version of wezterm on both \
|
||||
log::trace!("{:?}", err);
|
||||
let msg = if err.root_cause().is::<Timeout>() {
|
||||
"Timed out while parsing the response from the server. \
|
||||
This may be due to network connectivity issues"
|
||||
.to_string()
|
||||
} else if err.root_cause().is::<CorruptResponse>() {
|
||||
"Received an implausible and likely corrupt response from \
|
||||
the server. This can happen if the remote host outputs \
|
||||
to stdout prior to running commands."
|
||||
.to_string()
|
||||
} else if err.root_cause().is::<ChannelSendError>() {
|
||||
"Internal channel was closed prior to sending request. \
|
||||
This may indicate that the remote host output invalid data \
|
||||
to stdout prior to running the requested command"
|
||||
.to_string()
|
||||
} else {
|
||||
format!(
|
||||
"Please install the same version of wezterm on both \
|
||||
the client and server! \
|
||||
The server reported error '{}' while being asked for its \
|
||||
version. This likely means that the server is older \
|
||||
than the client.\n",
|
||||
err
|
||||
);
|
||||
err
|
||||
)
|
||||
};
|
||||
ui.output_str(&msg);
|
||||
bail!("{}", msg);
|
||||
}
|
||||
@ -1142,8 +1176,10 @@ impl Client {
|
||||
let (promise, rx) = bounded(1);
|
||||
self.sender
|
||||
.send(ReaderMessage::SendPdu { pdu, promise })
|
||||
.await?;
|
||||
rx.recv().await?
|
||||
.await
|
||||
.map_err(|_| ChannelSendError)
|
||||
.context("send_pdu send")?;
|
||||
rx.recv().await.context("send_pdu recv")?
|
||||
}
|
||||
|
||||
rpc!(ping, Ping = (), Pong);
|
||||
|
@ -68,7 +68,7 @@ where
|
||||
|
||||
match smol::future::or(rx_msg, wait_for_read).await {
|
||||
Ok(Item::Readable) => {
|
||||
let decoded = match Pdu::decode_async(&mut stream).await {
|
||||
let decoded = match Pdu::decode_async(&mut stream, None).await {
|
||||
Ok(data) => data,
|
||||
Err(err) => {
|
||||
if let Some(err) = err.root_cause().downcast_ref::<std::io::Error>() {
|
||||
|
Loading…
Reference in New Issue
Block a user