1
1
mirror of https://github.com/wez/wezterm.git synced 2024-12-23 21:32:13 +03:00

Fix file asyncread and asyncwrite future handling by keeping stateful futures

This commit is contained in:
Chip Senkbeil 2021-09-25 02:18:00 -05:00 committed by Wez Furlong
parent 35043e9341
commit 50a0372e17
4 changed files with 244 additions and 120 deletions

View File

@ -654,7 +654,7 @@ impl SessionInner {
let file_id = self.next_file_id; let file_id = self.next_file_id;
self.next_file_id += 1; self.next_file_id += 1;
let file = File { file_id, tx: None }; let file = File::new(file_id);
(file_id, file) (file_id, file)
} }

View File

@ -35,7 +35,7 @@ impl Sftp {
}))) })))
.await?; .await?;
let mut result = rx.recv().await?; let mut result = rx.recv().await?;
result.tx.replace(self.tx.clone()); result.initialize_sender(self.tx.clone());
Ok(result) Ok(result)
} }
@ -51,7 +51,7 @@ impl Sftp {
}))) })))
.await?; .await?;
let mut result = rx.recv().await?; let mut result = rx.recv().await?;
result.tx.replace(self.tx.clone()); result.initialize_sender(self.tx.clone());
Ok(result) Ok(result)
} }
@ -67,7 +67,7 @@ impl Sftp {
}))) })))
.await?; .await?;
let mut result = rx.recv().await?; let mut result = rx.recv().await?;
result.tx.replace(self.tx.clone()); result.initialize_sender(self.tx.clone());
Ok(result) Ok(result)
} }
@ -83,7 +83,7 @@ impl Sftp {
}))) })))
.await?; .await?;
let mut result = rx.recv().await?; let mut result = rx.recv().await?;
result.tx.replace(self.tx.clone()); result.initialize_sender(self.tx.clone());
Ok(result) Ok(result)
} }

View File

