Add LanguageServer::on_io method, for observing JSON sent back and forth

This commit is contained in:
Max Brunsfeld 2023-04-20 17:47:51 -07:00
parent abdccf7393
commit 2dd4920625

View File

@ -20,10 +20,10 @@ use std::{
future::Future,
io::Write,
path::PathBuf,
str::FromStr,
str::{self, FromStr as _},
sync::{
atomic::{AtomicUsize, Ordering::SeqCst},
Arc,
Arc, Weak,
},
};
use std::{path::Path, process::Stdio};
@ -34,16 +34,18 @@ const CONTENT_LEN_HEADER: &str = "Content-Length: ";
type NotificationHandler = Box<dyn Send + FnMut(Option<usize>, &str, AsyncAppContext)>;
type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
type IoHandler = Box<dyn Send + FnMut(bool, &str)>;
pub struct LanguageServer {
server_id: LanguageServerId,
next_id: AtomicUsize,
outbound_tx: channel::Sender<Vec<u8>>,
outbound_tx: channel::Sender<String>,
name: String,
capabilities: ServerCapabilities,
code_action_kinds: Option<Vec<CodeActionKind>>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
executor: Arc<executor::Background>,
#[allow(clippy::type_complexity)]
io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
@ -56,9 +58,16 @@ pub struct LanguageServer {
#[repr(transparent)]
pub struct LanguageServerId(pub usize);
pub struct Subscription {
method: &'static str,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
pub enum Subscription {
Detached,
Notification {
method: &'static str,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
},
Io {
id: usize,
io_handlers: Weak<Mutex<HashMap<usize, IoHandler>>>,
},
}
#[derive(Serialize, Deserialize)]
@ -177,33 +186,40 @@ impl LanguageServer {
Stdout: AsyncRead + Unpin + Send + 'static,
F: FnMut(AnyNotification) + 'static + Send,
{
let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
let (output_done_tx, output_done_rx) = barrier::channel();
let notification_handlers =
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
let io_handlers = Arc::new(Mutex::new(HashMap::default()));
let input_task = cx.spawn(|cx| {
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
Self::handle_input(
stdout,
on_unhandled_notification,
notification_handlers,
response_handlers,
notification_handlers.clone(),
response_handlers.clone(),
io_handlers.clone(),
cx,
)
.log_err()
});
let (output_done_tx, output_done_rx) = barrier::channel();
let output_task = cx.background().spawn({
let response_handlers = response_handlers.clone();
Self::handle_output(stdin, outbound_rx, output_done_tx, response_handlers).log_err()
Self::handle_output(
stdin,
outbound_rx,
output_done_tx,
response_handlers.clone(),
io_handlers.clone(),
)
.log_err()
});
Self {
server_id,
notification_handlers,
response_handlers,
io_handlers,
name: Default::default(),
capabilities: Default::default(),
code_action_kinds,
@ -226,6 +242,7 @@ impl LanguageServer {
mut on_unhandled_notification: F,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
cx: AsyncAppContext,
) -> anyhow::Result<()>
where
@ -252,7 +269,13 @@ impl LanguageServer {
buffer.resize(message_len, 0);
stdout.read_exact(&mut buffer).await?;
log::trace!("incoming message:{}", String::from_utf8_lossy(&buffer));
if let Ok(message) = str::from_utf8(&buffer) {
log::trace!("incoming message:{}", message);
for handler in io_handlers.lock().values_mut() {
handler(true, message);
}
}
if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
if let Some(handler) = notification_handlers.lock().get_mut(msg.method) {
@ -291,9 +314,10 @@ impl LanguageServer {
async fn handle_output<Stdin>(
stdin: Stdin,
outbound_rx: channel::Receiver<Vec<u8>>,
outbound_rx: channel::Receiver<String>,
output_done_tx: barrier::Sender,
response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
) -> anyhow::Result<()>
where
Stdin: AsyncWrite + Unpin + Send + 'static,
@ -307,13 +331,17 @@ impl LanguageServer {
});
let mut content_len_buffer = Vec::new();
while let Ok(message) = outbound_rx.recv().await {
log::trace!("outgoing message:{}", String::from_utf8_lossy(&message));
log::trace!("outgoing message:{}", message);
for handler in io_handlers.lock().values_mut() {
handler(false, &message);
}
content_len_buffer.clear();
write!(content_len_buffer, "{}", message.len()).unwrap();
stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
stdin.write_all(&content_len_buffer).await?;
stdin.write_all("\r\n\r\n".as_bytes()).await?;
stdin.write_all(&message).await?;
stdin.write_all(message.as_bytes()).await?;
stdin.flush().await?;
}
drop(output_done_tx);
@ -464,6 +492,19 @@ impl LanguageServer {
self.on_custom_request(T::METHOD, f)
}
#[must_use]
pub fn on_io<F>(&self, f: F) -> Subscription
where
F: 'static + Send + FnMut(bool, &str),
{
let id = self.next_id.fetch_add(1, SeqCst);
self.io_handlers.lock().insert(id, Box::new(f));
Subscription::Io {
id,
io_handlers: Arc::downgrade(&self.io_handlers),
}
}
pub fn remove_request_handler<T: request::Request>(&self) {
self.notification_handlers.lock().remove(T::METHOD);
}
@ -490,7 +531,7 @@ impl LanguageServer {
prev_handler.is_none(),
"registered multiple handlers for the same LSP method"
);
Subscription {
Subscription::Notification {
method,
notification_handlers: self.notification_handlers.clone(),
}
@ -537,7 +578,7 @@ impl LanguageServer {
},
};
if let Some(response) =
serde_json::to_vec(&response).log_err()
serde_json::to_string(&response).log_err()
{
outbound_tx.try_send(response).ok();
}
@ -560,7 +601,7 @@ impl LanguageServer {
message: error.to_string(),
}),
};
if let Some(response) = serde_json::to_vec(&response).log_err() {
if let Some(response) = serde_json::to_string(&response).log_err() {
outbound_tx.try_send(response).ok();
}
}
@ -572,7 +613,7 @@ impl LanguageServer {
prev_handler.is_none(),
"registered multiple handlers for the same LSP method"
);
Subscription {
Subscription::Notification {
method,
notification_handlers: self.notification_handlers.clone(),
}
@ -612,14 +653,14 @@ impl LanguageServer {
fn request_internal<T: request::Request>(
next_id: &AtomicUsize,
response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>,
outbound_tx: &channel::Sender<Vec<u8>>,
outbound_tx: &channel::Sender<String>,
params: T::Params,
) -> impl 'static + Future<Output = Result<T::Result>>
where
T::Result: 'static + Send,
{
let id = next_id.fetch_add(1, SeqCst);
let message = serde_json::to_vec(&Request {
let message = serde_json::to_string(&Request {
jsonrpc: JSON_RPC_VERSION,
id,
method: T::METHOD,
@ -662,10 +703,10 @@ impl LanguageServer {
}
fn notify_internal<T: notification::Notification>(
outbound_tx: &channel::Sender<Vec<u8>>,
outbound_tx: &channel::Sender<String>,
params: T::Params,
) -> Result<()> {
let message = serde_json::to_vec(&Notification {
let message = serde_json::to_string(&Notification {
jsonrpc: JSON_RPC_VERSION,
method: T::METHOD,
params,
@ -686,7 +727,7 @@ impl Drop for LanguageServer {
impl Subscription {
pub fn detach(mut self) {
self.method = "";
*(&mut self) = Self::Detached;
}
}
@ -698,7 +739,20 @@ impl fmt::Display for LanguageServerId {
impl Drop for Subscription {
fn drop(&mut self) {
self.notification_handlers.lock().remove(self.method);
match self {
Subscription::Detached => {}
Subscription::Notification {
method,
notification_handlers,
} => {
notification_handlers.lock().remove(method);
}
Subscription::Io { id, io_handlers } => {
if let Some(io_handlers) = io_handlers.upgrade() {
io_handlers.lock().remove(id);
}
}
}
}
}