diff --git a/Cargo.lock b/Cargo.lock index 370fc7dff4..7f7ae8a0fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -663,11 +663,11 @@ dependencies = [ [[package]] name = "async-native-tls" -version = "0.3.3" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e9e7a929bd34c68a82d58a4de7f86fffdaf97fb2af850162a7bb19dd7269b33" +checksum = "d57d4cec3c647232e1094dc013546c0b33ce785d8aeb251e1f20dfaf8a9a13fe" dependencies = [ - "async-std", + "futures-util", "native-tls", "thiserror", "url", @@ -876,17 +876,17 @@ dependencies = [ [[package]] name = "async-tungstenite" -version = "0.16.1" +version = "0.17.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5682ea0913e5c20780fe5785abacb85a411e7437bf52a1bedb93ddb3972cb8dd" +checksum = "a1b71b31561643aa8e7df3effe284fa83ab1a840e52294c5f4bd7bfd8b2becbb" dependencies = [ - "async-native-tls 0.3.3", + "async-native-tls 0.4.0", "async-std", "futures-io", "futures-util", "log", "pin-project-lite", - "tungstenite 0.16.0", + "tungstenite 0.17.3", ] [[package]] @@ -9636,15 +9636,13 @@ dependencies = [ [[package]] name = "sha-1" -version = "0.9.8" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99cd6713db3cf16b6c84e06321e049a9b9f699826e16096d23bbcc44d15d51a6" +checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c" dependencies = [ - "block-buffer 0.9.0", "cfg-if", "cpufeatures", - "digest 0.9.0", - "opaque-debug", + "digest 0.10.7", ] [[package]] @@ -11679,9 +11677,9 @@ checksum = "2c591d83f69777866b9126b24c6dd9a18351f177e49d625920d19f989fd31cf8" [[package]] name = "tungstenite" -version = "0.16.0" +version = "0.17.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ad3713a14ae247f22a728a0456a545df14acf3867f905adff84be99e23b3ad1" +checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0" dependencies = [ "base64 0.13.1", "byteorder", diff --git a/Cargo.toml b/Cargo.toml index 5f5de1f831..f374acdd97 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -296,7 +296,7 @@ async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "8 async-recursion = "1.0.0" async-tar = "0.4.2" async-trait = "0.1" -async-tungstenite = { version = "0.16" } +async-tungstenite = { version = "0.17" } async-watch = "0.3.1" async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] } base64 = "0.13" diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index e0c6690bb9..9559d12b6a 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -7,8 +7,9 @@ pub mod user; use anyhow::{anyhow, Context as _, Result}; use async_recursion::async_recursion; use async_tungstenite::tungstenite::{ + client::IntoClientRequest, error::Error as WebsocketError, - http::{Request, StatusCode}, + http::{HeaderValue, Request, StatusCode}, }; use clock::SystemClock; use collections::HashMap; @@ -235,6 +236,8 @@ pub enum EstablishConnectionError { #[error("{0}")] Http(#[from] http_client::Error), #[error("{0}")] + InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue), + #[error("{0}")] Io(#[from] std::io::Error), #[error("{0}")] Websocket(#[from] async_tungstenite::tungstenite::http::Error), @@ -1159,19 +1162,24 @@ impl Client { .ok() .unwrap_or_default(); - let request = Request::builder() - .header("Authorization", credentials.authorization_header()) - .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION) - .header("x-zed-app-version", app_version) - .header( - "x-zed-release-channel", - release_channel.map(|r| r.dev_name()).unwrap_or("unknown"), - ); - let http = self.http.clone(); + let credentials = credentials.clone(); let rpc_url = self.rpc_url(http, release_channel); cx.background_executor().spawn(async move { + use HttpOrHttps::*; + + #[derive(Debug)] + enum HttpOrHttps { + Http, + Https, + } + let mut rpc_url = rpc_url.await?; + let url_scheme = match rpc_url.scheme() { + "https" => Https, + "http" => Http, + _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?, + }; let rpc_host = rpc_url .host_str() .zip(rpc_url.port_or_known_default()) @@ -1180,10 +1188,37 @@ impl Client { log::info!("connected to rpc endpoint {}", rpc_url); - match rpc_url.scheme() { - "https" => { - rpc_url.set_scheme("wss").unwrap(); - let request = request.uri(rpc_url.as_str()).body(())?; + rpc_url + .set_scheme(match url_scheme { + Https => "wss", + Http => "ws", + }) + .unwrap(); + + // We call `into_client_request` to let `tungstenite` construct the WebSocket request + // for us from the RPC URL. + // + // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us. + let mut request = rpc_url.into_client_request()?; + + // We then modify the request to add our desired headers. + let request_headers = request.headers_mut(); + request_headers.insert( + "Authorization", + HeaderValue::from_str(&credentials.authorization_header())?, + ); + request_headers.insert( + "x-zed-protocol-version", + HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?, + ); + request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?); + request_headers.insert( + "x-zed-release-channel", + HeaderValue::from_str(&release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?, + ); + + match url_scheme { + Https => { let (stream, _) = async_tungstenite::async_std::client_async_tls(request, stream).await?; Ok(Connection::new( @@ -1192,9 +1227,7 @@ impl Client { .sink_map_err(|error| anyhow!(error)), )) } - "http" => { - rpc_url.set_scheme("ws").unwrap(); - let request = request.uri(rpc_url.as_str()).body(())?; + Http => { let (stream, _) = async_tungstenite::client_async(request, stream).await?; Ok(Connection::new( stream @@ -1202,7 +1235,6 @@ impl Client { .sink_map_err(|error| anyhow!(error)), )) } - _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?, } }) } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 1c5615e549..863ac1071b 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -12,7 +12,7 @@ use crate::{ executor::Executor, AppState, Error, RateLimit, RateLimiter, Result, }; -use anyhow::{anyhow, Context as _}; +use anyhow::{anyhow, bail, Context as _}; use async_tungstenite::tungstenite::{ protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage, }; @@ -1392,7 +1392,7 @@ pub async fn handle_websocket_request( let socket = socket .map_ok(to_tungstenite_message) .err_into() - .with(|message| async move { Ok(to_axum_message(message)) }); + .with(|message| async move { to_axum_message(message) }); let connection = Connection::new(Box::pin(socket)); async move { server @@ -5154,8 +5154,8 @@ async fn get_private_user_info( Ok(()) } -fn to_axum_message(message: TungsteniteMessage) -> AxumMessage { - match message { +fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result { + let message = match message { TungsteniteMessage::Text(payload) => AxumMessage::Text(payload), TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload), TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload), @@ -5164,7 +5164,20 @@ fn to_axum_message(message: TungsteniteMessage) -> AxumMessage { code: frame.code.into(), reason: frame.reason, })), - } + // We should never receive a frame while reading the message, according + // to the `tungstenite` maintainers: + // + // > It cannot occur when you read messages from the WebSocket, but it + // > can be used when you want to send the raw frames (e.g. you want to + // > send the frames to the WebSocket without composing the full message first). + // > + // > — https://github.com/snapview/tungstenite-rs/issues/268 + TungsteniteMessage::Frame(_) => { + bail!("received an unexpected frame while reading the message") + } + }; + + Ok(message) } fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {