diff --git a/Cargo.lock b/Cargo.lock index 3417a62d82..894b827c4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2316,6 +2316,7 @@ dependencies = [ "sha2 0.10.7", "sqlx", "subtle", + "supermaven_api", "telemetry_events", "text", "theme", @@ -2512,30 +2513,10 @@ dependencies = [ "async-compression", "async-std", "async-tar", + "client", "clock", "collections", "command_palette_hooks", - "fs", - "futures 0.3.28", - "gpui", - "language", - "lsp", - "node_runtime", - "parking_lot", - "rpc", - "serde", - "settings", - "smol", - "util", -] - -[[package]] -name = "copilot_ui" -version = "0.1.0" -dependencies = [ - "anyhow", - "client", - "copilot", "editor", "fs", "futures 0.3.28", @@ -2544,14 +2525,18 @@ dependencies = [ "language", "lsp", "menu", + "node_runtime", + "parking_lot", "project", + "rpc", + "serde", "serde_json", "settings", + "smol", "theme", "ui", "util", "workspace", - "zed_actions", ] [[package]] @@ -5143,6 +5128,30 @@ dependencies = [ "syn 2.0.59", ] +[[package]] +name = "inline_completion_button" +version = "0.1.0" +dependencies = [ + "anyhow", + "copilot", + "editor", + "fs", + "futures 0.3.28", + "gpui", + "indoc", + "language", + "lsp", + "project", + "serde_json", + "settings", + "supermaven", + "theme", + "ui", + "util", + "workspace", + "zed_actions", +] + [[package]] name = "inotify" version = "0.9.6" @@ -5548,6 +5557,7 @@ dependencies = [ "anyhow", "client", "collections", + "copilot", "editor", "env_logger", "futures 0.3.28", @@ -7422,7 +7432,6 @@ dependencies = [ "client", "clock", "collections", - "copilot", "env_logger", "fs", "futures 0.3.28", @@ -9594,6 +9603,43 @@ dependencies = [ "rayon", ] +[[package]] +name = "supermaven" +version = "0.1.0" +dependencies = [ + "anyhow", + "client", + "collections", + "editor", + "env_logger", + "futures 0.3.28", + "gpui", + "language", + "log", + "postage", + "project", + "serde", + "serde_json", + "settings", + "smol", + "supermaven_api", + "theme", + "ui", + "util", +] + +[[package]] +name = "supermaven_api" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.28", + "serde", + "serde_json", + "smol", + "util", +] + [[package]] name = "sval" version = "2.8.0" @@ -11798,12 +11844,12 @@ version = "0.1.0" dependencies = [ "anyhow", "client", - "copilot_ui", "db", "editor", "extensions_ui", "fuzzy", "gpui", + "inline_completion_button", "install_cli", "picker", "project", @@ -12683,7 +12729,6 @@ dependencies = [ "collections", "command_palette", "copilot", - "copilot_ui", "db", "dev_server_projects", "diagnostics", @@ -12700,6 +12745,7 @@ dependencies = [ "gpui", "headless", "image_viewer", + "inline_completion_button", "install_cli", "isahc", "journal", @@ -12730,6 +12776,7 @@ dependencies = [ "settings", "simplelog", "smol", + "supermaven", "tab_switcher", "task", "tasks_ui", diff --git a/Cargo.toml b/Cargo.toml index ca0e5f35bd..67ce732b61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ members = [ "crates/command_palette", "crates/command_palette_hooks", "crates/copilot", - "crates/copilot_ui", "crates/db", "crates/diagnostics", "crates/editor", @@ -42,6 +41,7 @@ members = [ "crates/gpui_macros", "crates/headless", "crates/image_viewer", + "crates/inline_completion_button", "crates/install_cli", "crates/journal", "crates/language", @@ -86,6 +86,8 @@ members = [ "crates/storybook", "crates/sum_tree", "crates/tab_switcher", + "crates/supermaven", + "crates/supermaven_api", "crates/terminal", "crates/terminal_view", "crates/text", @@ -159,7 +161,6 @@ color = { path = "crates/color" } command_palette = { path = "crates/command_palette" } command_palette_hooks = { path = "crates/command_palette_hooks" } copilot = { path = "crates/copilot" } -copilot_ui = { path = "crates/copilot_ui" } db = { path = "crates/db" } diagnostics = { path = "crates/diagnostics" } editor = { path = "crates/editor" } @@ -180,6 +181,7 @@ gpui_macros = { path = "crates/gpui_macros" } headless = { path = "crates/headless" } install_cli = { path = "crates/install_cli" } image_viewer = { path = "crates/image_viewer" } +inline_completion_button = { path = "crates/inline_completion_button" } journal = { path = "crates/journal" } language = { path = "crates/language" } language_selector = { path = "crates/language_selector" } @@ -220,6 +222,8 @@ settings = { path = "crates/settings" } snippet = { path = "crates/snippet" } sqlez = { path = "crates/sqlez" } sqlez_macros = { path = "crates/sqlez_macros" } +supermaven = { path = "crates/supermaven" } +supermaven_api = { path = "crates/supermaven_api"} story = { path = "crates/story" } storybook = { path = "crates/storybook" } sum_tree = { path = "crates/sum_tree" } diff --git a/assets/icons/supermaven.svg b/assets/icons/supermaven.svg new file mode 100644 index 0000000000..19837fbf56 --- /dev/null +++ b/assets/icons/supermaven.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/assets/icons/supermaven_disabled.svg b/assets/icons/supermaven_disabled.svg new file mode 100644 index 0000000000..39ff8a6122 --- /dev/null +++ b/assets/icons/supermaven_disabled.svg @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/assets/icons/supermaven_error.svg b/assets/icons/supermaven_error.svg new file mode 100644 index 0000000000..669322b97d --- /dev/null +++ b/assets/icons/supermaven_error.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/assets/icons/supermaven_init.svg b/assets/icons/supermaven_init.svg new file mode 100644 index 0000000000..b919d5559b --- /dev/null +++ b/assets/icons/supermaven_init.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/assets/settings/default.json b/assets/settings/default.json index c8560d7f15..de6da01f87 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -12,8 +12,8 @@ "base_keymap": "VSCode", // Features that can be globally enabled or disabled "features": { - // Show Copilot icon in status bar - "copilot": true + // Which inline completion provider to use. + "inline_completion_provider": "copilot" }, // The name of a font to use for rendering text in the editor "buffer_font_family": "Zed Mono", diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index a96a23b166..aeaae1f34d 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; use serde::{Deserialize, Serialize}; -use std::convert::TryFrom; +use std::{convert::TryFrom, sync::Arc}; use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] @@ -141,7 +141,7 @@ pub enum TextDelta { } pub async fn stream_completion( - client: &dyn HttpClient, + client: Arc, api_url: &str, api_key: &str, request: Request, diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index e8dbcf851a..5e719739ae 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -39,6 +39,7 @@ live_kit_server.workspace = true log.workspace = true nanoid.workspace = true open_ai.workspace = true +supermaven_api.workspace = true parking_lot.workspace = true prometheus = "0.13" prost.workspace = true diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index 8bd6a71514..271b146b0b 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -172,6 +172,11 @@ spec: secretKeyRef: name: slack key: panics_webhook + - name: SUPERMAVEN_ADMIN_API_KEY + valueFrom: + secretKeyRef: + name: supermaven + key: api_key - name: INVITE_LINK_PREFIX value: ${INVITE_LINK_PREFIX} - name: RUST_BACKTRACE diff --git a/crates/collab/src/completion.rs b/crates/collab/src/completion.rs new file mode 100644 index 0000000000..dd1f4b3be6 --- /dev/null +++ b/crates/collab/src/completion.rs @@ -0,0 +1,2 @@ +use anyhow::{anyhow, Result}; +use rpc::proto; diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 925d192fc0..ae83fccb98 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -138,6 +138,7 @@ pub struct Config { pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, + pub supermaven_admin_api_key: Option>, } impl Config { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index e4a83a4338..59f811f0b5 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -34,6 +34,7 @@ pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL}; use sha2::Digest; +use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; use futures::{ channel::oneshot, @@ -148,7 +149,8 @@ struct Session { peer: Arc, connection_pool: Arc>, live_kit_client: Option>, - http_client: IsahcHttpClient, + supermaven_client: Option>, + http_client: Arc, rate_limiter: Arc, _executor: Executor, } @@ -189,6 +191,14 @@ impl Session { } } + fn is_staff(&self) -> bool { + match &self.principal { + Principal::User(user) => user.admin, + Principal::Impersonated { .. } => true, + Principal::DevServer(_) => false, + } + } + fn dev_server_id(&self) -> Option { match &self.principal { Principal::User(_) | Principal::Impersonated { .. } => None, @@ -233,6 +243,14 @@ impl UserSession { pub fn user_id(&self) -> UserId { self.0.user_id().unwrap() } + + pub fn email(&self) -> Option { + match &self.0.principal { + Principal::User(user) => user.email_address.clone(), + Principal::Impersonated { user, .. } => user.email_address.clone(), + Principal::DevServer(..) => None, + } + } } impl Deref for UserSession { @@ -561,6 +579,7 @@ impl Server { .add_request_handler(user_handler(get_private_user_info)) .add_message_handler(user_message_handler(acknowledge_channel_message)) .add_message_handler(user_message_handler(acknowledge_buffer_version)) + .add_request_handler(user_handler(get_supermaven_api_key)) .add_streaming_request_handler({ let app_state = app_state.clone(); move |request, response, session| { @@ -938,13 +957,22 @@ impl Server { tracing::info!("connection opened"); let http_client = match IsahcHttpClient::new() { - Ok(http_client) => http_client, + Ok(http_client) => Arc::new(http_client), Err(error) => { tracing::error!(?error, "failed to create HTTP client"); return; } }; + let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() { + Some(Arc::new(SupermavenAdminApi::new( + supermaven_admin_api_key.to_string(), + http_client.clone(), + ))) + } else { + None + }; + let session = Session { principal: principal.clone(), connection_id, @@ -955,6 +983,7 @@ impl Server { http_client, rate_limiter: this.app_state.rate_limiter.clone(), _executor: executor.clone(), + supermaven_client, }; if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await { @@ -4210,7 +4239,7 @@ async fn complete_with_open_ai( api_key: Arc, ) -> Result<()> { let mut completion_stream = open_ai::stream_completion( - &session.http_client, + session.http_client.as_ref(), OPEN_AI_API_URL, &api_key, crate::ai::language_model_request_to_open_ai(request)?, @@ -4274,7 +4303,7 @@ async fn complete_with_google_ai( api_key: Arc, ) -> Result<()> { let mut stream = google_ai::stream_generate_content( - &session.http_client, + session.http_client.clone(), google_ai::API_URL, api_key.as_ref(), crate::ai::language_model_request_to_google_ai(request)?, @@ -4358,7 +4387,7 @@ async fn complete_with_anthropic( .collect(); let mut stream = anthropic::stream_completion( - &session.http_client, + session.http_client.clone(), "https://api.anthropic.com", &api_key, anthropic::Request { @@ -4482,7 +4511,7 @@ async fn count_tokens_with_language_model( let api_key = google_ai_api_key .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?; let tokens_response = google_ai::count_tokens( - &session.http_client, + session.http_client.as_ref(), google_ai::API_URL, &api_key, crate::ai::count_tokens_request_to_google_ai(request)?, @@ -4530,7 +4559,7 @@ async fn compute_embeddings( let embeddings = match request.model.as_str() { "openai/text-embedding-3-small" => { open_ai::embed( - &session.http_client, + session.http_client.as_ref(), OPEN_AI_API_URL, &api_key, OpenAiEmbeddingModel::TextEmbedding3Small, @@ -4602,6 +4631,37 @@ async fn authorize_access_to_language_models(session: &UserSession) -> Result<() } } +/// Get a Supermaven API key for the user +async fn get_supermaven_api_key( + _request: proto::GetSupermavenApiKey, + response: Response, + session: UserSession, +) -> Result<()> { + let user_id: String = session.user_id().to_string(); + if !session.is_staff() { + return Err(anyhow!("supermaven not enabled for this account"))?; + } + + let email = session + .email() + .ok_or_else(|| anyhow!("user must have an email"))?; + + let supermaven_admin_api = session + .supermaven_client + .as_ref() + .ok_or_else(|| anyhow!("supermaven not configured"))?; + + let result = supermaven_admin_api + .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email }) + .await?; + + response.send(proto::GetSupermavenApiKeyResponse { + api_key: result.api_key, + })?; + + Ok(()) +} + /// Start receiving chat updates for a channel async fn join_channel_chat( request: proto::JoinChannelChat, diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 2fec21f76e..3a456a328e 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -655,6 +655,7 @@ impl TestServer { auto_join_channel_id: None, migrations_path: None, seed_path: None, + supermaven_admin_api_key: None, }, }) } diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 609bd0e3a8..3f38a81f5b 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -27,28 +27,38 @@ anyhow.workspace = true async-compression.workspace = true async-tar.workspace = true collections.workspace = true +client.workspace = true command_palette_hooks.workspace = true +editor.workspace = true futures.workspace = true gpui.workspace = true language.workspace = true lsp.workspace = true +menu.workspace = true node_runtime.workspace = true parking_lot.workspace = true +project.workspace = true serde.workspace = true settings.workspace = true smol.workspace = true +ui.workspace = true util.workspace = true +workspace.workspace = true [target.'cfg(windows)'.dependencies] async-std = { version = "1.12.0", features = ["unstable"] } [dev-dependencies] clock.workspace = true +indoc.workspace = true +serde_json.workspace = true collections = { workspace = true, features = ["test-support"] } fs = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] } lsp = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } rpc = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } +theme = { workspace = true, features = ["test-support"] } util = { workspace = true, features = ["test-support"] } diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 99f94b5511..577f335d2a 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -1,4 +1,7 @@ +mod copilot_completion_provider; pub mod request; +mod sign_in; + use anyhow::{anyhow, Context as _, Result}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; @@ -10,9 +13,9 @@ use gpui::{ ModelContext, Task, WeakModel, }; use language::{ - language_settings::{all_language_settings, language_settings}, - point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, - LanguageServerName, PointUtf16, ToPointUtf16, + language_settings::{all_language_settings, language_settings, InlineCompletionProvider}, + point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, + ToPointUtf16, }; use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId}; use node_runtime::NodeRuntime; @@ -32,6 +35,9 @@ use util::{ fs::remove_matching, github::latest_github_release, http::HttpClient, maybe, paths, ResultExt, }; +pub use copilot_completion_provider::CopilotCompletionProvider; +pub use sign_in::CopilotCodeVerification; + actions!( copilot, [ @@ -144,7 +150,6 @@ impl CopilotServer { } struct RunningCopilotServer { - name: LanguageServerName, lsp: Arc, sign_in_status: SignInStatus, registered_buffers: HashMap, @@ -354,7 +359,9 @@ impl Copilot { let server_id = self.server_id; let http = self.http.clone(); let node_runtime = self.node_runtime.clone(); - if all_language_settings(None, cx).copilot_enabled(None, None) { + if all_language_settings(None, cx).inline_completions.provider + == InlineCompletionProvider::Copilot + { if matches!(self.server, CopilotServer::Disabled) { let start_task = cx .spawn(move |this, cx| { @@ -393,7 +400,6 @@ impl Copilot { http: http.clone(), node_runtime, server: CopilotServer::Running(RunningCopilotServer { - name: LanguageServerName(Arc::from("copilot")), lsp: Arc::new(server), sign_in_status: SignInStatus::Authorized, registered_buffers: Default::default(), @@ -467,7 +473,6 @@ impl Copilot { match server { Ok((server, status)) => { this.server = CopilotServer::Running(RunningCopilotServer { - name: LanguageServerName(Arc::from("copilot")), lsp: server, sign_in_status: SignInStatus::SignedOut, registered_buffers: Default::default(), @@ -607,9 +612,9 @@ impl Copilot { cx.background_executor().spawn(start_task) } - pub fn language_server(&self) -> Option<(&LanguageServerName, &Arc)> { + pub fn language_server(&self) -> Option<&Arc> { if let CopilotServer::Running(server) = &self.server { - Some((&server.name, &server.lsp)) + Some(&server.lsp) } else { None } diff --git a/crates/copilot_ui/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs similarity index 94% rename from crates/copilot_ui/src/copilot_completion_provider.rs rename to crates/copilot/src/copilot_completion_provider.rs index c6226c7bb1..970145a10f 100644 --- a/crates/copilot_ui/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -1,10 +1,12 @@ +use crate::{Completion, Copilot}; use anyhow::Result; use client::telemetry::Telemetry; -use copilot::Copilot; use editor::{Direction, InlineCompletionProvider}; use gpui::{AppContext, EntityId, Model, ModelContext, Task}; -use language::language_settings::AllLanguageSettings; -use language::{language_settings::all_language_settings, Buffer, OffsetRangeExt, ToOffset}; +use language::{ + language_settings::{all_language_settings, AllLanguageSettings}, + Buffer, OffsetRangeExt, ToOffset, +}; use settings::Settings; use std::{path::Path, sync::Arc, time::Duration}; @@ -13,7 +15,7 @@ pub const COPILOT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); pub struct CopilotCompletionProvider { cycled: bool, buffer_id: Option, - completions: Vec, + completions: Vec, active_completion_index: usize, file_extension: Option, pending_refresh: Task>, @@ -42,11 +44,11 @@ impl CopilotCompletionProvider { self } - fn active_completion(&self) -> Option<&copilot::Completion> { + fn active_completion(&self) -> Option<&Completion> { self.completions.get(self.active_completion_index) } - fn push_completion(&mut self, new_completion: copilot::Completion) { + fn push_completion(&mut self, new_completion: Completion) { for completion in &self.completions { if completion.text == new_completion.text && completion.range == new_completion.range { return; @@ -71,7 +73,7 @@ impl InlineCompletionProvider for CopilotCompletionProvider { let file = buffer.file(); let language = buffer.language_at(cursor_position); let settings = all_language_settings(file, cx); - settings.copilot_enabled(language.as_ref(), file.map(|f| f.path().as_ref())) + settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref())) } fn refresh( @@ -196,7 +198,10 @@ impl InlineCompletionProvider for CopilotCompletionProvider { fn discard(&mut self, cx: &mut ModelContext) { let settings = AllLanguageSettings::get_global(cx); - if !settings.copilot.feature_enabled { + + let copilot_enabled = settings.inline_completions_enabled(None, None); + + if !copilot_enabled { return; } @@ -298,7 +303,9 @@ mod tests { ) .await; let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); - cx.update_editor(|editor, cx| editor.set_inline_completion_provider(copilot_provider, cx)); + cx.update_editor(|editor, cx| { + editor.set_inline_completion_provider(Some(copilot_provider), cx) + }); // When inserting, ensure autocompletion is favored over Copilot suggestions. cx.set_state(indoc! {" @@ -318,7 +325,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot1".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -360,7 +367,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot1".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -393,7 +400,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot1".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -426,7 +433,7 @@ mod tests { // After debouncing, new Copilot completions should be requested. handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot2".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 5)), ..Default::default() @@ -503,7 +510,7 @@ mod tests { }); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: " let x = 4;".into(), range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 2)), ..Default::default() @@ -553,7 +560,9 @@ mod tests { ) .await; let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); - cx.update_editor(|editor, cx| editor.set_inline_completion_provider(copilot_provider, cx)); + cx.update_editor(|editor, cx| { + editor.set_inline_completion_provider(Some(copilot_provider), cx) + }); // Setup the editor with a completion request. cx.set_state(indoc! {" @@ -573,7 +582,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot1".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -615,7 +624,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.123. copilot\n 456".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -675,7 +684,9 @@ mod tests { ) .await; let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); - cx.update_editor(|editor, cx| editor.set_inline_completion_provider(copilot_provider, cx)); + cx.update_editor(|editor, cx| { + editor.set_inline_completion_provider(Some(copilot_provider), cx) + }); cx.set_state(indoc! {" one @@ -685,7 +696,7 @@ mod tests { handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "two.foo()".into(), range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 2)), ..Default::default() @@ -756,13 +767,13 @@ mod tests { let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); editor .update(cx, |editor, cx| { - editor.set_inline_completion_provider(copilot_provider, cx) + editor.set_inline_completion_provider(Some(copilot_provider), cx) }) .unwrap(); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "b = 2 + a".into(), range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 5)), ..Default::default() @@ -788,7 +799,7 @@ mod tests { handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "d = 4 + c".into(), range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 6)), ..Default::default() @@ -833,7 +844,7 @@ mod tests { async fn test_copilot_disabled_globs(executor: BackgroundExecutor, cx: &mut TestAppContext) { init_test(cx, |settings| { settings - .copilot + .inline_completions .get_or_insert(Default::default()) .disabled_globs = Some(vec![".env*".to_string()]); }); @@ -888,15 +899,15 @@ mod tests { let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); editor .update(cx, |editor, cx| { - editor.set_inline_completion_provider(copilot_provider, cx) + editor.set_inline_completion_provider(Some(copilot_provider), cx) }) .unwrap(); let mut copilot_requests = copilot_lsp - .handle_request::( + .handle_request::( move |_params, _cx| async move { - Ok(copilot::request::GetCompletionsResult { - completions: vec![copilot::request::Completion { + Ok(crate::request::GetCompletionsResult { + completions: vec![crate::request::Completion { text: "next line".into(), range: lsp::Range::new( lsp::Position::new(1, 0), @@ -931,21 +942,21 @@ mod tests { fn handle_copilot_completion_request( lsp: &lsp::FakeLanguageServer, - completions: Vec, - completions_cycling: Vec, + completions: Vec, + completions_cycling: Vec, ) { - lsp.handle_request::(move |_params, _cx| { + lsp.handle_request::(move |_params, _cx| { let completions = completions.clone(); async move { - Ok(copilot::request::GetCompletionsResult { + Ok(crate::request::GetCompletionsResult { completions: completions.clone(), }) } }); - lsp.handle_request::(move |_params, _cx| { + lsp.handle_request::(move |_params, _cx| { let completions_cycling = completions_cycling.clone(); async move { - Ok(copilot::request::GetCompletionsResult { + Ok(crate::request::GetCompletionsResult { completions: completions_cycling.clone(), }) } diff --git a/crates/copilot_ui/src/sign_in.rs b/crates/copilot/src/sign_in.rs similarity index 98% rename from crates/copilot_ui/src/sign_in.rs rename to crates/copilot/src/sign_in.rs index 396b2367f9..abf7252fef 100644 --- a/crates/copilot_ui/src/sign_in.rs +++ b/crates/copilot/src/sign_in.rs @@ -1,4 +1,4 @@ -use copilot::{request::PromptUserDeviceFlow, Copilot, Status}; +use crate::{request::PromptUserDeviceFlow, Copilot, Status}; use gpui::{ div, svg, AppContext, ClipboardItem, DismissEvent, Element, EventEmitter, FocusHandle, FocusableView, InteractiveElement, IntoElement, Model, MouseDownEvent, ParentElement, Render, @@ -26,7 +26,7 @@ impl EventEmitter for CopilotCodeVerification {} impl ModalView for CopilotCodeVerification {} impl CopilotCodeVerification { - pub(crate) fn new(copilot: &Model, cx: &mut ViewContext) -> Self { + pub fn new(copilot: &Model, cx: &mut ViewContext) -> Self { let status = copilot.read(cx).status(); Self { status, diff --git a/crates/copilot_ui/src/copilot_button.rs b/crates/copilot_ui/src/copilot_button.rs deleted file mode 100644 index b228a10839..0000000000 --- a/crates/copilot_ui/src/copilot_button.rs +++ /dev/null @@ -1,403 +0,0 @@ -use crate::sign_in::CopilotCodeVerification; -use anyhow::Result; -use copilot::{Copilot, SignOut, Status}; -use editor::{scroll::Autoscroll, Editor}; -use fs::Fs; -use gpui::{ - div, Action, AnchorCorner, AppContext, AsyncWindowContext, Entity, IntoElement, ParentElement, - Render, Subscription, View, ViewContext, WeakView, WindowContext, -}; -use language::{ - language_settings::{self, all_language_settings, AllLanguageSettings}, - File, Language, -}; -use settings::{update_settings_file, Settings, SettingsStore}; -use std::{path::Path, sync::Arc}; -use util::{paths, ResultExt}; -use workspace::notifications::NotificationId; -use workspace::{ - create_and_open_local_file, - item::ItemHandle, - ui::{ - popover_menu, ButtonCommon, Clickable, ContextMenu, IconButton, IconName, IconSize, Tooltip, - }, - StatusItemView, Toast, Workspace, -}; -use zed_actions::OpenBrowser; - -const COPILOT_SETTINGS_URL: &str = "https://github.com/settings/copilot"; - -struct CopilotStartingToast; - -struct CopilotErrorToast; - -pub struct CopilotButton { - editor_subscription: Option<(Subscription, usize)>, - editor_enabled: Option, - language: Option>, - file: Option>, - fs: Arc, -} - -impl Render for CopilotButton { - fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - let all_language_settings = all_language_settings(None, cx); - if !all_language_settings.copilot.feature_enabled { - return div(); - } - - let Some(copilot) = Copilot::global(cx) else { - return div(); - }; - let status = copilot.read(cx).status(); - - let enabled = self - .editor_enabled - .unwrap_or_else(|| all_language_settings.copilot_enabled(None, None)); - - let icon = match status { - Status::Error(_) => IconName::CopilotError, - Status::Authorized => { - if enabled { - IconName::Copilot - } else { - IconName::CopilotDisabled - } - } - _ => IconName::CopilotInit, - }; - - if let Status::Error(e) = status { - return div().child( - IconButton::new("copilot-error", icon) - .icon_size(IconSize::Small) - .on_click(cx.listener(move |_, _, cx| { - if let Some(workspace) = cx.window_handle().downcast::() { - workspace - .update(cx, |workspace, cx| { - workspace.show_toast( - Toast::new( - NotificationId::unique::(), - format!("Copilot can't be started: {}", e), - ) - .on_click( - "Reinstall Copilot", - |cx| { - if let Some(copilot) = Copilot::global(cx) { - copilot - .update(cx, |copilot, cx| { - copilot.reinstall(cx) - }) - .detach(); - } - }, - ), - cx, - ); - }) - .ok(); - } - })) - .tooltip(|cx| Tooltip::text("GitHub Copilot", cx)), - ); - } - let this = cx.view().clone(); - - div().child( - popover_menu("copilot") - .menu(move |cx| match status { - Status::Authorized => { - Some(this.update(cx, |this, cx| this.build_copilot_menu(cx))) - } - _ => Some(this.update(cx, |this, cx| this.build_copilot_start_menu(cx))), - }) - .anchor(AnchorCorner::BottomRight) - .trigger( - IconButton::new("copilot-icon", icon) - .tooltip(|cx| Tooltip::text("GitHub Copilot", cx)), - ), - ) - } -} - -impl CopilotButton { - pub fn new(fs: Arc, cx: &mut ViewContext) -> Self { - if let Some(copilot) = Copilot::global(cx) { - cx.observe(&copilot, |_, _, cx| cx.notify()).detach() - } - - cx.observe_global::(move |_, cx| cx.notify()) - .detach(); - - Self { - editor_subscription: None, - editor_enabled: None, - language: None, - file: None, - fs, - } - } - - pub fn build_copilot_start_menu(&mut self, cx: &mut ViewContext) -> View { - let fs = self.fs.clone(); - ContextMenu::build(cx, |menu, _| { - menu.entry("Sign In", None, initiate_sign_in).entry( - "Disable Copilot", - None, - move |cx| hide_copilot(fs.clone(), cx), - ) - }) - } - - pub fn build_copilot_menu(&mut self, cx: &mut ViewContext) -> View { - let fs = self.fs.clone(); - - ContextMenu::build(cx, move |mut menu, cx| { - if let Some(language) = self.language.clone() { - let fs = fs.clone(); - let language_enabled = - language_settings::language_settings(Some(&language), None, cx) - .show_copilot_suggestions; - - menu = menu.entry( - format!( - "{} Suggestions for {}", - if language_enabled { "Hide" } else { "Show" }, - language.name() - ), - None, - move |cx| toggle_copilot_for_language(language.clone(), fs.clone(), cx), - ); - } - - let settings = AllLanguageSettings::get_global(cx); - - if let Some(file) = &self.file { - let path = file.path().clone(); - let path_enabled = settings.copilot_enabled_for_path(&path); - - menu = menu.entry( - format!( - "{} Suggestions for This Path", - if path_enabled { "Hide" } else { "Show" } - ), - None, - move |cx| { - if let Some(workspace) = cx.window_handle().downcast::() { - if let Ok(workspace) = workspace.root_view(cx) { - let workspace = workspace.downgrade(); - cx.spawn(|cx| { - configure_disabled_globs( - workspace, - path_enabled.then_some(path.clone()), - cx, - ) - }) - .detach_and_log_err(cx); - } - } - }, - ); - } - - let globally_enabled = settings.copilot_enabled(None, None); - menu.entry( - if globally_enabled { - "Hide Suggestions for All Files" - } else { - "Show Suggestions for All Files" - }, - None, - move |cx| toggle_copilot_globally(fs.clone(), cx), - ) - .separator() - .link( - "Copilot Settings", - OpenBrowser { - url: COPILOT_SETTINGS_URL.to_string(), - } - .boxed_clone(), - ) - .action("Sign Out", SignOut.boxed_clone()) - }) - } - - pub fn update_enabled(&mut self, editor: View, cx: &mut ViewContext) { - let editor = editor.read(cx); - let snapshot = editor.buffer().read(cx).snapshot(cx); - let suggestion_anchor = editor.selections.newest_anchor().start; - let language = snapshot.language_at(suggestion_anchor); - let file = snapshot.file_at(suggestion_anchor).cloned(); - self.editor_enabled = { - let file = file.as_ref(); - Some( - file.map(|file| !file.is_private()).unwrap_or(true) - && all_language_settings(file, cx) - .copilot_enabled(language, file.map(|file| file.path().as_ref())), - ) - }; - self.language = language.cloned(); - self.file = file; - - cx.notify() - } -} - -impl StatusItemView for CopilotButton { - fn set_active_pane_item(&mut self, item: Option<&dyn ItemHandle>, cx: &mut ViewContext) { - if let Some(editor) = item.and_then(|item| item.act_as::(cx)) { - self.editor_subscription = Some(( - cx.observe(&editor, Self::update_enabled), - editor.entity_id().as_u64() as usize, - )); - self.update_enabled(editor, cx); - } else { - self.language = None; - self.editor_subscription = None; - self.editor_enabled = None; - } - cx.notify(); - } -} - -async fn configure_disabled_globs( - workspace: WeakView, - path_to_disable: Option>, - mut cx: AsyncWindowContext, -) -> Result<()> { - let settings_editor = workspace - .update(&mut cx, |_, cx| { - create_and_open_local_file(&paths::SETTINGS, cx, || { - settings::initial_user_settings_content().as_ref().into() - }) - })? - .await? - .downcast::() - .unwrap(); - - settings_editor.downgrade().update(&mut cx, |item, cx| { - let text = item.buffer().read(cx).snapshot(cx).text(); - - let settings = cx.global::(); - let edits = settings.edits_for_update::(&text, |file| { - let copilot = file.copilot.get_or_insert_with(Default::default); - let globs = copilot.disabled_globs.get_or_insert_with(|| { - settings - .get::(None) - .copilot - .disabled_globs - .iter() - .map(|glob| glob.glob().to_string()) - .collect() - }); - - if let Some(path_to_disable) = &path_to_disable { - globs.push(path_to_disable.to_string_lossy().into_owned()); - } else { - globs.clear(); - } - }); - - if !edits.is_empty() { - item.change_selections(Some(Autoscroll::newest()), cx, |selections| { - selections.select_ranges(edits.iter().map(|e| e.0.clone())); - }); - - // When *enabling* a path, don't actually perform an edit, just select the range. - if path_to_disable.is_some() { - item.edit(edits.iter().cloned(), cx); - } - } - })?; - - anyhow::Ok(()) -} - -fn toggle_copilot_globally(fs: Arc, cx: &mut AppContext) { - let show_copilot_suggestions = all_language_settings(None, cx).copilot_enabled(None, None); - update_settings_file::(fs, cx, move |file| { - file.defaults.show_copilot_suggestions = Some(!show_copilot_suggestions) - }); -} - -fn toggle_copilot_for_language(language: Arc, fs: Arc, cx: &mut AppContext) { - let show_copilot_suggestions = - all_language_settings(None, cx).copilot_enabled(Some(&language), None); - update_settings_file::(fs, cx, move |file| { - file.languages - .entry(language.name()) - .or_default() - .show_copilot_suggestions = Some(!show_copilot_suggestions); - }); -} - -fn hide_copilot(fs: Arc, cx: &mut AppContext) { - update_settings_file::(fs, cx, move |file| { - file.features.get_or_insert(Default::default()).copilot = Some(false); - }); -} - -pub fn initiate_sign_in(cx: &mut WindowContext) { - let Some(copilot) = Copilot::global(cx) else { - return; - }; - let status = copilot.read(cx).status(); - let Some(workspace) = cx.window_handle().downcast::() else { - return; - }; - match status { - Status::Starting { task } => { - let Some(workspace) = cx.window_handle().downcast::() else { - return; - }; - - let Ok(workspace) = workspace.update(cx, |workspace, cx| { - workspace.show_toast( - Toast::new( - NotificationId::unique::(), - "Copilot is starting...", - ), - cx, - ); - workspace.weak_handle() - }) else { - return; - }; - - cx.spawn(|mut cx| async move { - task.await; - if let Some(copilot) = cx.update(|cx| Copilot::global(cx)).ok().flatten() { - workspace - .update(&mut cx, |workspace, cx| match copilot.read(cx).status() { - Status::Authorized => workspace.show_toast( - Toast::new( - NotificationId::unique::(), - "Copilot has started!", - ), - cx, - ), - _ => { - workspace.dismiss_toast( - &NotificationId::unique::(), - cx, - ); - copilot - .update(cx, |copilot, cx| copilot.sign_in(cx)) - .detach_and_log_err(cx); - } - }) - .log_err(); - } - }) - .detach(); - } - _ => { - copilot.update(cx, |this, cx| this.sign_in(cx)).detach(); - workspace - .update(cx, |this, cx| { - this.toggle_modal(cx, |cx| CopilotCodeVerification::new(&copilot, cx)); - }) - .ok(); - } - } -} diff --git a/crates/copilot_ui/src/copilot_ui.rs b/crates/copilot_ui/src/copilot_ui.rs deleted file mode 100644 index 63bd03102f..0000000000 --- a/crates/copilot_ui/src/copilot_ui.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod copilot_button; -mod copilot_completion_provider; -mod sign_in; - -pub use copilot_button::*; -pub use copilot_completion_provider::*; -pub use sign_in::*; diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index ef5c97a592..bbd215b23b 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1757,19 +1757,22 @@ impl Editor { self.completion_provider = Some(hub); } - pub fn set_inline_completion_provider( + pub fn set_inline_completion_provider( &mut self, - provider: Model, + provider: Option>, cx: &mut ViewContext, - ) { - self.inline_completion_provider = Some(RegisteredInlineCompletionProvider { - _subscription: cx.observe(&provider, |this, _, cx| { - if this.focus_handle.is_focused(cx) { - this.update_visible_inline_completion(cx); - } - }), - provider: Arc::new(provider), - }); + ) where + T: InlineCompletionProvider, + { + self.inline_completion_provider = + provider.map(|provider| RegisteredInlineCompletionProvider { + _subscription: cx.observe(&provider, |this, _, cx| { + if this.focus_handle.is_focused(cx) { + this.update_visible_inline_completion(cx); + } + }), + provider: Arc::new(provider), + }); self.refresh_inline_completion(false, cx); } @@ -2676,7 +2679,7 @@ impl Editor { } drop(snapshot); - let had_active_copilot_completion = this.has_active_inline_completion(cx); + let had_active_inline_completion = this.has_active_inline_completion(cx); this.change_selections(Some(Autoscroll::fit()), cx, |s| s.select(new_selections)); if brace_inserted { @@ -2692,7 +2695,7 @@ impl Editor { } } - if had_active_copilot_completion { + if had_active_inline_completion { this.refresh_inline_completion(true, cx); if !this.has_active_inline_completion(cx) { this.trigger_completion_on_input(&text, cx); @@ -4005,7 +4008,7 @@ impl Editor { if !self.show_inline_completions || !provider.is_enabled(&buffer, cursor_buffer_position, cx) { - self.clear_inline_completion(cx); + self.discard_inline_completion(cx); return None; } @@ -4207,13 +4210,6 @@ impl Editor { self.discard_inline_completion(cx); } - fn clear_inline_completion(&mut self, cx: &mut ViewContext) { - if let Some(old_completion) = self.active_inline_completion.take() { - self.splice_inlays(vec![old_completion.id], Vec::new(), cx); - } - self.discard_inline_completion(cx); - } - fn inline_completion_provider(&self) -> Option> { Some(self.inline_completion_provider.as_ref()?.provider.clone()) } @@ -9947,12 +9943,14 @@ impl Editor { .raw_user_settings() .get("vim_mode") == Some(&serde_json::Value::Bool(true)); - let copilot_enabled = all_language_settings(file, cx).copilot_enabled(None, None); + + let copilot_enabled = all_language_settings(file, cx).inline_completions.provider + == language::language_settings::InlineCompletionProvider::Copilot; let copilot_enabled_for_language = self .buffer .read(cx) .settings_at(0, cx) - .show_copilot_suggestions; + .show_inline_completions; let telemetry = project.read(cx).client().telemetry().clone(); telemetry.report_editor_event( diff --git a/crates/editor/src/inline_completion_provider.rs b/crates/editor/src/inline_completion_provider.rs index 31edf80623..2fb2cb608f 100644 --- a/crates/editor/src/inline_completion_provider.rs +++ b/crates/editor/src/inline_completion_provider.rs @@ -25,11 +25,11 @@ pub trait InlineCompletionProvider: 'static + Sized { ); fn accept(&mut self, cx: &mut ModelContext); fn discard(&mut self, cx: &mut ModelContext); - fn active_completion_text( - &self, + fn active_completion_text<'a>( + &'a self, buffer: &Model, cursor_position: language::Anchor, - cx: &AppContext, + cx: &'a AppContext, ) -> Option<&str>; } @@ -57,7 +57,7 @@ pub trait InlineCompletionProviderHandle { fn accept(&self, cx: &mut AppContext); fn discard(&self, cx: &mut AppContext); fn active_completion_text<'a>( - &self, + &'a self, buffer: &Model, cursor_position: language::Anchor, cx: &'a AppContext, @@ -110,7 +110,7 @@ where } fn active_completion_text<'a>( - &self, + &'a self, buffer: &Model, cursor_position: language::Anchor, cx: &'a AppContext, diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 4fe461981f..53ed5894de 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use anyhow::{anyhow, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; use serde::{Deserialize, Serialize}; @@ -5,8 +7,8 @@ use util::http::HttpClient; pub const API_URL: &str = "https://generativelanguage.googleapis.com"; -pub async fn stream_generate_content( - client: &T, +pub async fn stream_generate_content( + client: Arc, api_url: &str, api_key: &str, request: GenerateContentRequest, diff --git a/crates/copilot_ui/Cargo.toml b/crates/inline_completion_button/Cargo.toml similarity index 88% rename from crates/copilot_ui/Cargo.toml rename to crates/inline_completion_button/Cargo.toml index 4bf3240aab..48acdb3ae1 100644 --- a/crates/copilot_ui/Cargo.toml +++ b/crates/inline_completion_button/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "copilot_ui" +name = "inline_completion_button" version = "0.1.0" edition = "2021" publish = false @@ -9,19 +9,18 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/copilot_ui.rs" +path = "src/inline_completion_button.rs" doctest = false [dependencies] anyhow.workspace = true -client.workspace = true copilot.workspace = true editor.workspace = true fs.workspace = true gpui.workspace = true language.workspace = true -menu.workspace = true settings.workspace = true +supermaven.workspace = true ui.workspace = true util.workspace = true workspace.workspace = true diff --git a/crates/copilot_ui/LICENSE-GPL b/crates/inline_completion_button/LICENSE-GPL similarity index 100% rename from crates/copilot_ui/LICENSE-GPL rename to crates/inline_completion_button/LICENSE-GPL diff --git a/crates/inline_completion_button/src/inline_completion_button.rs b/crates/inline_completion_button/src/inline_completion_button.rs new file mode 100644 index 0000000000..86f6945ac1 --- /dev/null +++ b/crates/inline_completion_button/src/inline_completion_button.rs @@ -0,0 +1,510 @@ +use anyhow::Result; +use copilot::{Copilot, CopilotCodeVerification, Status}; +use editor::{scroll::Autoscroll, Editor}; +use fs::Fs; +use gpui::{ + div, Action, AnchorCorner, AppContext, AsyncWindowContext, Entity, IntoElement, ParentElement, + Render, Subscription, View, ViewContext, WeakView, WindowContext, +}; +use language::{ + language_settings::{ + self, all_language_settings, AllLanguageSettings, InlineCompletionProvider, + }, + File, Language, +}; +use settings::{update_settings_file, Settings, SettingsStore}; +use std::{path::Path, sync::Arc}; +use supermaven::{AccountStatus, Supermaven}; +use util::{paths, ResultExt}; +use workspace::{ + create_and_open_local_file, + item::ItemHandle, + notifications::NotificationId, + ui::{ + popover_menu, ButtonCommon, Clickable, ContextMenu, IconButton, IconName, IconSize, Tooltip, + }, + StatusItemView, Toast, Workspace, +}; +use zed_actions::OpenBrowser; + +const COPILOT_SETTINGS_URL: &str = "https://github.com/settings/copilot"; + +struct CopilotStartingToast; + +struct CopilotErrorToast; + +pub struct InlineCompletionButton { + editor_subscription: Option<(Subscription, usize)>, + editor_enabled: Option, + language: Option>, + file: Option>, + fs: Arc, +} + +enum SupermavenButtonStatus { + Ready, + Errored(String), + NeedsActivation(String), + Initializing, +} + +impl Render for InlineCompletionButton { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let all_language_settings = all_language_settings(None, cx); + + match all_language_settings.inline_completions.provider { + InlineCompletionProvider::None => return div(), + + InlineCompletionProvider::Copilot => { + let Some(copilot) = Copilot::global(cx) else { + return div(); + }; + let status = copilot.read(cx).status(); + + let enabled = self.editor_enabled.unwrap_or_else(|| { + all_language_settings.inline_completions_enabled(None, None) + }); + + let icon = match status { + Status::Error(_) => IconName::CopilotError, + Status::Authorized => { + if enabled { + IconName::Copilot + } else { + IconName::CopilotDisabled + } + } + _ => IconName::CopilotInit, + }; + + if let Status::Error(e) = status { + return div().child( + IconButton::new("copilot-error", icon) + .icon_size(IconSize::Small) + .on_click(cx.listener(move |_, _, cx| { + if let Some(workspace) = cx.window_handle().downcast::() + { + workspace + .update(cx, |workspace, cx| { + workspace.show_toast( + Toast::new( + NotificationId::unique::(), + format!("Copilot can't be started: {}", e), + ) + .on_click("Reinstall Copilot", |cx| { + if let Some(copilot) = Copilot::global(cx) { + copilot + .update(cx, |copilot, cx| { + copilot.reinstall(cx) + }) + .detach(); + } + }), + cx, + ); + }) + .ok(); + } + })) + .tooltip(|cx| Tooltip::text("GitHub Copilot", cx)), + ); + } + let this = cx.view().clone(); + + div().child( + popover_menu("copilot") + .menu(move |cx| { + Some(match status { + Status::Authorized => { + this.update(cx, |this, cx| this.build_copilot_context_menu(cx)) + } + _ => this.update(cx, |this, cx| this.build_copilot_start_menu(cx)), + }) + }) + .anchor(AnchorCorner::BottomRight) + .trigger( + IconButton::new("copilot-icon", icon) + .tooltip(|cx| Tooltip::text("GitHub Copilot", cx)), + ), + ) + } + + InlineCompletionProvider::Supermaven => { + let Some(supermaven) = Supermaven::global(cx) else { + return div(); + }; + + let supermaven = supermaven.read(cx); + + let status = match supermaven { + Supermaven::Starting => SupermavenButtonStatus::Initializing, + Supermaven::FailedDownload { error } => { + SupermavenButtonStatus::Errored(error.to_string()) + } + Supermaven::Spawned(agent) => { + let account_status = agent.account_status.clone(); + match account_status { + AccountStatus::NeedsActivation { activate_url } => { + SupermavenButtonStatus::NeedsActivation(activate_url.clone()) + } + AccountStatus::Unknown => SupermavenButtonStatus::Initializing, + AccountStatus::Ready => SupermavenButtonStatus::Ready, + } + } + Supermaven::Error { error } => { + SupermavenButtonStatus::Errored(error.to_string()) + } + }; + + let icon = status.to_icon(); + let tooltip_text = status.to_tooltip(); + let this = cx.view().clone(); + + return div().child( + popover_menu("supermaven") + .menu(move |cx| match &status { + SupermavenButtonStatus::NeedsActivation(activate_url) => { + Some(ContextMenu::build(cx, |menu, _| { + let activate_url = activate_url.clone(); + menu.entry("Sign In", None, move |cx| { + cx.open_url(activate_url.as_str()) + }) + })) + } + SupermavenButtonStatus::Ready => Some( + this.update(cx, |this, cx| this.build_supermaven_context_menu(cx)), + ), + _ => None, + }) + .anchor(AnchorCorner::BottomRight) + .trigger( + IconButton::new("supermaven-icon", icon) + .tooltip(move |cx| Tooltip::text(tooltip_text.clone(), cx)), + ), + ); + } + } + } +} + +impl InlineCompletionButton { + pub fn new(fs: Arc, cx: &mut ViewContext) -> Self { + if let Some(copilot) = Copilot::global(cx) { + cx.observe(&copilot, |_, _, cx| cx.notify()).detach() + } + + cx.observe_global::(move |_, cx| cx.notify()) + .detach(); + + Self { + editor_subscription: None, + editor_enabled: None, + language: None, + file: None, + fs, + } + } + + pub fn build_copilot_start_menu(&mut self, cx: &mut ViewContext) -> View { + let fs = self.fs.clone(); + ContextMenu::build(cx, |menu, _| { + menu.entry("Sign In", None, initiate_sign_in).entry( + "Disable Copilot", + None, + move |cx| hide_copilot(fs.clone(), cx), + ) + }) + } + + pub fn build_language_settings_menu( + &self, + mut menu: ContextMenu, + cx: &mut WindowContext, + ) -> ContextMenu { + let fs = self.fs.clone(); + + if let Some(language) = self.language.clone() { + let fs = fs.clone(); + let language_enabled = language_settings::language_settings(Some(&language), None, cx) + .show_inline_completions; + + menu = menu.entry( + format!( + "{} Inline Completions for {}", + if language_enabled { "Hide" } else { "Show" }, + language.name() + ), + None, + move |cx| toggle_inline_completions_for_language(language.clone(), fs.clone(), cx), + ); + } + + let settings = AllLanguageSettings::get_global(cx); + + if let Some(file) = &self.file { + let path = file.path().clone(); + let path_enabled = settings.inline_completions_enabled_for_path(&path); + + menu = menu.entry( + format!( + "{} Inline Completions for This Path", + if path_enabled { "Hide" } else { "Show" } + ), + None, + move |cx| { + if let Some(workspace) = cx.window_handle().downcast::() { + if let Ok(workspace) = workspace.root_view(cx) { + let workspace = workspace.downgrade(); + cx.spawn(|cx| { + configure_disabled_globs( + workspace, + path_enabled.then_some(path.clone()), + cx, + ) + }) + .detach_and_log_err(cx); + } + } + }, + ); + } + + let globally_enabled = settings.inline_completions_enabled(None, None); + menu.entry( + if globally_enabled { + "Hide Inline Completions for All Files" + } else { + "Show Inline Completions for All Files" + }, + None, + move |cx| toggle_inline_completions_globally(fs.clone(), cx), + ) + } + + fn build_copilot_context_menu(&self, cx: &mut ViewContext) -> View { + ContextMenu::build(cx, |menu, cx| { + self.build_language_settings_menu(menu, cx) + .separator() + .link( + "Copilot Settings", + OpenBrowser { + url: COPILOT_SETTINGS_URL.to_string(), + } + .boxed_clone(), + ) + .action("Sign Out", copilot::SignOut.boxed_clone()) + }) + } + + fn build_supermaven_context_menu(&self, cx: &mut ViewContext) -> View { + ContextMenu::build(cx, |menu, cx| { + self.build_language_settings_menu(menu, cx).separator() + }) + } + + pub fn update_enabled(&mut self, editor: View, cx: &mut ViewContext) { + let editor = editor.read(cx); + let snapshot = editor.buffer().read(cx).snapshot(cx); + let suggestion_anchor = editor.selections.newest_anchor().start; + let language = snapshot.language_at(suggestion_anchor); + let file = snapshot.file_at(suggestion_anchor).cloned(); + self.editor_enabled = { + let file = file.as_ref(); + Some( + file.map(|file| !file.is_private()).unwrap_or(true) + && all_language_settings(file, cx).inline_completions_enabled( + language, + file.map(|file| file.path().as_ref()), + ), + ) + }; + self.language = language.cloned(); + self.file = file; + + cx.notify() + } +} + +impl StatusItemView for InlineCompletionButton { + fn set_active_pane_item(&mut self, item: Option<&dyn ItemHandle>, cx: &mut ViewContext) { + if let Some(editor) = item.and_then(|item| item.act_as::(cx)) { + self.editor_subscription = Some(( + cx.observe(&editor, Self::update_enabled), + editor.entity_id().as_u64() as usize, + )); + self.update_enabled(editor, cx); + } else { + self.language = None; + self.editor_subscription = None; + self.editor_enabled = None; + } + cx.notify(); + } +} + +impl SupermavenButtonStatus { + fn to_icon(&self) -> IconName { + match self { + SupermavenButtonStatus::Ready => IconName::Supermaven, + SupermavenButtonStatus::Errored(_) => IconName::SupermavenError, + SupermavenButtonStatus::NeedsActivation(_) => IconName::SupermavenInit, + SupermavenButtonStatus::Initializing => IconName::SupermavenInit, + } + } + + fn to_tooltip(&self) -> String { + match self { + SupermavenButtonStatus::Ready => "Supermaven is ready".to_string(), + SupermavenButtonStatus::Errored(error) => format!("Supermaven error: {}", error), + SupermavenButtonStatus::NeedsActivation(_) => "Supermaven needs activation".to_string(), + SupermavenButtonStatus::Initializing => "Supermaven initializing".to_string(), + } + } +} + +async fn configure_disabled_globs( + workspace: WeakView, + path_to_disable: Option>, + mut cx: AsyncWindowContext, +) -> Result<()> { + let settings_editor = workspace + .update(&mut cx, |_, cx| { + create_and_open_local_file(&paths::SETTINGS, cx, || { + settings::initial_user_settings_content().as_ref().into() + }) + })? + .await? + .downcast::() + .unwrap(); + + settings_editor.downgrade().update(&mut cx, |item, cx| { + let text = item.buffer().read(cx).snapshot(cx).text(); + + let settings = cx.global::(); + let edits = settings.edits_for_update::(&text, |file| { + let copilot = file.inline_completions.get_or_insert_with(Default::default); + let globs = copilot.disabled_globs.get_or_insert_with(|| { + settings + .get::(None) + .inline_completions + .disabled_globs + .iter() + .map(|glob| glob.glob().to_string()) + .collect() + }); + + if let Some(path_to_disable) = &path_to_disable { + globs.push(path_to_disable.to_string_lossy().into_owned()); + } else { + globs.clear(); + } + }); + + if !edits.is_empty() { + item.change_selections(Some(Autoscroll::newest()), cx, |selections| { + selections.select_ranges(edits.iter().map(|e| e.0.clone())); + }); + + // When *enabling* a path, don't actually perform an edit, just select the range. + if path_to_disable.is_some() { + item.edit(edits.iter().cloned(), cx); + } + } + })?; + + anyhow::Ok(()) +} + +fn toggle_inline_completions_globally(fs: Arc, cx: &mut AppContext) { + let show_inline_completions = + all_language_settings(None, cx).inline_completions_enabled(None, None); + update_settings_file::(fs, cx, move |file| { + file.defaults.show_inline_completions = Some(!show_inline_completions) + }); +} + +fn toggle_inline_completions_for_language( + language: Arc, + fs: Arc, + cx: &mut AppContext, +) { + let show_inline_completions = + all_language_settings(None, cx).inline_completions_enabled(Some(&language), None); + update_settings_file::(fs, cx, move |file| { + file.languages + .entry(language.name()) + .or_default() + .show_inline_completions = Some(!show_inline_completions); + }); +} + +fn hide_copilot(fs: Arc, cx: &mut AppContext) { + update_settings_file::(fs, cx, move |file| { + file.features.get_or_insert(Default::default()).copilot = Some(false); + }); +} + +pub fn initiate_sign_in(cx: &mut WindowContext) { + let Some(copilot) = Copilot::global(cx) else { + return; + }; + let status = copilot.read(cx).status(); + let Some(workspace) = cx.window_handle().downcast::() else { + return; + }; + match status { + Status::Starting { task } => { + let Some(workspace) = cx.window_handle().downcast::() else { + return; + }; + + let Ok(workspace) = workspace.update(cx, |workspace, cx| { + workspace.show_toast( + Toast::new( + NotificationId::unique::(), + "Copilot is starting...", + ), + cx, + ); + workspace.weak_handle() + }) else { + return; + }; + + cx.spawn(|mut cx| async move { + task.await; + if let Some(copilot) = cx.update(|cx| Copilot::global(cx)).ok().flatten() { + workspace + .update(&mut cx, |workspace, cx| match copilot.read(cx).status() { + Status::Authorized => workspace.show_toast( + Toast::new( + NotificationId::unique::(), + "Copilot has started!", + ), + cx, + ), + _ => { + workspace.dismiss_toast( + &NotificationId::unique::(), + cx, + ); + copilot + .update(cx, |copilot, cx| copilot.sign_in(cx)) + .detach_and_log_err(cx); + } + }) + .log_err(); + } + }) + .detach(); + } + _ => { + copilot.update(cx, |this, cx| this.sign_in(cx)).detach(); + workspace + .update(cx, |this, cx| { + this.toggle_modal(cx, |cx| CopilotCodeVerification::new(&copilot, cx)); + }) + .ok(); + } + } +} diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index bea5344be2..537816b983 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -51,8 +51,8 @@ pub fn all_language_settings<'a>( /// The settings for all languages. #[derive(Debug, Clone)] pub struct AllLanguageSettings { - /// The settings for GitHub Copilot. - pub copilot: CopilotSettings, + /// The inline completion settings. + pub inline_completions: InlineCompletionSettings, defaults: LanguageSettings, languages: HashMap, LanguageSettings>, pub(crate) file_types: HashMap, Vec>, @@ -101,9 +101,9 @@ pub struct LanguageSettings { /// - `"!"` - A language server ID prefixed with a `!` will be disabled. /// - `"..."` - A placeholder to refer to the **rest** of the registered language servers for this language. pub language_servers: Vec>, - /// Controls whether Copilot provides suggestion immediately (true) - /// or waits for a `copilot::Toggle` (false). - pub show_copilot_suggestions: bool, + /// Controls whether inline completions are shown immediately (true) + /// or manually by triggering `editor::ShowInlineCompletion` (false). + pub show_inline_completions: bool, /// Whether to show tabs and spaces in the editor. pub show_whitespaces: ShowWhitespaceSetting, /// Whether to start a new line with a comment when a previous line is a comment as well. @@ -165,12 +165,23 @@ impl LanguageSettings { } } -/// The settings for [GitHub Copilot](https://github.com/features/copilot). +/// The provider that supplies inline completions. +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum InlineCompletionProvider { + None, + #[default] + Copilot, + Supermaven, +} + +/// The settings for inline completions, such as [GitHub Copilot](https://github.com/features/copilot) +/// or [Supermaven](https://supermaven.com). #[derive(Clone, Debug, Default)] -pub struct CopilotSettings { - /// Whether Copilot is enabled. - pub feature_enabled: bool, - /// A list of globs representing files that Copilot should be disabled for. +pub struct InlineCompletionSettings { + /// The provider that supplies inline completions. + pub provider: InlineCompletionProvider, + /// A list of globs representing files that inline completions should be disabled for. pub disabled_globs: Vec, } @@ -180,9 +191,9 @@ pub struct AllLanguageSettingsContent { /// The settings for enabling/disabling features. #[serde(default)] pub features: Option, - /// The settings for GitHub Copilot. - #[serde(default)] - pub copilot: Option, + /// The inline completion settings. + #[serde(default, alias = "copilot")] + pub inline_completions: Option, /// The default language settings. #[serde(flatten)] pub defaults: LanguageSettingsContent, @@ -277,12 +288,12 @@ pub struct LanguageSettingsContent { /// Default: ["..."] #[serde(default)] pub language_servers: Option>>, - /// Controls whether Copilot provides suggestion immediately (true) - /// or waits for a `copilot::Toggle` (false). + /// Controls whether inline completions are shown immediately (true) + /// or manually by triggering `editor::ShowInlineCompletion` (false). /// /// Default: true - #[serde(default)] - pub show_copilot_suggestions: Option, + #[serde(default, alias = "show_copilot_suggestions")] + pub show_inline_completions: Option, /// Whether to show tabs and spaces in the editor. #[serde(default)] pub show_whitespaces: Option, @@ -314,10 +325,10 @@ pub struct LanguageSettingsContent { pub code_actions_on_format: Option>, } -/// The contents of the GitHub Copilot settings. -#[derive(Clone, Debug, PartialEq, Default, Serialize, Deserialize, JsonSchema)] -pub struct CopilotSettingsContent { - /// A list of globs representing files that Copilot should be disabled for. +/// The contents of the inline completion settings. +#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq)] +pub struct InlineCompletionSettingsContent { + /// A list of globs representing files that inline completions should be disabled for. #[serde(default)] pub disabled_globs: Option>, } @@ -328,6 +339,8 @@ pub struct CopilotSettingsContent { pub struct FeaturesContent { /// Whether the GitHub Copilot feature is enabled. pub copilot: Option, + /// Determines which inline completion provider to use. + pub inline_completion_provider: Option, } /// Controls the soft-wrapping behavior in the editor. @@ -475,29 +488,29 @@ impl AllLanguageSettings { &self.defaults } - /// Returns whether GitHub Copilot is enabled for the given path. - pub fn copilot_enabled_for_path(&self, path: &Path) -> bool { + /// Returns whether inline completions are enabled for the given path. + pub fn inline_completions_enabled_for_path(&self, path: &Path) -> bool { !self - .copilot + .inline_completions .disabled_globs .iter() .any(|glob| glob.is_match(path)) } - /// Returns whether GitHub Copilot is enabled for the given language and path. - pub fn copilot_enabled(&self, language: Option<&Arc>, path: Option<&Path>) -> bool { - if !self.copilot.feature_enabled { - return false; - } - + /// Returns whether inline completions are enabled for the given language and path. + pub fn inline_completions_enabled( + &self, + language: Option<&Arc>, + path: Option<&Path>, + ) -> bool { if let Some(path) = path { - if !self.copilot_enabled_for_path(path) { + if !self.inline_completions_enabled_for_path(path) { return false; } } self.language(language.map(|l| l.name()).as_deref()) - .show_copilot_suggestions + .show_inline_completions } } @@ -551,13 +564,13 @@ impl settings::Settings for AllLanguageSettings { languages.insert(language_name.clone(), language_settings); } - let mut copilot_enabled = default_value + let mut copilot_enabled = default_value.features.as_ref().and_then(|f| f.copilot); + let mut inline_completion_provider = default_value .features .as_ref() - .and_then(|f| f.copilot) - .ok_or_else(Self::missing_default)?; - let mut copilot_globs = default_value - .copilot + .and_then(|f| f.inline_completion_provider); + let mut completion_globs = default_value + .inline_completions .as_ref() .and_then(|c| c.disabled_globs.as_ref()) .ok_or_else(Self::missing_default)?; @@ -565,14 +578,21 @@ impl settings::Settings for AllLanguageSettings { let mut file_types: HashMap, Vec> = HashMap::default(); for user_settings in sources.customizations() { if let Some(copilot) = user_settings.features.as_ref().and_then(|f| f.copilot) { - copilot_enabled = copilot; + copilot_enabled = Some(copilot); + } + if let Some(provider) = user_settings + .features + .as_ref() + .and_then(|f| f.inline_completion_provider) + { + inline_completion_provider = Some(provider); } if let Some(globs) = user_settings - .copilot + .inline_completions .as_ref() .and_then(|f| f.disabled_globs.as_ref()) { - copilot_globs = globs; + completion_globs = globs; } // A user's global settings override the default global settings and @@ -601,9 +621,15 @@ impl settings::Settings for AllLanguageSettings { } Ok(Self { - copilot: CopilotSettings { - feature_enabled: copilot_enabled, - disabled_globs: copilot_globs + inline_completions: InlineCompletionSettings { + provider: if let Some(provider) = inline_completion_provider { + provider + } else if copilot_enabled.unwrap_or(true) { + InlineCompletionProvider::Copilot + } else { + InlineCompletionProvider::None + }, + disabled_globs: completion_globs .iter() .filter_map(|g| Some(globset::Glob::new(g).ok()?.compile_matcher())) .collect(), @@ -714,8 +740,8 @@ fn merge_settings(settings: &mut LanguageSettings, src: &LanguageSettingsContent ); merge(&mut settings.language_servers, src.language_servers.clone()); merge( - &mut settings.show_copilot_suggestions, - src.show_copilot_suggestions, + &mut settings.show_inline_completions, + src.show_inline_completions, ); merge(&mut settings.show_whitespaces, src.show_whitespaces); merge( diff --git a/crates/language_tools/Cargo.toml b/crates/language_tools/Cargo.toml index 6d0a1199b3..d85f5a6e52 100644 --- a/crates/language_tools/Cargo.toml +++ b/crates/language_tools/Cargo.toml @@ -15,6 +15,7 @@ doctest = false [dependencies] anyhow.workspace = true collections.workspace = true +copilot.workspace = true editor.workspace = true futures.workspace = true gpui.workspace = true @@ -26,7 +27,6 @@ settings.workspace = true theme.workspace = true tree-sitter.workspace = true ui.workspace = true -util.workspace = true workspace.workspace = true [dev-dependencies] diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index a35d8b33e5..28a27aac60 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -1,4 +1,5 @@ use collections::{HashMap, VecDeque}; +use copilot::Copilot; use editor::{actions::MoveToEnd, Editor, EditorEvent}; use futures::{channel::mpsc, StreamExt}; use gpui::{ @@ -7,11 +8,10 @@ use gpui::{ View, ViewContext, VisualContext, WeakModel, WindowContext, }; use language::{LanguageServerId, LanguageServerName}; -use lsp::IoKind; +use lsp::{IoKind, LanguageServer}; use project::{search::SearchQuery, Project}; use std::{borrow::Cow, sync::Arc}; use ui::{popover_menu, prelude::*, Button, Checkbox, ContextMenu, Label, Selection}; -use util::maybe; use workspace::{ item::{Item, ItemHandle, TabContentParams}, searchable::{SearchEvent, SearchableItem, SearchableItemHandle}, @@ -24,17 +24,21 @@ const MAX_STORED_LOG_ENTRIES: usize = 2000; pub struct LogStore { projects: HashMap, ProjectState>, - io_tx: mpsc::UnboundedSender<(WeakModel, LanguageServerId, IoKind, String)>, + language_servers: HashMap, + copilot_log_subscription: Option, + _copilot_subscription: Option, + io_tx: mpsc::UnboundedSender<(LanguageServerId, IoKind, String)>, } struct ProjectState { - servers: HashMap, _subscriptions: [gpui::Subscription; 2], } struct LanguageServerState { + name: LanguageServerName, log_messages: VecDeque, rpc_state: Option, + project: Option>, _io_logs_subscription: Option, _lsp_logs_subscription: Option, } @@ -109,15 +113,55 @@ pub fn init(cx: &mut AppContext) { impl LogStore { pub fn new(cx: &mut ModelContext) -> Self { let (io_tx, mut io_rx) = mpsc::unbounded(); + + let copilot_subscription = Copilot::global(cx).map(|copilot| { + let copilot = &copilot; + cx.subscribe( + copilot, + |this, copilot, copilot_event, cx| match copilot_event { + copilot::Event::CopilotLanguageServerStarted => { + if let Some(server) = copilot.read(cx).language_server() { + let server_id = server.server_id(); + let weak_this = cx.weak_model(); + this.copilot_log_subscription = + Some(server.on_notification::( + move |params, mut cx| { + weak_this + .update(&mut cx, |this, cx| { + this.add_language_server_log( + server_id, + ¶ms.message, + cx, + ); + }) + .ok(); + }, + )); + this.add_language_server( + None, + LanguageServerName(Arc::from("copilot")), + server.clone(), + cx, + ); + } + } + }, + ) + }); + let this = Self { + copilot_log_subscription: None, + _copilot_subscription: copilot_subscription, projects: HashMap::default(), + language_servers: HashMap::default(), io_tx, }; + cx.spawn(|this, mut cx| async move { - while let Some((project, server_id, io_kind, message)) = io_rx.next().await { + while let Some((server_id, io_kind, message)) = io_rx.next().await { if let Some(this) = this.upgrade() { this.update(&mut cx, |this, cx| { - this.on_io(project, server_id, io_kind, &message, cx); + this.on_io(server_id, io_kind, &message, cx); })?; } } @@ -132,20 +176,32 @@ impl LogStore { self.projects.insert( project.downgrade(), ProjectState { - servers: HashMap::default(), _subscriptions: [ cx.observe_release(project, move |this, _, _| { this.projects.remove(&weak_project); + this.language_servers + .retain(|_, state| state.project.as_ref() != Some(&weak_project)); }), cx.subscribe(project, |this, project, event, cx| match event { project::Event::LanguageServerAdded(id) => { - this.add_language_server(&project, *id, cx); + let read_project = project.read(cx); + if let Some((server, adapter)) = read_project + .language_server_for_id(*id) + .zip(read_project.language_server_adapter_for_id(*id)) + { + this.add_language_server( + Some(&project.downgrade()), + adapter.name.clone(), + server, + cx, + ); + } } project::Event::LanguageServerRemoved(id) => { - this.remove_language_server(&project, *id, cx); + this.remove_language_server(*id, cx); } project::Event::LanguageServerLog(id, message) => { - this.add_language_server_log(&project, *id, message, cx); + this.add_language_server_log(*id, message, cx); } _ => {} }), @@ -154,74 +210,69 @@ impl LogStore { ); } + fn get_language_server_state( + &mut self, + id: LanguageServerId, + ) -> Option<&mut LanguageServerState> { + self.language_servers.get_mut(&id) + } + fn add_language_server( &mut self, - project: &Model, - id: LanguageServerId, + project: Option<&WeakModel>, + name: LanguageServerName, + server: Arc, cx: &mut ModelContext, ) -> Option<&mut LanguageServerState> { - let project_state = self.projects.get_mut(&project.downgrade())?; - let server_state = project_state.servers.entry(id).or_insert_with(|| { - cx.notify(); - LanguageServerState { - rpc_state: None, - log_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), - _io_logs_subscription: None, - _lsp_logs_subscription: None, - } - }); + let server_state = self + .language_servers + .entry(server.server_id()) + .or_insert_with(|| { + cx.notify(); + LanguageServerState { + name, + rpc_state: None, + project: project.cloned(), + log_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), + _io_logs_subscription: None, + _lsp_logs_subscription: None, + } + }); - let server = project.read(cx).language_server_for_id(id); - if let Some(server) = server.as_deref() { - if server.has_notification_handler::() { - // Another event wants to re-add the server that was already added and subscribed to, avoid doing it again. - return Some(server_state); - } + if server.has_notification_handler::() { + // Another event wants to re-add the server that was already added and subscribed to, avoid doing it again. + return Some(server_state); } - let weak_project = project.downgrade(); let io_tx = self.io_tx.clone(); - server_state._io_logs_subscription = server.as_ref().map(|server| { - server.on_io(move |io_kind, message| { - io_tx - .unbounded_send((weak_project.clone(), id, io_kind, message.to_string())) - .ok(); - }) - }); + let server_id = server.server_id(); + server_state._io_logs_subscription = Some(server.on_io(move |io_kind, message| { + io_tx + .unbounded_send((server_id, io_kind, message.to_string())) + .ok(); + })); let this = cx.handle().downgrade(); - let weak_project = project.downgrade(); - server_state._lsp_logs_subscription = server.map(|server| { - let server_id = server.server_id(); - server.on_notification::({ + server_state._lsp_logs_subscription = + Some(server.on_notification::({ move |params, mut cx| { - if let Some((project, this)) = weak_project.upgrade().zip(this.upgrade()) { + if let Some(this) = this.upgrade() { this.update(&mut cx, |this, cx| { - this.add_language_server_log(&project, server_id, ¶ms.message, cx); + this.add_language_server_log(server_id, ¶ms.message, cx); }) .ok(); } } - }) - }); + })); Some(server_state) } fn add_language_server_log( &mut self, - project: &Model, id: LanguageServerId, message: &str, cx: &mut ModelContext, ) -> Option<()> { - let language_server_state = match self - .projects - .get_mut(&project.downgrade())? - .servers - .get_mut(&id) - { - Some(existing_state) => existing_state, - None => self.add_language_server(&project, id, cx)?, - }; + let language_server_state = self.get_language_server_state(id)?; let log_lines = &mut language_server_state.log_messages; while log_lines.len() >= MAX_STORED_LOG_ENTRIES { @@ -238,38 +289,43 @@ impl LogStore { Some(()) } - fn remove_language_server( - &mut self, - project: &Model, - id: LanguageServerId, - cx: &mut ModelContext, - ) -> Option<()> { - let project_state = self.projects.get_mut(&project.downgrade())?; - project_state.servers.remove(&id); + fn remove_language_server(&mut self, id: LanguageServerId, cx: &mut ModelContext) { + self.language_servers.remove(&id); cx.notify(); - Some(()) } - fn server_logs( - &self, - project: &Model, - server_id: LanguageServerId, - ) -> Option<&VecDeque> { - let weak_project = project.downgrade(); - let project_state = self.projects.get(&weak_project)?; - let server_state = project_state.servers.get(&server_id)?; - Some(&server_state.log_messages) + fn server_logs(&self, server_id: LanguageServerId) -> Option<&VecDeque> { + Some(&self.language_servers.get(&server_id)?.log_messages) + } + + fn server_ids_for_project<'a>( + &'a self, + project: &'a WeakModel, + ) -> impl Iterator + 'a { + [].into_iter() + .chain(self.language_servers.iter().filter_map(|(id, state)| { + if state.project.as_ref() == Some(project) { + return Some(*id); + } else { + None + } + })) + .chain(self.language_servers.iter().filter_map(|(id, state)| { + if state.project.is_none() { + return Some(*id); + } else { + None + } + })) } fn enable_rpc_trace_for_language_server( &mut self, - project: &Model, server_id: LanguageServerId, ) -> Option<&mut LanguageServerRpcState> { - let weak_project = project.downgrade(); - let project_state = self.projects.get_mut(&weak_project)?; - let server_state = project_state.servers.get_mut(&server_id)?; - let rpc_state = server_state + let rpc_state = self + .language_servers + .get_mut(&server_id)? .rpc_state .get_or_insert_with(|| LanguageServerRpcState { rpc_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), @@ -280,20 +336,14 @@ impl LogStore { pub fn disable_rpc_trace_for_language_server( &mut self, - project: &Model, server_id: LanguageServerId, - _: &mut ModelContext, ) -> Option<()> { - let project = project.downgrade(); - let project_state = self.projects.get_mut(&project)?; - let server_state = project_state.servers.get_mut(&server_id)?; - server_state.rpc_state.take(); + self.language_servers.get_mut(&server_id)?.rpc_state.take(); Some(()) } fn on_io( &mut self, - project: WeakModel, language_server_id: LanguageServerId, io_kind: IoKind, message: &str, @@ -303,18 +353,14 @@ impl LogStore { IoKind::StdOut => true, IoKind::StdIn => false, IoKind::StdErr => { - let project = project.upgrade()?; let message = format!("stderr: {}", message.trim()); - self.add_language_server_log(&project, language_server_id, &message, cx); + self.add_language_server_log(language_server_id, &message, cx); return Some(()); } }; let state = self - .projects - .get_mut(&project)? - .servers - .get_mut(&language_server_id)? + .get_language_server_state(language_server_id)? .rpc_state .as_mut()?; let kind = if is_received { @@ -360,42 +406,40 @@ impl LspLogView { ) -> Self { let server_id = log_store .read(cx) - .projects - .get(&project.downgrade()) - .and_then(|project| project.servers.keys().copied().next()); - let model_changes_subscription = cx.observe(&log_store, |this, store, cx| { - maybe!({ - let project_state = store.read(cx).projects.get(&this.project.downgrade())?; - if let Some(current_lsp) = this.current_server_id { - if !project_state.servers.contains_key(¤t_lsp) { - if let Some(server) = project_state.servers.iter().next() { - if this.is_showing_rpc_trace { - this.show_rpc_trace_for_server(*server.0, cx) - } else { - this.show_logs_for_server(*server.0, cx) - } - } else { - this.current_server_id = None; - this.editor.update(cx, |editor, cx| { - editor.set_read_only(false); - editor.clear(cx); - editor.set_read_only(true); - }); - cx.notify(); - } - } - } else { - if let Some(server) = project_state.servers.iter().next() { + .language_servers + .iter() + .find(|(_, server)| server.project == Some(project.downgrade())) + .map(|(id, _)| *id); + + let weak_project = project.downgrade(); + let model_changes_subscription = cx.observe(&log_store, move |this, store, cx| { + let first_server_id_for_project = + store.read(cx).server_ids_for_project(&weak_project).next(); + if let Some(current_lsp) = this.current_server_id { + if !store.read(cx).language_servers.contains_key(¤t_lsp) { + if let Some(server_id) = first_server_id_for_project { if this.is_showing_rpc_trace { - this.show_rpc_trace_for_server(*server.0, cx) + this.show_rpc_trace_for_server(server_id, cx) } else { - this.show_logs_for_server(*server.0, cx) + this.show_logs_for_server(server_id, cx) } + } else { + this.current_server_id = None; + this.editor.update(cx, |editor, cx| { + editor.set_read_only(false); + editor.clear(cx); + editor.set_read_only(true); + }); + cx.notify(); } } - - Some(()) - }); + } else if let Some(server_id) = first_server_id_for_project { + if this.is_showing_rpc_trace { + this.show_rpc_trace_for_server(server_id, cx) + } else { + this.show_logs_for_server(server_id, cx) + } + } cx.notify(); }); @@ -477,14 +521,14 @@ impl LspLogView { pub(crate) fn menu_items<'a>(&'a self, cx: &'a AppContext) -> Option> { let log_store = self.log_store.read(cx); - let state = log_store.projects.get(&self.project.downgrade())?; + let mut rows = self .project .read(cx) .language_servers() .filter_map(|(server_id, language_server_name, worktree_id)| { let worktree = self.project.read(cx).worktree_for_id(worktree_id, cx)?; - let state = state.servers.get(&server_id)?; + let state = log_store.language_servers.get(&server_id)?; Some(LogMenuItem { server_id, server_name: language_server_name, @@ -501,7 +545,7 @@ impl LspLogView { .read(cx) .supplementary_language_servers() .filter_map(|(&server_id, (name, _))| { - let state = state.servers.get(&server_id)?; + let state = log_store.language_servers.get(&server_id)?; Some(LogMenuItem { server_id, server_name: name.clone(), @@ -514,6 +558,27 @@ impl LspLogView { }) }), ) + .chain( + log_store + .language_servers + .iter() + .filter_map(|(server_id, state)| { + if state.project.is_none() { + Some(LogMenuItem { + server_id: *server_id, + server_name: state.name.clone(), + worktree_root_name: "supplementary".to_string(), + rpc_trace_enabled: state.rpc_state.is_some(), + rpc_trace_selected: self.is_showing_rpc_trace + && self.current_server_id == Some(*server_id), + logs_selected: !self.is_showing_rpc_trace + && self.current_server_id == Some(*server_id), + }) + } else { + None + } + }), + ) .collect::>(); rows.sort_by_key(|row| row.server_id); rows.dedup_by_key(|row| row.server_id); @@ -524,7 +589,7 @@ impl LspLogView { let log_contents = self .log_store .read(cx) - .server_logs(&self.project, server_id) + .server_logs(server_id) .map(log_contents); if let Some(log_contents) = log_contents { self.current_server_id = Some(server_id); @@ -544,7 +609,7 @@ impl LspLogView { ) { let rpc_log = self.log_store.update(cx, |log_store, _| { log_store - .enable_rpc_trace_for_language_server(&self.project, server_id) + .enable_rpc_trace_for_language_server(server_id) .map(|state| log_contents(&state.rpc_messages)) }); if let Some(rpc_log) = rpc_log { @@ -585,11 +650,11 @@ impl LspLogView { enabled: bool, cx: &mut ViewContext, ) { - self.log_store.update(cx, |log_store, cx| { + self.log_store.update(cx, |log_store, _| { if enabled { - log_store.enable_rpc_trace_for_language_server(&self.project, server_id); + log_store.enable_rpc_trace_for_language_server(server_id); } else { - log_store.disable_rpc_trace_for_language_server(&self.project, server_id, cx); + log_store.disable_rpc_trace_for_language_server(server_id); } }); if !enabled && Some(server_id) == self.current_server_id { diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index 30766a7b6f..1d943bc080 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -30,7 +30,6 @@ async-trait.workspace = true client.workspace = true clock.workspace = true collections.workspace = true -copilot.workspace = true fs.workspace = true futures.workspace = true fuzzy.workspace = true diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 733a06172b..28c6182016 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -20,7 +20,6 @@ use client::{ }; use clock::ReplicaId; use collections::{hash_map, BTreeMap, HashMap, HashSet, VecDeque}; -use copilot::Copilot; use debounced_delay::DebouncedDelay; use futures::{ channel::{ @@ -200,8 +199,6 @@ pub struct Project { _maintain_buffer_languages: Task<()>, _maintain_workspace_config: Task>, terminals: Terminals, - copilot_lsp_subscription: Option, - copilot_log_subscription: Option, current_lsp_settings: HashMap, LspSettings>, node: Option>, default_prettier: DefaultPrettier, @@ -685,8 +682,6 @@ impl Project { let (tx, rx) = mpsc::unbounded(); cx.spawn(move |this, cx| Self::send_buffer_ordered_messages(this, rx, cx)) .detach(); - let copilot_lsp_subscription = - Copilot::global(cx).map(|copilot| subscribe_for_copilot_events(&copilot, cx)); let tasks = Inventory::new(cx); Self { @@ -735,8 +730,6 @@ impl Project { terminals: Terminals { local_handles: Vec::new(), }, - copilot_lsp_subscription, - copilot_log_subscription: None, current_lsp_settings: ProjectSettings::get_global(cx).lsp.clone(), node: Some(node), default_prettier: DefaultPrettier::default(), @@ -823,8 +816,6 @@ impl Project { let (tx, rx) = mpsc::unbounded(); cx.spawn(move |this, cx| Self::send_buffer_ordered_messages(this, rx, cx)) .detach(); - let copilot_lsp_subscription = - Copilot::global(cx).map(|copilot| subscribe_for_copilot_events(&copilot, cx)); let mut this = Self { worktrees: Vec::new(), buffer_ordered_messages_tx: tx, @@ -891,8 +882,6 @@ impl Project { terminals: Terminals { local_handles: Vec::new(), }, - copilot_lsp_subscription, - copilot_log_subscription: None, current_lsp_settings: ProjectSettings::get_global(cx).lsp.clone(), node: None, default_prettier: DefaultPrettier::default(), @@ -1184,17 +1173,6 @@ impl Project { self.restart_language_servers(worktree, language, cx); } - if self.copilot_lsp_subscription.is_none() { - if let Some(copilot) = Copilot::global(cx) { - for buffer in self.opened_buffers.values() { - if let Some(buffer) = buffer.upgrade() { - self.register_buffer_with_copilot(&buffer, cx); - } - } - self.copilot_lsp_subscription = Some(subscribe_for_copilot_events(&copilot, cx)); - } - } - cx.notify(); } @@ -2351,7 +2329,7 @@ impl Project { self.detect_language_for_buffer(buffer, cx); self.register_buffer_with_language_servers(buffer, cx); - self.register_buffer_with_copilot(buffer, cx); + // self.register_buffer_with_copilot(buffer, cx); cx.observe_release(buffer, |this, buffer, cx| { if let Some(file) = File::from_dyn(buffer.file()) { if file.is_local() { @@ -2500,15 +2478,15 @@ impl Project { }); } - fn register_buffer_with_copilot( - &self, - buffer_handle: &Model, - cx: &mut ModelContext, - ) { - if let Some(copilot) = Copilot::global(cx) { - copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx)); - } - } + // fn register_buffer_with_copilot( + // &self, + // buffer_handle: &Model, + // cx: &mut ModelContext, + // ) { + // if let Some(copilot) = Copilot::global(cx) { + // copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx)); + // } + // } async fn send_buffer_ordered_messages( this: WeakModel, @@ -10475,43 +10453,6 @@ async fn search_ignored_entry( } } -fn subscribe_for_copilot_events( - copilot: &Model, - cx: &mut ModelContext<'_, Project>, -) -> gpui::Subscription { - cx.subscribe( - copilot, - |project, copilot, copilot_event, cx| match copilot_event { - copilot::Event::CopilotLanguageServerStarted => { - match copilot.read(cx).language_server() { - Some((name, copilot_server)) => { - // Another event wants to re-add the server that was already added and subscribed to, avoid doing it again. - if !copilot_server.has_notification_handler::() { - let new_server_id = copilot_server.server_id(); - let weak_project = cx.weak_model(); - let copilot_log_subscription = copilot_server - .on_notification::( - move |params, mut cx| { - weak_project.update(&mut cx, |_, cx| { - cx.emit(Event::LanguageServerLog( - new_server_id, - params.message, - )); - }).ok(); - }, - ); - project.supplementary_language_servers.insert(new_server_id, (name.clone(), Arc::clone(copilot_server))); - project.copilot_log_subscription = Some(copilot_log_subscription); - cx.emit(Event::LanguageServerAdded(new_server_id)); - } - } - None => debug_panic!("Received Copilot language server started event, but no language server is running"), - } - } - }, - ) -} - fn glob_literal_prefix(glob: &str) -> &str { let mut literal_end = 0; for (i, part) in glob.split(path::MAIN_SEPARATOR).enumerate() { diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 3dfa9508dc..5f8af8e1f0 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -207,7 +207,7 @@ message Envelope { GetCachedEmbeddings get_cached_embeddings = 189; GetCachedEmbeddingsResponse get_cached_embeddings_response = 190; ComputeEmbeddings compute_embeddings = 191; - ComputeEmbeddingsResponse compute_embeddings_response = 192; // current max + ComputeEmbeddingsResponse compute_embeddings_response = 192; UpdateChannelMessage update_channel_message = 170; ChannelMessageUpdate channel_message_update = 171; @@ -238,7 +238,10 @@ message Envelope { ValidateDevServerProjectRequest validate_dev_server_project_request = 194; DeleteDevServer delete_dev_server = 195; OpenNewBuffer open_new_buffer = 196; - DeleteDevServerProject delete_dev_server_project = 197; // Current max + DeleteDevServerProject delete_dev_server_project = 197; + + GetSupermavenApiKey get_supermaven_api_key = 198; + GetSupermavenApiKeyResponse get_supermaven_api_key_response = 199; // current max } reserved 158 to 161; @@ -2084,3 +2087,9 @@ message LspResponse { GetCodeActionsResponse get_code_actions_response = 2; } } + +message GetSupermavenApiKey {} + +message GetSupermavenApiKeyResponse { + string api_key = 1; +} diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 966a24ead9..d011f1d1d2 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -201,6 +201,8 @@ messages!( (GetProjectSymbolsResponse, Background), (GetReferences, Background), (GetReferencesResponse, Background), + (GetSupermavenApiKey, Background), + (GetSupermavenApiKeyResponse, Background), (GetTypeDefinition, Background), (GetTypeDefinitionResponse, Background), (GetImplementation, Background), @@ -360,6 +362,7 @@ request_messages!( (GetPrivateUserInfo, GetPrivateUserInfoResponse), (GetProjectSymbols, GetProjectSymbolsResponse), (GetReferences, GetReferencesResponse), + (GetSupermavenApiKey, GetSupermavenApiKeyResponse), (GetTypeDefinition, GetTypeDefinitionResponse), (GetUsers, UsersResponse), (IncomingCall, Ack), diff --git a/crates/supermaven/Cargo.toml b/crates/supermaven/Cargo.toml new file mode 100644 index 0000000000..4abbcd4a43 --- /dev/null +++ b/crates/supermaven/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "supermaven" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/supermaven.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +client.workspace = true +collections.workspace = true +editor.workspace = true +gpui.workspace = true +futures.workspace = true +language.workspace = true +log.workspace = true +postage.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +supermaven_api.workspace = true +smol.workspace = true +ui.workspace = true +util.workspace = true + +[dev-dependencies] +editor = { workspace = true, features = ["test-support"] } +env_logger.workspace = true +gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +settings = { workspace = true, features = ["test-support"] } +theme = { workspace = true, features = ["test-support"] } +util = { workspace = true, features = ["test-support"] } diff --git a/crates/supermaven/src/messages.rs b/crates/supermaven/src/messages.rs new file mode 100644 index 0000000000..9082e00d60 --- /dev/null +++ b/crates/supermaven/src/messages.rs @@ -0,0 +1,152 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SetApiKey { + pub api_key: String, +} + +// Outbound messages +#[derive(Debug, Serialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum OutboundMessage { + SetApiKey(SetApiKey), + StateUpdate(StateUpdateMessage), + #[allow(dead_code)] + UseFreeVersion, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct StateUpdateMessage { + pub new_id: String, + pub updates: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum StateUpdate { + FileUpdate(FileUpdateMessage), + CursorUpdate(CursorPositionUpdateMessage), +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct FileUpdateMessage { + pub path: String, + pub content: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct CursorPositionUpdateMessage { + pub path: String, + pub offset: usize, +} + +// Inbound messages coming in on stdout + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum ResponseItem { + // A completion + Text { text: String }, + // Vestigial message type from old versions -- safe to ignore + Del { text: String }, + // Be able to delete whitespace prior to the cursor, likely for the rest of the completion + Dedent { text: String }, + // When the completion is over + End, + // Got the closing parentheses and shouldn't show any more after + Barrier, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SupermavenResponse { + pub state_id: String, + pub items: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SupermavenMetadataMessage { + pub dust_strings: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SupermavenTaskUpdateMessage { + pub task: String, + pub status: TaskStatus, + pub percent_complete: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum TaskStatus { + InProgress, + Complete, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SupermavenActiveRepoMessage { + pub repo_simple_name: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum SupermavenPopupAction { + OpenUrl { label: String, url: String }, + NoOp { label: String }, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct SupermavenPopupMessage { + pub message: String, + pub actions: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "camelCase")] +pub struct ActivationRequest { + pub activate_url: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SupermavenSetMessage { + pub key: String, + pub value: serde_json::Value, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum ServiceTier { + FreeNoLicense, + #[serde(other)] + Unknown, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum SupermavenMessage { + Response(SupermavenResponse), + Metadata(SupermavenMetadataMessage), + Apology { + message: Option, + }, + ActivationRequest(ActivationRequest), + ActivationSuccess, + Passthrough { + passthrough: Box, + }, + Popup(SupermavenPopupMessage), + TaskStatus(SupermavenTaskUpdateMessage), + ActiveRepo(SupermavenActiveRepoMessage), + ServiceTier { + service_tier: ServiceTier, + }, + + Set(SupermavenSetMessage), + #[serde(other)] + Unknown, +} diff --git a/crates/supermaven/src/supermaven.rs b/crates/supermaven/src/supermaven.rs new file mode 100644 index 0000000000..c432116357 --- /dev/null +++ b/crates/supermaven/src/supermaven.rs @@ -0,0 +1,345 @@ +mod messages; +mod supermaven_completion_provider; + +pub use supermaven_completion_provider::*; + +use anyhow::{Context as _, Result}; +#[allow(unused_imports)] +use client::{proto, Client}; +use collections::BTreeMap; + +use futures::{channel::mpsc, io::BufReader, AsyncBufReadExt, StreamExt}; +use gpui::{AppContext, AsyncAppContext, EntityId, Global, Model, ModelContext, Task, WeakModel}; +use language::{language_settings::all_language_settings, Anchor, Buffer, ToOffset}; +use messages::*; +use postage::watch; +use serde::{Deserialize, Serialize}; +use settings::SettingsStore; +use smol::{ + io::AsyncWriteExt, + process::{Child, ChildStdin, ChildStdout, Command}, +}; +use std::{ops::Range, path::PathBuf, process::Stdio, sync::Arc}; +use ui::prelude::*; +use util::ResultExt; + +pub fn init(client: Arc, cx: &mut AppContext) { + let supermaven = cx.new_model(|_| Supermaven::Starting); + Supermaven::set_global(supermaven.clone(), cx); + + let mut provider = all_language_settings(None, cx).inline_completions.provider; + if provider == language::language_settings::InlineCompletionProvider::Supermaven { + supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx)); + } + + cx.observe_global::(move |cx| { + let new_provider = all_language_settings(None, cx).inline_completions.provider; + if new_provider != provider { + provider = new_provider; + if provider == language::language_settings::InlineCompletionProvider::Supermaven { + supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx)); + } else { + supermaven.update(cx, |supermaven, _cx| supermaven.stop()); + } + } + }) + .detach(); +} + +pub enum Supermaven { + Starting, + FailedDownload { error: anyhow::Error }, + Spawned(SupermavenAgent), + Error { error: anyhow::Error }, +} + +#[derive(Clone)] +pub enum AccountStatus { + Unknown, + NeedsActivation { activate_url: String }, + Ready, +} + +#[derive(Clone)] +struct SupermavenGlobal(Model); + +impl Global for SupermavenGlobal {} + +impl Supermaven { + pub fn global(cx: &AppContext) -> Option> { + cx.try_global::() + .map(|model| model.0.clone()) + } + + pub fn set_global(supermaven: Model, cx: &mut AppContext) { + cx.set_global(SupermavenGlobal(supermaven)); + } + + pub fn start(&mut self, client: Arc, cx: &mut ModelContext) { + if let Self::Starting = self { + cx.spawn(|this, mut cx| async move { + let binary_path = + supermaven_api::get_supermaven_agent_path(client.http_client()).await?; + + this.update(&mut cx, |this, cx| { + if let Self::Starting = this { + *this = + Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?); + } + anyhow::Ok(()) + }) + }) + .detach_and_log_err(cx) + } + } + + pub fn stop(&mut self) { + *self = Self::Starting; + } + + pub fn is_enabled(&self) -> bool { + matches!(self, Self::Spawned { .. }) + } + + pub fn complete( + &mut self, + buffer: &Model, + cursor_position: Anchor, + cx: &AppContext, + ) -> Option { + if let Self::Spawned(agent) = self { + let buffer_id = buffer.entity_id(); + let buffer = buffer.read(cx); + let path = buffer + .file() + .and_then(|file| Some(file.as_local()?.abs_path(cx))) + .unwrap_or_else(|| PathBuf::from("untitled")) + .to_string_lossy() + .to_string(); + let content = buffer.text(); + let offset = cursor_position.to_offset(buffer); + let state_id = agent.next_state_id; + agent.next_state_id.0 += 1; + + let (updates_tx, mut updates_rx) = watch::channel(); + postage::stream::Stream::try_recv(&mut updates_rx).unwrap(); + + agent.states.insert( + state_id, + SupermavenCompletionState { + buffer_id, + range: cursor_position.bias_left(buffer)..cursor_position.bias_right(buffer), + completion: Vec::new(), + text: String::new(), + updates_tx, + }, + ); + let _ = agent + .outgoing_tx + .unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage { + new_id: state_id.0.to_string(), + updates: vec![ + StateUpdate::FileUpdate(FileUpdateMessage { + path: path.clone(), + content, + }), + StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }), + ], + })); + + Some(SupermavenCompletion { + id: state_id, + updates: updates_rx, + }) + } else { + None + } + } + + pub fn completion( + &self, + id: SupermavenCompletionStateId, + ) -> Option<&SupermavenCompletionState> { + if let Self::Spawned(agent) = self { + agent.states.get(&id) + } else { + None + } + } +} + +pub struct SupermavenAgent { + _process: Child, + next_state_id: SupermavenCompletionStateId, + states: BTreeMap, + outgoing_tx: mpsc::UnboundedSender, + _handle_outgoing_messages: Task>, + _handle_incoming_messages: Task>, + pub account_status: AccountStatus, + service_tier: Option, + #[allow(dead_code)] + client: Arc, +} + +impl SupermavenAgent { + fn new( + binary_path: PathBuf, + client: Arc, + cx: &mut ModelContext, + ) -> Result { + let mut process = Command::new(&binary_path) + .arg("stdio") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .kill_on_drop(true) + .spawn() + .context("failed to start the binary")?; + + let stdin = process + .stdin + .take() + .context("failed to get stdin for process")?; + let stdout = process + .stdout + .take() + .context("failed to get stdout for process")?; + + let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); + + cx.spawn({ + let client = client.clone(); + let outgoing_tx = outgoing_tx.clone(); + move |this, mut cx| async move { + let mut status = client.status(); + while let Some(status) = status.next().await { + if status.is_connected() { + let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key; + outgoing_tx + .unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key })) + .ok(); + this.update(&mut cx, |this, cx| { + if let Supermaven::Spawned(this) = this { + this.account_status = AccountStatus::Ready; + cx.notify(); + } + })?; + break; + } + } + return anyhow::Ok(()); + } + }) + .detach(); + + Ok(Self { + _process: process, + next_state_id: SupermavenCompletionStateId::default(), + states: BTreeMap::default(), + outgoing_tx, + _handle_outgoing_messages: cx + .spawn(|_, _cx| Self::handle_outgoing_messages(outgoing_rx, stdin)), + _handle_incoming_messages: cx + .spawn(|this, cx| Self::handle_incoming_messages(this, stdout, cx)), + account_status: AccountStatus::Unknown, + service_tier: None, + client, + }) + } + + async fn handle_outgoing_messages( + mut outgoing: mpsc::UnboundedReceiver, + mut stdin: ChildStdin, + ) -> Result<()> { + while let Some(message) = outgoing.next().await { + let bytes = serde_json::to_vec(&message)?; + stdin.write_all(&bytes).await?; + stdin.write_all(&[b'\n']).await?; + } + Ok(()) + } + + async fn handle_incoming_messages( + this: WeakModel, + stdout: ChildStdout, + mut cx: AsyncAppContext, + ) -> Result<()> { + const MESSAGE_PREFIX: &str = "SM-MESSAGE "; + + let stdout = BufReader::new(stdout); + let mut lines = stdout.lines(); + while let Some(line) = lines.next().await { + let Some(line) = line.context("failed to read line from stdout").log_err() else { + continue; + }; + let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else { + continue; + }; + let Some(message) = serde_json::from_str::(&line) + .with_context(|| format!("failed to deserialize line from stdout: {:?}", line)) + .log_err() + else { + continue; + }; + + this.update(&mut cx, |this, _cx| { + if let Supermaven::Spawned(this) = this { + this.handle_message(message); + } + Task::ready(anyhow::Ok(())) + })? + .await?; + } + + Ok(()) + } + + fn handle_message(&mut self, message: SupermavenMessage) { + match message { + SupermavenMessage::ActivationRequest(request) => { + self.account_status = match request.activate_url { + Some(activate_url) => AccountStatus::NeedsActivation { + activate_url: activate_url.clone(), + }, + None => AccountStatus::Ready, + }; + } + SupermavenMessage::ServiceTier { service_tier } => { + self.service_tier = Some(service_tier); + } + SupermavenMessage::Response(response) => { + let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap()); + if let Some(state) = self.states.get_mut(&state_id) { + for item in &response.items { + if let ResponseItem::Text { text } = item { + state.text.push_str(text); + } + } + state.completion.extend(response.items); + *state.updates_tx.borrow_mut() = (); + } + } + SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough), + _ => { + log::warn!("unhandled message: {:?}", message); + } + } + } +} + +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +pub struct SupermavenCompletionStateId(usize); + +#[allow(dead_code)] +pub struct SupermavenCompletionState { + buffer_id: EntityId, + range: Range, + completion: Vec, + text: String, + updates_tx: watch::Sender<()>, +} + +pub struct SupermavenCompletion { + pub id: SupermavenCompletionStateId, + pub updates: watch::Receiver<()>, +} diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_completion_provider.rs new file mode 100644 index 0000000000..8dc06bfac0 --- /dev/null +++ b/crates/supermaven/src/supermaven_completion_provider.rs @@ -0,0 +1,131 @@ +use crate::{Supermaven, SupermavenCompletionStateId}; +use anyhow::Result; +use editor::{Direction, InlineCompletionProvider}; +use futures::StreamExt as _; +use gpui::{AppContext, Model, ModelContext, Task}; +use language::{ + language_settings::all_language_settings, Anchor, Buffer, OffsetRangeExt as _, ToOffset, +}; +use std::time::Duration; + +pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); + +pub struct SupermavenCompletionProvider { + supermaven: Model, + completion_id: Option, + pending_refresh: Task>, +} + +impl SupermavenCompletionProvider { + pub fn new(supermaven: Model) -> Self { + Self { + supermaven, + completion_id: None, + pending_refresh: Task::ready(Ok(())), + } + } +} + +impl InlineCompletionProvider for SupermavenCompletionProvider { + fn is_enabled(&self, buffer: &Model, cursor_position: Anchor, cx: &AppContext) -> bool { + if !self.supermaven.read(cx).is_enabled() { + return false; + } + + let buffer = buffer.read(cx); + let file = buffer.file(); + let language = buffer.language_at(cursor_position); + let settings = all_language_settings(file, cx); + settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref())) + } + + fn refresh( + &mut self, + buffer_handle: Model, + cursor_position: Anchor, + debounce: bool, + cx: &mut ModelContext, + ) { + let Some(mut completion) = self.supermaven.update(cx, |supermaven, cx| { + supermaven.complete(&buffer_handle, cursor_position, cx) + }) else { + return; + }; + + self.pending_refresh = cx.spawn(|this, mut cx| async move { + if debounce { + cx.background_executor().timer(DEBOUNCE_TIMEOUT).await; + } + + while let Some(()) = completion.updates.next().await { + this.update(&mut cx, |this, cx| { + this.completion_id = Some(completion.id); + cx.notify(); + })?; + } + Ok(()) + }); + } + + fn cycle( + &mut self, + _buffer: Model, + _cursor_position: Anchor, + _direction: Direction, + _cx: &mut ModelContext, + ) { + // todo!("cycling") + } + + fn accept(&mut self, _cx: &mut ModelContext) { + self.pending_refresh = Task::ready(Ok(())); + self.completion_id = None; + } + + fn discard(&mut self, _cx: &mut ModelContext) { + self.pending_refresh = Task::ready(Ok(())); + self.completion_id = None; + } + + fn active_completion_text<'a>( + &'a self, + buffer: &Model, + cursor_position: Anchor, + cx: &'a AppContext, + ) -> Option<&'a str> { + let completion_id = self.completion_id?; + let buffer = buffer.read(cx); + let cursor_offset = cursor_position.to_offset(buffer); + let completion = self.supermaven.read(cx).completion(completion_id)?; + + let mut completion_range = completion.range.to_offset(buffer); + + let prefix_len = common_prefix( + buffer.chars_for_range(completion_range.clone()), + completion.text.chars(), + ); + completion_range.start += prefix_len; + let suffix_len = common_prefix( + buffer.reversed_chars_for_range(completion_range.clone()), + completion.text[prefix_len..].chars().rev(), + ); + completion_range.end = completion_range.end.saturating_sub(suffix_len); + + let completion_text = &completion.text[prefix_len..completion.text.len() - suffix_len]; + if completion_range.is_empty() + && completion_range.start == cursor_offset + && !completion_text.trim().is_empty() + { + Some(completion_text) + } else { + None + } + } +} + +fn common_prefix, T2: Iterator>(a: T1, b: T2) -> usize { + a.zip(b) + .take_while(|(a, b)| a == b) + .map(|(a, _)| a.len_utf8()) + .sum() +} diff --git a/crates/supermaven_api/Cargo.toml b/crates/supermaven_api/Cargo.toml new file mode 100644 index 0000000000..69b6965283 --- /dev/null +++ b/crates/supermaven_api/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "supermaven_api" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/supermaven_api.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +futures.workspace = true +serde.workspace = true +serde_json.workspace = true +smol.workspace = true +util.workspace = true diff --git a/crates/supermaven_api/src/supermaven_api.rs b/crates/supermaven_api/src/supermaven_api.rs new file mode 100644 index 0000000000..9d55bc5413 --- /dev/null +++ b/crates/supermaven_api/src/supermaven_api.rs @@ -0,0 +1,291 @@ +use anyhow::{anyhow, Context, Result}; +use futures::io::BufReader; +use futures::{AsyncReadExt, Future}; +use serde::{Deserialize, Serialize}; +use smol::fs::{self, File}; +use smol::stream::StreamExt; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use util::http::{AsyncBody, HttpClient, Request as HttpRequest}; +use util::paths::SUPERMAVEN_DIR; + +#[derive(Serialize)] +pub struct GetExternalUserRequest { + pub id: String, +} + +#[derive(Serialize)] +pub struct CreateExternalUserRequest { + pub id: String, + pub email: String, +} + +#[derive(Serialize)] +pub struct DeleteExternalUserRequest { + pub id: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CreateExternalUserResponse { + pub api_key: String, +} + +#[derive(Deserialize)] +pub struct SupermavenApiError { + pub message: String, +} + +pub struct SupermavenBinary {} + +pub struct SupermavenAdminApi { + admin_api_key: String, + api_url: String, + http_client: Arc, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SupermavenDownloadResponse { + pub download_url: String, + pub version: u64, + pub sha256_hash: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SupermavenUser { + id: String, + email: String, + api_key: String, +} + +impl SupermavenAdminApi { + pub fn new(admin_api_key: String, http_client: Arc) -> Self { + Self { + admin_api_key, + api_url: "https://supermaven.com/api/".to_string(), + http_client, + } + } + + pub async fn try_get_user( + &self, + request: GetExternalUserRequest, + ) -> Result> { + let uri = format!("{}external-user/{}", &self.api_url, &request.id); + + let request = HttpRequest::get(&uri).header("Authorization", self.admin_api_key.clone()); + + let mut response = self + .http_client + .send(request.body(AsyncBody::default())?) + .await + .with_context(|| "Unable to get Supermaven API Key".to_string())?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + if response.status().is_client_error() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + if error.message == "User not found" { + return Ok(None); + } else { + return Err(anyhow!("Supermaven API error: {}", error.message)); + } + } else if response.status().is_server_error() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + return Err(anyhow!("Supermaven API server error").context(error.message)); + } + + let body_str = std::str::from_utf8(&body)?; + + Ok(Some( + serde_json::from_str::(body_str) + .with_context(|| "Unable to parse Supermaven user response".to_string())?, + )) + } + + pub async fn try_create_user( + &self, + request: CreateExternalUserRequest, + ) -> Result { + let uri = format!("{}external-user", &self.api_url); + + let request = HttpRequest::post(&uri) + .header("Authorization", self.admin_api_key.clone()) + .body(AsyncBody::from(serde_json::to_vec(&request)?))?; + + let mut response = self + .http_client + .send(request) + .await + .with_context(|| "Unable to create Supermaven API Key".to_string())?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + let body_str = std::str::from_utf8(&body)?; + + if !response.status().is_success() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + return Err(anyhow!("Supermaven API server error").context(error.message)); + } + + serde_json::from_str::(body_str) + .with_context(|| "Unable to parse Supermaven API Key response".to_string()) + } + + pub async fn try_delete_user(&self, request: DeleteExternalUserRequest) -> Result<()> { + let uri = format!("{}external-user/{}", &self.api_url, &request.id); + + let request = HttpRequest::delete(&uri).header("Authorization", self.admin_api_key.clone()); + + let mut response = self + .http_client + .send(request.body(AsyncBody::default())?) + .await + .with_context(|| "Unable to delete Supermaven User".to_string())?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + if response.status().is_client_error() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + if error.message == "User not found" { + return Ok(()); + } else { + return Err(anyhow!("Supermaven API error: {}", error.message)); + } + } else if response.status().is_server_error() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + return Err(anyhow!("Supermaven API server error").context(error.message)); + } + + Ok(()) + } + + pub async fn try_get_or_create_user( + &self, + request: CreateExternalUserRequest, + ) -> Result { + let get_user_request = GetExternalUserRequest { + id: request.id.clone(), + }; + + match self.try_get_user(get_user_request).await? { + None => self.try_create_user(request).await, + Some(SupermavenUser { api_key, .. }) => Ok(CreateExternalUserResponse { api_key }), + } + } +} + +pub async fn latest_release( + client: Arc, + platform: &str, + arch: &str, +) -> Result { + let uri = format!( + "https://supermaven.com/api/download-path?platform={}&arch={}", + platform, arch + ); + + // Download is not authenticated + let request = HttpRequest::get(&uri); + + let mut response = client + .send(request.body(AsyncBody::default())?) + .await + .with_context(|| "Unable to acquire Supermaven Agent".to_string())?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + if response.status().is_client_error() || response.status().is_server_error() { + let body_str = std::str::from_utf8(&body)?; + let error: SupermavenApiError = serde_json::from_str(body_str)?; + return Err(anyhow!("Supermaven API error: {}", error.message)); + } + + serde_json::from_slice::(&body) + .with_context(|| "Unable to parse Supermaven Agent response".to_string()) +} + +pub fn version_path(version: u64) -> PathBuf { + SUPERMAVEN_DIR.join(format!("sm-agent-{}", version)) +} + +pub async fn has_version(version_path: &Path) -> bool { + fs::metadata(version_path) + .await + .map_or(false, |m| m.is_file()) +} + +pub fn get_supermaven_agent_path( + client: Arc, +) -> impl Future> { + async move { + fs::create_dir_all(&*SUPERMAVEN_DIR) + .await + .with_context(|| { + format!( + "Could not create Supermaven Agent Directory at {:?}", + &*SUPERMAVEN_DIR + ) + })?; + + let platform = match std::env::consts::OS { + "macos" => "darwin", + "windows" => "windows", + "linux" => "linux", + _ => return Err(anyhow!("unsupported platform")), + }; + + let arch = match std::env::consts::ARCH { + "x86_64" => "amd64", + "aarch64" => "arm64", + _ => return Err(anyhow!("unsupported architecture")), + }; + + let download_info = latest_release(client.clone(), platform, arch).await?; + + let binary_path = version_path(download_info.version); + + if has_version(&binary_path).await { + return Ok(binary_path); + } + + let request = HttpRequest::get(&download_info.download_url); + + let mut response = client + .send(request.body(AsyncBody::default())?) + .await + .with_context(|| "Unable to download Supermaven Agent".to_string())?; + + let mut file = File::create(&binary_path) + .await + .with_context(|| format!("Unable to create file at {:?}", binary_path))?; + + futures::io::copy(BufReader::new(response.body_mut()), &mut file) + .await + .with_context(|| format!("Unable to write binary to file at {:?}", binary_path))?; + + #[cfg(not(windows))] + { + file.set_permissions(::from_mode( + 0o755, + )) + .await?; + } + + let mut old_binary_paths = fs::read_dir(&*SUPERMAVEN_DIR).await?; + while let Some(old_binary_path) = old_binary_paths.next().await { + let old_binary_path = old_binary_path?; + if old_binary_path.path() != binary_path { + fs::remove_file(old_binary_path.path()).await?; + } + } + + Ok(binary_path) + } +} diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs index bc05a8f3d3..9c9e05d6b6 100644 --- a/crates/ui/src/components/icon.rs +++ b/crates/ui/src/components/icon.rs @@ -155,6 +155,10 @@ pub enum IconName { Space, Split, Spinner, + Supermaven, + SupermavenDisabled, + SupermavenError, + SupermavenInit, Tab, Terminal, Trash, @@ -261,6 +265,10 @@ impl IconName { IconName::Space => "icons/space.svg", IconName::Split => "icons/split.svg", IconName::Spinner => "icons/spinner.svg", + IconName::Supermaven => "icons/supermaven.svg", + IconName::SupermavenDisabled => "icons/supermaven_disabled.svg", + IconName::SupermavenError => "icons/supermaven_error.svg", + IconName::SupermavenInit => "icons/supermaven_init.svg", IconName::Tab => "icons/tab.svg", IconName::Terminal => "icons/terminal.svg", IconName::Trash => "icons/trash.svg", diff --git a/crates/util/src/paths.rs b/crates/util/src/paths.rs index 205ea72f0a..feb7c19535 100644 --- a/crates/util/src/paths.rs +++ b/crates/util/src/paths.rs @@ -52,6 +52,7 @@ lazy_static::lazy_static! { pub static ref EXTENSIONS_DIR: PathBuf = SUPPORT_DIR.join("extensions"); pub static ref LANGUAGES_DIR: PathBuf = SUPPORT_DIR.join("languages"); pub static ref COPILOT_DIR: PathBuf = SUPPORT_DIR.join("copilot"); + pub static ref SUPERMAVEN_DIR: PathBuf = SUPPORT_DIR.join("supermaven"); pub static ref DEFAULT_PRETTIER_DIR: PathBuf = SUPPORT_DIR.join("prettier"); pub static ref DB_DIR: PathBuf = SUPPORT_DIR.join("db"); pub static ref CRASHES_DIR: Option = cfg!(target_os = "macos") diff --git a/crates/welcome/Cargo.toml b/crates/welcome/Cargo.toml index c18a09673f..e747072cde 100644 --- a/crates/welcome/Cargo.toml +++ b/crates/welcome/Cargo.toml @@ -17,7 +17,7 @@ test-support = [] [dependencies] anyhow.workspace = true client.workspace = true -copilot_ui.workspace = true +inline_completion_button.workspace = true db.workspace = true extensions_ui.workspace = true fuzzy.workspace = true diff --git a/crates/welcome/src/welcome.rs b/crates/welcome/src/welcome.rs index e6a2a53f2e..3ae07cda68 100644 --- a/crates/welcome/src/welcome.rs +++ b/crates/welcome/src/welcome.rs @@ -2,7 +2,6 @@ mod base_keymap_picker; mod base_keymap_setting; use client::{telemetry::Telemetry, TelemetrySettings}; -use copilot_ui; use db::kvp::KEY_VALUE_STORE; use gpui::{ svg, AnyElement, AppContext, EventEmitter, FocusHandle, FocusableView, InteractiveElement, @@ -143,7 +142,7 @@ impl Render for WelcomePage { this.telemetry.report_app_event( "welcome page: sign in to copilot".to_string(), ); - copilot_ui::initiate_sign_in(cx); + inline_completion_button::initiate_sign_in(cx); })), ) .child( diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 9a9f40020a..a8130fe5df 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -35,7 +35,6 @@ collab_ui.workspace = true collections.workspace = true command_palette.workspace = true copilot.workspace = true -copilot_ui.workspace = true db.workspace = true diagnostics.workspace = true editor.workspace = true @@ -51,6 +50,7 @@ go_to_line.workspace = true gpui.workspace = true headless.workspace = true image_viewer.workspace = true +inline_completion_button.workspace = true install_cli.workspace = true isahc.workspace = true journal.workspace = true @@ -83,6 +83,7 @@ settings.workspace = true simplelog = "0.9" smol.workspace = true tab_switcher.workspace = true +supermaven.workspace = true task.workspace = true tasks_ui.workspace = true telemetry_events.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 9850a2f603..3b2e96965e 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -9,16 +9,14 @@ mod zed; use anyhow::{anyhow, Context as _, Result}; use clap::{command, Parser}; use cli::FORCE_CLI_MODE_ENV_VAR_NAME; -use client::{parse_zed_link, telemetry::Telemetry, Client, DevServerToken, UserStore}; +use client::{parse_zed_link, Client, DevServerToken, UserStore}; use collab_ui::channel_view::ChannelView; -use copilot::Copilot; -use copilot_ui::CopilotCompletionProvider; use db::kvp::KEY_VALUE_STORE; -use editor::{Editor, EditorMode}; +use editor::Editor; use env_logger::Builder; use fs::RealFs; use futures::{future, StreamExt}; -use gpui::{App, AppContext, AsyncAppContext, Context, Task, ViewContext, VisualContext}; +use gpui::{App, AppContext, AsyncAppContext, Context, Task, VisualContext}; use image_viewer; use language::LanguageRegistry; use log::LevelFilter; @@ -55,6 +53,8 @@ use zed::{ OpenListener, OpenRequest, }; +use crate::zed::inline_completion_registry; + #[cfg(feature = "mimalloc")] #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; @@ -270,17 +270,20 @@ fn init_ui(args: Args) { editor::init(cx); image_viewer::init(cx); diagnostics::init(cx); + + // Initialize each completion provider. Settings are used for toggling between them. copilot::init( copilot_language_server_id, client.http_client(), node_runtime.clone(), cx, ); + supermaven::init(client.clone(), cx); assistant::init(client.clone(), cx); assistant2::init(client.clone(), cx); - init_inline_completion_provider(client.telemetry().clone(), cx); + inline_completion_registry::init(client.telemetry().clone(), cx); extension::init( fs.clone(), @@ -888,45 +891,3 @@ fn watch_file_types(fs: Arc, cx: &mut AppContext) { #[cfg(not(debug_assertions))] fn watch_file_types(_fs: Arc, _cx: &mut AppContext) {} - -fn init_inline_completion_provider(telemetry: Arc, cx: &mut AppContext) { - if let Some(copilot) = Copilot::global(cx) { - cx.observe_new_views(move |editor: &mut Editor, cx: &mut ViewContext| { - if editor.mode() == EditorMode::Full { - // We renamed some of these actions to not be copilot-specific, but that - // would have not been backwards-compatible. So here we are re-registering - // the actions with the old names to not break people's keymaps. - editor - .register_action(cx.listener( - |editor, _: &copilot::Suggest, cx: &mut ViewContext| { - editor.show_inline_completion(&Default::default(), cx); - }, - )) - .register_action(cx.listener( - |editor, _: &copilot::NextSuggestion, cx: &mut ViewContext| { - editor.next_inline_completion(&Default::default(), cx); - }, - )) - .register_action(cx.listener( - |editor, _: &copilot::PreviousSuggestion, cx: &mut ViewContext| { - editor.previous_inline_completion(&Default::default(), cx); - }, - )) - .register_action(cx.listener( - |editor, - _: &editor::actions::AcceptPartialCopilotSuggestion, - cx: &mut ViewContext| { - editor.accept_partial_inline_completion(&Default::default(), cx); - }, - )); - - let provider = cx.new_model(|_| { - CopilotCompletionProvider::new(copilot.clone()) - .with_telemetry(telemetry.clone()) - }); - editor.set_inline_completion_provider(provider, cx) - } - }) - .detach(); - } -} diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 6c0f155ce2..14cc9febd2 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -1,4 +1,5 @@ mod app_menus; +pub mod inline_completion_registry; mod only_instance; mod open_listener; @@ -127,7 +128,10 @@ pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { }) .detach(); - let copilot = cx.new_view(|cx| copilot_ui::CopilotButton::new(app_state.fs.clone(), cx)); + let inline_completion_button = cx.new_view(|cx| { + inline_completion_button::InlineCompletionButton::new(app_state.fs.clone(), cx) + }); + let diagnostic_summary = cx.new_view(|cx| diagnostics::items::DiagnosticIndicator::new(workspace, cx)); let activity_indicator = @@ -140,7 +144,7 @@ pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { workspace.status_bar().update(cx, |status_bar, cx| { status_bar.add_left_item(diagnostic_summary, cx); status_bar.add_left_item(activity_indicator, cx); - status_bar.add_right_item(copilot, cx); + status_bar.add_right_item(inline_completion_button, cx); status_bar.add_right_item(active_buffer_language, cx); status_bar.add_right_item(vim_mode_indicator, cx); status_bar.add_right_item(cursor_position, cx); diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs new file mode 100644 index 0000000000..7ea50322a3 --- /dev/null +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -0,0 +1,126 @@ +use std::{cell::RefCell, rc::Rc, sync::Arc}; + +use client::telemetry::Telemetry; +use collections::HashMap; +use copilot::{Copilot, CopilotCompletionProvider}; +use editor::{Editor, EditorMode}; +use gpui::{AnyWindowHandle, AppContext, Context, ViewContext, WeakView}; +use language::language_settings::all_language_settings; +use settings::SettingsStore; +use supermaven::{Supermaven, SupermavenCompletionProvider}; + +pub fn init(telemetry: Arc, cx: &mut AppContext) { + let editors: Rc, AnyWindowHandle>>> = Rc::default(); + cx.observe_new_views({ + let editors = editors.clone(); + let telemetry = telemetry.clone(); + move |editor: &mut Editor, cx: &mut ViewContext| { + if editor.mode() != EditorMode::Full { + return; + } + + register_backward_compatible_actions(editor, cx); + + let editor_handle = cx.view().downgrade(); + cx.on_release({ + let editor_handle = editor_handle.clone(); + let editors = editors.clone(); + move |_, _, _| { + editors.borrow_mut().remove(&editor_handle); + } + }) + .detach(); + editors + .borrow_mut() + .insert(editor_handle, cx.window_handle()); + let provider = all_language_settings(None, cx).inline_completions.provider; + assign_inline_completion_provider(editor, provider, &telemetry, cx); + } + }) + .detach(); + + let mut provider = all_language_settings(None, cx).inline_completions.provider; + for (editor, window) in editors.borrow().iter() { + _ = window.update(cx, |_window, cx| { + _ = editor.update(cx, |editor, cx| { + assign_inline_completion_provider(editor, provider, &telemetry, cx); + }) + }); + } + + cx.observe_global::(move |cx| { + let new_provider = all_language_settings(None, cx).inline_completions.provider; + if new_provider != provider { + provider = new_provider; + for (editor, window) in editors.borrow().iter() { + _ = window.update(cx, |_window, cx| { + _ = editor.update(cx, |editor, cx| { + assign_inline_completion_provider(editor, provider, &telemetry, cx); + }) + }); + } + } + }) + .detach(); +} + +fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut ViewContext) { + // We renamed some of these actions to not be copilot-specific, but that + // would have not been backwards-compatible. So here we are re-registering + // the actions with the old names to not break people's keymaps. + editor + .register_action(cx.listener( + |editor, _: &copilot::Suggest, cx: &mut ViewContext| { + editor.show_inline_completion(&Default::default(), cx); + }, + )) + .register_action(cx.listener( + |editor, _: &copilot::NextSuggestion, cx: &mut ViewContext| { + editor.next_inline_completion(&Default::default(), cx); + }, + )) + .register_action(cx.listener( + |editor, _: &copilot::PreviousSuggestion, cx: &mut ViewContext| { + editor.previous_inline_completion(&Default::default(), cx); + }, + )) + .register_action(cx.listener( + |editor, + _: &editor::actions::AcceptPartialCopilotSuggestion, + cx: &mut ViewContext| { + editor.accept_partial_inline_completion(&Default::default(), cx); + }, + )); +} + +fn assign_inline_completion_provider( + editor: &mut Editor, + provider: language::language_settings::InlineCompletionProvider, + telemetry: &Arc, + cx: &mut ViewContext, +) { + match provider { + language::language_settings::InlineCompletionProvider::None => {} + language::language_settings::InlineCompletionProvider::Copilot => { + if let Some(copilot) = Copilot::global(cx) { + if let Some(buffer) = editor.buffer().read(cx).as_singleton() { + if buffer.read(cx).file().is_some() { + copilot.update(cx, |copilot, cx| { + copilot.register_buffer(&buffer, cx); + }); + } + } + let provider = cx.new_model(|_| { + CopilotCompletionProvider::new(copilot).with_telemetry(telemetry.clone()) + }); + editor.set_inline_completion_provider(Some(provider), cx); + } + } + language::language_settings::InlineCompletionProvider::Supermaven => { + if let Some(supermaven) = Supermaven::global(cx) { + let provider = cx.new_model(|_| SupermavenCompletionProvider::new(supermaven)); + editor.set_inline_completion_provider(Some(provider), cx); + } + } + } +}