This commit is contained in:
Conrad Irwin 2023-10-26 11:48:25 +02:00
parent 693246ba26
commit 71ad3e1b20
9 changed files with 3545 additions and 0 deletions

44
crates/rpc2/Cargo.toml Normal file
View File

@ -0,0 +1,44 @@
[package]
description = "Shared logic for communication between the Zed app and the zed.dev server"
edition = "2021"
name = "rpc"
version = "0.1.0"
publish = false
[lib]
path = "src/rpc.rs"
doctest = false
[features]
test-support = ["collections/test-support", "gpui/test-support"]
[dependencies]
clock = { path = "../clock" }
collections = { path = "../collections" }
gpui = { path = "../gpui", optional = true }
util = { path = "../util" }
anyhow.workspace = true
async-lock = "2.4"
async-tungstenite = "0.16"
base64 = "0.13"
futures.workspace = true
parking_lot.workspace = true
prost.workspace = true
rand.workspace = true
rsa = "0.4"
serde.workspace = true
serde_derive.workspace = true
smol-timeout = "0.6"
tracing = { version = "0.1.34", features = ["log"] }
zstd = "0.11"
[build-dependencies]
prost-build = "0.9"
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
smol.workspace = true
tempdir.workspace = true
ctor.workspace = true
env_logger.workspace = true

8
crates/rpc2/build.rs Normal file
View File

@ -0,0 +1,8 @@
fn main() {
let mut build = prost_build::Config::new();
// build.protoc_arg("--experimental_allow_proto3_optional");
build
.type_attribute(".", "#[derive(serde::Serialize)]")
.compile_protos(&["proto/zed.proto"], &["proto"])
.unwrap();
}

1559
crates/rpc2/proto/zed.proto Normal file

File diff suppressed because it is too large Load Diff

136
crates/rpc2/src/auth.rs Normal file
View File

@ -0,0 +1,136 @@
use anyhow::{Context, Result};
use rand::{thread_rng, Rng as _};
use rsa::{PublicKey as _, PublicKeyEncoding, RSAPrivateKey, RSAPublicKey};
use std::convert::TryFrom;
pub struct PublicKey(RSAPublicKey);
pub struct PrivateKey(RSAPrivateKey);
/// Generate a public and private key for asymmetric encryption.
pub fn keypair() -> Result<(PublicKey, PrivateKey)> {
let mut rng = thread_rng();
let bits = 1024;
let private_key = RSAPrivateKey::new(&mut rng, bits)?;
let public_key = RSAPublicKey::from(&private_key);
Ok((PublicKey(public_key), PrivateKey(private_key)))
}
/// Generate a random 64-character base64 string.
pub fn random_token() -> String {
let mut rng = thread_rng();
let mut token_bytes = [0; 48];
for byte in token_bytes.iter_mut() {
*byte = rng.gen();
}
base64::encode_config(token_bytes, base64::URL_SAFE)
}
impl PublicKey {
/// Convert a string to a base64-encoded string that can only be decoded with the corresponding
/// private key.
pub fn encrypt_string(&self, string: &str) -> Result<String> {
let mut rng = thread_rng();
let bytes = string.as_bytes();
let encrypted_bytes = self
.0
.encrypt(&mut rng, PADDING_SCHEME, bytes)
.context("failed to encrypt string with public key")?;
let encrypted_string = base64::encode_config(&encrypted_bytes, base64::URL_SAFE);
Ok(encrypted_string)
}
}
impl PrivateKey {
/// Decrypt a base64-encoded string that was encrypted by the corresponding public key.
pub fn decrypt_string(&self, encrypted_string: &str) -> Result<String> {
let encrypted_bytes = base64::decode_config(encrypted_string, base64::URL_SAFE)
.context("failed to base64-decode encrypted string")?;
let bytes = self
.0
.decrypt(PADDING_SCHEME, &encrypted_bytes)
.context("failed to decrypt string with private key")?;
let string = String::from_utf8(bytes).context("decrypted content was not valid utf8")?;
Ok(string)
}
}
impl TryFrom<PublicKey> for String {
type Error = anyhow::Error;
fn try_from(key: PublicKey) -> Result<Self> {
let bytes = key.0.to_pkcs1().context("failed to serialize public key")?;
let string = base64::encode_config(&bytes, base64::URL_SAFE);
Ok(string)
}
}
impl TryFrom<String> for PublicKey {
type Error = anyhow::Error;
fn try_from(value: String) -> Result<Self> {
let bytes = base64::decode_config(&value, base64::URL_SAFE)
.context("failed to base64-decode public key string")?;
let key = Self(RSAPublicKey::from_pkcs1(&bytes).context("failed to parse public key")?);
Ok(key)
}
}
const PADDING_SCHEME: rsa::PaddingScheme = rsa::PaddingScheme::PKCS1v15Encrypt;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_encrypt_and_decrypt_token() {
// CLIENT:
// * generate a keypair for asymmetric encryption
// * serialize the public key to send it to the server.
let (public, private) = keypair().unwrap();
let public_string = String::try_from(public).unwrap();
assert_printable(&public_string);
// SERVER:
// * parse the public key
// * generate a random token.
// * encrypt the token using the public key.
let public = PublicKey::try_from(public_string).unwrap();
let token = random_token();
let encrypted_token = public.encrypt_string(&token).unwrap();
assert_eq!(token.len(), 64);
assert_ne!(encrypted_token, token);
assert_printable(&token);
assert_printable(&encrypted_token);
// CLIENT:
// * decrypt the token using the private key.
let decrypted_token = private.decrypt_string(&encrypted_token).unwrap();
assert_eq!(decrypted_token, token);
}
#[test]
fn test_tokens_are_always_url_safe() {
for _ in 0..5 {
let token = random_token();
let (public_key, _) = keypair().unwrap();
let encrypted_token = public_key.encrypt_string(&token).unwrap();
let public_key_str = String::try_from(public_key).unwrap();
assert_printable(&token);
assert_printable(&public_key_str);
assert_printable(&encrypted_token);
}
}
fn assert_printable(token: &str) {
for c in token.chars() {
assert!(
c.is_ascii_graphic(),
"token {:?} has non-printable char {}",
token,
c
);
assert_ne!(c, '/', "token {:?} is not URL-safe", token);
assert_ne!(c, '&', "token {:?} is not URL-safe", token);
}
}
}

