assistant: Require user to accept TOS for cloud provider (#16111)

This adds the requirement for users to accept the terms of service the
first time they send a message with the Cloud provider.

Once this is out and in a nightly, we need to add the check to the
server side too, to authenticate access to the models.

Demo:


https://github.com/user-attachments/assets/0edebf74-8120-4fa2-b801-bb76f04e8a17



Release Notes:

- N/A
This commit is contained in:
Thorsten Ball 2024-08-12 17:43:35 +02:00 committed by GitHub
parent 98f314ba21
commit fbb533b3e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 297 additions and 9 deletions

View File

@ -490,6 +490,7 @@ impl AssistantPanel {
}
language_model::Event::ProviderStateChanged => {
this.ensure_authenticated(cx);
cx.notify()
}
language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
@ -1712,6 +1713,7 @@ pub struct ContextEditor {
assistant_panel: WeakView<AssistantPanel>,
error_message: Option<SharedString>,
debug_inspector: Option<ContextInspector>,
show_accept_terms: bool,
}
const DEFAULT_TAB_TITLE: &str = "New Context";
@ -1772,6 +1774,7 @@ impl ContextEditor {
assistant_panel,
error_message: None,
debug_inspector: None,
show_accept_terms: false,
};
this.update_message_headers(cx);
this.insert_slash_command_output_sections(sections, cx);
@ -1804,6 +1807,16 @@ impl ContextEditor {
}
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
if provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx))
{
self.show_accept_terms = true;
cx.notify();
return;
}
if !self.apply_active_workflow_step(cx) {
self.error_message = None;
self.send_to_model(cx);
@ -3388,7 +3401,14 @@ impl ContextEditor {
None => (ButtonStyle::Filled, None),
};
let provider = LanguageModelRegistry::read_global(cx).active_provider();
let disabled = self.show_accept_terms
&& provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx));
ButtonLike::new("send_button")
.disabled(disabled)
.style(style)
.when_some(tooltip, |button, tooltip| {
button.tooltip(move |_| tooltip.clone())
@ -3437,6 +3457,15 @@ impl EventEmitter<SearchEvent> for ContextEditor {}
impl Render for ContextEditor {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
let accept_terms = if self.show_accept_terms {
provider
.as_ref()
.and_then(|provider| provider.render_accept_terms(cx))
} else {
None
};
v_flex()
.key_context("ContextEditor")
.capture_action(cx.listener(ContextEditor::cancel))
@ -3455,6 +3484,21 @@ impl Render for ContextEditor {
.bg(cx.theme().colors().editor_background)
.child(self.editor.clone()),
)
.when_some(accept_terms, |this, element| {
this.child(
div()
.absolute()
.right_4()
.bottom_10()
.max_w_96()
.py_2()
.px_3()
.elevation_2(cx)
.bg(cx.theme().colors().surface_background)
.occlude()
.child(element),
)
})
.child(
h_flex().flex_none().relative().child(
h_flex()

View File

@ -1,5 +1,6 @@
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{anyhow, Result};
use chrono::Duration;
use futures::{stream::BoxStream, StreamExt};
use gpui::{BackgroundExecutor, Context, Model, TestAppContext};
use parking_lot::Mutex;
@ -162,6 +163,11 @@ impl FakeServer {
return Ok(*message.downcast().unwrap());
}
let accepted_tos_at = chrono::Utc::now()
.checked_sub_signed(Duration::hours(5))
.expect("failed to build accepted_tos_at")
.timestamp() as u64;
if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
self.respond(
message
@ -172,6 +178,7 @@ impl FakeServer {
metrics_id: "the-metrics-id".into(),
staff: false,
flags: Default::default(),
accepted_tos_at: Some(accepted_tos_at),
},
);
continue;

View File

@ -1,5 +1,6 @@
use super::{proto, Client, Status, TypedEnvelope};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
use collections::{hash_map::Entry, HashMap, HashSet};
use feature_flags::FeatureFlagAppExt;
use futures::{channel::mpsc, Future, StreamExt};
@ -94,6 +95,7 @@ pub struct UserStore {
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
current_plan: Option<proto::Plan>,
current_user: watch::Receiver<Option<Arc<User>>>,
accepted_tos_at: Option<Option<DateTime<Utc>>>,
contacts: Vec<Arc<Contact>>,
incoming_contact_requests: Vec<Arc<User>>,
outgoing_contact_requests: Vec<Arc<User>>,
@ -150,6 +152,7 @@ impl UserStore {
by_github_login: Default::default(),
current_user: current_user_rx,
current_plan: None,
accepted_tos_at: None,
contacts: Default::default(),
incoming_contact_requests: Default::default(),
participant_indices: Default::default(),
@ -189,9 +192,10 @@ impl UserStore {
} else {
break;
};
let fetch_metrics_id =
let fetch_private_user_info =
client.request(proto::GetPrivateUserInfo {}).log_err();
let (user, info) = futures::join!(fetch_user, fetch_metrics_id);
let (user, info) =
futures::join!(fetch_user, fetch_private_user_info);
cx.update(|cx| {
if let Some(info) = info {
@ -202,9 +206,17 @@ impl UserStore {
client.telemetry.set_authenticated_user_info(
Some(info.metrics_id.clone()),
staff,
)
);
this.update(cx, |this, _| {
this.set_current_user_accepted_tos_at(
info.accepted_tos_at,
);
})
} else {
anyhow::Ok(())
}
})?;
})??;
current_user_tx.send(user).await.ok();
@ -680,6 +692,39 @@ impl UserStore {
self.current_user.clone()
}
pub fn current_user_has_accepted_terms(&self) -> Option<bool> {
self.accepted_tos_at
.map(|accepted_tos_at| accepted_tos_at.is_some())
}
pub fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
if self.current_user().is_none() {
return Task::ready(Err(anyhow!("no current user")));
};
let client = self.client.clone();
cx.spawn(move |this, mut cx| async move {
if let Some(client) = client.upgrade() {
let response = client
.request(proto::AcceptTermsOfService {})
.await
.context("error accepting tos")?;
this.update(&mut cx, |this, _| {
this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at))
})
} else {
Err(anyhow!("client not found"))
}
})
}
fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option<u64>) {
self.accepted_tos_at = Some(
accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)),
);
}
fn load_users(
&mut self,
request: impl RequestMessage<Response = UsersResponse>,

View File

@ -9,7 +9,8 @@ CREATE TABLE "users" (
"connected_once" BOOLEAN NOT NULL DEFAULT false,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"metrics_id" TEXT,
"github_user_id" INTEGER
"github_user_id" INTEGER,
"accepted_tos_at" TIMESTAMP WITHOUT TIME ZONE
);
CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login");
CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code");

View File

@ -0,0 +1 @@
ALTER TABLE users ADD accepted_tos_at TIMESTAMP WITHOUT TIME ZONE;

View File

@ -225,6 +225,26 @@ impl Database {
.await
}
/// Sets "accepted_tos_at" on the user to the given timestamp.
pub async fn set_user_accepted_tos_at(
&self,
id: UserId,
accepted_tos_at: Option<DateTime>,
) -> Result<()> {
self.transaction(|tx| async move {
user::Entity::update_many()
.filter(user::Column::Id.eq(id))
.set(user::ActiveModel {
accepted_tos_at: ActiveValue::set(accepted_tos_at),
..Default::default()
})
.exec(&*tx)
.await?;
Ok(())
})
.await
}
/// hard delete the user.
pub async fn destroy_user(&self, id: UserId) -> Result<()> {
self.transaction(|tx| async move {

View File

@ -18,6 +18,7 @@ pub struct Model {
pub connected_once: bool,
pub metrics_id: Uuid,
pub created_at: DateTime,
pub accepted_tos_at: Option<DateTime>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

View File

@ -10,6 +10,7 @@ mod extension_tests;
mod feature_flag_tests;
mod message_tests;
mod processed_stripe_event_tests;
mod user_tests;
use crate::migrations::run_database_migrations;

View File

@ -0,0 +1,45 @@
use chrono::Utc;
use crate::{
db::{Database, NewUserParams},
test_both_dbs,
};
use std::sync::Arc;
test_both_dbs!(
test_accepted_tos,
test_accepted_tos_postgres,
test_accepted_tos_sqlite
);
async fn test_accepted_tos(db: &Arc<Database>) {
let user_id = db
.create_user(
"user1@example.com",
false,
NewUserParams {
github_login: "user1".to_string(),
github_user_id: 1,
},
)
.await
.unwrap()
.user_id;
let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
assert!(user.accepted_tos_at.is_none());
let accepted_tos_at = Utc::now().naive_utc();
db.set_user_accepted_tos_at(user_id, Some(accepted_tos_at))
.await
.unwrap();
let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
assert!(user.accepted_tos_at.is_some());
assert_eq!(user.accepted_tos_at, Some(accepted_tos_at));
db.set_user_accepted_tos_at(user_id, None).await.unwrap();
let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
assert!(user.accepted_tos_at.is_none());
}

View File

@ -31,6 +31,7 @@ use axum::{
routing::get,
Extension, Router, TypedHeader,
};
use chrono::Utc;
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
@ -604,6 +605,7 @@ impl Server {
.add_message_handler(user_message_handler(update_followers))
.add_request_handler(user_handler(get_private_user_info))
.add_request_handler(user_handler(get_llm_api_token))
.add_request_handler(user_handler(accept_terms_of_service))
.add_message_handler(user_message_handler(acknowledge_channel_message))
.add_message_handler(user_message_handler(acknowledge_buffer_version))
.add_request_handler(user_handler(get_supermaven_api_key))
@ -4882,6 +4884,25 @@ async fn get_private_user_info(
metrics_id,
staff: user.admin,
flags,
accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
})?;
Ok(())
}
/// Accept the terms of service (tos) on behalf of the current user
async fn accept_terms_of_service(
_request: proto::AcceptTermsOfService,
response: Response<proto::AcceptTermsOfService>,
session: UserSession,
) -> Result<()> {
let db = session.db().await;
let accepted_tos_at = Utc::now();
db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
.await?;
response.send(proto::AcceptTermsOfServiceResponse {
accepted_tos_at: accepted_tos_at.timestamp() as u64,
})?;
Ok(())
}

View File

@ -9,7 +9,9 @@ pub mod settings;
use anyhow::Result;
use client::{Client, UserStore};
use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
};
pub use model::*;
use project::Fs;
use proto::Plan;
@ -114,6 +116,12 @@ pub trait LanguageModelProvider: 'static {
fn is_authenticated(&self, cx: &AppContext) -> bool;
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
fn configuration_view(&self, cx: &mut WindowContext) -> AnyView;
fn must_accept_terms(&self, _cx: &AppContext) -> bool {
false
}
fn render_accept_terms(&self, _cx: &mut WindowContext) -> Option<AnyElement> {
None
}
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
}

View File

@ -9,7 +9,10 @@ use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADE
use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, LanguageModels};
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
Subscription, Task,
};
use http_client::{AsyncBody, HttpClient, Method, Response};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -62,6 +65,7 @@ pub struct State {
client: Arc<Client>,
user_store: Model<UserStore>,
status: client::Status,
accept_terms: Option<Task<Result<()>>>,
_subscription: Subscription,
}
@ -77,6 +81,26 @@ impl State {
this.update(&mut cx, |_, cx| cx.notify())
})
}
fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
self.user_store
.read(cx)
.current_user_has_accepted_terms()
.unwrap_or(false)
}
fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
let user_store = self.user_store.clone();
self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
let _ = user_store
.update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
.await;
this.update(&mut cx, |this, cx| {
this.accept_terms = None;
cx.notify()
})
}));
}
}
impl CloudLanguageModelProvider {
@ -88,6 +112,7 @@ impl CloudLanguageModelProvider {
client: client.clone(),
user_store,
status,
accept_terms: None,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
@ -223,6 +248,57 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
.into()
}
fn must_accept_terms(&self, cx: &AppContext) -> bool {
!self.state.read(cx).has_accepted_terms_of_service(cx)
}
fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
let state = self.state.read(cx);
let terms = [(
"anthropic_terms_of_service",
"Anthropic Terms of Service",
"https://www.anthropic.com/legal/consumer-terms",
)]
.map(|(id, label, url)| {
Button::new(id, label)
.style(ButtonStyle::Subtle)
.icon(IconName::ExternalLink)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, cx| cx.open_url(url))
});
if state.has_accepted_terms_of_service(cx) {
None
} else {
let disabled = state.accept_terms.is_some();
Some(
v_flex()
.child(Label::new("Terms & Conditions").weight(FontWeight::SEMIBOLD))
.child("Please read and accept the terms and conditions of Zed AI and our provider partners to continue.")
.child(v_flex().m_2().gap_1().children(terms))
.child(
h_flex().justify_end().mt_1().child(
Button::new("accept_terms", "Accept")
.disabled(disabled)
.on_click({
let state = self.state.downgrade();
move |_, cx| {
state
.update(cx, |state, cx| {
state.accept_terms_of_service(cx)
})
.ok();
}
}),
),
)
.into_any(),
)
}
}
fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
@ -766,6 +842,7 @@ impl Render for ConfigurationView {
let is_connected = !self.state.read(cx).is_signed_out();
let plan = self.state.read(cx).user_store.read(cx).current_plan();
let must_accept_terms = !self.state.read(cx).has_accepted_terms_of_service(cx);
let is_pro = plan == Some(proto::Plan::ZedPro);
@ -773,6 +850,11 @@ impl Render for ConfigurationView {
v_flex()
.gap_3()
.max_w_4_5()
.when(must_accept_terms, |this| {
this.child(Label::new(
"You must accept the terms of service to use this provider.",
))
})
.child(Label::new(
if is_pro {
"You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."

View File

@ -49,7 +49,7 @@ message Envelope {
GetDefinition get_definition = 32;
GetDefinitionResponse get_definition_response = 33;
GetDeclaration get_declaration = 237;
GetDeclarationResponse get_declaration_response = 238; // current max
GetDeclarationResponse get_declaration_response = 238;
GetTypeDefinition get_type_definition = 34;
GetTypeDefinitionResponse get_type_definition_response = 35;
@ -130,6 +130,8 @@ message Envelope {
GetPrivateUserInfoResponse get_private_user_info_response = 103;
UpdateUserPlan update_user_plan = 234;
UpdateDiffBase update_diff_base = 104;
AcceptTermsOfService accept_terms_of_service = 239;
AcceptTermsOfServiceResponse accept_terms_of_service_response = 240; // current max
OnTypeFormatting on_type_formatting = 105;
OnTypeFormattingResponse on_type_formatting_response = 106;
@ -270,7 +272,7 @@ message Envelope {
AddWorktreeResponse add_worktree_response = 223;
GetLlmToken get_llm_token = 235;
GetLlmTokenResponse get_llm_token_response = 236; // current max
GetLlmTokenResponse get_llm_token_response = 236;
}
reserved 158 to 161;
@ -1692,6 +1694,7 @@ message GetPrivateUserInfoResponse {
string metrics_id = 1;
bool staff = 2;
repeated string flags = 3;
optional uint64 accepted_tos_at = 4;
}
enum Plan {
@ -1703,6 +1706,12 @@ message UpdateUserPlan {
Plan plan = 1;
}
message AcceptTermsOfService {}
message AcceptTermsOfServiceResponse {
uint64 accepted_tos_at = 1;
}
// Entities
message ViewId {

View File

@ -187,6 +187,8 @@ impl fmt::Display for PeerId {
}
messages!(
(AcceptTermsOfService, Foreground),
(AcceptTermsOfServiceResponse, Foreground),
(Ack, Foreground),
(AckBufferOperation, Background),
(AckChannelMessage, Background),
@ -409,6 +411,7 @@ messages!(
);
request_messages!(
(AcceptTermsOfService, AcceptTermsOfServiceResponse),
(ApplyCodeAction, ApplyCodeActionResponse),
(
ApplyCompletionAdditionalEdits,