add compression support

Summary:
Add streaming compression support to Mononoke. We do the compression inside our
own framing, to make sure the compression doesn't interfere with any buffering,
flushing and keepalives.

Reviewed By: krallin

Differential Revision: D28323591

fbshipit-source-id: 7192bedcf12a2d0ec025deed8b4fa3857eefd508
This commit is contained in:
Johan Schuijt-Li 2021-05-20 10:21:51 -07:00 committed by Facebook GitHub Bot
parent daf7eb2138
commit 0b87810652
6 changed files with 159 additions and 26 deletions

View File

@ -405,7 +405,7 @@ impl<'a> StdioRelay<'a> {
// Wrap the socket with the ssh codec
let (socket_read, socket_write) = tokio::io::split(socket);
let rx = FramedRead::new(socket_read, SshDecoder::new());
let tx = FramedWrite::new(socket_write, SshEncoder::new());
let tx = FramedWrite::new(socket_write, SshEncoder::new(None)?);
let preamble =
stream::once(async { Ok(SshMsg::new(IoStream::Preamble(preamble), Bytes::new())) });

View File

@ -276,7 +276,7 @@ async fn handle_hgcli<S: MononokeStream>(conn: AcceptedConnection, stream: S) ->
let (rx, tx) = tokio::io::split(stream);
let mut framed = FramedConn::setup(rx, tx);
let mut framed = FramedConn::setup(rx, tx, None)?;
let preamble = match framed.rd.next().await.transpose()? {
Some(maybe_preamble) => {
@ -414,11 +414,11 @@ where
R: AsyncRead + Send + std::marker::Unpin + 'static,
W: AsyncWrite + Send + std::marker::Unpin + 'static,
{
pub fn setup(rd: R, wr: W) -> Self {
pub fn setup(rd: R, wr: W, compression_writes: Option<i32>) -> Result<Self> {
// NOTE: FramedRead does buffering, so no need to wrap with a BufReader here.
let rd = FramedRead::new(rd, SshDecoder::new());
let wr = FramedWrite::new(wr, SshEncoder::new());
Self { rd, wr }
let wr = FramedWrite::new(wr, SshEncoder::new(compression_writes)?);
Ok(Self { rd, wr })
}
}

View File

@ -26,9 +26,11 @@ use crate::connection_acceptor::{self, AcceptedConnection, Acceptor, FramedConn,
use qps::Qps;
const HEADER_CLIENT_COMPRESSION: &str = "x-client-compression";
const HEADER_CLIENT_DEBUG: &str = "x-client-debug";
const HEADER_WEBSOCKET_KEY: &str = "sec-websocket-key";
const HEADER_WEBSOCKET_ACCEPT: &str = "sec-websocket-accept";
const HEADER_MONONOKE_ENCODING: &str = "x-mononoke-encoding";
const HEADER_MONONOKE_HOST: &str = "x-mononoke-host";
const HEADER_REVPROXY_REGION: &str = "x-fb-revproxy-region";
@ -190,21 +192,37 @@ where
let websocket_key = calculate_websocket_accept(req.headers());
let res = Response::builder()
let mut builder = Response::builder()
.status(http::StatusCode::SWITCHING_PROTOCOLS)
.header(http::header::CONNECTION, "upgrade")
.header(http::header::UPGRADE, "websocket")
.header(HEADER_WEBSOCKET_ACCEPT, websocket_key)
.body(Body::empty())
.map_err(HttpError::internal)?;
.header(HEADER_WEBSOCKET_ACCEPT, websocket_key);
let metadata = try_convert_headers_to_metadata(self.conn.is_trusted, &req.headers())
.await
.context("Invalid metadata")
.map_err(HttpError::BadRequest)?;
let zstd_level = 3i32;
let compression = match req.headers().get(HEADER_CLIENT_COMPRESSION) {
Some(header_value) => match header_value.as_bytes() {
b"zstd=stdin" if zstd_level > 0 => Ok(Some(zstd_level)),
_ => Err(anyhow!("'{}' is not a recognized compression value")),
},
None => Ok(None),
}
.map_err(HttpError::BadRequest)?;
let debug = req.headers().get(HEADER_CLIENT_DEBUG).is_some();
match compression {
Some(zstd_level) => {
builder = builder.header(HEADER_MONONOKE_ENCODING, format!("zstd={}", zstd_level));
}
_ => {}
};
let res = builder.body(Body::empty()).map_err(HttpError::internal)?;
let this = self.clone();
let fut = async move {
@ -219,8 +237,7 @@ where
let (rx, tx) = tokio::io::split(io);
let rx = AsyncReadExt::chain(Cursor::new(read_buf), rx);
let framed = FramedConn::setup(rx, tx);
let framed = FramedConn::setup(rx, tx, compression)?;
connection_acceptor::handle_wireproto(this.conn, framed, reponame, metadata, debug)
.await

View File

@ -20,6 +20,8 @@ session_id = { version = "0.1.0", path = "../server/session_id" }
tokio = { version = "1.5", features = ["full", "test-util"] }
tokio-util = { version = "0.6", features = ["full"] }
trust-dns-resolver = "0.20"
zstd = "=0.7.0+zstd.1.4.9"
zstd-safe = "=3.1.0+zstd.1.4.9"
[patch.crates-io]
addr2line = { git = "https://github.com/gimli-rs/addr2line.git", rev = "0b6b6018b5b252a18e628fba03885f7d21844b3c" }

View File

@ -26,6 +26,7 @@ use session_id::{generate_session_id, SessionId};
use tokio::time::timeout;
use tokio_util::codec::{Decoder, Encoder};
use trust_dns_resolver::TokioAsyncResolver;
use zstd::stream::raw::{Encoder as ZstdEncoder, InBuffer, Operation, OutBuffer};
use netstring::{NetstringDecoder, NetstringEncoder};
@ -35,8 +36,10 @@ pub use priority::Priority;
#[derive(Debug)]
pub struct SshDecoder(NetstringDecoder);
#[derive(Debug)]
pub struct SshEncoder(NetstringEncoder<Bytes>);
pub struct SshEncoder {
netstring: NetstringEncoder<Bytes>,
compressor: Option<ZstdEncoder<'static>>,
}
pub struct Stdio {
pub metadata: Arc<Metadata>,
@ -315,8 +318,48 @@ impl Decoder for SshDecoder {
}
impl SshEncoder {
pub fn new() -> Self {
SshEncoder(NetstringEncoder::default())
pub fn new(compression_level: Option<i32>) -> Result<Self> {
match compression_level {
Some(level) => Ok(SshEncoder {
netstring: NetstringEncoder::default(),
compressor: Some(ZstdEncoder::new(level)?),
}),
_ => Ok(SshEncoder {
netstring: NetstringEncoder::default(),
compressor: None,
}),
}
}
fn compress_into<'a>(&mut self, out: &mut BytesMut, input: &'a [u8]) -> Result<()> {
match &mut self.compressor {
Some(compressor) => {
let buflen = zstd_safe::compress_bound(input.len());
if buflen >= zstd_safe::dstream_out_size() {
return Err(anyhow!(
"block is too big to compress in to a single zstd block"
));
}
let mut src = InBuffer::around(input);
let mut dst = vec![0u8; buflen];
let mut dst = OutBuffer::around(&mut dst);
while src.pos < src.src.len() {
compressor.run(&mut src, &mut dst)?;
}
loop {
let remaining = compressor.flush(&mut dst)?;
if remaining == 0 {
break;
}
}
out.put_slice(dst.as_slice());
}
None => out.put_slice(input),
};
Ok(())
}
}
@ -325,21 +368,22 @@ impl Encoder<SshMsg> for SshEncoder {
fn encode(&mut self, msg: SshMsg, buf: &mut BytesMut) -> io::Result<()> {
let mut v = BytesMut::with_capacity(1 + msg.1.len());
match msg.0 {
IoStream::Stdin => {
v.put_u8(0);
v.put_slice(&msg.1);
Ok(self.0.encode(v.freeze(), buf).map_err(ioerr_cvt)?)
self.compress_into(&mut v, &msg.1).map_err(ioerr_cvt)?;
Ok(self.netstring.encode(v.freeze(), buf).map_err(ioerr_cvt)?)
}
IoStream::Stdout => {
v.put_u8(1);
v.put_slice(&msg.1);
Ok(self.0.encode(v.freeze(), buf).map_err(ioerr_cvt)?)
self.compress_into(&mut v, &msg.1).map_err(ioerr_cvt)?;
Ok(self.netstring.encode(v.freeze(), buf).map_err(ioerr_cvt)?)
}
IoStream::Stderr => {
v.put_u8(2);
v.put_slice(&msg.1);
Ok(self.0.encode(v.freeze(), buf).map_err(ioerr_cvt)?)
self.compress_into(&mut v, &msg.1).map_err(ioerr_cvt)?;
Ok(self.netstring.encode(v.freeze(), buf).map_err(ioerr_cvt)?)
}
IoStream::Preamble(preamble) => {
// msg.1 is ignored in preamble
@ -347,7 +391,7 @@ impl Encoder<SshMsg> for SshEncoder {
v.put_u8(3);
let preamble = serde_json::to_vec(&preamble)?;
v.extend_from_slice(&preamble);
Ok(self.0.encode(v.freeze(), buf).map_err(ioerr_cvt)?)
Ok(self.netstring.encode(v.freeze(), buf).map_err(ioerr_cvt)?)
}
}
}
@ -421,7 +465,7 @@ mod test {
#[test]
fn encode_simple() {
let mut buf = BytesMut::with_capacity(1024);
let mut encoder = SshEncoder::new();
let mut encoder = SshEncoder::new(None).unwrap();
encoder
.encode(SshMsg::new(Stdin, b"ls -l".bytes()), &mut buf)
@ -433,7 +477,7 @@ mod test {
#[test]
fn encode_zero() {
let mut buf = BytesMut::with_capacity(1024);
let mut encoder = SshEncoder::new();
let mut encoder = SshEncoder::new(None).unwrap();
encoder
.encode(SshMsg::new(Stdin, b"".bytes()), &mut buf)
@ -445,7 +489,7 @@ mod test {
#[test]
fn encode_one() {
let mut buf = BytesMut::with_capacity(1024);
let mut encoder = SshEncoder::new();
let mut encoder = SshEncoder::new(None).unwrap();
encoder
.encode(SshMsg::new(Stdin, b"X".bytes()), &mut buf)
@ -457,7 +501,7 @@ mod test {
#[test]
fn encode_multi() {
let mut buf = BytesMut::with_capacity(1024);
let mut encoder = SshEncoder::new();
let mut encoder = SshEncoder::new(None).unwrap();
encoder
.encode(SshMsg::new(Stdin, b"X".bytes()), &mut buf)
@ -472,6 +516,34 @@ mod test {
assert_eq!(buf.as_ref(), b"2:\x00X,2:\x01Y,2:\x02Z,");
}
#[test]
fn encode_compressed() {
let mut buf = BytesMut::with_capacity(1024);
let mut encoder = SshEncoder::new(Some(3)).unwrap();
encoder
.encode(
SshMsg::new(
Stdin,
b"hello hello hello hello hello hello hello hello hello".bytes(),
),
&mut buf,
)
.expect("encode failed");
assert_eq!(buf.as_ref(), b"22:\x00\x28\xb5\x2f\xfd\x00\x58\x64\x00\x00\x30\x68\x65\x6c\x6c\x6f\x20\x01\x00\x24\x2a\x45\x2c");
}
#[test]
fn encode_compressed_too_big() {
let mut buf = BytesMut::with_capacity(1024);
let mut encoder = SshEncoder::new(Some(3)).unwrap();
// 1MB, which is larger then 128KB zstd streaming buffer
let message = vec![0u8; 1048576];
let result = encoder.encode(SshMsg::new(Stdin, message.as_slice().bytes()), &mut buf);
assert!(result.is_err());
}
#[test]
fn decode_simple() {
let mut buf = BytesMut::with_capacity(1024);

View File

@ -0,0 +1,42 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This software may be used and distributed according to the terms of the
# GNU General Public License found in the LICENSE file in the root
# directory of this source tree.
$ . "${TEST_FIXTURES}/library.sh"
setup configuration
$ MONONOKE_DIRECT_PEER=1
$ setup_common_config
$ cd $TESTTMP
setup repo with 1MB file, which is larger then zstd stream buffer size
$ hginit_treemanifest repo-hg
$ cd repo-hg
$ printf '=%.0s' {1..1048576} > a
$ hg add a
$ hg ci -ma
setup master bookmarks
$ hg bookmark master_bookmark -r 'tip'
$ cd $TESTTMP
$ blobimport repo-hg/.hg repo
$ rm -rf repo-hg
start mononoke
$ mononoke
$ wait_for_mononoke
clone and checkout the repository with compression enabled
$ hg clone -U --shallow --debug "mononoke://$(mononoke_address)/repo" --config mononokepeer.compression=true 2>&1 | grep zstd
zstd compression on the wire is enabled
$ cd repo
$ hgmn checkout master_bookmark --config mononokepeer.compression=true
1 files updated, 0 files merged, 0 files removed, 0 files unresolved
(activating bookmark master_bookmark)
without compression again, no zstd indicator that compression is used
$ hgmn pull --debug 2>&1 | grep -P "(zstd|pulling|checking)"
pulling from mononoke://* (glob)
checking for updated bookmarks