112
crates/rpc2/src/conn.rs Normal file
View File

@ -0,0 +1,112 @@
use async_tungstenite::tungstenite::Message as WebSocketMessage;
use futures::{SinkExt as _, StreamExt as _};
pub struct Connection {
pub(crate) tx:
Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = anyhow::Error>>,
pub(crate) rx: Box<
dyn 'static
+ Send
+ Unpin
+ futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>,
>,
}
impl Connection {
pub fn new<S>(stream: S) -> Self
where
S: 'static
+ Send
+ Unpin
+ futures::Sink<WebSocketMessage, Error = anyhow::Error>
+ futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>,
{
let (tx, rx) = stream.split();
Self {
tx: Box::new(tx),
rx: Box::new(rx),
}
}
pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), anyhow::Error> {
self.tx.send(message).await
}
#[cfg(any(test, feature = "test-support"))]
pub fn in_memory(
executor: std::sync::Arc<gpui::executor::Background>,
) -> (Self, Self, std::sync::Arc<std::sync::atomic::AtomicBool>) {
use std::sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
};
let killed = Arc::new(AtomicBool::new(false));
let (a_tx, a_rx) = channel(killed.clone(), executor.clone());
let (b_tx, b_rx) = channel(killed.clone(), executor);
return (
Self { tx: a_tx, rx: b_rx },
Self { tx: b_tx, rx: a_rx },
killed,
);
#[allow(clippy::type_complexity)]
fn channel(
killed: Arc<AtomicBool>,
executor: Arc<gpui::executor::Background>,
) -> (
Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = anyhow::Error>>,
Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>>,
) {
use anyhow::anyhow;
use futures::channel::mpsc;
use std::io::{Error, ErrorKind};
let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
let tx = tx.sink_map_err(|error| anyhow!(error)).with({
let killed = killed.clone();
let executor = Arc::downgrade(&executor);
move |msg| {
let killed = killed.clone();
let executor = executor.clone();
Box::pin(async move {
if let Some(executor) = executor.upgrade() {
executor.simulate_random_delay().await;
}
// Writes to a half-open TCP connection will error.
if killed.load(SeqCst) {
std::io::Result::Err(Error::new(ErrorKind::Other, "connection lost"))?;
}
Ok(msg)
})
}
});
let rx = rx.then({
let killed = killed;
let executor = Arc::downgrade(&executor);
move |msg| {
let killed = killed.clone();
let executor = executor.clone();
Box::pin(async move {
if let Some(executor) = executor.upgrade() {
executor.simulate_random_delay().await;
}
// Reads from a half-open TCP connection will hang.
if killed.load(SeqCst) {
futures::future::pending::<()>().await;
}
Ok(msg)
})
}
});
(Box::new(tx), Box::new(rx))
}
}
}

70
crates/rpc2/src/macros.rs Normal file
View File

@ -0,0 +1,70 @@
#[macro_export]
macro_rules! messages {
($(($name:ident, $priority:ident)),* $(,)?) => {
pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
match envelope.payload {
$(Some(envelope::Payload::$name(payload)) => {
Some(Box::new(TypedEnvelope {
sender_id,
original_sender_id: envelope.original_sender_id.map(|original_sender| PeerId {
owner_id: original_sender.owner_id,
id: original_sender.id
}),
message_id: envelope.id,
payload,
}))
}, )*
_ => None
}
}
$(
impl EnvelopedMessage for $name {
const NAME: &'static str = std::stringify!($name);
const PRIORITY: MessagePriority = MessagePriority::$priority;
fn into_envelope(
self,
id: u32,
responding_to: Option<u32>,
original_sender_id: Option<PeerId>,
) -> Envelope {
Envelope {
id,
responding_to,
original_sender_id,
payload: Some(envelope::Payload::$name(self)),
}
}
fn from_envelope(envelope: Envelope) -> Option<Self> {
if let Some(envelope::Payload::$name(msg)) = envelope.payload {
Some(msg)
} else {
None
}
}
}
)*
};
}
#[macro_export]
macro_rules! request_messages {
($(($request_name:ident, $response_name:ident)),* $(,)?) => {
$(impl RequestMessage for $request_name {
type Response = $response_name;
})*
};
}
#[macro_export]
macro_rules! entity_messages {
($id_field:ident, $($name:ident),* $(,)?) => {
$(impl EntityMessage for $name {
fn remote_entity_id(&self) -> u64 {
self.$id_field
}
})*
};
}

933
crates/rpc2/src/peer.rs Normal file
View File