@ -1,101 +1,172 @@
use super::{ use super::{
CloseFile, FlushFile, ReadFile, SessionRequest, SessionSender, SftpRequest, WriteFile, CloseFile, FlushFile, ReadFile, SessionRequest, SessionSender, SftpRequest, WriteFile,
}; };
use smol::channel::bounded; use smol::{channel::bounded, future::FutureExt};
use std::{
fmt,
future::Future,
io,
pin::Pin,
task::{Context, Poll},
};
pub(crate) type FileId = usize; pub(crate) type FileId = usize;
/// A file handle to an SFTP connection. /// A file handle to an SFTP connection.
#[derive(Clone, Debug)]
pub struct File { pub struct File {
pub(crate) file_id: FileId, pub(crate) file_id: FileId,
pub(crate) tx: Option<SessionSender>, tx: Option<SessionSender>,
state: FileState,
}
#[derive(Default)]
struct FileState {
f_read: Option<Pin<Box<dyn Future<Output = io::Result<Vec<u8>>> + Send + Sync + 'static>>>,
f_write: Option<Pin<Box<dyn Future<Output = io::Result<usize>> + Send + Sync + 'static>>>,
f_flush: Option<Pin<Box<dyn Future<Output = io::Result<()>> + Send + Sync + 'static>>>,
f_close: Option<Pin<Box<dyn Future<Output = io::Result<()>> + Send + Sync + 'static>>>,
}
impl fmt::Debug for File {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("File")
.field("file_id", &self.file_id)
.finish()
}
}
impl File {
pub(crate) fn new(file_id: FileId) -> Self {
Self {
file_id,
tx: None,
state: Default::default(),
}
}
pub(crate) fn initialize_sender(&mut self, sender: SessionSender) {
self.tx.replace(sender);
}
} }
impl smol::io::AsyncRead for File { impl smol::io::AsyncRead for File {
fn poll_read( fn poll_read(
self: std::pin::Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut Context<'_>,
buf: &mut [u8], buf: &mut [u8],
) -> std::task::Poll<std::io::Result<usize>> { ) -> Poll<io::Result<usize>> {
use smol::future::FutureExt; async fn read(tx: SessionSender, file_id: usize, len: usize) -> io::Result<Vec<u8>> {
async fn read( inner_read(tx, file_id, len)
mut _self: std::pin::Pin<&mut File>,
buf: &mut [u8],
) -> std::io::Result<usize> {
let data = _self
.read(buf.len())
.await .await
.map_err(|x| std::io::Error::new(std::io::ErrorKind::Other, x))?; .map_err(|x| io::Error::new(io::ErrorKind::Other, x))
let n = data.len(); }
let tx = self.tx.as_ref().unwrap().clone();
let file_id = self.file_id;
buf.copy_from_slice(&data[..n]); let poll = self
.state
.f_read
.get_or_insert_with(|| Box::pin(read(tx, file_id, buf.len())))
.poll(cx);
Ok(n) if poll.is_ready() {
self.state.f_read.take();
} }
Box::pin(read(self, buf)).poll(cx) match poll {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(x)) => Poll::Ready(Err(x)),
Poll::Ready(Ok(data)) => {
let n = data.len();
(&mut buf[..n]).copy_from_slice(&data[..n]);
Poll::Ready(Ok(n))
}
}
} }
} }
impl smol::io::AsyncWrite for File { impl smol::io::AsyncWrite for File {
fn poll_write( fn poll_write(
self: std::pin::Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut Context<'_>,
buf: &[u8], buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> { ) -> Poll<io::Result<usize>> {
use smol::future::FutureExt; async fn write(tx: SessionSender, file_id: usize, buf: Vec<u8>) -> io::Result<usize> {
async fn write(mut _self: std::pin::Pin<&mut File>, buf: &[u8]) -> std::io::Result<usize> { let n = buf.len();
_self inner_write(tx, file_id, buf)
.write(buf.to_vec())
.await .await
.map(|_| buf.len()) .map(|_| n)
.map_err(|x| std::io::Error::new(std::io::ErrorKind::Other, x)) .map_err(|x| io::Error::new(io::ErrorKind::Other, x))
} }
Box::pin(write(self, buf)).poll(cx) let tx = self.tx.as_ref().unwrap().clone();
let file_id = self.file_id;
let poll = self
.state
.f_write
.get_or_insert_with(|| Box::pin(write(tx, file_id, buf.to_vec())))
.poll(cx);
if poll.is_ready() {
self.state.f_write.take();
} }
fn poll_flush( poll
self: std::pin::Pin<&mut Self>, }
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
use smol::future::FutureExt; async fn flush(tx: SessionSender, file_id: usize) -> io::Result<()> {
async fn flush(mut _self: std::pin::Pin<&mut File>) -> std::io::Result<()> { inner_flush(tx, file_id)
_self
.flush()
.await .await
.map_err(|x| std::io::Error::new(std::io::ErrorKind::Other, x)) .map_err(|x| io::Error::new(io::ErrorKind::Other, x))
} }
Box::pin(flush(self)).poll(cx) let tx = self.tx.as_ref().unwrap().clone();
let file_id = self.file_id;
let poll = self
.state
.f_flush
.get_or_insert_with(|| Box::pin(flush(tx, file_id)))
.poll(cx);
if poll.is_ready() {
self.state.f_flush.take();
} }
fn poll_close( poll
self: std::pin::Pin<&mut Self>, }
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
use smol::future::FutureExt; async fn close(tx: SessionSender, file_id: usize) -> io::Result<()> {
async fn close(mut _self: std::pin::Pin<&mut File>) -> std::io::Result<()> { inner_close(tx, file_id)
_self
.close()
.await .await
.map_err(|x| std::io::Error::new(std::io::ErrorKind::Other, x)) .map_err(|x| io::Error::new(io::ErrorKind::Other, x))
} }
Box::pin(close(self)).poll(cx) let tx = self.tx.as_ref().unwrap().clone();
let file_id = self.file_id;
let poll = self
.state
.f_close
.get_or_insert_with(|| Box::pin(close(tx, file_id)))
.poll(cx);
if poll.is_ready() {
self.state.f_close.take();
}
poll
} }
} }
impl File {
/// Writes some bytes to the file. /// Writes some bytes to the file.
async fn write(&mut self, data: Vec<u8>) -> anyhow::Result<()> { async fn inner_write(tx: SessionSender, file_id: usize, data: Vec<u8>) -> anyhow::Result<()> {
let (reply, rx) = bounded(1); let (reply, rx) = bounded(1);
self.tx tx.send(SessionRequest::Sftp(SftpRequest::WriteFile(WriteFile {
.as_ref() file_id,
.unwrap()
.send(SessionRequest::Sftp(SftpRequest::WriteFile(WriteFile {
file_id: self.file_id,
data, data,
reply, reply,
}))) })))
@ -108,13 +179,14 @@ impl File {
/// ///
/// If the vector is empty, this indicates that there are no more bytes /// If the vector is empty, this indicates that there are no more bytes
/// to read at the moment. /// to read at the moment.
async fn read(&mut self, max_bytes: usize) -> anyhow::Result<Vec<u8>> { async fn inner_read(
tx: SessionSender,
file_id: usize,
max_bytes: usize,
) -> anyhow::Result<Vec<u8>> {
let (reply, rx) = bounded(1); let (reply, rx) = bounded(1);
self.tx tx.send(SessionRequest::Sftp(SftpRequest::ReadFile(ReadFile {
.as_ref() file_id,
.unwrap()
.send(SessionRequest::Sftp(SftpRequest::ReadFile(ReadFile {
file_id: self.file_id,
max_bytes, max_bytes,
reply, reply,
}))) })))
@ -124,13 +196,10 @@ impl File {
} }
/// Flushes the remote file /// Flushes the remote file
async fn flush(&mut self) -> anyhow::Result<()> { async fn inner_flush(tx: SessionSender, file_id: usize) -> anyhow::Result<()> {
let (reply, rx) = bounded(1); let (reply, rx) = bounded(1);
self.tx tx.send(SessionRequest::Sftp(SftpRequest::FlushFile(FlushFile {
.as_ref() file_id,
.unwrap()
.send(SessionRequest::Sftp(SftpRequest::FlushFile(FlushFile {
file_id: self.file_id,
reply, reply,
}))) })))
.await?; .await?;
@ -139,17 +208,13 @@ impl File {
} }
/// Closes the handle to the remote file /// Closes the handle to the remote file
async fn close(&mut self) -> anyhow::Result<()> { async fn inner_close(tx: SessionSender, file_id: usize) -> anyhow::Result<()> {
let (reply, rx) = bounded(1); let (reply, rx) = bounded(1);
self.tx tx.send(SessionRequest::Sftp(SftpRequest::CloseFile(CloseFile {
.as_ref() file_id,
.unwrap()
.send(SessionRequest::Sftp(SftpRequest::CloseFile(CloseFile {
file_id: self.file_id,
reply, reply,
}))) })))
.await?; .await?;
let result = rx.recv().await?; let result = rx.recv().await?;
Ok(result) Ok(result)
} }
}

View File

@ -6,7 +6,6 @@ use wezterm_ssh::Session;
#[rstest] #[rstest]
#[smol_potat::test] #[smol_potat::test]
#[ignore]
async fn should_support_async_reading(#[future] session: Session) { async fn should_support_async_reading(#[future] session: Session) {
let session: Session = session.await; let session: Session = session.await;
@ -27,11 +26,17 @@ async fn should_support_async_reading(#[future] session: Session) {
.expect("Failed to read file to string"); .expect("Failed to read file to string");
assert_eq!(contents, "some file contents"); assert_eq!(contents, "some file contents");
// NOTE: Testing second time to ensure future is properly cleared
let mut contents = String::new();
remote_file
.read_to_string(&mut contents)
.await
.expect("Failed to read file to string second time");
} }
#[rstest] #[rstest]
#[smol_potat::test] #[smol_potat::test]
#[ignore]
async fn should_support_async_writing(#[future] session: Session) { async fn should_support_async_writing(#[future] session: Session) {
let session: Session = session.await; let session: Session = session.await;
@ -51,4 +56,58 @@ async fn should_support_async_writing(#[future] session: Session) {
.expect("Failed to write to file"); .expect("Failed to write to file");
file.assert("new contents for file"); file.assert("new contents for file");
// NOTE: Testing second time to ensure future is properly cleared
remote_file
.write_all(b"new contents for file")
.await
.expect("Failed to write to file second time");
}
#[rstest]
#[smol_potat::test]
async fn should_support_async_flush(#[future] session: Session) {
let session: Session = session.await;
let temp = TempDir::new().unwrap();
let file = temp.child("test-file");
file.write_str("some file contents").unwrap();
let mut remote_file = session
.sftp()
.create(file.path())
.await
.expect("Failed to open remote file");
remote_file.flush().await.expect("Failed to flush file");
// NOTE: Testing second time to ensure future is properly cleared
remote_file
.flush()
.await
.expect("Failed to flush file second time");
}
#[rstest]
#[smol_potat::test]
async fn should_support_async_close(#[future] session: Session) {
let session: Session = session.await;
let temp = TempDir::new().unwrap();
let file = temp.child("test-file");
file.write_str("some file contents").unwrap();
let mut remote_file = session
.sftp()
.create(file.path())
.await
.expect("Failed to open remote file");
remote_file.close().await.expect("Failed to close file");
// NOTE: Testing second time to ensure future is properly cleared
remote_file
.close()
.await
.expect("Failed to close file second time");
} }