add variant of bounded_traversal_stream which accepts children in the form of a stream

Summary: Adds version of `bounded_traversal_stream` where unfold returns a stream over children instead of an iterator. This function also applies back pressure on children iteration when we have too many unscheduled items.

Reviewed By: krallin

Differential Revision: D23931035

fbshipit-source-id: 2e2806653782d4e646dcdf4b2d4e624fd6543da8
This commit is contained in:
Pavel Aslanov 2020-10-07 03:36:45 -07:00 committed by Facebook GitHub Bot
parent ce9c900c76
commit daa5a0409a
3 changed files with 146 additions and 3 deletions

View File

@ -16,7 +16,7 @@ mod dag;
pub use dag::bounded_traversal_dag;
mod stream;
pub use stream::bounded_traversal_stream;
pub use stream::{bounded_traversal_stream, bounded_traversal_stream2};
mod common;

View File

@ -6,6 +6,7 @@
*/
use futures::{
future::{FutureExt, TryFutureExt},
ready,
stream::{self, FuturesUnordered, StreamExt},
Stream,
@ -64,3 +65,83 @@ where
}
})
}
/// This function is similar to `bouned_traversal_stream`:
/// - but instead of iterator over children unfold returns a stream over children
/// - this stream must be `Unpin`
/// - if unscheduled queue is too large it will suspend iteration over children stream
pub fn bounded_traversal_stream2<In, Ins, Out, Unfold, UFut, UStream, UErr>(
scheduled_max: usize,
init: Ins,
mut unfold: Unfold,
) -> impl Stream<Item = Result<Out, UErr>>
where
Ins: IntoIterator<Item = In>,
Unfold: FnMut(In) -> UFut,
UFut: Future<Output = Result<(Out, UStream), UErr>>,
UStream: Stream<Item = Result<In, UErr>> + Unpin,
{
enum Op<U, C> {
Unfold(U),
Child(C),
}
let init = init
.into_iter()
.map(|child| unfold(child).map_ok(Op::Unfold).right_future());
let mut unscheduled = VecDeque::from_iter(init);
let mut scheduled = FuturesUnordered::new();
stream::poll_fn(move |cx| {
loop {
if scheduled.is_empty() && unscheduled.is_empty() {
return Poll::Ready(None);
}
while scheduled.len() < scheduled_max {
match unscheduled.pop_front() {
Some(op) => scheduled.push(op),
None => break,
}
}
if let Some(op) = ready!(scheduled.poll_next_unpin(cx)).transpose()? {
match op {
Op::Unfold((out, children)) => {
let children = stream_into_try_future(children)
.map_ok(Op::Child)
.left_future();
unscheduled.push_back(children);
return Poll::Ready(Some(Ok(out)));
}
Op::Child((Some(child), children)) => {
unscheduled.push_back(unfold(child).map_ok(Op::Unfold).right_future());
let children = stream_into_try_future(children)
.map_ok(Op::Child)
.left_future();
// this will result in something like BFS (constraints to order of completion
// of scheduled tasks) traversal if unscheduled queue is small enough, otherwise
// it will suspend iteration over children and will put them in the unscheduled
// queue.
if unscheduled.len() > scheduled_max {
// we have too many unscheduled elements pause this children stream
unscheduled.push_back(children);
} else {
// continue polling for more children
scheduled.push(children);
}
}
_ => {}
}
}
}
})
}
fn stream_into_try_future<S, O, E>(stream: S) -> impl Future<Output = Result<(Option<O>, S), E>>
where
S: Stream<Item = Result<O, E>> + Unpin,
{
stream
.into_future()
.map(|(c, cs)| c.transpose().map(move |c| (c, cs)))
}

View File

@ -5,12 +5,14 @@
* GNU General Public License version 2.
*/
use super::{bounded_traversal, bounded_traversal_dag, bounded_traversal_stream};
use super::{
bounded_traversal, bounded_traversal_dag, bounded_traversal_stream, bounded_traversal_stream2,
};
use anyhow::Error;
use futures::{
channel::oneshot::{channel, Sender},
future::{self, FutureExt},
stream::TryStreamExt,
stream::{self, TryStreamExt},
Future,
};
use lock_ext::LockExt;
@ -535,3 +537,63 @@ async fn test_bounded_traversal_stream() -> Result<(), Error> {
assert_eq!(handle.await??, BTreeSet::from_iter(0..6));
Ok(())
}
#[tokio::test]
async fn test_bounded_traversal_stream2() -> Result<(), Error> {
// tree
// 0
// / \
// 1 2
// / / \
// 5 3 4
let tree = Tree::new(
0,
vec![
Tree::new(1, vec![Tree::leaf(5)]),
Tree::new(2, vec![Tree::leaf(3), Tree::leaf(4)]),
],
);
let tick = Tick::new();
let log: StateLog<BTreeSet<usize>> = StateLog::new();
let reference: StateLog<BTreeSet<usize>> = StateLog::new();
let traverse = bounded_traversal_stream2(2, Some(tree), {
let tick = tick.clone();
let log = log.clone();
move |Tree { id, children }| {
let log = log.clone();
tick.sleep(1).map(move |now| {
log.unfold(id, now);
Ok::<_, Error>((id, stream::iter(children.into_iter().map(Ok))))
})
}
})
.try_collect::<BTreeSet<usize>>()
.boxed();
let handle = tokio::spawn(traverse);
yield_now().await;
assert_eq!(log, reference);
tick.tick().await;
reference.unfold(0, 1);
assert_eq!(log, reference);
tick.tick().await;
reference.unfold(1, 2);
reference.unfold(2, 2);
assert_eq!(log, reference);
tick.tick().await;
reference.unfold(5, 3);
reference.unfold(3, 3);
assert_eq!(log, reference);
tick.tick().await;
reference.unfold(4, 4);
assert_eq!(log, reference);
assert_eq!(handle.await??, BTreeSet::from_iter(0..6));
Ok(())
}