diff --git a/filedescriptor/Cargo.toml b/filedescriptor/Cargo.toml index 415c31005..4840f9fdb 100644 --- a/filedescriptor/Cargo.toml +++ b/filedescriptor/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "filedescriptor" -version = "0.7.0" +version = "0.7.1" authors = ["Wez Furlong"] edition = "2018" repository = "https://github.com/wez/wzsh" diff --git a/filedescriptor/src/windows.rs b/filedescriptor/src/windows.rs index 3022185c6..37030ee8d 100644 --- a/filedescriptor/src/windows.rs +++ b/filedescriptor/src/windows.rs @@ -19,8 +19,8 @@ use winapi::um::processthreadsapi::*; use winapi::um::winbase::{FILE_TYPE_CHAR, FILE_TYPE_DISK, FILE_TYPE_PIPE}; use winapi::um::winnt::HANDLE; use winapi::um::winsock2::{ - accept, bind, closesocket, connect, getsockname, htonl, ioctlsocket, listen, WSAPoll, - WSASocketW, WSAStartup, INVALID_SOCKET, SOCKET, SOCK_STREAM, WSADATA, + accept, bind, closesocket, connect, getsockname, htonl, ioctlsocket, listen, recv, send, + WSAPoll, WSASocketW, WSAStartup, INVALID_SOCKET, SOCKET, SOCK_STREAM, WSADATA, WSA_FLAG_NO_HANDLE_INHERIT, }; pub use winapi::um::winsock2::{POLLERR, POLLHUP, POLLIN, POLLOUT, WSAPOLLFD as pollfd}; @@ -285,45 +285,80 @@ impl FromRawSocket for FileDescriptor { impl io::Read for FileDescriptor { fn read(&mut self, buf: &mut [u8]) -> Result { - let mut num_read = 0; - let ok = unsafe { - ReadFile( - self.handle.as_raw_handle() as *mut _, - buf.as_mut_ptr() as *mut _, - buf.len() as u32, - &mut num_read, - ptr::null_mut(), - ) - }; - if ok == 0 { - let err = IoError::last_os_error(); - if err.kind() == std::io::ErrorKind::BrokenPipe { - Ok(0) + if self.handle.is_socket_handle() { + // It's important to use the winsock functions to read/write + // even though ReadFile and WriteFile technically work; only + // the winsock functions respect non-blocking mode. + let num_read = unsafe { + recv( + self.as_socket_descriptor(), + buf.as_mut_ptr() as *mut _, + buf.len() as _, + 0, + ) + }; + if num_read < 0 { + Err(IoError::last_os_error()) } else { - Err(err) + Ok(num_read as usize) } } else { - Ok(num_read as usize) + let mut num_read = 0; + let ok = unsafe { + ReadFile( + self.handle.as_raw_handle() as *mut _, + buf.as_mut_ptr() as *mut _, + buf.len() as _, + &mut num_read, + ptr::null_mut(), + ) + }; + if ok == 0 { + let err = IoError::last_os_error(); + if err.kind() == std::io::ErrorKind::BrokenPipe { + Ok(0) + } else { + Err(err) + } + } else { + Ok(num_read as usize) + } } } } impl io::Write for FileDescriptor { fn write(&mut self, buf: &[u8]) -> Result { - let mut num_wrote = 0; - let ok = unsafe { - WriteFile( - self.handle.as_raw_handle() as *mut _, - buf.as_ptr() as *const _, - buf.len() as u32, - &mut num_wrote, - ptr::null_mut(), - ) - }; - if ok == 0 { - Err(IoError::last_os_error()) + if self.handle.is_socket_handle() { + let num_wrote = unsafe { + send( + self.as_socket_descriptor(), + buf.as_ptr() as *const _, + buf.len() as _, + 0, + ) + }; + if num_wrote < 0 { + Err(IoError::last_os_error()) + } else { + Ok(num_wrote as usize) + } } else { - Ok(num_wrote as usize) + let mut num_wrote = 0; + let ok = unsafe { + WriteFile( + self.handle.as_raw_handle() as *mut _, + buf.as_ptr() as *const _, + buf.len() as u32, + &mut num_wrote, + ptr::null_mut(), + ) + }; + if ok == 0 { + Err(IoError::last_os_error()) + } else { + Ok(num_wrote as usize) + } } } fn flush(&mut self) -> Result<(), io::Error> {