1
1
mirror of https://github.com/wez/wezterm.git synced 2024-12-26 14:54:16 +03:00

refactor client to prepare for reconnecting

This commit is contained in:
Wez Furlong 2019-06-26 08:30:04 -07:00
parent abfd98b7c6
commit 96c61fcd53
2 changed files with 103 additions and 67 deletions

View File

@ -5,7 +5,7 @@ use crate::mux::domain::alloc_domain_id;
use crate::mux::domain::DomainId; use crate::mux::domain::DomainId;
use crate::mux::Mux; use crate::mux::Mux;
use crate::server::codec::*; use crate::server::codec::*;
use crate::server::domain::ClientDomain; use crate::server::domain::{ClientDomain, ClientDomainConfig};
use crate::server::pollable::*; use crate::server::pollable::*;
use crate::server::tab::ClientTab; use crate::server::tab::ClientTab;
use crate::server::UnixStream; use crate::server::UnixStream;
@ -104,7 +104,7 @@ fn process_unilateral(local_domain_id: DomainId, decoded: DecodedPdu) -> Fallibl
} }
fn client_thread( fn client_thread(
mut stream: Box<dyn ReadAndWrite>, mut reconnectable: Reconnectable,
local_domain_id: DomainId, local_domain_id: DomainId,
rx: PollableReceiver<ReaderMessage>, rx: PollableReceiver<ReaderMessage>,
) -> Fallible<()> { ) -> Fallible<()> {
@ -120,8 +120,8 @@ fn client_thread(
next_serial += 1; next_serial += 1;
promises.insert(serial, promise); promises.insert(serial, promise);
pdu.encode(&mut stream, serial)?; pdu.encode(reconnectable.stream(), serial)?;
stream.flush()?; reconnectable.stream().flush()?;
} }
}, },
Err(TryRecvError::Empty) => break, Err(TryRecvError::Empty) => break,
@ -129,10 +129,10 @@ fn client_thread(
}; };
} }
let mut poll_array = [rx.as_poll_fd(), stream.as_poll_fd()]; let mut poll_array = [rx.as_poll_fd(), reconnectable.stream().as_poll_fd()];
poll_for_read(&mut poll_array); poll_for_read(&mut poll_array);
if poll_array[1].revents != 0 || stream.has_read_buffered() { if poll_array[1].revents != 0 || reconnectable.stream().has_read_buffered() {
// When TLS is enabled on a stream, it may require a mixture of // When TLS is enabled on a stream, it may require a mixture of
// reads AND writes in order to satisfy a given read or write. // reads AND writes in order to satisfy a given read or write.
// As a result, we may appear ready to read a PDU, but may not // As a result, we may appear ready to read a PDU, but may not
@ -140,9 +140,9 @@ fn client_thread(
// Set to non-blocking mode while we try to decode a packet to // Set to non-blocking mode while we try to decode a packet to
// avoid blocking. // avoid blocking.
loop { loop {
stream.set_non_blocking(true)?; reconnectable.stream().set_non_blocking(true)?;
let res = Pdu::try_read_and_decode(&mut stream, &mut read_buffer); let res = Pdu::try_read_and_decode(reconnectable.stream(), &mut read_buffer);
stream.set_non_blocking(false)?; reconnectable.stream().set_non_blocking(false)?;
if let Some(decoded) = res? { if let Some(decoded) = res? {
log::trace!("decoded serial {}", decoded.serial); log::trace!("decoded serial {}", decoded.serial);
if decoded.serial == 0 { if decoded.serial == 0 {
@ -179,56 +179,28 @@ fn unix_connect_with_retry(path: &Path) -> Result<UnixStream, std::io::Error> {
Err(error) Err(error)
} }
impl Client { struct Reconnectable {
pub fn new(local_domain_id: DomainId, stream: Box<dyn ReadAndWrite>) -> Self { config: ClientDomainConfig,
let (sender, receiver) = pollable_channel().expect("failed to create pollable_channel"); stream: Option<Box<dyn ReadAndWrite>>,
}
thread::spawn(move || { impl Reconnectable {
if let Err(e) = client_thread(stream, local_domain_id, receiver) { fn new(config: ClientDomainConfig, stream: Option<Box<dyn ReadAndWrite>>) -> Self {
log::debug!("client thread ended: {}", e); Self { config, stream }
Future::with_executor(gui_executor().unwrap(), move || { }
let mux = Mux::get().unwrap();
let client_domain = mux
.get_domain(local_domain_id)
.ok_or_else(|| format_err!("no such domain {}", local_domain_id))?;
let client_domain =
client_domain
.downcast_ref::<ClientDomain>()
.ok_or_else(|| {
format_err!(
"domain {} is not a ClientDomain instance",
local_domain_id
)
})?;
client_domain.perform_detach();
Ok(())
});
}
});
Self { fn stream(&mut self) -> &mut Box<dyn ReadAndWrite> {
sender, self.stream.as_mut().unwrap()
local_domain_id, }
fn connect(&mut self) -> Fallible<()> {
match self.config.clone() {
ClientDomainConfig::Unix(unix_dom) => self.unix_connect(unix_dom),
ClientDomainConfig::Tls(tls) => self.tls_connect(tls),
} }
} }
pub fn local_domain_id(&self) -> DomainId { fn unix_connect(&mut self, unix_dom: UnixDomain) -> Fallible<()> {
self.local_domain_id
}
pub fn new_default_unix_domain(config: &Arc<Config>) -> Fallible<Self> {
let unix_dom = config
.unix_domains
.first()
.ok_or_else(|| err_msg("no default unix domain is configured"))?;
Self::new_unix_domain(alloc_domain_id(), config, unix_dom)
}
pub fn new_unix_domain(
local_domain_id: DomainId,
_config: &Arc<Config>,
unix_dom: &UnixDomain,
) -> Fallible<Self> {
let sock_path = unix_dom.socket_path(); let sock_path = unix_dom.socket_path();
info!("connect to {}", sock_path.display()); info!("connect to {}", sock_path.display());
@ -252,15 +224,12 @@ impl Client {
}; };
let stream: Box<dyn ReadAndWrite> = Box::new(stream); let stream: Box<dyn ReadAndWrite> = Box::new(stream);
Ok(Self::new(local_domain_id, stream)) self.stream.replace(stream);
Ok(())
} }
#[cfg(any(feature = "openssl", unix))] #[cfg(any(feature = "openssl", unix))]
pub fn new_tls( pub fn tls_connect(&mut self, tls_client: TlsDomainClient) -> Fallible<()> {
local_domain_id: DomainId,
_config: &Arc<Config>,
tls_client: &TlsDomainClient,
) -> Fallible<Self> {
use crate::server::listener::read_bytes; use crate::server::listener::read_bytes;
use openssl::ssl::{SslConnector, SslFiletype, SslMethod}; use openssl::ssl::{SslConnector, SslFiletype, SslMethod};
use openssl::x509::X509; use openssl::x509::X509;
@ -333,15 +302,12 @@ impl Client {
) )
})?, })?,
); );
Ok(Self::new(local_domain_id, stream)) self.stream.replace(stream);
Ok(())
} }
#[cfg(not(any(feature = "openssl", unix)))] #[cfg(not(any(feature = "openssl", unix)))]
pub fn new_tls( pub fn tls_connect(&mut self, tls_client: TlsDomainClient) -> Fallible<()> {
local_domain_id: DomainId,
_config: &Arc<Config>,
tls_client: &TlsDomainClient,
) -> Fallible<Self> {
use crate::server::listener::IdentitySource; use crate::server::listener::IdentitySource;
use native_tls::TlsConnector; use native_tls::TlsConnector;
use std::convert::TryInto; use std::convert::TryInto;
@ -383,7 +349,76 @@ impl Client {
e e
) )
})?); })?);
Ok(Self::new(local_domain_id, stream)) self.stream.replace(stream);
Ok(())
}
}
impl Client {
fn new(local_domain_id: DomainId, reconnectable: Reconnectable) -> Self {
let (sender, receiver) = pollable_channel().expect("failed to create pollable_channel");
thread::spawn(move || {
if let Err(e) = client_thread(reconnectable, local_domain_id, receiver) {
log::debug!("client thread ended: {}", e);
Future::with_executor(gui_executor().unwrap(), move || {
let mux = Mux::get().unwrap();
let client_domain = mux
.get_domain(local_domain_id)
.ok_or_else(|| format_err!("no such domain {}", local_domain_id))?;
let client_domain =
client_domain
.downcast_ref::<ClientDomain>()
.ok_or_else(|| {
format_err!(
"domain {} is not a ClientDomain instance",
local_domain_id
)
})?;
client_domain.perform_detach();
Ok(())
});
}
});
Self {
sender,
local_domain_id,
}
}
pub fn local_domain_id(&self) -> DomainId {
self.local_domain_id
}
pub fn new_default_unix_domain(config: &Arc<Config>) -> Fallible<Self> {
let unix_dom = config
.unix_domains
.first()
.ok_or_else(|| err_msg("no default unix domain is configured"))?;
Self::new_unix_domain(alloc_domain_id(), config, unix_dom)
}
pub fn new_unix_domain(
local_domain_id: DomainId,
_config: &Arc<Config>,
unix_dom: &UnixDomain,
) -> Fallible<Self> {
let mut reconnectable =
Reconnectable::new(ClientDomainConfig::Unix(unix_dom.clone()), None);
reconnectable.connect()?;
Ok(Self::new(local_domain_id, reconnectable))
}
pub fn new_tls(
local_domain_id: DomainId,
_config: &Arc<Config>,
tls_client: &TlsDomainClient,
) -> Fallible<Self> {
let mut reconnectable =
Reconnectable::new(ClientDomainConfig::Tls(tls_client.clone()), None);
reconnectable.connect()?;
Ok(Self::new(local_domain_id, reconnectable))
} }
pub fn send_pdu(&self, pdu: Pdu) -> Future<Pdu> { pub fn send_pdu(&self, pdu: Pdu) -> Future<Pdu> {

View File

@ -54,6 +54,7 @@ impl ClientInner {
} }
} }
#[derive(Clone, Debug)]
pub enum ClientDomainConfig { pub enum ClientDomainConfig {
Unix(UnixDomain), Unix(UnixDomain),
Tls(TlsDomainClient), Tls(TlsDomainClient),