Add an extensions API to the collaboration server (#7807)

This PR adds a REST API to the collab server for searching and
downloading extensions. Previously, we had implemented this API in
zed.dev directly, but this implementation is better, because we use the
collab database to store the download counts for extensions.

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Conrad <conrad@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-02-15 12:53:57 -08:00 committed by GitHub
parent bdc2558eac
commit e1ae0d46da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 1755 additions and 174 deletions

View File

@ -35,6 +35,9 @@ jobs:
submodules: "recursive"
fetch-depth: 0
- name: Remove untracked files
run: git clean -df
- name: Set up default .cargo/config.toml
run: cp ./.cargo/ci-config.toml ~/.cargo/config.toml

View File

@ -45,8 +45,18 @@ jobs:
submodules: "recursive"
fetch-depth: 0
- name: Install cargo nextest
shell: bash -euxo pipefail {0}
run: |
cargo install cargo-nextest
- name: Limit target directory size
shell: bash -euxo pipefail {0}
run: script/clear-target-dir-if-larger-than 100
- name: Run tests
uses: ./.github/actions/run_tests
shell: bash -euxo pipefail {0}
run: cargo nextest run --package collab --no-fail-fast
publish:
name: Publish collab server image
@ -90,22 +100,26 @@ jobs:
- name: Sign into Kubernetes
run: doctl kubernetes cluster kubeconfig save --expiry-seconds 600 ${{ secrets.CLUSTER_NAME }}
- name: Determine namespace
- name: Start rollout
run: |
set -eu
if [[ $GITHUB_REF_NAME = "collab-production" ]]; then
echo "Deploying collab:$GITHUB_SHA to production"
echo "KUBE_NAMESPACE=production" >> $GITHUB_ENV
export ZED_KUBE_NAMESPACE=production
elif [[ $GITHUB_REF_NAME = "collab-staging" ]]; then
echo "Deploying collab:$GITHUB_SHA to staging"
echo "KUBE_NAMESPACE=staging" >> $GITHUB_ENV
export ZED_KUBE_NAMESPACE=staging
else
echo "cowardly refusing to deploy from an unknown branch"
exit 1
fi
- name: Start rollout
run: kubectl -n "$KUBE_NAMESPACE" set image deployment/collab collab=registry.digitalocean.com/zed/collab:${GITHUB_SHA}
echo "Deploying collab:$GITHUB_SHA to $ZED_KUBE_NAMESPACE"
- name: Wait for rollout to finish
run: kubectl -n "$KUBE_NAMESPACE" rollout status deployment/collab
source script/lib/deploy-helpers.sh
export_vars_for_environment $ZED_KUBE_NAMESPACE
export ZED_DO_CERTIFICATE_ID=$(doctl compute certificate list --format ID --no-header)
export ZED_IMAGE_ID="registry.digitalocean.com/zed/collab:${GITHUB_SHA}"
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/collab --watch
echo "deployed collab.template.yml to ${ZED_KUBE_NAMESPACE}"

5
.gitignore vendored
View File

@ -5,12 +5,8 @@
.DS_Store
/plugins/bin
/script/node_modules
/styles/node_modules
/styles/src/types/zed.ts
/crates/theme/schemas/theme.json
/crates/collab/static/styles.css
/crates/collab/.admins.json
/vendor/bin
/assets/*licenses.md
**/venv
.build
@ -25,3 +21,4 @@ DerivedData/
**/*.db
.pytest_cache
.venv
.blob_store

745
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,91 +1,91 @@
[workspace]
members = [
"crates/activity_indicator",
"crates/ai",
"crates/assets",
"crates/assistant",
"crates/audio",
"crates/auto_update",
"crates/breadcrumbs",
"crates/call",
"crates/channel",
"crates/cli",
"crates/client",
"crates/clock",
"crates/collab",
"crates/collab_ui",
"crates/collections",
"crates/command_palette",
"crates/copilot",
"crates/copilot_ui",
"crates/db",
"crates/diagnostics",
"crates/editor",
"crates/extension",
"crates/extensions_ui",
"crates/feature_flags",
"crates/feedback",
"crates/file_finder",
"crates/fs",
"crates/fsevent",
"crates/fuzzy",
"crates/git",
"crates/go_to_line",
"crates/gpui",
"crates/gpui_macros",
"crates/install_cli",
"crates/journal",
"crates/language",
"crates/language_selector",
"crates/language_tools",
"crates/live_kit_client",
"crates/live_kit_server",
"crates/lsp",
"crates/markdown_preview",
"crates/media",
"crates/menu",
"crates/multi_buffer",
"crates/node_runtime",
"crates/notifications",
"crates/outline",
"crates/picker",
"crates/plugin",
"crates/plugin_macros",
"crates/prettier",
"crates/project",
"crates/project_panel",
"crates/project_symbols",
"crates/quick_action_bar",
"crates/recent_projects",
"crates/refineable",
"crates/refineable/derive_refineable",
"crates/release_channel",
"crates/rich_text",
"crates/rope",
"crates/rpc",
"crates/search",
"crates/semantic_index",
"crates/settings",
"crates/snippet",
"crates/sqlez",
"crates/sqlez_macros",
"crates/story",
"crates/storybook",
"crates/sum_tree",
"crates/terminal",
"crates/terminal_view",
"crates/text",
"crates/theme",
"crates/theme_importer",
"crates/theme_selector",
"crates/ui",
"crates/util",
"crates/vcs_menu",
"crates/vim",
"crates/welcome",
"crates/workspace",
"crates/zed",
"crates/zed_actions",
"crates/activity_indicator",
"crates/ai",
"crates/assets",
"crates/assistant",
"crates/audio",
"crates/auto_update",
"crates/breadcrumbs",
"crates/call",
"crates/channel",
"crates/cli",
"crates/client",
"crates/clock",
"crates/collab",
"crates/collab_ui",
"crates/collections",
"crates/command_palette",
"crates/copilot",
"crates/copilot_ui",
"crates/db",
"crates/diagnostics",
"crates/editor",
"crates/extension",
"crates/extensions_ui",
"crates/feature_flags",
"crates/feedback",
"crates/file_finder",
"crates/fs",
"crates/fsevent",
"crates/fuzzy",
"crates/git",
"crates/go_to_line",
"crates/gpui",
"crates/gpui_macros",
"crates/install_cli",
"crates/journal",
"crates/language",
"crates/language_selector",
"crates/language_tools",
"crates/live_kit_client",
"crates/live_kit_server",
"crates/lsp",
"crates/markdown_preview",
"crates/media",
"crates/menu",
"crates/multi_buffer",
"crates/node_runtime",
"crates/notifications",
"crates/outline",
"crates/picker",
"crates/plugin",
"crates/plugin_macros",
"crates/prettier",
"crates/project",
"crates/project_panel",
"crates/project_symbols",
"crates/quick_action_bar",
"crates/recent_projects",
"crates/refineable",
"crates/refineable/derive_refineable",
"crates/release_channel",
"crates/rich_text",
"crates/rope",
"crates/rpc",
"crates/search",
"crates/semantic_index",
"crates/settings",
"crates/snippet",
"crates/sqlez",
"crates/sqlez_macros",
"crates/story",
"crates/storybook",
"crates/sum_tree",
"crates/terminal",
"crates/terminal_view",
"crates/text",
"crates/theme",
"crates/theme_importer",
"crates/theme_selector",
"crates/ui",
"crates/util",
"crates/vcs_menu",
"crates/vim",
"crates/welcome",
"crates/workspace",
"crates/zed",
"crates/zed_actions",
]
default-members = ["crates/zed"]
resolver = "2"
@ -191,8 +191,8 @@ globset = "0.4"
indoc = "1"
# We explicitly disable a http2 support in isahc.
isahc = { version = "1.7.2", default-features = false, features = [
"static-curl",
"text-decoding",
"static-curl",
"text-decoding",
] }
lazy_static = "1.4.0"
log = { version = "0.4.16", features = ["kv_unstable_serde"] }
@ -208,12 +208,13 @@ regex = "1.5"
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
rust-embed = { version = "8.0", features = ["include-exclude"] }
schemars = "0.8"
semver = { version = "1.0" }
serde = { version = "1.0", features = ["derive", "rc"] }
serde_derive = { version = "1.0", features = ["deserialize_in_place"] }
serde_json = { version = "1.0", features = ["preserve_order", "raw_value"] }
serde_json_lenient = { version = "0.1", features = [
"preserve_order",
"raw_value",
"preserve_order",
"raw_value",
] }
serde_repr = "0.1"
smallvec = { version = "1.6", features = ["union"] }
@ -223,7 +224,11 @@ sysinfo = "0.29.10"
tempfile = "3.9.0"
thiserror = "1.0.29"
tiktoken-rs = "0.5.7"
time = { version = "0.3", features = ["serde", "serde-well-known"] }
time = { version = "0.3", features = [
"serde",
"serde-well-known",
"formatting",
] }
toml = "0.5"
tree-sitter = { version = "0.20", features = ["wasm"] }
tree-sitter-astro = { git = "https://github.com/virchau13/tree-sitter-astro.git", rev = "e924787e12e8a03194f36a113290ac11d6dc10f3" }

View File

@ -1,2 +1,3 @@
collab: cd crates/collab && RUST_LOG=${RUST_LOG:-warn,collab=info} cargo run serve
collab: RUST_LOG=${RUST_LOG:-warn,collab=info} cargo run --package=collab serve
livekit: livekit-server --dev
blob_store: MINIO_ROOT_USER=the-blob-store-access-key MINIO_ROOT_PASSWORD=the-blob-store-secret-key minio server .blob_store

View File

@ -7,6 +7,11 @@ ZED_ENVIRONMENT = "development"
LIVE_KIT_SERVER = "http://localhost:7880"
LIVE_KIT_KEY = "devkey"
LIVE_KIT_SECRET = "secret"
BLOB_STORE_ACCESS_KEY = "the-blob-store-access-key"
BLOB_STORE_SECRET_KEY = "the-blob-store-secret-key"
BLOB_STORE_BUCKET = "the-extensions-bucket"
BLOB_STORE_URL = "http://127.0.0.1:9000"
BLOB_STORE_REGION = "the-region"
# RUST_LOG=info
# LOG_JSON=true

View File

@ -15,9 +15,11 @@ name = "seed"
required-features = ["seed-support"]
[dependencies]
anyhow.workspace = true
async-tungstenite = "0.16"
axum = { version = "0.5", features = ["json", "headers", "ws"] }
anyhow.workspace = true
aws-config = { version = "1.1.5" }
aws-sdk-s3 = { version = "1.15.0" }
async-tungstenite = "0.16"
axum-extra = { version = "0.3", features = ["erased-json"] }
base64 = "0.13"
chrono.workspace = true
@ -40,13 +42,26 @@ rand.workspace = true
reqwest = { version = "0.11", features = ["json"], optional = true }
rpc.workspace = true
scrypt = "0.7"
sea-orm = { version = "0.12.x", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls", "with-uuid"] }
sea-orm = { version = "0.12.x", features = [
"sqlx-postgres",
"postgres-array",
"runtime-tokio-rustls",
"with-uuid",
] }
semver.workspace = true
serde.workspace = true
serde_derive.workspace = true
serde_json.workspace = true
sha-1 = "0.9"
smallvec.workspace = true
sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid", "any"] }
sqlx = { version = "0.7", features = [
"runtime-tokio-rustls",
"postgres",
"json",
"time",
"uuid",
"any",
] }
text.workspace = true
time.workspace = true
tokio = { version = "1", features = ["full"] }

View File

@ -105,6 +105,31 @@ spec:
secretKeyRef:
name: livekit
key: secret
- name: BLOB_STORE_ACCESS_KEY
valueFrom:
secretKeyRef:
name: blob-store
key: access_key
- name: BLOB_STORE_SECRET_KEY
valueFrom:
secretKeyRef:
name: blob-store
key: secret_key
- name: BLOB_STORE_URL
valueFrom:
secretKeyRef:
name: blob-store
key: url
- name: BLOB_STORE_REGION
valueFrom:
secretKeyRef:
name: blob-store
key: region
- name: BLOB_STORE_BUCKET
valueFrom:
secretKeyRef:
name: blob-store
key: bucket
- name: INVITE_LINK_PREFIX
value: ${INVITE_LINK_PREFIX}
- name: RUST_BACKTRACE

View File

@ -353,3 +353,25 @@ CREATE TABLE contributors (
signed_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (user_id)
);
CREATE TABLE extensions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
external_id TEXT NOT NULL,
name TEXT NOT NULL,
latest_version TEXT NOT NULL,
total_download_count INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE extension_versions (
extension_id INTEGER REFERENCES extensions(id),
version TEXT NOT NULL,
published_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
authors TEXT NOT NULL,
repository TEXT NOT NULL,
description TEXT NOT NULL,
download_count INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (extension_id, version)
);
CREATE UNIQUE INDEX "index_extensions_external_id" ON "extensions" ("external_id");
CREATE INDEX "index_extensions_total_download_count" ON "extensions" ("total_download_count");

View File

@ -0,0 +1,22 @@
CREATE TABLE IF NOT EXISTS extensions (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
external_id TEXT NOT NULL,
latest_version TEXT NOT NULL,
total_download_count BIGINT NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS extension_versions (
extension_id INTEGER REFERENCES extensions(id),
version TEXT NOT NULL,
published_at TIMESTAMP NOT NULL DEFAULT now(),
authors TEXT NOT NULL,
repository TEXT NOT NULL,
description TEXT NOT NULL,
download_count BIGINT NOT NULL DEFAULT 0,
PRIMARY KEY(extension_id, version)
);
CREATE UNIQUE INDEX "index_extensions_external_id" ON "extensions" ("external_id");
CREATE INDEX "trigram_index_extensions_name" ON "extensions" USING GIN(name gin_trgm_ops);
CREATE INDEX "index_extensions_total_download_count" ON "extensions" ("total_download_count");

View File

@ -1,3 +1,5 @@
mod extensions;
use crate::{
auth,
db::{ContributorSelector, User, UserId},
@ -20,6 +22,8 @@ use std::sync::Arc;
use tower::ServiceBuilder;
use tracing::instrument;
pub use extensions::fetch_extensions_from_blob_store_periodically;
pub fn routes(rpc_server: Arc<rpc::Server>, state: Arc<AppState>) -> Router<Body> {
Router::new()
.route("/user", get(get_authenticated_user))
@ -28,6 +32,7 @@ pub fn routes(rpc_server: Arc<rpc::Server>, state: Arc<AppState>) -> Router<Body
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
.route("/contributors", get(get_contributors).post(add_contributor))
.route("/contributor", get(check_is_contributor))
.merge(extensions::router())
.layer(
ServiceBuilder::new()
.layer(Extension(state))

View File

@ -0,0 +1,237 @@
use crate::{
db::{ExtensionMetadata, NewExtensionVersion},
executor::Executor,
AppState, Error, Result,
};
use anyhow::{anyhow, Context as _};
use aws_sdk_s3::presigning::PresigningConfig;
use axum::{
extract::{Path, Query},
response::Redirect,
routing::get,
Extension, Json, Router,
};
use collections::HashMap;
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use std::{sync::Arc, time::Duration};
use time::PrimitiveDateTime;
use util::ResultExt;
pub fn router() -> Router {
Router::new()
.route("/extensions", get(get_extensions))
.route(
"/extensions/:extension_id/:version/download",
get(download_extension),
)
}
#[derive(Debug, Deserialize)]
struct GetExtensionsParams {
filter: Option<String>,
}
#[derive(Debug, Deserialize)]
struct DownloadExtensionParams {
extension_id: String,
version: String,
}
#[derive(Debug, Serialize)]
struct GetExtensionsResponse {
pub data: Vec<ExtensionMetadata>,
}
#[derive(Deserialize)]
struct ExtensionManifest {
name: String,
version: String,
description: Option<String>,
authors: Vec<String>,
repository: String,
}
async fn get_extensions(
Extension(app): Extension<Arc<AppState>>,
Query(params): Query<GetExtensionsParams>,
) -> Result<Json<GetExtensionsResponse>> {
let extensions = app.db.get_extensions(params.filter.as_deref(), 30).await?;
Ok(Json(GetExtensionsResponse { data: extensions }))
}
async fn download_extension(
Extension(app): Extension<Arc<AppState>>,
Path(params): Path<DownloadExtensionParams>,
) -> Result<Redirect> {
let Some((blob_store_client, bucket)) = app
.blob_store_client
.clone()
.zip(app.config.blob_store_bucket.clone())
else {
Err(Error::Http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let DownloadExtensionParams {
extension_id,
version,
} = params;
let version_exists = app
.db
.record_extension_download(&extension_id, &version)
.await?;
if !version_exists {
Err(Error::Http(
StatusCode::NOT_FOUND,
"unknown extension version".into(),
))?;
}
let url = blob_store_client
.get_object()
.bucket(bucket)
.key(format!(
"extensions/{extension_id}/{version}/archive.tar.gz"
))
.presigned(PresigningConfig::expires_in(EXTENSION_DOWNLOAD_URL_LIFETIME).unwrap())
.await
.map_err(|e| anyhow!("failed to create presigned extension download url {e}"))?;
Ok(Redirect::temporary(url.uri()))
}
const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>, executor: Executor) {
let Some(blob_store_client) = app_state.blob_store_client.clone() else {
log::info!("no blob store client");
return;
};
let Some(blob_store_bucket) = app_state.config.blob_store_bucket.clone() else {
log::info!("no blob store bucket");
return;
};
executor.spawn_detached({
let executor = executor.clone();
async move {
loop {
fetch_extensions_from_blob_store(
&blob_store_client,
&blob_store_bucket,
&app_state,
)
.await
.log_err();
executor.sleep(EXTENSION_FETCH_INTERVAL).await;
}
}
});
}
async fn fetch_extensions_from_blob_store(
blob_store_client: &aws_sdk_s3::Client,
blob_store_bucket: &String,
app_state: &Arc<AppState>,
) -> anyhow::Result<()> {
let list = blob_store_client
.list_objects()
.bucket(blob_store_bucket)
.prefix("extensions/")
.send()
.await?;
let objects = list
.contents
.ok_or_else(|| anyhow!("missing bucket contents"))?;
let mut published_versions = HashMap::<&str, Vec<&str>>::default();
for object in &objects {
let Some(key) = object.key.as_ref() else {
continue;
};
let mut parts = key.split('/');
let Some(_) = parts.next().filter(|part| *part == "extensions") else {
continue;
};
let Some(extension_id) = parts.next() else {
continue;
};
let Some(version) = parts.next() else {
continue;
};
published_versions
.entry(extension_id)
.or_default()
.push(version);
}
let known_versions = app_state.db.get_known_extension_versions().await?;
let mut new_versions = HashMap::<&str, Vec<NewExtensionVersion>>::default();
let empty = Vec::new();
for (extension_id, published_versions) in published_versions {
let known_versions = known_versions.get(extension_id).unwrap_or(&empty);
for published_version in published_versions {
if known_versions
.binary_search_by_key(&published_version, String::as_str)
.is_err()
{
let object = blob_store_client
.get_object()
.bucket(blob_store_bucket)
.key(format!(
"extensions/{extension_id}/{published_version}/manifest.json"
))
.send()
.await?;
let manifest_bytes = object
.body
.collect()
.await
.map(|data| data.into_bytes())
.with_context(|| format!("failed to download manifest for extension {extension_id} version {published_version}"))?
.to_vec();
let manifest = serde_json::from_slice::<ExtensionManifest>(&manifest_bytes)
.with_context(|| format!("invalid manifest for extension {extension_id} version {published_version}: {}", String::from_utf8_lossy(&manifest_bytes)))?;
let published_at = object.last_modified.ok_or_else(|| anyhow!("missing last modified timestamp for extension {extension_id} version {published_version}"))?;
let published_at =
time::OffsetDateTime::from_unix_timestamp_nanos(published_at.as_nanos())?;
let published_at = PrimitiveDateTime::new(published_at.date(), published_at.time());
let version = semver::Version::parse(&manifest.version).with_context(|| {
format!(
"invalid version for extension {extension_id} version {published_version}"
)
})?;
new_versions
.entry(extension_id)
.or_default()
.push(NewExtensionVersion {
name: manifest.name,
version,
description: manifest.description.unwrap_or_default(),
authors: manifest.authors,
repository: manifest.repository,
published_at,
});
}
}
}
app_state
.db
.insert_extension_versions(&new_versions)
.await?;
Ok(())
}

View File

@ -1,12 +1,8 @@
#[cfg(test)]
pub mod tests;
#[cfg(test)]
pub use tests::TestDb;
mod ids;
mod queries;
mod tables;
#[cfg(test)]
pub mod tests;
use crate::{executor::Executor, Error, Result};
use anyhow::anyhow;
@ -25,7 +21,7 @@ use sea_orm::{
FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
TransactionTrait,
};
use serde::{Deserialize, Serialize};
use serde::{ser::Error as _, Deserialize, Serialize, Serializer};
use sqlx::{
migrate::{Migrate, Migration, MigrationSource},
Connection,
@ -40,13 +36,17 @@ use std::{
sync::Arc,
time::Duration,
};
pub use tables::*;
use time::{format_description::well_known::iso8601, PrimitiveDateTime};
use tokio::sync::{Mutex, OwnedMutexGuard};
#[cfg(test)]
pub use tests::TestDb;
pub use ids::*;
pub use queries::contributors::ContributorSelector;
pub use sea_orm::ConnectOptions;
pub use tables::user::Model as User;
pub use tables::*;
/// Database gives you a handle that lets you access the database.
/// It handles pooling internally.
@ -717,3 +717,42 @@ pub struct WorktreeSettingsFile {
pub path: String,
pub content: String,
}
pub struct NewExtensionVersion {
pub name: String,
pub version: semver::Version,
pub description: String,
pub authors: Vec<String>,
pub repository: String,
pub published_at: PrimitiveDateTime,
}
#[derive(Debug, Serialize, PartialEq)]
pub struct ExtensionMetadata {
pub id: String,
pub name: String,
pub version: String,
pub authors: Vec<String>,
pub repository: String,
#[serde(serialize_with = "serialize_iso8601")]
pub published_at: PrimitiveDateTime,
pub download_count: u64,
}
pub fn serialize_iso8601<S: Serializer>(
datetime: &PrimitiveDateTime,
serializer: S,
) -> Result<S::Ok, S::Error> {
const SERDE_CONFIG: iso8601::EncodedConfig = iso8601::Config::DEFAULT
.set_year_is_six_digits(false)
.set_time_precision(iso8601::TimePrecision::Second {
decimal_digits: None,
})
.encode();
datetime
.assume_utc()
.format(&time::format_description::well_known::Iso8601::<SERDE_CONFIG>)
.map_err(S::Error::custom)?
.serialize(serializer)
}

View File

@ -85,6 +85,7 @@ id_type!(SignupId);
id_type!(UserId);
id_type!(ChannelBufferCollaboratorId);
id_type!(FlagId);
id_type!(ExtensionId);
id_type!(NotificationId);
id_type!(NotificationKindId);

View File

@ -5,6 +5,7 @@ pub mod buffers;
pub mod channels;
pub mod contacts;
pub mod contributors;
pub mod extensions;
pub mod messages;
pub mod notifications;
pub mod projects;

View File

@ -0,0 +1,205 @@
use super::*;
impl Database {
pub async fn get_extensions(
&self,
filter: Option<&str>,
limit: usize,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
let mut condition = Condition::all();
if let Some(filter) = filter {
let fuzzy_name_filter = Self::fuzzy_like_string(filter);
condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
}
let extensions = extension::Entity::find()
.filter(condition)
.order_by_desc(extension::Column::TotalDownloadCount)
.order_by_asc(extension::Column::Id)
.limit(Some(limit as u64))
.filter(
extension::Column::LatestVersion
.into_expr()
.eq(extension_version::Column::Version.into_expr()),
)
.inner_join(extension_version::Entity)
.select_also(extension_version::Entity)
.all(&*tx)
.await?;
Ok(extensions
.into_iter()
.filter_map(|(extension, latest_version)| {
let version = latest_version?;
Some(ExtensionMetadata {
id: extension.external_id,
name: extension.name,
version: version.version,
authors: version
.authors
.split(',')
.map(|author| author.trim().to_string())
.collect::<Vec<_>>(),
repository: version.repository,
published_at: version.published_at,
download_count: extension.total_download_count as u64,
})
})
.collect())
})
.await
}
pub async fn get_known_extension_versions<'a>(&self) -> Result<HashMap<String, Vec<String>>> {
self.transaction(|tx| async move {
let mut extension_external_ids_by_id = HashMap::default();
let mut rows = extension::Entity::find().stream(&*tx).await?;
while let Some(row) = rows.next().await {
let row = row?;
extension_external_ids_by_id.insert(row.id, row.external_id);
}
drop(rows);
let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
HashMap::default();
let mut rows = extension_version::Entity::find().stream(&*tx).await?;
while let Some(row) = rows.next().await {
let row = row?;
let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
continue;
};
let versions = known_versions_by_extension_id
.entry(extension_id.clone())
.or_default();
if let Err(ix) = versions.binary_search(&row.version) {
versions.insert(ix, row.version);
}
}
drop(rows);
Ok(known_versions_by_extension_id)
})
.await
}
pub async fn insert_extension_versions(
&self,
versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
) -> Result<()> {
self.transaction(|tx| async move {
for (external_id, versions) in versions_by_extension_id {
if versions.is_empty() {
continue;
}
let latest_version = versions
.iter()
.max_by_key(|version| &version.version)
.unwrap();
let insert = extension::Entity::insert(extension::ActiveModel {
name: ActiveValue::Set(latest_version.name.clone()),
external_id: ActiveValue::Set(external_id.to_string()),
id: ActiveValue::NotSet,
latest_version: ActiveValue::Set(latest_version.version.to_string()),
total_download_count: ActiveValue::NotSet,
})
.on_conflict(
OnConflict::columns([extension::Column::ExternalId])
.update_column(extension::Column::ExternalId)
.to_owned(),
);
let extension = if tx.support_returning() {
insert.exec_with_returning(&*tx).await?
} else {
// Sqlite
insert.exec_without_returning(&*tx).await?;
extension::Entity::find()
.filter(extension::Column::ExternalId.eq(*external_id))
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("failed to insert extension"))?
};
extension_version::Entity::insert_many(versions.iter().map(|version| {
extension_version::ActiveModel {
extension_id: ActiveValue::Set(extension.id),
published_at: ActiveValue::Set(version.published_at),
version: ActiveValue::Set(version.version.to_string()),
authors: ActiveValue::Set(version.authors.join(", ")),
repository: ActiveValue::Set(version.repository.clone()),
description: ActiveValue::Set(version.description.clone()),
download_count: ActiveValue::NotSet,
}
}))
.on_conflict(OnConflict::new().do_nothing().to_owned())
.exec_without_returning(&*tx)
.await?;
if let Ok(db_version) = semver::Version::parse(&extension.latest_version) {
if db_version >= latest_version.version {
continue;
}
}
let mut extension = extension.into_active_model();
extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
extension.name = ActiveValue::set(latest_version.name.clone());
extension::Entity::update(extension).exec(&*tx).await?;
}
Ok(())
})
.await
}
pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
self.transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryId {
Id,
}
let extension_id: Option<ExtensionId> = extension::Entity::find()
.filter(extension::Column::ExternalId.eq(extension))
.select_only()
.column(extension::Column::Id)
.into_values::<_, QueryId>()
.one(&*tx)
.await?;
let Some(extension_id) = extension_id else {
return Ok(false);
};
extension_version::Entity::update_many()
.col_expr(
extension_version::Column::DownloadCount,
extension_version::Column::DownloadCount.into_expr().add(1),
)
.filter(
extension_version::Column::ExtensionId
.eq(extension_id)
.and(extension_version::Column::Version.eq(version)),
)
.exec(&*tx)
.await?;
extension::Entity::update_many()
.col_expr(
extension::Column::TotalDownloadCount,
extension::Column::TotalDownloadCount.into_expr().add(1),
)
.filter(extension::Column::Id.eq(extension_id))
.exec(&*tx)
.await?;
Ok(true)
})
.await
}
}

View File

@ -10,6 +10,8 @@ pub mod channel_message;
pub mod channel_message_mention;
pub mod contact;
pub mod contributor;
pub mod extension;
pub mod extension_version;
pub mod feature_flag;
pub mod follower;
pub mod language_server;

View File

@ -0,0 +1,27 @@
use crate::db::ExtensionId;
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "extensions")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: ExtensionId,
pub external_id: String,
pub name: String,
pub latest_version: String,
pub total_download_count: i64,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_one = "super::extension_version::Entity")]
LatestVersion,
}
impl Related<super::extension_version::Entity> for Entity {
fn to() -> RelationDef {
Relation::LatestVersion.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,36 @@
use crate::db::ExtensionId;
use sea_orm::entity::prelude::*;
use time::PrimitiveDateTime;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "extension_versions")]
pub struct Model {
#[sea_orm(primary_key)]
pub extension_id: ExtensionId,
#[sea_orm(primary_key)]
pub version: String,
pub published_at: PrimitiveDateTime,
pub authors: String,
pub repository: String,
pub description: String,
pub download_count: i64,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::extension::Entity",
from = "Column::ExtensionId",
to = "super::extension::Column::Id"
on_condition = r#"super::extension::Column::LatestVersion.into_expr().eq(Column::Version.into_expr())"#
)]
Extension,
}
impl Related<super::extension::Entity> for Entity {
fn to() -> RelationDef {
Relation::Extension.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -2,6 +2,7 @@ mod buffer_tests;
mod channel_tests;
mod contributor_tests;
mod db_tests;
mod extension_tests;
mod feature_flag_tests;
mod message_tests;

View File

@ -0,0 +1,219 @@
use super::Database;
use crate::{
db::{ExtensionMetadata, NewExtensionVersion},
test_both_dbs,
};
use std::sync::Arc;
use time::{OffsetDateTime, PrimitiveDateTime};
test_both_dbs!(
test_extensions,
test_extensions_postgres,
test_extensions_sqlite
);
async fn test_extensions(db: &Arc<Database>) {
let versions = db.get_known_extension_versions().await.unwrap();
assert!(versions.is_empty());
let extensions = db.get_extensions(None, 5).await.unwrap();
assert!(extensions.is_empty());
let t0 = OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
let t0 = PrimitiveDateTime::new(t0.date(), t0.time());
db.insert_extension_versions(
&[
(
"ext1",
vec![
NewExtensionVersion {
name: "Extension 1".into(),
version: semver::Version::parse("0.0.1").unwrap(),
description: "an extension".into(),
authors: vec!["max".into()],
repository: "ext1/repo".into(),
published_at: t0,
},
NewExtensionVersion {
name: "Extension One".into(),
version: semver::Version::parse("0.0.2").unwrap(),
description: "a good extension".into(),
authors: vec!["max".into(), "marshall".into()],
repository: "ext1/repo".into(),
published_at: t0,
},
],
),
(
"ext2",
vec![NewExtensionVersion {
name: "Extension Two".into(),
version: semver::Version::parse("0.2.0").unwrap(),
description: "a great extension".into(),
authors: vec!["marshall".into()],
repository: "ext2/repo".into(),
published_at: t0,
}],
),
]
.into_iter()
.collect(),
)
.await
.unwrap();
let versions = db.get_known_extension_versions().await.unwrap();
assert_eq!(
versions,
[
("ext1".into(), vec!["0.0.1".into(), "0.0.2".into()]),
("ext2".into(), vec!["0.2.0".into()])
]
.into_iter()
.collect()
);
// The latest version of each extension is returned.
let extensions = db.get_extensions(None, 5).await.unwrap();
assert_eq!(
extensions,
&[
ExtensionMetadata {
id: "ext1".into(),
name: "Extension One".into(),
version: "0.0.2".into(),
authors: vec!["max".into(), "marshall".into()],
repository: "ext1/repo".into(),
published_at: t0,
download_count: 0,
},
ExtensionMetadata {
id: "ext2".into(),
name: "Extension Two".into(),
version: "0.2.0".into(),
authors: vec!["marshall".into()],
repository: "ext2/repo".into(),
published_at: t0,
download_count: 0
},
]
);
// Record extensions being downloaded.
for _ in 0..7 {
assert!(db.record_extension_download("ext2", "0.0.2").await.unwrap());
}
for _ in 0..3 {
assert!(db.record_extension_download("ext1", "0.0.1").await.unwrap());
}
for _ in 0..2 {
assert!(db.record_extension_download("ext1", "0.0.2").await.unwrap());
}
// Record download returns false if the extension does not exist.
assert!(!db
.record_extension_download("no-such-extension", "0.0.2")
.await
.unwrap());
// Extensions are returned in descending order of total downloads.
let extensions = db.get_extensions(None, 5).await.unwrap();
assert_eq!(
extensions,
&[
ExtensionMetadata {
id: "ext2".into(),
name: "Extension Two".into(),
version: "0.2.0".into(),
authors: vec!["marshall".into()],
repository: "ext2/repo".into(),
published_at: t0,
download_count: 7
},
ExtensionMetadata {
id: "ext1".into(),
name: "Extension One".into(),
version: "0.0.2".into(),
authors: vec!["max".into(), "marshall".into()],
repository: "ext1/repo".into(),
published_at: t0,
download_count: 5,
},
]
);
// Add more extensions, including a new version of `ext1`, and backfilling
// an older version of `ext2`.
db.insert_extension_versions(
&[
(
"ext1",
vec![NewExtensionVersion {
name: "Extension One".into(),
version: semver::Version::parse("0.0.3").unwrap(),
description: "a real good extension".into(),
authors: vec!["max".into(), "marshall".into()],
repository: "ext1/repo".into(),
published_at: t0,
}],
),
(
"ext2",
vec![NewExtensionVersion {
name: "Extension Two".into(),
version: semver::Version::parse("0.1.0").unwrap(),
description: "an old extension".into(),
authors: vec!["marshall".into()],
repository: "ext2/repo".into(),
published_at: t0,
}],
),
]
.into_iter()
.collect(),
)
.await
.unwrap();
let versions = db.get_known_extension_versions().await.unwrap();
assert_eq!(
versions,
[
(
"ext1".into(),
vec!["0.0.1".into(), "0.0.2".into(), "0.0.3".into()]
),
("ext2".into(), vec!["0.1.0".into(), "0.2.0".into()])
]
.into_iter()
.collect()
);
let extensions = db.get_extensions(None, 5).await.unwrap();
assert_eq!(
extensions,
&[
ExtensionMetadata {
id: "ext2".into(),
name: "Extension Two".into(),
version: "0.2.0".into(),
authors: vec!["marshall".into()],
repository: "ext2/repo".into(),
published_at: t0,
download_count: 7
},
ExtensionMetadata {
id: "ext1".into(),
name: "Extension One".into(),
version: "0.0.3".into(),
authors: vec!["max".into(), "marshall".into()],
repository: "ext1/repo".into(),
published_at: t0,
download_count: 5,
},
]
);
}

View File

@ -3,7 +3,8 @@ use std::fs;
pub fn load_dotenv() -> anyhow::Result<()> {
let env: toml::map::Map<String, toml::Value> = toml::de::from_str(
&fs::read_to_string("./.env.toml").map_err(|_| anyhow!("no .env.toml file found"))?,
&fs::read_to_string("./crates/collab/.env.toml")
.map_err(|_| anyhow!("no .env.toml file found"))?,
)?;
for (key, value) in env {

View File

@ -8,11 +8,14 @@ pub mod rpc;
#[cfg(test)]
mod tests;
use anyhow::anyhow;
use aws_config::{BehaviorVersion, Region};
use axum::{http::StatusCode, response::IntoResponse};
use db::Database;
use executor::Executor;
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
pub type Result<T, E = Error> = std::result::Result<T, E>;
@ -100,6 +103,11 @@ pub struct Config {
pub live_kit_secret: Option<String>,
pub rust_log: Option<String>,
pub log_json: Option<bool>,
pub blob_store_url: Option<String>,
pub blob_store_region: Option<String>,
pub blob_store_access_key: Option<String>,
pub blob_store_secret_key: Option<String>,
pub blob_store_bucket: Option<String>,
pub zed_environment: Arc<str>,
}
@ -118,6 +126,7 @@ pub struct MigrateConfig {
pub struct AppState {
pub db: Arc<Database>,
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
pub config: Config,
}
@ -146,8 +155,44 @@ impl AppState {
let this = Self {
db: Arc::new(db),
live_kit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
config,
};
Ok(Arc::new(this))
}
}
async fn build_blob_store_client(config: &Config) -> anyhow::Result<aws_sdk_s3::Client> {
let keys = aws_sdk_s3::config::Credentials::new(
config
.blob_store_access_key
.clone()
.ok_or_else(|| anyhow!("missing blob_store_access_key"))?,
config
.blob_store_secret_key
.clone()
.ok_or_else(|| anyhow!("missing blob_store_secret_key"))?,
None,
None,
"env",
);
let s3_config = aws_config::defaults(BehaviorVersion::latest())
.endpoint_url(
config
.blob_store_url
.as_ref()
.ok_or_else(|| anyhow!("missing blob_store_url"))?,
)
.region(Region::new(
config
.blob_store_region
.clone()
.ok_or_else(|| anyhow!("missing blob_store_region"))?,
))
.credentials_provider(keys)
.load()
.await;
Ok(aws_sdk_s3::Client::new(&s3_config))
}

View File

@ -1,6 +1,9 @@
use anyhow::anyhow;
use axum::{routing::get, Extension, Router};
use collab::{db, env, executor::Executor, AppState, Config, MigrateConfig, Result};
use collab::{
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
Config, MigrateConfig, Result,
};
use db::Database;
use std::{
env::args,
@ -50,6 +53,8 @@ async fn main() -> Result<()> {
let rpc_server = collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
rpc_server.start().await?;
fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
let app = collab::api::routes(rpc_server.clone(), state.clone())
.merge(collab::rpc::routes(rpc_server.clone()))
.merge(

View File

@ -479,6 +479,7 @@ impl TestServer {
Arc::new(AppState {
db: test_db.db().clone(),
live_kit_client: Some(Arc::new(fake_server.create_api_client())),
blob_store_client: None,
config: Config {
http_port: 0,
database_url: "".into(),
@ -491,6 +492,11 @@ impl TestServer {
rust_log: None,
log_json: None,
zed_environment: "test".into(),
blob_store_url: None,
blob_store_region: None,
blob_store_access_key: None,
blob_store_secret_key: None,
blob_store_bucket: None,
},
})
}

View File

@ -3,6 +3,10 @@
echo "installing foreman..."
which foreman > /dev/null || brew install foreman
echo "installing minio..."
which minio > /dev/null || brew install minio/stable/minio
mkdir -p .blob_store/the-extensions-bucket
echo "creating database..."
script/sqlx database create

View File

@ -1,5 +1,4 @@
#!/bin/bash
set -e
cd crates/collab
cargo run --quiet --package=collab --features seed-support --bin seed -- $@