diff --git a/Cargo.lock b/Cargo.lock index 3855cec38c..5b465d8131 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -851,6 +851,7 @@ dependencies = [ "envy", "futures", "gpui", + "hyper", "language", "lazy_static", "lipsum", diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index e7a836fabb..5ebde4a37d 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -25,6 +25,7 @@ base64 = "0.13" clap = { version = "3.1", features = ["derive"], optional = true } envy = "0.4.2" futures = "0.3" +hyper = "0.14" lazy_static = "1.4" lipsum = { version = "0.8", optional = true } nanoid = "0.4" diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 09b0d1af07..b61043109b 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -91,7 +91,8 @@ fn hash_access_token(token: &str) -> Result { None, params, &SaltString::generate(thread_rng()), - )? + ) + .map_err(anyhow::Error::new)? .to_string()) } @@ -105,6 +106,6 @@ pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result Result { - let hash = PasswordHash::new(hash)?; + let hash = PasswordHash::new(hash).map_err(anyhow::Error::new)?; Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok()) } diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 8550666ac5..7463e3483d 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1,6 +1,7 @@ -use crate::Result; +use crate::{Error, Result}; use anyhow::{anyhow, Context}; use async_trait::async_trait; +use axum::http::StatusCode; use futures::StreamExt; use nanoid::nanoid; use serde::Serialize; @@ -237,7 +238,7 @@ impl Db for PostgresDb { .fetch_optional(&self.pool) .await?; if let Some((code, count)) = result { - Ok(Some((code, count.try_into()?))) + Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?))) } else { Ok(None) } @@ -246,7 +247,7 @@ impl Db for PostgresDb { async fn redeem_invite_code(&self, code: &str, login: &str) -> Result { let mut tx = self.pool.begin().await?; - let inviter_id: UserId = sqlx::query_scalar( + let inviter_id: Option = sqlx::query_scalar( " UPDATE users SET invite_count = invite_count - 1 @@ -258,8 +259,30 @@ impl Db for PostgresDb { ) .bind(code) .fetch_optional(&mut tx) - .await? - .ok_or_else(|| anyhow!("invite code not found"))?; + .await?; + + let inviter_id = match inviter_id { + Some(inviter_id) => inviter_id, + None => { + if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1") + .bind(code) + .fetch_optional(&mut tx) + .await? + .is_some() + { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "no invites remaining".to_string(), + ))? + } else { + Err(Error::Http( + StatusCode::NOT_FOUND, + "invite code not found".to_string(), + ))? + } + } + }; + let invitee_id = sqlx::query_scalar( " INSERT INTO users diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index a751c229df..57490ce0bc 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -88,6 +88,18 @@ impl From for Error { } } +impl From for Error { + fn from(error: axum::Error) -> Self { + Self::Internal(error.into()) + } +} + +impl From for Error { + fn from(error: hyper::Error) -> Self { + Self::Internal(error.into()) + } +} + impl IntoResponse for Error { fn into_response(self) -> axum::response::Response { match self {