From 4466b6b76a5db27b412e987a26852a7e18a7a23d Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 7 Jul 2021 18:23:18 +0200 Subject: [PATCH] Refactor zed-rpc to work with websockets --- Cargo.lock | 68 ++++++++++++++++++++++ zed-rpc/Cargo.toml | 1 + zed-rpc/src/lib.rs | 2 + zed-rpc/src/peer.rs | 86 +++++++++++++--------------- zed-rpc/src/proto.rs | 131 +++++++++++-------------------------------- zed-rpc/src/test.rs | 64 +++++++++++++++++++++ 6 files changed, 208 insertions(+), 144 deletions(-) create mode 100644 zed-rpc/src/test.rs diff --git a/Cargo.lock b/Cargo.lock index b83ff00b7f..bdb3384da7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -286,6 +286,19 @@ dependencies = [ "syn", ] +[[package]] +name = "async-tungstenite" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8645e929ec7964448a901db9da30cd2ae8c7fecf4d6176af427837531dbbb63b" +dependencies = [ + "futures-io", + "futures-util", + "log", + "pin-project-lite", + "tungstenite", +] + [[package]] name = "atomic" version = "0.5.0" @@ -1713,6 +1726,12 @@ dependencies = [ "url", ] +[[package]] +name = "httparse" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3a87b616e37e93c22fb19bcd386f02f3af5ea98a25670ad0fce773de23c5e68" + [[package]] name = "humantime" version = "2.1.0" @@ -1806,6 +1825,15 @@ dependencies = [ "adler32", ] +[[package]] +name = "input_buffer" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f97967975f448f1a7ddb12b0bc41069d09ed6a1c161a92687e057325db35d413" +dependencies = [ + "bytes 1.0.1", +] + [[package]] name = "instant" version = "0.1.9" @@ -3304,6 +3332,19 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "sha-1" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c4cfa741c5832d0ef7fab46cabed29c2aae926db0b11bb2069edd8db5e64e16" +dependencies = [ + "block-buffer", + "cfg-if 1.0.0", + "cpufeatures", + "digest", + "opaque-debug", +] + [[package]] name = "sha1" version = "0.2.0" @@ -3920,6 +3961,26 @@ version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85e00391c1f3d171490a3f8bd79999b0002ae38d3da0d6a3a306c754b053d71b" +[[package]] +name = "tungstenite" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fe8dada8c1a3aeca77d6b51a4f1314e0f4b8e438b7b1b71e3ddaca8080e4093" +dependencies = [ + "base64 0.13.0", + "byteorder", + "bytes 1.0.1", + "http", + "httparse", + "input_buffer", + "log", + "rand 0.8.3", + "sha-1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typenum" version = "1.13.0" @@ -4057,6 +4118,12 @@ dependencies = [ "xmlwriter", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "uuid" version = "0.5.1" @@ -4356,6 +4423,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-lock", + "async-tungstenite", "base64 0.13.0", "futures", "log", diff --git a/zed-rpc/Cargo.toml b/zed-rpc/Cargo.toml index 7c03269a54..95700ff3d8 100644 --- a/zed-rpc/Cargo.toml +++ b/zed-rpc/Cargo.toml @@ -7,6 +7,7 @@ version = "0.1.0" [dependencies] anyhow = "1.0" async-lock = "2.4" +async-tungstenite = "0.14" base64 = "0.13" futures = "0.3" log = "0.4" diff --git a/zed-rpc/src/lib.rs b/zed-rpc/src/lib.rs index e9e18e92b8..fb09ec07a7 100644 --- a/zed-rpc/src/lib.rs +++ b/zed-rpc/src/lib.rs @@ -2,5 +2,7 @@ pub mod auth; mod peer; pub mod proto; pub mod rest; +#[cfg(test)] +mod test; pub use peer::*; diff --git a/zed-rpc/src/peer.rs b/zed-rpc/src/peer.rs index 87d07c694e..29c09ecf38 100644 --- a/zed-rpc/src/peer.rs +++ b/zed-rpc/src/peer.rs @@ -1,7 +1,12 @@ use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage}; use anyhow::{anyhow, Context, Result}; use async_lock::{Mutex, RwLock}; -use futures::{future::BoxFuture, AsyncRead, AsyncWrite, FutureExt}; +use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; +use futures::{ + future::BoxFuture, + stream::{SplitSink, SplitStream}, + FutureExt, StreamExt, +}; use postage::{ mpsc, prelude::{Sink, Stream}, @@ -72,13 +77,13 @@ struct Connection { response_channels: ResponseChannels, } -pub struct ConnectionHandler { +pub struct ConnectionHandler { peer: Arc, connection_id: ConnectionId, response_channels: ResponseChannels, outgoing_rx: mpsc::Receiver, - reader: MessageStream, - writer: MessageStream, + writer: MessageStream, + reader: MessageStream, } type ResponseChannels = Arc>>>; @@ -131,10 +136,16 @@ impl Peer { pub async fn add_connection( self: &Arc, conn: Conn, - ) -> (ConnectionId, ConnectionHandler) + ) -> ( + ConnectionId, + ConnectionHandler, SplitStream>, + ) where - Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static, + Conn: futures::Sink + + futures::Stream> + + Unpin, { + let (tx, rx) = conn.split(); let connection_id = ConnectionId( self.next_connection_id .fetch_add(1, atomic::Ordering::SeqCst), @@ -150,8 +161,8 @@ impl Peer { connection_id, response_channels: connection.response_channels.clone(), outgoing_rx, - reader: MessageStream::new(conn.clone()), - writer: MessageStream::new(conn), + writer: MessageStream::new(tx), + reader: MessageStream::new(rx), }; self.connections .write() @@ -291,9 +302,10 @@ impl Peer { } } -impl ConnectionHandler +impl ConnectionHandler where - Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static, + W: futures::Sink + Unpin, + R: futures::Stream> + Unpin, { pub async fn run(mut self) -> Result<()> { loop { @@ -402,38 +414,25 @@ impl fmt::Display for PeerId { #[cfg(test)] mod tests { use super::*; + use crate::test; use postage::oneshot; - use smol::{ - io::AsyncWriteExt, - net::unix::{UnixListener, UnixStream}, - }; - use std::io; - use tempdir::TempDir; #[test] fn test_request_response() { smol::block_on(async move { - // create socket - let socket_dir_path = TempDir::new("test-request-response").unwrap(); - let socket_path = socket_dir_path.path().join("test.sock"); - let listener = UnixListener::bind(&socket_path).unwrap(); - // create 2 clients connected to 1 server let server = Peer::new(); let client1 = Peer::new(); let client2 = Peer::new(); - let (client1_conn_id, task1) = client1 - .add_connection(UnixStream::connect(&socket_path).await.unwrap()) - .await; - let (client2_conn_id, task2) = client2 - .add_connection(UnixStream::connect(&socket_path).await.unwrap()) - .await; - let (_, task3) = server - .add_connection(listener.accept().await.unwrap().0) - .await; - let (_, task4) = server - .add_connection(listener.accept().await.unwrap().0) - .await; + + let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional(); + let (client1_conn_id, task1) = client1.add_connection(client1_to_server_conn).await; + let (_, task2) = server.add_connection(server_to_client_1_conn).await; + + let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional(); + let (client2_conn_id, task3) = client2.add_connection(client2_to_server_conn).await; + let (_, task4) = server.add_connection(server_to_client_2_conn).await; + smol::spawn(task1.run()).detach(); smol::spawn(task2.run()).detach(); smol::spawn(task3.run()).detach(); @@ -553,11 +552,7 @@ mod tests { #[test] fn test_disconnect() { smol::block_on(async move { - let socket_dir_path = TempDir::new("drop-client").unwrap(); - let socket_path = socket_dir_path.path().join(".sock"); - let listener = UnixListener::bind(&socket_path).unwrap(); - let client_conn = UnixStream::connect(&socket_path).await.unwrap(); - let (mut server_conn, _) = listener.accept().await.unwrap(); + let (client_conn, mut server_conn) = test::Channel::bidirectional(); let client = Peer::new(); let (connection_id, handler) = client.add_connection(client_conn).await; @@ -571,20 +566,19 @@ mod tests { client.disconnect(connection_id).await; incoming_messages_ended_rx.recv().await; - - let err = server_conn.write(&[]).await.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); + assert!( + futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![])) + .await + .is_err() + ); }); } #[test] fn test_io_error() { smol::block_on(async move { - let socket_dir_path = TempDir::new("io-error").unwrap(); - let socket_path = socket_dir_path.path().join(".sock"); - let _listener = UnixListener::bind(&socket_path).unwrap(); - let mut client_conn = UnixStream::connect(&socket_path).await.unwrap(); - client_conn.close().await.unwrap(); + let (client_conn, server_conn) = test::Channel::bidirectional(); + drop(server_conn); let client = Peer::new(); let (connection_id, handler) = client.add_connection(client_conn).await; diff --git a/zed-rpc/src/proto.rs b/zed-rpc/src/proto.rs index 1626e5aad5..bb082c1783 100644 --- a/zed-rpc/src/proto.rs +++ b/zed-rpc/src/proto.rs @@ -1,7 +1,7 @@ -use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _}; +use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; +use futures::{SinkExt as _, StreamExt as _}; use prost::Message; use std::{ - convert::TryInto, io, time::{Duration, SystemTime, UNIX_EPOCH}, }; @@ -81,66 +81,52 @@ message!(AddPeer); message!(RemovePeer); /// A stream of protobuf messages. -pub struct MessageStream { - byte_stream: T, - buffer: Vec, - upcoming_message_len: Option, +pub struct MessageStream { + stream: S, } -impl MessageStream { - pub fn new(byte_stream: T) -> Self { - Self { - byte_stream, - buffer: Default::default(), - upcoming_message_len: None, - } +impl MessageStream { + pub fn new(stream: S) -> Self { + Self { stream } } - pub fn inner_mut(&mut self) -> &mut T { - &mut self.byte_stream + pub fn inner_mut(&mut self) -> &mut S { + &mut self.stream } } -impl MessageStream +impl MessageStream where - T: AsyncWrite + Unpin, + S: futures::Sink + Unpin, { /// Write a given protobuf message to the stream. - pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> { - let message_len: u32 = message - .encoded_len() - .try_into() - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?; - self.buffer.clear(); - self.buffer.extend_from_slice(&message_len.to_be_bytes()); - message.encode(&mut self.buffer)?; - self.byte_stream.write_all(&self.buffer).await + pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> { + let mut buffer = Vec::with_capacity(message.encoded_len()); + message + .encode(&mut buffer) + .map_err(|err| io::Error::from(err))?; + self.stream.send(WebSocketMessage::Binary(buffer)).await?; + Ok(()) } } -impl MessageStream +impl MessageStream where - T: AsyncRead + Unpin, + S: futures::Stream> + Unpin, { /// Read a protobuf message of the given type from the stream. - pub async fn read_message(&mut self) -> io::Result { - loop { - if let Some(upcoming_message_len) = self.upcoming_message_len { - self.buffer.resize(upcoming_message_len, 0); - self.byte_stream.read_exact(&mut self.buffer).await?; - self.upcoming_message_len = None; - return Ok(Envelope::decode(self.buffer.as_slice())?); - } else { - self.buffer.resize(4, 0); - self.byte_stream.read_exact(&mut self.buffer).await?; - self.upcoming_message_len = Some(u32::from_be_bytes([ - self.buffer[0], - self.buffer[1], - self.buffer[2], - self.buffer[3], - ]) as usize); + pub async fn read_message(&mut self) -> Result { + while let Some(bytes) = self.stream.next().await { + match bytes? { + WebSocketMessage::Binary(bytes) => { + let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?; + return Ok(envelope); + } + WebSocketMessage::Close(_) => break, + _ => {} } } + Err(WebSocketError::ConnectionClosed) } } @@ -165,20 +151,12 @@ impl From for Timestamp { #[cfg(test)] mod tests { use super::*; - use std::{ - pin::Pin, - task::{Context, Poll}, - }; + use crate::test; #[test] fn test_round_trip_message() { smol::block_on(async { - let byte_stream = ChunkedStream { - bytes: Vec::new(), - read_offset: 0, - chunk_size: 3, - }; - + let stream = test::Channel::new(); let message1 = Auth { user_id: 5, access_token: "the-access-token".into(), @@ -191,7 +169,7 @@ mod tests { } .into_envelope(5, None, None); - let mut message_stream = MessageStream::new(byte_stream); + let mut message_stream = MessageStream::new(stream); message_stream.write_message(&message1).await.unwrap(); message_stream.write_message(&message2).await.unwrap(); let decoded_message1 = message_stream.read_message().await.unwrap(); @@ -200,47 +178,4 @@ mod tests { assert_eq!(decoded_message2, message2); }); } - - struct ChunkedStream { - bytes: Vec, - read_offset: usize, - chunk_size: usize, - } - - impl AsyncWrite for ChunkedStream { - fn poll_write( - mut self: Pin<&mut Self>, - _: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let bytes_written = buf.len().min(self.chunk_size); - self.bytes.extend_from_slice(&buf[0..bytes_written]); - Poll::Ready(Ok(bytes_written)) - } - - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - } - - impl AsyncRead for ChunkedStream { - fn poll_read( - mut self: Pin<&mut Self>, - _: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - let bytes_read = buf - .len() - .min(self.chunk_size) - .min(self.bytes.len() - self.read_offset); - let end_offset = self.read_offset + bytes_read; - buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]); - self.read_offset = end_offset; - Poll::Ready(Ok(bytes_read)) - } - } } diff --git a/zed-rpc/src/test.rs b/zed-rpc/src/test.rs new file mode 100644 index 0000000000..ad698a4094 --- /dev/null +++ b/zed-rpc/src/test.rs @@ -0,0 +1,64 @@ +use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; + +pub struct Channel { + tx: futures::channel::mpsc::UnboundedSender, + rx: futures::channel::mpsc::UnboundedReceiver, +} + +impl Channel { + pub fn new() -> Self { + let (tx, rx) = futures::channel::mpsc::unbounded(); + Self { tx, rx } + } + + pub fn bidirectional() -> (Self, Self) { + let (a_tx, a_rx) = futures::channel::mpsc::unbounded(); + let (b_tx, b_rx) = futures::channel::mpsc::unbounded(); + let a = Self { tx: a_tx, rx: b_rx }; + let b = Self { tx: b_tx, rx: a_rx }; + (a, b) + } +} + +impl futures::Sink for Channel { + type Error = WebSocketError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.tx) + .poll_ready(cx) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into()) + } + + fn start_send(mut self: Pin<&mut Self>, item: WebSocketMessage) -> Result<(), Self::Error> { + Pin::new(&mut self.tx) + .start_send(item) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.tx) + .poll_flush(cx) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into()) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.tx) + .poll_close(cx) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into()) + } +} + +impl futures::Stream for Channel { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.rx) + .poll_next(cx) + .map(|i| i.map(|i| Ok(i))) + } +}