From 0436de244ca8dec3744af57e9d03ccf9445f4d13 Mon Sep 17 00:00:00 2001 From: Thomas Orozco Date: Tue, 30 Jul 2019 09:42:39 -0700 Subject: [PATCH] rust: add stream_clone to futures-ext Summary: stream_clone() takes a stream of cloneable items (and errors) and clones it n ways to n streams. There's no buffering - all output streams must consume each item before the next input is consumed. Output streams can be dropped independently; the input is dropped if all outputs are dropped. Reviewed By: Imxset21 Differential Revision: D15746068 fbshipit-source-id: 7cf1e92b36449ae2112c91ef393d885e9e16c0ae --- futures-ext/src/lib.rs | 2 + futures-ext/src/stream_clone.rs | 182 ++++++++++++++++++++++++++++++++ 2 files changed, 184 insertions(+) create mode 100644 futures-ext/src/stream_clone.rs diff --git a/futures-ext/src/lib.rs b/futures-ext/src/lib.rs index 32a2f30ee7..f65498dc85 100644 --- a/futures-ext/src/lib.rs +++ b/futures-ext/src/lib.rs @@ -32,6 +32,7 @@ mod futures_ordered; pub mod io; mod launch; mod select_all; +mod stream_clone; mod stream_wrappers; mod streamfork; @@ -39,6 +40,7 @@ pub use crate::bytes_stream::{BytesStream, BytesStreamFuture}; pub use crate::futures_ordered::{futures_ordered, FuturesOrdered}; pub use crate::launch::top_level_launch; pub use crate::select_all::{select_all, SelectAll}; +pub use crate::stream_clone::stream_clone; pub use crate::stream_wrappers::{ BoxStreamWrapper, CollectNoConsume, CollectTo, StreamWrapper, TakeWhile, }; diff --git a/futures-ext/src/stream_clone.rs b/futures-ext/src/stream_clone.rs new file mode 100644 index 0000000000..5fb6fae630 --- /dev/null +++ b/futures-ext/src/stream_clone.rs @@ -0,0 +1,182 @@ +// Copyright (c) 2019-present, Facebook, Inc. +// All Rights Reserved. +// +// This software may be used and distributed according to the terms of the +// GNU General Public License version 2 or any later version. + +use futures::{ + future, + prelude::*, + stream::{Fuse, Stream}, + sync::mpsc, + AsyncSink, +}; +use std::mem; + +/// Given an input Stream, return clones of that stream. +/// This requires both the item and the error to be cloneable. +/// This provides a single element of buffering - all clones +/// must consume each element before the original can make progress. +pub fn stream_clone( + s: impl Stream + Send + 'static, + copies: usize, +) -> Vec + Send + 'static> { + stream_clone_with_spawner(s, copies, tokio::executor::DefaultExecutor::current()) +} + +/// Given an input Stream, return clones of that stream. +/// This requires both the item and the error to be cloneable. +/// This provides a single element of buffering - all clones +/// must consume each element before the original can make progress. +/// This takes a `future::Executor` to spawn the copying task onto. +pub fn stream_clone_with_spawner( + stream: S, + copies: usize, + spawner: impl future::Executor>, +) -> Vec + Send + 'static> +where + S: Stream + Send + 'static, + S::Item: Clone + Send + 'static, + S::Error: Clone + Send + 'static, +{ + let (senders, recvs): (Vec<_>, Vec<_>) = (0..copies).map(|_| mpsc::channel(1)).unzip(); + + let core = CloneCore { + inner: stream.fuse(), + pending: false, + senders, + }; + + spawner.execute(core).expect("Spawning core failed"); + + recvs + .into_iter() + .map(|rx| rx.then(|v| v.unwrap())) + .collect() +} + +pub struct CloneCore { + /// Input stream + inner: Fuse, + /// True while some sender is still accepting a result + pending: bool, + /// Downsteam streams + senders: Vec>>, +} + +impl Future for CloneCore +where + S: Stream, + S::Item: Clone, + S::Error: Clone, +{ + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll<(), ()> { + loop { + if !self.pending { + // Initial state - we need to get a new value from the input, and all senders + // are ready for it. + let val = match self.inner.poll() { + Ok(Async::Ready(Some(val))) => Ok(val), + Ok(Async::Ready(None)) => break Ok(Async::Ready(())), + Ok(Async::NotReady) => break Ok(Async::NotReady), + Err(err) => Err(err), + }; + + let senders: Result, _> = mem::replace(&mut self.senders, Vec::new()) + .into_iter() + .filter_map(|mut tx| { + // Try sending. If the channel isn't ready then it (probably) means the + // receiver has gone away so just drop it. + match tx.start_send(val.clone()) { + Err(err) => Some(Err(err)), + Ok(AsyncSink::Ready) => Some(Ok(tx)), + Ok(AsyncSink::NotReady(_)) => None, + } + }) + .collect(); + + self.senders = senders.expect("start_send failed unexpectedly"); + self.pending = !self.senders.is_empty(); + } + + if self.pending { + // Drive sends to completion + let mut done = true; + + for tx in &mut self.senders { + match tx.poll_complete() { + Err(_) => return Err(()), + Ok(Async::Ready(())) => (), + Ok(Async::NotReady) => { + done = false; + } + } + } + + self.pending = !done; + } + + // If we've lost all our senders then we're done + if self.senders.is_empty() { + break Ok(Async::Ready(())); + } + + // If we've still got incomplete senders, then break out + if self.pending { + break Ok(Async::NotReady); + } + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use futures::{future, stream}; + + #[test] + fn simple() { + let vec = vec![1, 2, 3, 4, 6]; + let s = stream::iter_ok::<_, ()>(vec.clone()); + + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let res = rt.block_on(future::lazy(|| { + let c = stream_clone(s, 5); + let c = c.into_iter().map(|c| c.collect()); + let c = future::join_all(c); + + c + })); + + for (idx, v) in res.unwrap().into_iter().enumerate() { + assert_eq!(v, vec, "idx {} mismatch", idx); + } + } + + #[test] + fn err() { + let vec = vec![Ok(1), Ok(2), Ok(3), Err("badness"), Ok(4)]; + let s = stream::iter_result(vec.clone()); + + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let res: Result<_, ()> = rt.block_on(future::lazy(|| { + let c = stream_clone(s, 5); + let c = c.into_iter().map(|c| c.then(Result::Ok).collect()); + let c = future::join_all(c); + + c + })); + + // Fuse keeps going after errors, so we get the entire vector. + for (idx, v) in res.unwrap().into_iter().enumerate() { + assert_eq!(v, vec, "idx {} mismatch", idx); + } + } + + // TODO some test with blocking consumers +}