@ -0,0 +1,933 @@
use super::{
proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
Connection,
};
use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::{
channel::{mpsc, oneshot},
stream::BoxStream,
FutureExt, SinkExt, StreamExt, TryFutureExt,
};
use parking_lot::{Mutex, RwLock};
use serde::{ser::SerializeStruct, Serialize};
use std::{fmt, sync::atomic::Ordering::SeqCst};
use std::{
future::Future,
marker::PhantomData,
sync::{
atomic::{self, AtomicU32},
Arc,
},
time::Duration,
};
use tracing::instrument;
#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
pub struct ConnectionId {
pub owner_id: u32,
pub id: u32,
}
impl Into<PeerId> for ConnectionId {
fn into(self) -> PeerId {
PeerId {
owner_id: self.owner_id,
id: self.id,
}
}
}
impl From<PeerId> for ConnectionId {
fn from(peer_id: PeerId) -> Self {
Self {
owner_id: peer_id.owner_id,
id: peer_id.id,
}
}
}
impl fmt::Display for ConnectionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", self.owner_id, self.id)
}
}
pub struct Receipt<T> {
pub sender_id: ConnectionId,
pub message_id: u32,
payload_type: PhantomData<T>,
}
impl<T> Clone for Receipt<T> {
fn clone(&self) -> Self {
Self {
sender_id: self.sender_id,
message_id: self.message_id,
payload_type: PhantomData,
}
}
}
impl<T> Copy for Receipt<T> {}
#[derive(Clone, Debug)]
pub struct TypedEnvelope<T> {
pub sender_id: ConnectionId,
pub original_sender_id: Option<PeerId>,
pub message_id: u32,
pub payload: T,
}
impl<T> TypedEnvelope<T> {
pub fn original_sender_id(&self) -> Result<PeerId> {
self.original_sender_id
.ok_or_else(|| anyhow!("missing original_sender_id"))
}
}
impl<T: RequestMessage> TypedEnvelope<T> {
pub fn receipt(&self) -> Receipt<T> {
Receipt {
sender_id: self.sender_id,
message_id: self.message_id,
payload_type: PhantomData,
}
}
}
pub struct Peer {
epoch: AtomicU32,
pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
next_connection_id: AtomicU32,
}
#[derive(Clone, Serialize)]
pub struct ConnectionState {
#[serde(skip)]
outgoing_tx: mpsc::UnboundedSender<proto::Message>,
next_message_id: Arc<AtomicU32>,
#[allow(clippy::type_complexity)]
#[serde(skip)]
response_channels:
Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
}
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
impl Peer {
pub fn new(epoch: u32) -> Arc<Self> {
Arc::new(Self {
epoch: AtomicU32::new(epoch),
connections: Default::default(),
next_connection_id: Default::default(),
})
}
pub fn epoch(&self) -> u32 {
self.epoch.load(SeqCst)
}
#[instrument(skip_all)]
pub fn add_connection<F, Fut, Out>(
self: &Arc<Self>,
connection: Connection,
create_timer: F,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
)
where
F: Send + Fn(Duration) -> Fut,
Fut: Send + Future<Output = Out>,
Out: Send,
{
// For outgoing messages, use an unbounded channel so that application code
// can always send messages without yielding. For incoming messages, use a
// bounded channel so that other peers will receive backpressure if they send
// messages faster than this peer can process them.
#[cfg(any(test, feature = "test-support"))]
const INCOMING_BUFFER_SIZE: usize = 1;
#[cfg(not(any(test, feature = "test-support")))]
const INCOMING_BUFFER_SIZE: usize = 64;
let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
let connection_id = ConnectionId {
owner_id: self.epoch.load(SeqCst),
id: self.next_connection_id.fetch_add(1, SeqCst),
};
let connection_state = ConnectionState {
outgoing_tx,
next_message_id: Default::default(),
response_channels: Arc::new(Mutex::new(Some(Default::default()))),
};
let mut writer = MessageStream::new(connection.tx);
let mut reader = MessageStream::new(connection.rx);
let this = self.clone();
let response_channels = connection_state.response_channels.clone();
let handle_io = async move {
tracing::trace!(%connection_id, "handle io future: start");
let _end_connection = util::defer(|| {
response_channels.lock().take();
this.connections.write().remove(&connection_id);
tracing::trace!(%connection_id, "handle io future: end");
});
// Send messages on this frequency so the connection isn't closed.
let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
futures::pin_mut!(keepalive_timer);
// Disconnect if we don't receive messages at least this frequently.
let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
futures::pin_mut!(receive_timeout);
loop {
tracing::trace!(%connection_id, "outer loop iteration start");
let read_message = reader.read().fuse();
futures::pin_mut!(read_message);
loop {
tracing::trace!(%connection_id, "inner loop iteration start");
futures::select_biased! {
outgoing = outgoing_rx.next().fuse() => match outgoing {
Some(outgoing) => {
tracing::trace!(%connection_id, "outgoing rpc message: writing");
futures::select_biased! {
result = writer.write(outgoing).fuse() => {
tracing::trace!(%connection_id, "outgoing rpc message: done writing");
result.context("failed to write RPC message")?;
tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
}
_ = create_timer(WRITE_TIMEOUT).fuse() => {
tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
Err(anyhow!("timed out writing message"))?;
}
}
}
None => {
tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
return Ok(())
},
},
_ = keepalive_timer => {
tracing::trace!(%connection_id, "keepalive interval: pinging");
futures::select_biased! {
result = writer.write(proto::Message::Ping).fuse() => {
tracing::trace!(%connection_id, "keepalive interval: done pinging");
result.context("failed to send keepalive")?;
tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
}
_ = create_timer(WRITE_TIMEOUT).fuse() => {
tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
Err(anyhow!("timed out sending keepalive"))?;
}
}
}
incoming = read_message => {
let incoming = incoming.context("error reading rpc message from socket")?;
tracing::trace!(%connection_id, "incoming rpc message: received");
tracing::trace!(%connection_id, "receive timeout: resetting");
receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
if let proto::Message::Envelope(incoming) = incoming {
tracing::trace!(%connection_id, "incoming rpc message: processing");
futures::select_biased! {
result = incoming_tx.send(incoming).fuse() => match result {
Ok(_) => {
tracing::trace!(%connection_id, "incoming rpc message: processed");
}
Err(_) => {
tracing::trace!(%connection_id, "incoming rpc message: channel closed");
return Ok(())
}
},
_ = create_timer(WRITE_TIMEOUT).fuse() => {
tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
Err(anyhow!("timed out processing incoming message"))?
}
}
}
break;
},
_ = receive_timeout => {
tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
Err(anyhow!("delay between messages too long"))?
}
}
}
}
};
let response_channels = connection_state.response_channels.clone();
self.connections
.write()
.insert(connection_id, connection_state);
let incoming_rx = incoming_rx.filter_map(move |incoming| {
let response_channels = response_channels.clone();
async move {
let message_id = incoming.id;
tracing::trace!(?incoming, "incoming message future: start");
let _end = util::defer(move || {
tracing::trace!(%connection_id, message_id, "incoming message future: end");
});
if let Some(responding_to) = incoming.responding_to {
tracing::trace!(
%connection_id,
message_id,
responding_to,
"incoming response: received"
);
let channel = response_channels.lock().as_mut()?.remove(&responding_to);
if let Some(tx) = channel {
let requester_resumed = oneshot::channel();
if let Err(error) = tx.send((incoming, requester_resumed.0)) {
tracing::trace!(
%connection_id,
message_id,
responding_to = responding_to,
?error,
"incoming response: request future dropped",
);
}
tracing::trace!(
%connection_id,
message_id,
responding_to,
"incoming response: waiting to resume requester"
);
let _ = requester_resumed.1.await;
tracing::trace!(
%connection_id,
message_id,
responding_to,
"incoming response: requester resumed"
);
} else {
tracing::warn!(
%connection_id,
message_id,
responding_to,
"incoming response: unknown request"
);
}
None
} else {
tracing::trace!(%connection_id, message_id, "incoming message: received");
proto::build_typed_envelope(connection_id, incoming).or_else(|| {
tracing::error!(
%connection_id,
message_id,
"unable to construct a typed envelope"
);
None
})
}
}
});
(connection_id, handle_io, incoming_rx.boxed())
}
#[cfg(any(test, feature = "test-support"))]
pub fn add_test_connection(
self: &Arc<Self>,
connection: Connection,
executor: Arc<gpui::executor::Background>,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
) {
let executor = executor.clone();
self.add_connection(connection, move |duration| executor.timer(duration))
}
pub fn disconnect(&self, connection_id: ConnectionId) {
self.connections.write().remove(&connection_id);
}
pub fn reset(&self, epoch: u32) {
self.teardown();
self.next_connection_id.store(0, SeqCst);
self.epoch.store(epoch, SeqCst);
}
pub fn teardown(&self) {
self.connections.write().clear();
}
pub fn request<T: RequestMessage>(
&self,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = Result<T::Response>> {
self.request_internal(None, receiver_id, request)
.map_ok(|envelope| envelope.payload)
}
pub fn request_envelope<T: RequestMessage>(
&self,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
self.request_internal(None, receiver_id, request)
}
pub fn forward_request<T: RequestMessage>(
&self,
sender_id: ConnectionId,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = Result<T::Response>> {
self.request_internal(Some(sender_id), receiver_id, request)
.map_ok(|envelope| envelope.payload)
}
pub fn request_internal<T: RequestMessage>(
&self,
original_sender_id: Option<ConnectionId>,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
let (tx, rx) = oneshot::channel();
let send = self.connection_state(receiver_id).and_then(|connection| {
let message_id = connection.next_message_id.fetch_add(1, SeqCst);
connection
.response_channels
.lock()
.as_mut()
.ok_or_else(|| anyhow!("connection was closed"))?
.insert(message_id, tx);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(request.into_envelope(
message_id,
None,
original_sender_id.map(Into::into),
)))
.map_err(|_| anyhow!("connection was closed"))?;
Ok(())
});
async move {
send?;
let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
Err(anyhow!(
"RPC request {} failed - {}",
T::NAME,
error.message
))
} else {
Ok(TypedEnvelope {
message_id: response.id,
sender_id: receiver_id,
original_sender_id: response.original_sender_id,
payload: T::Response::from_envelope(response)
.ok_or_else(|| anyhow!("received response of the wrong type"))?,
})
}
}
}
pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
let connection = self.connection_state(receiver_id)?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(
message.into_envelope(message_id, None, None),
))?;
Ok(())
}
pub fn forward_send<T: EnvelopedMessage>(
&self,
sender_id: ConnectionId,
receiver_id: ConnectionId,
message: T,
) -> Result<()> {
let connection = self.connection_state(receiver_id)?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(message.into_envelope(
message_id,
None,
Some(sender_id.into()),
)))?;
Ok(())
}
pub fn respond<T: RequestMessage>(
&self,
receipt: Receipt<T>,
response: T::Response,
) -> Result<()> {
let connection = self.connection_state(receipt.sender_id)?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(response.into_envelope(
message_id,
Some(receipt.message_id),
None,
)))?;
Ok(())
}
pub fn respond_with_error<T: RequestMessage>(
&self,
receipt: Receipt<T>,
response: proto::Error,
) -> Result<()> {
let connection = self.connection_state(receipt.sender_id)?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(response.into_envelope(
message_id,
Some(receipt.message_id),
None,
)))?;
Ok(())
}
pub fn respond_with_unhandled_message(
&self,
envelope: Box<dyn AnyTypedEnvelope>,
) -> Result<()> {
let connection = self.connection_state(envelope.sender_id())?;
let response = proto::Error {
message: format!("message {} was not handled", envelope.payload_type_name()),
};
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(response.into_envelope(
message_id,
Some(envelope.message_id()),
None,
)))?;
Ok(())
}
fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
let connections = self.connections.read();
let connection = connections
.get(&connection_id)
.ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
Ok(connection.clone())
}
}
impl Serialize for Peer {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut state = serializer.serialize_struct("Peer", 2)?;
state.serialize_field("connections", &*self.connections.read())?;
state.end()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TypedEnvelope;
use async_tungstenite::tungstenite::Message as WebSocketMessage;
use gpui::TestAppContext;
#[ctor::ctor]
fn init_logger() {
if std::env::var("RUST_LOG").is_ok() {
env_logger::init();
}
}
#[gpui::test(iterations = 50)]
async fn test_request_response(cx: &mut TestAppContext) {
let executor = cx.foreground();
// create 2 clients connected to 1 server
let server = Peer::new(0);
let client1 = Peer::new(0);
let client2 = Peer::new(0);
let (client1_to_server_conn, server_to_client_1_conn, _kill) =
Connection::in_memory(cx.background());
let (client1_conn_id, io_task1, client1_incoming) =
client1.add_test_connection(client1_to_server_conn, cx.background());
let (_, io_task2, server_incoming1) =
server.add_test_connection(server_to_client_1_conn, cx.background());
let (client2_to_server_conn, server_to_client_2_conn, _kill) =
Connection::in_memory(cx.background());
let (client2_conn_id, io_task3, client2_incoming) =
client2.add_test_connection(client2_to_server_conn, cx.background());
let (_, io_task4, server_incoming2) =
server.add_test_connection(server_to_client_2_conn, cx.background());
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
executor.spawn(io_task3).detach();
executor.spawn(io_task4).detach();
executor
.spawn(handle_messages(server_incoming1, server.clone()))
.detach();
executor
.spawn(handle_messages(client1_incoming, client1.clone()))
.detach();
executor
.spawn(handle_messages(server_incoming2, server.clone()))
.detach();
executor
.spawn(handle_messages(client2_incoming, client2.clone()))
.detach();
assert_eq!(
client1
.request(client1_conn_id, proto::Ping {},)
.await
.unwrap(),
proto::Ack {}
);
assert_eq!(
client2
.request(client2_conn_id, proto::Ping {},)
.await
.unwrap(),
proto::Ack {}
);
assert_eq!(
client1
.request(client1_conn_id, proto::Test { id: 1 },)
.await
.unwrap(),
proto::Test { id: 1 }
);
assert_eq!(
client2
.request(client2_conn_id, proto::Test { id: 2 })
.await
.unwrap(),
proto::Test { id: 2 }
);
client1.disconnect(client1_conn_id);
client2.disconnect(client1_conn_id);
async fn handle_messages(
mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
peer: Arc<Peer>,
) -> Result<()> {
while let Some(envelope) = messages.next().await {
let envelope = envelope.into_any();
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
let receipt = envelope.receipt();
peer.respond(receipt, proto::Ack {})?
} else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
{
peer.respond(envelope.receipt(), envelope.payload.clone())?
} else {
panic!("unknown message type");
}
}
Ok(())
}
}
#[gpui::test(iterations = 50)]
async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
let executor = cx.foreground();
let server = Peer::new(0);
let client = Peer::new(0);
let (client_to_server_conn, server_to_client_conn, _kill) =
Connection::in_memory(cx.background());
let (client_to_server_conn_id, io_task1, mut client_incoming) =
client.add_test_connection(client_to_server_conn, cx.background());
let (server_to_client_conn_id, io_task2, mut server_incoming) =
server.add_test_connection(server_to_client_conn, cx.background());
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
executor
.spawn(async move {
let request = server_incoming
.next()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Ping>>()
.unwrap();
server
.send(
server_to_client_conn_id,
proto::Error {
message: "message 1".to_string(),
},
)
.unwrap();
server
.send(
server_to_client_conn_id,
proto::Error {
message: "message 2".to_string(),
},
)
.unwrap();
server.respond(request.receipt(), proto::Ack {}).unwrap();
// Prevent the connection from being dropped
server_incoming.next().await;
})
.detach();
let events = Arc::new(Mutex::new(Vec::new()));
let response = client.request(client_to_server_conn_id, proto::Ping {});
let response_task = executor.spawn({
let events = events.clone();
async move {
response.await.unwrap();
events.lock().push("response".to_string());
}
});
executor
.spawn({
let events = events.clone();
async move {
let incoming1 = client_incoming
.next()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Error>>()
.unwrap();
events.lock().push(incoming1.payload.message);
let incoming2 = client_incoming
.next()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Error>>()
.unwrap();
events.lock().push(incoming2.payload.message);
// Prevent the connection from being dropped
client_incoming.next().await;
}
})
.detach();
response_task.await;
assert_eq!(
&*events.lock(),
&[
"message 1".to_string(),
"message 2".to_string(),
"response".to_string()
]
);
}
#[gpui::test(iterations = 50)]
async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
let executor = cx.foreground();
let server = Peer::new(0);
let client = Peer::new(0);
let (client_to_server_conn, server_to_client_conn, _kill) =
Connection::in_memory(cx.background());
let (client_to_server_conn_id, io_task1, mut client_incoming) =
client.add_test_connection(client_to_server_conn, cx.background());
let (server_to_client_conn_id, io_task2, mut server_incoming) =
server.add_test_connection(server_to_client_conn, cx.background());
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
executor
.spawn(async move {
let request1 = server_incoming
.next()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Ping>>()
.unwrap();
let request2 = server_incoming
.next()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Ping>>()
.unwrap();
server
.send(
server_to_client_conn_id,
proto::Error {
message: "message 1".to_string(),
},
)
.unwrap();
server
.send(
server_to_client_conn_id,
proto::Error {
message: "message 2".to_string(),
},
)
.unwrap();
server.respond(request1.receipt(), proto::Ack {}).unwrap();
server.respond(request2.receipt(), proto::Ack {}).unwrap();
// Prevent the connection from being dropped
server_incoming.next().await;
})
.detach();
let events = Arc::new(Mutex::new(Vec::new()));
let request1 = client.request(client_to_server_conn_id, proto::Ping {});
let request1_task = executor.spawn(request1);
let request2 = client.request(client_to_server_conn_id, proto::Ping {});
let request2_task = executor.spawn({
let events = events.clone();
async move {
request2.await.unwrap();
events.lock().push("response 2".to_string());
}
});
executor
.spawn({
let events = events.clone();
async move {
let incoming1 = client_incoming
.next()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Error>>()
.unwrap();
events.lock().push(incoming1.payload.message);
let incoming2 = client_incoming
.next()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Error>>()
.unwrap();
events.lock().push(incoming2.payload.message);
// Prevent the connection from being dropped
client_incoming.next().await;
}
})
.detach();
// Allow the request to make some progress before dropping it.
cx.background().simulate_random_delay().await;
drop(request1_task);
request2_task.await;
assert_eq!(
&*events.lock(),
&[
"message 1".to_string(),
"message 2".to_string(),
"response 2".to_string()
]
);
}
#[gpui::test(iterations = 50)]
async fn test_disconnect(cx: &mut TestAppContext) {
let executor = cx.foreground();
let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
let client = Peer::new(0);
let (connection_id, io_handler, mut incoming) =
client.add_test_connection(client_conn, cx.background());
let (io_ended_tx, io_ended_rx) = oneshot::channel();
executor
.spawn(async move {
io_handler.await.ok();
io_ended_tx.send(()).unwrap();
})
.detach();
let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
executor
.spawn(async move {
incoming.next().await;
messages_ended_tx.send(()).unwrap();
})
.detach();
client.disconnect(connection_id);
let _ = io_ended_rx.await;
let _ = messages_ended_rx.await;
assert!(server_conn
.send(WebSocketMessage::Binary(vec![]))
.await
.is_err());
}
#[gpui::test(iterations = 50)]
async fn test_io_error(cx: &mut TestAppContext) {
let executor = cx.foreground();
let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
let client = Peer::new(0);
let (connection_id, io_handler, mut incoming) =
client.add_test_connection(client_conn, cx.background());
executor.spawn(io_handler).detach();
executor
.spawn(async move { incoming.next().await })
.detach();
let response = executor.spawn(client.request(connection_id, proto::Ping {}));
let _request = server_conn.rx.next().await.unwrap().unwrap();
drop(server_conn);
assert_eq!(
response.await.unwrap_err().to_string(),
"connection was closed"
);
}
}

