Clean up tasks properly when dropping a FakeLanguageServer

* Make sure the fake's IO tasks are stopped
* Ensure that the fake's stdout is closed, so that the corresponding language
  server's IO tasks are woken up and halted.
This commit is contained in:
Max Brunsfeld 2022-03-01 13:26:59 -08:00
parent 0e6686916c
commit 74469a46ba

View File

@ -476,18 +476,22 @@ impl Drop for Subscription {
#[cfg(any(test, feature = "test-support"))]
pub struct FakeLanguageServer {
handlers: Arc<
Mutex<
HashMap<
&'static str,
Box<dyn Send + FnMut(usize, &[u8], gpui::AsyncAppContext) -> Vec<u8>>,
>,
>,
>,
handlers: FakeLanguageServerHandlers,
outgoing_tx: futures::channel::mpsc::UnboundedSender<Vec<u8>>,
incoming_rx: futures::channel::mpsc::UnboundedReceiver<Vec<u8>>,
_input_task: Task<Result<()>>,
_output_task: Task<Result<()>>,
}
type FakeLanguageServerHandlers = Arc<
Mutex<
HashMap<
&'static str,
Box<dyn Send + FnMut(usize, &[u8], gpui::AsyncAppContext) -> Vec<u8>>,
>,
>,
>;
#[cfg(any(test, feature = "test-support"))]
impl LanguageServer {
pub fn fake(cx: &mut gpui::MutableAppContext) -> (Arc<Self>, FakeLanguageServer) {
@ -533,59 +537,69 @@ impl FakeLanguageServer {
let (incoming_tx, incoming_rx) = futures::channel::mpsc::unbounded();
let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
let this = Self {
outgoing_tx: outgoing_tx.clone(),
incoming_rx,
handlers: Default::default(),
};
let handlers = FakeLanguageServerHandlers::default();
// Receive incoming messages
let handlers = this.handlers.clone();
cx.spawn(|cx| async move {
let mut buffer = Vec::new();
let mut stdin = smol::io::BufReader::new(stdin);
while Self::receive(&mut stdin, &mut buffer).await.is_ok() {
cx.background().simulate_random_delay().await;
if let Ok(request) = serde_json::from_slice::<AnyRequest>(&buffer) {
assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
let input_task = cx.spawn(|cx| {
let handlers = handlers.clone();
let outgoing_tx = outgoing_tx.clone();
async move {
let mut buffer = Vec::new();
let mut stdin = smol::io::BufReader::new(stdin);
while Self::receive(&mut stdin, &mut buffer).await.is_ok() {
cx.background().simulate_random_delay().await;
if let Ok(request) = serde_json::from_slice::<AnyRequest>(&buffer) {
assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
if let Some(handler) = handlers.lock().get_mut(request.method) {
let response =
handler(request.id, request.params.get().as_bytes(), cx.clone());
log::debug!("handled lsp request. method:{}", request.method);
outgoing_tx.unbounded_send(response)?;
} else {
log::debug!("unhandled lsp request. method:{}", request.method);
outgoing_tx.unbounded_send(
serde_json::to_vec(&AnyResponse {
let response;
if let Some(handler) = handlers.lock().get_mut(request.method) {
response =
handler(request.id, request.params.get().as_bytes(), cx.clone());
log::debug!("handled lsp request. method:{}", request.method);
} else {
response = serde_json::to_vec(&AnyResponse {
id: request.id,
error: Some(Error {
message: "no handler".to_string(),
}),
result: None,
})
.unwrap(),
)?;
.unwrap();
log::debug!("unhandled lsp request. method:{}", request.method);
}
outgoing_tx.unbounded_send(response)?;
} else {
incoming_tx.unbounded_send(buffer.clone())?;
}
} else {
incoming_tx.unbounded_send(buffer.clone())?;
}
Ok::<_, anyhow::Error>(())
}
Ok::<_, anyhow::Error>(())
})
.detach();
});
// Send outgoing messages
cx.background()
.spawn(async move {
let mut stdout = smol::io::BufWriter::new(stdout);
while let Some(notification) = outgoing_rx.next().await {
Self::send(&mut stdout, &notification).await;
}
})
.detach();
let output_task = cx.background().spawn(async move {
let mut stdout = smol::io::BufWriter::new(PipeWriterCloseOnDrop(stdout));
while let Some(message) = outgoing_rx.next().await {
stdout
.write_all(CONTENT_LEN_HEADER.as_bytes())
.await
.unwrap();
stdout
.write_all((format!("{}", message.len())).as_bytes())
.await
.unwrap();
stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
stdout.write_all(&message).await.unwrap();
stdout.flush().await.unwrap();
}
Ok(())
});
this
Self {
outgoing_tx,
incoming_rx,
handlers,
_input_task: input_task,
_output_task: output_task,
}
}
pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
@ -665,20 +679,6 @@ impl FakeLanguageServer {
.await;
}
async fn send(stdout: &mut smol::io::BufWriter<async_pipe::PipeWriter>, message: &[u8]) {
stdout
.write_all(CONTENT_LEN_HEADER.as_bytes())
.await
.unwrap();
stdout
.write_all((format!("{}", message.len())).as_bytes())
.await
.unwrap();
stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
stdout.write_all(&message).await.unwrap();
stdout.flush().await.unwrap();
}
async fn receive(
stdin: &mut smol::io::BufReader<async_pipe::PipeReader>,
buffer: &mut Vec<u8>,
@ -699,6 +699,44 @@ impl FakeLanguageServer {
}
}
struct PipeWriterCloseOnDrop(async_pipe::PipeWriter);
impl Drop for PipeWriterCloseOnDrop {
fn drop(&mut self) {
self.0.close().ok();
}
}
impl AsyncWrite for PipeWriterCloseOnDrop {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
let pipe = &mut self.0;
smol::pin!(pipe);
pipe.poll_write(cx, buf)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let pipe = &mut self.0;
smol::pin!(pipe);
pipe.poll_flush(cx)
}
fn poll_close(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let pipe = &mut self.0;
smol::pin!(pipe);
pipe.poll_close(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;