diff --git a/Cargo.lock b/Cargo.lock index 5957cc406f..3855cec38c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -856,6 +856,7 @@ dependencies = [ "lipsum", "log", "lsp", + "nanoid", "opentelemetry", "opentelemetry-otlp", "parking_lot", @@ -2761,6 +2762,15 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" +[[package]] +name = "nanoid" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8" +dependencies = [ + "rand 0.8.3", +] + [[package]] name = "native-tls" version = "0.2.10" diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index db8386330a..e7a836fabb 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -27,6 +27,7 @@ envy = "0.4.2" futures = "0.3" lazy_static = "1.4" lipsum = { version = "0.8", optional = true } +nanoid = "0.4" opentelemetry = { version = "0.17", features = ["rt-tokio"] } opentelemetry-otlp = { version = "0.10", features = ["tls-roots"] } parking_lot = "0.11.1" diff --git a/crates/collab/migrations/20220518151305_add_invites_to_users.sql b/crates/collab/migrations/20220518151305_add_invites_to_users.sql new file mode 100644 index 0000000000..4811cdf46b --- /dev/null +++ b/crates/collab/migrations/20220518151305_add_invites_to_users.sql @@ -0,0 +1,7 @@ +ALTER TABLE users +ADD invite_code VARCHAR(64), +ADD invite_count INTEGER NOT NULL DEFAULT 0, +ADD inviter_id INTEGER REFERENCES users (id), +ADD created_at TIMESTAMP NOT NULL DEFAULT NOW(); + +CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code"); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 056f94ecfe..c326a52061 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1,6 +1,7 @@ use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use futures::StreamExt; +use nanoid::nanoid; use serde::Serialize; pub use sqlx::postgres::PgPoolOptions as DbOptions; use sqlx::{types::Uuid, FromRow}; @@ -17,6 +18,10 @@ pub trait Db: Send + Sync { async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>; async fn destroy_user(&self, id: UserId) -> Result<()>; + async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>; + async fn get_invite_code(&self, id: UserId) -> Result>; + async fn redeem_invite_code(&self, code: &str, login: &str) -> Result; + async fn get_contacts(&self, id: UserId) -> Result>; async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result; async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>; @@ -189,6 +194,103 @@ impl Db for PostgresDb { .map(drop)?) } + // invite codes + + async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> { + let mut tx = self.pool.begin().await?; + sqlx::query( + " + UPDATE users + SET invite_code = $1 + WHERE id = $2 AND invite_code IS NULL + ", + ) + .bind(nanoid!(16)) + .bind(id) + .execute(&mut tx) + .await?; + sqlx::query( + " + UPDATE users + SET invite_count = $1 + WHERE id = $2 + ", + ) + .bind(count) + .bind(id) + .execute(&mut tx) + .await?; + tx.commit().await?; + Ok(()) + } + + async fn get_invite_code(&self, id: UserId) -> Result> { + let result: Option<(String, i32)> = sqlx::query_as( + " + SELECT invite_code, invite_count + FROM users + WHERE id = $1 AND invite_code IS NOT NULL + ", + ) + .bind(id) + .fetch_optional(&self.pool) + .await?; + if let Some((code, count)) = result { + Ok(Some((code, count.try_into()?))) + } else { + Ok(None) + } + } + + async fn redeem_invite_code(&self, code: &str, login: &str) -> Result { + let mut tx = self.pool.begin().await?; + + let inviter_id: UserId = sqlx::query_scalar( + " + UPDATE users + SET invite_count = invite_count - 1 + WHERE + invite_code = $1 AND + invite_count > 0 + RETURNING id + ", + ) + .bind(code) + .fetch_optional(&mut tx) + .await? + .ok_or_else(|| anyhow!("invite code not found"))?; + let invitee_id = sqlx::query_scalar( + " + INSERT INTO users + (github_login, admin, inviter_id) + VALUES + ($1, 'f', $2) + RETURNING id + ", + ) + .bind(login) + .bind(inviter_id) + .fetch_one(&mut tx) + .await + .map(UserId)?; + + sqlx::query( + " + INSERT INTO contacts + (user_id_a, user_id_b, a_to_b, should_notify, accepted) + VALUES + ($1, $2, 't', 't', 't') + ", + ) + .bind(inviter_id) + .bind(invitee_id) + .execute(&mut tx) + .await?; + + tx.commit().await?; + Ok(invitee_id) + } + // contacts async fn get_contacts(&self, user_id: UserId) -> Result> { @@ -1198,6 +1300,144 @@ pub mod tests { } } + #[tokio::test(flavor = "multi_thread")] + async fn test_invite_codes() { + let postgres = TestDb::postgres().await; + let db = postgres.db(); + let user1 = db.create_user("user-1", false).await.unwrap(); + + // Initially, user 1 has no invite code + assert_eq!(db.get_invite_code(user1).await.unwrap(), None); + + // User 1 creates an invite code that can be used twice. + db.set_invite_count(user1, 2).await.unwrap(); + let (invite_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 2); + + // User 2 redeems the invite code and becomes a contact of user 1. + let user2 = db.redeem_invite_code(&invite_code, "user-2").await.unwrap(); + let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 1); + assert_eq!( + db.get_contacts(user1).await.unwrap(), + [ + Contact::Accepted { + user_id: user1, + should_notify: false + }, + Contact::Accepted { + user_id: user2, + should_notify: true + } + ] + ); + assert_eq!( + db.get_contacts(user2).await.unwrap(), + [ + Contact::Accepted { + user_id: user1, + should_notify: false + }, + Contact::Accepted { + user_id: user2, + should_notify: false + } + ] + ); + + // User 3 redeems the invite code and becomes a contact of user 1. + let user3 = db.redeem_invite_code(&invite_code, "user-3").await.unwrap(); + let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 0); + assert_eq!( + db.get_contacts(user1).await.unwrap(), + [ + Contact::Accepted { + user_id: user1, + should_notify: false + }, + Contact::Accepted { + user_id: user2, + should_notify: true + }, + Contact::Accepted { + user_id: user3, + should_notify: true + } + ] + ); + assert_eq!( + db.get_contacts(user3).await.unwrap(), + [ + Contact::Accepted { + user_id: user1, + should_notify: false + }, + Contact::Accepted { + user_id: user3, + should_notify: false + }, + ] + ); + + // Trying to reedem the code for the third time results in an error. + db.redeem_invite_code(&invite_code, "user-4") + .await + .unwrap_err(); + + // Invite count can be updated after the code has been created. + db.set_invite_count(user1, 2).await.unwrap(); + let (latest_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap(); + assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0 + assert_eq!(invite_count, 2); + + // User 4 can now redeem the invite code and becomes a contact of user 1. + let user4 = db.redeem_invite_code(&invite_code, "user-4").await.unwrap(); + let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 1); + assert_eq!( + db.get_contacts(user1).await.unwrap(), + [ + Contact::Accepted { + user_id: user1, + should_notify: false + }, + Contact::Accepted { + user_id: user2, + should_notify: true + }, + Contact::Accepted { + user_id: user3, + should_notify: true + }, + Contact::Accepted { + user_id: user4, + should_notify: true + } + ] + ); + assert_eq!( + db.get_contacts(user4).await.unwrap(), + [ + Contact::Accepted { + user_id: user1, + should_notify: false + }, + Contact::Accepted { + user_id: user4, + should_notify: false + }, + ] + ); + + // An existing user cannot redeem invite codes. + db.redeem_invite_code(&invite_code, "user-2") + .await + .unwrap_err(); + let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 1); + } + pub struct TestDb { pub db: Option>, pub url: String, @@ -1348,6 +1588,22 @@ pub mod tests { unimplemented!() } + // invite codes + + async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> { + unimplemented!() + } + + async fn get_invite_code(&self, _id: UserId) -> Result> { + unimplemented!() + } + + async fn redeem_invite_code(&self, _code: &str, _login: &str) -> Result { + unimplemented!() + } + + // contacts + async fn get_contacts(&self, id: UserId) -> Result> { self.background.simulate_random_delay().await; let mut contacts = vec![Contact::Accepted {