674
crates/rpc2/src/proto.rs Normal file
View File

@ -0,0 +1,674 @@
#![allow(non_snake_case)]
use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
use anyhow::{anyhow, Result};
use async_tungstenite::tungstenite::Message as WebSocketMessage;
use collections::HashMap;
use futures::{SinkExt as _, StreamExt as _};
use prost::Message as _;
use serde::Serialize;
use std::any::{Any, TypeId};
use std::{
cmp,
fmt::Debug,
io, iter,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use std::{fmt, mem};
include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
const NAME: &'static str;
const PRIORITY: MessagePriority;
fn into_envelope(
self,
id: u32,
responding_to: Option<u32>,
original_sender_id: Option<PeerId>,
) -> Envelope;
fn from_envelope(envelope: Envelope) -> Option<Self>;
}
pub trait EntityMessage: EnvelopedMessage {
fn remote_entity_id(&self) -> u64;
}
pub trait RequestMessage: EnvelopedMessage {
type Response: EnvelopedMessage;
}
pub trait AnyTypedEnvelope: 'static + Send + Sync {
fn payload_type_id(&self) -> TypeId;
fn payload_type_name(&self) -> &'static str;
fn as_any(&self) -> &dyn Any;
fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
fn is_background(&self) -> bool;
fn original_sender_id(&self) -> Option<PeerId>;
fn sender_id(&self) -> ConnectionId;
fn message_id(&self) -> u32;
}
pub enum MessagePriority {
Foreground,
Background,
}
impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
fn payload_type_id(&self) -> TypeId {
TypeId::of::<T>()
}
fn payload_type_name(&self) -> &'static str {
T::NAME
}
fn as_any(&self) -> &dyn Any {
self
}
fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
self
}
fn is_background(&self) -> bool {
matches!(T::PRIORITY, MessagePriority::Background)
}
fn original_sender_id(&self) -> Option<PeerId> {
self.original_sender_id
}
fn sender_id(&self) -> ConnectionId {
self.sender_id
}
fn message_id(&self) -> u32 {
self.message_id
}
}
impl PeerId {
pub fn from_u64(peer_id: u64) -> Self {
let owner_id = (peer_id >> 32) as u32;
let id = peer_id as u32;
Self { owner_id, id }
}
pub fn as_u64(self) -> u64 {
((self.owner_id as u64) << 32) | (self.id as u64)
}
}
impl Copy for PeerId {}
impl Eq for PeerId {}
impl Ord for PeerId {
fn cmp(&self, other: &Self) -> cmp::Ordering {
self.owner_id
.cmp(&other.owner_id)
.then_with(|| self.id.cmp(&other.id))
}
}
impl PartialOrd for PeerId {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
impl std::hash::Hash for PeerId {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.owner_id.hash(state);
self.id.hash(state);
}
}
impl fmt::Display for PeerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", self.owner_id, self.id)
}
}
messages!(
(Ack, Foreground),
(AddProjectCollaborator, Foreground),
(ApplyCodeAction, Background),
(ApplyCodeActionResponse, Background),
(ApplyCompletionAdditionalEdits, Background),
(ApplyCompletionAdditionalEditsResponse, Background),
(BufferReloaded, Foreground),
(BufferSaved, Foreground),
(Call, Foreground),
(CallCanceled, Foreground),
(CancelCall, Foreground),
(CopyProjectEntry, Foreground),
(CreateBufferForPeer, Foreground),
(CreateChannel, Foreground),
(CreateChannelResponse, Foreground),
(ChannelMessageSent, Foreground),
(CreateProjectEntry, Foreground),
(CreateRoom, Foreground),
(CreateRoomResponse, Foreground),
(DeclineCall, Foreground),
(DeleteProjectEntry, Foreground),
(Error, Foreground),
(ExpandProjectEntry, Foreground),
(Follow, Foreground),
(FollowResponse, Foreground),
(FormatBuffers, Foreground),
(FormatBuffersResponse, Foreground),
(FuzzySearchUsers, Foreground),
(GetCodeActions, Background),
(GetCodeActionsResponse, Background),
(GetHover, Background),
(GetHoverResponse, Background),
(GetChannelMessages, Background),
(GetChannelMessagesResponse, Background),
(SendChannelMessage, Background),
(SendChannelMessageResponse, Background),
(GetCompletions, Background),
(GetCompletionsResponse, Background),
(GetDefinition, Background),
(GetDefinitionResponse, Background),
(GetTypeDefinition, Background),
(GetTypeDefinitionResponse, Background),
(GetDocumentHighlights, Background),
(GetDocumentHighlightsResponse, Background),
(GetReferences, Background),
(GetReferencesResponse, Background),
(GetProjectSymbols, Background),
(GetProjectSymbolsResponse, Background),
(GetUsers, Foreground),
(Hello, Foreground),
(IncomingCall, Foreground),
(InviteChannelMember, Foreground),
(UsersResponse, Foreground),
(JoinProject, Foreground),
(JoinProjectResponse, Foreground),
(JoinRoom, Foreground),
(JoinRoomResponse, Foreground),
(JoinChannelChat, Foreground),
(JoinChannelChatResponse, Foreground),
(LeaveChannelChat, Foreground),
(LeaveProject, Foreground),
(LeaveRoom, Foreground),
(OpenBufferById, Background),
(OpenBufferByPath, Background),
(OpenBufferForSymbol, Background),
(OpenBufferForSymbolResponse, Background),
(OpenBufferResponse, Background),
(PerformRename, Background),
(PerformRenameResponse, Background),
(OnTypeFormatting, Background),
(OnTypeFormattingResponse, Background),
(InlayHints, Background),
(InlayHintsResponse, Background),
(ResolveInlayHint, Background),
(ResolveInlayHintResponse, Background),
(RefreshInlayHints, Foreground),
(Ping, Foreground),
(PrepareRename, Background),
(PrepareRenameResponse, Background),
(ExpandProjectEntryResponse, Foreground),
(ProjectEntryResponse, Foreground),
(RejoinRoom, Foreground),
(RejoinRoomResponse, Foreground),
(RemoveContact, Foreground),
(RemoveChannelMember, Foreground),
(RemoveChannelMessage, Foreground),
(ReloadBuffers, Foreground),
(ReloadBuffersResponse, Foreground),
(RemoveProjectCollaborator, Foreground),
(RenameProjectEntry, Foreground),
(RequestContact, Foreground),
(RespondToContactRequest, Foreground),
(RespondToChannelInvite, Foreground),
(JoinChannel, Foreground),
(RoomUpdated, Foreground),
(SaveBuffer, Foreground),
(RenameChannel, Foreground),
(RenameChannelResponse, Foreground),
(SetChannelMemberAdmin, Foreground),
(SearchProject, Background),
(SearchProjectResponse, Background),
(ShareProject, Foreground),
(ShareProjectResponse, Foreground),
(ShowContacts, Foreground),
(StartLanguageServer, Foreground),
(SynchronizeBuffers, Foreground),
(SynchronizeBuffersResponse, Foreground),
(RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground),
(Test, Foreground),
(Unfollow, Foreground),
(UnshareProject, Foreground),
(UpdateBuffer, Foreground),
(UpdateBufferFile, Foreground),
(UpdateContacts, Foreground),
(DeleteChannel, Foreground),
(MoveChannel, Foreground),
(LinkChannel, Foreground),
(UnlinkChannel, Foreground),
(UpdateChannels, Foreground),
(UpdateDiagnosticSummary, Foreground),
(UpdateFollowers, Foreground),
(UpdateInviteInfo, Foreground),
(UpdateLanguageServer, Foreground),
(UpdateParticipantLocation, Foreground),
(UpdateProject, Foreground),
(UpdateProjectCollaborator, Foreground),
(UpdateWorktree, Foreground),
(UpdateWorktreeSettings, Foreground),
(UpdateDiffBase, Foreground),
(GetPrivateUserInfo, Foreground),
(GetPrivateUserInfoResponse, Foreground),
(GetChannelMembers, Foreground),
(GetChannelMembersResponse, Foreground),
(JoinChannelBuffer, Foreground),
(JoinChannelBufferResponse, Foreground),
(LeaveChannelBuffer, Background),
(UpdateChannelBuffer, Foreground),
(UpdateChannelBufferCollaborators, Foreground),
(AckBufferOperation, Background),
(AckChannelMessage, Background),
);
request_messages!(
(ApplyCodeAction, ApplyCodeActionResponse),
(
ApplyCompletionAdditionalEdits,
ApplyCompletionAdditionalEditsResponse
),
(Call, Ack),
(CancelCall, Ack),
(CopyProjectEntry, ProjectEntryResponse),
(CreateProjectEntry, ProjectEntryResponse),
(CreateRoom, CreateRoomResponse),
(CreateChannel, CreateChannelResponse),
(DeclineCall, Ack),
(DeleteProjectEntry, ProjectEntryResponse),
(ExpandProjectEntry, ExpandProjectEntryResponse),
(Follow, FollowResponse),
(FormatBuffers, FormatBuffersResponse),
(GetCodeActions, GetCodeActionsResponse),
(GetHover, GetHoverResponse),
(GetCompletions, GetCompletionsResponse),
(GetDefinition, GetDefinitionResponse),
(GetTypeDefinition, GetTypeDefinitionResponse),
(GetDocumentHighlights, GetDocumentHighlightsResponse),
(GetReferences, GetReferencesResponse),
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
(GetProjectSymbols, GetProjectSymbolsResponse),
(FuzzySearchUsers, UsersResponse),
(GetUsers, UsersResponse),
(InviteChannelMember, Ack),
(JoinProject, JoinProjectResponse),
(JoinRoom, JoinRoomResponse),
(JoinChannelChat, JoinChannelChatResponse),
(LeaveRoom, Ack),
(RejoinRoom, RejoinRoomResponse),
(IncomingCall, Ack),
(OpenBufferById, OpenBufferResponse),
(OpenBufferByPath, OpenBufferResponse),
(OpenBufferForSymbol, OpenBufferForSymbolResponse),
(Ping, Ack),
(PerformRename, PerformRenameResponse),
(PrepareRename, PrepareRenameResponse),
(OnTypeFormatting, OnTypeFormattingResponse),
(InlayHints, InlayHintsResponse),
(ResolveInlayHint, ResolveInlayHintResponse),
(RefreshInlayHints, Ack),
(ReloadBuffers, ReloadBuffersResponse),
(RequestContact, Ack),
(RemoveChannelMember, Ack),
(RemoveContact, Ack),
(RespondToContactRequest, Ack),
(RespondToChannelInvite, Ack),
(SetChannelMemberAdmin, Ack),
(SendChannelMessage, SendChannelMessageResponse),
(GetChannelMessages, GetChannelMessagesResponse),
(GetChannelMembers, GetChannelMembersResponse),
(JoinChannel, JoinRoomResponse),
(RemoveChannelMessage, Ack),
(DeleteChannel, Ack),
(RenameProjectEntry, ProjectEntryResponse),
(RenameChannel, RenameChannelResponse),
(LinkChannel, Ack),
(UnlinkChannel, Ack),
(MoveChannel, Ack),
(SaveBuffer, BufferSaved),
(SearchProject, SearchProjectResponse),
(ShareProject, ShareProjectResponse),
(SynchronizeBuffers, SynchronizeBuffersResponse),
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
(Test, Test),
(UpdateBuffer, Ack),
(UpdateParticipantLocation, Ack),
(UpdateProject, Ack),
(UpdateWorktree, Ack),
(JoinChannelBuffer, JoinChannelBufferResponse),
(LeaveChannelBuffer, Ack)
);
entity_messages!(
project_id,
AddProjectCollaborator,
ApplyCodeAction,
ApplyCompletionAdditionalEdits,
BufferReloaded,
BufferSaved,
CopyProjectEntry,
CreateBufferForPeer,
CreateProjectEntry,
DeleteProjectEntry,
ExpandProjectEntry,
FormatBuffers,
GetCodeActions,
GetCompletions,
GetDefinition,
GetTypeDefinition,
GetDocumentHighlights,
GetHover,
GetReferences,
GetProjectSymbols,
JoinProject,
LeaveProject,
OpenBufferById,
OpenBufferByPath,
OpenBufferForSymbol,
PerformRename,
OnTypeFormatting,
InlayHints,
ResolveInlayHint,
RefreshInlayHints,
PrepareRename,
ReloadBuffers,
RemoveProjectCollaborator,
RenameProjectEntry,
SaveBuffer,
SearchProject,
StartLanguageServer,
SynchronizeBuffers,
UnshareProject,
UpdateBuffer,
UpdateBufferFile,
UpdateDiagnosticSummary,
UpdateLanguageServer,
UpdateProject,
UpdateProjectCollaborator,
UpdateWorktree,
UpdateWorktreeSettings,
UpdateDiffBase
);
entity_messages!(
channel_id,
ChannelMessageSent,
UpdateChannelBuffer,
RemoveChannelMessage,
UpdateChannelBufferCollaborators,
);
const KIB: usize = 1024;
const MIB: usize = KIB * 1024;
const MAX_BUFFER_LEN: usize = MIB;
/// A stream of protobuf messages.
pub struct MessageStream<S> {
stream: S,
encoding_buffer: Vec<u8>,
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub enum Message {
Envelope(Envelope),
Ping,
Pong,
}
impl<S> MessageStream<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
encoding_buffer: Vec::new(),
}
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.stream
}
}
impl<S> MessageStream<S>
where
S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
{
pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
#[cfg(any(test, feature = "test-support"))]
const COMPRESSION_LEVEL: i32 = -7;
#[cfg(not(any(test, feature = "test-support")))]
const COMPRESSION_LEVEL: i32 = 4;
match message {
Message::Envelope(message) => {
self.encoding_buffer.reserve(message.encoded_len());
message
.encode(&mut self.encoding_buffer)
.map_err(io::Error::from)?;
let buffer =
zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
.unwrap();
self.encoding_buffer.clear();
self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
self.stream.send(WebSocketMessage::Binary(buffer)).await?;
}
Message::Ping => {
self.stream
.send(WebSocketMessage::Ping(Default::default()))
.await?;
}
Message::Pong => {
self.stream
.send(WebSocketMessage::Pong(Default::default()))
.await?;
}
}
Ok(())
}
}
impl<S> MessageStream<S>
where
S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
{
pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
while let Some(bytes) = self.stream.next().await {
match bytes? {
WebSocketMessage::Binary(bytes) => {
zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
let envelope = Envelope::decode(self.encoding_buffer.as_slice())
.map_err(io::Error::from)?;
self.encoding_buffer.clear();
self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
return Ok(Message::Envelope(envelope));
}
WebSocketMessage::Ping(_) => return Ok(Message::Ping),
WebSocketMessage::Pong(_) => return Ok(Message::Pong),
WebSocketMessage::Close(_) => break,
_ => {}
}
}
Err(anyhow!("connection closed"))
}
}
impl From<Timestamp> for SystemTime {
fn from(val: Timestamp) -> Self {
UNIX_EPOCH
.checked_add(Duration::new(val.seconds, val.nanos))
.unwrap()
}
}
impl From<SystemTime> for Timestamp {
fn from(time: SystemTime) -> Self {
let duration = time.duration_since(UNIX_EPOCH).unwrap();
Self {
seconds: duration.as_secs(),
nanos: duration.subsec_nanos(),
}
}
}
impl From<u128> for Nonce {
fn from(nonce: u128) -> Self {
let upper_half = (nonce >> 64) as u64;
let lower_half = nonce as u64;
Self {
upper_half,
lower_half,
}
}
}
impl From<Nonce> for u128 {
fn from(nonce: Nonce) -> Self {
let upper_half = (nonce.upper_half as u128) << 64;
let lower_half = nonce.lower_half as u128;
upper_half | lower_half
}
}
pub fn split_worktree_update(
mut message: UpdateWorktree,
max_chunk_size: usize,
) -> impl Iterator<Item = UpdateWorktree> {
let mut done_files = false;
let mut repository_map = message
.updated_repositories
.into_iter()
.map(|repo| (repo.work_directory_id, repo))
.collect::<HashMap<_, _>>();
iter::from_fn(move || {
if done_files {
return None;
}
let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
let updated_entries: Vec<_> = message
.updated_entries
.drain(..updated_entries_chunk_size)
.collect();
let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
let removed_entries = message
.removed_entries
.drain(..removed_entries_chunk_size)
.collect();
done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
let mut updated_repositories = Vec::new();
if !repository_map.is_empty() {
for entry in &updated_entries {
if let Some(repo) = repository_map.remove(&entry.id) {
updated_repositories.push(repo)
}
}
}
let removed_repositories = if done_files {
mem::take(&mut message.removed_repositories)
} else {
Default::default()
};
if done_files {
updated_repositories.extend(mem::take(&mut repository_map).into_values());
}
Some(UpdateWorktree {
project_id: message.project_id,
worktree_id: message.worktree_id,
root_name: message.root_name.clone(),
abs_path: message.abs_path.clone(),
updated_entries,
removed_entries,
scan_id: message.scan_id,
is_last_update: done_files && message.is_last_update,
updated_repositories,
removed_repositories,
})
})
}
#[cfg(test)]
mod tests {
use super::*;
#[gpui::test]
async fn test_buffer_size() {
let (tx, rx) = futures::channel::mpsc::unbounded();
let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
sink.write(Message::Envelope(Envelope {
payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
root_name: "abcdefg".repeat(10),
..Default::default()
})),
..Default::default()
}))
.await
.unwrap();
assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
sink.write(Message::Envelope(Envelope {
payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
root_name: "abcdefg".repeat(1000000),
..Default::default()
})),
..Default::default()
}))
.await
.unwrap();
assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
let mut stream = MessageStream::new(rx.map(anyhow::Ok));
stream.read().await.unwrap();
assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
stream.read().await.unwrap();
assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
}
#[gpui::test]
fn test_converting_peer_id_from_and_to_u64() {
let peer_id = PeerId {
owner_id: 10,
id: 3,
};
assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
let peer_id = PeerId {
owner_id: u32::MAX,
id: 3,
};
assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
let peer_id = PeerId {
owner_id: 10,
id: u32::MAX,
};
assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
let peer_id = PeerId {
owner_id: u32::MAX,
id: u32::MAX,
};
assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
}
}

9
crates/rpc2/src/rpc.rs Normal file
View File

@ -0,0 +1,9 @@
pub mod auth;
mod conn;
mod peer;
pub mod proto;
pub use conn::Connection;
pub use peer::*;
mod macros;
pub const PROTOCOL_VERSION: u32 = 64;