1
1
mirror of https://github.com/wez/wezterm.git synced 2024-11-10 15:04:32 +03:00

refactor client to prepare for reconnecting

This commit is contained in:
Wez Furlong 2019-06-26 08:30:04 -07:00
parent abfd98b7c6
commit 96c61fcd53
2 changed files with 103 additions and 67 deletions

View File

@ -5,7 +5,7 @@ use crate::mux::domain::alloc_domain_id;
use crate::mux::domain::DomainId;
use crate::mux::Mux;
use crate::server::codec::*;
use crate::server::domain::ClientDomain;
use crate::server::domain::{ClientDomain, ClientDomainConfig};
use crate::server::pollable::*;
use crate::server::tab::ClientTab;
use crate::server::UnixStream;
@ -104,7 +104,7 @@ fn process_unilateral(local_domain_id: DomainId, decoded: DecodedPdu) -> Fallibl
}
fn client_thread(
mut stream: Box<dyn ReadAndWrite>,
mut reconnectable: Reconnectable,
local_domain_id: DomainId,
rx: PollableReceiver<ReaderMessage>,
) -> Fallible<()> {
@ -120,8 +120,8 @@ fn client_thread(
next_serial += 1;
promises.insert(serial, promise);
pdu.encode(&mut stream, serial)?;
stream.flush()?;
pdu.encode(reconnectable.stream(), serial)?;
reconnectable.stream().flush()?;
}
},
Err(TryRecvError::Empty) => break,
@ -129,10 +129,10 @@ fn client_thread(
};
}
let mut poll_array = [rx.as_poll_fd(), stream.as_poll_fd()];
let mut poll_array = [rx.as_poll_fd(), reconnectable.stream().as_poll_fd()];
poll_for_read(&mut poll_array);
if poll_array[1].revents != 0 || stream.has_read_buffered() {
if poll_array[1].revents != 0 || reconnectable.stream().has_read_buffered() {
// When TLS is enabled on a stream, it may require a mixture of
// reads AND writes in order to satisfy a given read or write.
// As a result, we may appear ready to read a PDU, but may not
@ -140,9 +140,9 @@ fn client_thread(
// Set to non-blocking mode while we try to decode a packet to
// avoid blocking.
loop {
stream.set_non_blocking(true)?;
let res = Pdu::try_read_and_decode(&mut stream, &mut read_buffer);
stream.set_non_blocking(false)?;
reconnectable.stream().set_non_blocking(true)?;
let res = Pdu::try_read_and_decode(reconnectable.stream(), &mut read_buffer);
reconnectable.stream().set_non_blocking(false)?;
if let Some(decoded) = res? {
log::trace!("decoded serial {}", decoded.serial);
if decoded.serial == 0 {
@ -179,56 +179,28 @@ fn unix_connect_with_retry(path: &Path) -> Result<UnixStream, std::io::Error> {
Err(error)
}
impl Client {
pub fn new(local_domain_id: DomainId, stream: Box<dyn ReadAndWrite>) -> Self {
let (sender, receiver) = pollable_channel().expect("failed to create pollable_channel");
struct Reconnectable {
config: ClientDomainConfig,
stream: Option<Box<dyn ReadAndWrite>>,
}
thread::spawn(move || {
if let Err(e) = client_thread(stream, local_domain_id, receiver) {
log::debug!("client thread ended: {}", e);
Future::with_executor(gui_executor().unwrap(), move || {
let mux = Mux::get().unwrap();
let client_domain = mux
.get_domain(local_domain_id)
.ok_or_else(|| format_err!("no such domain {}", local_domain_id))?;
let client_domain =
client_domain
.downcast_ref::<ClientDomain>()
.ok_or_else(|| {
format_err!(
"domain {} is not a ClientDomain instance",
local_domain_id
)
})?;
client_domain.perform_detach();
Ok(())
});
}
});
impl Reconnectable {
fn new(config: ClientDomainConfig, stream: Option<Box<dyn ReadAndWrite>>) -> Self {
Self { config, stream }
}
Self {
sender,
local_domain_id,
fn stream(&mut self) -> &mut Box<dyn ReadAndWrite> {
self.stream.as_mut().unwrap()
}
fn connect(&mut self) -> Fallible<()> {
match self.config.clone() {
ClientDomainConfig::Unix(unix_dom) => self.unix_connect(unix_dom),
ClientDomainConfig::Tls(tls) => self.tls_connect(tls),
}
}
pub fn local_domain_id(&self) -> DomainId {
self.local_domain_id
}
pub fn new_default_unix_domain(config: &Arc<Config>) -> Fallible<Self> {
let unix_dom = config
.unix_domains
.first()
.ok_or_else(|| err_msg("no default unix domain is configured"))?;
Self::new_unix_domain(alloc_domain_id(), config, unix_dom)
}
pub fn new_unix_domain(
local_domain_id: DomainId,
_config: &Arc<Config>,
unix_dom: &UnixDomain,
) -> Fallible<Self> {
fn unix_connect(&mut self, unix_dom: UnixDomain) -> Fallible<()> {
let sock_path = unix_dom.socket_path();
info!("connect to {}", sock_path.display());
@ -252,15 +224,12 @@ impl Client {
};
let stream: Box<dyn ReadAndWrite> = Box::new(stream);
Ok(Self::new(local_domain_id, stream))
self.stream.replace(stream);
Ok(())
}
#[cfg(any(feature = "openssl", unix))]
pub fn new_tls(
local_domain_id: DomainId,
_config: &Arc<Config>,
tls_client: &TlsDomainClient,
) -> Fallible<Self> {
pub fn tls_connect(&mut self, tls_client: TlsDomainClient) -> Fallible<()> {
use crate::server::listener::read_bytes;
use openssl::ssl::{SslConnector, SslFiletype, SslMethod};
use openssl::x509::X509;
@ -333,15 +302,12 @@ impl Client {
)
})?,
);
Ok(Self::new(local_domain_id, stream))
self.stream.replace(stream);
Ok(())
}
#[cfg(not(any(feature = "openssl", unix)))]
pub fn new_tls(
local_domain_id: DomainId,
_config: &Arc<Config>,
tls_client: &TlsDomainClient,
) -> Fallible<Self> {
pub fn tls_connect(&mut self, tls_client: TlsDomainClient) -> Fallible<()> {
use crate::server::listener::IdentitySource;
use native_tls::TlsConnector;
use std::convert::TryInto;
@ -383,7 +349,76 @@ impl Client {
e
)
})?);
Ok(Self::new(local_domain_id, stream))
self.stream.replace(stream);
Ok(())
}
}
impl Client {
fn new(local_domain_id: DomainId, reconnectable: Reconnectable) -> Self {
let (sender, receiver) = pollable_channel().expect("failed to create pollable_channel");
thread::spawn(move || {
if let Err(e) = client_thread(reconnectable, local_domain_id, receiver) {
log::debug!("client thread ended: {}", e);
Future::with_executor(gui_executor().unwrap(), move || {
let mux = Mux::get().unwrap();
let client_domain = mux
.get_domain(local_domain_id)
.ok_or_else(|| format_err!("no such domain {}", local_domain_id))?;
let client_domain =
client_domain
.downcast_ref::<ClientDomain>()
.ok_or_else(|| {
format_err!(
"domain {} is not a ClientDomain instance",
local_domain_id
)
})?;
client_domain.perform_detach();
Ok(())
});
}
});
Self {
sender,
local_domain_id,
}
}
pub fn local_domain_id(&self) -> DomainId {
self.local_domain_id
}
pub fn new_default_unix_domain(config: &Arc<Config>) -> Fallible<Self> {
let unix_dom = config
.unix_domains
.first()
.ok_or_else(|| err_msg("no default unix domain is configured"))?;
Self::new_unix_domain(alloc_domain_id(), config, unix_dom)
}
pub fn new_unix_domain(
local_domain_id: DomainId,
_config: &Arc<Config>,
unix_dom: &UnixDomain,
) -> Fallible<Self> {
let mut reconnectable =
Reconnectable::new(ClientDomainConfig::Unix(unix_dom.clone()), None);
reconnectable.connect()?;
Ok(Self::new(local_domain_id, reconnectable))
}
pub fn new_tls(
local_domain_id: DomainId,
_config: &Arc<Config>,
tls_client: &TlsDomainClient,
) -> Fallible<Self> {
let mut reconnectable =
Reconnectable::new(ClientDomainConfig::Tls(tls_client.clone()), None);
reconnectable.connect()?;
Ok(Self::new(local_domain_id, reconnectable))
}
pub fn send_pdu(&self, pdu: Pdu) -> Future<Pdu> {

View File

@ -54,6 +54,7 @@ impl ClientInner {
}
}
#[derive(Clone, Debug)]
pub enum ClientDomainConfig {
Unix(UnixDomain),
Tls(TlsDomainClient),