Supermaven (#10788)

Adds a supermaven provider for completions. There are various other
refactors amidst this branch, primarily to make copilot no longer a
dependency of project as well as show LSP Logs for global LSPs like
copilot properly.

This feature is not enabled by default. We're going to seek to refine it
in the coming weeks.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Max <max@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Kyle Kelley 2024-05-03 12:50:42 -07:00 committed by GitHub
parent 610968815c
commit 6563330239
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
47 changed files with 2242 additions and 827 deletions

97
Cargo.lock generated
View File

@ -2316,6 +2316,7 @@ dependencies = [
"sha2 0.10.7",
"sqlx",
"subtle",
"supermaven_api",
"telemetry_events",
"text",
"theme",
@ -2512,30 +2513,10 @@ dependencies = [
"async-compression",
"async-std",
"async-tar",
"client",
"clock",
"collections",
"command_palette_hooks",
"fs",
"futures 0.3.28",
"gpui",
"language",
"lsp",
"node_runtime",
"parking_lot",
"rpc",
"serde",
"settings",
"smol",
"util",
]
[[package]]
name = "copilot_ui"
version = "0.1.0"
dependencies = [
"anyhow",
"client",
"copilot",
"editor",
"fs",
"futures 0.3.28",
@ -2544,14 +2525,18 @@ dependencies = [
"language",
"lsp",
"menu",
"node_runtime",
"parking_lot",
"project",
"rpc",
"serde",
"serde_json",
"settings",
"smol",
"theme",
"ui",
"util",
"workspace",
"zed_actions",
]
[[package]]
@ -5143,6 +5128,30 @@ dependencies = [
"syn 2.0.59",
]
[[package]]
name = "inline_completion_button"
version = "0.1.0"
dependencies = [
"anyhow",
"copilot",
"editor",
"fs",
"futures 0.3.28",
"gpui",
"indoc",
"language",
"lsp",
"project",
"serde_json",
"settings",
"supermaven",
"theme",
"ui",
"util",
"workspace",
"zed_actions",
]
[[package]]
name = "inotify"
version = "0.9.6"
@ -5548,6 +5557,7 @@ dependencies = [
"anyhow",
"client",
"collections",
"copilot",
"editor",
"env_logger",
"futures 0.3.28",
@ -7422,7 +7432,6 @@ dependencies = [
"client",
"clock",
"collections",
"copilot",
"env_logger",
"fs",
"futures 0.3.28",
@ -9594,6 +9603,43 @@ dependencies = [
"rayon",
]
[[package]]
name = "supermaven"
version = "0.1.0"
dependencies = [
"anyhow",
"client",
"collections",
"editor",
"env_logger",
"futures 0.3.28",
"gpui",
"language",
"log",
"postage",
"project",
"serde",
"serde_json",
"settings",
"smol",
"supermaven_api",
"theme",
"ui",
"util",
]
[[package]]
name = "supermaven_api"
version = "0.1.0"
dependencies = [
"anyhow",
"futures 0.3.28",
"serde",
"serde_json",
"smol",
"util",
]
[[package]]
name = "sval"
version = "2.8.0"
@ -11798,12 +11844,12 @@ version = "0.1.0"
dependencies = [
"anyhow",
"client",
"copilot_ui",
"db",
"editor",
"extensions_ui",
"fuzzy",
"gpui",
"inline_completion_button",
"install_cli",
"picker",
"project",
@ -12683,7 +12729,6 @@ dependencies = [
"collections",
"command_palette",
"copilot",
"copilot_ui",
"db",
"dev_server_projects",
"diagnostics",
@ -12700,6 +12745,7 @@ dependencies = [
"gpui",
"headless",
"image_viewer",
"inline_completion_button",
"install_cli",
"isahc",
"journal",
@ -12730,6 +12776,7 @@ dependencies = [
"settings",
"simplelog",
"smol",
"supermaven",
"tab_switcher",
"task",
"tasks_ui",

View File

@ -20,7 +20,6 @@ members = [
"crates/command_palette",
"crates/command_palette_hooks",
"crates/copilot",
"crates/copilot_ui",
"crates/db",
"crates/diagnostics",
"crates/editor",
@ -42,6 +41,7 @@ members = [
"crates/gpui_macros",
"crates/headless",
"crates/image_viewer",
"crates/inline_completion_button",
"crates/install_cli",
"crates/journal",
"crates/language",
@ -86,6 +86,8 @@ members = [
"crates/storybook",
"crates/sum_tree",
"crates/tab_switcher",
"crates/supermaven",
"crates/supermaven_api",
"crates/terminal",
"crates/terminal_view",
"crates/text",
@ -159,7 +161,6 @@ color = { path = "crates/color" }
command_palette = { path = "crates/command_palette" }
command_palette_hooks = { path = "crates/command_palette_hooks" }
copilot = { path = "crates/copilot" }
copilot_ui = { path = "crates/copilot_ui" }
db = { path = "crates/db" }
diagnostics = { path = "crates/diagnostics" }
editor = { path = "crates/editor" }
@ -180,6 +181,7 @@ gpui_macros = { path = "crates/gpui_macros" }
headless = { path = "crates/headless" }
install_cli = { path = "crates/install_cli" }
image_viewer = { path = "crates/image_viewer" }
inline_completion_button = { path = "crates/inline_completion_button" }
journal = { path = "crates/journal" }
language = { path = "crates/language" }
language_selector = { path = "crates/language_selector" }
@ -220,6 +222,8 @@ settings = { path = "crates/settings" }
snippet = { path = "crates/snippet" }
sqlez = { path = "crates/sqlez" }
sqlez_macros = { path = "crates/sqlez_macros" }
supermaven = { path = "crates/supermaven" }
supermaven_api = { path = "crates/supermaven_api"}
story = { path = "crates/story" }
storybook = { path = "crates/storybook" }
sum_tree = { path = "crates/sum_tree" }

View File

@ -0,0 +1,8 @@
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M3.30859 13.0703C3.80693 13.0703 4.21094 12.6663 4.21094 12.168C4.21094 11.6696 3.80693 11.2656 3.30859 11.2656C2.81025 11.2656 2.40625 11.6696 2.40625 12.168C2.40625 12.6663 2.81025 13.0703 3.30859 13.0703Z" fill="black"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M6.53516 8.03849L4.10799 12.6055L2.51562 11.7584L4.94279 7.19141L6.53516 8.03849Z" fill="black"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M7.38281 2.62443L4.93916 7.19141L3.33594 6.34432L5.77959 1.77734L7.38281 2.62443Z" fill="black"/>
<path d="M6.5625 3.08984C7.06084 3.08984 7.46484 2.68585 7.46484 2.1875C7.46484 1.68915 7.06084 1.28516 6.5625 1.28516C6.06416 1.28516 5.66016 1.68915 5.66016 2.1875C5.66016 2.68585 6.06416 3.08984 6.5625 3.08984Z" fill="black"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M10.882 1.31204C11.2842 1.41224 11.5664 1.7732 11.5664 2.18737V12.168H9.76084V5.8056L8.12938 8.87176L6.53516 8.02471L9.86653 1.76385C10.0611 1.39816 10.4799 1.21184 10.882 1.31204Z" fill="black"/>
<path d="M10.6641 13.0703C11.1624 13.0703 11.5664 12.6663 11.5664 12.168C11.5664 11.6696 11.1624 11.2656 10.6641 11.2656C10.1657 11.2656 9.76172 11.6696 9.76172 12.168C9.76172 12.6663 10.1657 13.0703 10.6641 13.0703Z" fill="black"/>
</svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@ -0,0 +1,15 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g opacity="0.5">
<path d="M3.78125 14.9375C4.35078 14.9375 4.8125 14.4758 4.8125 13.9062C4.8125 13.3367 4.35078 12.875 3.78125 12.875C3.21172 12.875 2.75 13.3367 2.75 13.9062C2.75 14.4758 3.21172 14.9375 3.78125 14.9375Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M7.46875 9.18684L4.69484 14.4062L2.875 13.4382L5.64891 8.21875L7.46875 9.18684Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M8.4375 2.99935L5.64475 8.21875L3.8125 7.25066L6.60525 2.03125L8.4375 2.99935Z" fill="white"/>
<path d="M7.5 3.53125C8.06953 3.53125 8.53125 3.06954 8.53125 2.5C8.53125 1.93046 8.06953 1.46875 7.5 1.46875C6.93047 1.46875 6.46875 1.93046 6.46875 2.5C6.46875 3.06954 6.93047 3.53125 7.5 3.53125Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M12.4366 1.49947C12.8962 1.61399 13.2188 2.02651 13.2188 2.49985V13.9063H11.1552V6.63497L9.29072 10.1392L7.46875 9.17109L11.276 2.01583C11.4984 1.59789 11.977 1.38496 12.4366 1.49947Z" fill="white"/>
<path d="M12.1875 14.9375C12.757 14.9375 13.2188 14.4758 13.2188 13.9062C13.2188 13.3367 12.757 12.875 12.1875 12.875C11.618 12.875 11.1562 13.3367 11.1562 13.9062C11.1562 14.4758 11.618 14.9375 12.1875 14.9375Z" fill="white"/>
</g>
<g>
<path d="M0.906311 6.42261L1.75155 4.60999L15.3462 10.9493L14.5009 12.7619L0.906311 6.42261Z" fill="white"/>
<circle cx="14.7841" cy="11.7906" r="1" transform="rotate(-65 14.7841 11.7906)" fill="white"/>
<circle cx="1.32893" cy="5.51631" r="1" transform="rotate(-65 1.32893 5.51631)" fill="white"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

@ -0,0 +1,11 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g opacity="0.5">
<path d="M3.78125 14.9375C4.35078 14.9375 4.8125 14.4758 4.8125 13.9062C4.8125 13.3367 4.35078 12.875 3.78125 12.875C3.21172 12.875 2.75 13.3367 2.75 13.9062C2.75 14.4758 3.21172 14.9375 3.78125 14.9375Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M7.46875 9.18684L4.69484 14.4062L2.875 13.4382L5.64891 8.21875L7.46875 9.18684Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M8.4375 2.99935L5.64475 8.21875L3.8125 7.25066L6.60525 2.03125L8.4375 2.99935Z" fill="white"/>
<path d="M7.5 3.53125C8.06953 3.53125 8.53125 3.06954 8.53125 2.5C8.53125 1.93046 8.06953 1.46875 7.5 1.46875C6.93047 1.46875 6.46875 1.93046 6.46875 2.5C6.46875 3.06954 6.93047 3.53125 7.5 3.53125Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M12.4366 1.49947C12.8962 1.61399 13.2188 2.02651 13.2188 2.49985V13.9063H11.1552V6.63497L9.29072 10.1392L7.46875 9.17109L11.276 2.01583C11.4984 1.59789 11.977 1.38496 12.4366 1.49947Z" fill="white"/>
<path d="M12.1875 14.9375C12.757 14.9375 13.2188 14.4758 13.2188 13.9062C13.2188 13.3367 12.757 12.875 12.1875 12.875C11.618 12.875 11.1562 13.3367 11.1562 13.9062C11.1562 14.4758 11.618 14.9375 12.1875 14.9375Z" fill="white"/>
</g>
<path fill-rule="evenodd" clip-rule="evenodd" d="M14.6847 15.9265C14.7823 16.0241 14.9406 16.0241 15.0382 15.9265L15.9259 15.0387C16.0235 14.9411 16.0235 14.7828 15.9259 14.6851L14.2408 12.9999L15.9259 11.3146C16.0236 11.217 16.0236 11.0587 15.9259 10.961L15.0382 10.0733C14.9406 9.97561 14.7823 9.97561 14.6847 10.0733L12.9996 11.7585L11.3145 10.0732C11.2169 9.97559 11.0586 9.97559 10.9609 10.0732L10.0732 10.961C9.97559 11.0587 9.97559 11.217 10.0732 11.3146L11.7584 12.9999L10.0732 14.6851C9.97562 14.7828 9.97562 14.9411 10.0732 15.0387L10.9609 15.9265C11.0586 16.0242 11.2169 16.0242 11.3145 15.9265L12.9996 14.2413L14.6847 15.9265Z" fill="white"/>
</svg>

After

Width:  |  Height:  |  Size: 1.9 KiB

View File

@ -0,0 +1,11 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g opacity="0.5">
<path d="M3.78125 14.9375C4.35078 14.9375 4.8125 14.4758 4.8125 13.9062C4.8125 13.3367 4.35078 12.875 3.78125 12.875C3.21172 12.875 2.75 13.3367 2.75 13.9062C2.75 14.4758 3.21172 14.9375 3.78125 14.9375Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M7.46875 9.18684L4.69484 14.4062L2.875 13.4382L5.64891 8.21875L7.46875 9.18684Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M8.4375 2.99935L5.64475 8.21875L3.8125 7.25066L6.60525 2.03125L8.4375 2.99935Z" fill="white"/>
<path d="M7.5 3.53125C8.06953 3.53125 8.53125 3.06954 8.53125 2.5C8.53125 1.93046 8.06953 1.46875 7.5 1.46875C6.93047 1.46875 6.46875 1.93046 6.46875 2.5C6.46875 3.06954 6.93047 3.53125 7.5 3.53125Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M12.4366 1.49947C12.8962 1.61399 13.2188 2.02651 13.2188 2.49985V13.9063H11.1552V6.63497L9.29072 10.1392L7.46875 9.17109L11.276 2.01583C11.4984 1.59789 11.977 1.38496 12.4366 1.49947Z" fill="white"/>
<path d="M12.1875 14.9375C12.757 14.9375 13.2188 14.4758 13.2188 13.9062C13.2188 13.3367 12.757 12.875 12.1875 12.875C11.618 12.875 11.1562 13.3367 11.1562 13.9062C11.1562 14.4758 11.618 14.9375 12.1875 14.9375Z" fill="white"/>
</g>
<circle cx="13" cy="13" r="3" fill="white"/>
</svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@ -12,8 +12,8 @@
"base_keymap": "VSCode",
// Features that can be globally enabled or disabled
"features": {
// Show Copilot icon in status bar
"copilot": true
// Which inline completion provider to use.
"inline_completion_provider": "copilot"
},
// The name of a font to use for rendering text in the editor
"buffer_font_family": "Zed Mono",

View File

@ -1,7 +1,7 @@
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::{convert::TryFrom, sync::Arc};
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
@ -141,7 +141,7 @@ pub enum TextDelta {
}
pub async fn stream_completion(
client: &dyn HttpClient,
client: Arc<dyn HttpClient>,
api_url: &str,
api_key: &str,
request: Request,

View File

@ -39,6 +39,7 @@ live_kit_server.workspace = true
log.workspace = true
nanoid.workspace = true
open_ai.workspace = true
supermaven_api.workspace = true
parking_lot.workspace = true
prometheus = "0.13"
prost.workspace = true

View File

@ -172,6 +172,11 @@ spec:
secretKeyRef:
name: slack
key: panics_webhook
- name: SUPERMAVEN_ADMIN_API_KEY
valueFrom:
secretKeyRef:
name: supermaven
key: api_key
- name: INVITE_LINK_PREFIX
value: ${INVITE_LINK_PREFIX}
- name: RUST_BACKTRACE

View File

@ -0,0 +1,2 @@
use anyhow::{anyhow, Result};
use rpc::proto;

View File

@ -138,6 +138,7 @@ pub struct Config {
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
pub supermaven_admin_api_key: Option<Arc<str>>,
}
impl Config {

View File

@ -34,6 +34,7 @@ pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
use sha2::Digest;
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
use futures::{
channel::oneshot,
@ -148,7 +149,8 @@ struct Session {
peer: Arc<Peer>,
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
http_client: IsahcHttpClient,
supermaven_client: Option<Arc<SupermavenAdminApi>>,
http_client: Arc<IsahcHttpClient>,
rate_limiter: Arc<RateLimiter>,
_executor: Executor,
}
@ -189,6 +191,14 @@ impl Session {
}
}
fn is_staff(&self) -> bool {
match &self.principal {
Principal::User(user) => user.admin,
Principal::Impersonated { .. } => true,
Principal::DevServer(_) => false,
}
}
fn dev_server_id(&self) -> Option<DevServerId> {
match &self.principal {
Principal::User(_) | Principal::Impersonated { .. } => None,
@ -233,6 +243,14 @@ impl UserSession {
pub fn user_id(&self) -> UserId {
self.0.user_id().unwrap()
}
pub fn email(&self) -> Option<String> {
match &self.0.principal {
Principal::User(user) => user.email_address.clone(),
Principal::Impersonated { user, .. } => user.email_address.clone(),
Principal::DevServer(..) => None,
}
}
}
impl Deref for UserSession {
@ -561,6 +579,7 @@ impl Server {
.add_request_handler(user_handler(get_private_user_info))
.add_message_handler(user_message_handler(acknowledge_channel_message))
.add_message_handler(user_message_handler(acknowledge_buffer_version))
.add_request_handler(user_handler(get_supermaven_api_key))
.add_streaming_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
@ -938,13 +957,22 @@ impl Server {
tracing::info!("connection opened");
let http_client = match IsahcHttpClient::new() {
Ok(http_client) => http_client,
Ok(http_client) => Arc::new(http_client),
Err(error) => {
tracing::error!(?error, "failed to create HTTP client");
return;
}
};
let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
Some(Arc::new(SupermavenAdminApi::new(
supermaven_admin_api_key.to_string(),
http_client.clone(),
)))
} else {
None
};
let session = Session {
principal: principal.clone(),
connection_id,
@ -955,6 +983,7 @@ impl Server {
http_client,
rate_limiter: this.app_state.rate_limiter.clone(),
_executor: executor.clone(),
supermaven_client,
};
if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
@ -4210,7 +4239,7 @@ async fn complete_with_open_ai(
api_key: Arc<str>,
) -> Result<()> {
let mut completion_stream = open_ai::stream_completion(
&session.http_client,
session.http_client.as_ref(),
OPEN_AI_API_URL,
&api_key,
crate::ai::language_model_request_to_open_ai(request)?,
@ -4274,7 +4303,7 @@ async fn complete_with_google_ai(
api_key: Arc<str>,
) -> Result<()> {
let mut stream = google_ai::stream_generate_content(
&session.http_client,
session.http_client.clone(),
google_ai::API_URL,
api_key.as_ref(),
crate::ai::language_model_request_to_google_ai(request)?,
@ -4358,7 +4387,7 @@ async fn complete_with_anthropic(
.collect();
let mut stream = anthropic::stream_completion(
&session.http_client,
session.http_client.clone(),
"https://api.anthropic.com",
&api_key,
anthropic::Request {
@ -4482,7 +4511,7 @@ async fn count_tokens_with_language_model(
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
let tokens_response = google_ai::count_tokens(
&session.http_client,
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
crate::ai::count_tokens_request_to_google_ai(request)?,
@ -4530,7 +4559,7 @@ async fn compute_embeddings(
let embeddings = match request.model.as_str() {
"openai/text-embedding-3-small" => {
open_ai::embed(
&session.http_client,
session.http_client.as_ref(),
OPEN_AI_API_URL,
&api_key,
OpenAiEmbeddingModel::TextEmbedding3Small,
@ -4602,6 +4631,37 @@ async fn authorize_access_to_language_models(session: &UserSession) -> Result<()
}
}
/// Get a Supermaven API key for the user
async fn get_supermaven_api_key(
_request: proto::GetSupermavenApiKey,
response: Response<proto::GetSupermavenApiKey>,
session: UserSession,
) -> Result<()> {
let user_id: String = session.user_id().to_string();
if !session.is_staff() {
return Err(anyhow!("supermaven not enabled for this account"))?;
}
let email = session
.email()
.ok_or_else(|| anyhow!("user must have an email"))?;
let supermaven_admin_api = session
.supermaven_client
.as_ref()
.ok_or_else(|| anyhow!("supermaven not configured"))?;
let result = supermaven_admin_api
.try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
.await?;
response.send(proto::GetSupermavenApiKeyResponse {
api_key: result.api_key,
})?;
Ok(())
}
/// Start receiving chat updates for a channel
async fn join_channel_chat(
request: proto::JoinChannelChat,

View File

@ -655,6 +655,7 @@ impl TestServer {
auto_join_channel_id: None,
migrations_path: None,
seed_path: None,
supermaven_admin_api_key: None,
},
})
}

View File

@ -27,28 +27,38 @@ anyhow.workspace = true
async-compression.workspace = true
async-tar.workspace = true
collections.workspace = true
client.workspace = true
command_palette_hooks.workspace = true
editor.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
lsp.workspace = true
menu.workspace = true
node_runtime.workspace = true
parking_lot.workspace = true
project.workspace = true
serde.workspace = true
settings.workspace = true
smol.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true
[target.'cfg(windows)'.dependencies]
async-std = { version = "1.12.0", features = ["unstable"] }
[dev-dependencies]
clock.workspace = true
indoc.workspace = true
serde_json.workspace = true
collections = { workspace = true, features = ["test-support"] }
fs = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
lsp = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
rpc = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
theme = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }

View File

@ -1,4 +1,7 @@
mod copilot_completion_provider;
pub mod request;
mod sign_in;
use anyhow::{anyhow, Context as _, Result};
use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive;
@ -10,9 +13,9 @@ use gpui::{
ModelContext, Task, WeakModel,
};
use language::{
language_settings::{all_language_settings, language_settings},
point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language,
LanguageServerName, PointUtf16, ToPointUtf16,
language_settings::{all_language_settings, language_settings, InlineCompletionProvider},
point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
ToPointUtf16,
};
use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId};
use node_runtime::NodeRuntime;
@ -32,6 +35,9 @@ use util::{
fs::remove_matching, github::latest_github_release, http::HttpClient, maybe, paths, ResultExt,
};
pub use copilot_completion_provider::CopilotCompletionProvider;
pub use sign_in::CopilotCodeVerification;
actions!(
copilot,
[
@ -144,7 +150,6 @@ impl CopilotServer {
}
struct RunningCopilotServer {
name: LanguageServerName,
lsp: Arc<LanguageServer>,
sign_in_status: SignInStatus,
registered_buffers: HashMap<EntityId, RegisteredBuffer>,
@ -354,7 +359,9 @@ impl Copilot {
let server_id = self.server_id;
let http = self.http.clone();
let node_runtime = self.node_runtime.clone();
if all_language_settings(None, cx).copilot_enabled(None, None) {
if all_language_settings(None, cx).inline_completions.provider
== InlineCompletionProvider::Copilot
{
if matches!(self.server, CopilotServer::Disabled) {
let start_task = cx
.spawn(move |this, cx| {
@ -393,7 +400,6 @@ impl Copilot {
http: http.clone(),
node_runtime,
server: CopilotServer::Running(RunningCopilotServer {
name: LanguageServerName(Arc::from("copilot")),
lsp: Arc::new(server),
sign_in_status: SignInStatus::Authorized,
registered_buffers: Default::default(),
@ -467,7 +473,6 @@ impl Copilot {
match server {
Ok((server, status)) => {
this.server = CopilotServer::Running(RunningCopilotServer {
name: LanguageServerName(Arc::from("copilot")),
lsp: server,
sign_in_status: SignInStatus::SignedOut,
registered_buffers: Default::default(),
@ -607,9 +612,9 @@ impl Copilot {
cx.background_executor().spawn(start_task)
}
pub fn language_server(&self) -> Option<(&LanguageServerName, &Arc<LanguageServer>)> {
pub fn language_server(&self) -> Option<&Arc<LanguageServer>> {
if let CopilotServer::Running(server) = &self.server {
Some((&server.name, &server.lsp))
Some(&server.lsp)
} else {
None
}

View File

@ -1,10 +1,12 @@
use crate::{Completion, Copilot};
use anyhow::Result;
use client::telemetry::Telemetry;
use copilot::Copilot;
use editor::{Direction, InlineCompletionProvider};
use gpui::{AppContext, EntityId, Model, ModelContext, Task};
use language::language_settings::AllLanguageSettings;
use language::{language_settings::all_language_settings, Buffer, OffsetRangeExt, ToOffset};
use language::{
language_settings::{all_language_settings, AllLanguageSettings},
Buffer, OffsetRangeExt, ToOffset,
};
use settings::Settings;
use std::{path::Path, sync::Arc, time::Duration};
@ -13,7 +15,7 @@ pub const COPILOT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
pub struct CopilotCompletionProvider {
cycled: bool,
buffer_id: Option<EntityId>,
completions: Vec<copilot::Completion>,
completions: Vec<Completion>,
active_completion_index: usize,
file_extension: Option<String>,
pending_refresh: Task<Result<()>>,
@ -42,11 +44,11 @@ impl CopilotCompletionProvider {
self
}
fn active_completion(&self) -> Option<&copilot::Completion> {
fn active_completion(&self) -> Option<&Completion> {
self.completions.get(self.active_completion_index)
}
fn push_completion(&mut self, new_completion: copilot::Completion) {
fn push_completion(&mut self, new_completion: Completion) {
for completion in &self.completions {
if completion.text == new_completion.text && completion.range == new_completion.range {
return;
@ -71,7 +73,7 @@ impl InlineCompletionProvider for CopilotCompletionProvider {
let file = buffer.file();
let language = buffer.language_at(cursor_position);
let settings = all_language_settings(file, cx);
settings.copilot_enabled(language.as_ref(), file.map(|f| f.path().as_ref()))
settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()))
}
fn refresh(
@ -196,7 +198,10 @@ impl InlineCompletionProvider for CopilotCompletionProvider {
fn discard(&mut self, cx: &mut ModelContext<Self>) {
let settings = AllLanguageSettings::get_global(cx);
if !settings.copilot.feature_enabled {
let copilot_enabled = settings.inline_completions_enabled(None, None);
if !copilot_enabled {
return;
}
@ -298,7 +303,9 @@ mod tests {
)
.await;
let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot));
cx.update_editor(|editor, cx| editor.set_inline_completion_provider(copilot_provider, cx));
cx.update_editor(|editor, cx| {
editor.set_inline_completion_provider(Some(copilot_provider), cx)
});
// When inserting, ensure autocompletion is favored over Copilot suggestions.
cx.set_state(indoc! {"
@ -318,7 +325,7 @@ mod tests {
);
handle_copilot_completion_request(
&copilot_lsp,
vec![copilot::request::Completion {
vec![crate::request::Completion {
text: "one.copilot1".into(),
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)),
..Default::default()
@ -360,7 +367,7 @@ mod tests {
);
handle_copilot_completion_request(
&copilot_lsp,
vec![copilot::request::Completion {
vec![crate::request::Completion {
text: "one.copilot1".into(),
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)),
..Default::default()
@ -393,7 +400,7 @@ mod tests {
);
handle_copilot_completion_request(
&copilot_lsp,
vec![copilot::request::Completion {
vec![crate::request::Completion {
text: "one.copilot1".into(),
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)),
..Default::default()
@ -426,7 +433,7 @@ mod tests {
// After debouncing, new Copilot completions should be requested.
handle_copilot_completion_request(
&copilot_lsp,
vec![copilot::request::Completion {
vec![crate::request::Completion {
text: "one.copilot2".into(),
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 5)),
..Default::default()
@ -503,7 +510,7 @@ mod tests {
});
handle_copilot_completion_request(
&copilot_lsp,
vec![copilot::request::Completion {
vec![crate::request::Completion {
text: " let x = 4;".into(),
range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 2)),
..Default::default()
@ -553,7 +560,9 @@ mod tests {
)
.await;
let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot));
cx.update_editor(|editor, cx| editor.set_inline_completion_provider(copilot_provider, cx));
cx.update_editor(|editor, cx| {
editor.set_inline_completion_provider(Some(copilot_provider), cx)
});
// Setup the editor with a completion request.
cx.set_state(indoc! {"
@ -573,7 +582,7 @@ mod tests {
);
handle_copilot_completion_request(
&copilot_lsp,
vec![copilot::request::Completion {
vec![crate::request::Completion {
text: "one.copilot1".into(),
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)),
..Default::default()
@ -615,7 +624,7 @@ mod tests {
);
handle_copilot_completion_request(
&copilot_lsp,
vec![copilot::request::Completion {
vec![crate::request::Completion {
text: "one.123. copilot\n 456".into(),
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)),
..Default::default()
@ -675,7 +684,9 @@ mod tests {
)
.await;
let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot));
cx.update_editor(|editor, cx| editor.set_inline_completion_provider(copilot_provider, cx));
cx.update_editor(|editor, cx| {
editor.set_inline_completion_provider(Some(copilot_provider), cx)
});
cx.set_state(indoc! {"
one
@ -685,7 +696,7 @@ mod tests {
handle_copilot_completion_request(
&copilot_lsp,
vec![copilot::request::Completion {
vec![crate::request::Completion {
text: "two.foo()".into(),
range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 2)),
..Default::default()
@ -756,13 +767,13 @@ mod tests {
let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot));
editor
.update(cx, |editor, cx| {
editor.set_inline_completion_provider(copilot_provider, cx)
editor.set_inline_completion_provider(Some(copilot_provider), cx)
})
.unwrap();
handle_copilot_completion_request(
&copilot_lsp,
vec![copilot::request::Completion {
vec![crate::request::Completion {
text: "b = 2 + a".into(),
range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 5)),
..Default::default()
@ -788,7 +799,7 @@ mod tests {
handle_copilot_completion_request(
&copilot_lsp,
vec![copilot::request::Completion {
vec![crate::request::Completion {
text: "d = 4 + c".into(),
range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 6)),
..Default::default()
@ -833,7 +844,7 @@ mod tests {
async fn test_copilot_disabled_globs(executor: BackgroundExecutor, cx: &mut TestAppContext) {
init_test(cx, |settings| {
settings
.copilot
.inline_completions
.get_or_insert(Default::default())
.disabled_globs = Some(vec![".env*".to_string()]);
});
@ -888,15 +899,15 @@ mod tests {
let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot));
editor
.update(cx, |editor, cx| {
editor.set_inline_completion_provider(copilot_provider, cx)
editor.set_inline_completion_provider(Some(copilot_provider), cx)
})
.unwrap();
let mut copilot_requests = copilot_lsp
.handle_request::<copilot::request::GetCompletions, _, _>(
.handle_request::<crate::request::GetCompletions, _, _>(
move |_params, _cx| async move {
Ok(copilot::request::GetCompletionsResult {
completions: vec![copilot::request::Completion {
Ok(crate::request::GetCompletionsResult {
completions: vec![crate::request::Completion {
text: "next line".into(),
range: lsp::Range::new(
lsp::Position::new(1, 0),
@ -931,21 +942,21 @@ mod tests {
fn handle_copilot_completion_request(
lsp: &lsp::FakeLanguageServer,
completions: Vec<copilot::request::Completion>,
completions_cycling: Vec<copilot::request::Completion>,
completions: Vec<crate::request::Completion>,
completions_cycling: Vec<crate::request::Completion>,
) {
lsp.handle_request::<copilot::request::GetCompletions, _, _>(move |_params, _cx| {
lsp.handle_request::<crate::request::GetCompletions, _, _>(move |_params, _cx| {
let completions = completions.clone();
async move {
Ok(copilot::request::GetCompletionsResult {
Ok(crate::request::GetCompletionsResult {
completions: completions.clone(),
})
}
});
lsp.handle_request::<copilot::request::GetCompletionsCycling, _, _>(move |_params, _cx| {
lsp.handle_request::<crate::request::GetCompletionsCycling, _, _>(move |_params, _cx| {
let completions_cycling = completions_cycling.clone();
async move {
Ok(copilot::request::GetCompletionsResult {
Ok(crate::request::GetCompletionsResult {
completions: completions_cycling.clone(),
})
}

View File

@ -1,4 +1,4 @@
use copilot::{request::PromptUserDeviceFlow, Copilot, Status};
use crate::{request::PromptUserDeviceFlow, Copilot, Status};
use gpui::{
div, svg, AppContext, ClipboardItem, DismissEvent, Element, EventEmitter, FocusHandle,
FocusableView, InteractiveElement, IntoElement, Model, MouseDownEvent, ParentElement, Render,
@ -26,7 +26,7 @@ impl EventEmitter<DismissEvent> for CopilotCodeVerification {}
impl ModalView for CopilotCodeVerification {}
impl CopilotCodeVerification {
pub(crate) fn new(copilot: &Model<Copilot>, cx: &mut ViewContext<Self>) -> Self {
pub fn new(copilot: &Model<Copilot>, cx: &mut ViewContext<Self>) -> Self {
let status = copilot.read(cx).status();
Self {
status,

View File

@ -1,403 +0,0 @@
use crate::sign_in::CopilotCodeVerification;
use anyhow::Result;
use copilot::{Copilot, SignOut, Status};
use editor::{scroll::Autoscroll, Editor};
use fs::Fs;
use gpui::{
div, Action, AnchorCorner, AppContext, AsyncWindowContext, Entity, IntoElement, ParentElement,
Render, Subscription, View, ViewContext, WeakView, WindowContext,
};
use language::{
language_settings::{self, all_language_settings, AllLanguageSettings},
File, Language,
};
use settings::{update_settings_file, Settings, SettingsStore};
use std::{path::Path, sync::Arc};
use util::{paths, ResultExt};
use workspace::notifications::NotificationId;
use workspace::{
create_and_open_local_file,
item::ItemHandle,
ui::{
popover_menu, ButtonCommon, Clickable, ContextMenu, IconButton, IconName, IconSize, Tooltip,
},
StatusItemView, Toast, Workspace,
};
use zed_actions::OpenBrowser;
const COPILOT_SETTINGS_URL: &str = "https://github.com/settings/copilot";
struct CopilotStartingToast;
struct CopilotErrorToast;
pub struct CopilotButton {
editor_subscription: Option<(Subscription, usize)>,
editor_enabled: Option<bool>,
language: Option<Arc<Language>>,
file: Option<Arc<dyn File>>,
fs: Arc<dyn Fs>,
}
impl Render for CopilotButton {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let all_language_settings = all_language_settings(None, cx);
if !all_language_settings.copilot.feature_enabled {
return div();
}
let Some(copilot) = Copilot::global(cx) else {
return div();
};
let status = copilot.read(cx).status();
let enabled = self
.editor_enabled
.unwrap_or_else(|| all_language_settings.copilot_enabled(None, None));
let icon = match status {
Status::Error(_) => IconName::CopilotError,
Status::Authorized => {
if enabled {
IconName::Copilot
} else {
IconName::CopilotDisabled
}
}
_ => IconName::CopilotInit,
};
if let Status::Error(e) = status {
return div().child(
IconButton::new("copilot-error", icon)
.icon_size(IconSize::Small)
.on_click(cx.listener(move |_, _, cx| {
if let Some(workspace) = cx.window_handle().downcast::<Workspace>() {
workspace
.update(cx, |workspace, cx| {
workspace.show_toast(
Toast::new(
NotificationId::unique::<CopilotErrorToast>(),
format!("Copilot can't be started: {}", e),
)
.on_click(
"Reinstall Copilot",
|cx| {
if let Some(copilot) = Copilot::global(cx) {
copilot
.update(cx, |copilot, cx| {
copilot.reinstall(cx)
})
.detach();
}
},
),
cx,
);
})
.ok();
}
}))
.tooltip(|cx| Tooltip::text("GitHub Copilot", cx)),
);
}
let this = cx.view().clone();
div().child(
popover_menu("copilot")
.menu(move |cx| match status {
Status::Authorized => {
Some(this.update(cx, |this, cx| this.build_copilot_menu(cx)))
}
_ => Some(this.update(cx, |this, cx| this.build_copilot_start_menu(cx))),
})
.anchor(AnchorCorner::BottomRight)
.trigger(
IconButton::new("copilot-icon", icon)
.tooltip(|cx| Tooltip::text("GitHub Copilot", cx)),
),
)
}
}
impl CopilotButton {
pub fn new(fs: Arc<dyn Fs>, cx: &mut ViewContext<Self>) -> Self {
if let Some(copilot) = Copilot::global(cx) {
cx.observe(&copilot, |_, _, cx| cx.notify()).detach()
}
cx.observe_global::<SettingsStore>(move |_, cx| cx.notify())
.detach();
Self {
editor_subscription: None,
editor_enabled: None,
language: None,
file: None,
fs,
}
}
pub fn build_copilot_start_menu(&mut self, cx: &mut ViewContext<Self>) -> View<ContextMenu> {
let fs = self.fs.clone();
ContextMenu::build(cx, |menu, _| {
menu.entry("Sign In", None, initiate_sign_in).entry(
"Disable Copilot",
None,
move |cx| hide_copilot(fs.clone(), cx),
)
})
}
pub fn build_copilot_menu(&mut self, cx: &mut ViewContext<Self>) -> View<ContextMenu> {
let fs = self.fs.clone();
ContextMenu::build(cx, move |mut menu, cx| {
if let Some(language) = self.language.clone() {
let fs = fs.clone();
let language_enabled =
language_settings::language_settings(Some(&language), None, cx)
.show_copilot_suggestions;
menu = menu.entry(
format!(
"{} Suggestions for {}",
if language_enabled { "Hide" } else { "Show" },
language.name()
),
None,
move |cx| toggle_copilot_for_language(language.clone(), fs.clone(), cx),
);
}
let settings = AllLanguageSettings::get_global(cx);
if let Some(file) = &self.file {
let path = file.path().clone();
let path_enabled = settings.copilot_enabled_for_path(&path);
menu = menu.entry(
format!(
"{} Suggestions for This Path",
if path_enabled { "Hide" } else { "Show" }
),
None,
move |cx| {
if let Some(workspace) = cx.window_handle().downcast::<Workspace>() {
if let Ok(workspace) = workspace.root_view(cx) {
let workspace = workspace.downgrade();
cx.spawn(|cx| {
configure_disabled_globs(
workspace,
path_enabled.then_some(path.clone()),
cx,
)
})
.detach_and_log_err(cx);
}
}
},
);
}
let globally_enabled = settings.copilot_enabled(None, None);
menu.entry(
if globally_enabled {
"Hide Suggestions for All Files"
} else {
"Show Suggestions for All Files"
},
None,
move |cx| toggle_copilot_globally(fs.clone(), cx),
)
.separator()
.link(
"Copilot Settings",
OpenBrowser {
url: COPILOT_SETTINGS_URL.to_string(),
}
.boxed_clone(),
)
.action("Sign Out", SignOut.boxed_clone())
})
}
pub fn update_enabled(&mut self, editor: View<Editor>, cx: &mut ViewContext<Self>) {
let editor = editor.read(cx);
let snapshot = editor.buffer().read(cx).snapshot(cx);
let suggestion_anchor = editor.selections.newest_anchor().start;
let language = snapshot.language_at(suggestion_anchor);
let file = snapshot.file_at(suggestion_anchor).cloned();
self.editor_enabled = {
let file = file.as_ref();
Some(
file.map(|file| !file.is_private()).unwrap_or(true)
&& all_language_settings(file, cx)
.copilot_enabled(language, file.map(|file| file.path().as_ref())),
)
};
self.language = language.cloned();
self.file = file;
cx.notify()
}
}
impl StatusItemView for CopilotButton {
fn set_active_pane_item(&mut self, item: Option<&dyn ItemHandle>, cx: &mut ViewContext<Self>) {
if let Some(editor) = item.and_then(|item| item.act_as::<Editor>(cx)) {
self.editor_subscription = Some((
cx.observe(&editor, Self::update_enabled),
editor.entity_id().as_u64() as usize,
));
self.update_enabled(editor, cx);
} else {
self.language = None;
self.editor_subscription = None;
self.editor_enabled = None;
}
cx.notify();
}
}
async fn configure_disabled_globs(
workspace: WeakView<Workspace>,
path_to_disable: Option<Arc<Path>>,
mut cx: AsyncWindowContext,
) -> Result<()> {
let settings_editor = workspace
.update(&mut cx, |_, cx| {
create_and_open_local_file(&paths::SETTINGS, cx, || {
settings::initial_user_settings_content().as_ref().into()
})
})?
.await?
.downcast::<Editor>()
.unwrap();
settings_editor.downgrade().update(&mut cx, |item, cx| {
let text = item.buffer().read(cx).snapshot(cx).text();
let settings = cx.global::<SettingsStore>();
let edits = settings.edits_for_update::<AllLanguageSettings>(&text, |file| {
let copilot = file.copilot.get_or_insert_with(Default::default);
let globs = copilot.disabled_globs.get_or_insert_with(|| {
settings
.get::<AllLanguageSettings>(None)
.copilot
.disabled_globs
.iter()
.map(|glob| glob.glob().to_string())
.collect()
});
if let Some(path_to_disable) = &path_to_disable {
globs.push(path_to_disable.to_string_lossy().into_owned());
} else {
globs.clear();
}
});
if !edits.is_empty() {
item.change_selections(Some(Autoscroll::newest()), cx, |selections| {
selections.select_ranges(edits.iter().map(|e| e.0.clone()));
});
// When *enabling* a path, don't actually perform an edit, just select the range.
if path_to_disable.is_some() {
item.edit(edits.iter().cloned(), cx);
}
}
})?;
anyhow::Ok(())
}
fn toggle_copilot_globally(fs: Arc<dyn Fs>, cx: &mut AppContext) {
let show_copilot_suggestions = all_language_settings(None, cx).copilot_enabled(None, None);
update_settings_file::<AllLanguageSettings>(fs, cx, move |file| {
file.defaults.show_copilot_suggestions = Some(!show_copilot_suggestions)
});
}
fn toggle_copilot_for_language(language: Arc<Language>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
let show_copilot_suggestions =
all_language_settings(None, cx).copilot_enabled(Some(&language), None);
update_settings_file::<AllLanguageSettings>(fs, cx, move |file| {
file.languages
.entry(language.name())
.or_default()
.show_copilot_suggestions = Some(!show_copilot_suggestions);
});
}
fn hide_copilot(fs: Arc<dyn Fs>, cx: &mut AppContext) {
update_settings_file::<AllLanguageSettings>(fs, cx, move |file| {
file.features.get_or_insert(Default::default()).copilot = Some(false);
});
}
pub fn initiate_sign_in(cx: &mut WindowContext) {
let Some(copilot) = Copilot::global(cx) else {
return;
};
let status = copilot.read(cx).status();
let Some(workspace) = cx.window_handle().downcast::<Workspace>() else {
return;
};
match status {
Status::Starting { task } => {
let Some(workspace) = cx.window_handle().downcast::<Workspace>() else {
return;
};
let Ok(workspace) = workspace.update(cx, |workspace, cx| {
workspace.show_toast(
Toast::new(
NotificationId::unique::<CopilotStartingToast>(),
"Copilot is starting...",
),
cx,
);
workspace.weak_handle()
}) else {
return;
};
cx.spawn(|mut cx| async move {
task.await;
if let Some(copilot) = cx.update(|cx| Copilot::global(cx)).ok().flatten() {
workspace
.update(&mut cx, |workspace, cx| match copilot.read(cx).status() {
Status::Authorized => workspace.show_toast(
Toast::new(
NotificationId::unique::<CopilotStartingToast>(),
"Copilot has started!",
),
cx,
),
_ => {
workspace.dismiss_toast(
&NotificationId::unique::<CopilotStartingToast>(),
cx,
);
copilot
.update(cx, |copilot, cx| copilot.sign_in(cx))
.detach_and_log_err(cx);
}
})
.log_err();
}
})
.detach();
}
_ => {
copilot.update(cx, |this, cx| this.sign_in(cx)).detach();
workspace
.update(cx, |this, cx| {
this.toggle_modal(cx, |cx| CopilotCodeVerification::new(&copilot, cx));
})
.ok();
}
}
}

View File

@ -1,7 +0,0 @@
pub mod copilot_button;
mod copilot_completion_provider;
mod sign_in;
pub use copilot_button::*;
pub use copilot_completion_provider::*;
pub use sign_in::*;

View File

@ -1757,19 +1757,22 @@ impl Editor {
self.completion_provider = Some(hub);
}
pub fn set_inline_completion_provider(
pub fn set_inline_completion_provider<T>(
&mut self,
provider: Model<impl InlineCompletionProvider>,
provider: Option<Model<T>>,
cx: &mut ViewContext<Self>,
) {
self.inline_completion_provider = Some(RegisteredInlineCompletionProvider {
_subscription: cx.observe(&provider, |this, _, cx| {
if this.focus_handle.is_focused(cx) {
this.update_visible_inline_completion(cx);
}
}),
provider: Arc::new(provider),
});
) where
T: InlineCompletionProvider,
{
self.inline_completion_provider =
provider.map(|provider| RegisteredInlineCompletionProvider {
_subscription: cx.observe(&provider, |this, _, cx| {
if this.focus_handle.is_focused(cx) {
this.update_visible_inline_completion(cx);
}
}),
provider: Arc::new(provider),
});
self.refresh_inline_completion(false, cx);
}
@ -2676,7 +2679,7 @@ impl Editor {
}
drop(snapshot);
let had_active_copilot_completion = this.has_active_inline_completion(cx);
let had_active_inline_completion = this.has_active_inline_completion(cx);
this.change_selections(Some(Autoscroll::fit()), cx, |s| s.select(new_selections));
if brace_inserted {
@ -2692,7 +2695,7 @@ impl Editor {
}
}
if had_active_copilot_completion {
if had_active_inline_completion {
this.refresh_inline_completion(true, cx);
if !this.has_active_inline_completion(cx) {
this.trigger_completion_on_input(&text, cx);
@ -4005,7 +4008,7 @@ impl Editor {
if !self.show_inline_completions
|| !provider.is_enabled(&buffer, cursor_buffer_position, cx)
{
self.clear_inline_completion(cx);
self.discard_inline_completion(cx);
return None;
}
@ -4207,13 +4210,6 @@ impl Editor {
self.discard_inline_completion(cx);
}
fn clear_inline_completion(&mut self, cx: &mut ViewContext<Self>) {
if let Some(old_completion) = self.active_inline_completion.take() {
self.splice_inlays(vec![old_completion.id], Vec::new(), cx);
}
self.discard_inline_completion(cx);
}
fn inline_completion_provider(&self) -> Option<Arc<dyn InlineCompletionProviderHandle>> {
Some(self.inline_completion_provider.as_ref()?.provider.clone())
}
@ -9947,12 +9943,14 @@ impl Editor {
.raw_user_settings()
.get("vim_mode")
== Some(&serde_json::Value::Bool(true));
let copilot_enabled = all_language_settings(file, cx).copilot_enabled(None, None);
let copilot_enabled = all_language_settings(file, cx).inline_completions.provider
== language::language_settings::InlineCompletionProvider::Copilot;
let copilot_enabled_for_language = self
.buffer
.read(cx)
.settings_at(0, cx)
.show_copilot_suggestions;
.show_inline_completions;
let telemetry = project.read(cx).client().telemetry().clone();
telemetry.report_editor_event(

View File

@ -25,11 +25,11 @@ pub trait InlineCompletionProvider: 'static + Sized {
);
fn accept(&mut self, cx: &mut ModelContext<Self>);
fn discard(&mut self, cx: &mut ModelContext<Self>);
fn active_completion_text(
&self,
fn active_completion_text<'a>(
&'a self,
buffer: &Model<Buffer>,
cursor_position: language::Anchor,
cx: &AppContext,
cx: &'a AppContext,
) -> Option<&str>;
}
@ -57,7 +57,7 @@ pub trait InlineCompletionProviderHandle {
fn accept(&self, cx: &mut AppContext);
fn discard(&self, cx: &mut AppContext);
fn active_completion_text<'a>(
&self,
&'a self,
buffer: &Model<Buffer>,
cursor_position: language::Anchor,
cx: &'a AppContext,
@ -110,7 +110,7 @@ where
}
fn active_completion_text<'a>(
&self,
&'a self,
buffer: &Model<Buffer>,
cursor_position: language::Anchor,
cx: &'a AppContext,

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use serde::{Deserialize, Serialize};
@ -5,8 +7,8 @@ use util::http::HttpClient;
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
pub async fn stream_generate_content<T: HttpClient>(
client: &T,
pub async fn stream_generate_content(
client: Arc<dyn HttpClient>,
api_url: &str,
api_key: &str,
request: GenerateContentRequest,

View File

@ -1,5 +1,5 @@
[package]
name = "copilot_ui"
name = "inline_completion_button"
version = "0.1.0"
edition = "2021"
publish = false
@ -9,19 +9,18 @@ license = "GPL-3.0-or-later"
workspace = true
[lib]
path = "src/copilot_ui.rs"
path = "src/inline_completion_button.rs"
doctest = false
[dependencies]
anyhow.workspace = true
client.workspace = true
copilot.workspace = true
editor.workspace = true
fs.workspace = true
gpui.workspace = true
language.workspace = true
menu.workspace = true
settings.workspace = true
supermaven.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true

View File

@ -0,0 +1,510 @@
use anyhow::Result;
use copilot::{Copilot, CopilotCodeVerification, Status};
use editor::{scroll::Autoscroll, Editor};
use fs::Fs;
use gpui::{
div, Action, AnchorCorner, AppContext, AsyncWindowContext, Entity, IntoElement, ParentElement,
Render, Subscription, View, ViewContext, WeakView, WindowContext,
};
use language::{
language_settings::{
self, all_language_settings, AllLanguageSettings, InlineCompletionProvider,
},
File, Language,
};
use settings::{update_settings_file, Settings, SettingsStore};
use std::{path::Path, sync::Arc};
use supermaven::{AccountStatus, Supermaven};
use util::{paths, ResultExt};
use workspace::{
create_and_open_local_file,
item::ItemHandle,
notifications::NotificationId,
ui::{
popover_menu, ButtonCommon, Clickable, ContextMenu, IconButton, IconName, IconSize, Tooltip,
},
StatusItemView, Toast, Workspace,
};
use zed_actions::OpenBrowser;
const COPILOT_SETTINGS_URL: &str = "https://github.com/settings/copilot";
struct CopilotStartingToast;
struct CopilotErrorToast;
pub struct InlineCompletionButton {
editor_subscription: Option<(Subscription, usize)>,
editor_enabled: Option<bool>,
language: Option<Arc<Language>>,
file: Option<Arc<dyn File>>,
fs: Arc<dyn Fs>,
}
enum SupermavenButtonStatus {
Ready,
Errored(String),
NeedsActivation(String),
Initializing,
}
impl Render for InlineCompletionButton {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let all_language_settings = all_language_settings(None, cx);
match all_language_settings.inline_completions.provider {
InlineCompletionProvider::None => return div(),
InlineCompletionProvider::Copilot => {
let Some(copilot) = Copilot::global(cx) else {
return div();
};
let status = copilot.read(cx).status();
let enabled = self.editor_enabled.unwrap_or_else(|| {
all_language_settings.inline_completions_enabled(None, None)
});
let icon = match status {
Status::Error(_) => IconName::CopilotError,
Status::Authorized => {
if enabled {
IconName::Copilot
} else {
IconName::CopilotDisabled
}
}
_ => IconName::CopilotInit,
};
if let Status::Error(e) = status {
return div().child(
IconButton::new("copilot-error", icon)
.icon_size(IconSize::Small)
.on_click(cx.listener(move |_, _, cx| {
if let Some(workspace) = cx.window_handle().downcast::<Workspace>()
{
workspace
.update(cx, |workspace, cx| {
workspace.show_toast(
Toast::new(
NotificationId::unique::<CopilotErrorToast>(),
format!("Copilot can't be started: {}", e),
)
.on_click("Reinstall Copilot", |cx| {
if let Some(copilot) = Copilot::global(cx) {
copilot
.update(cx, |copilot, cx| {
copilot.reinstall(cx)
})
.detach();
}
}),
cx,
);
})
.ok();
}
}))
.tooltip(|cx| Tooltip::text("GitHub Copilot", cx)),
);
}
let this = cx.view().clone();
div().child(
popover_menu("copilot")
.menu(move |cx| {
Some(match status {
Status::Authorized => {
this.update(cx, |this, cx| this.build_copilot_context_menu(cx))
}
_ => this.update(cx, |this, cx| this.build_copilot_start_menu(cx)),
})
})
.anchor(AnchorCorner::BottomRight)
.trigger(
IconButton::new("copilot-icon", icon)
.tooltip(|cx| Tooltip::text("GitHub Copilot", cx)),
),
)
}
InlineCompletionProvider::Supermaven => {
let Some(supermaven) = Supermaven::global(cx) else {
return div();
};
let supermaven = supermaven.read(cx);
let status = match supermaven {
Supermaven::Starting => SupermavenButtonStatus::Initializing,
Supermaven::FailedDownload { error } => {
SupermavenButtonStatus::Errored(error.to_string())
}
Supermaven::Spawned(agent) => {
let account_status = agent.account_status.clone();
match account_status {
AccountStatus::NeedsActivation { activate_url } => {
SupermavenButtonStatus::NeedsActivation(activate_url.clone())
}
AccountStatus::Unknown => SupermavenButtonStatus::Initializing,
AccountStatus::Ready => SupermavenButtonStatus::Ready,
}
}
Supermaven::Error { error } => {
SupermavenButtonStatus::Errored(error.to_string())
}
};
let icon = status.to_icon();
let tooltip_text = status.to_tooltip();
let this = cx.view().clone();
return div().child(
popover_menu("supermaven")
.menu(move |cx| match &status {
SupermavenButtonStatus::NeedsActivation(activate_url) => {
Some(ContextMenu::build(cx, |menu, _| {
let activate_url = activate_url.clone();
menu.entry("Sign In", None, move |cx| {
cx.open_url(activate_url.as_str())
})
}))
}
SupermavenButtonStatus::Ready => Some(
this.update(cx, |this, cx| this.build_supermaven_context_menu(cx)),
),
_ => None,
})
.anchor(AnchorCorner::BottomRight)
.trigger(
IconButton::new("supermaven-icon", icon)
.tooltip(move |cx| Tooltip::text(tooltip_text.clone(), cx)),
),
);
}
}
}
}
impl InlineCompletionButton {
pub fn new(fs: Arc<dyn Fs>, cx: &mut ViewContext<Self>) -> Self {
if let Some(copilot) = Copilot::global(cx) {
cx.observe(&copilot, |_, _, cx| cx.notify()).detach()
}
cx.observe_global::<SettingsStore>(move |_, cx| cx.notify())
.detach();
Self {
editor_subscription: None,
editor_enabled: None,
language: None,
file: None,
fs,
}
}
pub fn build_copilot_start_menu(&mut self, cx: &mut ViewContext<Self>) -> View<ContextMenu> {
let fs = self.fs.clone();
ContextMenu::build(cx, |menu, _| {
menu.entry("Sign In", None, initiate_sign_in).entry(
"Disable Copilot",
None,
move |cx| hide_copilot(fs.clone(), cx),
)
})
}
pub fn build_language_settings_menu(
&self,
mut menu: ContextMenu,
cx: &mut WindowContext,
) -> ContextMenu {
let fs = self.fs.clone();
if let Some(language) = self.language.clone() {
let fs = fs.clone();
let language_enabled = language_settings::language_settings(Some(&language), None, cx)
.show_inline_completions;
menu = menu.entry(
format!(
"{} Inline Completions for {}",
if language_enabled { "Hide" } else { "Show" },
language.name()
),
None,
move |cx| toggle_inline_completions_for_language(language.clone(), fs.clone(), cx),
);
}
let settings = AllLanguageSettings::get_global(cx);
if let Some(file) = &self.file {
let path = file.path().clone();
let path_enabled = settings.inline_completions_enabled_for_path(&path);
menu = menu.entry(
format!(
"{} Inline Completions for This Path",
if path_enabled { "Hide" } else { "Show" }
),
None,
move |cx| {
if let Some(workspace) = cx.window_handle().downcast::<Workspace>() {
if let Ok(workspace) = workspace.root_view(cx) {
let workspace = workspace.downgrade();
cx.spawn(|cx| {
configure_disabled_globs(
workspace,
path_enabled.then_some(path.clone()),
cx,
)
})
.detach_and_log_err(cx);
}
}
},
);
}
let globally_enabled = settings.inline_completions_enabled(None, None);
menu.entry(
if globally_enabled {
"Hide Inline Completions for All Files"
} else {
"Show Inline Completions for All Files"
},
None,
move |cx| toggle_inline_completions_globally(fs.clone(), cx),
)
}
fn build_copilot_context_menu(&self, cx: &mut ViewContext<Self>) -> View<ContextMenu> {
ContextMenu::build(cx, |menu, cx| {
self.build_language_settings_menu(menu, cx)
.separator()
.link(
"Copilot Settings",
OpenBrowser {
url: COPILOT_SETTINGS_URL.to_string(),
}
.boxed_clone(),
)
.action("Sign Out", copilot::SignOut.boxed_clone())
})
}
fn build_supermaven_context_menu(&self, cx: &mut ViewContext<Self>) -> View<ContextMenu> {
ContextMenu::build(cx, |menu, cx| {
self.build_language_settings_menu(menu, cx).separator()
})
}
pub fn update_enabled(&mut self, editor: View<Editor>, cx: &mut ViewContext<Self>) {
let editor = editor.read(cx);
let snapshot = editor.buffer().read(cx).snapshot(cx);
let suggestion_anchor = editor.selections.newest_anchor().start;
let language = snapshot.language_at(suggestion_anchor);
let file = snapshot.file_at(suggestion_anchor).cloned();
self.editor_enabled = {
let file = file.as_ref();
Some(
file.map(|file| !file.is_private()).unwrap_or(true)
&& all_language_settings(file, cx).inline_completions_enabled(
language,
file.map(|file| file.path().as_ref()),
),
)
};
self.language = language.cloned();
self.file = file;
cx.notify()
}
}
impl StatusItemView for InlineCompletionButton {
fn set_active_pane_item(&mut self, item: Option<&dyn ItemHandle>, cx: &mut ViewContext<Self>) {
if let Some(editor) = item.and_then(|item| item.act_as::<Editor>(cx)) {
self.editor_subscription = Some((
cx.observe(&editor, Self::update_enabled),
editor.entity_id().as_u64() as usize,
));
self.update_enabled(editor, cx);
} else {
self.language = None;
self.editor_subscription = None;
self.editor_enabled = None;
}
cx.notify();
}
}
impl SupermavenButtonStatus {
fn to_icon(&self) -> IconName {
match self {
SupermavenButtonStatus::Ready => IconName::Supermaven,
SupermavenButtonStatus::Errored(_) => IconName::SupermavenError,
SupermavenButtonStatus::NeedsActivation(_) => IconName::SupermavenInit,
SupermavenButtonStatus::Initializing => IconName::SupermavenInit,
}
}
fn to_tooltip(&self) -> String {
match self {
SupermavenButtonStatus::Ready => "Supermaven is ready".to_string(),
SupermavenButtonStatus::Errored(error) => format!("Supermaven error: {}", error),
SupermavenButtonStatus::NeedsActivation(_) => "Supermaven needs activation".to_string(),
SupermavenButtonStatus::Initializing => "Supermaven initializing".to_string(),
}
}
}
async fn configure_disabled_globs(
workspace: WeakView<Workspace>,
path_to_disable: Option<Arc<Path>>,
mut cx: AsyncWindowContext,
) -> Result<()> {
let settings_editor = workspace
.update(&mut cx, |_, cx| {
create_and_open_local_file(&paths::SETTINGS, cx, || {
settings::initial_user_settings_content().as_ref().into()
})
})?
.await?
.downcast::<Editor>()
.unwrap();
settings_editor.downgrade().update(&mut cx, |item, cx| {
let text = item.buffer().read(cx).snapshot(cx).text();
let settings = cx.global::<SettingsStore>();
let edits = settings.edits_for_update::<AllLanguageSettings>(&text, |file| {
let copilot = file.inline_completions.get_or_insert_with(Default::default);
let globs = copilot.disabled_globs.get_or_insert_with(|| {
settings
.get::<AllLanguageSettings>(None)
.inline_completions
.disabled_globs
.iter()
.map(|glob| glob.glob().to_string())
.collect()
});
if let Some(path_to_disable) = &path_to_disable {
globs.push(path_to_disable.to_string_lossy().into_owned());
} else {
globs.clear();
}
});
if !edits.is_empty() {
item.change_selections(Some(Autoscroll::newest()), cx, |selections| {
selections.select_ranges(edits.iter().map(|e| e.0.clone()));
});
// When *enabling* a path, don't actually perform an edit, just select the range.
if path_to_disable.is_some() {
item.edit(edits.iter().cloned(), cx);
}
}
})?;
anyhow::Ok(())
}
fn toggle_inline_completions_globally(fs: Arc<dyn Fs>, cx: &mut AppContext) {
let show_inline_completions =
all_language_settings(None, cx).inline_completions_enabled(None, None);
update_settings_file::<AllLanguageSettings>(fs, cx, move |file| {
file.defaults.show_inline_completions = Some(!show_inline_completions)
});
}
fn toggle_inline_completions_for_language(
language: Arc<Language>,
fs: Arc<dyn Fs>,
cx: &mut AppContext,
) {
let show_inline_completions =
all_language_settings(None, cx).inline_completions_enabled(Some(&language), None);
update_settings_file::<AllLanguageSettings>(fs, cx, move |file| {
file.languages
.entry(language.name())
.or_default()
.show_inline_completions = Some(!show_inline_completions);
});
}
fn hide_copilot(fs: Arc<dyn Fs>, cx: &mut AppContext) {
update_settings_file::<AllLanguageSettings>(fs, cx, move |file| {
file.features.get_or_insert(Default::default()).copilot = Some(false);
});
}
pub fn initiate_sign_in(cx: &mut WindowContext) {
let Some(copilot) = Copilot::global(cx) else {
return;
};
let status = copilot.read(cx).status();
let Some(workspace) = cx.window_handle().downcast::<Workspace>() else {
return;
};
match status {
Status::Starting { task } => {
let Some(workspace) = cx.window_handle().downcast::<Workspace>() else {
return;
};
let Ok(workspace) = workspace.update(cx, |workspace, cx| {
workspace.show_toast(
Toast::new(
NotificationId::unique::<CopilotStartingToast>(),
"Copilot is starting...",
),
cx,
);
workspace.weak_handle()
}) else {
return;
};
cx.spawn(|mut cx| async move {
task.await;
if let Some(copilot) = cx.update(|cx| Copilot::global(cx)).ok().flatten() {
workspace
.update(&mut cx, |workspace, cx| match copilot.read(cx).status() {
Status::Authorized => workspace.show_toast(
Toast::new(
NotificationId::unique::<CopilotStartingToast>(),
"Copilot has started!",
),
cx,
),
_ => {
workspace.dismiss_toast(
&NotificationId::unique::<CopilotStartingToast>(),
cx,
);
copilot
.update(cx, |copilot, cx| copilot.sign_in(cx))
.detach_and_log_err(cx);
}
})
.log_err();
}
})
.detach();
}
_ => {
copilot.update(cx, |this, cx| this.sign_in(cx)).detach();
workspace
.update(cx, |this, cx| {
this.toggle_modal(cx, |cx| CopilotCodeVerification::new(&copilot, cx));
})
.ok();
}
}
}

View File

@ -51,8 +51,8 @@ pub fn all_language_settings<'a>(
/// The settings for all languages.
#[derive(Debug, Clone)]
pub struct AllLanguageSettings {
/// The settings for GitHub Copilot.
pub copilot: CopilotSettings,
/// The inline completion settings.
pub inline_completions: InlineCompletionSettings,
defaults: LanguageSettings,
languages: HashMap<Arc<str>, LanguageSettings>,
pub(crate) file_types: HashMap<Arc<str>, Vec<String>>,
@ -101,9 +101,9 @@ pub struct LanguageSettings {
/// - `"!<language_server_id>"` - A language server ID prefixed with a `!` will be disabled.
/// - `"..."` - A placeholder to refer to the **rest** of the registered language servers for this language.
pub language_servers: Vec<Arc<str>>,
/// Controls whether Copilot provides suggestion immediately (true)
/// or waits for a `copilot::Toggle` (false).
pub show_copilot_suggestions: bool,
/// Controls whether inline completions are shown immediately (true)
/// or manually by triggering `editor::ShowInlineCompletion` (false).
pub show_inline_completions: bool,
/// Whether to show tabs and spaces in the editor.
pub show_whitespaces: ShowWhitespaceSetting,
/// Whether to start a new line with a comment when a previous line is a comment as well.
@ -165,12 +165,23 @@ impl LanguageSettings {
}
}
/// The settings for [GitHub Copilot](https://github.com/features/copilot).
/// The provider that supplies inline completions.
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum InlineCompletionProvider {
None,
#[default]
Copilot,
Supermaven,
}
/// The settings for inline completions, such as [GitHub Copilot](https://github.com/features/copilot)
/// or [Supermaven](https://supermaven.com).
#[derive(Clone, Debug, Default)]
pub struct CopilotSettings {
/// Whether Copilot is enabled.
pub feature_enabled: bool,
/// A list of globs representing files that Copilot should be disabled for.
pub struct InlineCompletionSettings {
/// The provider that supplies inline completions.
pub provider: InlineCompletionProvider,
/// A list of globs representing files that inline completions should be disabled for.
pub disabled_globs: Vec<GlobMatcher>,
}
@ -180,9 +191,9 @@ pub struct AllLanguageSettingsContent {
/// The settings for enabling/disabling features.
#[serde(default)]
pub features: Option<FeaturesContent>,
/// The settings for GitHub Copilot.
#[serde(default)]
pub copilot: Option<CopilotSettingsContent>,
/// The inline completion settings.
#[serde(default, alias = "copilot")]
pub inline_completions: Option<InlineCompletionSettingsContent>,
/// The default language settings.
#[serde(flatten)]
pub defaults: LanguageSettingsContent,
@ -277,12 +288,12 @@ pub struct LanguageSettingsContent {
/// Default: ["..."]
#[serde(default)]
pub language_servers: Option<Vec<Arc<str>>>,
/// Controls whether Copilot provides suggestion immediately (true)
/// or waits for a `copilot::Toggle` (false).
/// Controls whether inline completions are shown immediately (true)
/// or manually by triggering `editor::ShowInlineCompletion` (false).
///
/// Default: true
#[serde(default)]
pub show_copilot_suggestions: Option<bool>,
#[serde(default, alias = "show_copilot_suggestions")]
pub show_inline_completions: Option<bool>,
/// Whether to show tabs and spaces in the editor.
#[serde(default)]
pub show_whitespaces: Option<ShowWhitespaceSetting>,
@ -314,10 +325,10 @@ pub struct LanguageSettingsContent {
pub code_actions_on_format: Option<HashMap<String, bool>>,
}
/// The contents of the GitHub Copilot settings.
#[derive(Clone, Debug, PartialEq, Default, Serialize, Deserialize, JsonSchema)]
pub struct CopilotSettingsContent {
/// A list of globs representing files that Copilot should be disabled for.
/// The contents of the inline completion settings.
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct InlineCompletionSettingsContent {
/// A list of globs representing files that inline completions should be disabled for.
#[serde(default)]
pub disabled_globs: Option<Vec<String>>,
}
@ -328,6 +339,8 @@ pub struct CopilotSettingsContent {
pub struct FeaturesContent {
/// Whether the GitHub Copilot feature is enabled.
pub copilot: Option<bool>,
/// Determines which inline completion provider to use.
pub inline_completion_provider: Option<InlineCompletionProvider>,
}
/// Controls the soft-wrapping behavior in the editor.
@ -475,29 +488,29 @@ impl AllLanguageSettings {
&self.defaults
}
/// Returns whether GitHub Copilot is enabled for the given path.
pub fn copilot_enabled_for_path(&self, path: &Path) -> bool {
/// Returns whether inline completions are enabled for the given path.
pub fn inline_completions_enabled_for_path(&self, path: &Path) -> bool {
!self
.copilot
.inline_completions
.disabled_globs
.iter()
.any(|glob| glob.is_match(path))
}
/// Returns whether GitHub Copilot is enabled for the given language and path.
pub fn copilot_enabled(&self, language: Option<&Arc<Language>>, path: Option<&Path>) -> bool {
if !self.copilot.feature_enabled {
return false;
}
/// Returns whether inline completions are enabled for the given language and path.
pub fn inline_completions_enabled(
&self,
language: Option<&Arc<Language>>,
path: Option<&Path>,
) -> bool {
if let Some(path) = path {
if !self.copilot_enabled_for_path(path) {
if !self.inline_completions_enabled_for_path(path) {
return false;
}
}
self.language(language.map(|l| l.name()).as_deref())
.show_copilot_suggestions
.show_inline_completions
}
}
@ -551,13 +564,13 @@ impl settings::Settings for AllLanguageSettings {
languages.insert(language_name.clone(), language_settings);
}
let mut copilot_enabled = default_value
let mut copilot_enabled = default_value.features.as_ref().and_then(|f| f.copilot);
let mut inline_completion_provider = default_value
.features
.as_ref()
.and_then(|f| f.copilot)
.ok_or_else(Self::missing_default)?;
let mut copilot_globs = default_value
.copilot
.and_then(|f| f.inline_completion_provider);
let mut completion_globs = default_value
.inline_completions
.as_ref()
.and_then(|c| c.disabled_globs.as_ref())
.ok_or_else(Self::missing_default)?;
@ -565,14 +578,21 @@ impl settings::Settings for AllLanguageSettings {
let mut file_types: HashMap<Arc<str>, Vec<String>> = HashMap::default();
for user_settings in sources.customizations() {
if let Some(copilot) = user_settings.features.as_ref().and_then(|f| f.copilot) {
copilot_enabled = copilot;
copilot_enabled = Some(copilot);
}
if let Some(provider) = user_settings
.features
.as_ref()
.and_then(|f| f.inline_completion_provider)
{
inline_completion_provider = Some(provider);
}
if let Some(globs) = user_settings
.copilot
.inline_completions
.as_ref()
.and_then(|f| f.disabled_globs.as_ref())
{
copilot_globs = globs;
completion_globs = globs;
}
// A user's global settings override the default global settings and
@ -601,9 +621,15 @@ impl settings::Settings for AllLanguageSettings {
}
Ok(Self {
copilot: CopilotSettings {
feature_enabled: copilot_enabled,
disabled_globs: copilot_globs
inline_completions: InlineCompletionSettings {
provider: if let Some(provider) = inline_completion_provider {
provider
} else if copilot_enabled.unwrap_or(true) {
InlineCompletionProvider::Copilot
} else {
InlineCompletionProvider::None
},
disabled_globs: completion_globs
.iter()
.filter_map(|g| Some(globset::Glob::new(g).ok()?.compile_matcher()))
.collect(),
@ -714,8 +740,8 @@ fn merge_settings(settings: &mut LanguageSettings, src: &LanguageSettingsContent
);
merge(&mut settings.language_servers, src.language_servers.clone());
merge(
&mut settings.show_copilot_suggestions,
src.show_copilot_suggestions,
&mut settings.show_inline_completions,
src.show_inline_completions,
);
merge(&mut settings.show_whitespaces, src.show_whitespaces);
merge(

View File

@ -15,6 +15,7 @@ doctest = false
[dependencies]
anyhow.workspace = true
collections.workspace = true
copilot.workspace = true
editor.workspace = true
futures.workspace = true
gpui.workspace = true
@ -26,7 +27,6 @@ settings.workspace = true
theme.workspace = true
tree-sitter.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true
[dev-dependencies]

View File

@ -1,4 +1,5 @@
use collections::{HashMap, VecDeque};
use copilot::Copilot;
use editor::{actions::MoveToEnd, Editor, EditorEvent};
use futures::{channel::mpsc, StreamExt};
use gpui::{
@ -7,11 +8,10 @@ use gpui::{
View, ViewContext, VisualContext, WeakModel, WindowContext,
};
use language::{LanguageServerId, LanguageServerName};
use lsp::IoKind;
use lsp::{IoKind, LanguageServer};
use project::{search::SearchQuery, Project};
use std::{borrow::Cow, sync::Arc};
use ui::{popover_menu, prelude::*, Button, Checkbox, ContextMenu, Label, Selection};
use util::maybe;
use workspace::{
item::{Item, ItemHandle, TabContentParams},
searchable::{SearchEvent, SearchableItem, SearchableItemHandle},
@ -24,17 +24,21 @@ const MAX_STORED_LOG_ENTRIES: usize = 2000;
pub struct LogStore {
projects: HashMap<WeakModel<Project>, ProjectState>,
io_tx: mpsc::UnboundedSender<(WeakModel<Project>, LanguageServerId, IoKind, String)>,
language_servers: HashMap<LanguageServerId, LanguageServerState>,
copilot_log_subscription: Option<lsp::Subscription>,
_copilot_subscription: Option<gpui::Subscription>,
io_tx: mpsc::UnboundedSender<(LanguageServerId, IoKind, String)>,
}
struct ProjectState {
servers: HashMap<LanguageServerId, LanguageServerState>,
_subscriptions: [gpui::Subscription; 2],
}
struct LanguageServerState {
name: LanguageServerName,
log_messages: VecDeque<String>,
rpc_state: Option<LanguageServerRpcState>,
project: Option<WeakModel<Project>>,
_io_logs_subscription: Option<lsp::Subscription>,
_lsp_logs_subscription: Option<lsp::Subscription>,
}
@ -109,15 +113,55 @@ pub fn init(cx: &mut AppContext) {
impl LogStore {
pub fn new(cx: &mut ModelContext<Self>) -> Self {
let (io_tx, mut io_rx) = mpsc::unbounded();
let copilot_subscription = Copilot::global(cx).map(|copilot| {
let copilot = &copilot;
cx.subscribe(
copilot,
|this, copilot, copilot_event, cx| match copilot_event {
copilot::Event::CopilotLanguageServerStarted => {
if let Some(server) = copilot.read(cx).language_server() {
let server_id = server.server_id();
let weak_this = cx.weak_model();
this.copilot_log_subscription =
Some(server.on_notification::<copilot::request::LogMessage, _>(
move |params, mut cx| {
weak_this
.update(&mut cx, |this, cx| {
this.add_language_server_log(
server_id,
&params.message,
cx,
);
})
.ok();
},
));
this.add_language_server(
None,
LanguageServerName(Arc::from("copilot")),
server.clone(),
cx,
);
}
}
},
)
});
let this = Self {
copilot_log_subscription: None,
_copilot_subscription: copilot_subscription,
projects: HashMap::default(),
language_servers: HashMap::default(),
io_tx,
};
cx.spawn(|this, mut cx| async move {
while let Some((project, server_id, io_kind, message)) = io_rx.next().await {
while let Some((server_id, io_kind, message)) = io_rx.next().await {
if let Some(this) = this.upgrade() {
this.update(&mut cx, |this, cx| {
this.on_io(project, server_id, io_kind, &message, cx);
this.on_io(server_id, io_kind, &message, cx);
})?;
}
}
@ -132,20 +176,32 @@ impl LogStore {
self.projects.insert(
project.downgrade(),
ProjectState {
servers: HashMap::default(),
_subscriptions: [
cx.observe_release(project, move |this, _, _| {
this.projects.remove(&weak_project);
this.language_servers
.retain(|_, state| state.project.as_ref() != Some(&weak_project));
}),
cx.subscribe(project, |this, project, event, cx| match event {
project::Event::LanguageServerAdded(id) => {
this.add_language_server(&project, *id, cx);
let read_project = project.read(cx);
if let Some((server, adapter)) = read_project
.language_server_for_id(*id)
.zip(read_project.language_server_adapter_for_id(*id))
{
this.add_language_server(
Some(&project.downgrade()),
adapter.name.clone(),
server,
cx,
);
}
}
project::Event::LanguageServerRemoved(id) => {
this.remove_language_server(&project, *id, cx);
this.remove_language_server(*id, cx);
}
project::Event::LanguageServerLog(id, message) => {
this.add_language_server_log(&project, *id, message, cx);
this.add_language_server_log(*id, message, cx);
}
_ => {}
}),
@ -154,74 +210,69 @@ impl LogStore {
);
}
fn get_language_server_state(
&mut self,
id: LanguageServerId,
) -> Option<&mut LanguageServerState> {
self.language_servers.get_mut(&id)
}
fn add_language_server(
&mut self,
project: &Model<Project>,
id: LanguageServerId,
project: Option<&WeakModel<Project>>,
name: LanguageServerName,
server: Arc<LanguageServer>,
cx: &mut ModelContext<Self>,
) -> Option<&mut LanguageServerState> {
let project_state = self.projects.get_mut(&project.downgrade())?;
let server_state = project_state.servers.entry(id).or_insert_with(|| {
cx.notify();
LanguageServerState {
rpc_state: None,
log_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES),
_io_logs_subscription: None,
_lsp_logs_subscription: None,
}
});
let server_state = self
.language_servers
.entry(server.server_id())
.or_insert_with(|| {
cx.notify();
LanguageServerState {
name,
rpc_state: None,
project: project.cloned(),
log_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES),
_io_logs_subscription: None,
_lsp_logs_subscription: None,
}
});
let server = project.read(cx).language_server_for_id(id);
if let Some(server) = server.as_deref() {
if server.has_notification_handler::<lsp::notification::LogMessage>() {
// Another event wants to re-add the server that was already added and subscribed to, avoid doing it again.
return Some(server_state);
}
if server.has_notification_handler::<lsp::notification::LogMessage>() {
// Another event wants to re-add the server that was already added and subscribed to, avoid doing it again.
return Some(server_state);
}
let weak_project = project.downgrade();
let io_tx = self.io_tx.clone();
server_state._io_logs_subscription = server.as_ref().map(|server| {
server.on_io(move |io_kind, message| {
io_tx
.unbounded_send((weak_project.clone(), id, io_kind, message.to_string()))
.ok();
})
});
let server_id = server.server_id();
server_state._io_logs_subscription = Some(server.on_io(move |io_kind, message| {
io_tx
.unbounded_send((server_id, io_kind, message.to_string()))
.ok();
}));
let this = cx.handle().downgrade();
let weak_project = project.downgrade();
server_state._lsp_logs_subscription = server.map(|server| {
let server_id = server.server_id();
server.on_notification::<lsp::notification::LogMessage, _>({
server_state._lsp_logs_subscription =
Some(server.on_notification::<lsp::notification::LogMessage, _>({
move |params, mut cx| {
if let Some((project, this)) = weak_project.upgrade().zip(this.upgrade()) {
if let Some(this) = this.upgrade() {
this.update(&mut cx, |this, cx| {
this.add_language_server_log(&project, server_id, &params.message, cx);
this.add_language_server_log(server_id, &params.message, cx);
})
.ok();
}
}
})
});
}));
Some(server_state)
}
fn add_language_server_log(
&mut self,
project: &Model<Project>,
id: LanguageServerId,
message: &str,
cx: &mut ModelContext<Self>,
) -> Option<()> {
let language_server_state = match self
.projects
.get_mut(&project.downgrade())?
.servers
.get_mut(&id)
{
Some(existing_state) => existing_state,
None => self.add_language_server(&project, id, cx)?,
};
let language_server_state = self.get_language_server_state(id)?;
let log_lines = &mut language_server_state.log_messages;
while log_lines.len() >= MAX_STORED_LOG_ENTRIES {
@ -238,38 +289,43 @@ impl LogStore {
Some(())
}
fn remove_language_server(
&mut self,
project: &Model<Project>,
id: LanguageServerId,
cx: &mut ModelContext<Self>,
) -> Option<()> {
let project_state = self.projects.get_mut(&project.downgrade())?;
project_state.servers.remove(&id);
fn remove_language_server(&mut self, id: LanguageServerId, cx: &mut ModelContext<Self>) {
self.language_servers.remove(&id);
cx.notify();
Some(())
}
fn server_logs(
&self,
project: &Model<Project>,
server_id: LanguageServerId,
) -> Option<&VecDeque<String>> {
let weak_project = project.downgrade();
let project_state = self.projects.get(&weak_project)?;
let server_state = project_state.servers.get(&server_id)?;
Some(&server_state.log_messages)
fn server_logs(&self, server_id: LanguageServerId) -> Option<&VecDeque<String>> {
Some(&self.language_servers.get(&server_id)?.log_messages)
}
fn server_ids_for_project<'a>(
&'a self,
project: &'a WeakModel<Project>,
) -> impl Iterator<Item = LanguageServerId> + 'a {
[].into_iter()
.chain(self.language_servers.iter().filter_map(|(id, state)| {
if state.project.as_ref() == Some(project) {
return Some(*id);
} else {
None
}
}))
.chain(self.language_servers.iter().filter_map(|(id, state)| {
if state.project.is_none() {
return Some(*id);
} else {
None
}
}))
}
fn enable_rpc_trace_for_language_server(
&mut self,
project: &Model<Project>,
server_id: LanguageServerId,
) -> Option<&mut LanguageServerRpcState> {
let weak_project = project.downgrade();
let project_state = self.projects.get_mut(&weak_project)?;
let server_state = project_state.servers.get_mut(&server_id)?;
let rpc_state = server_state
let rpc_state = self
.language_servers
.get_mut(&server_id)?
.rpc_state
.get_or_insert_with(|| LanguageServerRpcState {
rpc_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES),
@ -280,20 +336,14 @@ impl LogStore {
pub fn disable_rpc_trace_for_language_server(
&mut self,
project: &Model<Project>,
server_id: LanguageServerId,
_: &mut ModelContext<Self>,
) -> Option<()> {
let project = project.downgrade();
let project_state = self.projects.get_mut(&project)?;
let server_state = project_state.servers.get_mut(&server_id)?;
server_state.rpc_state.take();
self.language_servers.get_mut(&server_id)?.rpc_state.take();
Some(())
}
fn on_io(
&mut self,
project: WeakModel<Project>,
language_server_id: LanguageServerId,
io_kind: IoKind,
message: &str,
@ -303,18 +353,14 @@ impl LogStore {
IoKind::StdOut => true,
IoKind::StdIn => false,
IoKind::StdErr => {
let project = project.upgrade()?;
let message = format!("stderr: {}", message.trim());
self.add_language_server_log(&project, language_server_id, &message, cx);
self.add_language_server_log(language_server_id, &message, cx);
return Some(());
}
};
let state = self
.projects
.get_mut(&project)?
.servers
.get_mut(&language_server_id)?
.get_language_server_state(language_server_id)?
.rpc_state
.as_mut()?;
let kind = if is_received {
@ -360,42 +406,40 @@ impl LspLogView {
) -> Self {
let server_id = log_store
.read(cx)
.projects
.get(&project.downgrade())
.and_then(|project| project.servers.keys().copied().next());
let model_changes_subscription = cx.observe(&log_store, |this, store, cx| {
maybe!({
let project_state = store.read(cx).projects.get(&this.project.downgrade())?;
if let Some(current_lsp) = this.current_server_id {
if !project_state.servers.contains_key(&current_lsp) {
if let Some(server) = project_state.servers.iter().next() {
if this.is_showing_rpc_trace {
this.show_rpc_trace_for_server(*server.0, cx)
} else {
this.show_logs_for_server(*server.0, cx)
}
} else {
this.current_server_id = None;
this.editor.update(cx, |editor, cx| {
editor.set_read_only(false);
editor.clear(cx);
editor.set_read_only(true);
});
cx.notify();
}
}
} else {
if let Some(server) = project_state.servers.iter().next() {
.language_servers
.iter()
.find(|(_, server)| server.project == Some(project.downgrade()))
.map(|(id, _)| *id);
let weak_project = project.downgrade();
let model_changes_subscription = cx.observe(&log_store, move |this, store, cx| {
let first_server_id_for_project =
store.read(cx).server_ids_for_project(&weak_project).next();
if let Some(current_lsp) = this.current_server_id {
if !store.read(cx).language_servers.contains_key(&current_lsp) {
if let Some(server_id) = first_server_id_for_project {
if this.is_showing_rpc_trace {
this.show_rpc_trace_for_server(*server.0, cx)
this.show_rpc_trace_for_server(server_id, cx)
} else {
this.show_logs_for_server(*server.0, cx)
this.show_logs_for_server(server_id, cx)
}
} else {
this.current_server_id = None;
this.editor.update(cx, |editor, cx| {
editor.set_read_only(false);
editor.clear(cx);
editor.set_read_only(true);
});
cx.notify();
}
}
Some(())
});
} else if let Some(server_id) = first_server_id_for_project {
if this.is_showing_rpc_trace {
this.show_rpc_trace_for_server(server_id, cx)
} else {
this.show_logs_for_server(server_id, cx)
}
}
cx.notify();
});
@ -477,14 +521,14 @@ impl LspLogView {
pub(crate) fn menu_items<'a>(&'a self, cx: &'a AppContext) -> Option<Vec<LogMenuItem>> {
let log_store = self.log_store.read(cx);
let state = log_store.projects.get(&self.project.downgrade())?;
let mut rows = self
.project
.read(cx)
.language_servers()
.filter_map(|(server_id, language_server_name, worktree_id)| {
let worktree = self.project.read(cx).worktree_for_id(worktree_id, cx)?;
let state = state.servers.get(&server_id)?;
let state = log_store.language_servers.get(&server_id)?;
Some(LogMenuItem {
server_id,
server_name: language_server_name,
@ -501,7 +545,7 @@ impl LspLogView {
.read(cx)
.supplementary_language_servers()
.filter_map(|(&server_id, (name, _))| {
let state = state.servers.get(&server_id)?;
let state = log_store.language_servers.get(&server_id)?;
Some(LogMenuItem {
server_id,
server_name: name.clone(),
@ -514,6 +558,27 @@ impl LspLogView {
})
}),
)
.chain(
log_store
.language_servers
.iter()
.filter_map(|(server_id, state)| {
if state.project.is_none() {
Some(LogMenuItem {
server_id: *server_id,
server_name: state.name.clone(),
worktree_root_name: "supplementary".to_string(),
rpc_trace_enabled: state.rpc_state.is_some(),
rpc_trace_selected: self.is_showing_rpc_trace
&& self.current_server_id == Some(*server_id),
logs_selected: !self.is_showing_rpc_trace
&& self.current_server_id == Some(*server_id),
})
} else {
None
}
}),
)
.collect::<Vec<_>>();
rows.sort_by_key(|row| row.server_id);
rows.dedup_by_key(|row| row.server_id);
@ -524,7 +589,7 @@ impl LspLogView {
let log_contents = self
.log_store
.read(cx)
.server_logs(&self.project, server_id)
.server_logs(server_id)
.map(log_contents);
if let Some(log_contents) = log_contents {
self.current_server_id = Some(server_id);
@ -544,7 +609,7 @@ impl LspLogView {
) {
let rpc_log = self.log_store.update(cx, |log_store, _| {
log_store
.enable_rpc_trace_for_language_server(&self.project, server_id)
.enable_rpc_trace_for_language_server(server_id)
.map(|state| log_contents(&state.rpc_messages))
});
if let Some(rpc_log) = rpc_log {
@ -585,11 +650,11 @@ impl LspLogView {
enabled: bool,
cx: &mut ViewContext<Self>,
) {
self.log_store.update(cx, |log_store, cx| {
self.log_store.update(cx, |log_store, _| {
if enabled {
log_store.enable_rpc_trace_for_language_server(&self.project, server_id);
log_store.enable_rpc_trace_for_language_server(server_id);
} else {
log_store.disable_rpc_trace_for_language_server(&self.project, server_id, cx);
log_store.disable_rpc_trace_for_language_server(server_id);
}
});
if !enabled && Some(server_id) == self.current_server_id {

View File

@ -30,7 +30,6 @@ async-trait.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
copilot.workspace = true
fs.workspace = true
futures.workspace = true
fuzzy.workspace = true

View File

@ -20,7 +20,6 @@ use client::{
};
use clock::ReplicaId;
use collections::{hash_map, BTreeMap, HashMap, HashSet, VecDeque};
use copilot::Copilot;
use debounced_delay::DebouncedDelay;
use futures::{
channel::{
@ -200,8 +199,6 @@ pub struct Project {
_maintain_buffer_languages: Task<()>,
_maintain_workspace_config: Task<Result<()>>,
terminals: Terminals,
copilot_lsp_subscription: Option<gpui::Subscription>,
copilot_log_subscription: Option<lsp::Subscription>,
current_lsp_settings: HashMap<Arc<str>, LspSettings>,
node: Option<Arc<dyn NodeRuntime>>,
default_prettier: DefaultPrettier,
@ -685,8 +682,6 @@ impl Project {
let (tx, rx) = mpsc::unbounded();
cx.spawn(move |this, cx| Self::send_buffer_ordered_messages(this, rx, cx))
.detach();
let copilot_lsp_subscription =
Copilot::global(cx).map(|copilot| subscribe_for_copilot_events(&copilot, cx));
let tasks = Inventory::new(cx);
Self {
@ -735,8 +730,6 @@ impl Project {
terminals: Terminals {
local_handles: Vec::new(),
},
copilot_lsp_subscription,
copilot_log_subscription: None,
current_lsp_settings: ProjectSettings::get_global(cx).lsp.clone(),
node: Some(node),
default_prettier: DefaultPrettier::default(),
@ -823,8 +816,6 @@ impl Project {
let (tx, rx) = mpsc::unbounded();
cx.spawn(move |this, cx| Self::send_buffer_ordered_messages(this, rx, cx))
.detach();
let copilot_lsp_subscription =
Copilot::global(cx).map(|copilot| subscribe_for_copilot_events(&copilot, cx));
let mut this = Self {
worktrees: Vec::new(),
buffer_ordered_messages_tx: tx,
@ -891,8 +882,6 @@ impl Project {
terminals: Terminals {
local_handles: Vec::new(),
},
copilot_lsp_subscription,
copilot_log_subscription: None,
current_lsp_settings: ProjectSettings::get_global(cx).lsp.clone(),
node: None,
default_prettier: DefaultPrettier::default(),
@ -1184,17 +1173,6 @@ impl Project {
self.restart_language_servers(worktree, language, cx);
}
if self.copilot_lsp_subscription.is_none() {
if let Some(copilot) = Copilot::global(cx) {
for buffer in self.opened_buffers.values() {
if let Some(buffer) = buffer.upgrade() {
self.register_buffer_with_copilot(&buffer, cx);
}
}
self.copilot_lsp_subscription = Some(subscribe_for_copilot_events(&copilot, cx));
}
}
cx.notify();
}
@ -2351,7 +2329,7 @@ impl Project {
self.detect_language_for_buffer(buffer, cx);
self.register_buffer_with_language_servers(buffer, cx);
self.register_buffer_with_copilot(buffer, cx);
// self.register_buffer_with_copilot(buffer, cx);
cx.observe_release(buffer, |this, buffer, cx| {
if let Some(file) = File::from_dyn(buffer.file()) {
if file.is_local() {
@ -2500,15 +2478,15 @@ impl Project {
});
}
fn register_buffer_with_copilot(
&self,
buffer_handle: &Model<Buffer>,
cx: &mut ModelContext<Self>,
) {
if let Some(copilot) = Copilot::global(cx) {
copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx));
}
}
// fn register_buffer_with_copilot(
// &self,
// buffer_handle: &Model<Buffer>,
// cx: &mut ModelContext<Self>,
// ) {
// if let Some(copilot) = Copilot::global(cx) {
// copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx));
// }
// }
async fn send_buffer_ordered_messages(
this: WeakModel<Self>,
@ -10475,43 +10453,6 @@ async fn search_ignored_entry(
}
}
fn subscribe_for_copilot_events(
copilot: &Model<Copilot>,
cx: &mut ModelContext<'_, Project>,
) -> gpui::Subscription {
cx.subscribe(
copilot,
|project, copilot, copilot_event, cx| match copilot_event {
copilot::Event::CopilotLanguageServerStarted => {
match copilot.read(cx).language_server() {
Some((name, copilot_server)) => {
// Another event wants to re-add the server that was already added and subscribed to, avoid doing it again.
if !copilot_server.has_notification_handler::<copilot::request::LogMessage>() {
let new_server_id = copilot_server.server_id();
let weak_project = cx.weak_model();
let copilot_log_subscription = copilot_server
.on_notification::<copilot::request::LogMessage, _>(
move |params, mut cx| {
weak_project.update(&mut cx, |_, cx| {
cx.emit(Event::LanguageServerLog(
new_server_id,
params.message,
));
}).ok();
},
);
project.supplementary_language_servers.insert(new_server_id, (name.clone(), Arc::clone(copilot_server)));
project.copilot_log_subscription = Some(copilot_log_subscription);
cx.emit(Event::LanguageServerAdded(new_server_id));
}
}
None => debug_panic!("Received Copilot language server started event, but no language server is running"),
}
}
},
)
}
fn glob_literal_prefix(glob: &str) -> &str {
let mut literal_end = 0;
for (i, part) in glob.split(path::MAIN_SEPARATOR).enumerate() {

View File

@ -207,7 +207,7 @@ message Envelope {
GetCachedEmbeddings get_cached_embeddings = 189;
GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
ComputeEmbeddings compute_embeddings = 191;
ComputeEmbeddingsResponse compute_embeddings_response = 192; // current max
ComputeEmbeddingsResponse compute_embeddings_response = 192;
UpdateChannelMessage update_channel_message = 170;
ChannelMessageUpdate channel_message_update = 171;
@ -238,7 +238,10 @@ message Envelope {
ValidateDevServerProjectRequest validate_dev_server_project_request = 194;
DeleteDevServer delete_dev_server = 195;
OpenNewBuffer open_new_buffer = 196;
DeleteDevServerProject delete_dev_server_project = 197; // Current max
DeleteDevServerProject delete_dev_server_project = 197;
GetSupermavenApiKey get_supermaven_api_key = 198;
GetSupermavenApiKeyResponse get_supermaven_api_key_response = 199; // current max
}
reserved 158 to 161;
@ -2084,3 +2087,9 @@ message LspResponse {
GetCodeActionsResponse get_code_actions_response = 2;
}
}
message GetSupermavenApiKey {}
message GetSupermavenApiKeyResponse {
string api_key = 1;
}

View File

@ -201,6 +201,8 @@ messages!(
(GetProjectSymbolsResponse, Background),
(GetReferences, Background),
(GetReferencesResponse, Background),
(GetSupermavenApiKey, Background),
(GetSupermavenApiKeyResponse, Background),
(GetTypeDefinition, Background),
(GetTypeDefinitionResponse, Background),
(GetImplementation, Background),
@ -360,6 +362,7 @@ request_messages!(
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
(GetProjectSymbols, GetProjectSymbolsResponse),
(GetReferences, GetReferencesResponse),
(GetSupermavenApiKey, GetSupermavenApiKeyResponse),
(GetTypeDefinition, GetTypeDefinitionResponse),
(GetUsers, UsersResponse),
(IncomingCall, Ack),

View File

@ -0,0 +1,41 @@
[package]
name = "supermaven"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/supermaven.rs"
doctest = false
[dependencies]
anyhow.workspace = true
client.workspace = true
collections.workspace = true
editor.workspace = true
gpui.workspace = true
futures.workspace = true
language.workspace = true
log.workspace = true
postage.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
supermaven_api.workspace = true
smol.workspace = true
ui.workspace = true
util.workspace = true
[dev-dependencies]
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
theme = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }

View File

@ -0,0 +1,152 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SetApiKey {
pub api_key: String,
}
// Outbound messages
#[derive(Debug, Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum OutboundMessage {
SetApiKey(SetApiKey),
StateUpdate(StateUpdateMessage),
#[allow(dead_code)]
UseFreeVersion,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StateUpdateMessage {
pub new_id: String,
pub updates: Vec<StateUpdate>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum StateUpdate {
FileUpdate(FileUpdateMessage),
CursorUpdate(CursorPositionUpdateMessage),
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct FileUpdateMessage {
pub path: String,
pub content: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct CursorPositionUpdateMessage {
pub path: String,
pub offset: usize,
}
// Inbound messages coming in on stdout
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum ResponseItem {
// A completion
Text { text: String },
// Vestigial message type from old versions -- safe to ignore
Del { text: String },
// Be able to delete whitespace prior to the cursor, likely for the rest of the completion
Dedent { text: String },
// When the completion is over
End,
// Got the closing parentheses and shouldn't show any more after
Barrier,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SupermavenResponse {
pub state_id: String,
pub items: Vec<ResponseItem>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SupermavenMetadataMessage {
pub dust_strings: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SupermavenTaskUpdateMessage {
pub task: String,
pub status: TaskStatus,
pub percent_complete: Option<f32>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TaskStatus {
InProgress,
Complete,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SupermavenActiveRepoMessage {
pub repo_simple_name: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum SupermavenPopupAction {
OpenUrl { label: String, url: String },
NoOp { label: String },
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct SupermavenPopupMessage {
pub message: String,
pub actions: Vec<SupermavenPopupAction>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "camelCase")]
pub struct ActivationRequest {
pub activate_url: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SupermavenSetMessage {
pub key: String,
pub value: serde_json::Value,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ServiceTier {
FreeNoLicense,
#[serde(other)]
Unknown,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum SupermavenMessage {
Response(SupermavenResponse),
Metadata(SupermavenMetadataMessage),
Apology {
message: Option<String>,
},
ActivationRequest(ActivationRequest),
ActivationSuccess,
Passthrough {
passthrough: Box<SupermavenMessage>,
},
Popup(SupermavenPopupMessage),
TaskStatus(SupermavenTaskUpdateMessage),
ActiveRepo(SupermavenActiveRepoMessage),
ServiceTier {
service_tier: ServiceTier,
},
Set(SupermavenSetMessage),
#[serde(other)]
Unknown,
}

View File

@ -0,0 +1,345 @@
mod messages;
mod supermaven_completion_provider;
pub use supermaven_completion_provider::*;
use anyhow::{Context as _, Result};
#[allow(unused_imports)]
use client::{proto, Client};
use collections::BTreeMap;
use futures::{channel::mpsc, io::BufReader, AsyncBufReadExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, EntityId, Global, Model, ModelContext, Task, WeakModel};
use language::{language_settings::all_language_settings, Anchor, Buffer, ToOffset};
use messages::*;
use postage::watch;
use serde::{Deserialize, Serialize};
use settings::SettingsStore;
use smol::{
io::AsyncWriteExt,
process::{Child, ChildStdin, ChildStdout, Command},
};
use std::{ops::Range, path::PathBuf, process::Stdio, sync::Arc};
use ui::prelude::*;
use util::ResultExt;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let supermaven = cx.new_model(|_| Supermaven::Starting);
Supermaven::set_global(supermaven.clone(), cx);
let mut provider = all_language_settings(None, cx).inline_completions.provider;
if provider == language::language_settings::InlineCompletionProvider::Supermaven {
supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
}
cx.observe_global::<SettingsStore>(move |cx| {
let new_provider = all_language_settings(None, cx).inline_completions.provider;
if new_provider != provider {
provider = new_provider;
if provider == language::language_settings::InlineCompletionProvider::Supermaven {
supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
} else {
supermaven.update(cx, |supermaven, _cx| supermaven.stop());
}
}
})
.detach();
}
pub enum Supermaven {
Starting,
FailedDownload { error: anyhow::Error },
Spawned(SupermavenAgent),
Error { error: anyhow::Error },
}
#[derive(Clone)]
pub enum AccountStatus {
Unknown,
NeedsActivation { activate_url: String },
Ready,
}
#[derive(Clone)]
struct SupermavenGlobal(Model<Supermaven>);
impl Global for SupermavenGlobal {}
impl Supermaven {
pub fn global(cx: &AppContext) -> Option<Model<Self>> {
cx.try_global::<SupermavenGlobal>()
.map(|model| model.0.clone())
}
pub fn set_global(supermaven: Model<Self>, cx: &mut AppContext) {
cx.set_global(SupermavenGlobal(supermaven));
}
pub fn start(&mut self, client: Arc<Client>, cx: &mut ModelContext<Self>) {
if let Self::Starting = self {
cx.spawn(|this, mut cx| async move {
let binary_path =
supermaven_api::get_supermaven_agent_path(client.http_client()).await?;
this.update(&mut cx, |this, cx| {
if let Self::Starting = this {
*this =
Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?);
}
anyhow::Ok(())
})
})
.detach_and_log_err(cx)
}
}
pub fn stop(&mut self) {
*self = Self::Starting;
}
pub fn is_enabled(&self) -> bool {
matches!(self, Self::Spawned { .. })
}
pub fn complete(
&mut self,
buffer: &Model<Buffer>,
cursor_position: Anchor,
cx: &AppContext,
) -> Option<SupermavenCompletion> {
if let Self::Spawned(agent) = self {
let buffer_id = buffer.entity_id();
let buffer = buffer.read(cx);
let path = buffer
.file()
.and_then(|file| Some(file.as_local()?.abs_path(cx)))
.unwrap_or_else(|| PathBuf::from("untitled"))
.to_string_lossy()
.to_string();
let content = buffer.text();
let offset = cursor_position.to_offset(buffer);
let state_id = agent.next_state_id;
agent.next_state_id.0 += 1;
let (updates_tx, mut updates_rx) = watch::channel();
postage::stream::Stream::try_recv(&mut updates_rx).unwrap();
agent.states.insert(
state_id,
SupermavenCompletionState {
buffer_id,
range: cursor_position.bias_left(buffer)..cursor_position.bias_right(buffer),
completion: Vec::new(),
text: String::new(),
updates_tx,
},
);
let _ = agent
.outgoing_tx
.unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage {
new_id: state_id.0.to_string(),
updates: vec![
StateUpdate::FileUpdate(FileUpdateMessage {
path: path.clone(),
content,
}),
StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }),
],
}));
Some(SupermavenCompletion {
id: state_id,
updates: updates_rx,
})
} else {
None
}
}
pub fn completion(
&self,
id: SupermavenCompletionStateId,
) -> Option<&SupermavenCompletionState> {
if let Self::Spawned(agent) = self {
agent.states.get(&id)
} else {
None
}
}
}
pub struct SupermavenAgent {
_process: Child,
next_state_id: SupermavenCompletionStateId,
states: BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
outgoing_tx: mpsc::UnboundedSender<OutboundMessage>,
_handle_outgoing_messages: Task<Result<()>>,
_handle_incoming_messages: Task<Result<()>>,
pub account_status: AccountStatus,
service_tier: Option<ServiceTier>,
#[allow(dead_code)]
client: Arc<Client>,
}
impl SupermavenAgent {
fn new(
binary_path: PathBuf,
client: Arc<Client>,
cx: &mut ModelContext<Supermaven>,
) -> Result<Self> {
let mut process = Command::new(&binary_path)
.arg("stdio")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true)
.spawn()
.context("failed to start the binary")?;
let stdin = process
.stdin
.take()
.context("failed to get stdin for process")?;
let stdout = process
.stdout
.take()
.context("failed to get stdout for process")?;
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
cx.spawn({
let client = client.clone();
let outgoing_tx = outgoing_tx.clone();
move |this, mut cx| async move {
let mut status = client.status();
while let Some(status) = status.next().await {
if status.is_connected() {
let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key;
outgoing_tx
.unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key }))
.ok();
this.update(&mut cx, |this, cx| {
if let Supermaven::Spawned(this) = this {
this.account_status = AccountStatus::Ready;
cx.notify();
}
})?;
break;
}
}
return anyhow::Ok(());
}
})
.detach();
Ok(Self {
_process: process,
next_state_id: SupermavenCompletionStateId::default(),
states: BTreeMap::default(),
outgoing_tx,
_handle_outgoing_messages: cx
.spawn(|_, _cx| Self::handle_outgoing_messages(outgoing_rx, stdin)),
_handle_incoming_messages: cx
.spawn(|this, cx| Self::handle_incoming_messages(this, stdout, cx)),
account_status: AccountStatus::Unknown,
service_tier: None,
client,
})
}
async fn handle_outgoing_messages(
mut outgoing: mpsc::UnboundedReceiver<OutboundMessage>,
mut stdin: ChildStdin,
) -> Result<()> {
while let Some(message) = outgoing.next().await {
let bytes = serde_json::to_vec(&message)?;
stdin.write_all(&bytes).await?;
stdin.write_all(&[b'\n']).await?;
}
Ok(())
}
async fn handle_incoming_messages(
this: WeakModel<Supermaven>,
stdout: ChildStdout,
mut cx: AsyncAppContext,
) -> Result<()> {
const MESSAGE_PREFIX: &str = "SM-MESSAGE ";
let stdout = BufReader::new(stdout);
let mut lines = stdout.lines();
while let Some(line) = lines.next().await {
let Some(line) = line.context("failed to read line from stdout").log_err() else {
continue;
};
let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else {
continue;
};
let Some(message) = serde_json::from_str::<SupermavenMessage>(&line)
.with_context(|| format!("failed to deserialize line from stdout: {:?}", line))
.log_err()
else {
continue;
};
this.update(&mut cx, |this, _cx| {
if let Supermaven::Spawned(this) = this {
this.handle_message(message);
}
Task::ready(anyhow::Ok(()))
})?
.await?;
}
Ok(())
}
fn handle_message(&mut self, message: SupermavenMessage) {
match message {
SupermavenMessage::ActivationRequest(request) => {
self.account_status = match request.activate_url {
Some(activate_url) => AccountStatus::NeedsActivation {
activate_url: activate_url.clone(),
},
None => AccountStatus::Ready,
};
}
SupermavenMessage::ServiceTier { service_tier } => {
self.service_tier = Some(service_tier);
}
SupermavenMessage::Response(response) => {
let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
if let Some(state) = self.states.get_mut(&state_id) {
for item in &response.items {
if let ResponseItem::Text { text } = item {
state.text.push_str(text);
}
}
state.completion.extend(response.items);
*state.updates_tx.borrow_mut() = ();
}
}
SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough),
_ => {
log::warn!("unhandled message: {:?}", message);
}
}
}
}
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
pub struct SupermavenCompletionStateId(usize);
#[allow(dead_code)]
pub struct SupermavenCompletionState {
buffer_id: EntityId,
range: Range<Anchor>,
completion: Vec<ResponseItem>,
text: String,
updates_tx: watch::Sender<()>,
}
pub struct SupermavenCompletion {
pub id: SupermavenCompletionStateId,
pub updates: watch::Receiver<()>,
}

View File

@ -0,0 +1,131 @@
use crate::{Supermaven, SupermavenCompletionStateId};
use anyhow::Result;
use editor::{Direction, InlineCompletionProvider};
use futures::StreamExt as _;
use gpui::{AppContext, Model, ModelContext, Task};
use language::{
language_settings::all_language_settings, Anchor, Buffer, OffsetRangeExt as _, ToOffset,
};
use std::time::Duration;
pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
pub struct SupermavenCompletionProvider {
supermaven: Model<Supermaven>,
completion_id: Option<SupermavenCompletionStateId>,
pending_refresh: Task<Result<()>>,
}
impl SupermavenCompletionProvider {
pub fn new(supermaven: Model<Supermaven>) -> Self {
Self {
supermaven,
completion_id: None,
pending_refresh: Task::ready(Ok(())),
}
}
}
impl InlineCompletionProvider for SupermavenCompletionProvider {
fn is_enabled(&self, buffer: &Model<Buffer>, cursor_position: Anchor, cx: &AppContext) -> bool {
if !self.supermaven.read(cx).is_enabled() {
return false;
}
let buffer = buffer.read(cx);
let file = buffer.file();
let language = buffer.language_at(cursor_position);
let settings = all_language_settings(file, cx);
settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()))
}
fn refresh(
&mut self,
buffer_handle: Model<Buffer>,
cursor_position: Anchor,
debounce: bool,
cx: &mut ModelContext<Self>,
) {
let Some(mut completion) = self.supermaven.update(cx, |supermaven, cx| {
supermaven.complete(&buffer_handle, cursor_position, cx)
}) else {
return;
};
self.pending_refresh = cx.spawn(|this, mut cx| async move {
if debounce {
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
}
while let Some(()) = completion.updates.next().await {
this.update(&mut cx, |this, cx| {
this.completion_id = Some(completion.id);
cx.notify();
})?;
}
Ok(())
});
}
fn cycle(
&mut self,
_buffer: Model<Buffer>,
_cursor_position: Anchor,
_direction: Direction,
_cx: &mut ModelContext<Self>,
) {
// todo!("cycling")
}
fn accept(&mut self, _cx: &mut ModelContext<Self>) {
self.pending_refresh = Task::ready(Ok(()));
self.completion_id = None;
}
fn discard(&mut self, _cx: &mut ModelContext<Self>) {
self.pending_refresh = Task::ready(Ok(()));
self.completion_id = None;
}
fn active_completion_text<'a>(
&'a self,
buffer: &Model<Buffer>,
cursor_position: Anchor,
cx: &'a AppContext,
) -> Option<&'a str> {
let completion_id = self.completion_id?;
let buffer = buffer.read(cx);
let cursor_offset = cursor_position.to_offset(buffer);
let completion = self.supermaven.read(cx).completion(completion_id)?;
let mut completion_range = completion.range.to_offset(buffer);
let prefix_len = common_prefix(
buffer.chars_for_range(completion_range.clone()),
completion.text.chars(),
);
completion_range.start += prefix_len;
let suffix_len = common_prefix(
buffer.reversed_chars_for_range(completion_range.clone()),
completion.text[prefix_len..].chars().rev(),
);
completion_range.end = completion_range.end.saturating_sub(suffix_len);
let completion_text = &completion.text[prefix_len..completion.text.len() - suffix_len];
if completion_range.is_empty()
&& completion_range.start == cursor_offset
&& !completion_text.trim().is_empty()
{
Some(completion_text)
} else {
None
}
}
}
fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
a.zip(b)
.take_while(|(a, b)| a == b)
.map(|(a, _)| a.len_utf8())
.sum()
}

View File

@ -0,0 +1,21 @@
[package]
name = "supermaven_api"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/supermaven_api.rs"
doctest = false
[dependencies]
anyhow.workspace = true
futures.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
util.workspace = true

View File

@ -0,0 +1,291 @@
use anyhow::{anyhow, Context, Result};
use futures::io::BufReader;
use futures::{AsyncReadExt, Future};
use serde::{Deserialize, Serialize};
use smol::fs::{self, File};
use smol::stream::StreamExt;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use util::http::{AsyncBody, HttpClient, Request as HttpRequest};
use util::paths::SUPERMAVEN_DIR;
#[derive(Serialize)]
pub struct GetExternalUserRequest {
pub id: String,
}
#[derive(Serialize)]
pub struct CreateExternalUserRequest {
pub id: String,
pub email: String,
}
#[derive(Serialize)]
pub struct DeleteExternalUserRequest {
pub id: String,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateExternalUserResponse {
pub api_key: String,
}
#[derive(Deserialize)]
pub struct SupermavenApiError {
pub message: String,
}
pub struct SupermavenBinary {}
pub struct SupermavenAdminApi {
admin_api_key: String,
api_url: String,
http_client: Arc<dyn HttpClient>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SupermavenDownloadResponse {
pub download_url: String,
pub version: u64,
pub sha256_hash: String,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SupermavenUser {
id: String,
email: String,
api_key: String,
}
impl SupermavenAdminApi {
pub fn new(admin_api_key: String, http_client: Arc<dyn HttpClient>) -> Self {
Self {
admin_api_key,
api_url: "https://supermaven.com/api/".to_string(),
http_client,
}
}
pub async fn try_get_user(
&self,
request: GetExternalUserRequest,
) -> Result<Option<SupermavenUser>> {
let uri = format!("{}external-user/{}", &self.api_url, &request.id);
let request = HttpRequest::get(&uri).header("Authorization", self.admin_api_key.clone());
let mut response = self
.http_client
.send(request.body(AsyncBody::default())?)
.await
.with_context(|| "Unable to get Supermaven API Key".to_string())?;
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
if response.status().is_client_error() {
let error: SupermavenApiError = serde_json::from_slice(&body)?;
if error.message == "User not found" {
return Ok(None);
} else {
return Err(anyhow!("Supermaven API error: {}", error.message));
}
} else if response.status().is_server_error() {
let error: SupermavenApiError = serde_json::from_slice(&body)?;
return Err(anyhow!("Supermaven API server error").context(error.message));
}
let body_str = std::str::from_utf8(&body)?;
Ok(Some(
serde_json::from_str::<SupermavenUser>(body_str)
.with_context(|| "Unable to parse Supermaven user response".to_string())?,
))
}
pub async fn try_create_user(
&self,
request: CreateExternalUserRequest,
) -> Result<CreateExternalUserResponse> {
let uri = format!("{}external-user", &self.api_url);
let request = HttpRequest::post(&uri)
.header("Authorization", self.admin_api_key.clone())
.body(AsyncBody::from(serde_json::to_vec(&request)?))?;
let mut response = self
.http_client
.send(request)
.await
.with_context(|| "Unable to create Supermaven API Key".to_string())?;
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?;
if !response.status().is_success() {
let error: SupermavenApiError = serde_json::from_slice(&body)?;
return Err(anyhow!("Supermaven API server error").context(error.message));
}
serde_json::from_str::<CreateExternalUserResponse>(body_str)
.with_context(|| "Unable to parse Supermaven API Key response".to_string())
}
pub async fn try_delete_user(&self, request: DeleteExternalUserRequest) -> Result<()> {
let uri = format!("{}external-user/{}", &self.api_url, &request.id);
let request = HttpRequest::delete(&uri).header("Authorization", self.admin_api_key.clone());
let mut response = self
.http_client
.send(request.body(AsyncBody::default())?)
.await
.with_context(|| "Unable to delete Supermaven User".to_string())?;
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
if response.status().is_client_error() {
let error: SupermavenApiError = serde_json::from_slice(&body)?;
if error.message == "User not found" {
return Ok(());
} else {
return Err(anyhow!("Supermaven API error: {}", error.message));
}
} else if response.status().is_server_error() {
let error: SupermavenApiError = serde_json::from_slice(&body)?;
return Err(anyhow!("Supermaven API server error").context(error.message));
}
Ok(())
}
pub async fn try_get_or_create_user(
&self,
request: CreateExternalUserRequest,
) -> Result<CreateExternalUserResponse> {
let get_user_request = GetExternalUserRequest {
id: request.id.clone(),
};
match self.try_get_user(get_user_request).await? {
None => self.try_create_user(request).await,
Some(SupermavenUser { api_key, .. }) => Ok(CreateExternalUserResponse { api_key }),
}
}
}
pub async fn latest_release(
client: Arc<dyn HttpClient>,
platform: &str,
arch: &str,
) -> Result<SupermavenDownloadResponse> {
let uri = format!(
"https://supermaven.com/api/download-path?platform={}&arch={}",
platform, arch
);
// Download is not authenticated
let request = HttpRequest::get(&uri);
let mut response = client
.send(request.body(AsyncBody::default())?)
.await
.with_context(|| "Unable to acquire Supermaven Agent".to_string())?;
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
if response.status().is_client_error() || response.status().is_server_error() {
let body_str = std::str::from_utf8(&body)?;
let error: SupermavenApiError = serde_json::from_str(body_str)?;
return Err(anyhow!("Supermaven API error: {}", error.message));
}
serde_json::from_slice::<SupermavenDownloadResponse>(&body)
.with_context(|| "Unable to parse Supermaven Agent response".to_string())
}
pub fn version_path(version: u64) -> PathBuf {
SUPERMAVEN_DIR.join(format!("sm-agent-{}", version))
}
pub async fn has_version(version_path: &Path) -> bool {
fs::metadata(version_path)
.await
.map_or(false, |m| m.is_file())
}
pub fn get_supermaven_agent_path(
client: Arc<dyn HttpClient>,
) -> impl Future<Output = Result<PathBuf>> {
async move {
fs::create_dir_all(&*SUPERMAVEN_DIR)
.await
.with_context(|| {
format!(
"Could not create Supermaven Agent Directory at {:?}",
&*SUPERMAVEN_DIR
)
})?;
let platform = match std::env::consts::OS {
"macos" => "darwin",
"windows" => "windows",
"linux" => "linux",
_ => return Err(anyhow!("unsupported platform")),
};
let arch = match std::env::consts::ARCH {
"x86_64" => "amd64",
"aarch64" => "arm64",
_ => return Err(anyhow!("unsupported architecture")),
};
let download_info = latest_release(client.clone(), platform, arch).await?;
let binary_path = version_path(download_info.version);
if has_version(&binary_path).await {
return Ok(binary_path);
}
let request = HttpRequest::get(&download_info.download_url);
let mut response = client
.send(request.body(AsyncBody::default())?)
.await
.with_context(|| "Unable to download Supermaven Agent".to_string())?;
let mut file = File::create(&binary_path)
.await
.with_context(|| format!("Unable to create file at {:?}", binary_path))?;
futures::io::copy(BufReader::new(response.body_mut()), &mut file)
.await
.with_context(|| format!("Unable to write binary to file at {:?}", binary_path))?;
#[cfg(not(windows))]
{
file.set_permissions(<fs::Permissions as fs::unix::PermissionsExt>::from_mode(
0o755,
))
.await?;
}
let mut old_binary_paths = fs::read_dir(&*SUPERMAVEN_DIR).await?;
while let Some(old_binary_path) = old_binary_paths.next().await {
let old_binary_path = old_binary_path?;
if old_binary_path.path() != binary_path {
fs::remove_file(old_binary_path.path()).await?;
}
}
Ok(binary_path)
}
}

View File

@ -155,6 +155,10 @@ pub enum IconName {
Space,
Split,
Spinner,
Supermaven,
SupermavenDisabled,
SupermavenError,
SupermavenInit,
Tab,
Terminal,
Trash,
@ -261,6 +265,10 @@ impl IconName {
IconName::Space => "icons/space.svg",
IconName::Split => "icons/split.svg",
IconName::Spinner => "icons/spinner.svg",
IconName::Supermaven => "icons/supermaven.svg",
IconName::SupermavenDisabled => "icons/supermaven_disabled.svg",
IconName::SupermavenError => "icons/supermaven_error.svg",
IconName::SupermavenInit => "icons/supermaven_init.svg",
IconName::Tab => "icons/tab.svg",
IconName::Terminal => "icons/terminal.svg",
IconName::Trash => "icons/trash.svg",

View File

@ -52,6 +52,7 @@ lazy_static::lazy_static! {
pub static ref EXTENSIONS_DIR: PathBuf = SUPPORT_DIR.join("extensions");
pub static ref LANGUAGES_DIR: PathBuf = SUPPORT_DIR.join("languages");
pub static ref COPILOT_DIR: PathBuf = SUPPORT_DIR.join("copilot");
pub static ref SUPERMAVEN_DIR: PathBuf = SUPPORT_DIR.join("supermaven");
pub static ref DEFAULT_PRETTIER_DIR: PathBuf = SUPPORT_DIR.join("prettier");
pub static ref DB_DIR: PathBuf = SUPPORT_DIR.join("db");
pub static ref CRASHES_DIR: Option<PathBuf> = cfg!(target_os = "macos")

View File

@ -17,7 +17,7 @@ test-support = []
[dependencies]
anyhow.workspace = true
client.workspace = true
copilot_ui.workspace = true
inline_completion_button.workspace = true
db.workspace = true
extensions_ui.workspace = true
fuzzy.workspace = true

View File

@ -2,7 +2,6 @@ mod base_keymap_picker;
mod base_keymap_setting;
use client::{telemetry::Telemetry, TelemetrySettings};
use copilot_ui;
use db::kvp::KEY_VALUE_STORE;
use gpui::{
svg, AnyElement, AppContext, EventEmitter, FocusHandle, FocusableView, InteractiveElement,
@ -143,7 +142,7 @@ impl Render for WelcomePage {
this.telemetry.report_app_event(
"welcome page: sign in to copilot".to_string(),
);
copilot_ui::initiate_sign_in(cx);
inline_completion_button::initiate_sign_in(cx);
})),
)
.child(

View File

@ -35,7 +35,6 @@ collab_ui.workspace = true
collections.workspace = true
command_palette.workspace = true
copilot.workspace = true
copilot_ui.workspace = true
db.workspace = true
diagnostics.workspace = true
editor.workspace = true
@ -51,6 +50,7 @@ go_to_line.workspace = true
gpui.workspace = true
headless.workspace = true
image_viewer.workspace = true
inline_completion_button.workspace = true
install_cli.workspace = true
isahc.workspace = true
journal.workspace = true
@ -83,6 +83,7 @@ settings.workspace = true
simplelog = "0.9"
smol.workspace = true
tab_switcher.workspace = true
supermaven.workspace = true
task.workspace = true
tasks_ui.workspace = true
telemetry_events.workspace = true

View File

@ -9,16 +9,14 @@ mod zed;
use anyhow::{anyhow, Context as _, Result};
use clap::{command, Parser};
use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
use client::{parse_zed_link, telemetry::Telemetry, Client, DevServerToken, UserStore};
use client::{parse_zed_link, Client, DevServerToken, UserStore};
use collab_ui::channel_view::ChannelView;
use copilot::Copilot;
use copilot_ui::CopilotCompletionProvider;
use db::kvp::KEY_VALUE_STORE;
use editor::{Editor, EditorMode};
use editor::Editor;
use env_logger::Builder;
use fs::RealFs;
use futures::{future, StreamExt};
use gpui::{App, AppContext, AsyncAppContext, Context, Task, ViewContext, VisualContext};
use gpui::{App, AppContext, AsyncAppContext, Context, Task, VisualContext};
use image_viewer;
use language::LanguageRegistry;
use log::LevelFilter;
@ -55,6 +53,8 @@ use zed::{
OpenListener, OpenRequest,
};
use crate::zed::inline_completion_registry;
#[cfg(feature = "mimalloc")]
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
@ -270,17 +270,20 @@ fn init_ui(args: Args) {
editor::init(cx);
image_viewer::init(cx);
diagnostics::init(cx);
// Initialize each completion provider. Settings are used for toggling between them.
copilot::init(
copilot_language_server_id,
client.http_client(),
node_runtime.clone(),
cx,
);
supermaven::init(client.clone(), cx);
assistant::init(client.clone(), cx);
assistant2::init(client.clone(), cx);
init_inline_completion_provider(client.telemetry().clone(), cx);
inline_completion_registry::init(client.telemetry().clone(), cx);
extension::init(
fs.clone(),
@ -888,45 +891,3 @@ fn watch_file_types(fs: Arc<dyn fs::Fs>, cx: &mut AppContext) {
#[cfg(not(debug_assertions))]
fn watch_file_types(_fs: Arc<dyn fs::Fs>, _cx: &mut AppContext) {}
fn init_inline_completion_provider(telemetry: Arc<Telemetry>, cx: &mut AppContext) {
if let Some(copilot) = Copilot::global(cx) {
cx.observe_new_views(move |editor: &mut Editor, cx: &mut ViewContext<Editor>| {
if editor.mode() == EditorMode::Full {
// We renamed some of these actions to not be copilot-specific, but that
// would have not been backwards-compatible. So here we are re-registering
// the actions with the old names to not break people's keymaps.
editor
.register_action(cx.listener(
|editor, _: &copilot::Suggest, cx: &mut ViewContext<Editor>| {
editor.show_inline_completion(&Default::default(), cx);
},
))
.register_action(cx.listener(
|editor, _: &copilot::NextSuggestion, cx: &mut ViewContext<Editor>| {
editor.next_inline_completion(&Default::default(), cx);
},
))
.register_action(cx.listener(
|editor, _: &copilot::PreviousSuggestion, cx: &mut ViewContext<Editor>| {
editor.previous_inline_completion(&Default::default(), cx);
},
))
.register_action(cx.listener(
|editor,
_: &editor::actions::AcceptPartialCopilotSuggestion,
cx: &mut ViewContext<Editor>| {
editor.accept_partial_inline_completion(&Default::default(), cx);
},
));
let provider = cx.new_model(|_| {
CopilotCompletionProvider::new(copilot.clone())
.with_telemetry(telemetry.clone())
});
editor.set_inline_completion_provider(provider, cx)
}
})
.detach();
}
}

View File

@ -1,4 +1,5 @@
mod app_menus;
pub mod inline_completion_registry;
mod only_instance;
mod open_listener;
@ -127,7 +128,10 @@ pub fn initialize_workspace(app_state: Arc<AppState>, cx: &mut AppContext) {
})
.detach();
let copilot = cx.new_view(|cx| copilot_ui::CopilotButton::new(app_state.fs.clone(), cx));
let inline_completion_button = cx.new_view(|cx| {
inline_completion_button::InlineCompletionButton::new(app_state.fs.clone(), cx)
});
let diagnostic_summary =
cx.new_view(|cx| diagnostics::items::DiagnosticIndicator::new(workspace, cx));
let activity_indicator =
@ -140,7 +144,7 @@ pub fn initialize_workspace(app_state: Arc<AppState>, cx: &mut AppContext) {
workspace.status_bar().update(cx, |status_bar, cx| {
status_bar.add_left_item(diagnostic_summary, cx);
status_bar.add_left_item(activity_indicator, cx);
status_bar.add_right_item(copilot, cx);
status_bar.add_right_item(inline_completion_button, cx);
status_bar.add_right_item(active_buffer_language, cx);
status_bar.add_right_item(vim_mode_indicator, cx);
status_bar.add_right_item(cursor_position, cx);

View File

@ -0,0 +1,126 @@
use std::{cell::RefCell, rc::Rc, sync::Arc};
use client::telemetry::Telemetry;
use collections::HashMap;
use copilot::{Copilot, CopilotCompletionProvider};
use editor::{Editor, EditorMode};
use gpui::{AnyWindowHandle, AppContext, Context, ViewContext, WeakView};
use language::language_settings::all_language_settings;
use settings::SettingsStore;
use supermaven::{Supermaven, SupermavenCompletionProvider};
pub fn init(telemetry: Arc<Telemetry>, cx: &mut AppContext) {
let editors: Rc<RefCell<HashMap<WeakView<Editor>, AnyWindowHandle>>> = Rc::default();
cx.observe_new_views({
let editors = editors.clone();
let telemetry = telemetry.clone();
move |editor: &mut Editor, cx: &mut ViewContext<Editor>| {
if editor.mode() != EditorMode::Full {
return;
}
register_backward_compatible_actions(editor, cx);
let editor_handle = cx.view().downgrade();
cx.on_release({
let editor_handle = editor_handle.clone();
let editors = editors.clone();
move |_, _, _| {
editors.borrow_mut().remove(&editor_handle);
}
})
.detach();
editors
.borrow_mut()
.insert(editor_handle, cx.window_handle());
let provider = all_language_settings(None, cx).inline_completions.provider;
assign_inline_completion_provider(editor, provider, &telemetry, cx);
}
})
.detach();
let mut provider = all_language_settings(None, cx).inline_completions.provider;
for (editor, window) in editors.borrow().iter() {
_ = window.update(cx, |_window, cx| {
_ = editor.update(cx, |editor, cx| {
assign_inline_completion_provider(editor, provider, &telemetry, cx);
})
});
}
cx.observe_global::<SettingsStore>(move |cx| {
let new_provider = all_language_settings(None, cx).inline_completions.provider;
if new_provider != provider {
provider = new_provider;
for (editor, window) in editors.borrow().iter() {
_ = window.update(cx, |_window, cx| {
_ = editor.update(cx, |editor, cx| {
assign_inline_completion_provider(editor, provider, &telemetry, cx);
})
});
}
}
})
.detach();
}
fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut ViewContext<Editor>) {
// We renamed some of these actions to not be copilot-specific, but that
// would have not been backwards-compatible. So here we are re-registering
// the actions with the old names to not break people's keymaps.
editor
.register_action(cx.listener(
|editor, _: &copilot::Suggest, cx: &mut ViewContext<Editor>| {
editor.show_inline_completion(&Default::default(), cx);
},
))
.register_action(cx.listener(
|editor, _: &copilot::NextSuggestion, cx: &mut ViewContext<Editor>| {
editor.next_inline_completion(&Default::default(), cx);
},
))
.register_action(cx.listener(
|editor, _: &copilot::PreviousSuggestion, cx: &mut ViewContext<Editor>| {
editor.previous_inline_completion(&Default::default(), cx);
},
))
.register_action(cx.listener(
|editor,
_: &editor::actions::AcceptPartialCopilotSuggestion,
cx: &mut ViewContext<Editor>| {
editor.accept_partial_inline_completion(&Default::default(), cx);
},
));
}
fn assign_inline_completion_provider(
editor: &mut Editor,
provider: language::language_settings::InlineCompletionProvider,
telemetry: &Arc<Telemetry>,
cx: &mut ViewContext<Editor>,
) {
match provider {
language::language_settings::InlineCompletionProvider::None => {}
language::language_settings::InlineCompletionProvider::Copilot => {
if let Some(copilot) = Copilot::global(cx) {
if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
if buffer.read(cx).file().is_some() {
copilot.update(cx, |copilot, cx| {
copilot.register_buffer(&buffer, cx);
});
}
}
let provider = cx.new_model(|_| {
CopilotCompletionProvider::new(copilot).with_telemetry(telemetry.clone())
});
editor.set_inline_completion_provider(Some(provider), cx);
}
}
language::language_settings::InlineCompletionProvider::Supermaven => {
if let Some(supermaven) = Supermaven::global(cx) {
let provider = cx.new_model(|_| SupermavenCompletionProvider::new(supermaven));
editor.set_inline_completion_provider(Some(provider), cx);
}
}
}
}