From c19083a35c89a22395595f8934c117a14943ed24 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 29 May 2024 11:55:40 +0200 Subject: [PATCH] Fix bug when loading bucket from database --- crates/collab/src/rate_limiter.rs | 46 ++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/crates/collab/src/rate_limiter.rs b/crates/collab/src/rate_limiter.rs index c78dec4b24..844a3af949 100644 --- a/crates/collab/src/rate_limiter.rs +++ b/crates/collab/src/rate_limiter.rs @@ -62,7 +62,7 @@ impl RateLimiter { let mut bucket = self .buckets .entry(bucket_key.clone()) - .or_insert_with(|| RateBucket::new(T::capacity(), T::refill_duration(), now)); + .or_insert_with(|| RateBucket::new::(now)); if bucket.value_mut().allow(now) { self.dirty_buckets.insert(bucket_key); @@ -72,19 +72,19 @@ impl RateLimiter { } } - async fn load_bucket( + async fn load_bucket( &self, user_id: UserId, ) -> Result, Error> { Ok(self .db - .get_rate_bucket(user_id, K::db_name()) + .get_rate_bucket(user_id, T::db_name()) .await? - .map(|saved_bucket| RateBucket { - capacity: K::capacity(), - refill_time_per_token: K::refill_duration(), - token_count: saved_bucket.token_count as usize, - last_refill: DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc), + .map(|saved_bucket| { + RateBucket::from_db::( + saved_bucket.token_count as usize, + DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc), + ) })) } @@ -124,15 +124,24 @@ struct RateBucket { } impl RateBucket { - fn new(capacity: usize, refill_duration: Duration, now: DateTimeUtc) -> Self { - RateBucket { - capacity, - token_count: capacity, - refill_time_per_token: refill_duration / capacity as i32, + fn new(now: DateTimeUtc) -> Self { + Self { + capacity: T::capacity(), + token_count: T::capacity(), + refill_time_per_token: T::refill_duration() / T::capacity() as i32, last_refill: now, } } + fn from_db(token_count: usize, last_refill: DateTimeUtc) -> Self { + Self { + capacity: T::capacity(), + token_count, + refill_time_per_token: T::refill_duration() / T::capacity() as i32, + last_refill, + } + } + fn allow(&mut self, now: DateTimeUtc) -> bool { self.refill(now); if self.token_count > 0 { @@ -252,6 +261,17 @@ mod tests { .check_internal::(user_1, now) .await .unwrap_err(); + + // After 1s, user 1 can make another request before being rate-limited again. + now += Duration::seconds(1); + rate_limiter + .check_internal::(user_1, now) + .await + .unwrap(); + rate_limiter + .check_internal::(user_1, now) + .await + .unwrap_err(); } struct RateLimitA;