diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index d242deead4..d27332cbde 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -1,16 +1,17 @@ mod request; mod sign_in; -use anyhow::{anyhow, bail, Context, Result}; +use anyhow::{anyhow, Context, Result}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; use client::Client; +use collections::HashMap; use futures::{future::Shared, Future, FutureExt, TryFutureExt}; use gpui::{ actions, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, }; -use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, ToPointUtf16}; +use language::{point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, Language, ToPointUtf16}; use log::{debug, error}; use lsp::LanguageServer; use node_runtime::NodeRuntime; @@ -92,6 +93,7 @@ enum CopilotServer { Started { server: Arc, status: SignInStatus, + subscriptions_by_buffer_id: HashMap, }, } @@ -275,6 +277,7 @@ impl Copilot { this.server = CopilotServer::Started { server, status: SignInStatus::SignedOut, + subscriptions_by_buffer_id: Default::default(), }; this.update_sign_in_status(status, cx); } @@ -288,7 +291,7 @@ impl Copilot { } fn sign_in(&mut self, cx: &mut ModelContext) -> Task> { - if let CopilotServer::Started { server, status } = &mut self.server { + if let CopilotServer::Started { server, status, .. } = &mut self.server { let task = match status { SignInStatus::Authorized { .. } | SignInStatus::Unauthorized { .. } => { Task::ready(Ok(())).shared() @@ -373,7 +376,7 @@ impl Copilot { } fn sign_out(&mut self, cx: &mut ModelContext) -> Task> { - if let CopilotServer::Started { server, status } = &mut self.server { + if let CopilotServer::Started { server, status, .. } = &mut self.server { *status = SignInStatus::SignedOut; cx.notify(); @@ -410,43 +413,8 @@ impl Copilot { cx.foreground().spawn(start_task) } - pub fn completion( - &self, - buffer: &ModelHandle, - position: T, - cx: &mut ModelContext, - ) -> Task>> - where - T: ToPointUtf16, - { - let server = match self.authorized_server() { - Ok(server) => server, - Err(error) => return Task::ready(Err(error)), - }; - - let buffer = buffer.read(cx); - - if !buffer.file().map(|file| file.is_local()).unwrap_or(true) { - return Task::ready(Err(anyhow!("Copilot only works locally"))); - } - - let buffer = buffer.snapshot(); - let request = server.request::( - build_completion_params(&buffer, position, cx).unwrap(), - ); - cx.background().spawn(async move { - let result = request.await?; - let completion = result - .completions - .into_iter() - .next() - .map(|completion| completion_from_lsp(completion, &buffer)); - anyhow::Ok(completion) - }) - } - - pub fn completions_cycling( - &self, + pub fn completions( + &mut self, buffer: &ModelHandle, position: T, cx: &mut ModelContext, @@ -454,27 +422,150 @@ impl Copilot { where T: ToPointUtf16, { - let server = match self.authorized_server() { - Ok(server) => server, - Err(error) => return Task::ready(Err(error)), + self.request_completions::(buffer, position, cx) + } + + pub fn completions_cycling( + &mut self, + buffer: &ModelHandle, + position: T, + cx: &mut ModelContext, + ) -> Task>> + where + T: ToPointUtf16, + { + self.request_completions::(buffer, position, cx) + } + + fn request_completions( + &mut self, + buffer: &ModelHandle, + position: T, + cx: &mut ModelContext, + ) -> Task>> + where + R: lsp::request::Request< + Params = request::GetCompletionsParams, + Result = request::GetCompletionsResult, + >, + T: ToPointUtf16, + { + let buffer_id = buffer.id(); + let uri: lsp::Url = format!("buffer://{}", buffer_id).parse().unwrap(); + let snapshot = buffer.read(cx).snapshot(); + let server = match &mut self.server { + CopilotServer::Starting { .. } => { + return Task::ready(Err(anyhow!("copilot is still starting"))) + } + CopilotServer::Disabled => return Task::ready(Err(anyhow!("copilot is disabled"))), + CopilotServer::Error(error) => { + return Task::ready(Err(anyhow!( + "copilot was not started because of an error: {}", + error + ))) + } + CopilotServer::Started { + server, + status, + subscriptions_by_buffer_id, + } => { + if matches!(status, SignInStatus::Authorized { .. }) { + subscriptions_by_buffer_id + .entry(buffer_id) + .or_insert_with(|| { + server + .notify::( + lsp::DidOpenTextDocumentParams { + text_document: lsp::TextDocumentItem { + uri: uri.clone(), + language_id: id_for_language( + buffer.read(cx).language(), + ), + version: 0, + text: snapshot.text(), + }, + }, + ) + .log_err(); + + let uri = uri.clone(); + cx.observe_release(buffer, move |this, _, _| { + if let CopilotServer::Started { + server, + subscriptions_by_buffer_id, + .. + } = &mut this.server + { + server + .notify::( + lsp::DidCloseTextDocumentParams { + text_document: lsp::TextDocumentIdentifier::new( + uri.clone(), + ), + }, + ) + .log_err(); + subscriptions_by_buffer_id.remove(&buffer_id); + } + }) + }); + + server.clone() + } else { + return Task::ready(Err(anyhow!("must sign in before using copilot"))); + } + } }; - let buffer = buffer.read(cx); + let settings = cx.global::(); + let position = position.to_point_utf16(&snapshot); + let language = snapshot.language_at(position); + let language_name = language.map(|language| language.name()); + let language_name = language_name.as_deref(); - if !buffer.file().map(|file| file.is_local()).unwrap_or(true) { - return Task::ready(Err(anyhow!("Copilot only works locally"))); + let path; + let relative_path; + if let Some(file) = snapshot.file() { + if let Some(file) = file.as_local() { + path = file.abs_path(cx); + } else { + path = file.full_path(cx); + } + relative_path = file.path().to_path_buf(); + } else { + path = PathBuf::new(); + relative_path = PathBuf::new(); } - let buffer = buffer.snapshot(); - let request = server.request::( - build_completion_params(&buffer, position, cx).unwrap(), - ); + let params = request::GetCompletionsParams { + doc: request::GetCompletionsDocument { + source: snapshot.text(), + tab_size: settings.tab_size(language_name).into(), + indent_size: 1, + insert_spaces: !settings.hard_tabs(language_name), + uri, + path: path.to_string_lossy().into(), + relative_path: relative_path.to_string_lossy().into(), + language_id: id_for_language(language), + position: point_to_lsp(position), + version: 0, + }, + }; cx.background().spawn(async move { - let result = request.await?; + let result = server.request::(params).await?; let completions = result .completions .into_iter() - .map(|completion| completion_from_lsp(completion, &buffer)) + .map(|completion| { + let start = snapshot + .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left); + let end = + snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left); + Completion { + range: snapshot.anchor_before(start)..snapshot.anchor_after(end), + text: completion.text, + } + }) .collect(); anyhow::Ok(completions) }) @@ -516,85 +607,14 @@ impl Copilot { cx.notify(); } } - - fn authorized_server(&self) -> Result> { - match &self.server { - CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")), - CopilotServer::Disabled => Err(anyhow!("copilot is disabled")), - CopilotServer::Error(error) => Err(anyhow!( - "copilot was not started because of an error: {}", - error - )), - CopilotServer::Started { server, status } => { - if matches!(status, SignInStatus::Authorized { .. }) { - Ok(server.clone()) - } else { - Err(anyhow!("must sign in before using copilot")) - } - } - } - } } -fn build_completion_params( - buffer: &BufferSnapshot, - position: T, - cx: &AppContext, -) -> anyhow::Result -where - T: ToPointUtf16, -{ - let position = position.to_point_utf16(&buffer); - let language_name = buffer.language_at(position).map(|language| language.name()); - let language_name = language_name.as_deref(); - - let path; - let relative_path; - if let Some(file) = buffer.file() { - if let Some(file) = file.as_local() { - path = file.abs_path(cx); - } else { - path = file.full_path(cx); - } - relative_path = file.path().to_path_buf(); - } else { - path = PathBuf::from("/untitled"); - relative_path = PathBuf::from("untitled"); - } - - let settings = cx.global::(); - let language_id = match language_name { +fn id_for_language(language: Option<&Arc>) -> String { + let language_name = language.map(|language| language.name()); + match language_name.as_deref() { Some("Plain Text") => "plaintext".to_string(), Some(language_name) => language_name.to_lowercase(), None => "plaintext".to_string(), - }; - - let Ok(uri) = lsp::Url::from_file_path(&path) else { - bail!("Failed convert file path") - }; - - Ok(request::GetCompletionsParams { - doc: request::GetCompletionsDocument { - source: buffer.text(), - tab_size: settings.tab_size(language_name).into(), - indent_size: 1, - insert_spaces: !settings.hard_tabs(language_name), - uri, - path: path.to_string_lossy().into(), - relative_path: relative_path.to_string_lossy().into(), - language_id, - position: point_to_lsp(position), - version: 0, - }, - }) -} - -fn completion_from_lsp(completion: request::Completion, buffer: &BufferSnapshot) -> Completion { - let start = buffer.clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left); - let end = buffer.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left); - Completion { - range: buffer.anchor_before(start)..buffer.anchor_after(end), - text: completion.text, } } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index a5722e183e..5c6470fe90 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -2843,14 +2843,14 @@ impl Editor { self.copilot_state.pending_refresh = cx.spawn_weak(|this, mut cx| async move { let (completion, completions_cycling) = copilot.update(&mut cx, |copilot, cx| { ( - copilot.completion(&buffer, buffer_position, cx), + copilot.completions(&buffer, buffer_position, cx), copilot.completions_cycling(&buffer, buffer_position, cx), ) }); let (completion, completions_cycling) = futures::join!(completion, completions_cycling); let mut completions = Vec::new(); - completions.extend(completion.log_err().flatten()); + completions.extend(completion.log_err().into_iter().flatten()); completions.extend(completions_cycling.log_err().into_iter().flatten()); this.upgrade(&cx)?.update(&mut cx, |this, cx| { if !completions.is_empty() {