diff --git a/crates/collab/src/api/extensions.rs b/crates/collab/src/api/extensions.rs index 2e8acd6c7b..fd0909a2f4 100644 --- a/crates/collab/src/api/extensions.rs +++ b/crates/collab/src/api/extensions.rs @@ -1,3 +1,4 @@ +use crate::db::ExtensionVersionConstraints; use crate::{db::NewExtensionVersion, AppState, Error, Result}; use anyhow::{anyhow, Context as _}; use aws_sdk_s3::presigning::PresigningConfig; @@ -10,14 +11,16 @@ use axum::{ }; use collections::HashMap; use rpc::{ExtensionApiManifest, GetExtensionsResponse}; +use semantic_version::SemanticVersion; use serde::Deserialize; use std::{sync::Arc, time::Duration}; use time::PrimitiveDateTime; -use util::ResultExt; +use util::{maybe, ResultExt}; pub fn router() -> Router { Router::new() .route("/extensions", get(get_extensions)) + .route("/extensions/updates", get(get_extension_updates)) .route("/extensions/:extension_id", get(get_extension_versions)) .route( "/extensions/:extension_id/download", @@ -48,9 +51,7 @@ async fn get_extensions( .map(|s| s.split(',').map(|s| s.trim()).collect::>()); let extensions = if let Some(extension_ids) = extension_ids { - app.db - .get_extensions_by_ids(&extension_ids, params.max_schema_version) - .await? + app.db.get_extensions_by_ids(&extension_ids, None).await? } else { app.db .get_extensions(params.filter.as_deref(), params.max_schema_version, 500) @@ -60,6 +61,34 @@ async fn get_extensions( Ok(Json(GetExtensionsResponse { data: extensions })) } +#[derive(Debug, Deserialize)] +struct GetExtensionUpdatesParams { + ids: String, + min_schema_version: i32, + max_schema_version: i32, + min_wasm_api_version: SemanticVersion, + max_wasm_api_version: SemanticVersion, +} + +async fn get_extension_updates( + Extension(app): Extension>, + Query(params): Query, +) -> Result> { + let constraints = ExtensionVersionConstraints { + schema_versions: params.min_schema_version..=params.max_schema_version, + wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version, + }; + + let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::>(); + + let extensions = app + .db + .get_extensions_by_ids(&extension_ids, Some(&constraints)) + .await?; + + Ok(Json(GetExtensionsResponse { data: extensions })) +} + #[derive(Debug, Deserialize)] struct GetExtensionVersionsParams { extension_id: String, @@ -79,15 +108,31 @@ async fn get_extension_versions( #[derive(Debug, Deserialize)] struct DownloadLatestExtensionParams { extension_id: String, + min_schema_version: Option, + max_schema_version: Option, + min_wasm_api_version: Option, + max_wasm_api_version: Option, } async fn download_latest_extension( Extension(app): Extension>, Path(params): Path, ) -> Result { + let constraints = maybe!({ + let min_schema_version = params.min_schema_version?; + let max_schema_version = params.max_schema_version?; + let min_wasm_api_version = params.min_wasm_api_version?; + let max_wasm_api_version = params.max_wasm_api_version?; + + Some(ExtensionVersionConstraints { + schema_versions: min_schema_version..=max_schema_version, + wasm_api_versions: min_wasm_api_version..=max_wasm_api_version, + }) + }); + let extension = app .db - .get_extension(¶ms.extension_id) + .get_extension(¶ms.extension_id, constraints.as_ref()) .await? .ok_or_else(|| anyhow!("unknown extension"))?; download_extension( diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 637a8c31f5..0527e070ea 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -21,11 +21,13 @@ use sea_orm::{ FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement, TransactionTrait, }; -use serde::{ser::Error as _, Deserialize, Serialize, Serializer}; +use semantic_version::SemanticVersion; +use serde::{Deserialize, Serialize}; use sqlx::{ migrate::{Migrate, Migration, MigrationSource}, Connection, }; +use std::ops::RangeInclusive; use std::{ fmt::Write as _, future::Future, @@ -36,7 +38,7 @@ use std::{ sync::Arc, time::Duration, }; -use time::{format_description::well_known::iso8601, PrimitiveDateTime}; +use time::PrimitiveDateTime; use tokio::sync::{Mutex, OwnedMutexGuard}; #[cfg(test)] @@ -730,20 +732,7 @@ pub struct NewExtensionVersion { pub published_at: PrimitiveDateTime, } -pub fn serialize_iso8601( - datetime: &PrimitiveDateTime, - serializer: S, -) -> Result { - const SERDE_CONFIG: iso8601::EncodedConfig = iso8601::Config::DEFAULT - .set_year_is_six_digits(false) - .set_time_precision(iso8601::TimePrecision::Second { - decimal_digits: None, - }) - .encode(); - - datetime - .assume_utc() - .format(&time::format_description::well_known::Iso8601::) - .map_err(S::Error::custom)? - .serialize(serializer) +pub struct ExtensionVersionConstraints { + pub schema_versions: RangeInclusive, + pub wasm_api_versions: RangeInclusive, } diff --git a/crates/collab/src/db/queries/extensions.rs b/crates/collab/src/db/queries/extensions.rs index fc3def1d6d..d6938fd776 100644 --- a/crates/collab/src/db/queries/extensions.rs +++ b/crates/collab/src/db/queries/extensions.rs @@ -1,5 +1,8 @@ +use std::str::FromStr; + use chrono::Utc; use sea_orm::sea_query::IntoCondition; +use util::ResultExt; use super::*; @@ -32,23 +35,83 @@ impl Database { pub async fn get_extensions_by_ids( &self, ids: &[&str], - max_schema_version: i32, + constraints: Option<&ExtensionVersionConstraints>, ) -> Result> { self.transaction(|tx| async move { - let condition = Condition::all() - .add( - extension::Column::LatestVersion - .into_expr() - .eq(extension_version::Column::Version.into_expr()), - ) - .add(extension::Column::ExternalId.is_in(ids.iter().copied())) - .add(extension_version::Column::SchemaVersion.lte(max_schema_version)); + let extensions = extension::Entity::find() + .filter(extension::Column::ExternalId.is_in(ids.iter().copied())) + .all(&*tx) + .await?; - self.get_extensions_where(condition, None, &tx).await + let mut max_versions = self + .get_latest_versions_for_extensions(&extensions, constraints, &tx) + .await?; + + Ok(extensions + .into_iter() + .filter_map(|extension| { + let (version, _) = max_versions.remove(&extension.id)?; + Some(metadata_from_extension_and_version(extension, version)) + }) + .collect()) }) .await } + async fn get_latest_versions_for_extensions( + &self, + extensions: &[extension::Model], + constraints: Option<&ExtensionVersionConstraints>, + tx: &DatabaseTransaction, + ) -> Result> { + let mut versions = extension_version::Entity::find() + .filter( + extension_version::Column::ExtensionId + .is_in(extensions.iter().map(|extension| extension.id)), + ) + .stream(tx) + .await?; + + let mut max_versions = + HashMap::::default(); + while let Some(version) = versions.next().await { + let version = version?; + let Some(extension_version) = SemanticVersion::from_str(&version.version).log_err() + else { + continue; + }; + + if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id) { + if max_extension_version > &extension_version { + continue; + } + } + + if let Some(constraints) = constraints { + if !constraints + .schema_versions + .contains(&version.schema_version) + { + continue; + } + + if let Some(wasm_api_version) = version.wasm_api_version.as_ref() { + if let Some(version) = SemanticVersion::from_str(wasm_api_version).log_err() { + if !constraints.wasm_api_versions.contains(&version) { + continue; + } + } else { + continue; + } + } + } + + max_versions.insert(version.extension_id, (version, extension_version)); + } + + Ok(max_versions) + } + /// Returns all of the versions for the extension with the given ID. pub async fn get_extension_versions( &self, @@ -88,22 +151,26 @@ impl Database { .collect()) } - pub async fn get_extension(&self, extension_id: &str) -> Result> { + pub async fn get_extension( + &self, + extension_id: &str, + constraints: Option<&ExtensionVersionConstraints>, + ) -> Result> { self.transaction(|tx| async move { let extension = extension::Entity::find() .filter(extension::Column::ExternalId.eq(extension_id)) - .filter( - extension::Column::LatestVersion - .into_expr() - .eq(extension_version::Column::Version.into_expr()), - ) - .inner_join(extension_version::Entity) - .select_also(extension_version::Entity) .one(&*tx) - .await?; + .await? + .ok_or_else(|| anyhow!("no such extension: {extension_id}"))?; - Ok(extension.and_then(|(extension, version)| { - Some(metadata_from_extension_and_version(extension, version?)) + let extensions = [extension]; + let mut versions = self + .get_latest_versions_for_extensions(&extensions, constraints, &tx) + .await?; + let [extension] = extensions; + + Ok(versions.remove(&extension.id).map(|(max_version, _)| { + metadata_from_extension_and_version(extension, max_version) })) }) .await diff --git a/crates/collab/src/db/tests/extension_tests.rs b/crates/collab/src/db/tests/extension_tests.rs index 49e94e24d5..b91570c494 100644 --- a/crates/collab/src/db/tests/extension_tests.rs +++ b/crates/collab/src/db/tests/extension_tests.rs @@ -1,4 +1,5 @@ use super::Database; +use crate::db::ExtensionVersionConstraints; use crate::{ db::{queries::extensions::convert_time_to_chrono, ExtensionMetadata, NewExtensionVersion}, test_both_dbs, @@ -278,3 +279,108 @@ async fn test_extensions(db: &Arc) { ] ); } + +test_both_dbs!( + test_extensions_by_id, + test_extensions_by_id_postgres, + test_extensions_by_id_sqlite +); + +async fn test_extensions_by_id(db: &Arc) { + let versions = db.get_known_extension_versions().await.unwrap(); + assert!(versions.is_empty()); + + let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + assert!(extensions.is_empty()); + + let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap(); + let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time()); + + let t0_chrono = convert_time_to_chrono(t0); + + db.insert_extension_versions( + &[ + ( + "ext1", + vec![ + NewExtensionVersion { + name: "Extension 1".into(), + version: semver::Version::parse("0.0.1").unwrap(), + description: "an extension".into(), + authors: vec!["max".into()], + repository: "ext1/repo".into(), + schema_version: 1, + wasm_api_version: Some("0.0.4".into()), + published_at: t0, + }, + NewExtensionVersion { + name: "Extension 1".into(), + version: semver::Version::parse("0.0.2").unwrap(), + description: "a good extension".into(), + authors: vec!["max".into()], + repository: "ext1/repo".into(), + schema_version: 1, + wasm_api_version: Some("0.0.4".into()), + published_at: t0, + }, + NewExtensionVersion { + name: "Extension 1".into(), + version: semver::Version::parse("0.0.3").unwrap(), + description: "a real good extension".into(), + authors: vec!["max".into(), "marshall".into()], + repository: "ext1/repo".into(), + schema_version: 1, + wasm_api_version: Some("0.0.5".into()), + published_at: t0, + }, + ], + ), + ( + "ext2", + vec![NewExtensionVersion { + name: "Extension 2".into(), + version: semver::Version::parse("0.2.0").unwrap(), + description: "a great extension".into(), + authors: vec!["marshall".into()], + repository: "ext2/repo".into(), + schema_version: 0, + wasm_api_version: None, + published_at: t0, + }], + ), + ] + .into_iter() + .collect(), + ) + .await + .unwrap(); + + let extensions = db + .get_extensions_by_ids( + &["ext1"], + Some(&ExtensionVersionConstraints { + schema_versions: 1..=1, + wasm_api_versions: "0.0.1".parse().unwrap()..="0.0.4".parse().unwrap(), + }), + ) + .await + .unwrap(); + + assert_eq!( + extensions, + &[ExtensionMetadata { + id: "ext1".into(), + manifest: rpc::ExtensionApiManifest { + name: "Extension 1".into(), + version: "0.0.2".into(), + authors: vec!["max".into()], + description: Some("a good extension".into()), + repository: "ext1/repo".into(), + schema_version: Some(1), + wasm_api_version: Some("0.0.4".into()), + }, + published_at: t0_chrono, + download_count: 0, + }] + ); +} diff --git a/crates/extension/Cargo.toml b/crates/extension/Cargo.toml index e9b02c72dd..df02174e1e 100644 --- a/crates/extension/Cargo.toml +++ b/crates/extension/Cargo.toml @@ -12,10 +12,6 @@ workspace = true path = "src/extension_store.rs" doctest = false -[[bin]] -name = "extension_json_schemas" -path = "src/extension_json_schemas.rs" - [dependencies] anyhow.workspace = true async-compression.workspace = true diff --git a/crates/extension/src/extension_json_schemas.rs b/crates/extension/src/extension_json_schemas.rs deleted file mode 100644 index b46e72fce6..0000000000 --- a/crates/extension/src/extension_json_schemas.rs +++ /dev/null @@ -1,17 +0,0 @@ -use language::LanguageConfig; -use schemars::schema_for; -use theme::ThemeFamilyContent; - -fn main() { - let theme_family_schema = schema_for!(ThemeFamilyContent); - let language_config_schema = schema_for!(LanguageConfig); - - println!( - "{}", - serde_json::to_string_pretty(&theme_family_schema).unwrap() - ); - println!( - "{}", - serde_json::to_string_pretty(&language_config_schema).unwrap() - ); -} diff --git a/crates/extension/src/extension_store.rs b/crates/extension/src/extension_store.rs index 904d3cd25d..72001a0f73 100644 --- a/crates/extension/src/extension_store.rs +++ b/crates/extension/src/extension_store.rs @@ -36,6 +36,7 @@ use node_runtime::NodeRuntime; use semantic_version::SemanticVersion; use serde::{Deserialize, Serialize}; use settings::Settings; +use std::ops::RangeInclusive; use std::str::FromStr; use std::{ cmp::Ordering, @@ -51,7 +52,10 @@ use util::{ paths::EXTENSIONS_DIR, ResultExt, }; -use wasm_host::{wit::is_supported_wasm_api_version, WasmExtension, WasmHost}; +use wasm_host::{ + wit::{is_supported_wasm_api_version, wasm_api_version_range}, + WasmExtension, WasmHost, +}; pub use extension_manifest::{ ExtensionLibraryKind, ExtensionManifest, GrammarManifestEntry, OldExtensionManifest, @@ -64,6 +68,11 @@ const FS_WATCH_LATENCY: Duration = Duration::from_millis(100); /// The current extension [`SchemaVersion`] supported by Zed. const CURRENT_SCHEMA_VERSION: SchemaVersion = SchemaVersion(1); +/// Returns the [`SchemaVersion`] range that is compatible with this version of Zed. +pub fn schema_version_range() -> RangeInclusive { + SchemaVersion::ZERO..=CURRENT_SCHEMA_VERSION +} + /// Returns whether the given extension version is compatible with this version of Zed. pub fn is_version_compatible(extension_version: &ExtensionMetadata) -> bool { let schema_version = extension_version.manifest.schema_version.unwrap_or(0); @@ -412,15 +421,15 @@ impl ExtensionStore { query.push(("filter", search)); } - self.fetch_extensions_from_api("/extensions", query, cx) + self.fetch_extensions_from_api("/extensions", &query, cx) } pub fn fetch_extensions_with_update_available( &mut self, cx: &mut ModelContext, ) -> Task>> { - let version = CURRENT_SCHEMA_VERSION.to_string(); - let mut query = vec![("max_schema_version", version.as_str())]; + let schema_versions = schema_version_range(); + let wasm_api_versions = wasm_api_version_range(); let extension_settings = ExtensionSettings::get_global(cx); let extension_ids = self .extension_index @@ -430,9 +439,20 @@ impl ExtensionStore { .filter(|id| extension_settings.should_auto_update(id)) .collect::>() .join(","); - query.push(("ids", &extension_ids)); - - let task = self.fetch_extensions_from_api("/extensions", query, cx); + let task = self.fetch_extensions_from_api( + "/extensions/updates", + &[ + ("min_schema_version", &schema_versions.start().to_string()), + ("max_schema_version", &schema_versions.end().to_string()), + ( + "min_wasm_api_version", + &wasm_api_versions.start().to_string(), + ), + ("max_wasm_api_version", &wasm_api_versions.end().to_string()), + ("ids", &extension_ids), + ], + cx, + ); cx.spawn(move |this, mut cx| async move { let extensions = task.await?; this.update(&mut cx, |this, _cx| { @@ -456,7 +476,7 @@ impl ExtensionStore { extension_id: &str, cx: &mut ModelContext, ) -> Task>> { - self.fetch_extensions_from_api(&format!("/extensions/{extension_id}"), Vec::new(), cx) + self.fetch_extensions_from_api(&format!("/extensions/{extension_id}"), &[], cx) } pub fn check_for_updates(&mut self, cx: &mut ModelContext) { @@ -500,7 +520,7 @@ impl ExtensionStore { fn fetch_extensions_from_api( &self, path: &str, - query: Vec<(&str, &str)>, + query: &[(&str, &str)], cx: &mut ModelContext<'_, ExtensionStore>, ) -> Task>> { let url = self.http_client.build_zed_api_url(path, &query); @@ -614,9 +634,23 @@ impl ExtensionStore { ) { log::info!("installing extension {extension_id} latest version"); + let schema_versions = schema_version_range(); + let wasm_api_versions = wasm_api_version_range(); + let Some(url) = self .http_client - .build_zed_api_url(&format!("/extensions/{extension_id}/download"), &[]) + .build_zed_api_url( + &format!("/extensions/{extension_id}/download"), + &[ + ("min_schema_version", &schema_versions.start().to_string()), + ("max_schema_version", &schema_versions.end().to_string()), + ( + "min_wasm_api_version", + &wasm_api_versions.start().to_string(), + ), + ("max_wasm_api_version", &wasm_api_versions.end().to_string()), + ], + ) .log_err() else { return; diff --git a/crates/extension/src/wasm_host/wit.rs b/crates/extension/src/wasm_host/wit.rs index a4790deec1..da14f04664 100644 --- a/crates/extension/src/wasm_host/wit.rs +++ b/crates/extension/src/wasm_host/wit.rs @@ -5,6 +5,7 @@ use super::{wasm_engine, WasmState}; use anyhow::{Context, Result}; use language::LspAdapterDelegate; use semantic_version::SemanticVersion; +use std::ops::RangeInclusive; use std::sync::Arc; use wasmtime::{ component::{Component, Instance, Linker, Resource}, @@ -30,7 +31,13 @@ fn wasi_view(state: &mut WasmState) -> &mut WasmState { /// Returns whether the given Wasm API version is supported by the Wasm host. pub fn is_supported_wasm_api_version(version: SemanticVersion) -> bool { - since_v0_0_1::MIN_VERSION <= version && version <= latest::MAX_VERSION + wasm_api_version_range().contains(&version) +} + +/// Returns the Wasm API version range that is supported by the Wasm host. +#[inline(always)] +pub fn wasm_api_version_range() -> RangeInclusive { + since_v0_0_1::MIN_VERSION..=latest::MAX_VERSION } pub enum Extension { diff --git a/crates/extensions_ui/src/extensions_ui.rs b/crates/extensions_ui/src/extensions_ui.rs index 1649230f5f..5a1278b0f7 100644 --- a/crates/extensions_ui/src/extensions_ui.rs +++ b/crates/extensions_ui/src/extensions_ui.rs @@ -587,12 +587,11 @@ impl ExtensionsPage { .disabled(disabled) .on_click(cx.listener({ let extension_id = extension.id.clone(); - let version = extension.manifest.version.clone(); move |this, _, cx| { this.telemetry .report_app_event("extensions: install extension".to_string()); ExtensionStore::global(cx).update(cx, |store, cx| { - store.install_extension(extension_id.clone(), version.clone(), cx) + store.install_latest_extension(extension_id.clone(), cx) }); } })),