Add development credentials provider (#11505)

This PR adds a new development credentials provider for the purpose of
streamlining local development against production collab.

## Problem

Today if you want to run a development build of Zed against the
production collab server, you need to either:

1. Enter your keychain password every time in order to retrieve your
saved credentials
2. Re-authenticate with zed.dev every time
    - This can get annoying as you need to pop out into a browser window
- I've also seen cases where if you re-auth too many times in a row
GitHub will make you confirm the authentication, as it looks suspicious

## Solution

This PR decouples the concept of the credentials provider from the
keychain, and adds a new development credentials provider to address
this specific case.

Now when running a development build of Zed and the
`ZED_DEVELOPMENT_AUTH` environment variable is set to a non-empty value,
the credentials will be saved to disk instead of the system keychain.

While this is not as secure as storing them in the system keychain,
since it is only used for development the tradeoff seems acceptable for
the resulting improvement in UX.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-05-07 13:59:18 -04:00 committed by GitHub
parent c260f7d5ac
commit d64106e01b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 203 additions and 57 deletions

View File

@ -16,39 +16,38 @@ doctest = false
test-support = ["clock/test-support", "collections/test-support", "gpui/test-support", "rpc/test-support"]
[dependencies]
chrono = { workspace = true, features = ["serde"] }
clock.workspace = true
collections.workspace = true
gpui.workspace = true
util.workspace = true
release_channel.workspace = true
rpc.workspace = true
text.workspace = true
settings.workspace = true
feature_flags.workspace = true
anyhow.workspace = true
async-recursion = "0.3"
async-tungstenite = { version = "0.16", features = ["async-std", "async-native-tls"] }
chrono = { workspace = true, features = ["serde"] }
clock.workspace = true
collections.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
lazy_static.workspace = true
log.workspace = true
once_cell = "1.19.0"
parking_lot.workspace = true
postage.workspace = true
rand.workspace = true
release_channel.workspace = true
rpc.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
sha2.workspace = true
smol.workspace = true
sysinfo.workspace = true
telemetry_events.workspace = true
tempfile.workspace = true
text.workspace = true
thiserror.workspace = true
time.workspace = true
tiny_http = "0.8"
url.workspace = true
util.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }

View File

