mirror of
https://github.com/wez/wezterm.git
synced 2024-12-23 05:12:40 +03:00
Fix file asyncread and asyncwrite future handling by keeping stateful futures
This commit is contained in:
parent
35043e9341
commit
50a0372e17
@ -654,7 +654,7 @@ impl SessionInner {
|
||||
let file_id = self.next_file_id;
|
||||
self.next_file_id += 1;
|
||||
|
||||
let file = File { file_id, tx: None };
|
||||
let file = File::new(file_id);
|
||||
(file_id, file)
|
||||
}
|
||||
|
||||
|
@ -35,7 +35,7 @@ impl Sftp {
|
||||
})))
|
||||
.await?;
|
||||
let mut result = rx.recv().await?;
|
||||
result.tx.replace(self.tx.clone());
|
||||
result.initialize_sender(self.tx.clone());
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
@ -51,7 +51,7 @@ impl Sftp {
|
||||
})))
|
||||
.await?;
|
||||
let mut result = rx.recv().await?;
|
||||
result.tx.replace(self.tx.clone());
|
||||
result.initialize_sender(self.tx.clone());
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
@ -67,7 +67,7 @@ impl Sftp {
|
||||
})))
|
||||
.await?;
|
||||
let mut result = rx.recv().await?;
|
||||
result.tx.replace(self.tx.clone());
|
||||
result.initialize_sender(self.tx.clone());
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
@ -83,7 +83,7 @@ impl Sftp {
|
||||
})))
|
||||
.await?;
|
||||
let mut result = rx.recv().await?;
|
||||
result.tx.replace(self.tx.clone());
|
||||
result.initialize_sender(self.tx.clone());
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
|
@ -1,155 +1,220 @@
|
||||
use super::{
|
||||
CloseFile, FlushFile, ReadFile, SessionRequest, SessionSender, SftpRequest, WriteFile,
|
||||
};
|
||||
use smol::channel::bounded;
|
||||
use smol::{channel::bounded, future::FutureExt};
|
||||
use std::{
|
||||
fmt,
|
||||
future::Future,
|
||||
io,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
pub(crate) type FileId = usize;
|
||||
|
||||
/// A file handle to an SFTP connection.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct File {
|
||||
pub(crate) file_id: FileId,
|
||||
pub(crate) tx: Option<SessionSender>,
|
||||
tx: Option<SessionSender>,
|
||||
state: FileState,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct FileState {
|
||||
f_read: Option<Pin<Box<dyn Future<Output = io::Result<Vec<u8>>> + Send + Sync + 'static>>>,
|
||||
f_write: Option<Pin<Box<dyn Future<Output = io::Result<usize>> + Send + Sync + 'static>>>,
|
||||
f_flush: Option<Pin<Box<dyn Future<Output = io::Result<()>> + Send + Sync + 'static>>>,
|
||||
f_close: Option<Pin<Box<dyn Future<Output = io::Result<()>> + Send + Sync + 'static>>>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for File {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("File")
|
||||
.field("file_id", &self.file_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl File {
|
||||
pub(crate) fn new(file_id: FileId) -> Self {
|
||||
Self {
|
||||
file_id,
|
||||
tx: None,
|
||||
state: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn initialize_sender(&mut self, sender: SessionSender) {
|
||||
self.tx.replace(sender);
|
||||
}
|
||||
}
|
||||
|
||||
impl smol::io::AsyncRead for File {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> std::task::Poll<std::io::Result<usize>> {
|
||||
use smol::future::FutureExt;
|
||||
async fn read(
|
||||
mut _self: std::pin::Pin<&mut File>,
|
||||
buf: &mut [u8],
|
||||
) -> std::io::Result<usize> {
|
||||
let data = _self
|
||||
.read(buf.len())
|
||||
) -> Poll<io::Result<usize>> {
|
||||
async fn read(tx: SessionSender, file_id: usize, len: usize) -> io::Result<Vec<u8>> {
|
||||
inner_read(tx, file_id, len)
|
||||
.await
|
||||
.map_err(|x| std::io::Error::new(std::io::ErrorKind::Other, x))?;
|
||||
let n = data.len();
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))
|
||||
}
|
||||
let tx = self.tx.as_ref().unwrap().clone();
|
||||
let file_id = self.file_id;
|
||||
|
||||
buf.copy_from_slice(&data[..n]);
|
||||
let poll = self
|
||||
.state
|
||||
.f_read
|
||||
.get_or_insert_with(|| Box::pin(read(tx, file_id, buf.len())))
|
||||
.poll(cx);
|
||||
|
||||
Ok(n)
|
||||
if poll.is_ready() {
|
||||
self.state.f_read.take();
|
||||
}
|
||||
|
||||
Box::pin(read(self, buf)).poll(cx)
|
||||
match poll {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(Err(x)) => Poll::Ready(Err(x)),
|
||||
Poll::Ready(Ok(data)) => {
|
||||
let n = data.len();
|
||||
(&mut buf[..n]).copy_from_slice(&data[..n]);
|
||||
Poll::Ready(Ok(n))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl smol::io::AsyncWrite for File {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<std::io::Result<usize>> {
|
||||
use smol::future::FutureExt;
|
||||
async fn write(mut _self: std::pin::Pin<&mut File>, buf: &[u8]) -> std::io::Result<usize> {
|
||||
_self
|
||||
.write(buf.to_vec())
|
||||
) -> Poll<io::Result<usize>> {
|
||||
async fn write(tx: SessionSender, file_id: usize, buf: Vec<u8>) -> io::Result<usize> {
|
||||
let n = buf.len();
|
||||
inner_write(tx, file_id, buf)
|
||||
.await
|
||||
.map(|_| buf.len())
|
||||
.map_err(|x| std::io::Error::new(std::io::ErrorKind::Other, x))
|
||||
.map(|_| n)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))
|
||||
}
|
||||
|
||||
Box::pin(write(self, buf)).poll(cx)
|
||||
let tx = self.tx.as_ref().unwrap().clone();
|
||||
let file_id = self.file_id;
|
||||
|
||||
let poll = self
|
||||
.state
|
||||
.f_write
|
||||
.get_or_insert_with(|| Box::pin(write(tx, file_id, buf.to_vec())))
|
||||
.poll(cx);
|
||||
|
||||
if poll.is_ready() {
|
||||
self.state.f_write.take();
|
||||
}
|
||||
|
||||
poll
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
use smol::future::FutureExt;
|
||||
async fn flush(mut _self: std::pin::Pin<&mut File>) -> std::io::Result<()> {
|
||||
_self
|
||||
.flush()
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
async fn flush(tx: SessionSender, file_id: usize) -> io::Result<()> {
|
||||
inner_flush(tx, file_id)
|
||||
.await
|
||||
.map_err(|x| std::io::Error::new(std::io::ErrorKind::Other, x))
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))
|
||||
}
|
||||
|
||||
Box::pin(flush(self)).poll(cx)
|
||||
let tx = self.tx.as_ref().unwrap().clone();
|
||||
let file_id = self.file_id;
|
||||
|
||||
let poll = self
|
||||
.state
|
||||
.f_flush
|
||||
.get_or_insert_with(|| Box::pin(flush(tx, file_id)))
|
||||
.poll(cx);
|
||||
|
||||
if poll.is_ready() {
|
||||
self.state.f_flush.take();
|
||||
}
|
||||
|
||||
poll
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
use smol::future::FutureExt;
|
||||
async fn close(mut _self: std::pin::Pin<&mut File>) -> std::io::Result<()> {
|
||||
_self
|
||||
.close()
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
async fn close(tx: SessionSender, file_id: usize) -> io::Result<()> {
|
||||
inner_close(tx, file_id)
|
||||
.await
|
||||
.map_err(|x| std::io::Error::new(std::io::ErrorKind::Other, x))
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))
|
||||
}
|
||||
|
||||
Box::pin(close(self)).poll(cx)
|
||||
let tx = self.tx.as_ref().unwrap().clone();
|
||||
let file_id = self.file_id;
|
||||
|
||||
let poll = self
|
||||
.state
|
||||
.f_close
|
||||
.get_or_insert_with(|| Box::pin(close(tx, file_id)))
|
||||
.poll(cx);
|
||||
|
||||
if poll.is_ready() {
|
||||
self.state.f_close.take();
|
||||
}
|
||||
|
||||
poll
|
||||
}
|
||||
}
|
||||
|
||||
impl File {
|
||||
/// Writes some bytes to the file.
|
||||
async fn write(&mut self, data: Vec<u8>) -> anyhow::Result<()> {
|
||||
let (reply, rx) = bounded(1);
|
||||
self.tx
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.send(SessionRequest::Sftp(SftpRequest::WriteFile(WriteFile {
|
||||
file_id: self.file_id,
|
||||
data,
|
||||
reply,
|
||||
})))
|
||||
.await?;
|
||||
let result = rx.recv().await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Reads some bytes from the file, returning a vector of bytes read.
|
||||
///
|
||||
/// If the vector is empty, this indicates that there are no more bytes
|
||||
/// to read at the moment.
|
||||
async fn read(&mut self, max_bytes: usize) -> anyhow::Result<Vec<u8>> {
|
||||
let (reply, rx) = bounded(1);
|
||||
self.tx
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.send(SessionRequest::Sftp(SftpRequest::ReadFile(ReadFile {
|
||||
file_id: self.file_id,
|
||||
max_bytes,
|
||||
reply,
|
||||
})))
|
||||
.await?;
|
||||
let result = rx.recv().await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Flushes the remote file
|
||||
async fn flush(&mut self) -> anyhow::Result<()> {
|
||||
let (reply, rx) = bounded(1);
|
||||
self.tx
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.send(SessionRequest::Sftp(SftpRequest::FlushFile(FlushFile {
|
||||
file_id: self.file_id,
|
||||
reply,
|
||||
})))
|
||||
.await?;
|
||||
let result = rx.recv().await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Closes the handle to the remote file
|
||||
async fn close(&mut self) -> anyhow::Result<()> {
|
||||
let (reply, rx) = bounded(1);
|
||||
self.tx
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.send(SessionRequest::Sftp(SftpRequest::CloseFile(CloseFile {
|
||||
file_id: self.file_id,
|
||||
reply,
|
||||
})))
|
||||
.await?;
|
||||
let result = rx.recv().await?;
|
||||
Ok(result)
|
||||
}
|
||||
/// Writes some bytes to the file.
|
||||
async fn inner_write(tx: SessionSender, file_id: usize, data: Vec<u8>) -> anyhow::Result<()> {
|
||||
let (reply, rx) = bounded(1);
|
||||
tx.send(SessionRequest::Sftp(SftpRequest::WriteFile(WriteFile {
|
||||
file_id,
|
||||
data,
|
||||
reply,
|
||||
})))
|
||||
.await?;
|
||||
let result = rx.recv().await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Reads some bytes from the file, returning a vector of bytes read.
|
||||
///
|
||||
/// If the vector is empty, this indicates that there are no more bytes
|
||||
/// to read at the moment.
|
||||
async fn inner_read(
|
||||
tx: SessionSender,
|
||||
file_id: usize,
|
||||
max_bytes: usize,
|
||||
) -> anyhow::Result<Vec<u8>> {
|
||||
let (reply, rx) = bounded(1);
|
||||
tx.send(SessionRequest::Sftp(SftpRequest::ReadFile(ReadFile {
|
||||
file_id,
|
||||
max_bytes,
|
||||
reply,
|
||||
})))
|
||||
.await?;
|
||||
let result = rx.recv().await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Flushes the remote file
|
||||
async fn inner_flush(tx: SessionSender, file_id: usize) -> anyhow::Result<()> {
|
||||
let (reply, rx) = bounded(1);
|
||||
tx.send(SessionRequest::Sftp(SftpRequest::FlushFile(FlushFile {
|
||||
file_id,
|
||||
reply,
|
||||
})))
|
||||
.await?;
|
||||
let result = rx.recv().await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Closes the handle to the remote file
|
||||
async fn inner_close(tx: SessionSender, file_id: usize) -> anyhow::Result<()> {
|
||||
let (reply, rx) = bounded(1);
|
||||
tx.send(SessionRequest::Sftp(SftpRequest::CloseFile(CloseFile {
|
||||
file_id,
|
||||
reply,
|
||||
})))
|
||||
.await?;
|
||||
let result = rx.recv().await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ use wezterm_ssh::Session;
|
||||
|
||||
#[rstest]
|
||||
#[smol_potat::test]
|
||||
#[ignore]
|
||||
async fn should_support_async_reading(#[future] session: Session) {
|
||||
let session: Session = session.await;
|
||||
|
||||
@ -27,11 +26,17 @@ async fn should_support_async_reading(#[future] session: Session) {
|
||||
.expect("Failed to read file to string");
|
||||
|
||||
assert_eq!(contents, "some file contents");
|
||||
|
||||
// NOTE: Testing second time to ensure future is properly cleared
|
||||
let mut contents = String::new();
|
||||
remote_file
|
||||
.read_to_string(&mut contents)
|
||||
.await
|
||||
.expect("Failed to read file to string second time");
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[smol_potat::test]
|
||||
#[ignore]
|
||||
async fn should_support_async_writing(#[future] session: Session) {
|
||||
let session: Session = session.await;
|
||||
|
||||
@ -51,4 +56,58 @@ async fn should_support_async_writing(#[future] session: Session) {
|
||||
.expect("Failed to write to file");
|
||||
|
||||
file.assert("new contents for file");
|
||||
|
||||
// NOTE: Testing second time to ensure future is properly cleared
|
||||
remote_file
|
||||
.write_all(b"new contents for file")
|
||||
.await
|
||||
.expect("Failed to write to file second time");
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[smol_potat::test]
|
||||
async fn should_support_async_flush(#[future] session: Session) {
|
||||
let session: Session = session.await;
|
||||
|
||||
let temp = TempDir::new().unwrap();
|
||||
let file = temp.child("test-file");
|
||||
file.write_str("some file contents").unwrap();
|
||||
|
||||
let mut remote_file = session
|
||||
.sftp()
|
||||
.create(file.path())
|
||||
.await
|
||||
.expect("Failed to open remote file");
|
||||
|
||||
remote_file.flush().await.expect("Failed to flush file");
|
||||
|
||||
// NOTE: Testing second time to ensure future is properly cleared
|
||||
remote_file
|
||||
.flush()
|
||||
.await
|
||||
.expect("Failed to flush file second time");
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[smol_potat::test]
|
||||
async fn should_support_async_close(#[future] session: Session) {
|
||||
let session: Session = session.await;
|
||||
|
||||
let temp = TempDir::new().unwrap();
|
||||
let file = temp.child("test-file");
|
||||
file.write_str("some file contents").unwrap();
|
||||
|
||||
let mut remote_file = session
|
||||
.sftp()
|
||||
.create(file.path())
|
||||
.await
|
||||
.expect("Failed to open remote file");
|
||||
|
||||
remote_file.close().await.expect("Failed to close file");
|
||||
|
||||
// NOTE: Testing second time to ensure future is properly cleared
|
||||
remote_file
|
||||
.close()
|
||||
.await
|
||||
.expect("Failed to close file second time");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user