convert bounded_traversal crate to new-style futures

Summary: Convert `bounded_traversal` crate to new-style futures

Reviewed By: krallin

Differential Revision: D19836232

fbshipit-source-id: 9296656da058c700b615a2e3fa915427e28fea96
This commit is contained in:
Pavel Aslanov 2020-02-12 03:48:20 -08:00 committed by Facebook Github Bot
parent b3779e4fc7
commit b862d0eaf1
6 changed files with 954 additions and 79 deletions

View File

@ -0,0 +1,59 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This software may be used and distributed according to the terms of the
* GNU General Public License found in the LICENSE file in the root
* directory of this source tree.
*/
use futures::ready;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
#[derive(Clone, Copy)]
pub(crate) struct NodeLocation<Index> {
pub node_index: Index, // node index inside execution tree
pub child_index: usize, // index inside parents children list
}
// This is essentially just a `.map` over futures `{FFut|UFut}`, this only exisists
// so it would be possible to name `FuturesUnoredered` type parameter.
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub(crate) enum Job<In, UFut, FFut> {
Unfold { value: In, future: UFut },
Fold { value: In, future: FFut },
}
pub(crate) enum JobResult<In, UFutResult, FFutResult> {
Unfold { value: In, result: UFutResult },
Fold { value: In, result: FFutResult },
}
impl<In, UFut, FFut> Future for Job<In, UFut, FFut>
where
In: Clone,
UFut: Future,
FFut: Future,
{
type Output = JobResult<In, UFut::Output, FFut::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
// see `impl<A, B> Future for Either<A, B>`
unsafe {
let result = match self.get_unchecked_mut() {
Job::Fold { value, future } => JobResult::Fold {
value: value.clone(),
result: ready!(Pin::new_unchecked(future).poll(cx)),
},
Job::Unfold { value, future } => JobResult::Unfold {
value: value.clone(),
result: ready!(Pin::new_unchecked(future).poll(cx)),
},
};
Poll::Ready(result)
}
}
}

View File

@ -0,0 +1,312 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This software may be used and distributed according to the terms of the
* GNU General Public License found in the LICENSE file in the root
* directory of this source tree.
*/
use super::{
common::{Job, JobResult, NodeLocation},
Iter,
};
use futures::{ready, stream::FuturesUnordered, StreamExt};
use std::{
collections::{HashMap, VecDeque},
future::Future,
hash::Hash,
mem,
pin::Pin,
task::{Context, Poll},
};
/// `bounded_traversal_dag` traverses implicit asynchronous DAG specified by `init`
/// and `unfold` arguments, and it also does backward pass with `fold` operation.
/// All `unfold` and `fold` operations are executed in parallel if they do not
/// depend on each other (not related by ancestor-descendant relation in implicit DAG)
/// with amount of concurrency constrained by `scheduled_max`.
///
/// ## Difference between `bounded_traversal_dag` and `bounded_traversal`
/// Obvious difference is that `bounded_traversal_dag` correctly handles DAGs
/// (`bounded_traversal` treats all children references as distinct and its execution time
/// is proportional to number of paths from the root, since DAG can be constructed to contain
/// `O(exp(N))` path it might cause problems) but it comes with a price:
/// - `bounded_traversal_dag` keeps `Out` result of computation for all the nodes
/// but `bounded_traversal` only keeps results for nodes that have not been completely
/// evaluatated
/// - `In` has additional constraints to be `Eq + Hash + Clone`
/// - `Out` has additional constraint to be `Clone`
///
/// ## `init: In`
/// Is the root of the implicit tree to be traversed
///
/// ## `unfold: FnMut(In) -> impl Future<Output = Result<(OutCtx, impl IntoIterator<Item = In>), Err>>`
/// Asynchronous function which given input value produces list of its children. And context
/// associated with current node. If this list is empty, it is a leaf of the tree, and `fold`
/// will be run on this node.
///
/// ## `fold: FnMut(OutCtx, impl Iterator<Out>) -> impl Future<Item = Result<Out, Err>>`
/// Aynchronous function which given node context and output of `fold` for its chidlren
/// should produce new output value.
///
/// ## return value `impl Future<Output = Result<Option<Out>, Err>>`
/// Result of running fold operation on the root of the tree. `None` indiciate that cycle
/// has been found.
///
pub fn bounded_traversal_dag<Err, In, Ins, Out, OutCtx, Unfold, UFut, Fold, FFut>(
scheduled_max: usize,
init: In,
unfold: Unfold,
fold: Fold,
) -> impl Future<Output = Result<Option<Out>, Err>>
where
In: Eq + Hash + Clone,
Out: Clone,
Unfold: FnMut(In) -> UFut,
UFut: Future<Output = Result<(OutCtx, Ins), Err>>,
Ins: IntoIterator<Item = In>,
Fold: FnMut(OutCtx, Iter<Out>) -> FFut,
FFut: Future<Output = Result<Out, Err>>,
{
BoundedTraversalDAG::new(scheduled_max, init, unfold, fold)
}
struct Children<Out, OutCtx> {
context: OutCtx,
children: Vec<Option<Out>>,
children_left: usize,
}
enum Node<In, Out, OutCtx> {
Pending {
parents: Vec<NodeLocation<In>>, // nodes blocked by current node
children: Option<Children<Out, OutCtx>>, // present if node waits for children to be computed
},
Done(Out),
}
#[must_use = "futures do nothing unless polled"]
struct BoundedTraversalDAG<In, Out, OutCtx, Unfold, UFut, Fold, FFut> {
init: In,
unfold: Unfold,
fold: Fold,
scheduled_max: usize,
scheduled: FuturesUnordered<Job<In, UFut, FFut>>, // jobs being executed
unscheduled: VecDeque<Job<In, UFut, FFut>>, // as of yet unscheduled jobs
execution_tree: HashMap<In, Node<In, Out, OutCtx>>, // tree tracking execution process
}
impl<Err, In, Ins, Out, OutCtx, Unfold, UFut, Fold, FFut>
BoundedTraversalDAG<In, Out, OutCtx, Unfold, UFut, Fold, FFut>
where
In: Clone + Eq + Hash,
Out: Clone,
Unfold: FnMut(In) -> UFut,
UFut: Future<Output = Result<(OutCtx, Ins), Err>>,
Ins: IntoIterator<Item = In>,
Fold: FnMut(OutCtx, Iter<Out>) -> FFut,
FFut: Future<Output = Result<Out, Err>>,
{
fn new(scheduled_max: usize, init: In, unfold: Unfold, fold: Fold) -> Self {
let mut this = Self {
init: init.clone(),
unfold,
fold,
scheduled_max,
scheduled: FuturesUnordered::new(),
unscheduled: VecDeque::new(),
execution_tree: HashMap::new(),
};
let init_out = this.enqueue_unfold(
NodeLocation {
node_index: init.clone(),
child_index: 0,
},
init,
);
// can not be resolved since execution tree is empty
debug_assert!(init_out.is_none());
this
}
fn enqueue_unfold(&mut self, parent: NodeLocation<In>, value: In) -> Option<Out> {
match self.execution_tree.get_mut(&value) {
None => {
// schedule unfold for previously unseen `value`
self.execution_tree.insert(
value.clone(),
Node::Pending {
parents: vec![parent],
children: None,
},
);
self.unscheduled.push_front(Job::Unfold {
value: value.clone(),
future: (self.unfold)(value),
});
None
}
Some(Node::Pending { parents, .. }) => {
// we already have a node associated with the same input value,
// register as a dependency for this node.
parents.push(parent);
None
}
Some(Node::Done(result)) => Some(result.clone()),
}
}
fn enqueue_fold(&mut self, value: In, context: OutCtx, children: Iter<Out>) {
self.unscheduled.push_front(Job::Fold {
value,
future: (self.fold)(context, children),
});
}
fn process_unfold(&mut self, value: In, (context, children): (OutCtx, Ins)) {
// schedule unfold for node's children
let mut children_left = 0;
let children: Vec<_> = children
.into_iter()
.enumerate()
.map(|(child_index, child)| {
let out = self.enqueue_unfold(
NodeLocation {
node_index: value.clone(),
child_index,
},
child,
);
if out.is_none() {
children_left += 1;
}
out
})
.collect();
if children_left != 0 {
// update pending node with `wait` state
let node = self
.execution_tree
.get_mut(&value)
.expect("unfold referenced invalid node");
match node {
Node::Pending { children: wait, .. } => {
mem::replace(
wait,
Some(Children {
context,
children,
children_left,
}),
);
}
_ => unreachable!("running unfold for Node::Done"),
}
} else {
// do not have any dependencies (leaf node), schedule fold immediately
self.enqueue_fold(value, context, children.into_iter().flatten());
}
}
fn process_fold(&mut self, value: In, result: Out) {
// mark node as done
let node = self
.execution_tree
.get_mut(&value)
.expect("fold referenced invalid node");
let parents = match mem::replace(node, Node::Done(result.clone())) {
Node::Pending { parents, .. } => parents,
_ => unreachable!("running fold for Node::Done"),
};
// update all the parents wait for this result
for parent in parents {
self.update_location(parent, result.clone());
}
}
fn update_location(&mut self, loc: NodeLocation<In>, result: Out) {
let node = self
.execution_tree
.get_mut(&loc.node_index)
.expect("`update_location` referenced invalid node");
let children = match node {
Node::Pending { children, .. } => children,
_ => unreachable!("updating already resolved parent node"),
};
let no_children_left = {
// update parent
let mut children = children
.as_mut()
.expect("`update_location` referenced not blocked node");
debug_assert!(children.children[loc.child_index].is_none());
children.children[loc.child_index] = Some(result);
children.children_left -= 1;
children.children_left == 0
};
if no_children_left {
// all parents children have been completed, so we need
// to schedule fold operation for it
let Children {
context, children, ..
} = children
.take()
.expect("`update_location` reference node without children");
self.enqueue_fold(loc.node_index, context, children.into_iter().flatten());
}
}
}
impl<Err, In, Ins, Out, OutCtx, Unfold, UFut, Fold, FFut> Future
for BoundedTraversalDAG<In, Out, OutCtx, Unfold, UFut, Fold, FFut>
where
In: Eq + Hash + Clone,
Out: Clone,
Unfold: FnMut(In) -> UFut,
UFut: Future<Output = Result<(OutCtx, Ins), Err>>,
Ins: IntoIterator<Item = In>,
Fold: FnMut(OutCtx, Iter<Out>) -> FFut,
FFut: Future<Output = Result<Out, Err>>,
{
type Output = Result<Option<Out>, Err>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
loop {
if this.unscheduled.is_empty() && this.scheduled.is_empty() {
// we have not received result of with `value == init` and
// nothing is scheduled or unscheduled, it means that we have
// cycle dependency somewhere inside input graph
return Poll::Ready(Ok(None));
}
// schedule as many jobs as possible
for job in this.unscheduled.drain(
..std::cmp::min(
this.unscheduled.len(),
this.scheduled_max - this.scheduled.len(),
),
) {
this.scheduled.push(job);
}
// execute scheduled until it is blocked or done
if let Some(job_result) = ready!(this.scheduled.poll_next_unpin(cx)) {
match job_result {
JobResult::Unfold { value, result } => this.process_unfold(value, result?),
JobResult::Fold { value, result } => {
// we have computed value associated with `init` node
if value == this.init {
// all jobs have to be completed and execution_tree empty
assert!(this.unscheduled.is_empty());
assert!(this.scheduled.is_empty());
return Poll::Ready(Ok(Some(result?)));
}
this.process_fold(value, result?);
}
}
}
}
}
}

View File

@ -6,8 +6,23 @@
* directory of this source tree.
*/
//! Read the documentation of [bounded_traversal](crate::bounded_traversal),
//! [bounded_traversal_dag](crate::bounded_traversal_dag) and
//! [bounded_traversal_stream](crate::bounded_traversal_stream)
mod tree;
pub use tree::bounded_traversal;
mod dag;
pub use dag::bounded_traversal_dag;
mod stream;
pub use stream::bounded_traversal_stream;
mod common;
#[cfg(test)]
mod tests;
/// A type used frequently in fold-like invocations inside this module
pub type Iter<Out> = std::iter::Flatten<std::vec::IntoIter<Option<Out>>>;

View File

@ -8,13 +8,10 @@
use futures::{
ready,
stream::{self, FuturesUnordered},
task::Poll,
Future, Stream,
stream::{self, FuturesUnordered, StreamExt},
Stream,
};
use std::collections::VecDeque;
use std::iter::FromIterator;
use std::pin::Pin;
use std::{collections::VecDeque, future::Future, iter::FromIterator, task::Poll};
/// `bounded_traversal_stream` traverses implicit asynchronous tree specified by `init`
/// and `unfold` arguments. All `unfold` operations are executed in parallel if they
@ -46,9 +43,8 @@ where
Ins: IntoIterator<Item = In>,
{
let mut unscheduled = VecDeque::from_iter(init);
let mut scheduled = Pin::new(Box::new(FuturesUnordered::new()));
let mut scheduled = FuturesUnordered::new();
stream::poll_fn(move |cx| loop {
let scheduled = scheduled.as_mut();
if scheduled.is_empty() && unscheduled.is_empty() {
return Poll::Ready(None);
}
@ -59,8 +55,7 @@ where
scheduled.push(unfold(item))
}
let poll = scheduled.poll_next(cx);
if let Some((out, children)) = ready!(poll).transpose()? {
if let Some((out, children)) = ready!(scheduled.poll_next_unpin(cx)).transpose()? {
for child in children {
unscheduled.push_front(child);
}

View File

@ -6,25 +6,24 @@
* directory of this source tree.
*/
use super::bounded_traversal_stream;
use super::{bounded_traversal, bounded_traversal_dag, bounded_traversal_stream};
use anyhow::Error;
use futures::{
channel::oneshot::{channel, Canceled, Sender},
future::{self, FutureExt, TryFutureExt},
channel::oneshot::{channel, Sender},
future::{self, FutureExt},
stream::TryStreamExt,
Future,
};
use lock_ext::LockExt;
use maplit::hashmap;
use pretty_assertions::assert_eq;
use std::{
cmp::{Ord, Ordering},
collections::{BTreeSet, BinaryHeap},
iter::FromIterator,
sync::{Arc, Mutex},
thread,
time::Duration,
};
use tokio::runtime::Runtime;
use tokio::task::yield_now;
// Tree for test purposes
struct Tree {
@ -88,7 +87,7 @@ impl Tick {
}
}
fn tick(&self) {
async fn tick(&self) {
let (current_time, done) = self.inner.with(|inner| {
inner.current_time += 1;
let mut done = Vec::new();
@ -105,20 +104,21 @@ impl Tick {
for sender in done {
sender.send(current_time).unwrap();
}
yield_now().await
}
fn sleep(&self, delay: usize) -> impl Future<Output = Result<usize, Canceled>> {
fn sleep(&self, delay: usize) -> impl Future<Output = usize> {
let this = self.clone();
let (send, recv) = channel();
future::lazy(move |_cx| {
async move {
let (send, recv) = channel();
this.inner.with(move |inner| {
inner.events.push(TickEvent {
time: inner.current_time + delay,
sender: send,
});
});
})
.then(|_| recv.map(|v| v))
recv.await.expect("peer closed")
}
}
}
@ -126,7 +126,7 @@ impl Tick {
#[derive(Debug, Eq, PartialEq, Hash, Clone, Ord, PartialOrd)]
enum State<V> {
Unfold { id: usize, time: usize },
Done { value: Option<V> },
Fold { id: usize, time: usize, value: V },
}
#[derive(Clone, Debug)]
@ -141,15 +141,15 @@ impl<V: Ord> StateLog<V> {
}
}
fn fold(&self, id: usize, time: usize, value: V) {
self.states
.with(move |states| states.insert(State::Fold { id, time, value }));
}
fn unfold(&self, id: usize, time: usize) {
self.states
.with(move |states| states.insert(State::Unfold { id, time }));
}
fn done(&self, value: Option<V>) {
self.states
.with(move |states| states.insert(State::Done { value }));
}
}
impl<V: Ord + Clone> PartialEq for StateLog<V> {
@ -158,55 +158,327 @@ impl<V: Ord + Clone> PartialEq for StateLog<V> {
}
}
#[test]
fn test_tick() -> Result<(), Error> {
use futures::stream::FuturesUnordered;
#[tokio::test]
async fn test_tick() -> Result<(), Error> {
let log = Arc::new(Mutex::new(Vec::new()));
let mut reference = Vec::new();
let tick = Tick::new();
let runtime = Runtime::new()?;
let futs: FuturesUnordered<
Box<dyn Future<Output = Result<(), Canceled>> + Sync + Send + Unpin>,
> = FuturesUnordered::new();
futs.push(Box::new(tick.sleep(3).map_ok({
let handle = tokio::spawn({
let log = log.clone();
move |t| log.with(|l| l.push((3, t)))
})));
futs.push(Box::new(tick.sleep(1).map_ok({
let log = log.clone();
move |t| log.with(|l| l.push((1, t)))
})));
futs.push(Box::new(tick.sleep(2).map_ok({
let log = log.clone();
move |t| log.with(|l| l.push((2, t)))
})));
runtime.spawn(futs.try_for_each(|f| future::ok(f)));
thread::sleep(Duration::from_millis(50));
let tick = tick.clone();
async move {
let f0 = tick.sleep(3).map(|t| log.with(|l| l.push((3usize, t))));
let f1 = tick.sleep(1).map(|t| log.with(|l| l.push((1usize, t))));
let f2 = tick.sleep(2).map(|t| log.with(|l| l.push((2usize, t))));
future::join3(f0, f1, f2).await;
}
});
yield_now().await;
let tick = move || {
tick.tick();
thread::sleep(Duration::from_millis(50));
};
tick();
reference.push((1, 1));
tick.tick().await;
reference.push((1usize, 1usize));
assert_eq!(log.with(|l| l.clone()), reference);
tick();
tick.tick().await;
reference.push((2, 2));
assert_eq!(log.with(|l| l.clone()), reference);
tick();
tick.tick().await;
reference.push((3, 3));
assert_eq!(log.with(|l| l.clone()), reference);
handle.await?;
Ok(())
}
#[test]
fn test_bounded_traversal_stream() -> Result<(), Error> {
#[tokio::test]
async fn test_bounded_traversal() -> Result<(), Error> {
// tree
// 0
// / \
// 1 2
// / / \
// 5 3 4
let tree = Tree::new(
0,
vec![
Tree::new(1, vec![Tree::leaf(5)]),
Tree::new(2, vec![Tree::leaf(3), Tree::leaf(4)]),
],
);
let tick = Tick::new();
let log: StateLog<String> = StateLog::new();
let reference: StateLog<String> = StateLog::new();
let traverse = bounded_traversal(
2, // level of parallelism
tree,
// unfold
{
let tick = tick.clone();
let log = log.clone();
move |Tree { id, children }| {
let log = log.clone();
tick.sleep(1).map(move |now| {
log.unfold(id, now);
Ok::<_, Error>((id, children))
})
}
},
// fold
{
let tick = tick.clone();
let log = log.clone();
move |id, children| {
let log = log.clone();
tick.sleep(1).map(move |now| {
let value = id.to_string() + &children.collect::<String>();
log.fold(id, now, value.clone());
Ok::<_, Error>(value)
})
}
},
)
.boxed();
let handle = tokio::spawn(traverse);
yield_now().await;
assert_eq!(log, reference);
tick.tick().await;
reference.unfold(0, 1);
assert_eq!(log, reference);
tick.tick().await;
reference.unfold(1, 2);
reference.unfold(2, 2);
assert_eq!(log, reference);
// only two unfolds executet because of the parallelism constraint
tick.tick().await;
reference.unfold(5, 3);
reference.unfold(4, 3);
assert_eq!(log, reference);
tick.tick().await;
reference.fold(4, 4, "4".to_string());
reference.fold(5, 4, "5".to_string());
assert_eq!(log, reference);
tick.tick().await;
reference.fold(1, 5, "15".to_string());
reference.unfold(3, 5);
assert_eq!(log, reference);
tick.tick().await;
reference.fold(3, 6, "3".to_string());
assert_eq!(log, reference);
tick.tick().await;
reference.fold(2, 7, "234".to_string());
assert_eq!(log, reference);
tick.tick().await;
reference.fold(0, 8, "015234".to_string());
assert_eq!(log, reference);
assert_eq!(handle.await??, "015234");
Ok(())
}
#[tokio::test]
async fn test_bounded_traversal_dag() -> Result<(), Error> {
// dag
// 0
// / \
// 1 2
// \ / \
// 3 4
// / \
// 5 6
// \ /
// 7
// |
// 4 - will be resolved by the time it is reached
let dag = hashmap! {
0 => vec![1, 2],
1 => vec![3],
2 => vec![3, 4],
3 => vec![5, 6],
4 => vec![],
5 => vec![7],
6 => vec![7],
7 => vec![4],
};
let tick = Tick::new();
let log: StateLog<String> = StateLog::new();
let reference: StateLog<String> = StateLog::new();
let traverse = bounded_traversal_dag(
2, // level of parallelism
0,
// unfold
{
let tick = tick.clone();
let log = log.clone();
move |id| {
let log = log.clone();
let children = dag.get(&id).cloned().unwrap_or_default();
tick.sleep(1).map(move |now| {
log.unfold(id, now);
Ok::<_, Error>((id, children))
})
}
},
// fold
{
let tick = tick.clone();
let log = log.clone();
move |id, children| {
let log = log.clone();
tick.sleep(1).map(move |now| {
let value = id.to_string() + &children.collect::<String>();
log.fold(id, now, value.clone());
Ok(value)
})
}
},
)
.boxed();
let handle = tokio::spawn(traverse);
yield_now().await;
assert_eq!(log, reference);
tick.tick().await;
reference.unfold(0, 1);
assert_eq!(log, reference);
tick.tick().await;
reference.unfold(1, 2);
reference.unfold(2, 2);
assert_eq!(log, reference);
tick.tick().await;
reference.unfold(3, 3);
reference.unfold(4, 3);
assert_eq!(log, reference);
tick.tick().await;
reference.fold(4, 4, "4".to_string());
reference.unfold(6, 4);
assert_eq!(log, reference);
tick.tick().await;
reference.unfold(5, 5);
reference.unfold(7, 5);
assert_eq!(log, reference);
tick.tick().await;
reference.fold(7, 6, "74".to_string());
assert_eq!(log, reference);
tick.tick().await;
reference.fold(5, 7, "574".to_string());
reference.fold(6, 7, "674".to_string());
assert_eq!(log, reference);
tick.tick().await;
reference.fold(3, 8, "3574674".to_string());
assert_eq!(log, reference);
tick.tick().await;
reference.fold(1, 9, "13574674".to_string());
reference.fold(2, 9, "235746744".to_string());
assert_eq!(log, reference);
tick.tick().await;
reference.fold(0, 10, "013574674235746744".to_string());
assert_eq!(log, reference);
assert_eq!(handle.await??, Some("013574674235746744".to_string()));
Ok(())
}
#[tokio::test]
async fn test_bounded_traversal_dag_with_cycle() -> Result<(), Error> {
// graph with cycle
// 0
// / \
// 1 2
// \ /
// 3
// |
// 2 <- forms cycle
let graph = hashmap! {
0 => vec![1, 2],
1 => vec![3],
2 => vec![3],
3 => vec![2],
};
let tick = Tick::new();
let log: StateLog<String> = StateLog::new();
let reference: StateLog<String> = StateLog::new();
let traverse = bounded_traversal_dag(
2, // level of parallelism
0,
// unfold
{
let tick = tick.clone();
let log = log.clone();
move |id| {
let log = log.clone();
let children = graph.get(&id).cloned().unwrap_or_default();
tick.sleep(1).map(move |now| {
log.unfold(id, now);
Ok::<_, Error>((id, children))
})
}
},
// fold
{
let tick = tick.clone();
let log = log.clone();
move |id, children| {
let log = log.clone();
tick.sleep(1).map(move |now| {
let value = id.to_string() + &children.collect::<String>();
log.fold(id, now, value.clone());
Ok(value)
})
}
},
)
.boxed();
let handle = tokio::spawn(traverse);
yield_now().await;
assert_eq!(log, reference);
tick.tick().await;
reference.unfold(0, 1);
assert_eq!(log, reference);
tick.tick().await;
reference.unfold(1, 2);
reference.unfold(2, 2);
assert_eq!(log, reference);
tick.tick().await;
reference.unfold(3, 3);
assert_eq!(log, reference);
assert_eq!(handle.await??, None); // cycle detected
Ok(())
}
#[tokio::test]
async fn test_bounded_traversal_stream() -> Result<(), Error> {
// tree
// 0
// / \
@ -224,50 +496,43 @@ fn test_bounded_traversal_stream() -> Result<(), Error> {
let tick = Tick::new();
let log: StateLog<BTreeSet<usize>> = StateLog::new();
let reference: StateLog<BTreeSet<usize>> = StateLog::new();
let rt = Runtime::new()?;
let traverse = bounded_traversal_stream(2, Some(tree), {
let tick = tick.clone();
let log = log.clone();
move |Tree { id, children }| {
let log = log.clone();
tick.sleep(1).map_ok(move |now| {
tick.sleep(1).map(move |now| {
log.unfold(id, now);
(id, children)
Ok::<_, Error>((id, children))
})
}
});
rt.spawn(traverse.try_collect().map_ok({
let log = log.clone();
move |items: Vec<usize>| log.done(Some(BTreeSet::from_iter(items)))
}));
})
.try_collect::<BTreeSet<usize>>()
.boxed();
let handle = tokio::spawn(traverse);
let tick = move || {
tick.tick();
thread::sleep(Duration::from_millis(50));
};
thread::sleep(Duration::from_millis(50));
yield_now().await;
assert_eq!(log, reference);
tick();
tick.tick().await;
reference.unfold(0, 1);
assert_eq!(log, reference);
tick();
tick.tick().await;
reference.unfold(1, 2);
reference.unfold(2, 2);
assert_eq!(log, reference);
tick();
tick.tick().await;
reference.unfold(5, 3);
reference.unfold(4, 3);
assert_eq!(log, reference);
tick();
tick.tick().await;
reference.unfold(3, 4);
reference.done(Some(BTreeSet::from_iter(0..6)));
assert_eq!(log, reference);
assert_eq!(handle.await??, BTreeSet::from_iter(0..6));
Ok(())
}

View File

@ -0,0 +1,229 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This software may be used and distributed according to the terms of the
* GNU General Public License found in the LICENSE file in the root
* directory of this source tree.
*/
use super::{
common::{Job, JobResult},
Iter,
};
use futures::{ready, stream::FuturesUnordered, StreamExt};
use std::{
collections::{HashMap, VecDeque},
future::Future,
pin::Pin,
task::{Context, Poll},
};
/// `bounded_traversal` traverses implicit asynchronous tree specified by `init`
/// and `unfold` arguments, and it also does backward pass with `fold` operation.
/// All `unfold` and `fold` operations are executed in parallel if they do not
/// depend on each other (not related by ancestor-descendant relation in implicit tree)
/// with amount of concurrency constrained by `scheduled_max`.
///
/// ## `init: In`
/// Is the root of the implicit tree to be traversed
///
/// ## `unfold: FnMut(In) -> impl Future<Output = Result<(OutCtx, impl IntoIterator<Item = In>), Err>>`
/// Asynchronous function which given input value produces list of its children. And context
/// associated with current node. If this list is empty, it is a leaf of the tree, and `fold`
/// will be run on this node.
///
/// ## `fold: FnMut(OutCtx, impl Iterator<Out>) -> impl Future<Output = Result<Out, Err>>`
/// Aynchronous function which given node context and output of `fold` for its chidlren
/// should produce new output value.
///
/// ## return value `impl Future<Output = Result<Out, Err>>`
/// Result of running fold operation on the root of the tree.
///
pub fn bounded_traversal<Err, In, Ins, Out, OutCtx, Unfold, UFut, Fold, FFut>(
scheduled_max: usize,
init: In,
unfold: Unfold,
fold: Fold,
) -> impl Future<Output = Result<Out, Err>>
where
Unfold: FnMut(In) -> UFut,
UFut: Future<Output = Result<(OutCtx, Ins), Err>>,
Ins: IntoIterator<Item = In>,
Fold: FnMut(OutCtx, Iter<Out>) -> FFut,
FFut: Future<Output = Result<Out, Err>>,
{
BoundedTraversal::new(scheduled_max, init, unfold, fold)
}
// execution tree node
struct Node<Out, OutCtx> {
parent: NodeLocation, // location of this node relative to it's parent
context: OutCtx, // context associated with node
children: Vec<Option<Out>>, // results of children folds
children_left: usize, // number of unresolved children
}
#[derive(Clone, Copy, Eq, PartialEq, Hash)]
struct NodeIndex(usize);
type NodeLocation = super::common::NodeLocation<NodeIndex>;
#[must_use = "futures do nothing unless polled"]
struct BoundedTraversal<Out, OutCtx, Unfold, UFut, Fold, FFut> {
unfold: Unfold,
fold: Fold,
scheduled_max: usize,
scheduled: FuturesUnordered<Job<NodeLocation, UFut, FFut>>, // jobs being executed
unscheduled: VecDeque<Job<NodeLocation, UFut, FFut>>, // as of yet unscheduled jobs
execution_tree: HashMap<NodeIndex, Node<Out, OutCtx>>, // tree tracking execution process
execution_tree_index: NodeIndex, // last allocated node index
}
impl<Err, In, Ins, Out, OutCtx, Unfold, UFut, Fold, FFut>
BoundedTraversal<Out, OutCtx, Unfold, UFut, Fold, FFut>
where
Unfold: FnMut(In) -> UFut,
UFut: Future<Output = Result<(OutCtx, Ins), Err>>,
Ins: IntoIterator<Item = In>,
Fold: FnMut(OutCtx, Iter<Out>) -> FFut,
FFut: Future<Output = Result<Out, Err>>,
{
fn new(scheduled_max: usize, init: In, unfold: Unfold, fold: Fold) -> Self {
let mut this = Self {
unfold,
fold,
scheduled_max,
scheduled: FuturesUnordered::new(),
unscheduled: VecDeque::new(),
execution_tree: HashMap::new(),
execution_tree_index: NodeIndex(0),
};
this.enqueue_unfold(
NodeLocation {
node_index: NodeIndex(0),
child_index: 0,
},
init,
);
this
}
fn enqueue_unfold(&mut self, parent: NodeLocation, value: In) {
self.unscheduled.push_front(Job::Unfold {
value: parent,
future: (self.unfold)(value),
});
}
fn enqueue_fold(&mut self, parent: NodeLocation, context: OutCtx, children: Iter<Out>) {
self.unscheduled.push_front(Job::Fold {
value: parent,
future: (self.fold)(context, children),
});
}
fn process_unfold(&mut self, parent: NodeLocation, (context, children): (OutCtx, Ins)) {
// allocate index
self.execution_tree_index = NodeIndex(self.execution_tree_index.0 + 1);
let node_index = self.execution_tree_index;
// schedule unfold for node's children
let count = children.into_iter().fold(0, |child_index, child| {
self.enqueue_unfold(
NodeLocation {
node_index,
child_index,
},
child,
);
child_index + 1
});
if count != 0 {
// allocate node
let mut children = Vec::new();
children.resize_with(count, || None);
self.execution_tree.insert(
node_index,
Node {
parent,
context,
children,
children_left: count,
},
);
} else {
// leaf node schedules fold for itself immediately
self.enqueue_fold(parent, context, Vec::new().into_iter().flatten());
}
}
fn process_fold(&mut self, parent: NodeLocation, result: Out) {
// update parent
let node = self
.execution_tree
.get_mut(&parent.node_index)
.expect("fold referenced invalid node");
debug_assert!(node.children[parent.child_index].is_none());
node.children[parent.child_index] = Some(result);
node.children_left -= 1;
if node.children_left == 0 {
// all parents children have been completed, so we need
// to schedule fold operation for it
let Node {
parent,
context,
children,
..
} = self
.execution_tree
.remove(&parent.node_index)
.expect("fold referenced invalid node");
self.enqueue_fold(parent, context, children.into_iter().flatten());
}
}
}
impl<Err, In, Ins, Out, OutCtx, Unfold, UFut, Fold, FFut> Future
for BoundedTraversal<Out, OutCtx, Unfold, UFut, Fold, FFut>
where
Unfold: FnMut(In) -> UFut,
UFut: Future<Output = Result<(OutCtx, Ins), Err>>,
Ins: IntoIterator<Item = In>,
Fold: FnMut(OutCtx, Iter<Out>) -> FFut,
FFut: Future<Output = Result<Out, Err>>,
{
type Output = Result<Out, Err>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
loop {
// schedule as many jobs as possible
for job in this.unscheduled.drain(
..std::cmp::min(
this.unscheduled.len(),
this.scheduled_max - this.scheduled.len(),
),
) {
this.scheduled.push(job);
}
// execute scheduled until it is blocked or done
if let Some(job_result) = ready!(this.scheduled.poll_next_unpin(cx)) {
match job_result {
JobResult::Unfold { value, result } => this.process_unfold(value, result?),
JobResult::Fold { value, result } => {
// `0` is special index which means whole tree have been executed
if value.node_index == NodeIndex(0) {
// all jobs have to be completed and execution_tree empty
assert!(this.execution_tree.is_empty());
assert!(this.unscheduled.is_empty());
assert!(this.scheduled.is_empty());
return Poll::Ready(result);
}
this.process_fold(value, result?);
}
}
}
}
}
}