@ -30,6 +30,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources, SettingsStore};
use std::fmt;
use std::pin::Pin;
use std::{
any::TypeId,
convert::TryFrom,
@ -65,6 +66,13 @@ impl fmt::Display for DevServerToken {
lazy_static! {
static ref ZED_SERVER_URL: Option<String> = std::env::var("ZED_SERVER_URL").ok();
static ref ZED_RPC_URL: Option<String> = std::env::var("ZED_RPC_URL").ok();
/// An environment variable whose presence indicates that the development auth
/// provider should be used.
///
/// Only works in development. Setting this environment variable in other release
/// channels is a no-op.
pub static ref ZED_DEVELOPMENT_AUTH: bool =
std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty());
pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
.ok()
.and_then(|s| if s.is_empty() { None } else { Some(s) });
@ -161,6 +169,7 @@ pub struct Client {
peer: Arc<Peer>,
http: Arc<HttpClientWithUrl>,
telemetry: Arc<Telemetry>,
credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
state: RwLock<ClientState>,
#[allow(clippy::type_complexity)]
@ -298,6 +307,32 @@ impl Credentials {
}
}
/// A provider for [`Credentials`].
///
/// Used to abstract over reading and writing credentials to some form of
/// persistence (like the system keychain).
trait CredentialsProvider {
/// Reads the credentials from the provider.
fn read_credentials<'a>(
&'a self,
cx: &'a AsyncAppContext,
) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>>;
/// Writes the credentials to the provider.
fn write_credentials<'a>(
&'a self,
user_id: u64,
access_token: String,
cx: &'a AsyncAppContext,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
/// Deletes the credentials from the provider.
fn delete_credentials<'a>(
&'a self,
cx: &'a AsyncAppContext,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
}
impl Default for ClientState {
fn default() -> Self {
Self {
@ -443,11 +478,27 @@ impl Client {
http: Arc<HttpClientWithUrl>,
cx: &mut AppContext,
) -> Arc<Self> {
let use_zed_development_auth = match ReleaseChannel::try_global(cx) {
Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH,
Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
| None => false,
};
let credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static> =
if use_zed_development_auth {
Arc::new(DevelopmentCredentialsProvider {
path: util::paths::CONFIG_DIR.join("development_auth"),
})
} else {
Arc::new(KeychainCredentialsProvider)
};
Arc::new(Self {
id: AtomicU64::new(0),
peer: Peer::new(0),
telemetry: Telemetry::new(clock, http.clone(), cx),
http,
credentials_provider,
state: Default::default(),
#[cfg(any(test, feature = "test-support"))]
@ -763,8 +814,11 @@ impl Client {
}
}
pub async fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
read_credentials_from_keychain(cx).await.is_some()
pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool {
self.credentials_provider
.read_credentials(cx)
.await
.is_some()
}
pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self {
@ -775,7 +829,7 @@ impl Client {
#[async_recursion(?Send)]
pub async fn authenticate_and_connect(
self: &Arc<Self>,
try_keychain: bool,
try_provider: bool,
cx: &AsyncAppContext,
) -> anyhow::Result<()> {
let was_disconnected = match *self.status().borrow() {
@ -796,12 +850,13 @@ impl Client {
self.set_status(Status::Reauthenticating, cx)
}
let mut read_from_keychain = false;
let mut read_from_provider = false;
let mut credentials = self.state.read().credentials.clone();
if credentials.is_none() && try_keychain {
credentials = read_credentials_from_keychain(cx).await;
read_from_keychain = credentials.is_some();
if credentials.is_none() && try_provider {
credentials = self.credentials_provider.read_credentials(cx).await;
read_from_provider = credentials.is_some();
}
if credentials.is_none() {
let mut status_rx = self.status();
let _ = status_rx.next().await;
@ -838,9 +893,9 @@ impl Client {
match connection {
Ok(conn) => {
self.state.write().credentials = Some(credentials.clone());
if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
if let Credentials::User{user_id, access_token} = credentials {
write_credentials_to_keychain(user_id, access_token, cx).await.log_err();
self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err();
}
}
@ -854,8 +909,8 @@ impl Client {
}
Err(EstablishConnectionError::Unauthorized) => {
self.state.write().credentials.take();
if read_from_keychain {
delete_credentials_from_keychain(cx).await.log_err();
if read_from_provider {
self.credentials_provider.delete_credentials(cx).await.log_err();
self.set_status(Status::SignedOut, cx);
self.authenticate_and_connect(false, cx).await
} else {
@ -1264,8 +1319,11 @@ impl Client {
self.state.write().credentials = None;
self.disconnect(&cx);
if self.has_keychain_credentials(cx).await {
delete_credentials_from_keychain(cx).await.log_err();
if self.has_credentials(cx).await {
self.credentials_provider
.delete_credentials(cx)
.await
.log_err();
}
}
@ -1465,41 +1523,128 @@ impl Client {
}
}
async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
if IMPERSONATE_LOGIN.is_some() {
return None;
}
let (user_id, access_token) = cx
.update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
.log_err()?
.await
.log_err()??;
Some(Credentials::User {
user_id: user_id.parse().ok()?,
access_token: String::from_utf8(access_token).ok()?,
})
}
async fn write_credentials_to_keychain(
#[derive(Serialize, Deserialize)]
struct DevelopmentCredentials {
user_id: u64,
access_token: String,
cx: &AsyncAppContext,
) -> Result<()> {
cx.update(move |cx| {
cx.write_credentials(
&ClientSettings::get_global(cx).server_url,
&user_id.to_string(),
access_token.as_bytes(),
)
})?
.await
}
async fn delete_credentials_from_keychain(cx: &AsyncAppContext) -> Result<()> {
cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
.await
/// A credentials provider that stores credentials in a local file.
///
/// This MUST only be used in development, as this is not a secure way of storing
/// credentials on user machines.
///
/// Its existence is purely to work around the annoyance of having to constantly
/// re-allow access to the system keychain when developing Zed.
struct DevelopmentCredentialsProvider {
path: PathBuf,
}
impl CredentialsProvider for DevelopmentCredentialsProvider {
fn read_credentials<'a>(
&'a self,
_cx: &'a AsyncAppContext,
) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
async move {
if IMPERSONATE_LOGIN.is_some() {
return None;
}
let json = std::fs::read(&self.path).log_err()?;
let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?;
Some(Credentials::User {
user_id: credentials.user_id,
access_token: credentials.access_token,
})
}
.boxed_local()
}
fn write_credentials<'a>(
&'a self,
user_id: u64,
access_token: String,
_cx: &'a AsyncAppContext,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
async move {
let json = serde_json::to_string(&DevelopmentCredentials {
user_id,
access_token,
})?;
std::fs::write(&self.path, json)?;
Ok(())
}
.boxed_local()
}
fn delete_credentials<'a>(
&'a self,
_cx: &'a AsyncAppContext,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local()
}
}
/// A credentials provider that stores credentials in the system keychain.
struct KeychainCredentialsProvider;
impl CredentialsProvider for KeychainCredentialsProvider {
fn read_credentials<'a>(
&'a self,
cx: &'a AsyncAppContext,
) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
async move {
if IMPERSONATE_LOGIN.is_some() {
return None;
}
let (user_id, access_token) = cx
.update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
.log_err()?
.await
.log_err()??;
Some(Credentials::User {
user_id: user_id.parse().ok()?,
access_token: String::from_utf8(access_token).ok()?,
})
}
.boxed_local()
}
fn write_credentials<'a>(
&'a self,
user_id: u64,
access_token: String,
cx: &'a AsyncAppContext,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
async move {
cx.update(move |cx| {
cx.write_credentials(
&ClientSettings::get_global(cx).server_url,
&user_id.to_string(),
access_token.as_bytes(),
)
})?
.await
}
.boxed_local()
}
fn delete_credentials<'a>(
&'a self,
cx: &'a AsyncAppContext,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
async move {
cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
.await
}
.boxed_local()
}
}
/// prefix for the zed:// url scheme

View File

@ -542,10 +542,12 @@ fn handle_open_request(
async fn authenticate(client: Arc<Client>, cx: &AsyncAppContext) -> Result<()> {
if stdout_is_a_pty() {
if client::IMPERSONATE_LOGIN.is_some() {
if *client::ZED_DEVELOPMENT_AUTH {
client.authenticate_and_connect(true, &cx).await?;
} else if client::IMPERSONATE_LOGIN.is_some() {
client.authenticate_and_connect(false, &cx).await?;
}
} else if client.has_keychain_credentials(&cx).await {
} else if client.has_credentials(&cx).await {
client.authenticate_and_connect(true, &cx).await?;
}
Ok::<_, anyhow::Error>(())