mirror of
https://github.com/facebook/sapling.git
synced 2024-10-10 16:57:49 +03:00
async-compression: make sure decompression is framed and won't overread it's input
Summary: - Added test for checking if decompressing isn't overreading it's input - As a result removed ZSTD decompression which is overreading input - Changes in decompression code propagated usage of BufRead in bundle2 parsing code - As a result a bug in OuterDecoder was found where the buf was not consumed if the header value is "0" Reviewed By: jsgf Differential Revision: D6557440 fbshipit-source-id: 89a9f4c8790017c5b86d28d467e45f687d7323f6
This commit is contained in:
parent
35e3c19f1c
commit
f59d6a1c8b
@ -7,17 +7,16 @@
|
||||
//! Non-blocking, buffered compression and decompression
|
||||
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
use std::io::{self, Read};
|
||||
use std::io::{self, BufRead, Read};
|
||||
|
||||
use bzip2::read::BzDecoder;
|
||||
use bzip2::bufread::BzDecoder;
|
||||
use tokio_io::AsyncRead;
|
||||
use zstd::Decoder as ZstdDecoder;
|
||||
|
||||
use raw::RawDecoder;
|
||||
|
||||
pub struct Decompressor<'a, R>
|
||||
where
|
||||
R: AsyncRead + 'a,
|
||||
R: AsyncRead + BufRead + 'a,
|
||||
{
|
||||
d_type: DecompressorType,
|
||||
inner: Box<RawDecoder<R> + 'a>,
|
||||
@ -32,7 +31,7 @@ pub enum DecompressorType {
|
||||
|
||||
impl<'a, R> Decompressor<'a, R>
|
||||
where
|
||||
R: AsyncRead + 'a,
|
||||
R: AsyncRead + BufRead + 'a,
|
||||
{
|
||||
pub fn new(r: R, dt: DecompressorType) -> Self {
|
||||
Decompressor {
|
||||
@ -43,8 +42,9 @@ where
|
||||
// https://github.com/alexcrichton/flate2-rs/issues/62 to be
|
||||
// fixed
|
||||
DecompressorType::Gzip => unimplemented!(),
|
||||
// ZstdDecoder::new() should only fail on OOM, so just call unwrap here.
|
||||
DecompressorType::Zstd => Box::new(ZstdDecoder::new(r).unwrap()),
|
||||
// TODO: The zstd crate is not safe for decompressing Read input, because it is
|
||||
// overconsuming it
|
||||
DecompressorType::Zstd => unimplemented!(),
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -65,16 +65,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, R: AsyncRead + 'a> Read for Decompressor<'a, R> {
|
||||
impl<'a, R: AsyncRead + BufRead + 'a> Read for Decompressor<'a, R> {
|
||||
#[inline]
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.inner.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, R: AsyncRead + 'a> AsyncRead for Decompressor<'a, R> {}
|
||||
impl<'a, R: AsyncRead + BufRead + 'a> AsyncRead for Decompressor<'a, R> {}
|
||||
|
||||
impl<'a, R: AsyncRead + 'a> Debug for Decompressor<'a, R> {
|
||||
impl<'a, R: AsyncRead + BufRead + 'a> Debug for Decompressor<'a, R> {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
f.debug_struct("Decompressor")
|
||||
.field("decoder_type", &self.d_type)
|
||||
|
@ -33,5 +33,4 @@ mod test;
|
||||
pub use compressor::{Compressor, CompressorType};
|
||||
pub use decompressor::{Decompressor, DecompressorType};
|
||||
|
||||
pub const ZSTD_DEFAULT_LEVEL: i32 = 1;
|
||||
pub use bzip2::Compression as Bzip2Compression;
|
||||
|
@ -6,23 +6,23 @@
|
||||
|
||||
//! Raw upstream decoders, plus a uniform interface for accessing them.
|
||||
|
||||
use std::io::{self, Read, Write};
|
||||
use std::io::{self, BufRead, Read, Write};
|
||||
use std::result;
|
||||
|
||||
use futures::Poll;
|
||||
use tokio_io::AsyncWrite;
|
||||
|
||||
use bzip2::read::BzDecoder;
|
||||
use bzip2::bufread::BzDecoder;
|
||||
use bzip2::write::BzEncoder;
|
||||
use zstd::{Decoder as ZstdDecoder, Encoder as ZstdEncoder};
|
||||
use zstd::Encoder as ZstdEncoder;
|
||||
|
||||
pub trait RawDecoder<R: Read>: Read {
|
||||
pub trait RawDecoder<R: BufRead>: Read {
|
||||
fn get_ref(&self) -> &R;
|
||||
fn get_mut(&mut self) -> &mut R;
|
||||
fn into_inner(self: Box<Self>) -> R;
|
||||
}
|
||||
|
||||
impl<R: Read> RawDecoder<R> for BzDecoder<R> {
|
||||
impl<R: BufRead> RawDecoder<R> for BzDecoder<R> {
|
||||
#[inline]
|
||||
fn get_ref(&self) -> &R {
|
||||
BzDecoder::get_ref(self)
|
||||
@ -39,23 +39,6 @@ impl<R: Read> RawDecoder<R> for BzDecoder<R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Read> RawDecoder<R> for ZstdDecoder<R> {
|
||||
#[inline]
|
||||
fn get_ref(&self) -> &R {
|
||||
ZstdDecoder::get_ref(self)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_mut(&mut self) -> &mut R {
|
||||
ZstdDecoder::get_mut(self)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn into_inner(self: Box<Self>) -> R {
|
||||
ZstdDecoder::finish(*self)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait RawEncoder<W>: AsyncWrite
|
||||
where
|
||||
W: AsyncWrite + Send,
|
||||
|
@ -4,30 +4,52 @@
|
||||
// This software may be used and distributed according to the terms of the
|
||||
// GNU General Public License version 2 or any later version.
|
||||
|
||||
use std::io::{self, Cursor, Write};
|
||||
use std::io::{self, BufReader, Cursor, Read, Write};
|
||||
|
||||
use bzip2;
|
||||
use futures::{Async, Poll};
|
||||
use quickcheck::TestResult;
|
||||
use quickcheck::{Arbitrary, Gen, TestResult};
|
||||
use tokio_core::reactor::Core;
|
||||
use tokio_io::AsyncWrite;
|
||||
use tokio_io::io::read_to_end;
|
||||
|
||||
use retry::retry_write;
|
||||
|
||||
use ZSTD_DEFAULT_LEVEL;
|
||||
use compressor::{Compressor, CompressorType};
|
||||
use decompressor::Decompressor;
|
||||
use membuf::MemBuf;
|
||||
use metered::{MeteredRead, MeteredWrite};
|
||||
|
||||
quickcheck! {
|
||||
fn test_bzip2_roundtrip(input: Vec<u8>) -> TestResult {
|
||||
roundtrip(CompressorType::Bzip2(bzip2::Compression::Default), &input)
|
||||
fn test_bzip2_roundtrip(cmprs: BzipCompression, input: Vec<u8>) -> TestResult {
|
||||
roundtrip(CompressorType::Bzip2(cmprs.0), &input)
|
||||
}
|
||||
|
||||
fn test_zstd_roundtrip(input: Vec<u8>) -> TestResult {
|
||||
roundtrip(CompressorType::Zstd { level: ZSTD_DEFAULT_LEVEL }, &input)
|
||||
fn test_bzip_overreading(
|
||||
cmprs: BzipCompression,
|
||||
compressable_input: Vec<u8>,
|
||||
extra_input: Vec<u8>
|
||||
) -> TestResult {
|
||||
check_overreading(
|
||||
CompressorType::Bzip2(cmprs.0),
|
||||
compressable_input.as_slice(),
|
||||
extra_input.as_slice(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct BzipCompression(bzip2::Compression);
|
||||
impl Arbitrary for BzipCompression {
|
||||
fn arbitrary<G: Gen>(g: &mut G) -> Self {
|
||||
BzipCompression(
|
||||
g.choose(&[
|
||||
bzip2::Compression::Fastest,
|
||||
bzip2::Compression::Best,
|
||||
bzip2::Compression::Default,
|
||||
]).unwrap()
|
||||
.clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,12 +68,15 @@ fn roundtrip(ct: CompressorType, input: &[u8]) -> TestResult {
|
||||
// Turn the MeteredWrite<Cursor> into a Cursor
|
||||
let compressed_buf = compressed_buf.into_inner();
|
||||
|
||||
let read_buf = MeteredRead::new(MemBuf::new(32 * 1024));
|
||||
let mut decoder = MeteredRead::new(Decompressor::new(read_buf, ct.decompressor_type()));
|
||||
let decoder = {
|
||||
let mut read_buf = BufReader::new(MeteredRead::new(MemBuf::new(32 * 1024)));
|
||||
|
||||
assert_matches!(decoder.get_mut().get_mut().get_mut().write_buf(compressed_buf.get_ref()),
|
||||
Ok(l) if l as u64 == compressed_buf.position());
|
||||
decoder.get_mut().get_mut().get_mut().mark_eof();
|
||||
assert_matches!(read_buf.get_mut().get_mut().write_buf(compressed_buf.get_ref()),
|
||||
Ok(l) if l as u64 == compressed_buf.position());
|
||||
read_buf.get_mut().get_mut().mark_eof();
|
||||
|
||||
MeteredRead::new(Decompressor::new(read_buf, ct.decompressor_type()))
|
||||
};
|
||||
|
||||
let result = Vec::with_capacity(32 * 1024);
|
||||
let read_future = read_to_end(decoder, result);
|
||||
@ -60,7 +85,7 @@ fn roundtrip(ct: CompressorType, input: &[u8]) -> TestResult {
|
||||
let (decoder, result) = core.run(read_future).unwrap();
|
||||
assert_eq!(decoder.total_thru(), input.len() as u64);
|
||||
assert_eq!(
|
||||
decoder.get_ref().get_ref().total_thru(),
|
||||
decoder.get_ref().get_ref().get_ref().total_thru(),
|
||||
compressed_buf.position()
|
||||
);
|
||||
assert_eq!(input, result.as_slice());
|
||||
@ -122,3 +147,53 @@ fn test_fail_after_retry() {
|
||||
let status = retry_write(&mut writer, &buffer);
|
||||
assert_eq!(status.unwrap_err().kind(), io::ErrorKind::NotFound);
|
||||
}
|
||||
|
||||
fn check_overreading(c_type: CompressorType, data: &[u8], extra: &[u8]) -> TestResult {
|
||||
const EXTRA_SPACE: usize = 8;
|
||||
|
||||
let mut decompressor = {
|
||||
let mut compressor = Compressor::new(Cursor::new(Vec::new()), c_type);
|
||||
compressor.write_all(data).unwrap();
|
||||
compressor.flush().unwrap();
|
||||
let mut compressed = compressor.try_finish().unwrap().into_inner();
|
||||
|
||||
for u in extra {
|
||||
compressed.push(*u);
|
||||
}
|
||||
|
||||
Decompressor::new(
|
||||
BufReader::new(Cursor::new(compressed)),
|
||||
c_type.decompressor_type(),
|
||||
)
|
||||
};
|
||||
|
||||
{
|
||||
let mut buf = vec![0u8; data.len() + extra.len() + EXTRA_SPACE];
|
||||
let mut expected = Vec::new();
|
||||
expected.extend_from_slice(data);
|
||||
expected.extend_from_slice(vec![0u8; extra.len() + EXTRA_SPACE].as_slice());
|
||||
|
||||
if !(decompressor.read(buf.as_mut_slice()).unwrap() == data.len() && buf == expected) {
|
||||
return TestResult::error(format!("decoding failed, buf: {:?}", buf));
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut buf = vec![0u8; data.len() + extra.len() + EXTRA_SPACE];
|
||||
|
||||
if !(decompressor.read(buf.as_mut_slice()).unwrap() == 0 && buf == vec![0u8; buf.len()]) {
|
||||
return TestResult::error(format!("detecting eof failed, buf: {:?}", buf));
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut buf = Vec::new();
|
||||
|
||||
let mut remainder = decompressor.into_inner();
|
||||
if !(remainder.read_to_end(&mut buf).unwrap() == extra.len() && buf.as_slice() == extra) {
|
||||
return TestResult::error(format!("leaving remainder failed, buf: {:?}", buf));
|
||||
}
|
||||
}
|
||||
|
||||
TestResult::passed()
|
||||
}
|
||||
|
@ -5,12 +5,13 @@
|
||||
// GNU General Public License version 2 or any later version.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io::{self, Cursor, Read};
|
||||
use std::io::{self, BufReader, Cursor, Read};
|
||||
use std::iter;
|
||||
use std::str::{self, FromStr};
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::{Async, Stream};
|
||||
use futures_ext::io::Either;
|
||||
use nom::{is_alphanumeric, is_digit, ErrorKind, FindSubstring, IResult, Needed, Slice};
|
||||
use slog;
|
||||
use tokio_io::AsyncRead;
|
||||
@ -23,7 +24,6 @@ use batch;
|
||||
use errors;
|
||||
use errors::*;
|
||||
|
||||
|
||||
/// Parse an unsigned decimal integer. If it reaches the end of input, it returns Incomplete,
|
||||
/// as there may be more digits following
|
||||
fn digit<F: Fn(u8) -> bool>(input: &[u8], isdigit: F) -> IResult<&[u8], &[u8]> {
|
||||
@ -78,12 +78,7 @@ fn ident_complete(input: &[u8]) -> IResult<&[u8], &[u8]> {
|
||||
/// but I don't know if that ever happens in practice.)
|
||||
named!(
|
||||
param_star<HashMap<Vec<u8>, Vec<u8>>>,
|
||||
do_parse!(
|
||||
tag!(b"* ") >>
|
||||
count: integer >> tag!(b"\n") >>
|
||||
res: apply!(params, count) >>
|
||||
(res)
|
||||
)
|
||||
do_parse!(tag!(b"* ") >> count: integer >> tag!(b"\n") >> res: apply!(params, count) >> (res))
|
||||
);
|
||||
|
||||
/// A named parameter is a name followed by a decimal integer of the number of
|
||||
@ -93,10 +88,8 @@ named!(
|
||||
named!(
|
||||
param_kv<HashMap<Vec<u8>, Vec<u8>>>,
|
||||
do_parse!(
|
||||
key: ident >> tag!(b" ") >>
|
||||
len: integer >> tag!(b"\n") >>
|
||||
val: take!(len) >>
|
||||
(iter::once((key.to_vec(), val.to_vec())).collect())
|
||||
key: ident >> tag!(b" ") >> len: integer >> tag!(b"\n") >> val: take!(len)
|
||||
>> (iter::once((key.to_vec(), val.to_vec())).collect())
|
||||
)
|
||||
);
|
||||
|
||||
@ -142,11 +135,7 @@ fn notcomma(b: u8) -> bool {
|
||||
named!(
|
||||
batch_param_escaped<(Vec<u8>, Vec<u8>)>,
|
||||
map_res!(
|
||||
do_parse!(
|
||||
key: take_until_and_consume1!("=") >>
|
||||
val: take_while!(notcomma) >>
|
||||
((key, val))
|
||||
),
|
||||
do_parse!(key: take_until_and_consume1!("=") >> val: take_while!(notcomma) >> ((key, val))),
|
||||
|(k, v)| Ok::<_, Error>((batch::unescape(k)?, batch::unescape(v)?))
|
||||
)
|
||||
);
|
||||
@ -224,9 +213,8 @@ fn notsemi(b: u8) -> bool {
|
||||
named!(
|
||||
cmd<(Vec<u8>, Vec<u8>)>,
|
||||
do_parse!(
|
||||
cmd: take_until_and_consume1!(" ") >>
|
||||
args: take_while!(notsemi) >>
|
||||
((cmd.to_vec(), args.to_vec()))
|
||||
cmd: take_until_and_consume1!(" ") >> args: take_while!(notsemi)
|
||||
>> ((cmd.to_vec(), args.to_vec()))
|
||||
)
|
||||
);
|
||||
|
||||
@ -290,22 +278,25 @@ fn bundle2stream(inp: &[u8]) -> IResult<&[u8], Bytes> {
|
||||
// Reaching the end of the buffer just means we need more input, not that there is no
|
||||
// more input. So remap EOF to WouldBlock.
|
||||
#[derive(Debug)]
|
||||
struct EofCursor<T>(Cursor<T>);
|
||||
struct EofCursor<T>(Cursor<T>, bool);
|
||||
impl<T: AsRef<[u8]>> Read for EofCursor<T> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self.0.read(buf) {
|
||||
Ok(0) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
||||
Ok(v) => Ok(v),
|
||||
Err(err) => Err(err),
|
||||
if self.1 {
|
||||
self.0.read(buf)
|
||||
} else {
|
||||
match self.0.read(buf) {
|
||||
Ok(0) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
||||
Ok(v) => Ok(v),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<T: AsRef<[u8]>> AsyncRead for EofCursor<T> {}
|
||||
|
||||
let mut cur = EofCursor(Cursor::new(inp));
|
||||
{
|
||||
let mut cur = {
|
||||
let logger = slog::Logger::root(slog::Discard, o!());
|
||||
let mut b2 = Bundle2Stream::new(&mut cur, logger);
|
||||
let mut b2 = Bundle2Stream::new(BufReader::new(EofCursor(Cursor::new(inp), false)), logger);
|
||||
|
||||
loop {
|
||||
match b2.poll() {
|
||||
@ -321,9 +312,18 @@ fn bundle2stream(inp: &[u8]) -> IResult<&[u8], Bytes> {
|
||||
}
|
||||
}
|
||||
}
|
||||
b2.into_end().unwrap()
|
||||
};
|
||||
|
||||
let (stream, rest) = inp.split_at(cur.0.position() as usize);
|
||||
match cur.get_mut() {
|
||||
&mut Either::A(ref mut r) => r.get_mut().get_mut().get_mut().1 = true,
|
||||
&mut Either::B(ref mut r) => r.get_mut().get_mut().get_mut().get_mut().1 = true,
|
||||
}
|
||||
|
||||
let mut x = Vec::new();
|
||||
cur.read_to_end(&mut x).unwrap();
|
||||
let x = inp.len() - x.len();
|
||||
let (stream, rest) = inp.split_at(x);
|
||||
IResult::Done(rest, Bytes::from(stream))
|
||||
}
|
||||
|
||||
@ -334,8 +334,7 @@ fn bundle2stream(inp: &[u8]) -> IResult<&[u8], Bytes> {
|
||||
fn parse_command<'a, C, F, T>(
|
||||
inp: &'a [u8],
|
||||
cmd: C,
|
||||
parse_params: fn(&[u8], usize)
|
||||
-> IResult<&[u8], HashMap<Vec<u8>, Vec<u8>>>,
|
||||
parse_params: fn(&[u8], usize) -> IResult<&[u8], HashMap<Vec<u8>, Vec<u8>>>,
|
||||
nargs: usize,
|
||||
func: F,
|
||||
) -> IResult<&'a [u8], T>
|
||||
@ -344,9 +343,10 @@ where
|
||||
C: AsRef<[u8]>,
|
||||
{
|
||||
let cmd = cmd.as_ref();
|
||||
let res = do_parse!(inp,
|
||||
tag!(cmd) >> tag!("\n") >>
|
||||
p: call!(parse_params, nargs) >> (p));
|
||||
let res = do_parse!(
|
||||
inp,
|
||||
tag!(cmd) >> tag!("\n") >> p: call!(parse_params, nargs) >> (p)
|
||||
);
|
||||
|
||||
match res {
|
||||
IResult::Done(rest, v) => {
|
||||
@ -473,8 +473,7 @@ pub fn parse_request(buf: &mut BytesMut) -> Result<Option<Request>> {
|
||||
/// return them as a SingleRequest::Unbundle object for later processing.
|
||||
fn unbundle(
|
||||
inp: &[u8],
|
||||
parse_params: fn(&[u8], usize)
|
||||
-> IResult<&[u8], HashMap<Vec<u8>, Vec<u8>>>,
|
||||
parse_params: fn(&[u8], usize) -> IResult<&[u8], HashMap<Vec<u8>, Vec<u8>>>,
|
||||
) -> IResult<&[u8], SingleRequest> {
|
||||
// Use this as a syntactic proxy for SingleRequest::Unbundle, which works because
|
||||
// SingleRequest's values are struct-like enums, and this is a struct, so the command macro is
|
||||
@ -482,15 +481,14 @@ fn unbundle(
|
||||
struct UnbundleCmd {
|
||||
heads: Vec<String>,
|
||||
}
|
||||
do_parse!(inp,
|
||||
do_parse!(
|
||||
inp,
|
||||
unbundle: command!("unbundle", UnbundleCmd, parse_params, {
|
||||
heads => stringlist,
|
||||
}) >>
|
||||
stream: bundle2stream >>
|
||||
(SingleRequest::Unbundle {
|
||||
heads: unbundle.heads,
|
||||
stream: stream
|
||||
})
|
||||
}) >> stream: bundle2stream >> (SingleRequest::Unbundle {
|
||||
heads: unbundle.heads,
|
||||
stream: stream,
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
@ -885,7 +883,8 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_pair() {
|
||||
let p = b"0000000000000000000000000000000000000000-0000000000000000000000000000000000000000";
|
||||
let p =
|
||||
b"0000000000000000000000000000000000000000-0000000000000000000000000000000000000000";
|
||||
assert_eq!(
|
||||
pair(p),
|
||||
IResult::Done(&b""[..], (nodehash::NULL_HASH, nodehash::NULL_HASH))
|
||||
@ -900,7 +899,8 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_pairlist() {
|
||||
let p = b"0000000000000000000000000000000000000000-0000000000000000000000000000000000000000 \
|
||||
let p =
|
||||
b"0000000000000000000000000000000000000000-0000000000000000000000000000000000000000 \
|
||||
0000000000000000000000000000000000000000-0000000000000000000000000000000000000000";
|
||||
assert_eq!(
|
||||
pairlist(p),
|
||||
@ -913,7 +913,8 @@ mod test {
|
||||
)
|
||||
);
|
||||
|
||||
let p = b"0000000000000000000000000000000000000000-0000000000000000000000000000000000000000";
|
||||
let p =
|
||||
b"0000000000000000000000000000000000000000-0000000000000000000000000000000000000000";
|
||||
assert_eq!(
|
||||
pairlist(p),
|
||||
IResult::Done(&b""[..], vec![(nodehash::NULL_HASH, nodehash::NULL_HASH)])
|
||||
@ -922,7 +923,8 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_hashlist() {
|
||||
let p = b"0000000000000000000000000000000000000000 0000000000000000000000000000000000000000 \
|
||||
let p =
|
||||
b"0000000000000000000000000000000000000000 0000000000000000000000000000000000000000 \
|
||||
0000000000000000000000000000000000000000 0000000000000000000000000000000000000000";
|
||||
assert_eq!(
|
||||
hashlist(p),
|
||||
@ -1104,10 +1106,11 @@ mod test_parse {
|
||||
|
||||
#[test]
|
||||
fn test_parse_between() {
|
||||
let inp = "between\n\
|
||||
pairs 163\n\
|
||||
1111111111111111111111111111111111111111-2222222222222222222222222222222222222222 \
|
||||
3333333333333333333333333333333333333333-4444444444444444444444444444444444444444";
|
||||
let inp =
|
||||
"between\n\
|
||||
pairs 163\n\
|
||||
1111111111111111111111111111111111111111-2222222222222222222222222222222222222222 \
|
||||
3333333333333333333333333333333333333333-4444444444444444444444444444444444444444";
|
||||
test_parse(
|
||||
inp,
|
||||
Request::Single(SingleRequest::Between {
|
||||
@ -1125,10 +1128,11 @@ mod test_parse {
|
||||
|
||||
#[test]
|
||||
fn test_parse_branches() {
|
||||
let inp = "branches\n\
|
||||
nodes 163\n\
|
||||
1111111111111111111111111111111111111111 2222222222222222222222222222222222222222 \
|
||||
3333333333333333333333333333333333333333 4444444444444444444444444444444444444444";
|
||||
let inp =
|
||||
"branches\n\
|
||||
nodes 163\n\
|
||||
1111111111111111111111111111111111111111 2222222222222222222222222222222222222222 \
|
||||
3333333333333333333333333333333333333333 4444444444444444444444444444444444444444";
|
||||
test_parse(
|
||||
inp,
|
||||
Request::Single(SingleRequest::Branches {
|
||||
@ -1153,9 +1157,10 @@ mod test_parse {
|
||||
|
||||
#[test]
|
||||
fn test_parse_changegroup() {
|
||||
let inp = "changegroup\n\
|
||||
roots 81\n\
|
||||
1111111111111111111111111111111111111111 2222222222222222222222222222222222222222";
|
||||
let inp =
|
||||
"changegroup\n\
|
||||
roots 81\n\
|
||||
1111111111111111111111111111111111111111 2222222222222222222222222222222222222222";
|
||||
|
||||
test_parse(
|
||||
inp,
|
||||
@ -1167,11 +1172,12 @@ mod test_parse {
|
||||
|
||||
#[test]
|
||||
fn test_parse_changegroupsubset() {
|
||||
let inp = "changegroupsubset\n\
|
||||
heads 40\n\
|
||||
1111111111111111111111111111111111111111\
|
||||
bases 81\n\
|
||||
2222222222222222222222222222222222222222 3333333333333333333333333333333333333333";
|
||||
let inp =
|
||||
"changegroupsubset\n\
|
||||
heads 40\n\
|
||||
1111111111111111111111111111111111111111\
|
||||
bases 81\n\
|
||||
2222222222222222222222222222222222222222 3333333333333333333333333333333333333333";
|
||||
|
||||
test_parse(
|
||||
inp,
|
||||
@ -1222,18 +1228,19 @@ mod test_parse {
|
||||
);
|
||||
|
||||
// with arguments
|
||||
let inp = "getbundle\n\
|
||||
* 5\n\
|
||||
heads 40\n\
|
||||
1111111111111111111111111111111111111111\
|
||||
common 81\n\
|
||||
2222222222222222222222222222222222222222 3333333333333333333333333333333333333333\
|
||||
bundlecaps 14\n\
|
||||
cap1,CAP2,cap3\
|
||||
listkeys 9\n\
|
||||
key1,key2\
|
||||
extra 5\n\
|
||||
extra";
|
||||
let inp =
|
||||
"getbundle\n\
|
||||
* 5\n\
|
||||
heads 40\n\
|
||||
1111111111111111111111111111111111111111\
|
||||
common 81\n\
|
||||
2222222222222222222222222222222222222222 3333333333333333333333333333333333333333\
|
||||
bundlecaps 14\n\
|
||||
cap1,CAP2,cap3\
|
||||
listkeys 9\n\
|
||||
key1,key2\
|
||||
extra 5\n\
|
||||
extra";
|
||||
test_parse(
|
||||
inp,
|
||||
Request::Single(SingleRequest::Getbundle(GetbundleArgs {
|
||||
|
@ -7,12 +7,15 @@
|
||||
//! Overall coordinator for parsing bundle2 streams.
|
||||
|
||||
use std::fmt::{self, Debug, Display, Formatter};
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::mem;
|
||||
|
||||
use futures::{Async, Poll, Stream};
|
||||
use slog;
|
||||
|
||||
use async_compression::Decompressor;
|
||||
use futures_ext::{AsyncReadExt, FramedStream, ReadLeadingBuffer, StreamWrapper};
|
||||
use futures_ext::io::Either;
|
||||
use tokio_io::AsyncRead;
|
||||
|
||||
use Bundle2Item;
|
||||
@ -24,7 +27,7 @@ use stream_start::StartDecoder;
|
||||
#[derive(Debug)]
|
||||
pub struct Bundle2Stream<'a, R>
|
||||
where
|
||||
R: AsyncRead + 'a,
|
||||
R: AsyncRead + BufRead + 'a,
|
||||
{
|
||||
inner: Bundle2StreamInner,
|
||||
current_stream: CurrentStream<'a, R>,
|
||||
@ -38,18 +41,20 @@ struct Bundle2StreamInner {
|
||||
|
||||
enum CurrentStream<'a, R>
|
||||
where
|
||||
R: AsyncRead + 'a,
|
||||
R: AsyncRead + BufRead + 'a,
|
||||
{
|
||||
Start(FramedStream<R, StartDecoder>),
|
||||
Outer(OuterStream<'a, ReadLeadingBuffer<R>>),
|
||||
Inner(BoxInnerStream<'a, ReadLeadingBuffer<R>>),
|
||||
Outer(OuterStream<'a, BufReader<ReadLeadingBuffer<R>>>),
|
||||
Inner(BoxInnerStream<'a, BufReader<ReadLeadingBuffer<R>>>),
|
||||
Invalid,
|
||||
End,
|
||||
End(ReadLeadingBuffer<
|
||||
Either<BufReader<ReadLeadingBuffer<R>>, Decompressor<'a, BufReader<ReadLeadingBuffer<R>>>>,
|
||||
>),
|
||||
}
|
||||
|
||||
impl<'a, R> CurrentStream<'a, R>
|
||||
where
|
||||
R: AsyncRead + 'a,
|
||||
R: AsyncRead + BufRead + 'a,
|
||||
{
|
||||
pub fn take(&mut self) -> Self {
|
||||
mem::replace(self, CurrentStream::Invalid)
|
||||
@ -58,7 +63,7 @@ where
|
||||
|
||||
impl<'a, R> Display for CurrentStream<'a, R>
|
||||
where
|
||||
R: AsyncRead + 'a,
|
||||
R: AsyncRead + BufRead + 'a,
|
||||
{
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
use self::CurrentStream::*;
|
||||
@ -68,7 +73,7 @@ where
|
||||
&Outer(_) => "outer",
|
||||
&Inner(_) => "inner",
|
||||
&Invalid => "invalid",
|
||||
&End => "end",
|
||||
&End(_) => "end",
|
||||
};
|
||||
write!(fmt, "{}", s)
|
||||
}
|
||||
@ -76,7 +81,7 @@ where
|
||||
|
||||
impl<'a, R> Debug for CurrentStream<'a, R>
|
||||
where
|
||||
R: AsyncRead + Debug + 'a,
|
||||
R: AsyncRead + BufRead + Debug + 'a,
|
||||
{
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
match self {
|
||||
@ -86,14 +91,14 @@ where
|
||||
// part_inner::BoolFuture doesn't implement Debug.
|
||||
&CurrentStream::Inner(_) => write!(f, "Inner(inner_stream)"),
|
||||
&CurrentStream::Invalid => write!(f, "Invalid"),
|
||||
&CurrentStream::End => write!(f, "End"),
|
||||
&CurrentStream::End(_) => write!(f, "End"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, R> Bundle2Stream<'a, R>
|
||||
where
|
||||
R: AsyncRead + 'a,
|
||||
R: AsyncRead + BufRead + 'a,
|
||||
{
|
||||
pub fn new(read: R, logger: slog::Logger) -> Bundle2Stream<'a, R> {
|
||||
Bundle2Stream {
|
||||
@ -108,11 +113,27 @@ where
|
||||
pub fn app_errors(&self) -> &[ErrorKind] {
|
||||
&self.inner.app_errors
|
||||
}
|
||||
|
||||
pub fn into_end(
|
||||
self,
|
||||
) -> Option<
|
||||
ReadLeadingBuffer<
|
||||
Either<
|
||||
BufReader<ReadLeadingBuffer<R>>,
|
||||
Decompressor<'a, BufReader<ReadLeadingBuffer<R>>>,
|
||||
>,
|
||||
>,
|
||||
> {
|
||||
match self.current_stream {
|
||||
CurrentStream::End(ret) => Some(ret),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, R> Stream for Bundle2Stream<'a, R>
|
||||
where
|
||||
R: AsyncRead + 'a,
|
||||
R: AsyncRead + BufRead + 'a,
|
||||
{
|
||||
type Item = Bundle2Item;
|
||||
type Error = Error;
|
||||
@ -132,7 +153,7 @@ impl Bundle2StreamInner {
|
||||
current_stream: CurrentStream<'a, R>,
|
||||
) -> (Poll<Option<Bundle2Item>, Error>, CurrentStream<'a, R>)
|
||||
where
|
||||
R: AsyncRead + 'a,
|
||||
R: AsyncRead + BufRead + 'a,
|
||||
{
|
||||
match current_stream {
|
||||
CurrentStream::Start(mut stream) => {
|
||||
@ -143,7 +164,11 @@ impl Bundle2StreamInner {
|
||||
(Ok(Async::Ready(None)), CurrentStream::Start(stream))
|
||||
}
|
||||
Ok(Async::Ready(Some(start))) => {
|
||||
match outer_stream(&start, stream.into_inner_leading(), &self.logger) {
|
||||
match outer_stream(
|
||||
&start,
|
||||
BufReader::new(stream.into_inner_leading()),
|
||||
&self.logger,
|
||||
) {
|
||||
Err(e) => {
|
||||
// Can't do much if reading stream level params
|
||||
// failed -- go to the invalid state.
|
||||
@ -185,7 +210,10 @@ impl Bundle2StreamInner {
|
||||
}
|
||||
Ok(Async::Ready(Some(OuterFrame::StreamEnd))) => {
|
||||
// No more parts to go.
|
||||
(Ok(Async::Ready(None)), CurrentStream::End)
|
||||
(
|
||||
Ok(Async::Ready(None)),
|
||||
CurrentStream::End(stream.into_inner_leading()),
|
||||
)
|
||||
}
|
||||
_ => panic!("Expected a header or StreamEnd!"),
|
||||
}
|
||||
@ -210,7 +238,7 @@ impl Bundle2StreamInner {
|
||||
Err(ErrorKind::Bundle2Decode("corrupt byte stream".into()).into()),
|
||||
CurrentStream::Invalid,
|
||||
),
|
||||
CurrentStream::End => (Ok(Async::Ready(None)), CurrentStream::End),
|
||||
CurrentStream::End(s) => (Ok(Async::Ready(None)), CurrentStream::End(s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@
|
||||
#![deny(warnings)]
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::io::BufRead;
|
||||
use std::str;
|
||||
|
||||
use futures::{future, Stream};
|
||||
@ -54,14 +55,14 @@ type WrappedStream<'a, T> = Map<
|
||||
pub trait InnerStream<'a, T>
|
||||
: Stream<Item = InnerPart, Error = Error> + BoxStreamWrapper<WrappedStream<'a, T>>
|
||||
where
|
||||
T: AsyncRead + 'a,
|
||||
T: AsyncRead + BufRead + 'a,
|
||||
{
|
||||
}
|
||||
|
||||
impl<'a, T, U> InnerStream<'a, T> for U
|
||||
where
|
||||
U: Stream<Item = InnerPart, Error = Error> + BoxStreamWrapper<WrappedStream<'a, T>>,
|
||||
T: AsyncRead + 'a,
|
||||
T: AsyncRead + BufRead + 'a,
|
||||
{
|
||||
}
|
||||
|
||||
@ -126,7 +127,7 @@ pub fn validate_header(header: PartHeader) -> Result<Option<PartHeader>> {
|
||||
}
|
||||
|
||||
/// Convert an OuterStream into an InnerStream using the part header.
|
||||
pub fn inner_stream<'a, R: AsyncRead + 'a>(
|
||||
pub fn inner_stream<'a, R: AsyncRead + BufRead + 'a>(
|
||||
header: &PartHeader,
|
||||
stream: OuterStream<'a, R>,
|
||||
logger: &slog::Logger,
|
||||
|
@ -8,6 +8,7 @@
|
||||
//! stream-level parameters (see `stream_start` for those). This parses bundle2
|
||||
//! part headers and puts together chunks for inner codecs to parse.
|
||||
|
||||
use std::io::BufRead;
|
||||
use std::mem;
|
||||
|
||||
use ascii::AsciiString;
|
||||
@ -25,7 +26,7 @@ use part_inner::validate_header;
|
||||
use types::StreamHeader;
|
||||
use utils::{get_decompressor_type, BytesExt};
|
||||
|
||||
pub fn outer_stream<'a, R: AsyncRead>(
|
||||
pub fn outer_stream<'a, R: AsyncRead + BufRead>(
|
||||
stream_header: &StreamHeader,
|
||||
r: R,
|
||||
logger: &slog::Logger,
|
||||
@ -139,12 +140,12 @@ impl OuterDecoder {
|
||||
return (Ok(None), OuterState::Header);
|
||||
}
|
||||
|
||||
let _ = buf.split_to(4);
|
||||
if header_len == 0 {
|
||||
// A zero-length header indicates that the stream has ended.
|
||||
return (Ok(Some(OuterFrame::StreamEnd)), OuterState::StreamEnd);
|
||||
}
|
||||
|
||||
let _ = buf.split_to(4);
|
||||
let part_header = Self::decode_header(buf.split_to(header_len).freeze());
|
||||
if let Err(e) = part_header {
|
||||
let next = match e.downcast::<ErrorKind>() {
|
||||
|
@ -7,7 +7,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::convert::From;
|
||||
use std::fmt::Debug;
|
||||
use std::io::{self, Cursor};
|
||||
use std::io::{self, BufRead, BufReader, Cursor};
|
||||
use std::str::FromStr;
|
||||
|
||||
use futures::stream::Stream;
|
||||
@ -16,7 +16,7 @@ use slog_term;
|
||||
use tokio_core::reactor::Core;
|
||||
use tokio_io::AsyncRead;
|
||||
|
||||
use async_compression::{Bzip2Compression, CompressorType, ZSTD_DEFAULT_LEVEL};
|
||||
use async_compression::{Bzip2Compression, CompressorType};
|
||||
use async_compression::membuf::MemBuf;
|
||||
use futures_ext::StreamExt;
|
||||
use mercurial_types::{MPath, NodeHash, RepoPath, NULL_HASH};
|
||||
@ -72,19 +72,12 @@ fn parse_uncompressed(read_ops: PartialWithErrors<GenWouldBlock>) {
|
||||
#[test]
|
||||
fn test_parse_unknown_compression() {
|
||||
let mut core = Core::new().unwrap();
|
||||
let bundle2_buf = MemBuf::from(Vec::from(UNKNOWN_COMPRESSION_BUNDLE2));
|
||||
let bundle2_buf = BufReader::new(MemBuf::from(Vec::from(UNKNOWN_COMPRESSION_BUNDLE2)));
|
||||
let outer_stream_err = parse_stream_start(&mut core, bundle2_buf, Some("IL")).unwrap_err();
|
||||
assert_matches!(outer_stream_err.downcast::<ErrorKind>().unwrap(),
|
||||
ErrorKind::Bundle2Decode(ref msg) if msg == "unknown compression 'IL'");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_bundle_roundtrip_zstd() {
|
||||
empty_bundle_roundtrip(Some(CompressorType::Zstd {
|
||||
level: ZSTD_DEFAULT_LEVEL,
|
||||
}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_bundle_roundtrip_bzip() {
|
||||
empty_bundle_roundtrip(Some(CompressorType::Bzip2(Bzip2Compression::Default)));
|
||||
@ -133,13 +126,6 @@ fn empty_bundle_roundtrip(ct: Option<CompressorType>) {
|
||||
assert!(item.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_part_zstd() {
|
||||
unknown_part(Some(CompressorType::Zstd {
|
||||
level: ZSTD_DEFAULT_LEVEL,
|
||||
}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_part_bzip() {
|
||||
unknown_part(Some(CompressorType::Bzip2(Bzip2Compression::Default)));
|
||||
@ -200,7 +186,7 @@ fn parse_bundle(
|
||||
let mut core = Core::new().unwrap();
|
||||
|
||||
let bundle2_buf = MemBuf::from(Vec::from(input));
|
||||
let partial_read = PartialAsyncRead::new(bundle2_buf, read_ops);
|
||||
let partial_read = BufReader::new(PartialAsyncRead::new(bundle2_buf, read_ops));
|
||||
let stream = parse_stream_start(&mut core, partial_read, compression).unwrap();
|
||||
|
||||
let (res, stream) = core.next_stream(stream);
|
||||
@ -222,7 +208,7 @@ fn parse_bundle(
|
||||
assert!(res.is_none());
|
||||
}
|
||||
|
||||
fn verify_cg2<'a, R: AsyncRead + 'a>(
|
||||
fn verify_cg2<'a, R: AsyncRead + BufRead + 'a>(
|
||||
core: &mut Core,
|
||||
stream: Bundle2Stream<'a, R>,
|
||||
) -> Bundle2Stream<'a, R> {
|
||||
@ -347,7 +333,7 @@ fn parse_wirepack(read_ops: PartialWithErrors<GenWouldBlock>) {
|
||||
let mut core = Core::new().unwrap();
|
||||
|
||||
let cursor = Cursor::new(WIREPACK_BUNDLE2);
|
||||
let partial_read = PartialAsyncRead::new(cursor, read_ops);
|
||||
let partial_read = BufReader::new(PartialAsyncRead::new(cursor, read_ops));
|
||||
|
||||
let stream = parse_stream_start(&mut core, partial_read, None).unwrap();
|
||||
let collect_fut = stream.collect_no_consume();
|
||||
@ -479,7 +465,7 @@ fn path(bytes: &[u8]) -> MPath {
|
||||
MPath::new(bytes).unwrap()
|
||||
}
|
||||
|
||||
fn parse_stream_start<'a, R: AsyncRead + 'a>(
|
||||
fn parse_stream_start<'a, R: AsyncRead + BufRead + 'a>(
|
||||
core: &mut Core,
|
||||
reader: R,
|
||||
compression: Option<&str>,
|
||||
@ -512,7 +498,7 @@ fn make_root_logger() -> Logger {
|
||||
Logger::root(slog_term::FullFormat::new(plain).build().fuse(), o!())
|
||||
}
|
||||
|
||||
fn next_cg2_part<'a, R: AsyncRead + 'a>(
|
||||
fn next_cg2_part<'a, R: AsyncRead + BufRead + 'a>(
|
||||
core: &mut Core,
|
||||
stream: Bundle2Stream<'a, R>,
|
||||
) -> (changegroup::Part, Bundle2Stream<'a, R>) {
|
||||
|
Loading…
Reference in New Issue
Block a user