diff --git a/Cargo.lock b/Cargo.lock index 515b7571a0..c0e93600d2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2351,6 +2351,7 @@ dependencies = [ "thiserror", "time", "tiny_http", + "tokio-socks", "url", "util", "windows 0.58.0", @@ -11349,6 +11350,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-socks" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d4770b8024672c1101b3f6733eab95b18007dbe0847a8afe341fcf79e06043f" +dependencies = [ + "either", + "futures-io", + "futures-util", + "thiserror", +] + [[package]] name = "tokio-stream" version = "0.1.15" diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index da470fb1e6..72ca8ffc24 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -48,6 +48,7 @@ text.workspace = true thiserror.workspace = true time.workspace = true tiny_http = "0.8" +tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] } url.workspace = true util.workspace = true worktree.workspace = true diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index e916dcbd2c..ccd7bda344 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -1,6 +1,7 @@ #[cfg(any(test, feature = "test-support"))] pub mod test; +mod socks; pub mod telemetry; pub mod user; @@ -31,6 +32,7 @@ use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, Requ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; +use socks::connect_socks_proxy_stream; use std::fmt; use std::pin::Pin; use std::{ @@ -1177,6 +1179,7 @@ impl Client { .unwrap_or_default(); let http = self.http.clone(); + let proxy = http.proxy().cloned(); let credentials = credentials.clone(); let rpc_url = self.rpc_url(http, release_channel); cx.background_executor().spawn(async move { @@ -1198,7 +1201,7 @@ impl Client { .host_str() .zip(rpc_url.port_or_known_default()) .ok_or_else(|| anyhow!("missing host in rpc url"))?; - let stream = smol::net::TcpStream::connect(rpc_host).await?; + let stream = connect_socks_proxy_stream(proxy.as_ref(), rpc_host).await?; log::info!("connected to rpc endpoint {}", rpc_url); diff --git a/crates/client/src/socks.rs b/crates/client/src/socks.rs new file mode 100644 index 0000000000..de4300b1d6 --- /dev/null +++ b/crates/client/src/socks.rs @@ -0,0 +1,68 @@ +//! socks proxy +use anyhow::{anyhow, Result}; +use futures::io::{AsyncRead, AsyncWrite}; +use http_client::Uri; +use tokio_socks::{ + io::Compat, + tcp::{Socks4Stream, Socks5Stream}, +}; + +pub(crate) async fn connect_socks_proxy_stream( + proxy: Option<&Uri>, + rpc_host: (&str, u16), +) -> Result> { + let stream = match parse_socks_proxy(proxy) { + Some((socks_proxy, SocksVersion::V4)) => { + let stream = Socks4Stream::connect_with_socket( + Compat::new(smol::net::TcpStream::connect(socks_proxy).await?), + rpc_host, + ) + .await + .map_err(|err| anyhow!("error connecting to socks {}", err))?; + Box::new(stream) as Box + } + Some((socks_proxy, SocksVersion::V5)) => Box::new( + Socks5Stream::connect_with_socket( + Compat::new(smol::net::TcpStream::connect(socks_proxy).await?), + rpc_host, + ) + .await + .map_err(|err| anyhow!("error connecting to socks {}", err))?, + ) as Box, + None => Box::new(smol::net::TcpStream::connect(rpc_host).await?) as Box, + }; + Ok(stream) +} + +fn parse_socks_proxy(proxy: Option<&Uri>) -> Option<((String, u16), SocksVersion)> { + let Some(proxy_uri) = proxy else { + return None; + }; + let Some(scheme) = proxy_uri.scheme_str() else { + return None; + }; + let socks_version = if scheme.starts_with("socks4") { + // socks4 + SocksVersion::V4 + } else if scheme.starts_with("socks") { + // socks, socks5 + SocksVersion::V5 + } else { + return None; + }; + if let (Some(host), Some(port)) = (proxy_uri.host(), proxy_uri.port_u16()) { + Some(((host.to_string(), port), socks_version)) + } else { + None + } +} + +// private helper structs and traits + +enum SocksVersion { + V4, + V5, +} + +pub(crate) trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + Send + 'static {} +impl AsyncReadWrite for T {}