diff --git a/src/server/client.rs b/src/server/client.rs index a7e04fe16..29a30f909 100644 --- a/src/server/client.rs +++ b/src/server/client.rs @@ -5,7 +5,7 @@ use crate::mux::domain::alloc_domain_id; use crate::mux::domain::DomainId; use crate::mux::Mux; use crate::server::codec::*; -use crate::server::domain::ClientDomain; +use crate::server::domain::{ClientDomain, ClientDomainConfig}; use crate::server::pollable::*; use crate::server::tab::ClientTab; use crate::server::UnixStream; @@ -104,7 +104,7 @@ fn process_unilateral(local_domain_id: DomainId, decoded: DecodedPdu) -> Fallibl } fn client_thread( - mut stream: Box, + mut reconnectable: Reconnectable, local_domain_id: DomainId, rx: PollableReceiver, ) -> Fallible<()> { @@ -120,8 +120,8 @@ fn client_thread( next_serial += 1; promises.insert(serial, promise); - pdu.encode(&mut stream, serial)?; - stream.flush()?; + pdu.encode(reconnectable.stream(), serial)?; + reconnectable.stream().flush()?; } }, Err(TryRecvError::Empty) => break, @@ -129,10 +129,10 @@ fn client_thread( }; } - let mut poll_array = [rx.as_poll_fd(), stream.as_poll_fd()]; + let mut poll_array = [rx.as_poll_fd(), reconnectable.stream().as_poll_fd()]; poll_for_read(&mut poll_array); - if poll_array[1].revents != 0 || stream.has_read_buffered() { + if poll_array[1].revents != 0 || reconnectable.stream().has_read_buffered() { // When TLS is enabled on a stream, it may require a mixture of // reads AND writes in order to satisfy a given read or write. // As a result, we may appear ready to read a PDU, but may not @@ -140,9 +140,9 @@ fn client_thread( // Set to non-blocking mode while we try to decode a packet to // avoid blocking. loop { - stream.set_non_blocking(true)?; - let res = Pdu::try_read_and_decode(&mut stream, &mut read_buffer); - stream.set_non_blocking(false)?; + reconnectable.stream().set_non_blocking(true)?; + let res = Pdu::try_read_and_decode(reconnectable.stream(), &mut read_buffer); + reconnectable.stream().set_non_blocking(false)?; if let Some(decoded) = res? { log::trace!("decoded serial {}", decoded.serial); if decoded.serial == 0 { @@ -179,56 +179,28 @@ fn unix_connect_with_retry(path: &Path) -> Result { Err(error) } -impl Client { - pub fn new(local_domain_id: DomainId, stream: Box) -> Self { - let (sender, receiver) = pollable_channel().expect("failed to create pollable_channel"); +struct Reconnectable { + config: ClientDomainConfig, + stream: Option>, +} - thread::spawn(move || { - if let Err(e) = client_thread(stream, local_domain_id, receiver) { - log::debug!("client thread ended: {}", e); - Future::with_executor(gui_executor().unwrap(), move || { - let mux = Mux::get().unwrap(); - let client_domain = mux - .get_domain(local_domain_id) - .ok_or_else(|| format_err!("no such domain {}", local_domain_id))?; - let client_domain = - client_domain - .downcast_ref::() - .ok_or_else(|| { - format_err!( - "domain {} is not a ClientDomain instance", - local_domain_id - ) - })?; - client_domain.perform_detach(); - Ok(()) - }); - } - }); +impl Reconnectable { + fn new(config: ClientDomainConfig, stream: Option>) -> Self { + Self { config, stream } + } - Self { - sender, - local_domain_id, + fn stream(&mut self) -> &mut Box { + self.stream.as_mut().unwrap() + } + + fn connect(&mut self) -> Fallible<()> { + match self.config.clone() { + ClientDomainConfig::Unix(unix_dom) => self.unix_connect(unix_dom), + ClientDomainConfig::Tls(tls) => self.tls_connect(tls), } } - pub fn local_domain_id(&self) -> DomainId { - self.local_domain_id - } - - pub fn new_default_unix_domain(config: &Arc) -> Fallible { - let unix_dom = config - .unix_domains - .first() - .ok_or_else(|| err_msg("no default unix domain is configured"))?; - Self::new_unix_domain(alloc_domain_id(), config, unix_dom) - } - - pub fn new_unix_domain( - local_domain_id: DomainId, - _config: &Arc, - unix_dom: &UnixDomain, - ) -> Fallible { + fn unix_connect(&mut self, unix_dom: UnixDomain) -> Fallible<()> { let sock_path = unix_dom.socket_path(); info!("connect to {}", sock_path.display()); @@ -252,15 +224,12 @@ impl Client { }; let stream: Box = Box::new(stream); - Ok(Self::new(local_domain_id, stream)) + self.stream.replace(stream); + Ok(()) } #[cfg(any(feature = "openssl", unix))] - pub fn new_tls( - local_domain_id: DomainId, - _config: &Arc, - tls_client: &TlsDomainClient, - ) -> Fallible { + pub fn tls_connect(&mut self, tls_client: TlsDomainClient) -> Fallible<()> { use crate::server::listener::read_bytes; use openssl::ssl::{SslConnector, SslFiletype, SslMethod}; use openssl::x509::X509; @@ -333,15 +302,12 @@ impl Client { ) })?, ); - Ok(Self::new(local_domain_id, stream)) + self.stream.replace(stream); + Ok(()) } #[cfg(not(any(feature = "openssl", unix)))] - pub fn new_tls( - local_domain_id: DomainId, - _config: &Arc, - tls_client: &TlsDomainClient, - ) -> Fallible { + pub fn tls_connect(&mut self, tls_client: TlsDomainClient) -> Fallible<()> { use crate::server::listener::IdentitySource; use native_tls::TlsConnector; use std::convert::TryInto; @@ -383,7 +349,76 @@ impl Client { e ) })?); - Ok(Self::new(local_domain_id, stream)) + self.stream.replace(stream); + Ok(()) + } +} + +impl Client { + fn new(local_domain_id: DomainId, reconnectable: Reconnectable) -> Self { + let (sender, receiver) = pollable_channel().expect("failed to create pollable_channel"); + + thread::spawn(move || { + if let Err(e) = client_thread(reconnectable, local_domain_id, receiver) { + log::debug!("client thread ended: {}", e); + Future::with_executor(gui_executor().unwrap(), move || { + let mux = Mux::get().unwrap(); + let client_domain = mux + .get_domain(local_domain_id) + .ok_or_else(|| format_err!("no such domain {}", local_domain_id))?; + let client_domain = + client_domain + .downcast_ref::() + .ok_or_else(|| { + format_err!( + "domain {} is not a ClientDomain instance", + local_domain_id + ) + })?; + client_domain.perform_detach(); + Ok(()) + }); + } + }); + + Self { + sender, + local_domain_id, + } + } + + pub fn local_domain_id(&self) -> DomainId { + self.local_domain_id + } + + pub fn new_default_unix_domain(config: &Arc) -> Fallible { + let unix_dom = config + .unix_domains + .first() + .ok_or_else(|| err_msg("no default unix domain is configured"))?; + Self::new_unix_domain(alloc_domain_id(), config, unix_dom) + } + + pub fn new_unix_domain( + local_domain_id: DomainId, + _config: &Arc, + unix_dom: &UnixDomain, + ) -> Fallible { + let mut reconnectable = + Reconnectable::new(ClientDomainConfig::Unix(unix_dom.clone()), None); + reconnectable.connect()?; + Ok(Self::new(local_domain_id, reconnectable)) + } + + pub fn new_tls( + local_domain_id: DomainId, + _config: &Arc, + tls_client: &TlsDomainClient, + ) -> Fallible { + let mut reconnectable = + Reconnectable::new(ClientDomainConfig::Tls(tls_client.clone()), None); + reconnectable.connect()?; + Ok(Self::new(local_domain_id, reconnectable)) } pub fn send_pdu(&self, pdu: Pdu) -> Future { diff --git a/src/server/domain.rs b/src/server/domain.rs index 0a18f7ee7..6dfdf9d73 100644 --- a/src/server/domain.rs +++ b/src/server/domain.rs @@ -54,6 +54,7 @@ impl ClientInner { } } +#[derive(Clone, Debug)] pub enum ClientDomainConfig { Unix(UnixDomain), Tls(TlsDomainClient),