async-compression: make sure decompression is framed and won't overread it's input

Summary:
- Added test for checking if decompressing isn't overreading it's input
  - As a result removed ZSTD decompression which is overreading input
- Changes in decompression code propagated usage of BufRead in bundle2 parsing code
  - As a result a bug in OuterDecoder was found where the buf was not consumed if the header value is "0"

Reviewed By: jsgf

Differential Revision: D6557440

fbshipit-source-id: 89a9f4c8790017c5b86d28d467e45f687d7323f6
This commit is contained in:
Lukas Piatkowski 2018-01-15 10:31:25 -08:00 committed by Facebook Github Bot
parent 35e3c19f1c
commit f59d6a1c8b
9 changed files with 244 additions and 164 deletions

View File

@ -7,17 +7,16 @@
//! Non-blocking, buffered compression and decompression
use std::fmt::{self, Debug, Formatter};
use std::io::{self, Read};
use std::io::{self, BufRead, Read};
use bzip2::read::BzDecoder;
use bzip2::bufread::BzDecoder;
use tokio_io::AsyncRead;
use zstd::Decoder as ZstdDecoder;
use raw::RawDecoder;
pub struct Decompressor<'a, R>
where
R: AsyncRead + 'a,
R: AsyncRead + BufRead + 'a,
{
d_type: DecompressorType,
inner: Box<RawDecoder<R> + 'a>,
@ -32,7 +31,7 @@ pub enum DecompressorType {
impl<'a, R> Decompressor<'a, R>
where
R: AsyncRead + 'a,
R: AsyncRead + BufRead + 'a,
{
pub fn new(r: R, dt: DecompressorType) -> Self {
Decompressor {
@ -43,8 +42,9 @@ where
// https://github.com/alexcrichton/flate2-rs/issues/62 to be
// fixed
DecompressorType::Gzip => unimplemented!(),
// ZstdDecoder::new() should only fail on OOM, so just call unwrap here.
DecompressorType::Zstd => Box::new(ZstdDecoder::new(r).unwrap()),
// TODO: The zstd crate is not safe for decompressing Read input, because it is
// overconsuming it
DecompressorType::Zstd => unimplemented!(),
},
}
}
@ -65,16 +65,16 @@ where
}
}
impl<'a, R: AsyncRead + 'a> Read for Decompressor<'a, R> {
impl<'a, R: AsyncRead + BufRead + 'a> Read for Decompressor<'a, R> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl<'a, R: AsyncRead + 'a> AsyncRead for Decompressor<'a, R> {}
impl<'a, R: AsyncRead + BufRead + 'a> AsyncRead for Decompressor<'a, R> {}
impl<'a, R: AsyncRead + 'a> Debug for Decompressor<'a, R> {
impl<'a, R: AsyncRead + BufRead + 'a> Debug for Decompressor<'a, R> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("Decompressor")
.field("decoder_type", &self.d_type)

View File

@ -33,5 +33,4 @@ mod test;
pub use compressor::{Compressor, CompressorType};
pub use decompressor::{Decompressor, DecompressorType};
pub const ZSTD_DEFAULT_LEVEL: i32 = 1;
pub use bzip2::Compression as Bzip2Compression;

View File

@ -6,23 +6,23 @@
//! Raw upstream decoders, plus a uniform interface for accessing them.
use std::io::{self, Read, Write};
use std::io::{self, BufRead, Read, Write};
use std::result;
use futures::Poll;
use tokio_io::AsyncWrite;
use bzip2::read::BzDecoder;
use bzip2::bufread::BzDecoder;
use bzip2::write::BzEncoder;
use zstd::{Decoder as ZstdDecoder, Encoder as ZstdEncoder};
use zstd::Encoder as ZstdEncoder;
pub trait RawDecoder<R: Read>: Read {
pub trait RawDecoder<R: BufRead>: Read {
fn get_ref(&self) -> &R;
fn get_mut(&mut self) -> &mut R;
fn into_inner(self: Box<Self>) -> R;
}
impl<R: Read> RawDecoder<R> for BzDecoder<R> {
impl<R: BufRead> RawDecoder<R> for BzDecoder<R> {
#[inline]
fn get_ref(&self) -> &R {
BzDecoder::get_ref(self)
@ -39,23 +39,6 @@ impl<R: Read> RawDecoder<R> for BzDecoder<R> {
}
}
impl<R: Read> RawDecoder<R> for ZstdDecoder<R> {
#[inline]
fn get_ref(&self) -> &R {
ZstdDecoder::get_ref(self)
}
#[inline]
fn get_mut(&mut self) -> &mut R {
ZstdDecoder::get_mut(self)
}
#[inline]
fn into_inner(self: Box<Self>) -> R {
ZstdDecoder::finish(*self)
}
}
pub trait RawEncoder<W>: AsyncWrite
where
W: AsyncWrite + Send,

View File

@ -4,30 +4,52 @@
// This software may be used and distributed according to the terms of the
// GNU General Public License version 2 or any later version.
use std::io::{self, Cursor, Write};
use std::io::{self, BufReader, Cursor, Read, Write};
use bzip2;
use futures::{Async, Poll};
use quickcheck::TestResult;
use quickcheck::{Arbitrary, Gen, TestResult};
use tokio_core::reactor::Core;
use tokio_io::AsyncWrite;
use tokio_io::io::read_to_end;
use retry::retry_write;
use ZSTD_DEFAULT_LEVEL;
use compressor::{Compressor, CompressorType};
use decompressor::Decompressor;
use membuf::MemBuf;
use metered::{MeteredRead, MeteredWrite};
quickcheck! {
fn test_bzip2_roundtrip(input: Vec<u8>) -> TestResult {
roundtrip(CompressorType::Bzip2(bzip2::Compression::Default), &input)
fn test_bzip2_roundtrip(cmprs: BzipCompression, input: Vec<u8>) -> TestResult {
roundtrip(CompressorType::Bzip2(cmprs.0), &input)
}
fn test_zstd_roundtrip(input: Vec<u8>) -> TestResult {
roundtrip(CompressorType::Zstd { level: ZSTD_DEFAULT_LEVEL }, &input)
fn test_bzip_overreading(
cmprs: BzipCompression,
compressable_input: Vec<u8>,
extra_input: Vec<u8>
) -> TestResult {
check_overreading(
CompressorType::Bzip2(cmprs.0),
compressable_input.as_slice(),
extra_input.as_slice(),
)
}
}
#[derive(Debug, Clone)]
struct BzipCompression(bzip2::Compression);
impl Arbitrary for BzipCompression {
fn arbitrary<G: Gen>(g: &mut G) -> Self {
BzipCompression(
g.choose(&[
bzip2::Compression::Fastest,
bzip2::Compression::Best,
bzip2::Compression::Default,
]).unwrap()
.clone(),
)
}
}
@ -46,12 +68,15 @@ fn roundtrip(ct: CompressorType, input: &[u8]) -> TestResult {
// Turn the MeteredWrite<Cursor> into a Cursor
let compressed_buf = compressed_buf.into_inner();
let read_buf = MeteredRead::new(MemBuf::new(32 * 1024));
let mut decoder = MeteredRead::new(Decompressor::new(read_buf, ct.decompressor_type()));
let decoder = {
let mut read_buf = BufReader::new(MeteredRead::new(MemBuf::new(32 * 1024)));
assert_matches!(decoder.get_mut().get_mut().get_mut().write_buf(compressed_buf.get_ref()),
Ok(l) if l as u64 == compressed_buf.position());
decoder.get_mut().get_mut().get_mut().mark_eof();
assert_matches!(read_buf.get_mut().get_mut().write_buf(compressed_buf.get_ref()),
Ok(l) if l as u64 == compressed_buf.position());
read_buf.get_mut().get_mut().mark_eof();
MeteredRead::new(Decompressor::new(read_buf, ct.decompressor_type()))
};
let result = Vec::with_capacity(32 * 1024);
let read_future = read_to_end(decoder, result);
@ -60,7 +85,7 @@ fn roundtrip(ct: CompressorType, input: &[u8]) -> TestResult {
let (decoder, result) = core.run(read_future).unwrap();
assert_eq!(decoder.total_thru(), input.len() as u64);
assert_eq!(
decoder.get_ref().get_ref().total_thru(),
decoder.get_ref().get_ref().get_ref().total_thru(),
compressed_buf.position()
);
assert_eq!(input, result.as_slice());
@ -122,3 +147,53 @@ fn test_fail_after_retry() {
let status = retry_write(&mut writer, &buffer);
assert_eq!(status.unwrap_err().kind(), io::ErrorKind::NotFound);
}
fn check_overreading(c_type: CompressorType, data: &[u8], extra: &[u8]) -> TestResult {
const EXTRA_SPACE: usize = 8;
let mut decompressor = {
let mut compressor = Compressor::new(Cursor::new(Vec::new()), c_type);
compressor.write_all(data).unwrap();
compressor.flush().unwrap();
let mut compressed = compressor.try_finish().unwrap().into_inner();
for u in extra {
compressed.push(*u);
}
Decompressor::new(
BufReader::new(Cursor::new(compressed)),
c_type.decompressor_type(),
)
};
{
let mut buf = vec![0u8; data.len() + extra.len() + EXTRA_SPACE];
let mut expected = Vec::new();
expected.extend_from_slice(data);
expected.extend_from_slice(vec![0u8; extra.len() + EXTRA_SPACE].as_slice());
if !(decompressor.read(buf.as_mut_slice()).unwrap() == data.len() && buf == expected) {
return TestResult::error(format!("decoding failed, buf: {:?}", buf));
}
}
{
let mut buf = vec![0u8; data.len() + extra.len() + EXTRA_SPACE];
if !(decompressor.read(buf.as_mut_slice()).unwrap() == 0 && buf == vec![0u8; buf.len()]) {
return TestResult::error(format!("detecting eof failed, buf: {:?}", buf));
}
}
{
let mut buf = Vec::new();
let mut remainder = decompressor.into_inner();
if !(remainder.read_to_end(&mut buf).unwrap() == extra.len() && buf.as_slice() == extra) {
return TestResult::error(format!("leaving remainder failed, buf: {:?}", buf));
}
}
TestResult::passed()
}

View File

@ -5,12 +5,13 @@
// GNU General Public License version 2 or any later version.
use std::collections::HashMap;
use std::io::{self, Cursor, Read};
use std::io::{self, BufReader, Cursor, Read};
use std::iter;
use std::str::{self, FromStr};
use bytes::{Bytes, BytesMut};
use futures::{Async, Stream};
use futures_ext::io::Either;
use nom::{is_alphanumeric, is_digit, ErrorKind, FindSubstring, IResult, Needed, Slice};
use slog;
use tokio_io::AsyncRead;
@ -23,7 +24,6 @@ use batch;
use errors;
use errors::*;
/// Parse an unsigned decimal integer. If it reaches the end of input, it returns Incomplete,
/// as there may be more digits following
fn digit<F: Fn(u8) -> bool>(input: &[u8], isdigit: F) -> IResult<&[u8], &[u8]> {
@ -78,12 +78,7 @@ fn ident_complete(input: &[u8]) -> IResult<&[u8], &[u8]> {
/// but I don't know if that ever happens in practice.)
named!(
param_star<HashMap<Vec<u8>, Vec<u8>>>,
do_parse!(
tag!(b"* ") >>
count: integer >> tag!(b"\n") >>
res: apply!(params, count) >>
(res)
)
do_parse!(tag!(b"* ") >> count: integer >> tag!(b"\n") >> res: apply!(params, count) >> (res))
);
/// A named parameter is a name followed by a decimal integer of the number of
@ -93,10 +88,8 @@ named!(
named!(
param_kv<HashMap<Vec<u8>, Vec<u8>>>,
do_parse!(
key: ident >> tag!(b" ") >>
len: integer >> tag!(b"\n") >>
val: take!(len) >>
(iter::once((key.to_vec(), val.to_vec())).collect())
key: ident >> tag!(b" ") >> len: integer >> tag!(b"\n") >> val: take!(len)
>> (iter::once((key.to_vec(), val.to_vec())).collect())
)
);
@ -142,11 +135,7 @@ fn notcomma(b: u8) -> bool {
named!(
batch_param_escaped<(Vec<u8>, Vec<u8>)>,
map_res!(
do_parse!(
key: take_until_and_consume1!("=") >>
val: take_while!(notcomma) >>
((key, val))
),
do_parse!(key: take_until_and_consume1!("=") >> val: take_while!(notcomma) >> ((key, val))),
|(k, v)| Ok::<_, Error>((batch::unescape(k)?, batch::unescape(v)?))
)
);
@ -224,9 +213,8 @@ fn notsemi(b: u8) -> bool {
named!(
cmd<(Vec<u8>, Vec<u8>)>,
do_parse!(
cmd: take_until_and_consume1!(" ") >>
args: take_while!(notsemi) >>
((cmd.to_vec(), args.to_vec()))
cmd: take_until_and_consume1!(" ") >> args: take_while!(notsemi)
>> ((cmd.to_vec(), args.to_vec()))
)
);
@ -290,22 +278,25 @@ fn bundle2stream(inp: &[u8]) -> IResult<&[u8], Bytes> {
// Reaching the end of the buffer just means we need more input, not that there is no
// more input. So remap EOF to WouldBlock.
#[derive(Debug)]
struct EofCursor<T>(Cursor<T>);
struct EofCursor<T>(Cursor<T>, bool);
impl<T: AsRef<[u8]>> Read for EofCursor<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.0.read(buf) {
Ok(0) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
Ok(v) => Ok(v),
Err(err) => Err(err),
if self.1 {
self.0.read(buf)
} else {
match self.0.read(buf) {
Ok(0) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
Ok(v) => Ok(v),
Err(err) => Err(err),
}
}
}
}
impl<T: AsRef<[u8]>> AsyncRead for EofCursor<T> {}
let mut cur = EofCursor(Cursor::new(inp));
{
let mut cur = {
let logger = slog::Logger::root(slog::Discard, o!());
let mut b2 = Bundle2Stream::new(&mut cur, logger);
let mut b2 = Bundle2Stream::new(BufReader::new(EofCursor(Cursor::new(inp), false)), logger);
loop {
match b2.poll() {
@ -321,9 +312,18 @@ fn bundle2stream(inp: &[u8]) -> IResult<&[u8], Bytes> {
}
}
}
b2.into_end().unwrap()
};
let (stream, rest) = inp.split_at(cur.0.position() as usize);
match cur.get_mut() {
&mut Either::A(ref mut r) => r.get_mut().get_mut().get_mut().1 = true,
&mut Either::B(ref mut r) => r.get_mut().get_mut().get_mut().get_mut().1 = true,
}
let mut x = Vec::new();
cur.read_to_end(&mut x).unwrap();
let x = inp.len() - x.len();
let (stream, rest) = inp.split_at(x);
IResult::Done(rest, Bytes::from(stream))
}
@ -334,8 +334,7 @@ fn bundle2stream(inp: &[u8]) -> IResult<&[u8], Bytes> {
fn parse_command<'a, C, F, T>(
inp: &'a [u8],
cmd: C,
parse_params: fn(&[u8], usize)
-> IResult<&[u8], HashMap<Vec<u8>, Vec<u8>>>,
parse_params: fn(&[u8], usize) -> IResult<&[u8], HashMap<Vec<u8>, Vec<u8>>>,
nargs: usize,
func: F,
) -> IResult<&'a [u8], T>
@ -344,9 +343,10 @@ where
C: AsRef<[u8]>,
{
let cmd = cmd.as_ref();
let res = do_parse!(inp,
tag!(cmd) >> tag!("\n") >>
p: call!(parse_params, nargs) >> (p));
let res = do_parse!(
inp,
tag!(cmd) >> tag!("\n") >> p: call!(parse_params, nargs) >> (p)
);
match res {
IResult::Done(rest, v) => {
@ -473,8 +473,7 @@ pub fn parse_request(buf: &mut BytesMut) -> Result<Option<Request>> {
/// return them as a SingleRequest::Unbundle object for later processing.
fn unbundle(
inp: &[u8],
parse_params: fn(&[u8], usize)
-> IResult<&[u8], HashMap<Vec<u8>, Vec<u8>>>,
parse_params: fn(&[u8], usize) -> IResult<&[u8], HashMap<Vec<u8>, Vec<u8>>>,
) -> IResult<&[u8], SingleRequest> {
// Use this as a syntactic proxy for SingleRequest::Unbundle, which works because
// SingleRequest's values are struct-like enums, and this is a struct, so the command macro is
@ -482,15 +481,14 @@ fn unbundle(
struct UnbundleCmd {
heads: Vec<String>,
}
do_parse!(inp,
do_parse!(
inp,
unbundle: command!("unbundle", UnbundleCmd, parse_params, {
heads => stringlist,
}) >>
stream: bundle2stream >>
(SingleRequest::Unbundle {
heads: unbundle.heads,
stream: stream
})
}) >> stream: bundle2stream >> (SingleRequest::Unbundle {
heads: unbundle.heads,
stream: stream,
})
)
}
@ -885,7 +883,8 @@ mod test {
#[test]
fn test_pair() {
let p = b"0000000000000000000000000000000000000000-0000000000000000000000000000000000000000";
let p =
b"0000000000000000000000000000000000000000-0000000000000000000000000000000000000000";
assert_eq!(
pair(p),
IResult::Done(&b""[..], (nodehash::NULL_HASH, nodehash::NULL_HASH))
@ -900,7 +899,8 @@ mod test {
#[test]
fn test_pairlist() {
let p = b"0000000000000000000000000000000000000000-0000000000000000000000000000000000000000 \
let p =
b"0000000000000000000000000000000000000000-0000000000000000000000000000000000000000 \
0000000000000000000000000000000000000000-0000000000000000000000000000000000000000";
assert_eq!(
pairlist(p),
@ -913,7 +913,8 @@ mod test {
)
);
let p = b"0000000000000000000000000000000000000000-0000000000000000000000000000000000000000";
let p =
b"0000000000000000000000000000000000000000-0000000000000000000000000000000000000000";
assert_eq!(
pairlist(p),
IResult::Done(&b""[..], vec![(nodehash::NULL_HASH, nodehash::NULL_HASH)])
@ -922,7 +923,8 @@ mod test {
#[test]
fn test_hashlist() {
let p = b"0000000000000000000000000000000000000000 0000000000000000000000000000000000000000 \
let p =
b"0000000000000000000000000000000000000000 0000000000000000000000000000000000000000 \
0000000000000000000000000000000000000000 0000000000000000000000000000000000000000";
assert_eq!(
hashlist(p),
@ -1104,10 +1106,11 @@ mod test_parse {
#[test]
fn test_parse_between() {
let inp = "between\n\
pairs 163\n\
1111111111111111111111111111111111111111-2222222222222222222222222222222222222222 \
3333333333333333333333333333333333333333-4444444444444444444444444444444444444444";
let inp =
"between\n\
pairs 163\n\
1111111111111111111111111111111111111111-2222222222222222222222222222222222222222 \
3333333333333333333333333333333333333333-4444444444444444444444444444444444444444";
test_parse(
inp,
Request::Single(SingleRequest::Between {
@ -1125,10 +1128,11 @@ mod test_parse {
#[test]
fn test_parse_branches() {
let inp = "branches\n\
nodes 163\n\
1111111111111111111111111111111111111111 2222222222222222222222222222222222222222 \
3333333333333333333333333333333333333333 4444444444444444444444444444444444444444";
let inp =
"branches\n\
nodes 163\n\
1111111111111111111111111111111111111111 2222222222222222222222222222222222222222 \
3333333333333333333333333333333333333333 4444444444444444444444444444444444444444";
test_parse(
inp,
Request::Single(SingleRequest::Branches {
@ -1153,9 +1157,10 @@ mod test_parse {
#[test]
fn test_parse_changegroup() {
let inp = "changegroup\n\
roots 81\n\
1111111111111111111111111111111111111111 2222222222222222222222222222222222222222";
let inp =
"changegroup\n\
roots 81\n\
1111111111111111111111111111111111111111 2222222222222222222222222222222222222222";
test_parse(
inp,
@ -1167,11 +1172,12 @@ mod test_parse {
#[test]
fn test_parse_changegroupsubset() {
let inp = "changegroupsubset\n\
heads 40\n\
1111111111111111111111111111111111111111\
bases 81\n\
2222222222222222222222222222222222222222 3333333333333333333333333333333333333333";
let inp =
"changegroupsubset\n\
heads 40\n\
1111111111111111111111111111111111111111\
bases 81\n\
2222222222222222222222222222222222222222 3333333333333333333333333333333333333333";
test_parse(
inp,
@ -1222,18 +1228,19 @@ mod test_parse {
);
// with arguments
let inp = "getbundle\n\
* 5\n\
heads 40\n\
1111111111111111111111111111111111111111\
common 81\n\
2222222222222222222222222222222222222222 3333333333333333333333333333333333333333\
bundlecaps 14\n\
cap1,CAP2,cap3\
listkeys 9\n\
key1,key2\
extra 5\n\
extra";
let inp =
"getbundle\n\
* 5\n\
heads 40\n\
1111111111111111111111111111111111111111\
common 81\n\
2222222222222222222222222222222222222222 3333333333333333333333333333333333333333\
bundlecaps 14\n\
cap1,CAP2,cap3\
listkeys 9\n\
key1,key2\
extra 5\n\
extra";
test_parse(
inp,
Request::Single(SingleRequest::Getbundle(GetbundleArgs {

View File

@ -7,12 +7,15 @@
//! Overall coordinator for parsing bundle2 streams.
use std::fmt::{self, Debug, Display, Formatter};
use std::io::{BufRead, BufReader};
use std::mem;
use futures::{Async, Poll, Stream};
use slog;
use async_compression::Decompressor;
use futures_ext::{AsyncReadExt, FramedStream, ReadLeadingBuffer, StreamWrapper};
use futures_ext::io::Either;
use tokio_io::AsyncRead;
use Bundle2Item;
@ -24,7 +27,7 @@ use stream_start::StartDecoder;
#[derive(Debug)]
pub struct Bundle2Stream<'a, R>
where
R: AsyncRead + 'a,
R: AsyncRead + BufRead + 'a,
{
inner: Bundle2StreamInner,
current_stream: CurrentStream<'a, R>,
@ -38,18 +41,20 @@ struct Bundle2StreamInner {
enum CurrentStream<'a, R>
where
R: AsyncRead + 'a,
R: AsyncRead + BufRead + 'a,
{
Start(FramedStream<R, StartDecoder>),
Outer(OuterStream<'a, ReadLeadingBuffer<R>>),
Inner(BoxInnerStream<'a, ReadLeadingBuffer<R>>),
Outer(OuterStream<'a, BufReader<ReadLeadingBuffer<R>>>),
Inner(BoxInnerStream<'a, BufReader<ReadLeadingBuffer<R>>>),
Invalid,
End,
End(ReadLeadingBuffer<
Either<BufReader<ReadLeadingBuffer<R>>, Decompressor<'a, BufReader<ReadLeadingBuffer<R>>>>,
>),
}
impl<'a, R> CurrentStream<'a, R>
where
R: AsyncRead + 'a,
R: AsyncRead + BufRead + 'a,
{
pub fn take(&mut self) -> Self {
mem::replace(self, CurrentStream::Invalid)
@ -58,7 +63,7 @@ where
impl<'a, R> Display for CurrentStream<'a, R>
where
R: AsyncRead + 'a,
R: AsyncRead + BufRead + 'a,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
use self::CurrentStream::*;
@ -68,7 +73,7 @@ where
&Outer(_) => "outer",
&Inner(_) => "inner",
&Invalid => "invalid",
&End => "end",
&End(_) => "end",
};
write!(fmt, "{}", s)
}
@ -76,7 +81,7 @@ where
impl<'a, R> Debug for CurrentStream<'a, R>
where
R: AsyncRead + Debug + 'a,
R: AsyncRead + BufRead + Debug + 'a,
{
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
@ -86,14 +91,14 @@ where
// part_inner::BoolFuture doesn't implement Debug.
&CurrentStream::Inner(_) => write!(f, "Inner(inner_stream)"),
&CurrentStream::Invalid => write!(f, "Invalid"),
&CurrentStream::End => write!(f, "End"),
&CurrentStream::End(_) => write!(f, "End"),
}
}
}
impl<'a, R> Bundle2Stream<'a, R>
where
R: AsyncRead + 'a,
R: AsyncRead + BufRead + 'a,
{
pub fn new(read: R, logger: slog::Logger) -> Bundle2Stream<'a, R> {
Bundle2Stream {
@ -108,11 +113,27 @@ where
pub fn app_errors(&self) -> &[ErrorKind] {
&self.inner.app_errors
}
pub fn into_end(
self,
) -> Option<
ReadLeadingBuffer<
Either<
BufReader<ReadLeadingBuffer<R>>,
Decompressor<'a, BufReader<ReadLeadingBuffer<R>>>,
>,
>,
> {
match self.current_stream {
CurrentStream::End(ret) => Some(ret),
_ => None,
}
}
}
impl<'a, R> Stream for Bundle2Stream<'a, R>
where
R: AsyncRead + 'a,
R: AsyncRead + BufRead + 'a,
{
type Item = Bundle2Item;
type Error = Error;
@ -132,7 +153,7 @@ impl Bundle2StreamInner {
current_stream: CurrentStream<'a, R>,
) -> (Poll<Option<Bundle2Item>, Error>, CurrentStream<'a, R>)
where
R: AsyncRead + 'a,
R: AsyncRead + BufRead + 'a,
{
match current_stream {
CurrentStream::Start(mut stream) => {
@ -143,7 +164,11 @@ impl Bundle2StreamInner {
(Ok(Async::Ready(None)), CurrentStream::Start(stream))
}
Ok(Async::Ready(Some(start))) => {
match outer_stream(&start, stream.into_inner_leading(), &self.logger) {
match outer_stream(
&start,
BufReader::new(stream.into_inner_leading()),
&self.logger,
) {
Err(e) => {
// Can't do much if reading stream level params
// failed -- go to the invalid state.
@ -185,7 +210,10 @@ impl Bundle2StreamInner {
}
Ok(Async::Ready(Some(OuterFrame::StreamEnd))) => {
// No more parts to go.
(Ok(Async::Ready(None)), CurrentStream::End)
(
Ok(Async::Ready(None)),
CurrentStream::End(stream.into_inner_leading()),
)
}
_ => panic!("Expected a header or StreamEnd!"),
}
@ -210,7 +238,7 @@ impl Bundle2StreamInner {
Err(ErrorKind::Bundle2Decode("corrupt byte stream".into()).into()),
CurrentStream::Invalid,
),
CurrentStream::End => (Ok(Async::Ready(None)), CurrentStream::End),
CurrentStream::End(s) => (Ok(Async::Ready(None)), CurrentStream::End(s)),
}
}
}

View File

@ -8,6 +8,7 @@
#![deny(warnings)]
use std::collections::{HashMap, HashSet};
use std::io::BufRead;
use std::str;
use futures::{future, Stream};
@ -54,14 +55,14 @@ type WrappedStream<'a, T> = Map<
pub trait InnerStream<'a, T>
: Stream<Item = InnerPart, Error = Error> + BoxStreamWrapper<WrappedStream<'a, T>>
where
T: AsyncRead + 'a,
T: AsyncRead + BufRead + 'a,
{
}
impl<'a, T, U> InnerStream<'a, T> for U
where
U: Stream<Item = InnerPart, Error = Error> + BoxStreamWrapper<WrappedStream<'a, T>>,
T: AsyncRead + 'a,
T: AsyncRead + BufRead + 'a,
{
}
@ -126,7 +127,7 @@ pub fn validate_header(header: PartHeader) -> Result<Option<PartHeader>> {
}
/// Convert an OuterStream into an InnerStream using the part header.
pub fn inner_stream<'a, R: AsyncRead + 'a>(
pub fn inner_stream<'a, R: AsyncRead + BufRead + 'a>(
header: &PartHeader,
stream: OuterStream<'a, R>,
logger: &slog::Logger,

View File

@ -8,6 +8,7 @@
//! stream-level parameters (see `stream_start` for those). This parses bundle2
//! part headers and puts together chunks for inner codecs to parse.
use std::io::BufRead;
use std::mem;
use ascii::AsciiString;
@ -25,7 +26,7 @@ use part_inner::validate_header;
use types::StreamHeader;
use utils::{get_decompressor_type, BytesExt};
pub fn outer_stream<'a, R: AsyncRead>(
pub fn outer_stream<'a, R: AsyncRead + BufRead>(
stream_header: &StreamHeader,
r: R,
logger: &slog::Logger,
@ -139,12 +140,12 @@ impl OuterDecoder {
return (Ok(None), OuterState::Header);
}
let _ = buf.split_to(4);
if header_len == 0 {
// A zero-length header indicates that the stream has ended.
return (Ok(Some(OuterFrame::StreamEnd)), OuterState::StreamEnd);
}
let _ = buf.split_to(4);
let part_header = Self::decode_header(buf.split_to(header_len).freeze());
if let Err(e) = part_header {
let next = match e.downcast::<ErrorKind>() {

View File

@ -7,7 +7,7 @@
use std::collections::HashMap;
use std::convert::From;
use std::fmt::Debug;
use std::io::{self, Cursor};
use std::io::{self, BufRead, BufReader, Cursor};
use std::str::FromStr;
use futures::stream::Stream;
@ -16,7 +16,7 @@ use slog_term;
use tokio_core::reactor::Core;
use tokio_io::AsyncRead;
use async_compression::{Bzip2Compression, CompressorType, ZSTD_DEFAULT_LEVEL};
use async_compression::{Bzip2Compression, CompressorType};
use async_compression::membuf::MemBuf;
use futures_ext::StreamExt;
use mercurial_types::{MPath, NodeHash, RepoPath, NULL_HASH};
@ -72,19 +72,12 @@ fn parse_uncompressed(read_ops: PartialWithErrors<GenWouldBlock>) {
#[test]
fn test_parse_unknown_compression() {
let mut core = Core::new().unwrap();
let bundle2_buf = MemBuf::from(Vec::from(UNKNOWN_COMPRESSION_BUNDLE2));
let bundle2_buf = BufReader::new(MemBuf::from(Vec::from(UNKNOWN_COMPRESSION_BUNDLE2)));
let outer_stream_err = parse_stream_start(&mut core, bundle2_buf, Some("IL")).unwrap_err();
assert_matches!(outer_stream_err.downcast::<ErrorKind>().unwrap(),
ErrorKind::Bundle2Decode(ref msg) if msg == "unknown compression 'IL'");
}
#[test]
fn test_empty_bundle_roundtrip_zstd() {
empty_bundle_roundtrip(Some(CompressorType::Zstd {
level: ZSTD_DEFAULT_LEVEL,
}));
}
#[test]
fn test_empty_bundle_roundtrip_bzip() {
empty_bundle_roundtrip(Some(CompressorType::Bzip2(Bzip2Compression::Default)));
@ -133,13 +126,6 @@ fn empty_bundle_roundtrip(ct: Option<CompressorType>) {
assert!(item.is_none());
}
#[test]
fn test_unknown_part_zstd() {
unknown_part(Some(CompressorType::Zstd {
level: ZSTD_DEFAULT_LEVEL,
}));
}
#[test]
fn test_unknown_part_bzip() {
unknown_part(Some(CompressorType::Bzip2(Bzip2Compression::Default)));
@ -200,7 +186,7 @@ fn parse_bundle(
let mut core = Core::new().unwrap();
let bundle2_buf = MemBuf::from(Vec::from(input));
let partial_read = PartialAsyncRead::new(bundle2_buf, read_ops);
let partial_read = BufReader::new(PartialAsyncRead::new(bundle2_buf, read_ops));
let stream = parse_stream_start(&mut core, partial_read, compression).unwrap();
let (res, stream) = core.next_stream(stream);
@ -222,7 +208,7 @@ fn parse_bundle(
assert!(res.is_none());
}
fn verify_cg2<'a, R: AsyncRead + 'a>(
fn verify_cg2<'a, R: AsyncRead + BufRead + 'a>(
core: &mut Core,
stream: Bundle2Stream<'a, R>,
) -> Bundle2Stream<'a, R> {
@ -347,7 +333,7 @@ fn parse_wirepack(read_ops: PartialWithErrors<GenWouldBlock>) {
let mut core = Core::new().unwrap();
let cursor = Cursor::new(WIREPACK_BUNDLE2);
let partial_read = PartialAsyncRead::new(cursor, read_ops);
let partial_read = BufReader::new(PartialAsyncRead::new(cursor, read_ops));
let stream = parse_stream_start(&mut core, partial_read, None).unwrap();
let collect_fut = stream.collect_no_consume();
@ -479,7 +465,7 @@ fn path(bytes: &[u8]) -> MPath {
MPath::new(bytes).unwrap()
}
fn parse_stream_start<'a, R: AsyncRead + 'a>(
fn parse_stream_start<'a, R: AsyncRead + BufRead + 'a>(
core: &mut Core,
reader: R,
compression: Option<&str>,
@ -512,7 +498,7 @@ fn make_root_logger() -> Logger {
Logger::root(slog_term::FullFormat::new(plain).build().fuse(), o!())
}
fn next_cg2_part<'a, R: AsyncRead + 'a>(
fn next_cg2_part<'a, R: AsyncRead + BufRead + 'a>(
core: &mut Core,
stream: Bundle2Stream<'a, R>,
) -> (changegroup::Part, Bundle2Stream<'a, R>) {