diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 9d5218965b..03380a60a6 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -490,6 +490,7 @@ impl AssistantPanel { } language_model::Event::ProviderStateChanged => { this.ensure_authenticated(cx); + cx.notify() } language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { @@ -1712,6 +1713,7 @@ pub struct ContextEditor { assistant_panel: WeakView, error_message: Option, debug_inspector: Option, + show_accept_terms: bool, } const DEFAULT_TAB_TITLE: &str = "New Context"; @@ -1772,6 +1774,7 @@ impl ContextEditor { assistant_panel, error_message: None, debug_inspector: None, + show_accept_terms: false, }; this.update_message_headers(cx); this.insert_slash_command_output_sections(sections, cx); @@ -1804,6 +1807,16 @@ impl ContextEditor { } fn assist(&mut self, _: &Assist, cx: &mut ViewContext) { + let provider = LanguageModelRegistry::read_global(cx).active_provider(); + if provider + .as_ref() + .map_or(false, |provider| provider.must_accept_terms(cx)) + { + self.show_accept_terms = true; + cx.notify(); + return; + } + if !self.apply_active_workflow_step(cx) { self.error_message = None; self.send_to_model(cx); @@ -3388,7 +3401,14 @@ impl ContextEditor { None => (ButtonStyle::Filled, None), }; + let provider = LanguageModelRegistry::read_global(cx).active_provider(); + let disabled = self.show_accept_terms + && provider + .as_ref() + .map_or(false, |provider| provider.must_accept_terms(cx)); + ButtonLike::new("send_button") + .disabled(disabled) .style(style) .when_some(tooltip, |button, tooltip| { button.tooltip(move |_| tooltip.clone()) @@ -3437,6 +3457,15 @@ impl EventEmitter for ContextEditor {} impl Render for ContextEditor { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let provider = LanguageModelRegistry::read_global(cx).active_provider(); + let accept_terms = if self.show_accept_terms { + provider + .as_ref() + .and_then(|provider| provider.render_accept_terms(cx)) + } else { + None + }; + v_flex() .key_context("ContextEditor") .capture_action(cx.listener(ContextEditor::cancel)) @@ -3455,6 +3484,21 @@ impl Render for ContextEditor { .bg(cx.theme().colors().editor_background) .child(self.editor.clone()), ) + .when_some(accept_terms, |this, element| { + this.child( + div() + .absolute() + .right_4() + .bottom_10() + .max_w_96() + .py_2() + .px_3() + .elevation_2(cx) + .bg(cx.theme().colors().surface_background) + .occlude() + .child(element), + ) + }) .child( h_flex().flex_none().relative().child( h_flex() diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 5e8ad2181c..bc39661e29 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,5 +1,6 @@ use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use anyhow::{anyhow, Result}; +use chrono::Duration; use futures::{stream::BoxStream, StreamExt}; use gpui::{BackgroundExecutor, Context, Model, TestAppContext}; use parking_lot::Mutex; @@ -162,6 +163,11 @@ impl FakeServer { return Ok(*message.downcast().unwrap()); } + let accepted_tos_at = chrono::Utc::now() + .checked_sub_signed(Duration::hours(5)) + .expect("failed to build accepted_tos_at") + .timestamp() as u64; + if message.is::>() { self.respond( message @@ -172,6 +178,7 @@ impl FakeServer { metrics_id: "the-metrics-id".into(), staff: false, flags: Default::default(), + accepted_tos_at: Some(accepted_tos_at), }, ); continue; diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 6464ffb0cb..d816b5af12 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -1,5 +1,6 @@ use super::{proto, Client, Status, TypedEnvelope}; use anyhow::{anyhow, Context, Result}; +use chrono::{DateTime, Utc}; use collections::{hash_map::Entry, HashMap, HashSet}; use feature_flags::FeatureFlagAppExt; use futures::{channel::mpsc, Future, StreamExt}; @@ -94,6 +95,7 @@ pub struct UserStore { update_contacts_tx: mpsc::UnboundedSender, current_plan: Option, current_user: watch::Receiver>>, + accepted_tos_at: Option>>, contacts: Vec>, incoming_contact_requests: Vec>, outgoing_contact_requests: Vec>, @@ -150,6 +152,7 @@ impl UserStore { by_github_login: Default::default(), current_user: current_user_rx, current_plan: None, + accepted_tos_at: None, contacts: Default::default(), incoming_contact_requests: Default::default(), participant_indices: Default::default(), @@ -189,9 +192,10 @@ impl UserStore { } else { break; }; - let fetch_metrics_id = + let fetch_private_user_info = client.request(proto::GetPrivateUserInfo {}).log_err(); - let (user, info) = futures::join!(fetch_user, fetch_metrics_id); + let (user, info) = + futures::join!(fetch_user, fetch_private_user_info); cx.update(|cx| { if let Some(info) = info { @@ -202,9 +206,17 @@ impl UserStore { client.telemetry.set_authenticated_user_info( Some(info.metrics_id.clone()), staff, - ) + ); + + this.update(cx, |this, _| { + this.set_current_user_accepted_tos_at( + info.accepted_tos_at, + ); + }) + } else { + anyhow::Ok(()) } - })?; + })??; current_user_tx.send(user).await.ok(); @@ -680,6 +692,39 @@ impl UserStore { self.current_user.clone() } + pub fn current_user_has_accepted_terms(&self) -> Option { + self.accepted_tos_at + .map(|accepted_tos_at| accepted_tos_at.is_some()) + } + + pub fn accept_terms_of_service(&mut self, cx: &mut ModelContext) -> Task> { + if self.current_user().is_none() { + return Task::ready(Err(anyhow!("no current user"))); + }; + + let client = self.client.clone(); + cx.spawn(move |this, mut cx| async move { + if let Some(client) = client.upgrade() { + let response = client + .request(proto::AcceptTermsOfService {}) + .await + .context("error accepting tos")?; + + this.update(&mut cx, |this, _| { + this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at)) + }) + } else { + Err(anyhow!("client not found")) + } + }) + } + + fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option) { + self.accepted_tos_at = Some( + accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)), + ); + } + fn load_users( &mut self, request: impl RequestMessage, diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 01b18e8fd7..f3ab451fda 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -9,7 +9,8 @@ CREATE TABLE "users" ( "connected_once" BOOLEAN NOT NULL DEFAULT false, "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, "metrics_id" TEXT, - "github_user_id" INTEGER + "github_user_id" INTEGER, + "accepted_tos_at" TIMESTAMP WITHOUT TIME ZONE ); CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login"); CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code"); diff --git a/crates/collab/migrations/20240812073542_add_accepted_tos_at.sql b/crates/collab/migrations/20240812073542_add_accepted_tos_at.sql new file mode 100644 index 0000000000..43fa0e7bbd --- /dev/null +++ b/crates/collab/migrations/20240812073542_add_accepted_tos_at.sql @@ -0,0 +1 @@ +ALTER TABLE users ADD accepted_tos_at TIMESTAMP WITHOUT TIME ZONE; diff --git a/crates/collab/src/db/queries/users.rs b/crates/collab/src/db/queries/users.rs index 60c4aa8e3c..447946b7b2 100644 --- a/crates/collab/src/db/queries/users.rs +++ b/crates/collab/src/db/queries/users.rs @@ -225,6 +225,26 @@ impl Database { .await } + /// Sets "accepted_tos_at" on the user to the given timestamp. + pub async fn set_user_accepted_tos_at( + &self, + id: UserId, + accepted_tos_at: Option, + ) -> Result<()> { + self.transaction(|tx| async move { + user::Entity::update_many() + .filter(user::Column::Id.eq(id)) + .set(user::ActiveModel { + accepted_tos_at: ActiveValue::set(accepted_tos_at), + ..Default::default() + }) + .exec(&*tx) + .await?; + Ok(()) + }) + .await + } + /// hard delete the user. pub async fn destroy_user(&self, id: UserId) -> Result<()> { self.transaction(|tx| async move { diff --git a/crates/collab/src/db/tables/user.rs b/crates/collab/src/db/tables/user.rs index a801e6383e..100bb11650 100644 --- a/crates/collab/src/db/tables/user.rs +++ b/crates/collab/src/db/tables/user.rs @@ -18,6 +18,7 @@ pub struct Model { pub connected_once: bool, pub metrics_id: Uuid, pub created_at: DateTime, + pub accepted_tos_at: Option, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index c1976fd9c5..6705a5c832 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -10,6 +10,7 @@ mod extension_tests; mod feature_flag_tests; mod message_tests; mod processed_stripe_event_tests; +mod user_tests; use crate::migrations::run_database_migrations; diff --git a/crates/collab/src/db/tests/user_tests.rs b/crates/collab/src/db/tests/user_tests.rs new file mode 100644 index 0000000000..e2ef1eeba4 --- /dev/null +++ b/crates/collab/src/db/tests/user_tests.rs @@ -0,0 +1,45 @@ +use chrono::Utc; + +use crate::{ + db::{Database, NewUserParams}, + test_both_dbs, +}; +use std::sync::Arc; + +test_both_dbs!( + test_accepted_tos, + test_accepted_tos_postgres, + test_accepted_tos_sqlite +); + +async fn test_accepted_tos(db: &Arc) { + let user_id = db + .create_user( + "user1@example.com", + false, + NewUserParams { + github_login: "user1".to_string(), + github_user_id: 1, + }, + ) + .await + .unwrap() + .user_id; + + let user = db.get_user_by_id(user_id).await.unwrap().unwrap(); + assert!(user.accepted_tos_at.is_none()); + + let accepted_tos_at = Utc::now().naive_utc(); + db.set_user_accepted_tos_at(user_id, Some(accepted_tos_at)) + .await + .unwrap(); + + let user = db.get_user_by_id(user_id).await.unwrap().unwrap(); + assert!(user.accepted_tos_at.is_some()); + assert_eq!(user.accepted_tos_at, Some(accepted_tos_at)); + + db.set_user_accepted_tos_at(user_id, None).await.unwrap(); + + let user = db.get_user_by_id(user_id).await.unwrap().unwrap(); + assert!(user.accepted_tos_at.is_none()); +} diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index f6f9798351..08c725eb5b 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -31,6 +31,7 @@ use axum::{ routing::get, Extension, Router, TypedHeader, }; +use chrono::Utc; use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; @@ -604,6 +605,7 @@ impl Server { .add_message_handler(user_message_handler(update_followers)) .add_request_handler(user_handler(get_private_user_info)) .add_request_handler(user_handler(get_llm_api_token)) + .add_request_handler(user_handler(accept_terms_of_service)) .add_message_handler(user_message_handler(acknowledge_channel_message)) .add_message_handler(user_message_handler(acknowledge_buffer_version)) .add_request_handler(user_handler(get_supermaven_api_key)) @@ -4882,6 +4884,25 @@ async fn get_private_user_info( metrics_id, staff: user.admin, flags, + accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64), + })?; + Ok(()) +} + +/// Accept the terms of service (tos) on behalf of the current user +async fn accept_terms_of_service( + _request: proto::AcceptTermsOfService, + response: Response, + session: UserSession, +) -> Result<()> { + let db = session.db().await; + + let accepted_tos_at = Utc::now(); + db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc())) + .await?; + + response.send(proto::AcceptTermsOfServiceResponse { + accepted_tos_at: accepted_tos_at.timestamp() as u64, })?; Ok(()) } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 90ced4d9bc..9377dea178 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -9,7 +9,9 @@ pub mod settings; use anyhow::Result; use client::{Client, UserStore}; use futures::{future::BoxFuture, stream::BoxStream}; -use gpui::{AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext}; +use gpui::{ + AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext, +}; pub use model::*; use project::Fs; use proto::Plan; @@ -114,6 +116,12 @@ pub trait LanguageModelProvider: 'static { fn is_authenticated(&self, cx: &AppContext) -> bool; fn authenticate(&self, cx: &mut AppContext) -> Task>; fn configuration_view(&self, cx: &mut WindowContext) -> AnyView; + fn must_accept_terms(&self, _cx: &AppContext) -> bool { + false + } + fn render_accept_terms(&self, _cx: &mut WindowContext) -> Option { + None + } fn reset_credentials(&self, cx: &mut AppContext) -> Task>; } diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 27a4f4f8b7..1f1ae92956 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -9,7 +9,10 @@ use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADE use collections::BTreeMap; use feature_flags::{FeatureFlagAppExt, LanguageModels}; use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt}; -use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task}; +use gpui::{ + AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext, + Subscription, Task, +}; use http_client::{AsyncBody, HttpClient, Method, Response}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -62,6 +65,7 @@ pub struct State { client: Arc, user_store: Model, status: client::Status, + accept_terms: Option>>, _subscription: Subscription, } @@ -77,6 +81,26 @@ impl State { this.update(&mut cx, |_, cx| cx.notify()) }) } + + fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool { + self.user_store + .read(cx) + .current_user_has_accepted_terms() + .unwrap_or(false) + } + + fn accept_terms_of_service(&mut self, cx: &mut ModelContext) { + let user_store = self.user_store.clone(); + self.accept_terms = Some(cx.spawn(move |this, mut cx| async move { + let _ = user_store + .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))? + .await; + this.update(&mut cx, |this, cx| { + this.accept_terms = None; + cx.notify() + }) + })); + } } impl CloudLanguageModelProvider { @@ -88,6 +112,7 @@ impl CloudLanguageModelProvider { client: client.clone(), user_store, status, + accept_terms: None, _subscription: cx.observe_global::(|_, cx| { cx.notify(); }), @@ -223,6 +248,57 @@ impl LanguageModelProvider for CloudLanguageModelProvider { .into() } + fn must_accept_terms(&self, cx: &AppContext) -> bool { + !self.state.read(cx).has_accepted_terms_of_service(cx) + } + + fn render_accept_terms(&self, cx: &mut WindowContext) -> Option { + let state = self.state.read(cx); + + let terms = [( + "anthropic_terms_of_service", + "Anthropic Terms of Service", + "https://www.anthropic.com/legal/consumer-terms", + )] + .map(|(id, label, url)| { + Button::new(id, label) + .style(ButtonStyle::Subtle) + .icon(IconName::ExternalLink) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, cx| cx.open_url(url)) + }); + + if state.has_accepted_terms_of_service(cx) { + None + } else { + let disabled = state.accept_terms.is_some(); + Some( + v_flex() + .child(Label::new("Terms & Conditions").weight(FontWeight::SEMIBOLD)) + .child("Please read and accept the terms and conditions of Zed AI and our provider partners to continue.") + .child(v_flex().m_2().gap_1().children(terms)) + .child( + h_flex().justify_end().mt_1().child( + Button::new("accept_terms", "Accept") + .disabled(disabled) + .on_click({ + let state = self.state.downgrade(); + move |_, cx| { + state + .update(cx, |state, cx| { + state.accept_terms_of_service(cx) + }) + .ok(); + } + }), + ), + ) + .into_any(), + ) + } + } + fn reset_credentials(&self, _cx: &mut AppContext) -> Task> { Task::ready(Ok(())) } @@ -766,6 +842,7 @@ impl Render for ConfigurationView { let is_connected = !self.state.read(cx).is_signed_out(); let plan = self.state.read(cx).user_store.read(cx).current_plan(); + let must_accept_terms = !self.state.read(cx).has_accepted_terms_of_service(cx); let is_pro = plan == Some(proto::Plan::ZedPro); @@ -773,6 +850,11 @@ impl Render for ConfigurationView { v_flex() .gap_3() .max_w_4_5() + .when(must_accept_terms, |this| { + this.child(Label::new( + "You must accept the terms of service to use this provider.", + )) + }) .child(Label::new( if is_pro { "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro." diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 995ebcd341..299337d6c2 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -49,7 +49,7 @@ message Envelope { GetDefinition get_definition = 32; GetDefinitionResponse get_definition_response = 33; GetDeclaration get_declaration = 237; - GetDeclarationResponse get_declaration_response = 238; // current max + GetDeclarationResponse get_declaration_response = 238; GetTypeDefinition get_type_definition = 34; GetTypeDefinitionResponse get_type_definition_response = 35; @@ -130,6 +130,8 @@ message Envelope { GetPrivateUserInfoResponse get_private_user_info_response = 103; UpdateUserPlan update_user_plan = 234; UpdateDiffBase update_diff_base = 104; + AcceptTermsOfService accept_terms_of_service = 239; + AcceptTermsOfServiceResponse accept_terms_of_service_response = 240; // current max OnTypeFormatting on_type_formatting = 105; OnTypeFormattingResponse on_type_formatting_response = 106; @@ -270,7 +272,7 @@ message Envelope { AddWorktreeResponse add_worktree_response = 223; GetLlmToken get_llm_token = 235; - GetLlmTokenResponse get_llm_token_response = 236; // current max + GetLlmTokenResponse get_llm_token_response = 236; } reserved 158 to 161; @@ -1692,6 +1694,7 @@ message GetPrivateUserInfoResponse { string metrics_id = 1; bool staff = 2; repeated string flags = 3; + optional uint64 accepted_tos_at = 4; } enum Plan { @@ -1703,6 +1706,12 @@ message UpdateUserPlan { Plan plan = 1; } +message AcceptTermsOfService {} + +message AcceptTermsOfServiceResponse { + uint64 accepted_tos_at = 1; +} + // Entities message ViewId { diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 139ee8fdf9..7ca48b1369 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -187,6 +187,8 @@ impl fmt::Display for PeerId { } messages!( + (AcceptTermsOfService, Foreground), + (AcceptTermsOfServiceResponse, Foreground), (Ack, Foreground), (AckBufferOperation, Background), (AckChannelMessage, Background), @@ -409,6 +411,7 @@ messages!( ); request_messages!( + (AcceptTermsOfService, AcceptTermsOfServiceResponse), (ApplyCodeAction, ApplyCodeActionResponse), ( ApplyCompletionAdditionalEdits,